use crate::{ minimize::{Direction, ExitCondition, OptimizationResult}, objective_function::ObjectiveFun, traits::XVar, }; use super::line_search::LineSearch; pub fn steepest_descent + Clone, E>( fun: &dyn ObjectiveFun, x0: &[T], max_iters: usize, tolerance: f64, line_search: &LineSearch, direction: f64, ) -> OptimizationResult { // Make a mutable copy of x0 to work with let mut xs = Vec::new(); xs.extend_from_slice(x0); // Perform the iteration let mut f_iminus1 = f64::INFINITY; let mut f = 0.0; let mut i = 0; for _ in 0..max_iters { let primes = fun.prime(&xs); xs.iter_mut().zip(primes.iter()).for_each(|(x, prime)| { *x = x.update(direction * line_search.get_learning_rate(), prime) }); f = fun.eval(&xs); if (f - f_iminus1).abs() < tolerance { break; } else { f_iminus1 = f; } i += 1; } let exit_con = if i == max_iters { ExitCondition::MaxIter } else { ExitCondition::Converged }; OptimizationResult { best_xs: xs, best_fun_val: f, exit_con, iters: i, } } #[cfg(test)] mod test { use crate::objective_function::Fun; use super::*; #[test] pub fn simple_steepest_descent_test() { let fun = Box::new(|xs: &[f64]| xs.iter().fold(0.0, |acc, x| acc + x.powi(2))); let prime = Box::new(|xs: &[f64]| xs.iter().copied().collect::>()); let obj = Fun::new(fun, prime); let line_search = LineSearch::ConstAlpha { learning_rate: 0.25, }; let res = steepest_descent(&obj, &[20.0], 1000, 1e-12, &line_search, -1.0); if let ExitCondition::MaxIter = res.exit_con { panic!("Failed to converge to minima"); } println!( "{:?} on iteration {}\n{}", res.best_xs, res.iters, res.best_fun_val ); assert!(res.best_fun_val < 1e-8); } }