diff --git a/src/gradient_descent/adam.rs b/src/gradient_descent/adam.rs new file mode 100644 index 0000000..653286e --- /dev/null +++ b/src/gradient_descent/adam.rs @@ -0,0 +1,161 @@ +use crate::{ + gradient_descent::consts::EPS, + minimize::{ExitCondition, OptimizationResult}, + objective_function::ObjectiveFun, + traits::XVar, +}; +use std::fmt::Debug; + +use super::conjugate_gradient::ConjGradPrime; + +pub struct AdamParameters { + alpha0: f64, + beta1: f64, + beta2: f64, +} + +pub trait AdamPrime: ConjGradPrime { + fn zero(&self) -> Self; + fn sqrt(&self) -> Self; +} + +impl AdamPrime for f64 { + fn zero(&self) -> Self { + 0.0 + } + + fn sqrt(&self) -> Self { + f64::sqrt(*self) + } +} + +impl AdamPrime for Vec { + fn zero(&self) -> Self { + (0..self.len()).map(|_| 0.0).collect() + } + + fn sqrt(&self) -> Self { + self.iter().map(|val| val.sqrt()).collect() + } +} +pub fn adam + Clone, E: Debug + AdamPrime>( + fun: &dyn ObjectiveFun, + x0: &T, + max_iters: usize, + tolerance: f64, + params: &AdamParameters, +) -> OptimizationResult { + // Make a mutable copy of x0 to work with + let mut xs = x0.clone(); + + // Perform the iteration + let mut t = 0; + let mut prime = fun.prime(x0); + let mut m = prime.zero(); + let mut v = prime.zero(); + let mut old_f = fun.eval(x0); + let mut f = old_f; + for _ in 0..max_iters { + // Do an adam step + m = m.scale(params.beta1).add(&prime.scale(1.0 - params.beta1)); + v = (v.scale(params.beta2)).add(&prime.mul(&prime.scale(1.0 - params.beta2))); + let mhat = m.scale(1.0 / (1.0 - params.beta1.powi(t as i32 + 1))); + let vhat = v.scale(1.0 / (1.0 - params.beta2.powi(t as i32 + 1))); + let update_direction = mhat.div(&vhat.sqrt().add_float(EPS)).scale(-1.0); + + xs = xs.update(params.alpha0, &update_direction); + prime = fun.prime(&xs); + + // Check convergence + f = fun.eval(&xs); + if f.is_nan() { + break; + } + if (f - old_f).abs() < tolerance { + break; + } + old_f = f; + t += 1; + } + + let exit_con = if t == max_iters { + ExitCondition::MaxIter + } else { + ExitCondition::Converged + }; + OptimizationResult { + best_xs: xs, + best_fun_val: f, + exit_con, + iters: t, + } +} + +#[cfg(test)] +mod test { + use crate::objective_function::Fun; + + use super::*; + + #[test] + pub fn simple_adam_test() { + let fun = Box::new(|xs: &Vec| xs.iter().fold(0.0, |acc, x| acc + x.powi(2))); + let prime = Box::new(|xs: &Vec| xs.iter().map(|x| 2.0 * x).collect()); + + let obj = Fun::new(fun, prime); + let params = AdamParameters { + alpha0: 0.1, + beta1: 0.9, + beta2: 0.999, + }; + let res = adam(&obj, &vec![10.0, 10.0], 1000, 1e-12, ¶ms); + + println!( + "Best val is {:?} for xs {:?}", + res.best_fun_val, res.best_xs + ); + + println!("Exitted with {:?}", res.exit_con); + if let ExitCondition::MaxIter = res.exit_con { + panic!("Failed to converge to minima"); + } + println!( + "{:?} on iteration {} has value:\n{}", + res.best_xs, res.iters, res.best_fun_val + ); + assert!(res.best_fun_val < 1e-8); + } + + #[test] + pub fn basic_beale_test() { + let fun = Box::new(|x: &Vec| { + (1.5 - x[0] + x[0] * x[1]).powi(2) + + (2.25 - x[0] + x[0] * x[1].powi(2)).powi(2) + + (2.625 - x[0] + x[0] * x[1].powi(3)).powi(2) + }); + let prime = Box::new(|x: &Vec| { + vec![ + 2.0 * (1.5 - x[0] + x[0] * x[1]) * (x[1] - 1.0) + + 2.0 * (2.25 - x[0] + x[0] * x[1].powi(2)) * (x[1].powi(2) - 1.0) + + 2.0 * (2.625 - x[0] + x[0] * x[1].powi(3)) * (x[1].powi(3) - 1.0), + 2.0 * (1.5 - x[0] + x[0] * x[1]) * (x[0]) + + 2.0 * (2.25 - x[0] + x[0] * x[1].powi(2)) * (2.0 * x[0] * x[1]) + + 2.0 * (2.625 - x[0] + x[0] * x[1].powi(3)) * (3.0 * x[0] * x[1].powi(3)), + ] + }); + let obj = Fun::new(fun, prime); + let params = AdamParameters { + alpha0: 0.1, + beta1: 0.9, + beta2: 0.999, + }; + let res = adam(&obj, &vec![4.0, 1.00], 1000, 1e-12, ¶ms); + println!( + "Best val is {:?} for xs {:?} in {} iterations", + res.best_fun_val, res.best_xs, res.iters + ); + + println!("Exit condition is: {:?}", res.exit_con); + assert!(res.best_fun_val < 1e-7); + } +} diff --git a/src/gradient_descent/base.rs b/src/gradient_descent/base.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/gradient_descent/conjugate_gradient.rs b/src/gradient_descent/conjugate_gradient.rs index 5c075b8..30bad31 100644 --- a/src/gradient_descent/conjugate_gradient.rs +++ b/src/gradient_descent/conjugate_gradient.rs @@ -1,14 +1,15 @@ use crate::{ + gradient_descent::consts::EPS, minimize::{ExitCondition, OptimizationResult}, objective_function::ObjectiveFun, traits::XVar, }; use std::fmt::Debug; -use super::line_search::LineSearch; +use super::{line_search::LineSearch, steepest_descent::SteepestDescentPrime}; /// Trait that should be implemented by the Prime type for conjugate gradient -pub trait ConjGradPrime: Clone + Debug { +pub trait ConjGradPrime: Clone + Debug + SteepestDescentPrime { /// Multiply primes by each other fn mul(&self, rhs: &Self) -> Self; /// Subtract primes from each other @@ -19,6 +20,8 @@ pub trait ConjGradPrime: Clone + Debug { fn div(&self, denominator: &Self) -> Self; /// Max between the prime and a float fn max(&self, rhs: f64) -> Self; + /// Add a float to the prime + fn add_float(&self, rhs: f64) -> Self; } impl ConjGradPrime for f64 { @@ -41,6 +44,10 @@ impl ConjGradPrime for f64 { fn add(&self, rhs: &Self) -> Self { self + rhs } + + fn add_float(&self, rhs: f64) -> Self { + self + rhs + } } impl ConjGradPrime for Vec { @@ -72,9 +79,13 @@ impl ConjGradPrime for Vec { fn add(&self, rhs: &Self) -> Self { self.iter() .zip(rhs.iter()) - .map(|(lhs, rhs)| lhs - rhs) + .map(|(lhs, rhs)| lhs + rhs) .collect() } + + fn add_float(&self, rhs: f64) -> Self { + self.iter().map(|val| val + rhs).collect() + } } pub fn conjugate_gradient + Clone, E: Debug + ConjGradPrime>( @@ -92,7 +103,7 @@ pub fn conjugate_gradient + Clone, E: Debug + ConjGradPrime>( let mut f = 0.0; let mut i = 0; let mut prev_residual = fun.prime(&xs); - let mut direction = T::scale_prime(&prev_residual, -1.0); + let mut direction = prev_residual.scale(-1.0); for _ in 0..max_iters { let learning_rate = line_search.get_learning_rate(fun, &xs, &direction); xs = xs.update(learning_rate, &direction); @@ -105,11 +116,11 @@ pub fn conjugate_gradient + Clone, E: Debug + ConjGradPrime>( } // Update using polack-ribiere - let new_residual = T::scale_prime(&fun.prime(&xs), -1.0); + let new_residual = fun.prime(&xs).scale(-1.0); let beta = (new_residual.mul(&new_residual.sub(&prev_residual))) - .div(&prev_residual.mul(&prev_residual)); + .div(&prev_residual.mul(&prev_residual).add_float(EPS)); let beta = beta.max(0.0); - direction = new_residual.add(&beta.mul(&direction)); + direction = new_residual.sub(&beta.mul(&direction)); prev_residual = new_residual.clone(); i += 1; } @@ -140,8 +151,8 @@ mod test { let obj = Fun::new(fun, prime); let line_searches = vec![LineSearch::BackTrack { - gamma: 0.9, - c: 0.01, + gamma: 0.5, + c: 0.001, }]; for line_search in line_searches { let res = conjugate_gradient(&obj, &vec![20.0, 20.0], 1000, 1e-12, &line_search); @@ -185,10 +196,10 @@ mod test { gamma: 0.9, c: 0.01, }; - let res = conjugate_gradient(&obj, &vec![4.0, 1.00], 10000, 1e-12, &line_search); + let res = conjugate_gradient(&obj, &vec![4.0, 1.00], 1000, 1e-12, &line_search); println!( - "Best val is {:?} for xs {:?}", - res.best_fun_val, res.best_xs + "Best val is {:?} for xs {:?} in {} iterations", + res.best_fun_val, res.best_xs, res.iters ); println!("Exit condition is: {:?}", res.exit_con); diff --git a/src/gradient_descent/consts.rs b/src/gradient_descent/consts.rs new file mode 100644 index 0000000..c360bed --- /dev/null +++ b/src/gradient_descent/consts.rs @@ -0,0 +1 @@ +pub const EPS: f64 = 1e-12; diff --git a/src/gradient_descent/line_search.rs b/src/gradient_descent/line_search.rs index 9087ffb..e295004 100644 --- a/src/gradient_descent/line_search.rs +++ b/src/gradient_descent/line_search.rs @@ -2,6 +2,8 @@ use core::fmt; use crate::{objective_function::ObjectiveFun, traits::XVar}; +use super::steepest_descent::SteepestDescentPrime; + pub enum LineSearch { ConstAlpha { learning_rate: f64 }, BackTrack { gamma: f64, c: f64 }, @@ -16,7 +18,7 @@ impl LineSearch { ) -> f64 where T: XVar + Clone, - E: fmt::Debug, + E: fmt::Debug + SteepestDescentPrime, { match self { LineSearch::ConstAlpha { learning_rate } => *learning_rate, @@ -25,10 +27,7 @@ impl LineSearch { let fk = fun.eval(xs); let mut new_f = fun.eval(&xs.update(1.0, &prime)); let mut t = 1.0; - while fk - < new_f - + t * c * T::prime_inner_product(&T::scale_prime(&prime, -1.0), direction) - { + while fk < new_f + t * c * prime.scale(-1.0).inner_product(direction) { t *= gamma; let new_x = xs.update(t, direction); new_f = fun.eval(&new_x); diff --git a/src/gradient_descent/mod.rs b/src/gradient_descent/mod.rs index 1b59051..71a0348 100644 --- a/src/gradient_descent/mod.rs +++ b/src/gradient_descent/mod.rs @@ -1,4 +1,6 @@ +pub mod adam; pub mod base; pub mod conjugate_gradient; +pub mod consts; pub mod line_search; pub mod steepest_descent; diff --git a/src/gradient_descent/steepest_descent.rs b/src/gradient_descent/steepest_descent.rs index 9ef63f9..85c910e 100644 --- a/src/gradient_descent/steepest_descent.rs +++ b/src/gradient_descent/steepest_descent.rs @@ -3,10 +3,37 @@ use crate::{ objective_function::ObjectiveFun, traits::XVar, }; +use std::fmt::Debug; use super::line_search::LineSearch; -pub fn steepest_descent + Clone, E: std::fmt::Debug>( +pub trait SteepestDescentPrime { + fn scale(&self, factor: f64) -> Self; + fn inner_product(&self, rhs: &Self) -> f64; +} + +impl SteepestDescentPrime for f64 { + fn scale(&self, factor: f64) -> Self { + self * factor + } + + fn inner_product(&self, rhs: &Self) -> f64 { + self * rhs + } +} + +impl SteepestDescentPrime for Vec { + fn scale(&self, factor: f64) -> Self { + self.iter().map(|val| val * factor).collect() + } + + fn inner_product(&self, rhs: &Self) -> f64 { + self.iter() + .zip(rhs) + .fold(0.0, |acc, (lhs, rhs)| acc + lhs * rhs) + } +} +pub fn steepest_descent + Clone, E: Debug + SteepestDescentPrime>( fun: &dyn ObjectiveFun, x0: &T, max_iters: usize, @@ -21,9 +48,9 @@ pub fn steepest_descent + Clone, E: std::fmt::Debug>( let mut f = 0.0; let mut i = 0; for _ in 0..max_iters { - let direction = T::scale_prime(&fun.prime(&xs), -1.0); - let learning_rate = line_search.get_learning_rate(fun, &xs, &direction); - xs = xs.update(learning_rate, &direction); + let direction = &fun.prime(&xs).scale(-1.0); + let learning_rate = line_search.get_learning_rate(fun, &xs, direction); + xs = xs.update(learning_rate, direction); f = fun.eval(&xs); if (f - f_iminus1).abs() < tolerance { break; @@ -62,7 +89,10 @@ mod test { LineSearch::ConstAlpha { learning_rate: 0.25, }, - LineSearch::BackTrack { gamma: 0.9, c: 0.3 }, + LineSearch::BackTrack { + gamma: 0.5, + c: 0.001, + }, ]; for line_search in line_searches { let res = steepest_descent(&obj, &vec![20.0, 20.0], 1000, 1e-12, &line_search); @@ -102,8 +132,8 @@ mod test { }; let res = steepest_descent(&obj, &vec![3.1, 0.5], 10000, 1e-12, &line_search); println!( - "Best val is {:?} for xs {:?}", - res.best_fun_val, res.best_xs + "Best val is {:?} for xs {:?} in {} iterations", + res.best_fun_val, res.best_xs, res.iters, ); assert!(res.best_fun_val < 1e-7); } diff --git a/src/traits.rs b/src/traits.rs index 6631779..5d51354 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -9,12 +9,6 @@ use std::fmt::Debug; pub trait XVar: Clone + Debug { /// Update the current Xvariable based on the prime fn update(&self, alpha: f64, prime: &T) -> Self; - /// Multiply the prime by a float - fn scale_prime(prime: &T, rhs: f64) -> T; - /// Add a float to the prime - fn add_prime(prime: &T, rhs: f64) -> T; - /// Inner Product of prime - fn prime_inner_product(prime: &T, rhs: &T) -> f64; } /// Implementation of XVar for an f64 type @@ -22,18 +16,6 @@ impl XVar for f64 { fn update(&self, alpha: f64, prime: &f64) -> Self { self + alpha * prime } - - fn scale_prime(prime: &f64, rhs: f64) -> f64 { - prime * rhs - } - - fn add_prime(prime: &f64, rhs: f64) -> f64 { - prime + rhs - } - - fn prime_inner_product(prime: &f64, rhs: &f64) -> f64 { - prime * rhs - } } /// Implementation of XVar for a Vec type @@ -44,19 +26,4 @@ impl XVar> for Vec { .map(|(x, xprime)| x + alpha * xprime) .collect() } - - fn scale_prime(prime: &Vec, rhs: f64) -> Vec { - prime.iter().map(|val| val * rhs).collect() - } - - fn add_prime(prime: &Vec, rhs: f64) -> Vec { - prime.iter().map(|val| val + rhs).collect() - } - - fn prime_inner_product(prime: &Vec, rhs: &Vec) -> f64 { - prime - .iter() - .zip(rhs.iter()) - .fold(0.0, |acc, a| acc + a.0 * a.1) - } }