From f9bc5adb71cadb656e7d7f9fe0d22a9ece1a92a9 Mon Sep 17 00:00:00 2001 From: Alex Selimov Date: Thu, 12 Dec 2024 23:03:31 -0500 Subject: [PATCH] Setup code architecture, some basic traits and types, and steepest descent algorithm --- src/gradient_descent/line_search.rs | 12 ++++ src/gradient_descent/mod.rs | 2 + src/gradient_descent/steepest_descent.rs | 79 ++++++++++++++++++++++ src/heuristics/mod.rs | 0 src/lib.rs | 5 ++ src/minimize.rs | 28 ++++++++ src/objective_function.rs | 84 ++++++++++++++++++++++++ src/traits.rs | 10 +++ 8 files changed, 220 insertions(+) create mode 100644 src/gradient_descent/line_search.rs create mode 100644 src/gradient_descent/mod.rs create mode 100644 src/gradient_descent/steepest_descent.rs create mode 100644 src/heuristics/mod.rs create mode 100644 src/lib.rs create mode 100644 src/minimize.rs create mode 100644 src/objective_function.rs create mode 100644 src/traits.rs diff --git a/src/gradient_descent/line_search.rs b/src/gradient_descent/line_search.rs new file mode 100644 index 0000000..9762fbc --- /dev/null +++ b/src/gradient_descent/line_search.rs @@ -0,0 +1,12 @@ +pub enum LineSearch { + ConstAlpha { learning_rate: f64 }, +} + +impl LineSearch { + pub fn get_learning_rate(&self) -> f64 { + match self { + LineSearch::ConstAlpha { learning_rate } => *learning_rate, + } + } +} + diff --git a/src/gradient_descent/mod.rs b/src/gradient_descent/mod.rs new file mode 100644 index 0000000..b5d5dc0 --- /dev/null +++ b/src/gradient_descent/mod.rs @@ -0,0 +1,2 @@ +pub mod line_search; +pub mod steepest_descent; diff --git a/src/gradient_descent/steepest_descent.rs b/src/gradient_descent/steepest_descent.rs new file mode 100644 index 0000000..3d37471 --- /dev/null +++ b/src/gradient_descent/steepest_descent.rs @@ -0,0 +1,79 @@ +use crate::{ + minimize::{Direction, ExitCondition, OptimizationResult}, + objective_function::ObjectiveFun, + traits::XVar, +}; + +use super::line_search::LineSearch; + +pub fn steepest_descent + Clone, E>( + fun: &dyn ObjectiveFun, + x0: &[T], + max_iters: usize, + tolerance: f64, + line_search: &LineSearch, + direction: f64, +) -> OptimizationResult { + // Make a mutable copy of x0 to work with + let mut xs = Vec::new(); + xs.extend_from_slice(x0); + + // Perform the iteration + let mut f_iminus1 = f64::INFINITY; + let mut f = 0.0; + let mut i = 0; + for _ in 0..max_iters { + let primes = fun.prime(&xs); + xs.iter_mut().zip(primes.iter()).for_each(|(x, prime)| { + *x = x.update(direction * line_search.get_learning_rate(), prime) + }); + f = fun.eval(&xs); + + if (f - f_iminus1).abs() < tolerance { + break; + } else { + f_iminus1 = f; + } + i += 1; + } + + let exit_con = if i == max_iters { + ExitCondition::MaxIter + } else { + ExitCondition::Converged + }; + OptimizationResult { + best_xs: xs, + best_fun_val: f, + exit_con, + iters: i, + } +} + +#[cfg(test)] +mod test { + use crate::objective_function::Fun; + + use super::*; + + #[test] + pub fn simple_steepest_descent_test() { + let fun = Box::new(|xs: &[f64]| xs.iter().fold(0.0, |acc, x| acc + x.powi(2))); + let prime = Box::new(|xs: &[f64]| xs.iter().copied().collect::>()); + + let obj = Fun::new(fun, prime); + let line_search = LineSearch::ConstAlpha { + learning_rate: 0.25, + }; + let res = steepest_descent(&obj, &[20.0], 1000, 1e-12, &line_search, -1.0); + + if let ExitCondition::MaxIter = res.exit_con { + panic!("Failed to converge to minima"); + } + println!( + "{:?} on iteration {}\n{}", + res.best_xs, res.iters, res.best_fun_val + ); + assert!(res.best_fun_val < 1e-8); + } +} diff --git a/src/heuristics/mod.rs b/src/heuristics/mod.rs new file mode 100644 index 0000000..e69de29 diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..40ddc0b --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,5 @@ +pub mod gradient_descent; +pub mod heuristics; +pub mod minimize; +pub mod objective_function; +pub mod traits; diff --git a/src/minimize.rs b/src/minimize.rs new file mode 100644 index 0000000..09d6143 --- /dev/null +++ b/src/minimize.rs @@ -0,0 +1,28 @@ +/// Result Enum dictating the exit condition for an optimization call +pub enum ExitCondition { + /// Optimization has converged to user specified tolerance + Converged, + /// Optimization has exceeded user specified max iteration count + MaxIter, +} + +pub enum Direction { + Minimize, + Maximize, +} + +impl Direction { + pub fn factor(&self) -> f64 { + match self { + Direction::Minimize => -1.0, + Direction::Maximize => 1.0, + } + } +} +/// Struct holding the results for a minimization call +pub struct OptimizationResult { + pub best_xs: Vec, + pub best_fun_val: f64, + pub exit_con: ExitCondition, + pub iters: usize, +} diff --git a/src/objective_function.rs b/src/objective_function.rs new file mode 100644 index 0000000..2b10350 --- /dev/null +++ b/src/objective_function.rs @@ -0,0 +1,84 @@ +use crate::traits::XVar; + +/// Trait that should be implemented for objects that will be minimzed +pub trait ObjectiveFun + Clone, E> { + /// Return the objective function value at a specified coordinate + fn eval(&self, xs: &[T]) -> f64; + /// Return the gradients of the objective function value for specified coordinates + fn prime(&self, xs: &[T]) -> Vec; +} + +/// Enum allowing for selection of style of numerical differentiation +pub enum DiffStyle { + ForwardDifference, + BackwardDifference, + CentralDifference, +} + +/// Struct that wraps a lambda and provides a numerical derivative for it for use in gradient +/// descent algorithms +pub struct FunWithNumericalDiff { + function: Box f64>, + dx: f64, + style: DiffStyle, +} + +impl ObjectiveFun for FunWithNumericalDiff { + fn eval(&self, xs: &[f64]) -> f64 { + (self.function)(xs) + } + + fn prime(&self, xs: &[f64]) -> Vec { + let mut xs_local = Vec::new(); + xs_local.extend_from_slice(xs); + let f: Box f64> = match self.style { + DiffStyle::ForwardDifference => Box::new(move |(i, x)| -> f64 { + xs_local[i] = x + self.dx; + let xprime = ((self.function)(&xs_local) - (self.function)(xs)) / (self.dx); + xs_local[i] = *x; + xprime + }), + DiffStyle::BackwardDifference => Box::new(move |(i, x)| -> f64 { + xs_local[i] = x - self.dx; + let xprime = ((self.function)(xs) - (self.function)(&xs_local)) / (self.dx); + xs_local[i] = *x; + xprime + }), + DiffStyle::CentralDifference => Box::new(move |(i, x)| -> f64 { + xs_local[i] = x - (0.5 * self.dx); + let f1 = (self.function)(&xs_local); + xs_local[i] = x + (0.5 * self.dx); + let f2 = (self.function)(&xs_local); + xs_local[i] = *x; + (f2 - f1) / self.dx + }), + }; + xs.iter().enumerate().map(f).collect() + } +} + +/// Struct that wraps two lambda with one providing the objective function evaluation and the other +/// providing the gradient value +pub struct Fun, E> { + function: Box f64>, + prime: Box Vec>, +} + +// Simple type to remove the generics +pub type F64Fun = Fun; + +impl, E> ObjectiveFun for Fun { + fn eval(&self, xs: &[T]) -> f64 { + (self.function)(xs) + } + + fn prime(&self, xs: &[T]) -> Vec { + (self.prime)(xs) + } +} + +impl, E> Fun { + pub fn new(function: Box f64>, prime: Box Vec>) -> Self { + Fun { function, prime } + } +} diff --git a/src/traits.rs b/src/traits.rs new file mode 100644 index 0000000..6b3d3f2 --- /dev/null +++ b/src/traits.rs @@ -0,0 +1,10 @@ +pub trait XVar: Clone { + fn update(&self, alpha: f64, prime: &E) -> Self; +} + +/// Implementation of XVar for an f64 type +impl XVar for f64 { + fn update(&self, alpha: f64, prime: &f64) -> Self { + self + alpha * prime + } +}