neural_net/src/neural_net.hpp

104 lines
3.0 KiB
C++

#ifndef NEURAL_NET_H
#define NEURAL_NET_H
#include "activation_function.hpp"
#include "matrix.hpp"
#include <random>
#include <vector>
template <class ActivationFunction> class NeuralNet {
public:
NeuralNet(std::vector<size_t> &layer_sizes) : m_sizes(layer_sizes) {
// Initialize the activation function
m_activation_func = ActivationFunction();
// Create random sampling device
std::random_device rd{};
std::mt19937 gen{rd()};
std::normal_distribution<float> dist{0.0, 1.0};
// Initialize weights for each layer connection
for (size_t i = 0; i < m_sizes.size() - 1; i++) {
size_t rows = m_sizes[i + 1]; // neurons in next layer
size_t cols = m_sizes[i]; // neurons in current layer
// Create and initialize weight matrix
Matrix<float> W(rows, cols, 0.0);
for (size_t j = 0; j < rows; j++) {
for (size_t k = 0; k < cols; k++) {
W(j, k) = dist(gen) * m_activation_func.init_stddev(cols);
}
}
m_weights.push_back(W);
}
};
// Set new weights for the network
void set_weights(const std::vector<Matrix<float>> &new_weights) {
// Validate new weights
if (new_weights.empty()) {
throw std::invalid_argument("Weights vector cannot be empty");
}
// Validate layer sizes match
if (new_weights.size() != m_weights.size()) {
throw std::invalid_argument(
"Number of weight matrices doesn't match network architecture");
}
// Validate layer connectivity
for (size_t i = 0; i < new_weights.size(); i++) {
if (new_weights[i].rows() != m_weights[i].rows()) {
throw std::invalid_argument(
"New weight matrix rows don't match existing architecture");
}
if (new_weights[i].cols() != m_weights[i].cols()) {
throw std::invalid_argument(
"New weight matrix columns don't match existing architecture");
}
}
// Update weights
m_weights = new_weights;
};
/** Pass input vector through the neural network.
* This is a fully connected neural network geometry.
* @param x Input vector
* @return output of feed forward phase
*/
std::vector<float> feed_forward(const std::vector<float> &x) {
// Convert input vector to matrix
Matrix<float> A = Matrix<float>(x.size(), 1, x);
// Feed each layer forward except the last layer using the user specified
// activation function
for (size_t i = 0; i < m_sizes.size() - 2; i++) {
// Calculate Z = W * A
Matrix Z = m_weights[i] * A;
// Apply activation function
m_activation_func(Z.data());
A = Z;
}
// Always use soft max for the final layer
Matrix Z = m_weights.back() * A;
m_soft_max(Z.data());
// Convert final output to vector
std::vector<float> output(Z.rows());
for (size_t i = 0; i < Z.rows(); i++) {
output[i] = Z(i, 0);
}
return output;
};
private:
ActivationFunction m_activation_func;
SoftMax m_soft_max;
std::vector<size_t> m_sizes;
std::vector<Matrix<float>> m_weights;
};
#endif