40 lines
1.1 KiB
Rust
40 lines
1.1 KiB
Rust
use core::fmt;
|
|
|
|
use crate::{objective_function::ObjectiveFun, traits::XVar};
|
|
|
|
use super::steepest_descent::SteepestDescentPrime;
|
|
|
|
pub enum LineSearch {
|
|
ConstAlpha { learning_rate: f64 },
|
|
BackTrack { gamma: f64, c: f64 },
|
|
}
|
|
|
|
impl LineSearch {
|
|
pub fn get_learning_rate<T, E>(
|
|
&self,
|
|
fun: &dyn ObjectiveFun<T, E>,
|
|
xs: &T,
|
|
direction: &E,
|
|
) -> f64
|
|
where
|
|
T: XVar<E> + Clone,
|
|
E: fmt::Debug + SteepestDescentPrime,
|
|
{
|
|
match self {
|
|
LineSearch::ConstAlpha { learning_rate } => *learning_rate,
|
|
LineSearch::BackTrack { gamma, c } => {
|
|
let prime = fun.prime(xs);
|
|
let fk = fun.eval(xs);
|
|
let mut new_f = fun.eval(&xs.update(1.0, &prime));
|
|
let mut t = 1.0;
|
|
while fk < new_f + t * c * prime.scale(-1.0).inner_product(direction) {
|
|
t *= gamma;
|
|
let new_x = xs.update(t, direction);
|
|
new_f = fun.eval(&new_x);
|
|
}
|
|
t
|
|
}
|
|
}
|
|
}
|
|
}
|