Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added a
Empty file.
40 changes: 40 additions & 0 deletions tffm/base.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import tensorflow as tf
from .core import TFFMCore
import sklearn
from sklearn.base import BaseEstimator
from abc import ABCMeta, abstractmethod
import six
from tqdm import tqdm
import numpy as np
import os
import pickle


def batcher(X_, y_=None, w_=None, batch_size=-1):
Expand Down Expand Up @@ -277,3 +279,41 @@ def destroy(self):
"""Terminates session and destroyes graph."""
self.session.close()
self.core.graph = None

def save_model(self, path):
"""Saves the entire model"""
self.save_state(path)

## Additional core attributes that are to be pickled!
CORE_ATTRS = ['n_features']

# remove loss_function if classifier
if self.__class__.__name__.rpartition('.')[2] == 'TFFMClassifier':
del self._init_params_copy['loss_function']
pickle_params ={
'init': self._init_params_copy,
'core': {k:getattr(self.core, k) for k in CORE_ATTRS}
}
if not os.path.isdir(os.path.dirname(path)):
os.makedirs(path, exist_ok=True)

pickle.dump(pickle_params, open(path+'.tffm', 'wb'))

@staticmethod
def load_model(klass, path):
"""
Restores the TFFM model, along with tensorflow state

Parameters:
-----------
klass: TFFMRegressor or TFFMClassifier
path: path to save pickled model file with extension .tffm
"""
if klass.__name__.rpartition('.')[2] not in ['TFFMRegressor', 'TFFMClassifier']:
raise TypeError('klass is not supported: %s'%klass)
_unpickled_obj = pickle.load(open(path+'.tffm', 'rb'))
_new_model = klass(**_unpickled_obj['init'])
for k in _unpickled_obj['core']:
setattr(_new_model.core, k, _unpickled_obj['core'][k])
_new_model.load_state(path)
return _new_model
4 changes: 3 additions & 1 deletion tffm/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from .base import TFFMBaseModel
from .utils import loss_logistic, loss_mse, sigmoid

import copy


class TFFMClassifier(TFFMBaseModel):
Expand All @@ -24,6 +24,7 @@ def __init__(self, **init_params):
base class TFFMBaseModel."""

init_params['loss_function'] = loss_logistic
self._init_params_copy = copy.deepcopy(init_params)
self.init_basemodel(**init_params)

def _preprocess_sample_weights(self, sample_weight, pos_class_weight, used_y):
Expand Down Expand Up @@ -118,6 +119,7 @@ def __init__(self, **init_params):
not supported for TFFMRegressor. For custom loss function, extend the
base class TFFMBaseModel."""

self._init_params_copy = copy.deepcopy(init_params)
init_params['loss_function'] = loss_mse
self.init_basemodel(**init_params)

Expand Down