Compiling version of backtracking line search, needs debugging
This commit is contained in:
parent
9e6d750c2d
commit
42e9b748dd
@ -1,12 +1,49 @@
|
|||||||
|
use crate::{objective_function::ObjectiveFun, traits::XVar};
|
||||||
|
|
||||||
pub enum LineSearch {
|
pub enum LineSearch {
|
||||||
ConstAlpha { learning_rate: f64 },
|
ConstAlpha {
|
||||||
|
learning_rate: f64,
|
||||||
|
},
|
||||||
|
BackTrack {
|
||||||
|
max_iterations: usize,
|
||||||
|
gamma: f64,
|
||||||
|
c: f64,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LineSearch {
|
impl LineSearch {
|
||||||
pub fn get_learning_rate(&self) -> f64 {
|
pub fn get_learning_rate<T, E>(
|
||||||
|
&self,
|
||||||
|
fun: &dyn ObjectiveFun<T, E>,
|
||||||
|
xs: &T,
|
||||||
|
direction: &E,
|
||||||
|
) -> f64
|
||||||
|
where
|
||||||
|
T: XVar<E> + Clone,
|
||||||
|
E:,
|
||||||
|
{
|
||||||
match self {
|
match self {
|
||||||
LineSearch::ConstAlpha { learning_rate } => *learning_rate,
|
LineSearch::ConstAlpha { learning_rate } => *learning_rate,
|
||||||
|
LineSearch::BackTrack {
|
||||||
|
max_iterations,
|
||||||
|
gamma,
|
||||||
|
c,
|
||||||
|
} => {
|
||||||
|
let prime = fun.prime(xs);
|
||||||
|
let mut del_f = T::scale_prime(&prime, *c);
|
||||||
|
let mut new_f = fun.eval(&xs.update(-1.0, &prime));
|
||||||
|
let mut t = 1.0;
|
||||||
|
for i in 0..*max_iterations {
|
||||||
|
if new_f > T::prime_inner_product(&T::scale_prime(&del_f, t), direction) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
t *= gamma;
|
||||||
|
let new_x = xs.update(-t, &prime);
|
||||||
|
new_f = fun.eval(&new_x);
|
||||||
|
del_f = fun.prime(&new_x);
|
||||||
|
}
|
||||||
|
t
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ pub fn steepest_descent<T: XVar<E> + Clone, E>(
|
|||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
for _ in 0..max_iters {
|
for _ in 0..max_iters {
|
||||||
let primes = fun.prime(&xs);
|
let primes = fun.prime(&xs);
|
||||||
let learning_rate = line_search.get_learning_rate(fun, &xs);
|
let learning_rate = line_search.get_learning_rate(fun, &xs, &T::scale_prime(&primes, -1.0));
|
||||||
xs = xs.update(direction * learning_rate, &primes);
|
xs = xs.update(direction * learning_rate, &primes);
|
||||||
f = fun.eval(&xs);
|
f = fun.eval(&xs);
|
||||||
|
|
||||||
@ -60,9 +60,17 @@ mod test {
|
|||||||
let prime = Box::new(|xs: &Vec<f64>| xs.iter().map(|x| 2.0 * x).collect());
|
let prime = Box::new(|xs: &Vec<f64>| xs.iter().map(|x| 2.0 * x).collect());
|
||||||
|
|
||||||
let obj = Fun::new(fun, prime);
|
let obj = Fun::new(fun, prime);
|
||||||
let line_search = LineSearch::ConstAlpha {
|
let line_searches = vec![
|
||||||
|
LineSearch::ConstAlpha {
|
||||||
learning_rate: 0.25,
|
learning_rate: 0.25,
|
||||||
};
|
},
|
||||||
|
LineSearch::BackTrack {
|
||||||
|
max_iterations: 100,
|
||||||
|
gamma: 0.5,
|
||||||
|
c: 0.1,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
for line_search in line_searches {
|
||||||
let res = steepest_descent(&obj, &vec![20.0], 1000, 1e-12, &line_search, -1.0);
|
let res = steepest_descent(&obj, &vec![20.0], 1000, 1e-12, &line_search, -1.0);
|
||||||
|
|
||||||
if let ExitCondition::MaxIter = res.exit_con {
|
if let ExitCondition::MaxIter = res.exit_con {
|
||||||
@ -75,3 +83,4 @@ mod test {
|
|||||||
assert!(res.best_fun_val < 1e-8);
|
assert!(res.best_fun_val < 1e-8);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
@ -1,9 +1,18 @@
|
|||||||
/// Trait defining the data structure that must be implemented for the independent variables used
|
/// Trait defining the data structure that must be implemented for the independent variables used
|
||||||
/// in the objective function. The generic type denotes the type of the prime of that variable
|
/// in the objective function. The generic type denotes the type of the prime of that variable
|
||||||
///
|
/// NOTE: This trait also defines some functions that are required to operate on the prime data
|
||||||
|
/// type. It should be noted that we are unable to just require T to implement Mul<f64> or Add<f64>
|
||||||
|
/// Bbecause then we wouldn't be able to implement XVar for plain Vec types which seems
|
||||||
|
/// inconvenient
|
||||||
pub trait XVar<T>: Clone {
|
pub trait XVar<T>: Clone {
|
||||||
/// Update the current Xvariable based on the prime
|
/// Update the current Xvariable based on the prime
|
||||||
fn update(&self, alpha: f64, prime: &T) -> Self;
|
fn update(&self, alpha: f64, prime: &T) -> Self;
|
||||||
|
/// Multiply the prime by a float
|
||||||
|
fn scale_prime(prime: &T, rhs: f64) -> T;
|
||||||
|
/// Add a float to the prime
|
||||||
|
fn add_prime(prime: &T, rhs: f64) -> T;
|
||||||
|
/// Inner Produce
|
||||||
|
fn prime_inner_product(prime: &T, rhs: &T) -> f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implementation of XVar for an f64 type
|
/// Implementation of XVar for an f64 type
|
||||||
@ -11,6 +20,18 @@ impl XVar<f64> for f64 {
|
|||||||
fn update(&self, alpha: f64, prime: &f64) -> Self {
|
fn update(&self, alpha: f64, prime: &f64) -> Self {
|
||||||
self + alpha * prime
|
self + alpha * prime
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn scale_prime(prime: &f64, rhs: f64) -> f64 {
|
||||||
|
prime * rhs
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_prime(prime: &f64, rhs: f64) -> f64 {
|
||||||
|
prime + rhs
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prime_inner_product(prime: &f64, rhs: &f64) -> f64 {
|
||||||
|
prime * rhs
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Implementation of XVar for a Vec<f64> type
|
/// Implementation of XVar for a Vec<f64> type
|
||||||
@ -21,4 +42,19 @@ impl XVar<Vec<f64>> for Vec<f64> {
|
|||||||
.map(|(x, xprime)| x + alpha * xprime)
|
.map(|(x, xprime)| x + alpha * xprime)
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn scale_prime(prime: &Vec<f64>, rhs: f64) -> Vec<f64> {
|
||||||
|
prime.iter().map(|val| val * rhs).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add_prime(prime: &Vec<f64>, rhs: f64) -> Vec<f64> {
|
||||||
|
prime.iter().map(|val| val + rhs).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prime_inner_product(prime: &Vec<f64>, rhs: &Vec<f64>) -> f64 {
|
||||||
|
prime
|
||||||
|
.iter()
|
||||||
|
.zip(rhs.iter())
|
||||||
|
.fold(0.0, |acc, a| acc + a.0 * a.1)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user