Skip to content

Commit d82cb10

Browse files
Add numerical stability improvements to grad_descent.hpp
1 parent 0581182 commit d82cb10

File tree

1 file changed

+12
-8
lines changed

1 file changed

+12
-8
lines changed

include/spatialize/grad_descent.hpp

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
#include "spatialize/utils.hpp"
1414

1515
namespace sptlz{
16+
// Optimization constants (shared with adaptive_esi_idw.hpp via utils)
17+
constexpr float OPT_EPSILON = 1e-10f; // Numerical stability for optimization
1618

1719
std::vector<std::vector<int>> neighbors(std::vector<int> elements, int d){
1820
std::vector<std::vector<int>> nbs, nb;
@@ -160,18 +162,19 @@ namespace sptlz{
160162
float eval(std::vector<float> *X){
161163
int n = values->size();
162164

163-
float r = 0.0, sum_w, est, wj;
165+
float r = 0.0;
164166
std::vector<float> params = {X->at(1), X->at(2)};
165167

166168
auto tr_coords = transform(this->coords, &params, &(this->centroid));
167169

168170
for(int i=0; i<n; i++){
169171
auto ds = distances(&tr_coords, i);
170-
sum_w = 0.0;
171-
est = 0.0;
172+
float sum_w = 0.0;
173+
float est = 0.0;
174+
float wj;
172175
for(int j=0;j<n;j++){
173176
if(j!=i){
174-
wj = 1.0/(1.0+std::pow(ds.at(j), X->at(0)));
177+
wj = 1.0f/(OPT_EPSILON + std::pow(ds.at(j), X->at(0)));
175178
sum_w += wj;
176179
est += wj*values->at(j);
177180
}
@@ -190,18 +193,19 @@ namespace sptlz{
190193
float eval(std::vector<float> *X){
191194
int n = values->size();
192195

193-
float r = 0.0, sum_w, est, wj;
196+
float r = 0.0;
194197
std::vector<float> params = {X->at(1), X->at(2), X->at(3), X->at(4), X->at(5)};
195198

196199
auto tr_coords = transform(this->coords, &params, &(this->centroid));
197200

198201
for(int i=0; i<n; i++){
199202
auto ds = distances(&tr_coords, i);
200-
sum_w = 0.0;
201-
est = 0.0;
203+
float sum_w = 0.0;
204+
float est = 0.0;
205+
float wj;
202206
for(int j=0;j<n;j++){
203207
if(j!=i){
204-
wj = 1.0/(1.0+std::pow(ds.at(j), X->at(0)));
208+
wj = 1.0f/(OPT_EPSILON + std::pow(ds.at(j), X->at(0)));
205209
sum_w += wj;
206210
est += wj*values->at(j);
207211
}

0 commit comments

Comments
 (0)