Update traits to operate on single generic types and not slices of generic types. This should more flexibility for the user when defining objective functions

master
Alex Selimov 1 month ago
parent f9bc5adb71
commit 164f5a72c4

@ -1,5 +1,5 @@
use crate::{ use crate::{
minimize::{Direction, ExitCondition, OptimizationResult}, minimize::{ExitCondition, OptimizationResult},
objective_function::ObjectiveFun, objective_function::ObjectiveFun,
traits::XVar, traits::XVar,
}; };
@ -8,15 +8,14 @@ use super::line_search::LineSearch;
pub fn steepest_descent<T: XVar<E> + Clone, E>( pub fn steepest_descent<T: XVar<E> + Clone, E>(
fun: &dyn ObjectiveFun<T, E>, fun: &dyn ObjectiveFun<T, E>,
x0: &[T], x0: &T,
max_iters: usize, max_iters: usize,
tolerance: f64, tolerance: f64,
line_search: &LineSearch, line_search: &LineSearch,
direction: f64, direction: f64,
) -> OptimizationResult<T> { ) -> OptimizationResult<T> {
// Make a mutable copy of x0 to work with // Make a mutable copy of x0 to work with
let mut xs = Vec::new(); let mut xs = x0.clone();
xs.extend_from_slice(x0);
// Perform the iteration // Perform the iteration
let mut f_iminus1 = f64::INFINITY; let mut f_iminus1 = f64::INFINITY;
@ -24,9 +23,8 @@ pub fn steepest_descent<T: XVar<E> + Clone, E>(
let mut i = 0; let mut i = 0;
for _ in 0..max_iters { for _ in 0..max_iters {
let primes = fun.prime(&xs); let primes = fun.prime(&xs);
xs.iter_mut().zip(primes.iter()).for_each(|(x, prime)| { let learning_rate = line_search.get_learning_rate(fun, &xs);
*x = x.update(direction * line_search.get_learning_rate(), prime) xs = xs.update(direction * learning_rate, &primes);
});
f = fun.eval(&xs); f = fun.eval(&xs);
if (f - f_iminus1).abs() < tolerance { if (f - f_iminus1).abs() < tolerance {
@ -58,14 +56,14 @@ mod test {
#[test] #[test]
pub fn simple_steepest_descent_test() { pub fn simple_steepest_descent_test() {
let fun = Box::new(|xs: &[f64]| xs.iter().fold(0.0, |acc, x| acc + x.powi(2))); let fun = Box::new(|xs: &Vec<f64>| xs.iter().fold(0.0, |acc, x| acc + x.powi(2)));
let prime = Box::new(|xs: &[f64]| xs.iter().copied().collect::<Vec<f64>>()); let prime = Box::new(|xs: &Vec<f64>| xs.iter().map(|x| 2.0 * x).collect());
let obj = Fun::new(fun, prime); let obj = Fun::new(fun, prime);
let line_search = LineSearch::ConstAlpha { let line_search = LineSearch::ConstAlpha {
learning_rate: 0.25, learning_rate: 0.25,
}; };
let res = steepest_descent(&obj, &[20.0], 1000, 1e-12, &line_search, -1.0); let res = steepest_descent(&obj, &vec![20.0], 1000, 1e-12, &line_search, -1.0);
if let ExitCondition::MaxIter = res.exit_con { if let ExitCondition::MaxIter = res.exit_con {
panic!("Failed to converge to minima"); panic!("Failed to converge to minima");

@ -21,7 +21,7 @@ impl Direction {
} }
/// Struct holding the results for a minimization call /// Struct holding the results for a minimization call
pub struct OptimizationResult<T> { pub struct OptimizationResult<T> {
pub best_xs: Vec<T>, pub best_xs: T,
pub best_fun_val: f64, pub best_fun_val: f64,
pub exit_con: ExitCondition, pub exit_con: ExitCondition,
pub iters: usize, pub iters: usize,

@ -3,9 +3,9 @@ use crate::traits::XVar;
/// Trait that should be implemented for objects that will be minimzed /// Trait that should be implemented for objects that will be minimzed
pub trait ObjectiveFun<T: XVar<E> + Clone, E> { pub trait ObjectiveFun<T: XVar<E> + Clone, E> {
/// Return the objective function value at a specified coordinate /// Return the objective function value at a specified coordinate
fn eval(&self, xs: &[T]) -> f64; fn eval(&self, xs: &T) -> f64;
/// Return the gradients of the objective function value for specified coordinates /// Return the gradients of the objective function value for specified coordinates
fn prime(&self, xs: &[T]) -> Vec<E>; fn prime(&self, xs: &T) -> E;
} }
/// Enum allowing for selection of style of numerical differentiation /// Enum allowing for selection of style of numerical differentiation
@ -23,12 +23,12 @@ pub struct FunWithNumericalDiff {
style: DiffStyle, style: DiffStyle,
} }
impl ObjectiveFun<f64, f64> for FunWithNumericalDiff { impl ObjectiveFun<Vec<f64>, Vec<f64>> for FunWithNumericalDiff {
fn eval(&self, xs: &[f64]) -> f64 { fn eval(&self, xs: &Vec<f64>) -> f64 {
(self.function)(xs) (self.function)(xs)
} }
fn prime(&self, xs: &[f64]) -> Vec<f64> { fn prime(&self, xs: &Vec<f64>) -> Vec<f64> {
let mut xs_local = Vec::new(); let mut xs_local = Vec::new();
xs_local.extend_from_slice(xs); xs_local.extend_from_slice(xs);
let f: Box<dyn FnMut((usize, &f64)) -> f64> = match self.style { let f: Box<dyn FnMut((usize, &f64)) -> f64> = match self.style {
@ -60,25 +60,25 @@ impl ObjectiveFun<f64, f64> for FunWithNumericalDiff {
/// Struct that wraps two lambda with one providing the objective function evaluation and the other /// Struct that wraps two lambda with one providing the objective function evaluation and the other
/// providing the gradient value /// providing the gradient value
pub struct Fun<T: XVar<E>, E> { pub struct Fun<T: XVar<E>, E> {
function: Box<dyn Fn(&[T]) -> f64>, function: Box<dyn Fn(&T) -> f64>,
prime: Box<dyn Fn(&[T]) -> Vec<E>>, prime: Box<dyn Fn(&T) -> E>,
} }
// Simple type to remove the generics // Simple type to remove the generics
pub type F64Fun = Fun<f64, f64>; pub type F64Fun = Fun<f64, f64>;
impl<T: XVar<E>, E> ObjectiveFun<T, E> for Fun<T, E> { impl<T: XVar<E>, E> ObjectiveFun<T, E> for Fun<T, E> {
fn eval(&self, xs: &[T]) -> f64 { fn eval(&self, xs: &T) -> f64 {
(self.function)(xs) (self.function)(xs)
} }
fn prime(&self, xs: &[T]) -> Vec<E> { fn prime(&self, xs: &T) -> E {
(self.prime)(xs) (self.prime)(xs)
} }
} }
impl<T: XVar<E>, E> Fun<T, E> { impl<T: XVar<E>, E> Fun<T, E> {
pub fn new(function: Box<dyn Fn(&[T]) -> f64>, prime: Box<dyn Fn(&[T]) -> Vec<E>>) -> Self { pub fn new(function: Box<dyn Fn(&T) -> f64>, prime: Box<dyn Fn(&T) -> E>) -> Self {
Fun { function, prime } Fun { function, prime }
} }
} }

@ -1,5 +1,9 @@
pub trait XVar<E>: Clone { /// Trait defining the data structure that must be implemented for the independent variables used
fn update(&self, alpha: f64, prime: &E) -> Self; /// in the objective function. The generic type denotes the type of the prime of that variable
///
pub trait XVar<T>: Clone {
/// Update the current Xvariable based on the prime
fn update(&self, alpha: f64, prime: &T) -> Self;
} }
/// Implementation of XVar for an f64 type /// Implementation of XVar for an f64 type
@ -8,3 +12,13 @@ impl XVar<f64> for f64 {
self + alpha * prime self + alpha * prime
} }
} }
/// Implementation of XVar for a Vec<f64> type
impl XVar<Vec<f64>> for Vec<f64> {
fn update(&self, alpha: f64, prime: &Vec<f64>) -> Self {
self.iter()
.zip(prime)
.map(|(x, xprime)| x + alpha * xprime)
.collect()
}
}

Loading…
Cancel
Save