Fix steepest descent and backtracking line search

master
Alex Selimov 4 weeks ago
parent 42e9b748dd
commit 08faf76ea3

@ -1,3 +1,5 @@
use core::fmt;
use crate::{objective_function::ObjectiveFun, traits::XVar}; use crate::{objective_function::ObjectiveFun, traits::XVar};
pub enum LineSearch { pub enum LineSearch {
@ -20,7 +22,7 @@ impl LineSearch {
) -> f64 ) -> f64
where where
T: XVar<E> + Clone, T: XVar<E> + Clone,
E:, E: fmt::Debug,
{ {
match self { match self {
LineSearch::ConstAlpha { learning_rate } => *learning_rate, LineSearch::ConstAlpha { learning_rate } => *learning_rate,
@ -30,17 +32,16 @@ impl LineSearch {
c, c,
} => { } => {
let prime = fun.prime(xs); let prime = fun.prime(xs);
let mut del_f = T::scale_prime(&prime, *c); let fk = fun.eval(xs);
let mut new_f = fun.eval(&xs.update(-1.0, &prime)); let mut new_f = fun.eval(&xs.update(1.0, &prime));
let mut t = 1.0; let mut t = 1.0;
for i in 0..*max_iterations { while fk
if new_f > T::prime_inner_product(&T::scale_prime(&del_f, t), direction) { < new_f
break; + t * c * T::prime_inner_product(&T::scale_prime(&prime, -1.0), direction)
} {
t *= gamma; t *= gamma;
let new_x = xs.update(-t, &prime); let new_x = xs.update(t, direction);
new_f = fun.eval(&new_x); new_f = fun.eval(&new_x);
del_f = fun.prime(&new_x);
} }
t t
} }

@ -6,7 +6,7 @@ use crate::{
use super::line_search::LineSearch; use super::line_search::LineSearch;
pub fn steepest_descent<T: XVar<E> + Clone, E>( pub fn steepest_descent<T: XVar<E> + Clone, E: std::fmt::Debug>(
fun: &dyn ObjectiveFun<T, E>, fun: &dyn ObjectiveFun<T, E>,
x0: &T, x0: &T,
max_iters: usize, max_iters: usize,
@ -22,11 +22,10 @@ pub fn steepest_descent<T: XVar<E> + Clone, E>(
let mut f = 0.0; let mut f = 0.0;
let mut i = 0; let mut i = 0;
for _ in 0..max_iters { for _ in 0..max_iters {
let primes = fun.prime(&xs); let direction = T::scale_prime(&fun.prime(&xs), -1.0);
let learning_rate = line_search.get_learning_rate(fun, &xs, &T::scale_prime(&primes, -1.0)); let learning_rate = line_search.get_learning_rate(fun, &xs, &direction);
xs = xs.update(direction * learning_rate, &primes); xs = xs.update(learning_rate, &direction);
f = fun.eval(&xs); f = fun.eval(&xs);
if (f - f_iminus1).abs() < tolerance { if (f - f_iminus1).abs() < tolerance {
break; break;
} else { } else {
@ -66,21 +65,52 @@ mod test {
}, },
LineSearch::BackTrack { LineSearch::BackTrack {
max_iterations: 100, max_iterations: 100,
gamma: 0.5, gamma: 0.9,
c: 0.1, c: 0.3,
}, },
]; ];
for line_search in line_searches { for line_search in line_searches {
let res = steepest_descent(&obj, &vec![20.0], 1000, 1e-12, &line_search, -1.0); let res = steepest_descent(&obj, &vec![20.0, 20.0], 1000, 1e-12, &line_search, -1.0);
if let ExitCondition::MaxIter = res.exit_con { if let ExitCondition::MaxIter = res.exit_con {
panic!("Failed to converge to minima"); panic!("Failed to converge to minima");
} }
println!( println!(
"{:?} on iteration {}\n{}", "{:?} on iteration {} has value:\n{}",
res.best_xs, res.iters, res.best_fun_val res.best_xs, res.iters, res.best_fun_val
); );
assert!(res.best_fun_val < 1e-8); assert!(res.best_fun_val < 1e-8);
} }
} }
#[test]
pub fn basic_beale_test() {
let fun = Box::new(|x: &Vec<f64>| {
(1.5 - x[0] + x[0] * x[1]).powi(2)
+ (2.25 - x[0] + x[0] * x[1].powi(2)).powi(2)
+ (2.625 - x[0] + x[0] * x[1].powi(3)).powi(2)
});
let prime = Box::new(|x: &Vec<f64>| {
vec![
2.0 * (1.5 - x[0] + x[0] * x[1]) * (x[1] - 1.0)
+ 2.0 * (2.25 - x[0] + x[0] * x[1].powi(2)) * (x[1].powi(2) - 1.0)
+ 2.0 * (2.625 - x[0] + x[0] * x[1].powi(3)) * (x[1].powi(3) - 1.0),
2.0 * (1.5 - x[0] + x[0] * x[1]) * (x[0])
+ 2.0 * (2.25 - x[0] + x[0] * x[1].powi(2)) * (2.0 * x[0] * x[1])
+ 2.0 * (2.625 - x[0] + x[0] * x[1].powi(3)) * (3.0 * x[0] * x[1].powi(3)),
]
});
let obj = Fun::new(fun, prime);
let line_search = LineSearch::BackTrack {
max_iterations: 1000,
gamma: 0.9,
c: 0.01,
};
let res = steepest_descent(&obj, &vec![3.1, 0.5], 10000, 1e-12, &line_search, -1.0);
println!(
"Best val is {:?} for xs {:?}",
res.best_fun_val, res.best_xs
);
assert!(res.best_fun_val < 1e-7);
}
} }

@ -1,10 +1,12 @@
use std::fmt::Debug;
/// Trait defining the data structure that must be implemented for the independent variables used /// Trait defining the data structure that must be implemented for the independent variables used
/// in the objective function. The generic type denotes the type of the prime of that variable /// in the objective function. The generic type denotes the type of the prime of that variable
/// NOTE: This trait also defines some functions that are required to operate on the prime data /// NOTE: This trait also defines some functions that are required to operate on the prime data
/// type. It should be noted that we are unable to just require T to implement Mul<f64> or Add<f64> /// type. It should be noted that we are unable to just require T to implement Mul<f64> or Add<f64>
/// Bbecause then we wouldn't be able to implement XVar for plain Vec types which seems /// Bbecause then we wouldn't be able to implement XVar for plain Vec types which seems
/// inconvenient /// inconvenient
pub trait XVar<T>: Clone { pub trait XVar<T>: Clone + Debug {
/// Update the current Xvariable based on the prime /// Update the current Xvariable based on the prime
fn update(&self, alpha: f64, prime: &T) -> Self; fn update(&self, alpha: f64, prime: &T) -> Self;
/// Multiply the prime by a float /// Multiply the prime by a float
@ -38,7 +40,7 @@ impl XVar<f64> for f64 {
impl XVar<Vec<f64>> for Vec<f64> { impl XVar<Vec<f64>> for Vec<f64> {
fn update(&self, alpha: f64, prime: &Vec<f64>) -> Self { fn update(&self, alpha: f64, prime: &Vec<f64>) -> Self {
self.iter() self.iter()
.zip(prime) .zip(prime.iter())
.map(|(x, xprime)| x + alpha * xprime) .map(|(x, xprime)| x + alpha * xprime)
.collect() .collect()
} }

Loading…
Cancel
Save