From c9205415f24e8d0bf56fd70cdf0a0630de4b37b5 Mon Sep 17 00:00:00 2001 From: Alex Selimov Date: Wed, 22 Jan 2025 22:13:27 -0500 Subject: [PATCH] Fix missing negative on residual --- src/gradient_descent/conjugate_gradient.rs | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/gradient_descent/conjugate_gradient.rs b/src/gradient_descent/conjugate_gradient.rs index c814fca..5c075b8 100644 --- a/src/gradient_descent/conjugate_gradient.rs +++ b/src/gradient_descent/conjugate_gradient.rs @@ -99,20 +99,18 @@ pub fn conjugate_gradient + Clone, E: Debug + ConjGradPrime>( // Check for convergence f = fun.eval(&xs); if (f - f_iminus1).abs() < tolerance { - println!("{f} {f_iminus1}"); break; } else { f_iminus1 = f; } // Update using polack-ribiere - let new_residual = fun.prime(&xs); - let beta = new_residual - .mul(&new_residual.sub(&prev_residual)) - .div(&new_residual.mul(&new_residual)); + let new_residual = T::scale_prime(&fun.prime(&xs), -1.0); + let beta = (new_residual.mul(&new_residual.sub(&prev_residual))) + .div(&prev_residual.mul(&prev_residual)); let beta = beta.max(0.0); direction = new_residual.add(&beta.mul(&direction)); - prev_residual = new_residual; + prev_residual = new_residual.clone(); i += 1; } @@ -153,6 +151,7 @@ mod test { res.best_fun_val, res.best_xs ); + println!("Exitted with {:?}", res.exit_con); if let ExitCondition::MaxIter = res.exit_con { panic!("Failed to converge to minima"); } @@ -186,11 +185,12 @@ mod test { gamma: 0.9, c: 0.01, }; - let res = conjugate_gradient(&obj, &vec![3.1, 0.5], 10000, 1e-12, &line_search); + let res = conjugate_gradient(&obj, &vec![4.0, 1.00], 10000, 1e-12, &line_search); println!( "Best val is {:?} for xs {:?}", res.best_fun_val, res.best_xs ); + println!("Exit condition is: {:?}", res.exit_con); assert!(res.best_fun_val < 1e-7); }