-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathWeightMatrix.cpp
More file actions
70 lines (58 loc) · 2.15 KB
/
WeightMatrix.cpp
File metadata and controls
70 lines (58 loc) · 2.15 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
#pragma once
#include <vector>
#include <stdexcept>
#include <cmath>
#include "Vector.cpp"
#include "Matrix.cpp"
#include "Activation.cpp"
// Handles the mathematical operations and storage of layer weights
class WeightMatrix {
private:
// We'll decide on the exact storage mechanism later
Matrix weights_;
Vector bias_;
public:
WeightMatrix(size_t input_size, size_t output_size)
: weights_(output_size, input_size), bias_(output_size) // initialize matrix and bias vector
{
// initialize weights using xavier initialization
weights_.xavier_init();
bias_.uniform_init();
}
// Forward pass: output = weights * input + bias
Vector multiply(const Vector& input) const {
Vector output = weights_.multiply(input);
for (size_t i = 0; i < bias_.size(); ++i) {
output[i] += bias_[i];
}
return output;
}
// update weights and biases during backpropogation
void update(const Matrix& weight_gradients, const Vector& bias_gradients, float learning_rate = 0.01f) {
// verify dimensions match
if (weight_gradients.rows() != weights_.rows() ||
weight_gradients.cols() != weights_.cols()) {
throw std:: invalid_argument("Weight gradient dimensions don't match");
}
if (bias_gradients.size() != bias_.size()) {
throw std::invalid_argument("Bias gradient dimensions don't match");
}
for (size_t i = 0; i < weights_.rows(); ++i) {
for (size_t j = 0; j < weights_.cols(); ++j) {
weights_.at(i, j) -= learning_rate * weight_gradients.at(i, j);
}
}
for (size_t i = 0; i < bias_.size(); ++i) {
bias_[i] -= learning_rate * bias_gradients[i];
}
}
float& at(size_t row, size_t col) {
return weights_.at(row, col);
}
const float& at(size_t row, size_t col) const {
return weights_.at(row, col);
}
// Getter methods for accessing dimensions
size_t input_size() const { return weights_.cols(); }
size_t output_size() const { return weights_.rows(); }
};