parent
c9205415f2
commit
2993580861
@ -0,0 +1,161 @@
|
||||
use crate::{
|
||||
gradient_descent::consts::EPS,
|
||||
minimize::{ExitCondition, OptimizationResult},
|
||||
objective_function::ObjectiveFun,
|
||||
traits::XVar,
|
||||
};
|
||||
use std::fmt::Debug;
|
||||
|
||||
use super::conjugate_gradient::ConjGradPrime;
|
||||
|
||||
pub struct AdamParameters {
|
||||
alpha0: f64,
|
||||
beta1: f64,
|
||||
beta2: f64,
|
||||
}
|
||||
|
||||
pub trait AdamPrime: ConjGradPrime {
|
||||
fn zero(&self) -> Self;
|
||||
fn sqrt(&self) -> Self;
|
||||
}
|
||||
|
||||
impl AdamPrime for f64 {
|
||||
fn zero(&self) -> Self {
|
||||
0.0
|
||||
}
|
||||
|
||||
fn sqrt(&self) -> Self {
|
||||
f64::sqrt(*self)
|
||||
}
|
||||
}
|
||||
|
||||
impl AdamPrime for Vec<f64> {
|
||||
fn zero(&self) -> Self {
|
||||
(0..self.len()).map(|_| 0.0).collect()
|
||||
}
|
||||
|
||||
fn sqrt(&self) -> Self {
|
||||
self.iter().map(|val| val.sqrt()).collect()
|
||||
}
|
||||
}
|
||||
pub fn adam<T: XVar<E> + Clone, E: Debug + AdamPrime>(
|
||||
fun: &dyn ObjectiveFun<T, E>,
|
||||
x0: &T,
|
||||
max_iters: usize,
|
||||
tolerance: f64,
|
||||
params: &AdamParameters,
|
||||
) -> OptimizationResult<T> {
|
||||
// Make a mutable copy of x0 to work with
|
||||
let mut xs = x0.clone();
|
||||
|
||||
// Perform the iteration
|
||||
let mut t = 0;
|
||||
let mut prime = fun.prime(x0);
|
||||
let mut m = prime.zero();
|
||||
let mut v = prime.zero();
|
||||
let mut old_f = fun.eval(x0);
|
||||
let mut f = old_f;
|
||||
for _ in 0..max_iters {
|
||||
// Do an adam step
|
||||
m = m.scale(params.beta1).add(&prime.scale(1.0 - params.beta1));
|
||||
v = (v.scale(params.beta2)).add(&prime.mul(&prime.scale(1.0 - params.beta2)));
|
||||
let mhat = m.scale(1.0 / (1.0 - params.beta1.powi(t as i32 + 1)));
|
||||
let vhat = v.scale(1.0 / (1.0 - params.beta2.powi(t as i32 + 1)));
|
||||
let update_direction = mhat.div(&vhat.sqrt().add_float(EPS)).scale(-1.0);
|
||||
|
||||
xs = xs.update(params.alpha0, &update_direction);
|
||||
prime = fun.prime(&xs);
|
||||
|
||||
// Check convergence
|
||||
f = fun.eval(&xs);
|
||||
if f.is_nan() {
|
||||
break;
|
||||
}
|
||||
if (f - old_f).abs() < tolerance {
|
||||
break;
|
||||
}
|
||||
old_f = f;
|
||||
t += 1;
|
||||
}
|
||||
|
||||
let exit_con = if t == max_iters {
|
||||
ExitCondition::MaxIter
|
||||
} else {
|
||||
ExitCondition::Converged
|
||||
};
|
||||
OptimizationResult {
|
||||
best_xs: xs,
|
||||
best_fun_val: f,
|
||||
exit_con,
|
||||
iters: t,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use crate::objective_function::Fun;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
pub fn simple_adam_test() {
|
||||
let fun = Box::new(|xs: &Vec<f64>| xs.iter().fold(0.0, |acc, x| acc + x.powi(2)));
|
||||
let prime = Box::new(|xs: &Vec<f64>| xs.iter().map(|x| 2.0 * x).collect());
|
||||
|
||||
let obj = Fun::new(fun, prime);
|
||||
let params = AdamParameters {
|
||||
alpha0: 0.1,
|
||||
beta1: 0.9,
|
||||
beta2: 0.999,
|
||||
};
|
||||
let res = adam(&obj, &vec![10.0, 10.0], 1000, 1e-12, ¶ms);
|
||||
|
||||
println!(
|
||||
"Best val is {:?} for xs {:?}",
|
||||
res.best_fun_val, res.best_xs
|
||||
);
|
||||
|
||||
println!("Exitted with {:?}", res.exit_con);
|
||||
if let ExitCondition::MaxIter = res.exit_con {
|
||||
panic!("Failed to converge to minima");
|
||||
}
|
||||
println!(
|
||||
"{:?} on iteration {} has value:\n{}",
|
||||
res.best_xs, res.iters, res.best_fun_val
|
||||
);
|
||||
assert!(res.best_fun_val < 1e-8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn basic_beale_test() {
|
||||
let fun = Box::new(|x: &Vec<f64>| {
|
||||
(1.5 - x[0] + x[0] * x[1]).powi(2)
|
||||
+ (2.25 - x[0] + x[0] * x[1].powi(2)).powi(2)
|
||||
+ (2.625 - x[0] + x[0] * x[1].powi(3)).powi(2)
|
||||
});
|
||||
let prime = Box::new(|x: &Vec<f64>| {
|
||||
vec![
|
||||
2.0 * (1.5 - x[0] + x[0] * x[1]) * (x[1] - 1.0)
|
||||
+ 2.0 * (2.25 - x[0] + x[0] * x[1].powi(2)) * (x[1].powi(2) - 1.0)
|
||||
+ 2.0 * (2.625 - x[0] + x[0] * x[1].powi(3)) * (x[1].powi(3) - 1.0),
|
||||
2.0 * (1.5 - x[0] + x[0] * x[1]) * (x[0])
|
||||
+ 2.0 * (2.25 - x[0] + x[0] * x[1].powi(2)) * (2.0 * x[0] * x[1])
|
||||
+ 2.0 * (2.625 - x[0] + x[0] * x[1].powi(3)) * (3.0 * x[0] * x[1].powi(3)),
|
||||
]
|
||||
});
|
||||
let obj = Fun::new(fun, prime);
|
||||
let params = AdamParameters {
|
||||
alpha0: 0.1,
|
||||
beta1: 0.9,
|
||||
beta2: 0.999,
|
||||
};
|
||||
let res = adam(&obj, &vec![4.0, 1.00], 1000, 1e-12, ¶ms);
|
||||
println!(
|
||||
"Best val is {:?} for xs {:?} in {} iterations",
|
||||
res.best_fun_val, res.best_xs, res.iters
|
||||
);
|
||||
|
||||
println!("Exit condition is: {:?}", res.exit_con);
|
||||
assert!(res.best_fun_val < 1e-7);
|
||||
}
|
||||
}
|
@ -0,0 +1 @@
|
||||
pub const EPS: f64 = 1e-12;
|
@ -1,4 +1,6 @@
|
||||
pub mod adam;
|
||||
pub mod base;
|
||||
pub mod conjugate_gradient;
|
||||
pub mod consts;
|
||||
pub mod line_search;
|
||||
pub mod steepest_descent;
|
||||
|
Loading…
Reference in new issue