From e18bb781bc137e540a373e0ce2e324ff8d4d804e Mon Sep 17 00:00:00 2001 From: innerNULL Date: Sat, 28 Mar 2020 15:35:00 +0800 Subject: [PATCH] Supports finetuning mode. --- src/libfm/libfm.cpp | 21 +++++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/src/libfm/libfm.cpp b/src/libfm/libfm.cpp index 82b795a..57e01b1 100644 --- a/src/libfm/libfm.cpp +++ b/src/libfm/libfm.cpp @@ -100,11 +100,12 @@ int main(int argc, char **argv) { const std::string param_save_model = cmdline.registerParameter("save_model", "filename for writing the FM model"); const std::string param_load_model = cmdline.registerParameter("load_model", "filename for reading the FM model"); - + const std::string param_finetuning = cmdline.registerParameter("finetuning", "if finetuning mode, 1 for yes, 0 for not, default 0"); const std::string param_do_sampling = "do_sampling"; const std::string param_do_multilevel = "do_multilevel"; const std::string param_num_eval_cases = "num_eval_cases"; + if (cmdline.hasParameter(param_help) || (argc == 1)) { cmdline.print_help(); return 0; @@ -118,6 +119,7 @@ int main(int argc, char **argv) { if (! cmdline.hasParameter(param_method)) { cmdline.setValue(param_method, "mcmc"); } if (! cmdline.hasParameter(param_init_stdev)) { cmdline.setValue(param_init_stdev, "0.1"); } if (! cmdline.hasParameter(param_dim)) { cmdline.setValue(param_dim, "1,1,8"); } + if (! cmdline.hasParameter(param_finetuning)) { cmdline.setValue(param_finetuning, "0"); } // Check for invalid flags. if (! cmdline.getValue(param_method).compare("mcmc") && cmdline.hasParameter(param_save_model)) { @@ -132,6 +134,13 @@ int main(int argc, char **argv) { return 0; } + if ((!cmdline.hasParameter(param_load_model)) + && (cmdline.getValue(param_finetuning) == "1")) { + std::cout << "WARNING: -finetuning disabled since no loading model." << std::endl; + cmdline.setValue(param_finetuning, "0"); + return 0; + } + if (! cmdline.getValue(param_method).compare("als")) { // als is an mcmc without sampling and hyperparameter inference cmdline.setValue(param_method, "mcmc"); if (! cmdline.hasParameter(param_do_sampling)) { cmdline.setValue(param_do_sampling, "0"); } @@ -280,7 +289,15 @@ int main(int argc, char **argv) { ((fm_learn_sgd_element_adapt_reg*)fml)->validation = validation; } else if (! cmdline.getValue(param_method).compare("mcmc")) { - fm.w.init_normal(fm.init_mean, fm.init_stdev); + if (cmdline.getValue(param_finetuning) != "1") { + std::cout << "WARNING: not finetuning mode. finetuning=" << cmdline.getValue(param_finetuning) << std::endl; + fm.w.init_normal(fm.init_mean, fm.init_stdev); + } else { + std::cout << "WARNING: finetuning mode. finetuning=" << cmdline.getValue(param_finetuning) << std::endl; + //std::cout << "DEBUG: debug finetune parameter" << std::endl; + //fm.saveModel("./last_model_0"); + //return 0; + } fml = new fm_learn_mcmc_simultaneous(); fml->validation = validation; ((fm_learn_mcmc*)fml)->num_iter = cmdline.getValue(param_num_iter, 100);