diff --git a/src/gradient_descent/line_search.rs b/src/gradient_descent/line_search.rs index 8eff048..fb3d287 100644 --- a/src/gradient_descent/line_search.rs +++ b/src/gradient_descent/line_search.rs @@ -1,3 +1,5 @@ +use core::fmt; + use crate::{objective_function::ObjectiveFun, traits::XVar}; pub enum LineSearch { @@ -20,7 +22,7 @@ impl LineSearch { ) -> f64 where T: XVar + Clone, - E:, + E: fmt::Debug, { match self { LineSearch::ConstAlpha { learning_rate } => *learning_rate, @@ -30,17 +32,16 @@ impl LineSearch { c, } => { let prime = fun.prime(xs); - let mut del_f = T::scale_prime(&prime, *c); - let mut new_f = fun.eval(&xs.update(-1.0, &prime)); + let fk = fun.eval(xs); + let mut new_f = fun.eval(&xs.update(1.0, &prime)); let mut t = 1.0; - for i in 0..*max_iterations { - if new_f > T::prime_inner_product(&T::scale_prime(&del_f, t), direction) { - break; - } + while fk + < new_f + + t * c * T::prime_inner_product(&T::scale_prime(&prime, -1.0), direction) + { t *= gamma; - let new_x = xs.update(-t, &prime); + let new_x = xs.update(t, direction); new_f = fun.eval(&new_x); - del_f = fun.prime(&new_x); } t } diff --git a/src/gradient_descent/steepest_descent.rs b/src/gradient_descent/steepest_descent.rs index d3ca127..d3cefa2 100644 --- a/src/gradient_descent/steepest_descent.rs +++ b/src/gradient_descent/steepest_descent.rs @@ -6,7 +6,7 @@ use crate::{ use super::line_search::LineSearch; -pub fn steepest_descent + Clone, E>( +pub fn steepest_descent + Clone, E: std::fmt::Debug>( fun: &dyn ObjectiveFun, x0: &T, max_iters: usize, @@ -22,11 +22,10 @@ pub fn steepest_descent + Clone, E>( let mut f = 0.0; let mut i = 0; for _ in 0..max_iters { - let primes = fun.prime(&xs); - let learning_rate = line_search.get_learning_rate(fun, &xs, &T::scale_prime(&primes, -1.0)); - xs = xs.update(direction * learning_rate, &primes); + let direction = T::scale_prime(&fun.prime(&xs), -1.0); + let learning_rate = line_search.get_learning_rate(fun, &xs, &direction); + xs = xs.update(learning_rate, &direction); f = fun.eval(&xs); - if (f - f_iminus1).abs() < tolerance { break; } else { @@ -66,21 +65,52 @@ mod test { }, LineSearch::BackTrack { max_iterations: 100, - gamma: 0.5, - c: 0.1, + gamma: 0.9, + c: 0.3, }, ]; for line_search in line_searches { - let res = steepest_descent(&obj, &vec![20.0], 1000, 1e-12, &line_search, -1.0); + let res = steepest_descent(&obj, &vec![20.0, 20.0], 1000, 1e-12, &line_search, -1.0); if let ExitCondition::MaxIter = res.exit_con { panic!("Failed to converge to minima"); } println!( - "{:?} on iteration {}\n{}", + "{:?} on iteration {} has value:\n{}", res.best_xs, res.iters, res.best_fun_val ); assert!(res.best_fun_val < 1e-8); } } + + #[test] + pub fn basic_beale_test() { + let fun = Box::new(|x: &Vec| { + (1.5 - x[0] + x[0] * x[1]).powi(2) + + (2.25 - x[0] + x[0] * x[1].powi(2)).powi(2) + + (2.625 - x[0] + x[0] * x[1].powi(3)).powi(2) + }); + let prime = Box::new(|x: &Vec| { + vec![ + 2.0 * (1.5 - x[0] + x[0] * x[1]) * (x[1] - 1.0) + + 2.0 * (2.25 - x[0] + x[0] * x[1].powi(2)) * (x[1].powi(2) - 1.0) + + 2.0 * (2.625 - x[0] + x[0] * x[1].powi(3)) * (x[1].powi(3) - 1.0), + 2.0 * (1.5 - x[0] + x[0] * x[1]) * (x[0]) + + 2.0 * (2.25 - x[0] + x[0] * x[1].powi(2)) * (2.0 * x[0] * x[1]) + + 2.0 * (2.625 - x[0] + x[0] * x[1].powi(3)) * (3.0 * x[0] * x[1].powi(3)), + ] + }); + let obj = Fun::new(fun, prime); + let line_search = LineSearch::BackTrack { + max_iterations: 1000, + gamma: 0.9, + c: 0.01, + }; + let res = steepest_descent(&obj, &vec![3.1, 0.5], 10000, 1e-12, &line_search, -1.0); + println!( + "Best val is {:?} for xs {:?}", + res.best_fun_val, res.best_xs + ); + assert!(res.best_fun_val < 1e-7); + } } diff --git a/src/traits.rs b/src/traits.rs index b2fd3b8..8f73485 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,10 +1,12 @@ +use std::fmt::Debug; + /// Trait defining the data structure that must be implemented for the independent variables used /// in the objective function. The generic type denotes the type of the prime of that variable /// NOTE: This trait also defines some functions that are required to operate on the prime data /// type. It should be noted that we are unable to just require T to implement Mul or Add /// Bbecause then we wouldn't be able to implement XVar for plain Vec types which seems /// inconvenient -pub trait XVar: Clone { +pub trait XVar: Clone + Debug { /// Update the current Xvariable based on the prime fn update(&self, alpha: f64, prime: &T) -> Self; /// Multiply the prime by a float @@ -38,7 +40,7 @@ impl XVar for f64 { impl XVar> for Vec { fn update(&self, alpha: f64, prime: &Vec) -> Self { self.iter() - .zip(prime) + .zip(prime.iter()) .map(|(x, xprime)| x + alpha * xprime) .collect() }