80 lines
2.0 KiB
Rust
80 lines
2.0 KiB
Rust
use crate::{
|
|
minimize::{Direction, ExitCondition, OptimizationResult},
|
|
objective_function::ObjectiveFun,
|
|
traits::XVar,
|
|
};
|
|
|
|
use super::line_search::LineSearch;
|
|
|
|
pub fn steepest_descent<T: XVar<E> + Clone, E>(
|
|
fun: &dyn ObjectiveFun<T, E>,
|
|
x0: &[T],
|
|
max_iters: usize,
|
|
tolerance: f64,
|
|
line_search: &LineSearch,
|
|
direction: f64,
|
|
) -> OptimizationResult<T> {
|
|
// Make a mutable copy of x0 to work with
|
|
let mut xs = Vec::new();
|
|
xs.extend_from_slice(x0);
|
|
|
|
// Perform the iteration
|
|
let mut f_iminus1 = f64::INFINITY;
|
|
let mut f = 0.0;
|
|
let mut i = 0;
|
|
for _ in 0..max_iters {
|
|
let primes = fun.prime(&xs);
|
|
xs.iter_mut().zip(primes.iter()).for_each(|(x, prime)| {
|
|
*x = x.update(direction * line_search.get_learning_rate(), prime)
|
|
});
|
|
f = fun.eval(&xs);
|
|
|
|
if (f - f_iminus1).abs() < tolerance {
|
|
break;
|
|
} else {
|
|
f_iminus1 = f;
|
|
}
|
|
i += 1;
|
|
}
|
|
|
|
let exit_con = if i == max_iters {
|
|
ExitCondition::MaxIter
|
|
} else {
|
|
ExitCondition::Converged
|
|
};
|
|
OptimizationResult {
|
|
best_xs: xs,
|
|
best_fun_val: f,
|
|
exit_con,
|
|
iters: i,
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod test {
|
|
use crate::objective_function::Fun;
|
|
|
|
use super::*;
|
|
|
|
#[test]
|
|
pub fn simple_steepest_descent_test() {
|
|
let fun = Box::new(|xs: &[f64]| xs.iter().fold(0.0, |acc, x| acc + x.powi(2)));
|
|
let prime = Box::new(|xs: &[f64]| xs.iter().copied().collect::<Vec<f64>>());
|
|
|
|
let obj = Fun::new(fun, prime);
|
|
let line_search = LineSearch::ConstAlpha {
|
|
learning_rate: 0.25,
|
|
};
|
|
let res = steepest_descent(&obj, &[20.0], 1000, 1e-12, &line_search, -1.0);
|
|
|
|
if let ExitCondition::MaxIter = res.exit_con {
|
|
panic!("Failed to converge to minima");
|
|
}
|
|
println!(
|
|
"{:?} on iteration {}\n{}",
|
|
res.best_xs, res.iters, res.best_fun_val
|
|
);
|
|
assert!(res.best_fun_val < 1e-8);
|
|
}
|
|
}
|