parent
08faf76ea3
commit
fdc3ffd164
@ -0,0 +1,197 @@
|
|||||||
|
use crate::{
|
||||||
|
minimize::{ExitCondition, OptimizationResult},
|
||||||
|
objective_function::ObjectiveFun,
|
||||||
|
traits::XVar,
|
||||||
|
};
|
||||||
|
use std::fmt::Debug;
|
||||||
|
|
||||||
|
use super::line_search::LineSearch;
|
||||||
|
|
||||||
|
/// Trait that should be implemented by the Prime type for conjugate gradient
|
||||||
|
pub trait ConjGradPrime: Clone + Debug {
|
||||||
|
/// Multiply primes by each other
|
||||||
|
fn mul(&self, rhs: &Self) -> Self;
|
||||||
|
/// Subtract primes from each other
|
||||||
|
fn sub(&self, rhs: &Self) -> Self;
|
||||||
|
/// Add primes from each other
|
||||||
|
fn add(&self, rhs: &Self) -> Self;
|
||||||
|
/// Divide prime by another prime (numerator/denominator)
|
||||||
|
fn div(&self, denominator: &Self) -> Self;
|
||||||
|
/// Max between the prime and a float
|
||||||
|
fn max(&self, rhs: f64) -> Self;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConjGradPrime for f64 {
|
||||||
|
fn mul(&self, rhs: &f64) -> f64 {
|
||||||
|
self * rhs
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sub(&self, rhs: &f64) -> f64 {
|
||||||
|
self - rhs
|
||||||
|
}
|
||||||
|
|
||||||
|
fn div(&self, denominator: &f64) -> f64 {
|
||||||
|
self / denominator
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max(&self, rhs: f64) -> Self {
|
||||||
|
f64::max(*self, rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add(&self, rhs: &Self) -> Self {
|
||||||
|
self + rhs
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConjGradPrime for Vec<f64> {
|
||||||
|
fn mul(&self, rhs: &Vec<f64>) -> Vec<f64> {
|
||||||
|
self.iter()
|
||||||
|
.zip(rhs.iter())
|
||||||
|
.map(|(lhs, rhs)| lhs * rhs)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sub(&self, rhs: &Vec<f64>) -> Vec<f64> {
|
||||||
|
self.iter()
|
||||||
|
.zip(rhs.iter())
|
||||||
|
.map(|(lhs, rhs)| lhs - rhs)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn div(&self, denominator: &Vec<f64>) -> Vec<f64> {
|
||||||
|
self.iter()
|
||||||
|
.zip(denominator.iter())
|
||||||
|
.map(|(num, denom)| num / denom)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn max(&self, rhs: f64) -> Self {
|
||||||
|
self.iter().map(|val| val.max(rhs)).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add(&self, rhs: &Self) -> Self {
|
||||||
|
self.iter()
|
||||||
|
.zip(rhs.iter())
|
||||||
|
.map(|(lhs, rhs)| lhs - rhs)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn conjugate_gradient<T: XVar<E> + Clone, E: Debug + ConjGradPrime>(
|
||||||
|
fun: &dyn ObjectiveFun<T, E>,
|
||||||
|
x0: &T,
|
||||||
|
max_iters: usize,
|
||||||
|
tolerance: f64,
|
||||||
|
line_search: &LineSearch,
|
||||||
|
) -> OptimizationResult<T> {
|
||||||
|
// Make a mutable copy of x0 to work with
|
||||||
|
let mut xs = x0.clone();
|
||||||
|
|
||||||
|
// Perform the iteration
|
||||||
|
let mut f_iminus1 = f64::INFINITY;
|
||||||
|
let mut f = 0.0;
|
||||||
|
let mut i = 0;
|
||||||
|
let mut prev_residual = fun.prime(&xs);
|
||||||
|
let mut direction = T::scale_prime(&prev_residual, -1.0);
|
||||||
|
for _ in 0..max_iters {
|
||||||
|
let learning_rate = line_search.get_learning_rate(fun, &xs, &direction);
|
||||||
|
xs = xs.update(learning_rate, &direction);
|
||||||
|
// Check for convergence
|
||||||
|
f = fun.eval(&xs);
|
||||||
|
if (f - f_iminus1).abs() < tolerance {
|
||||||
|
println!("{f} {f_iminus1}");
|
||||||
|
break;
|
||||||
|
} else {
|
||||||
|
f_iminus1 = f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update using polack-ribiere
|
||||||
|
let new_residual = fun.prime(&xs);
|
||||||
|
let beta = new_residual
|
||||||
|
.mul(&new_residual.sub(&prev_residual))
|
||||||
|
.div(&new_residual.mul(&new_residual));
|
||||||
|
let beta = beta.max(0.0);
|
||||||
|
direction = new_residual.add(&beta.mul(&direction));
|
||||||
|
prev_residual = new_residual;
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let exit_con = if i == max_iters {
|
||||||
|
ExitCondition::MaxIter
|
||||||
|
} else {
|
||||||
|
ExitCondition::Converged
|
||||||
|
};
|
||||||
|
OptimizationResult {
|
||||||
|
best_xs: xs,
|
||||||
|
best_fun_val: f,
|
||||||
|
exit_con,
|
||||||
|
iters: i,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod test {
|
||||||
|
use crate::objective_function::Fun;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
pub fn simple_conjugate_gradient_test() {
|
||||||
|
let fun = Box::new(|xs: &Vec<f64>| xs.iter().fold(0.0, |acc, x| acc + x.powi(2)));
|
||||||
|
let prime = Box::new(|xs: &Vec<f64>| xs.iter().map(|x| 2.0 * x).collect());
|
||||||
|
|
||||||
|
let obj = Fun::new(fun, prime);
|
||||||
|
let line_searches = vec![LineSearch::BackTrack {
|
||||||
|
gamma: 0.9,
|
||||||
|
c: 0.01,
|
||||||
|
}];
|
||||||
|
for line_search in line_searches {
|
||||||
|
let res = conjugate_gradient(&obj, &vec![20.0, 20.0], 1000, 1e-12, &line_search);
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"Best val is {:?} for xs {:?}",
|
||||||
|
res.best_fun_val, res.best_xs
|
||||||
|
);
|
||||||
|
|
||||||
|
if let ExitCondition::MaxIter = res.exit_con {
|
||||||
|
panic!("Failed to converge to minima");
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"{:?} on iteration {} has value:\n{}",
|
||||||
|
res.best_xs, res.iters, res.best_fun_val
|
||||||
|
);
|
||||||
|
assert!(res.best_fun_val < 1e-8);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
pub fn basic_beale_test() {
|
||||||
|
let fun = Box::new(|x: &Vec<f64>| {
|
||||||
|
(1.5 - x[0] + x[0] * x[1]).powi(2)
|
||||||
|
+ (2.25 - x[0] + x[0] * x[1].powi(2)).powi(2)
|
||||||
|
+ (2.625 - x[0] + x[0] * x[1].powi(3)).powi(2)
|
||||||
|
});
|
||||||
|
let prime = Box::new(|x: &Vec<f64>| {
|
||||||
|
vec![
|
||||||
|
2.0 * (1.5 - x[0] + x[0] * x[1]) * (x[1] - 1.0)
|
||||||
|
+ 2.0 * (2.25 - x[0] + x[0] * x[1].powi(2)) * (x[1].powi(2) - 1.0)
|
||||||
|
+ 2.0 * (2.625 - x[0] + x[0] * x[1].powi(3)) * (x[1].powi(3) - 1.0),
|
||||||
|
2.0 * (1.5 - x[0] + x[0] * x[1]) * (x[0])
|
||||||
|
+ 2.0 * (2.25 - x[0] + x[0] * x[1].powi(2)) * (2.0 * x[0] * x[1])
|
||||||
|
+ 2.0 * (2.625 - x[0] + x[0] * x[1].powi(3)) * (3.0 * x[0] * x[1].powi(3)),
|
||||||
|
]
|
||||||
|
});
|
||||||
|
let obj = Fun::new(fun, prime);
|
||||||
|
let line_search = LineSearch::BackTrack {
|
||||||
|
gamma: 0.9,
|
||||||
|
c: 0.01,
|
||||||
|
};
|
||||||
|
let res = conjugate_gradient(&obj, &vec![3.1, 0.5], 10000, 1e-12, &line_search);
|
||||||
|
println!(
|
||||||
|
"Best val is {:?} for xs {:?}",
|
||||||
|
res.best_fun_val, res.best_xs
|
||||||
|
);
|
||||||
|
println!("Exit condition is: {:?}", res.exit_con);
|
||||||
|
assert!(res.best_fun_val < 1e-7);
|
||||||
|
}
|
||||||
|
}
|
@ -1,2 +1,4 @@
|
|||||||
|
pub mod base;
|
||||||
|
pub mod conjugate_gradient;
|
||||||
pub mod line_search;
|
pub mod line_search;
|
||||||
pub mod steepest_descent;
|
pub mod steepest_descent;
|
||||||
|
Loading…
Reference in new issue