generated from aselimov/cpp_project_template
76 lines
2.0 KiB
C++
76 lines
2.0 KiB
C++
#include <gtest/gtest.h>
|
|
#include "../../src/activation_function.hpp"
|
|
#include <cmath>
|
|
#include <vector>
|
|
|
|
TEST(ActivationFunctionTest, SigmoidTest) {
|
|
Sigmoid sigmoid;
|
|
std::vector<float> input = {0.0, 10.0, -10.0, 1.0, -1.0};
|
|
std::vector<float> expected = {
|
|
0.5,
|
|
0.9999546,
|
|
0.0000454,
|
|
1.0 / (1.0 + exp(-1.0)),
|
|
1.0 / (1.0 + exp(1.0))
|
|
};
|
|
|
|
std::vector<float> test = input;
|
|
sigmoid(test);
|
|
|
|
ASSERT_EQ(test.size(), expected.size());
|
|
for (size_t i = 0; i < test.size(); i++) {
|
|
EXPECT_NEAR(test[i], expected[i], 1e-6);
|
|
}
|
|
|
|
// Test initialization standard deviation
|
|
EXPECT_NEAR(sigmoid.init_stddev(100), sqrt(1.0/100), 1e-6);
|
|
}
|
|
|
|
TEST(ActivationFunctionTest, ReLUTest) {
|
|
ReLU relu;
|
|
std::vector<float> input = {0.0, 5.0, -5.0, 0.0001, -0.0001};
|
|
std::vector<float> expected = {0.0, 5.0, 0.0, 0.0001, 0.0};
|
|
|
|
std::vector<float> test = input;
|
|
relu(test);
|
|
|
|
ASSERT_EQ(test.size(), expected.size());
|
|
for (size_t i = 0; i < test.size(); i++) {
|
|
EXPECT_FLOAT_EQ(test[i], expected[i]);
|
|
}
|
|
|
|
// Test initialization standard deviation
|
|
EXPECT_NEAR(relu.init_stddev(100), sqrt(2.0/100), 1e-6);
|
|
}
|
|
|
|
TEST(ActivationFunctionTest, SoftMaxTest) {
|
|
SoftMax softmax;
|
|
std::vector<float> input = {1.0, 2.0, 3.0, 4.0, 1.0};
|
|
std::vector<float> test = input;
|
|
|
|
softmax(test);
|
|
|
|
// Test properties of softmax
|
|
ASSERT_EQ(test.size(), input.size());
|
|
|
|
// Sum should be approximately 1
|
|
float sum = 0.0;
|
|
for (float val : test) {
|
|
sum += val;
|
|
// All values should be between 0 and 1
|
|
EXPECT_GE(val, 0.0);
|
|
EXPECT_LE(val, 1.0);
|
|
}
|
|
EXPECT_NEAR(sum, 1.0, 1e-6);
|
|
|
|
// Higher input should lead to higher output
|
|
for (size_t i = 0; i < test.size() - 1; i++) {
|
|
if (input[i] < input[i + 1]) {
|
|
EXPECT_LT(test[i], test[i + 1]);
|
|
}
|
|
}
|
|
|
|
// Test initialization standard deviation
|
|
EXPECT_NEAR(softmax.init_stddev(100), sqrt(1.0/100), 1e-6);
|
|
}
|