roptimize/src/gradient_descent/line_search.rs

40 lines
1.1 KiB
Rust

use core::fmt;
use crate::{objective_function::ObjectiveFun, traits::XVar};
use super::steepest_descent::SteepestDescentPrime;
pub enum LineSearch {
ConstAlpha { learning_rate: f64 },
BackTrack { gamma: f64, c: f64 },
}
impl LineSearch {
pub fn get_learning_rate<T, E>(
&self,
fun: &dyn ObjectiveFun<T, E>,
xs: &T,
direction: &E,
) -> f64
where
T: XVar<E> + Clone,
E: fmt::Debug + SteepestDescentPrime,
{
match self {
LineSearch::ConstAlpha { learning_rate } => *learning_rate,
LineSearch::BackTrack { gamma, c } => {
let prime = fun.prime(xs);
let fk = fun.eval(xs);
let mut new_f = fun.eval(&xs.update(1.0, &prime));
let mut t = 1.0;
while fk < new_f + t * c * prime.scale(-1.0).inner_product(direction) {
t *= gamma;
let new_x = xs.update(t, direction);
new_f = fun.eval(&new_x);
}
t
}
}
}
}