use core::fmt; use crate::{objective_function::ObjectiveFun, traits::XVar}; use super::steepest_descent::SteepestDescentPrime; pub enum LineSearch { ConstAlpha { learning_rate: f64 }, BackTrack { gamma: f64, c: f64 }, } impl LineSearch { pub fn get_learning_rate( &self, fun: &dyn ObjectiveFun, xs: &T, direction: &E, ) -> f64 where T: XVar + Clone, E: fmt::Debug + SteepestDescentPrime, { match self { LineSearch::ConstAlpha { learning_rate } => *learning_rate, LineSearch::BackTrack { gamma, c } => { let prime = fun.prime(xs); let fk = fun.eval(xs); let mut new_f = fun.eval(&xs.update(1.0, &prime)); let mut t = 1.0; while fk < new_f + t * c * prime.scale(-1.0).inner_product(direction) { t *= gamma; let new_x = xs.update(t, direction); new_f = fun.eval(&new_x); } t } } } }