#include <gtest/gtest.h>
#include "../../src/activation_function.hpp"
#include <cmath>

TEST(ActivationFunctionTest, SigmoidTest) {
    Sigmoid sigmoid;
    
    // Test sigmoid at x = 0 (should be 0.5)
    EXPECT_NEAR(sigmoid(0.0), 0.5, 1e-6);
    
    // Test sigmoid at large positive value (should approach 1)
    EXPECT_NEAR(sigmoid(10.0), 1.0, 1e-4);
    
    // Test sigmoid at large negative value (should approach 0)
    EXPECT_NEAR(sigmoid(-10.0), 0.0, 1e-4);
    
    // Test sigmoid at x = 1
    EXPECT_NEAR(sigmoid(1.0), 1.0 / (1.0 + exp(-1.0)), 1e-6);
    
    // Test sigmoid at x = -1
    EXPECT_NEAR(sigmoid(-1.0), 1.0 / (1.0 + exp(1.0)), 1e-6);
}

TEST(ActivationFunctionTest, ReLUTest) {
    ReLU relu;
    
    // Test ReLU at x = 0 (should be 0)
    EXPECT_DOUBLE_EQ(relu(0.0), 0.0);
    
    // Test ReLU at positive value (should be same value)
    EXPECT_DOUBLE_EQ(relu(5.0), 5.0);
    
    // Test ReLU at negative value (should be 0)
    EXPECT_DOUBLE_EQ(relu(-5.0), 0.0);
    
    // Test ReLU at very small positive value
    EXPECT_DOUBLE_EQ(relu(0.0001), 0.0001);
    
    // Test ReLU at very small negative value
    EXPECT_DOUBLE_EQ(relu(-0.0001), 0.0);
}