diff --git a/src/gradient_descent/steepest_descent.rs b/src/gradient_descent/steepest_descent.rs index 3d37471..1447aab 100644 --- a/src/gradient_descent/steepest_descent.rs +++ b/src/gradient_descent/steepest_descent.rs @@ -1,5 +1,5 @@ use crate::{ - minimize::{Direction, ExitCondition, OptimizationResult}, + minimize::{ExitCondition, OptimizationResult}, objective_function::ObjectiveFun, traits::XVar, }; @@ -8,15 +8,14 @@ use super::line_search::LineSearch; pub fn steepest_descent + Clone, E>( fun: &dyn ObjectiveFun, - x0: &[T], + x0: &T, max_iters: usize, tolerance: f64, line_search: &LineSearch, direction: f64, ) -> OptimizationResult { // Make a mutable copy of x0 to work with - let mut xs = Vec::new(); - xs.extend_from_slice(x0); + let mut xs = x0.clone(); // Perform the iteration let mut f_iminus1 = f64::INFINITY; @@ -24,9 +23,8 @@ pub fn steepest_descent + Clone, E>( let mut i = 0; for _ in 0..max_iters { let primes = fun.prime(&xs); - xs.iter_mut().zip(primes.iter()).for_each(|(x, prime)| { - *x = x.update(direction * line_search.get_learning_rate(), prime) - }); + let learning_rate = line_search.get_learning_rate(fun, &xs); + xs = xs.update(direction * learning_rate, &primes); f = fun.eval(&xs); if (f - f_iminus1).abs() < tolerance { @@ -58,14 +56,14 @@ mod test { #[test] pub fn simple_steepest_descent_test() { - let fun = Box::new(|xs: &[f64]| xs.iter().fold(0.0, |acc, x| acc + x.powi(2))); - let prime = Box::new(|xs: &[f64]| xs.iter().copied().collect::>()); + let fun = Box::new(|xs: &Vec| xs.iter().fold(0.0, |acc, x| acc + x.powi(2))); + let prime = Box::new(|xs: &Vec| xs.iter().map(|x| 2.0 * x).collect()); let obj = Fun::new(fun, prime); let line_search = LineSearch::ConstAlpha { learning_rate: 0.25, }; - let res = steepest_descent(&obj, &[20.0], 1000, 1e-12, &line_search, -1.0); + let res = steepest_descent(&obj, &vec![20.0], 1000, 1e-12, &line_search, -1.0); if let ExitCondition::MaxIter = res.exit_con { panic!("Failed to converge to minima"); diff --git a/src/minimize.rs b/src/minimize.rs index 09d6143..9772add 100644 --- a/src/minimize.rs +++ b/src/minimize.rs @@ -21,7 +21,7 @@ impl Direction { } /// Struct holding the results for a minimization call pub struct OptimizationResult { - pub best_xs: Vec, + pub best_xs: T, pub best_fun_val: f64, pub exit_con: ExitCondition, pub iters: usize, diff --git a/src/objective_function.rs b/src/objective_function.rs index 2b10350..1ae411e 100644 --- a/src/objective_function.rs +++ b/src/objective_function.rs @@ -3,9 +3,9 @@ use crate::traits::XVar; /// Trait that should be implemented for objects that will be minimzed pub trait ObjectiveFun + Clone, E> { /// Return the objective function value at a specified coordinate - fn eval(&self, xs: &[T]) -> f64; + fn eval(&self, xs: &T) -> f64; /// Return the gradients of the objective function value for specified coordinates - fn prime(&self, xs: &[T]) -> Vec; + fn prime(&self, xs: &T) -> E; } /// Enum allowing for selection of style of numerical differentiation @@ -23,12 +23,12 @@ pub struct FunWithNumericalDiff { style: DiffStyle, } -impl ObjectiveFun for FunWithNumericalDiff { - fn eval(&self, xs: &[f64]) -> f64 { +impl ObjectiveFun, Vec> for FunWithNumericalDiff { + fn eval(&self, xs: &Vec) -> f64 { (self.function)(xs) } - fn prime(&self, xs: &[f64]) -> Vec { + fn prime(&self, xs: &Vec) -> Vec { let mut xs_local = Vec::new(); xs_local.extend_from_slice(xs); let f: Box f64> = match self.style { @@ -60,25 +60,25 @@ impl ObjectiveFun for FunWithNumericalDiff { /// Struct that wraps two lambda with one providing the objective function evaluation and the other /// providing the gradient value pub struct Fun, E> { - function: Box f64>, - prime: Box Vec>, + function: Box f64>, + prime: Box E>, } // Simple type to remove the generics pub type F64Fun = Fun; impl, E> ObjectiveFun for Fun { - fn eval(&self, xs: &[T]) -> f64 { + fn eval(&self, xs: &T) -> f64 { (self.function)(xs) } - fn prime(&self, xs: &[T]) -> Vec { + fn prime(&self, xs: &T) -> E { (self.prime)(xs) } } impl, E> Fun { - pub fn new(function: Box f64>, prime: Box Vec>) -> Self { + pub fn new(function: Box f64>, prime: Box E>) -> Self { Fun { function, prime } } } diff --git a/src/traits.rs b/src/traits.rs index 6b3d3f2..fbd7616 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -1,5 +1,9 @@ -pub trait XVar: Clone { - fn update(&self, alpha: f64, prime: &E) -> Self; +/// Trait defining the data structure that must be implemented for the independent variables used +/// in the objective function. The generic type denotes the type of the prime of that variable +/// +pub trait XVar: Clone { + /// Update the current Xvariable based on the prime + fn update(&self, alpha: f64, prime: &T) -> Self; } /// Implementation of XVar for an f64 type @@ -8,3 +12,13 @@ impl XVar for f64 { self + alpha * prime } } + +/// Implementation of XVar for a Vec type +impl XVar> for Vec { + fn update(&self, alpha: f64, prime: &Vec) -> Self { + self.iter() + .zip(prime) + .map(|(x, xprime)| x + alpha * xprime) + .collect() + } +}