diff --git a/a b/a new file mode 100644 index 0000000..e69de29 diff --git a/tffm/base.py b/tffm/base.py index 4e72626..ed62227 100644 --- a/tffm/base.py +++ b/tffm/base.py @@ -1,11 +1,13 @@ import tensorflow as tf from .core import TFFMCore +import sklearn from sklearn.base import BaseEstimator from abc import ABCMeta, abstractmethod import six from tqdm import tqdm import numpy as np import os +import pickle def batcher(X_, y_=None, w_=None, batch_size=-1): @@ -277,3 +279,41 @@ def destroy(self): """Terminates session and destroyes graph.""" self.session.close() self.core.graph = None + + def save_model(self, path): + """Saves the entire model""" + self.save_state(path) + + ## Additional core attributes that are to be pickled! + CORE_ATTRS = ['n_features'] + + # remove loss_function if classifier + if self.__class__.__name__.rpartition('.')[2] == 'TFFMClassifier': + del self._init_params_copy['loss_function'] + pickle_params ={ + 'init': self._init_params_copy, + 'core': {k:getattr(self.core, k) for k in CORE_ATTRS} + } + if not os.path.isdir(os.path.dirname(path)): + os.makedirs(path, exist_ok=True) + + pickle.dump(pickle_params, open(path+'.tffm', 'wb')) + + @staticmethod + def load_model(klass, path): + """ + Restores the TFFM model, along with tensorflow state + + Parameters: + ----------- + klass: TFFMRegressor or TFFMClassifier + path: path to save pickled model file with extension .tffm + """ + if klass.__name__.rpartition('.')[2] not in ['TFFMRegressor', 'TFFMClassifier']: + raise TypeError('klass is not supported: %s'%klass) + _unpickled_obj = pickle.load(open(path+'.tffm', 'rb')) + _new_model = klass(**_unpickled_obj['init']) + for k in _unpickled_obj['core']: + setattr(_new_model.core, k, _unpickled_obj['core'][k]) + _new_model.load_state(path) + return _new_model diff --git a/tffm/models.py b/tffm/models.py index b37e98e..5f8a5bb 100644 --- a/tffm/models.py +++ b/tffm/models.py @@ -3,7 +3,7 @@ import numpy as np from .base import TFFMBaseModel from .utils import loss_logistic, loss_mse, sigmoid - +import copy class TFFMClassifier(TFFMBaseModel): @@ -24,6 +24,7 @@ def __init__(self, **init_params): base class TFFMBaseModel.""" init_params['loss_function'] = loss_logistic + self._init_params_copy = copy.deepcopy(init_params) self.init_basemodel(**init_params) def _preprocess_sample_weights(self, sample_weight, pos_class_weight, used_y): @@ -118,6 +119,7 @@ def __init__(self, **init_params): not supported for TFFMRegressor. For custom loss function, extend the base class TFFMBaseModel.""" + self._init_params_copy = copy.deepcopy(init_params) init_params['loss_function'] = loss_mse self.init_basemodel(**init_params)