From 42e9b748dd782f6d8e239fc51066dfcac68da80a Mon Sep 17 00:00:00 2001 From: Alex Selimov Date: Sat, 4 Jan 2025 22:21:13 -0500 Subject: [PATCH] Compiling version of backtracking line search, needs debugging --- src/gradient_descent/line_search.rs | 43 ++++++++++++++++++++++-- src/gradient_descent/steepest_descent.rs | 33 +++++++++++------- src/traits.rs | 38 ++++++++++++++++++++- 3 files changed, 98 insertions(+), 16 deletions(-) diff --git a/src/gradient_descent/line_search.rs b/src/gradient_descent/line_search.rs index 9762fbc..8eff048 100644 --- a/src/gradient_descent/line_search.rs +++ b/src/gradient_descent/line_search.rs @@ -1,12 +1,49 @@ +use crate::{objective_function::ObjectiveFun, traits::XVar}; + pub enum LineSearch { - ConstAlpha { learning_rate: f64 }, + ConstAlpha { + learning_rate: f64, + }, + BackTrack { + max_iterations: usize, + gamma: f64, + c: f64, + }, } impl LineSearch { - pub fn get_learning_rate(&self) -> f64 { + pub fn get_learning_rate( + &self, + fun: &dyn ObjectiveFun, + xs: &T, + direction: &E, + ) -> f64 + where + T: XVar + Clone, + E:, + { match self { LineSearch::ConstAlpha { learning_rate } => *learning_rate, + LineSearch::BackTrack { + max_iterations, + gamma, + c, + } => { + let prime = fun.prime(xs); + let mut del_f = T::scale_prime(&prime, *c); + let mut new_f = fun.eval(&xs.update(-1.0, &prime)); + let mut t = 1.0; + for i in 0..*max_iterations { + if new_f > T::prime_inner_product(&T::scale_prime(&del_f, t), direction) { + break; + } + t *= gamma; + let new_x = xs.update(-t, &prime); + new_f = fun.eval(&new_x); + del_f = fun.prime(&new_x); + } + t + } } } } - diff --git a/src/gradient_descent/steepest_descent.rs b/src/gradient_descent/steepest_descent.rs index 1447aab..d3ca127 100644 --- a/src/gradient_descent/steepest_descent.rs +++ b/src/gradient_descent/steepest_descent.rs @@ -23,7 +23,7 @@ pub fn steepest_descent + Clone, E>( let mut i = 0; for _ in 0..max_iters { let primes = fun.prime(&xs); - let learning_rate = line_search.get_learning_rate(fun, &xs); + let learning_rate = line_search.get_learning_rate(fun, &xs, &T::scale_prime(&primes, -1.0)); xs = xs.update(direction * learning_rate, &primes); f = fun.eval(&xs); @@ -60,18 +60,27 @@ mod test { let prime = Box::new(|xs: &Vec| xs.iter().map(|x| 2.0 * x).collect()); let obj = Fun::new(fun, prime); - let line_search = LineSearch::ConstAlpha { - learning_rate: 0.25, - }; - let res = steepest_descent(&obj, &vec![20.0], 1000, 1e-12, &line_search, -1.0); + let line_searches = vec![ + LineSearch::ConstAlpha { + learning_rate: 0.25, + }, + LineSearch::BackTrack { + max_iterations: 100, + gamma: 0.5, + c: 0.1, + }, + ]; + for line_search in line_searches { + let res = steepest_descent(&obj, &vec![20.0], 1000, 1e-12, &line_search, -1.0); - if let ExitCondition::MaxIter = res.exit_con { - panic!("Failed to converge to minima"); + if let ExitCondition::MaxIter = res.exit_con { + panic!("Failed to converge to minima"); + } + println!( + "{:?} on iteration {}\n{}", + res.best_xs, res.iters, res.best_fun_val + ); + assert!(res.best_fun_val < 1e-8); } - println!( - "{:?} on iteration {}\n{}", - res.best_xs, res.iters, res.best_fun_val - ); - assert!(res.best_fun_val < 1e-8); } } diff --git a/src/traits.rs b/src/traits.rs index fbd7616..b2fd3b8 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,9 +1,18 @@ /// Trait defining the data structure that must be implemented for the independent variables used /// in the objective function. The generic type denotes the type of the prime of that variable -/// +/// NOTE: This trait also defines some functions that are required to operate on the prime data +/// type. It should be noted that we are unable to just require T to implement Mul or Add +/// Bbecause then we wouldn't be able to implement XVar for plain Vec types which seems +/// inconvenient pub trait XVar: Clone { /// Update the current Xvariable based on the prime fn update(&self, alpha: f64, prime: &T) -> Self; + /// Multiply the prime by a float + fn scale_prime(prime: &T, rhs: f64) -> T; + /// Add a float to the prime + fn add_prime(prime: &T, rhs: f64) -> T; + /// Inner Produce + fn prime_inner_product(prime: &T, rhs: &T) -> f64; } /// Implementation of XVar for an f64 type @@ -11,6 +20,18 @@ impl XVar for f64 { fn update(&self, alpha: f64, prime: &f64) -> Self { self + alpha * prime } + + fn scale_prime(prime: &f64, rhs: f64) -> f64 { + prime * rhs + } + + fn add_prime(prime: &f64, rhs: f64) -> f64 { + prime + rhs + } + + fn prime_inner_product(prime: &f64, rhs: &f64) -> f64 { + prime * rhs + } } /// Implementation of XVar for a Vec type @@ -21,4 +42,19 @@ impl XVar> for Vec { .map(|(x, xprime)| x + alpha * xprime) .collect() } + + fn scale_prime(prime: &Vec, rhs: f64) -> Vec { + prime.iter().map(|val| val * rhs).collect() + } + + fn add_prime(prime: &Vec, rhs: f64) -> Vec { + prime.iter().map(|val| val + rhs).collect() + } + + fn prime_inner_product(prime: &Vec, rhs: &Vec) -> f64 { + prime + .iter() + .zip(rhs.iter()) + .fold(0.0, |acc, a| acc + a.0 * a.1) + } }