From fdc3ffd1645506d1245787520678b80bdc2304f7 Mon Sep 17 00:00:00 2001 From: Alex Selimov Date: Fri, 17 Jan 2025 23:53:43 -0500 Subject: [PATCH] Add initial conjugate gradient. WIP --- src/gradient_descent/conjugate_gradient.rs | 197 +++++++++++++++++++++ src/gradient_descent/line_search.rs | 16 +- src/gradient_descent/mod.rs | 2 + src/gradient_descent/steepest_descent.rs | 12 +- src/minimize.rs | 1 + src/traits.rs | 2 +- 6 files changed, 207 insertions(+), 23 deletions(-) create mode 100644 src/gradient_descent/conjugate_gradient.rs diff --git a/src/gradient_descent/conjugate_gradient.rs b/src/gradient_descent/conjugate_gradient.rs new file mode 100644 index 0000000..c814fca --- /dev/null +++ b/src/gradient_descent/conjugate_gradient.rs @@ -0,0 +1,197 @@ +use crate::{ + minimize::{ExitCondition, OptimizationResult}, + objective_function::ObjectiveFun, + traits::XVar, +}; +use std::fmt::Debug; + +use super::line_search::LineSearch; + +/// Trait that should be implemented by the Prime type for conjugate gradient +pub trait ConjGradPrime: Clone + Debug { + /// Multiply primes by each other + fn mul(&self, rhs: &Self) -> Self; + /// Subtract primes from each other + fn sub(&self, rhs: &Self) -> Self; + /// Add primes from each other + fn add(&self, rhs: &Self) -> Self; + /// Divide prime by another prime (numerator/denominator) + fn div(&self, denominator: &Self) -> Self; + /// Max between the prime and a float + fn max(&self, rhs: f64) -> Self; +} + +impl ConjGradPrime for f64 { + fn mul(&self, rhs: &f64) -> f64 { + self * rhs + } + + fn sub(&self, rhs: &f64) -> f64 { + self - rhs + } + + fn div(&self, denominator: &f64) -> f64 { + self / denominator + } + + fn max(&self, rhs: f64) -> Self { + f64::max(*self, rhs) + } + + fn add(&self, rhs: &Self) -> Self { + self + rhs + } +} + +impl ConjGradPrime for Vec { + fn mul(&self, rhs: &Vec) -> Vec { + self.iter() + .zip(rhs.iter()) + .map(|(lhs, rhs)| lhs * rhs) + .collect() + } + + fn sub(&self, rhs: &Vec) -> Vec { + self.iter() + .zip(rhs.iter()) + .map(|(lhs, rhs)| lhs - rhs) + .collect() + } + + fn div(&self, denominator: &Vec) -> Vec { + self.iter() + .zip(denominator.iter()) + .map(|(num, denom)| num / denom) + .collect() + } + + fn max(&self, rhs: f64) -> Self { + self.iter().map(|val| val.max(rhs)).collect() + } + + fn add(&self, rhs: &Self) -> Self { + self.iter() + .zip(rhs.iter()) + .map(|(lhs, rhs)| lhs - rhs) + .collect() + } +} + +pub fn conjugate_gradient + Clone, E: Debug + ConjGradPrime>( + fun: &dyn ObjectiveFun, + x0: &T, + max_iters: usize, + tolerance: f64, + line_search: &LineSearch, +) -> OptimizationResult { + // Make a mutable copy of x0 to work with + let mut xs = x0.clone(); + + // Perform the iteration + let mut f_iminus1 = f64::INFINITY; + let mut f = 0.0; + let mut i = 0; + let mut prev_residual = fun.prime(&xs); + let mut direction = T::scale_prime(&prev_residual, -1.0); + for _ in 0..max_iters { + let learning_rate = line_search.get_learning_rate(fun, &xs, &direction); + xs = xs.update(learning_rate, &direction); + // Check for convergence + f = fun.eval(&xs); + if (f - f_iminus1).abs() < tolerance { + println!("{f} {f_iminus1}"); + break; + } else { + f_iminus1 = f; + } + + // Update using polack-ribiere + let new_residual = fun.prime(&xs); + let beta = new_residual + .mul(&new_residual.sub(&prev_residual)) + .div(&new_residual.mul(&new_residual)); + let beta = beta.max(0.0); + direction = new_residual.add(&beta.mul(&direction)); + prev_residual = new_residual; + i += 1; + } + + let exit_con = if i == max_iters { + ExitCondition::MaxIter + } else { + ExitCondition::Converged + }; + OptimizationResult { + best_xs: xs, + best_fun_val: f, + exit_con, + iters: i, + } +} + +#[cfg(test)] +mod test { + use crate::objective_function::Fun; + + use super::*; + + #[test] + pub fn simple_conjugate_gradient_test() { + let fun = Box::new(|xs: &Vec| xs.iter().fold(0.0, |acc, x| acc + x.powi(2))); + let prime = Box::new(|xs: &Vec| xs.iter().map(|x| 2.0 * x).collect()); + + let obj = Fun::new(fun, prime); + let line_searches = vec![LineSearch::BackTrack { + gamma: 0.9, + c: 0.01, + }]; + for line_search in line_searches { + let res = conjugate_gradient(&obj, &vec![20.0, 20.0], 1000, 1e-12, &line_search); + + println!( + "Best val is {:?} for xs {:?}", + res.best_fun_val, res.best_xs + ); + + if let ExitCondition::MaxIter = res.exit_con { + panic!("Failed to converge to minima"); + } + println!( + "{:?} on iteration {} has value:\n{}", + res.best_xs, res.iters, res.best_fun_val + ); + assert!(res.best_fun_val < 1e-8); + } + } + + #[test] + pub fn basic_beale_test() { + let fun = Box::new(|x: &Vec| { + (1.5 - x[0] + x[0] * x[1]).powi(2) + + (2.25 - x[0] + x[0] * x[1].powi(2)).powi(2) + + (2.625 - x[0] + x[0] * x[1].powi(3)).powi(2) + }); + let prime = Box::new(|x: &Vec| { + vec![ + 2.0 * (1.5 - x[0] + x[0] * x[1]) * (x[1] - 1.0) + + 2.0 * (2.25 - x[0] + x[0] * x[1].powi(2)) * (x[1].powi(2) - 1.0) + + 2.0 * (2.625 - x[0] + x[0] * x[1].powi(3)) * (x[1].powi(3) - 1.0), + 2.0 * (1.5 - x[0] + x[0] * x[1]) * (x[0]) + + 2.0 * (2.25 - x[0] + x[0] * x[1].powi(2)) * (2.0 * x[0] * x[1]) + + 2.0 * (2.625 - x[0] + x[0] * x[1].powi(3)) * (3.0 * x[0] * x[1].powi(3)), + ] + }); + let obj = Fun::new(fun, prime); + let line_search = LineSearch::BackTrack { + gamma: 0.9, + c: 0.01, + }; + let res = conjugate_gradient(&obj, &vec![3.1, 0.5], 10000, 1e-12, &line_search); + println!( + "Best val is {:?} for xs {:?}", + res.best_fun_val, res.best_xs + ); + println!("Exit condition is: {:?}", res.exit_con); + assert!(res.best_fun_val < 1e-7); + } +} diff --git a/src/gradient_descent/line_search.rs b/src/gradient_descent/line_search.rs index fb3d287..9087ffb 100644 --- a/src/gradient_descent/line_search.rs +++ b/src/gradient_descent/line_search.rs @@ -3,14 +3,8 @@ use core::fmt; use crate::{objective_function::ObjectiveFun, traits::XVar}; pub enum LineSearch { - ConstAlpha { - learning_rate: f64, - }, - BackTrack { - max_iterations: usize, - gamma: f64, - c: f64, - }, + ConstAlpha { learning_rate: f64 }, + BackTrack { gamma: f64, c: f64 }, } impl LineSearch { @@ -26,11 +20,7 @@ impl LineSearch { { match self { LineSearch::ConstAlpha { learning_rate } => *learning_rate, - LineSearch::BackTrack { - max_iterations, - gamma, - c, - } => { + LineSearch::BackTrack { gamma, c } => { let prime = fun.prime(xs); let fk = fun.eval(xs); let mut new_f = fun.eval(&xs.update(1.0, &prime)); diff --git a/src/gradient_descent/mod.rs b/src/gradient_descent/mod.rs index b5d5dc0..1b59051 100644 --- a/src/gradient_descent/mod.rs +++ b/src/gradient_descent/mod.rs @@ -1,2 +1,4 @@ +pub mod base; +pub mod conjugate_gradient; pub mod line_search; pub mod steepest_descent; diff --git a/src/gradient_descent/steepest_descent.rs b/src/gradient_descent/steepest_descent.rs index d3cefa2..9ef63f9 100644 --- a/src/gradient_descent/steepest_descent.rs +++ b/src/gradient_descent/steepest_descent.rs @@ -12,7 +12,6 @@ pub fn steepest_descent + Clone, E: std::fmt::Debug>( max_iters: usize, tolerance: f64, line_search: &LineSearch, - direction: f64, ) -> OptimizationResult { // Make a mutable copy of x0 to work with let mut xs = x0.clone(); @@ -63,14 +62,10 @@ mod test { LineSearch::ConstAlpha { learning_rate: 0.25, }, - LineSearch::BackTrack { - max_iterations: 100, - gamma: 0.9, - c: 0.3, - }, + LineSearch::BackTrack { gamma: 0.9, c: 0.3 }, ]; for line_search in line_searches { - let res = steepest_descent(&obj, &vec![20.0, 20.0], 1000, 1e-12, &line_search, -1.0); + let res = steepest_descent(&obj, &vec![20.0, 20.0], 1000, 1e-12, &line_search); if let ExitCondition::MaxIter = res.exit_con { panic!("Failed to converge to minima"); @@ -102,11 +97,10 @@ mod test { }); let obj = Fun::new(fun, prime); let line_search = LineSearch::BackTrack { - max_iterations: 1000, gamma: 0.9, c: 0.01, }; - let res = steepest_descent(&obj, &vec![3.1, 0.5], 10000, 1e-12, &line_search, -1.0); + let res = steepest_descent(&obj, &vec![3.1, 0.5], 10000, 1e-12, &line_search); println!( "Best val is {:?} for xs {:?}", res.best_fun_val, res.best_xs diff --git a/src/minimize.rs b/src/minimize.rs index 9772add..848dfe9 100644 --- a/src/minimize.rs +++ b/src/minimize.rs @@ -1,4 +1,5 @@ /// Result Enum dictating the exit condition for an optimization call +#[derive(Debug)] pub enum ExitCondition { /// Optimization has converged to user specified tolerance Converged, diff --git a/src/traits.rs b/src/traits.rs index 8f73485..6631779 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -13,7 +13,7 @@ pub trait XVar: Clone + Debug { fn scale_prime(prime: &T, rhs: f64) -> T; /// Add a float to the prime fn add_prime(prime: &T, rhs: f64) -> T; - /// Inner Produce + /// Inner Product of prime fn prime_inner_product(prime: &T, rhs: &T) -> f64; }