Skip to content

Commit 32b0d3e

Browse files
author
Felipe garrido R
committed
improved optimization func for adaptative esi
1 parent cc943d3 commit 32b0d3e

File tree

5 files changed

+835
-3
lines changed

5 files changed

+835
-3
lines changed

include/spatialize/abstract_esi.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,7 +310,7 @@ namespace sptlz{
310310
return(root->search_leaf(point));
311311
}
312312

313-
std::string to_json(){
313+
std::string to_json(){
314314
return(root->to_json(""));
315315
}
316316
};

include/spatialize/adaptive_esi_idw.hpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include <random>
77
#include "spatialize/abstract_esi.hpp"
88
#include "spatialize/utils.hpp"
9+
#include "spatialize/grad_descent.hpp"
910

1011
namespace sptlz{
1112
class LOO2D{
@@ -88,6 +89,11 @@ namespace sptlz{
8889

8990
class ADAPTIVE_ESI_IDW: public ESI {
9091
protected:
92+
int d, k;
93+
std::vector<std::vector<float>> param_ranges;
94+
std::vector<float> steps;
95+
std::vector<int> ns;
96+
9197
std::vector<float> leaf_estimation(std::vector<std::vector<float>> *coords, std::vector<float> *values, std::vector<int> *samples_id, std::vector<std::vector<float>> *locations, std::vector<int> *locations_id, std::vector<float> *params){
9298
std::vector<float> result;
9399

@@ -281,6 +287,25 @@ namespace sptlz{
281287
}
282288

283289
std::vector<float> get_params(std::vector<std::vector<float>> *coords, std::vector<float> *values){
290+
if(coords->size()==0){
291+
return(std::vector<float>());
292+
}else if(coords->size()==1){
293+
return(std::vector<float>({values->at(0)}));
294+
}
295+
296+
LOOND *fn;
297+
if(this->d==2){
298+
fn = new LOO_2D(coords, values, 0.01);
299+
}else{
300+
fn = new LOO_3D(coords, values, 0.01);
301+
}
302+
GradDesc *opt = new GridNBRndDesc(fn, this->param_ranges, this->steps, this->ns, k, std::rand());
303+
std::vector<float> m = get_minimum(opt, &(this->param_ranges), 100);
304+
305+
return(m);
306+
}
307+
308+
std::vector<float> get_params2(std::vector<std::vector<float>> *coords, std::vector<float> *values){
284309
std::uniform_real_distribution<float> uni_float(0, 1);
285310
int best_of = 3;
286311
if(coords->size()==0){
@@ -362,6 +387,31 @@ namespace sptlz{
362387
int seed=206936):
363388
ESI(_coords, _values, lambda, forest_size, bbox, visitor, seed){
364389
this->class_name = __func__;
390+
if(_coords.at(0).size()==2){
391+
this->d = 2;
392+
this->param_ranges = {
393+
{ 0.5, 10.0},
394+
{-90.0, 90.0},
395+
{ 0.1, 1.0}
396+
};
397+
this->steps = {0.5, 10.0, 0.1};
398+
this->ns = {19, 18, 9};
399+
}else if(_coords.at(0).size()==3){
400+
this->d = 3;
401+
this->param_ranges = {
402+
{ 0.5, 10.0},
403+
{-90.0, 90.0},
404+
{-90.0, 90.0},
405+
{-90.0, 90.0},
406+
{ 0.1, 1.0},
407+
{ 0.1, 1.0}
408+
};
409+
this->steps = {0.5, 10.0, 10.0, 10.0, 0.1, 0.1};
410+
this->ns = {19, 18, 18, 18, 9, 9};
411+
}else{
412+
throw std::runtime_error("ADAPTIVE_ESI_IDW available just for 2D and 3D");
413+
}
414+
this->k = (int)std::ceil(0.1*std::pow(3, this->ns.size()));
365415
post_process();
366416
}
367417

0 commit comments

Comments
 (0)