From e369a1cab71e78160e855f5d83b1477b809c2f35 Mon Sep 17 00:00:00 2001 From: Daniel Yoo Date: Thu, 2 Oct 2014 13:54:38 -0700 Subject: [PATCH] added options to save and load model parameters --- src/libfm/libfm.cpp | 64 +++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 62 insertions(+), 2 deletions(-) diff --git a/src/libfm/libfm.cpp b/src/libfm/libfm.cpp index a41667b..5581168 100644 --- a/src/libfm/libfm.cpp +++ b/src/libfm/libfm.cpp @@ -42,6 +42,7 @@ #include #include #include +#include #include #include #include @@ -96,6 +97,14 @@ int main(int argc, char **argv) { const std::string param_relation = cmdline.registerParameter("relation", "BS: filenames for the relations, default=''"); const std::string param_cache_size = cmdline.registerParameter("cache_size", "cache size for data storage (only applicable if data is in binary format), default=infty"); + + + const std::string param_save_model_w0 = cmdline.registerParameter("save_model_w0", "write model w weights to text file; default = ''"); + const std::string param_load_model_w0 = cmdline.registerParameter("load_model_w0", "load model w weights from text file; default = ''"); + const std::string param_save_model_w = cmdline.registerParameter("save_model_w", "write model w weights to binary file; default = ''"); + const std::string param_load_model_w = cmdline.registerParameter("load_model_w", "load model w weights from binary file; default = ''"); + const std::string param_save_model_v = cmdline.registerParameter("save_model_v", "write model v weights to binary file; default = ''"); + const std::string param_load_model_v = cmdline.registerParameter("load_model_v", "load model v weights from binary file; default = ''"); const std::string param_do_sampling = "do_sampling"; @@ -111,7 +120,7 @@ int main(int argc, char **argv) { if (! cmdline.hasParameter(param_method)) { cmdline.setValue(param_method, "mcmc"); } if (! cmdline.hasParameter(param_init_stdev)) { cmdline.setValue(param_init_stdev, "0.1"); } if (! cmdline.hasParameter(param_dim)) { cmdline.setValue(param_dim, "1,1,8"); } - + if (! cmdline.getValue(param_method).compare("als")) { // als is an mcmc without sampling and hyperparameter inference cmdline.setValue(param_method, "mcmc"); if (! cmdline.hasParameter(param_do_sampling)) { cmdline.setValue(param_do_sampling, "0"); } @@ -238,7 +247,9 @@ int main(int argc, char **argv) { fm.init(); } - + + + // (3) Setup the learning method: fm_learn* fml; if (! cmdline.getValue(param_method).compare("sgd")) { @@ -383,9 +394,58 @@ int main(int argc, char **argv) { fm.debug(); fml->debug(); } + + + // () Load in the model to the factorization machine + if (cmdline.hasParameter(param_load_model_w0)) { + std::string filename = cmdline.getValue(param_load_model_w0); + std::cout << "reading " << filename << std::endl; std::cout.flush(); + std::ifstream in(filename.c_str(), std::ios_base::in | std::ios_base::binary); + if (in.is_open()) { + in.read(reinterpret_cast(&fm.w0), sizeof(double)); + in.close(); + } else { + throw "could not open " + filename; + } + } + + if (cmdline.hasParameter(param_load_model_w)) { + std::string filename = cmdline.getValue(param_load_model_w); + fm.w.loadFromBinaryFile(filename); + } + if (cmdline.hasParameter(param_load_model_v)) { + std::string filename = cmdline.getValue(param_load_model_v); + fm.v.loadFromBinaryFile(filename); + } + // () learn fml->learn(train, test); + + + // () Save the learned model + if (cmdline.hasParameter(param_save_model_w0)) { + std::string filename = cmdline.getValue(param_save_model_w0); + std::cout << "writing " << filename << std::endl; std::cout.flush(); + std::ofstream out (filename.c_str(), std::ios_base::binary); + if (out.is_open()) { + out.write(reinterpret_cast(&fm.w0), sizeof(double)); + out.close(); + } else { + throw "could not open " + filename; + } + } + + if (cmdline.hasParameter(param_save_model_w)) { + std::string filename = cmdline.getValue(param_save_model_w); + fm.w.saveToBinaryFile(filename); + } + + if (cmdline.hasParameter(param_save_model_v)) { + std::string filename = cmdline.getValue(param_save_model_v); + fm.v.saveToBinaryFile(filename); + } + // () Prediction at the end (not for mcmc and als) if (cmdline.getValue(param_method).compare("mcmc")) {