roptimize/src/gradient_descent/steepest_descent.rs

80 lines
2.0 KiB
Rust

use crate::{
minimize::{Direction, ExitCondition, OptimizationResult},
objective_function::ObjectiveFun,
traits::XVar,
};
use super::line_search::LineSearch;
pub fn steepest_descent<T: XVar<E> + Clone, E>(
fun: &dyn ObjectiveFun<T, E>,
x0: &[T],
max_iters: usize,
tolerance: f64,
line_search: &LineSearch,
direction: f64,
) -> OptimizationResult<T> {
// Make a mutable copy of x0 to work with
let mut xs = Vec::new();
xs.extend_from_slice(x0);
// Perform the iteration
let mut f_iminus1 = f64::INFINITY;
let mut f = 0.0;
let mut i = 0;
for _ in 0..max_iters {
let primes = fun.prime(&xs);
xs.iter_mut().zip(primes.iter()).for_each(|(x, prime)| {
*x = x.update(direction * line_search.get_learning_rate(), prime)
});
f = fun.eval(&xs);
if (f - f_iminus1).abs() < tolerance {
break;
} else {
f_iminus1 = f;
}
i += 1;
}
let exit_con = if i == max_iters {
ExitCondition::MaxIter
} else {
ExitCondition::Converged
};
OptimizationResult {
best_xs: xs,
best_fun_val: f,
exit_con,
iters: i,
}
}
#[cfg(test)]
mod test {
use crate::objective_function::Fun;
use super::*;
#[test]
pub fn simple_steepest_descent_test() {
let fun = Box::new(|xs: &[f64]| xs.iter().fold(0.0, |acc, x| acc + x.powi(2)));
let prime = Box::new(|xs: &[f64]| xs.iter().copied().collect::<Vec<f64>>());
let obj = Fun::new(fun, prime);
let line_search = LineSearch::ConstAlpha {
learning_rate: 0.25,
};
let res = steepest_descent(&obj, &[20.0], 1000, 1e-12, &line_search, -1.0);
if let ExitCondition::MaxIter = res.exit_con {
panic!("Failed to converge to minima");
}
println!(
"{:?} on iteration {}\n{}",
res.best_xs, res.iters, res.best_fun_val
);
assert!(res.best_fun_val < 1e-8);
}
}