Add Adam optimizer, move traits around, and fix issue with conjugate gradient
This commit is contained in:
parent
c9205415f2
commit
2993580861
161
src/gradient_descent/adam.rs
Normal file
161
src/gradient_descent/adam.rs
Normal file
@ -0,0 +1,161 @@
|
|||||||
|
use crate::{
|
||||||
|
gradient_descent::consts::EPS,
|
||||||
|
minimize::{ExitCondition, OptimizationResult},
|
||||||
|
objective_function::ObjectiveFun,
|
||||||
|
traits::XVar,
|
||||||
|
};
|
||||||
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
use super::conjugate_gradient::ConjGradPrime;
|
||||||
|
|
||||||
|
pub struct AdamParameters {
|
||||||
|
alpha0: f64,
|
||||||
|
beta1: f64,
|
||||||
|
beta2: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait AdamPrime: ConjGradPrime {
|
||||||
|
fn zero(&self) -> Self;
|
||||||
|
fn sqrt(&self) -> Self;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AdamPrime for f64 {
|
||||||
|
fn zero(&self) -> Self {
|
||||||
|
0.0
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sqrt(&self) -> Self {
|
||||||
|
f64::sqrt(*self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AdamPrime for Vec<f64> {
|
||||||
|
fn zero(&self) -> Self {
|
||||||
|
(0..self.len()).map(|_| 0.0).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sqrt(&self) -> Self {
|
||||||
|
self.iter().map(|val| val.sqrt()).collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn adam<T: XVar<E> + Clone, E: Debug + AdamPrime>(
|
||||||
|
fun: &dyn ObjectiveFun<T, E>,
|
||||||
|
x0: &T,
|
||||||
|
max_iters: usize,
|
||||||
|
tolerance: f64,
|
||||||
|
params: &AdamParameters,
|
||||||
|
) -> OptimizationResult<T> {
|
||||||
|
// Make a mutable copy of x0 to work with
|
||||||
|
let mut xs = x0.clone();
|
||||||
|
|
||||||
|
// Perform the iteration
|
||||||
|
let mut t = 0;
|
||||||
|
let mut prime = fun.prime(x0);
|
||||||
|
let mut m = prime.zero();
|
||||||
|
let mut v = prime.zero();
|
||||||
|
let mut old_f = fun.eval(x0);
|
||||||
|
let mut f = old_f;
|
||||||
|
for _ in 0..max_iters {
|
||||||
|
// Do an adam step
|
||||||
|
m = m.scale(params.beta1).add(&prime.scale(1.0 - params.beta1));
|
||||||
|
v = (v.scale(params.beta2)).add(&prime.mul(&prime.scale(1.0 - params.beta2)));
|
||||||
|
let mhat = m.scale(1.0 / (1.0 - params.beta1.powi(t as i32 + 1)));
|
||||||
|
let vhat = v.scale(1.0 / (1.0 - params.beta2.powi(t as i32 + 1)));
|
||||||
|
let update_direction = mhat.div(&vhat.sqrt().add_float(EPS)).scale(-1.0);
|
||||||
|
|
||||||
|
xs = xs.update(params.alpha0, &update_direction);
|
||||||
|
prime = fun.prime(&xs);
|
||||||
|
|
||||||
|
// Check convergence
|
||||||
|
f = fun.eval(&xs);
|
||||||
|
if f.is_nan() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
if (f - old_f).abs() < tolerance {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
old_f = f;
|
||||||
|
t += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let exit_con = if t == max_iters {
|
||||||
|
ExitCondition::MaxIter
|
||||||
|
} else {
|
||||||
|
ExitCondition::Converged
|
||||||
|
};
|
||||||
|
OptimizationResult {
|
||||||
|
best_xs: xs,
|
||||||
|
best_fun_val: f,
|
||||||
|
exit_con,
|
||||||
|
iters: t,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use crate::objective_function::Fun;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
pub fn simple_adam_test() {
|
||||||
|
let fun = Box::new(|xs: &Vec<f64>| xs.iter().fold(0.0, |acc, x| acc + x.powi(2)));
|
||||||
|
let prime = Box::new(|xs: &Vec<f64>| xs.iter().map(|x| 2.0 * x).collect());
|
||||||
|
|
||||||
|
let obj = Fun::new(fun, prime);
|
||||||
|
let params = AdamParameters {
|
||||||
|
alpha0: 0.1,
|
||||||
|
beta1: 0.9,
|
||||||
|
beta2: 0.999,
|
||||||
|
};
|
||||||
|
let res = adam(&obj, &vec![10.0, 10.0], 1000, 1e-12, ¶ms);
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"Best val is {:?} for xs {:?}",
|
||||||
|
res.best_fun_val, res.best_xs
|
||||||
|
);
|
||||||
|
|
||||||
|
println!("Exitted with {:?}", res.exit_con);
|
||||||
|
if let ExitCondition::MaxIter = res.exit_con {
|
||||||
|
panic!("Failed to converge to minima");
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"{:?} on iteration {} has value:\n{}",
|
||||||
|
res.best_xs, res.iters, res.best_fun_val
|
||||||
|
);
|
||||||
|
assert!(res.best_fun_val < 1e-8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
pub fn basic_beale_test() {
|
||||||
|
let fun = Box::new(|x: &Vec<f64>| {
|
||||||
|
(1.5 - x[0] + x[0] * x[1]).powi(2)
|
||||||
|
+ (2.25 - x[0] + x[0] * x[1].powi(2)).powi(2)
|
||||||
|
+ (2.625 - x[0] + x[0] * x[1].powi(3)).powi(2)
|
||||||
|
});
|
||||||
|
let prime = Box::new(|x: &Vec<f64>| {
|
||||||
|
vec![
|
||||||
|
2.0 * (1.5 - x[0] + x[0] * x[1]) * (x[1] - 1.0)
|
||||||
|
+ 2.0 * (2.25 - x[0] + x[0] * x[1].powi(2)) * (x[1].powi(2) - 1.0)
|
||||||
|
+ 2.0 * (2.625 - x[0] + x[0] * x[1].powi(3)) * (x[1].powi(3) - 1.0),
|
||||||
|
2.0 * (1.5 - x[0] + x[0] * x[1]) * (x[0])
|
||||||
|
+ 2.0 * (2.25 - x[0] + x[0] * x[1].powi(2)) * (2.0 * x[0] * x[1])
|
||||||
|
+ 2.0 * (2.625 - x[0] + x[0] * x[1].powi(3)) * (3.0 * x[0] * x[1].powi(3)),
|
||||||
|
]
|
||||||
|
});
|
||||||
|
let obj = Fun::new(fun, prime);
|
||||||
|
let params = AdamParameters {
|
||||||
|
alpha0: 0.1,
|
||||||
|
beta1: 0.9,
|
||||||
|
beta2: 0.999,
|
||||||
|
};
|
||||||
|
let res = adam(&obj, &vec![4.0, 1.00], 1000, 1e-12, ¶ms);
|
||||||
|
println!(
|
||||||
|
"Best val is {:?} for xs {:?} in {} iterations",
|
||||||
|
res.best_fun_val, res.best_xs, res.iters
|
||||||
|
);
|
||||||
|
|
||||||
|
println!("Exit condition is: {:?}", res.exit_con);
|
||||||
|
assert!(res.best_fun_val < 1e-7);
|
||||||
|
}
|
||||||
|
}
|
0
src/gradient_descent/base.rs
Normal file
0
src/gradient_descent/base.rs
Normal file
@ -1,14 +1,15 @@
|
|||||||
use crate::{
|
use crate::{
|
||||||
|
gradient_descent::consts::EPS,
|
||||||
minimize::{ExitCondition, OptimizationResult},
|
minimize::{ExitCondition, OptimizationResult},
|
||||||
objective_function::ObjectiveFun,
|
objective_function::ObjectiveFun,
|
||||||
traits::XVar,
|
traits::XVar,
|
||||||
};
|
};
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use super::line_search::LineSearch;
|
use super::{line_search::LineSearch, steepest_descent::SteepestDescentPrime};
|
||||||
|
|
||||||
/// Trait that should be implemented by the Prime type for conjugate gradient
|
/// Trait that should be implemented by the Prime type for conjugate gradient
|
||||||
pub trait ConjGradPrime: Clone + Debug {
|
pub trait ConjGradPrime: Clone + Debug + SteepestDescentPrime {
|
||||||
/// Multiply primes by each other
|
/// Multiply primes by each other
|
||||||
fn mul(&self, rhs: &Self) -> Self;
|
fn mul(&self, rhs: &Self) -> Self;
|
||||||
/// Subtract primes from each other
|
/// Subtract primes from each other
|
||||||
@ -19,6 +20,8 @@ pub trait ConjGradPrime: Clone + Debug {
|
|||||||
fn div(&self, denominator: &Self) -> Self;
|
fn div(&self, denominator: &Self) -> Self;
|
||||||
/// Max between the prime and a float
|
/// Max between the prime and a float
|
||||||
fn max(&self, rhs: f64) -> Self;
|
fn max(&self, rhs: f64) -> Self;
|
||||||
|
/// Add a float to the prime
|
||||||
|
fn add_float(&self, rhs: f64) -> Self;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ConjGradPrime for f64 {
|
impl ConjGradPrime for f64 {
|
||||||
@ -41,6 +44,10 @@ impl ConjGradPrime for f64 {
|
|||||||
fn add(&self, rhs: &Self) -> Self {
|
fn add(&self, rhs: &Self) -> Self {
|
||||||
self + rhs
|
self + rhs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn add_float(&self, rhs: f64) -> Self {
|
||||||
|
self + rhs
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ConjGradPrime for Vec<f64> {
|
impl ConjGradPrime for Vec<f64> {
|
||||||
@ -72,9 +79,13 @@ impl ConjGradPrime for Vec<f64> {
|
|||||||
fn add(&self, rhs: &Self) -> Self {
|
fn add(&self, rhs: &Self) -> Self {
|
||||||
self.iter()
|
self.iter()
|
||||||
.zip(rhs.iter())
|
.zip(rhs.iter())
|
||||||
.map(|(lhs, rhs)| lhs - rhs)
|
.map(|(lhs, rhs)| lhs + rhs)
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn add_float(&self, rhs: f64) -> Self {
|
||||||
|
self.iter().map(|val| val + rhs).collect()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn conjugate_gradient<T: XVar<E> + Clone, E: Debug + ConjGradPrime>(
|
pub fn conjugate_gradient<T: XVar<E> + Clone, E: Debug + ConjGradPrime>(
|
||||||
@ -92,7 +103,7 @@ pub fn conjugate_gradient<T: XVar<E> + Clone, E: Debug + ConjGradPrime>(
|
|||||||
let mut f = 0.0;
|
let mut f = 0.0;
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
let mut prev_residual = fun.prime(&xs);
|
let mut prev_residual = fun.prime(&xs);
|
||||||
let mut direction = T::scale_prime(&prev_residual, -1.0);
|
let mut direction = prev_residual.scale(-1.0);
|
||||||
for _ in 0..max_iters {
|
for _ in 0..max_iters {
|
||||||
let learning_rate = line_search.get_learning_rate(fun, &xs, &direction);
|
let learning_rate = line_search.get_learning_rate(fun, &xs, &direction);
|
||||||
xs = xs.update(learning_rate, &direction);
|
xs = xs.update(learning_rate, &direction);
|
||||||
@ -105,11 +116,11 @@ pub fn conjugate_gradient<T: XVar<E> + Clone, E: Debug + ConjGradPrime>(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Update using polack-ribiere
|
// Update using polack-ribiere
|
||||||
let new_residual = T::scale_prime(&fun.prime(&xs), -1.0);
|
let new_residual = fun.prime(&xs).scale(-1.0);
|
||||||
let beta = (new_residual.mul(&new_residual.sub(&prev_residual)))
|
let beta = (new_residual.mul(&new_residual.sub(&prev_residual)))
|
||||||
.div(&prev_residual.mul(&prev_residual));
|
.div(&prev_residual.mul(&prev_residual).add_float(EPS));
|
||||||
let beta = beta.max(0.0);
|
let beta = beta.max(0.0);
|
||||||
direction = new_residual.add(&beta.mul(&direction));
|
direction = new_residual.sub(&beta.mul(&direction));
|
||||||
prev_residual = new_residual.clone();
|
prev_residual = new_residual.clone();
|
||||||
i += 1;
|
i += 1;
|
||||||
}
|
}
|
||||||
@ -140,8 +151,8 @@ mod test {
|
|||||||
|
|
||||||
let obj = Fun::new(fun, prime);
|
let obj = Fun::new(fun, prime);
|
||||||
let line_searches = vec![LineSearch::BackTrack {
|
let line_searches = vec![LineSearch::BackTrack {
|
||||||
gamma: 0.9,
|
gamma: 0.5,
|
||||||
c: 0.01,
|
c: 0.001,
|
||||||
}];
|
}];
|
||||||
for line_search in line_searches {
|
for line_search in line_searches {
|
||||||
let res = conjugate_gradient(&obj, &vec![20.0, 20.0], 1000, 1e-12, &line_search);
|
let res = conjugate_gradient(&obj, &vec![20.0, 20.0], 1000, 1e-12, &line_search);
|
||||||
@ -185,10 +196,10 @@ mod test {
|
|||||||
gamma: 0.9,
|
gamma: 0.9,
|
||||||
c: 0.01,
|
c: 0.01,
|
||||||
};
|
};
|
||||||
let res = conjugate_gradient(&obj, &vec![4.0, 1.00], 10000, 1e-12, &line_search);
|
let res = conjugate_gradient(&obj, &vec![4.0, 1.00], 1000, 1e-12, &line_search);
|
||||||
println!(
|
println!(
|
||||||
"Best val is {:?} for xs {:?}",
|
"Best val is {:?} for xs {:?} in {} iterations",
|
||||||
res.best_fun_val, res.best_xs
|
res.best_fun_val, res.best_xs, res.iters
|
||||||
);
|
);
|
||||||
|
|
||||||
println!("Exit condition is: {:?}", res.exit_con);
|
println!("Exit condition is: {:?}", res.exit_con);
|
||||||
|
1
src/gradient_descent/consts.rs
Normal file
1
src/gradient_descent/consts.rs
Normal file
@ -0,0 +1 @@
|
|||||||
|
pub const EPS: f64 = 1e-12;
|
@ -2,6 +2,8 @@ use core::fmt;
|
|||||||
|
|
||||||
use crate::{objective_function::ObjectiveFun, traits::XVar};
|
use crate::{objective_function::ObjectiveFun, traits::XVar};
|
||||||
|
|
||||||
|
use super::steepest_descent::SteepestDescentPrime;
|
||||||
|
|
||||||
pub enum LineSearch {
|
pub enum LineSearch {
|
||||||
ConstAlpha { learning_rate: f64 },
|
ConstAlpha { learning_rate: f64 },
|
||||||
BackTrack { gamma: f64, c: f64 },
|
BackTrack { gamma: f64, c: f64 },
|
||||||
@ -16,7 +18,7 @@ impl LineSearch {
|
|||||||
) -> f64
|
) -> f64
|
||||||
where
|
where
|
||||||
T: XVar<E> + Clone,
|
T: XVar<E> + Clone,
|
||||||
E: fmt::Debug,
|
E: fmt::Debug + SteepestDescentPrime,
|
||||||
{
|
{
|
||||||
match self {
|
match self {
|
||||||
LineSearch::ConstAlpha { learning_rate } => *learning_rate,
|
LineSearch::ConstAlpha { learning_rate } => *learning_rate,
|
||||||
@ -25,10 +27,7 @@ impl LineSearch {
|
|||||||
let fk = fun.eval(xs);
|
let fk = fun.eval(xs);
|
||||||
let mut new_f = fun.eval(&xs.update(1.0, &prime));
|
let mut new_f = fun.eval(&xs.update(1.0, &prime));
|
||||||
let mut t = 1.0;
|
let mut t = 1.0;
|
||||||
while fk
|
while fk < new_f + t * c * prime.scale(-1.0).inner_product(direction) {
|
||||||
< new_f
|
|
||||||
+ t * c * T::prime_inner_product(&T::scale_prime(&prime, -1.0), direction)
|
|
||||||
{
|
|
||||||
t *= gamma;
|
t *= gamma;
|
||||||
let new_x = xs.update(t, direction);
|
let new_x = xs.update(t, direction);
|
||||||
new_f = fun.eval(&new_x);
|
new_f = fun.eval(&new_x);
|
||||||
|
@ -1,4 +1,6 @@
|
|||||||
|
pub mod adam;
|
||||||
pub mod base;
|
pub mod base;
|
||||||
pub mod conjugate_gradient;
|
pub mod conjugate_gradient;
|
||||||
|
pub mod consts;
|
||||||
pub mod line_search;
|
pub mod line_search;
|
||||||
pub mod steepest_descent;
|
pub mod steepest_descent;
|
||||||
|
@ -3,10 +3,37 @@ use crate::{
|
|||||||
objective_function::ObjectiveFun,
|
objective_function::ObjectiveFun,
|
||||||
traits::XVar,
|
traits::XVar,
|
||||||
};
|
};
|
||||||
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use super::line_search::LineSearch;
|
use super::line_search::LineSearch;
|
||||||
|
|
||||||
pub fn steepest_descent<T: XVar<E> + Clone, E: std::fmt::Debug>(
|
pub trait SteepestDescentPrime {
|
||||||
|
fn scale(&self, factor: f64) -> Self;
|
||||||
|
fn inner_product(&self, rhs: &Self) -> f64;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SteepestDescentPrime for f64 {
|
||||||
|
fn scale(&self, factor: f64) -> Self {
|
||||||
|
self * factor
|
||||||
|
}
|
||||||
|
|
||||||
|
fn inner_product(&self, rhs: &Self) -> f64 {
|
||||||
|
self * rhs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SteepestDescentPrime for Vec<f64> {
|
||||||
|
fn scale(&self, factor: f64) -> Self {
|
||||||
|
self.iter().map(|val| val * factor).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn inner_product(&self, rhs: &Self) -> f64 {
|
||||||
|
self.iter()
|
||||||
|
.zip(rhs)
|
||||||
|
.fold(0.0, |acc, (lhs, rhs)| acc + lhs * rhs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn steepest_descent<T: XVar<E> + Clone, E: Debug + SteepestDescentPrime>(
|
||||||
fun: &dyn ObjectiveFun<T, E>,
|
fun: &dyn ObjectiveFun<T, E>,
|
||||||
x0: &T,
|
x0: &T,
|
||||||
max_iters: usize,
|
max_iters: usize,
|
||||||
@ -21,9 +48,9 @@ pub fn steepest_descent<T: XVar<E> + Clone, E: std::fmt::Debug>(
|
|||||||
let mut f = 0.0;
|
let mut f = 0.0;
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
for _ in 0..max_iters {
|
for _ in 0..max_iters {
|
||||||
let direction = T::scale_prime(&fun.prime(&xs), -1.0);
|
let direction = &fun.prime(&xs).scale(-1.0);
|
||||||
let learning_rate = line_search.get_learning_rate(fun, &xs, &direction);
|
let learning_rate = line_search.get_learning_rate(fun, &xs, direction);
|
||||||
xs = xs.update(learning_rate, &direction);
|
xs = xs.update(learning_rate, direction);
|
||||||
f = fun.eval(&xs);
|
f = fun.eval(&xs);
|
||||||
if (f - f_iminus1).abs() < tolerance {
|
if (f - f_iminus1).abs() < tolerance {
|
||||||
break;
|
break;
|
||||||
@ -62,7 +89,10 @@ mod test {
|
|||||||
LineSearch::ConstAlpha {
|
LineSearch::ConstAlpha {
|
||||||
learning_rate: 0.25,
|
learning_rate: 0.25,
|
||||||
},
|
},
|
||||||
LineSearch::BackTrack { gamma: 0.9, c: 0.3 },
|
LineSearch::BackTrack {
|
||||||
|
gamma: 0.5,
|
||||||
|
c: 0.001,
|
||||||
|
},
|
||||||
];
|
];
|
||||||
for line_search in line_searches {
|
for line_search in line_searches {
|
||||||
let res = steepest_descent(&obj, &vec![20.0, 20.0], 1000, 1e-12, &line_search);
|
let res = steepest_descent(&obj, &vec![20.0, 20.0], 1000, 1e-12, &line_search);
|
||||||
@ -102,8 +132,8 @@ mod test {
|
|||||||
};
|
};
|
||||||
let res = steepest_descent(&obj, &vec![3.1, 0.5], 10000, 1e-12, &line_search);
|
let res = steepest_descent(&obj, &vec![3.1, 0.5], 10000, 1e-12, &line_search);
|
||||||
println!(
|
println!(
|
||||||
"Best val is {:?} for xs {:?}",
|
"Best val is {:?} for xs {:?} in {} iterations",
|
||||||
res.best_fun_val, res.best_xs
|
res.best_fun_val, res.best_xs, res.iters,
|
||||||
);
|
);
|
||||||
assert!(res.best_fun_val < 1e-7);
|
assert!(res.best_fun_val < 1e-7);
|
||||||
}
|
}
|
||||||
|
@ -9,12 +9,6 @@ use std::fmt::Debug;
|
|||||||
pub trait XVar<T>: Clone + Debug {
|
pub trait XVar<T>: Clone + Debug {
|
||||||
/// Update the current Xvariable based on the prime
|
/// Update the current Xvariable based on the prime
|
||||||
fn update(&self, alpha: f64, prime: &T) -> Self;
|
fn update(&self, alpha: f64, prime: &T) -> Self;
|
||||||
/// Multiply the prime by a float
|
|
||||||
fn scale_prime(prime: &T, rhs: f64) -> T;
|
|
||||||
/// Add a float to the prime
|
|
||||||
fn add_prime(prime: &T, rhs: f64) -> T;
|
|
||||||
/// Inner Product of prime
|
|
||||||
fn prime_inner_product(prime: &T, rhs: &T) -> f64;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implementation of XVar for an f64 type
|
/// Implementation of XVar for an f64 type
|
||||||
@ -22,18 +16,6 @@ impl XVar<f64> for f64 {
|
|||||||
fn update(&self, alpha: f64, prime: &f64) -> Self {
|
fn update(&self, alpha: f64, prime: &f64) -> Self {
|
||||||
self + alpha * prime
|
self + alpha * prime
|
||||||
}
|
}
|
||||||
|
|
||||||
fn scale_prime(prime: &f64, rhs: f64) -> f64 {
|
|
||||||
prime * rhs
|
|
||||||
}
|
|
||||||
|
|
||||||
fn add_prime(prime: &f64, rhs: f64) -> f64 {
|
|
||||||
prime + rhs
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prime_inner_product(prime: &f64, rhs: &f64) -> f64 {
|
|
||||||
prime * rhs
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implementation of XVar for a Vec<f64> type
|
/// Implementation of XVar for a Vec<f64> type
|
||||||
@ -44,19 +26,4 @@ impl XVar<Vec<f64>> for Vec<f64> {
|
|||||||
.map(|(x, xprime)| x + alpha * xprime)
|
.map(|(x, xprime)| x + alpha * xprime)
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn scale_prime(prime: &Vec<f64>, rhs: f64) -> Vec<f64> {
|
|
||||||
prime.iter().map(|val| val * rhs).collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn add_prime(prime: &Vec<f64>, rhs: f64) -> Vec<f64> {
|
|
||||||
prime.iter().map(|val| val + rhs).collect()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prime_inner_product(prime: &Vec<f64>, rhs: &Vec<f64>) -> f64 {
|
|
||||||
prime
|
|
||||||
.iter()
|
|
||||||
.zip(rhs.iter())
|
|
||||||
.fold(0.0, |acc, a| acc + a.0 * a.1)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user