Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions src/libfm/libfm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,12 @@ int main(int argc, char **argv) {

const std::string param_save_model = cmdline.registerParameter("save_model", "filename for writing the FM model");
const std::string param_load_model = cmdline.registerParameter("load_model", "filename for reading the FM model");

const std::string param_finetuning = cmdline.registerParameter("finetuning", "if finetuning mode, 1 for yes, 0 for not, default 0");
const std::string param_do_sampling = "do_sampling";
const std::string param_do_multilevel = "do_multilevel";
const std::string param_num_eval_cases = "num_eval_cases";


if (cmdline.hasParameter(param_help) || (argc == 1)) {
cmdline.print_help();
return 0;
Expand All @@ -118,6 +119,7 @@ int main(int argc, char **argv) {
if (! cmdline.hasParameter(param_method)) { cmdline.setValue(param_method, "mcmc"); }
if (! cmdline.hasParameter(param_init_stdev)) { cmdline.setValue(param_init_stdev, "0.1"); }
if (! cmdline.hasParameter(param_dim)) { cmdline.setValue(param_dim, "1,1,8"); }
if (! cmdline.hasParameter(param_finetuning)) { cmdline.setValue(param_finetuning, "0"); }

// Check for invalid flags.
if (! cmdline.getValue(param_method).compare("mcmc") && cmdline.hasParameter(param_save_model)) {
Expand All @@ -132,6 +134,13 @@ int main(int argc, char **argv) {
return 0;
}

if ((!cmdline.hasParameter(param_load_model))
&& (cmdline.getValue(param_finetuning) == "1")) {
std::cout << "WARNING: -finetuning disabled since no loading model." << std::endl;
cmdline.setValue(param_finetuning, "0");
return 0;
}

if (! cmdline.getValue(param_method).compare("als")) { // als is an mcmc without sampling and hyperparameter inference
cmdline.setValue(param_method, "mcmc");
if (! cmdline.hasParameter(param_do_sampling)) { cmdline.setValue(param_do_sampling, "0"); }
Expand Down Expand Up @@ -280,7 +289,15 @@ int main(int argc, char **argv) {
((fm_learn_sgd_element_adapt_reg*)fml)->validation = validation;

} else if (! cmdline.getValue(param_method).compare("mcmc")) {
fm.w.init_normal(fm.init_mean, fm.init_stdev);
if (cmdline.getValue(param_finetuning) != "1") {
std::cout << "WARNING: not finetuning mode. finetuning=" << cmdline.getValue(param_finetuning) << std::endl;
fm.w.init_normal(fm.init_mean, fm.init_stdev);
} else {
std::cout << "WARNING: finetuning mode. finetuning=" << cmdline.getValue(param_finetuning) << std::endl;
//std::cout << "DEBUG: debug finetune parameter" << std::endl;
//fm.saveModel("./last_model_0");
//return 0;
}
fml = new fm_learn_mcmc_simultaneous();
fml->validation = validation;
((fm_learn_mcmc*)fml)->num_iter = cmdline.getValue(param_num_iter, 100);
Expand Down