Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 62 additions & 2 deletions src/libfm/libfm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
#include <cstdlib>
#include <cstdio>
#include <iostream>
#include <fstream>
#include <string>
#include <iterator>
#include <algorithm>
Expand Down Expand Up @@ -96,6 +97,14 @@ int main(int argc, char **argv) {
const std::string param_relation = cmdline.registerParameter("relation", "BS: filenames for the relations, default=''");

const std::string param_cache_size = cmdline.registerParameter("cache_size", "cache size for data storage (only applicable if data is in binary format), default=infty");


const std::string param_save_model_w0 = cmdline.registerParameter("save_model_w0", "write model w weights to text file; default = ''");
const std::string param_load_model_w0 = cmdline.registerParameter("load_model_w0", "load model w weights from text file; default = ''");
const std::string param_save_model_w = cmdline.registerParameter("save_model_w", "write model w weights to binary file; default = ''");
const std::string param_load_model_w = cmdline.registerParameter("load_model_w", "load model w weights from binary file; default = ''");
const std::string param_save_model_v = cmdline.registerParameter("save_model_v", "write model v weights to binary file; default = ''");
const std::string param_load_model_v = cmdline.registerParameter("load_model_v", "load model v weights from binary file; default = ''");


const std::string param_do_sampling = "do_sampling";
Expand All @@ -111,7 +120,7 @@ int main(int argc, char **argv) {
if (! cmdline.hasParameter(param_method)) { cmdline.setValue(param_method, "mcmc"); }
if (! cmdline.hasParameter(param_init_stdev)) { cmdline.setValue(param_init_stdev, "0.1"); }
if (! cmdline.hasParameter(param_dim)) { cmdline.setValue(param_dim, "1,1,8"); }

if (! cmdline.getValue(param_method).compare("als")) { // als is an mcmc without sampling and hyperparameter inference
cmdline.setValue(param_method, "mcmc");
if (! cmdline.hasParameter(param_do_sampling)) { cmdline.setValue(param_do_sampling, "0"); }
Expand Down Expand Up @@ -238,7 +247,9 @@ int main(int argc, char **argv) {
fm.init();

}




// (3) Setup the learning method:
fm_learn* fml;
if (! cmdline.getValue(param_method).compare("sgd")) {
Expand Down Expand Up @@ -383,9 +394,58 @@ int main(int argc, char **argv) {
fm.debug();
fml->debug();
}


// () Load in the model to the factorization machine
if (cmdline.hasParameter(param_load_model_w0)) {
std::string filename = cmdline.getValue(param_load_model_w0);
std::cout << "reading " << filename << std::endl; std::cout.flush();
std::ifstream in(filename.c_str(), std::ios_base::in | std::ios_base::binary);
if (in.is_open()) {
in.read(reinterpret_cast<char*>(&fm.w0), sizeof(double));
in.close();
} else {
throw "could not open " + filename;
}
}

if (cmdline.hasParameter(param_load_model_w)) {
std::string filename = cmdline.getValue(param_load_model_w);
fm.w.loadFromBinaryFile(filename);
}

if (cmdline.hasParameter(param_load_model_v)) {
std::string filename = cmdline.getValue(param_load_model_v);
fm.v.loadFromBinaryFile(filename);
}

// () learn
fml->learn(train, test);


// () Save the learned model
if (cmdline.hasParameter(param_save_model_w0)) {
std::string filename = cmdline.getValue(param_save_model_w0);
std::cout << "writing " << filename << std::endl; std::cout.flush();
std::ofstream out (filename.c_str(), std::ios_base::binary);
if (out.is_open()) {
out.write(reinterpret_cast<char*>(&fm.w0), sizeof(double));
out.close();
} else {
throw "could not open " + filename;
}
}

if (cmdline.hasParameter(param_save_model_w)) {
std::string filename = cmdline.getValue(param_save_model_w);
fm.w.saveToBinaryFile(filename);
}

if (cmdline.hasParameter(param_save_model_v)) {
std::string filename = cmdline.getValue(param_save_model_v);
fm.v.saveToBinaryFile(filename);
}


// () Prediction at the end (not for mcmc and als)
if (cmdline.getValue(param_method).compare("mcmc")) {
Expand Down