-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathMatrix.cpp
More file actions
163 lines (134 loc) · 5.31 KB
/
Matrix.cpp
File metadata and controls
163 lines (134 loc) · 5.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#pragma once
#include <vector>
#include <stdexcept>
#include <cmath>
#include "Vector.cpp"
class Matrix {
private:
// Store data in a single contiguous array for better cache performance
// Access element at row i, column j using: data_[i * cols_ + j]
std::vector<float> data_;
size_t rows_;
size_t cols_;
public:
Matrix(size_t rows, size_t cols) : data_(rows*cols), rows_(rows), cols_(cols) {}
// Initialize weights using Xavier initialization
void xavier_init() {
// Xavier initialization helps prevent vanishing/exploding gradients
// by keeping the variance of activations roughly constant across layers
float limit = std::sqrt(6.0f / (rows_ * cols_));
for (size_t i = 0; i < rows_ * cols_; ++i) {
// Generate random number between -limit and limit
data_[i] = (static_cast<float>(rand()) / RAND_MAX * 2.0f - 1.0f) * limit;
}
}
// Matrix-vector multiplication is our most performance-critical operation
Vector multiply(const Vector& vec) const {
if (vec.size() != cols_) {
throw std::invalid_argument("Matrix and vector dimensions dont match for multiplication");
}
Vector result(rows_);
for (size_t i = 0; i < rows_; ++i) {
float sum = 0.0f;
for (size_t j = 0; j < cols_; ++j) {
sum += data_[i * cols_ + j] * vec[j];
}
result[i] = sum;
}
return result;
}
/*
Vector multiply_unrolled(const Vector& vec) const {
Vector result(rows_);
for (size_t i = 0; i < rows_; ++i) {
float sum0 = 0.0f, sum1 = 0.0f, sum2 = 0.0f, sum3 = 0.0f;
// Process 4 elements at a time
size_t j = 0;
for (; j + 3 < cols_; j += 4) {
sum0 += data_[i * cols_ + j] * vec[j];
sum1 += data_[i * cols_ + j + 1] * vec[j + 1];
sum2 += data_[i * cols_ + j + 2] * vec[j + 2];
sum3 += data_[i * cols_ + j + 3] * vec[j + 3];
}
// Handle remaining elements
float sum = sum0 + sum1 + sum2 + sum3;
for (; j < cols_; ++j) {
sum += data_[i * cols_ + j] * vec[j];
}
result[i] = sum;
}
return result;
}
Vector multiply_simd(const Vector& vec) const {
Vector result(rows_);
for (size_t i = 0; i < rows_; ++i) {
// Initialize accumulator to zero
float32x4_t sum_vec = vdupq_n_f32(0.0f); // Creates vector of 4 zeros
// Process 4 elements at a time using NEON
size_t j = 0;
for (; j + 3 < cols_; j += 4) {
// Load 4 elements from matrix and vector
float32x4_t a = vld1q_f32(&data_[i * cols_ + j]);
float32x4_t b = vld1q_f32(&vec[j]);
// Multiply and accumulate
sum_vec = vmlaq_f32(sum_vec, a, b); // sum_vec += a * b
}
// Sum the four elements of sum_vec
float sum = vaddvq_f32(sum_vec); // Horizontal add
// Handle remaining elements
for (; j < cols_; ++j) {
sum += data_[i * cols_ + j] * vec[j];
}
result[i] = sum;
}
return result;
}
Vector multiply_blocked(const Vector& vec) const {
Vector result(rows_);
constexpr size_t BLOCK_SIZE = 64; // Tune this based on your CPU's L1 cache size
for (size_t i = 0; i < rows_; i += BLOCK_SIZE) {
for (size_t j = 0; j < cols_; j += BLOCK_SIZE) {
// Process a block of the matrix
size_t i_end = std::min(i + BLOCK_SIZE, rows_);
size_t j_end = std::min(j + BLOCK_SIZE, cols_);
for (size_t ii = i; ii < i_end; ++ii) {
float sum = result[ii]; // Accumulate into existing sum
for (size_t jj = j; jj < j_end; ++jj) {
sum += data_[ii * cols_ + jj] * vec[jj];
}
result[ii] = sum;
}
}
}
return result;
}
Vector multiply_parallel(const Vector& vec) const {
Vector result(rows_);
#pragma omp parallel for
for (size_t i = 0; i < rows_; ++i) {
float32x4_t sum_vec = vdupq_n_f32(0.0f);
size_t j = 0;
for (; j + 3 < cols_; j += 4) {
float32x4_t a = vld1q_f32(&data_[i * cols_ + j]);
float32x4_t b = vld1q_f32(&vec[j]);
sum_vec = vmlaq_f32(sum_vec, a, b);
}
float sum = vaddvq_f32(sum_vec);
// Handle remaining elements
for (; j < cols_; ++j) {
sum += data_[i * cols_ + j] * vec[j];
}
result[i] = sum;
}
return result;
}
*/
float& at(size_t row, size_t col) {
return data_[row * cols_ + col];
}
const float& at(size_t row, size_t col) const {
return data_[row * cols_ + col];
}
size_t rows() const { return rows_; }
size_t cols() const { return cols_; }
};