neural_net/src/neural_net.hpp

20 lines
549 B
C++

#ifndef NEURAL_NET_H
#define NEURAL_NET_H
#include "activation_function.hpp"
#include <vector>
template <class ActivationFunction> class NeuralNet {
public:
NeuralNet(std::vector<size_t> &layer_sizes);
private:
ActivationFunction m_activation_func;
SoftMax m_soft_max;
std::vector<size_t> m_sizes;
std::vector<float> m_weights;
std::vector<float> feed_forward(std::vector<float> &x);
std::vector<float> feed_layer_forward(size_t layer_start_idx, size_t size,
std::vector<float> &A);
};
#endif