|
6 | 6 | #include <random> |
7 | 7 | #include "spatialize/abstract_esi.hpp" |
8 | 8 | #include "spatialize/utils.hpp" |
| 9 | +#include "spatialize/grad_descent.hpp" |
9 | 10 |
|
10 | 11 | namespace sptlz{ |
11 | 12 | class LOO2D{ |
@@ -88,6 +89,11 @@ namespace sptlz{ |
88 | 89 |
|
89 | 90 | class ADAPTIVE_ESI_IDW: public ESI { |
90 | 91 | protected: |
| 92 | + int d, k; |
| 93 | + std::vector<std::vector<float>> param_ranges; |
| 94 | + std::vector<float> steps; |
| 95 | + std::vector<int> ns; |
| 96 | + |
91 | 97 | std::vector<float> leaf_estimation(std::vector<std::vector<float>> *coords, std::vector<float> *values, std::vector<int> *samples_id, std::vector<std::vector<float>> *locations, std::vector<int> *locations_id, std::vector<float> *params){ |
92 | 98 | std::vector<float> result; |
93 | 99 |
|
@@ -281,6 +287,25 @@ namespace sptlz{ |
281 | 287 | } |
282 | 288 |
|
283 | 289 | std::vector<float> get_params(std::vector<std::vector<float>> *coords, std::vector<float> *values){ |
| 290 | + if(coords->size()==0){ |
| 291 | + return(std::vector<float>()); |
| 292 | + }else if(coords->size()==1){ |
| 293 | + return(std::vector<float>({values->at(0)})); |
| 294 | + } |
| 295 | + |
| 296 | + LOOND *fn; |
| 297 | + if(this->d==2){ |
| 298 | + fn = new LOO_2D(coords, values, 0.01); |
| 299 | + }else{ |
| 300 | + fn = new LOO_3D(coords, values, 0.01); |
| 301 | + } |
| 302 | + GradDesc *opt = new GridNBRndDesc(fn, this->param_ranges, this->steps, this->ns, k, std::rand()); |
| 303 | + std::vector<float> m = get_minimum(opt, &(this->param_ranges), 100); |
| 304 | + |
| 305 | + return(m); |
| 306 | + } |
| 307 | + |
| 308 | + std::vector<float> get_params2(std::vector<std::vector<float>> *coords, std::vector<float> *values){ |
284 | 309 | std::uniform_real_distribution<float> uni_float(0, 1); |
285 | 310 | int best_of = 3; |
286 | 311 | if(coords->size()==0){ |
@@ -362,6 +387,31 @@ namespace sptlz{ |
362 | 387 | int seed=206936): |
363 | 388 | ESI(_coords, _values, lambda, forest_size, bbox, visitor, seed){ |
364 | 389 | this->class_name = __func__; |
| 390 | + if(_coords.at(0).size()==2){ |
| 391 | + this->d = 2; |
| 392 | + this->param_ranges = { |
| 393 | + { 0.5, 10.0}, |
| 394 | + {-90.0, 90.0}, |
| 395 | + { 0.1, 1.0} |
| 396 | + }; |
| 397 | + this->steps = {0.5, 10.0, 0.1}; |
| 398 | + this->ns = {19, 18, 9}; |
| 399 | + }else if(_coords.at(0).size()==3){ |
| 400 | + this->d = 3; |
| 401 | + this->param_ranges = { |
| 402 | + { 0.5, 10.0}, |
| 403 | + {-90.0, 90.0}, |
| 404 | + {-90.0, 90.0}, |
| 405 | + {-90.0, 90.0}, |
| 406 | + { 0.1, 1.0}, |
| 407 | + { 0.1, 1.0} |
| 408 | + }; |
| 409 | + this->steps = {0.5, 10.0, 10.0, 10.0, 0.1, 0.1}; |
| 410 | + this->ns = {19, 18, 18, 18, 9, 9}; |
| 411 | + }else{ |
| 412 | + throw std::runtime_error("ADAPTIVE_ESI_IDW available just for 2D and 3D"); |
| 413 | + } |
| 414 | + this->k = (int)std::ceil(0.1*std::pow(3, this->ns.size())); |
365 | 415 | post_process(); |
366 | 416 | } |
367 | 417 |
|
|
0 commit comments