#include "../../src/activation_function.hpp"
#include <cmath>
#include <gtest/gtest.h>
#include <vector>

TEST(ActivationFunctionTest, SigmoidTest) {
  Sigmoid sigmoid;
  std::vector<float> input = {0.0, 10.0, -10.0, 1.0, -1.0};
  std::vector<float> expected = {0.5, 0.9999546, 0.0000454,
                                 static_cast<float>(1.0 / (1.0 + exp(-1.0))),
                                 static_cast<float>(1.0 / (1.0 + exp(1.0)))};

  std::vector<float> test = input;
  sigmoid(test);

  ASSERT_EQ(test.size(), expected.size());
  for (size_t i = 0; i < test.size(); i++) {
    EXPECT_NEAR(test[i], expected[i], 1e-6);
  }

  // Test initialization standard deviation
  EXPECT_NEAR(sigmoid.init_stddev(100), sqrt(1.0 / 100), 1e-6);
}

TEST(ActivationFunctionTest, ReLUTest) {
  ReLU relu;
  std::vector<float> input = {0.0, 5.0, -5.0, 0.0001, -0.0001};
  std::vector<float> expected = {0.0, 5.0, 0.0, 0.0001, 0.0};

  std::vector<float> test = input;
  relu(test);

  ASSERT_EQ(test.size(), expected.size());
  for (size_t i = 0; i < test.size(); i++) {
    EXPECT_FLOAT_EQ(test[i], expected[i]);
  }

  // Test initialization standard deviation
  EXPECT_NEAR(relu.init_stddev(100), sqrt(2.0 / 100), 1e-6);
}

TEST(ActivationFunctionTest, SoftMaxTest) {
  SoftMax softmax;
  std::vector<float> input = {1.0, 2.0, 3.0, 4.0, 1.0};
  std::vector<float> test = input;

  softmax(test);

  // Test properties of softmax
  ASSERT_EQ(test.size(), input.size());

  // Sum should be approximately 1
  float sum = 0.0;
  for (float val : test) {
    sum += val;
    // All values should be between 0 and 1
    EXPECT_GE(val, 0.0);
    EXPECT_LE(val, 1.0);
  }
  EXPECT_NEAR(sum, 1.0, 1e-6);

  // Higher input should lead to higher output
  for (size_t i = 0; i < test.size() - 1; i++) {
    if (input[i] < input[i + 1]) {
      EXPECT_LT(test[i], test[i + 1]);
    }
  }

  // Test initialization standard deviation
  EXPECT_NEAR(softmax.init_stddev(100), sqrt(1.0 / 100), 1e-6);
}