Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 39 additions & 3 deletions edge/model/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,11 @@ def append_data(self, x, y, **kwargs):

def save(self, save_path):
"""
Saves the GP in PyTorch format.
PyTorch does NOT save samples or class structure. Such a model cannot be loaded by a simple "file.open" method.
Saves the GP in PyTorch format, and optionally the Dataset object.
PyTorch does NOT save samples or class structure. Such a model cannot
be loaded by a simple "file.open" method.
See the GP.load method for more information.
:param save_path: str or Path: the path of the file where to save the model
:param save_path: str or Path: where to save the GP model
"""
save_path = str(save_path)
if not save_path.endswith('.pth'):
Expand All @@ -205,6 +206,16 @@ def save(self, save_path):

torch.save(save_dict, save_path)

def save_dataset(self, save_path):
"""
Saves a dataset, using the method implemented in the dataset class.
:param save_data: str or Path: where to save the Dataset
"""
save_path = str(save_path)
if not save_path.endswith('.pth'):
save_path += '.pth'
self.dataset.save(save_path)

# Careful: composing decorators with @staticmethod can be tricky. The @staticmethod decorator should be the last
# one, because it does NOT return a method but an observer object
@staticmethod
Expand Down Expand Up @@ -234,8 +245,18 @@ def load(load_path, train_x, train_y):
**construction_parameters
)
model.load_state_dict(save_dict['state_dict'])

return model

def load_dataset(self, load_path):
"""
Loads and sets `train_x` and `train_y`.
:param load_path: str or Path: the path to the data file
"""
load_path = str(load_path)
self.dataset.load(load_path)
self._set_gp_data_to_dataset()


class Dataset:
"""
Expand Down Expand Up @@ -268,6 +289,21 @@ def append(self, append_x, append_y, **kwargs):
self.train_x = torch.cat((self.train_x, atleast_2d(append_x)), dim=0)
self.train_y = torch.cat((self.train_y, append_y), dim=0)

def save(self, save_path):
save_path = str(save_path)
if not save_path.endswith('.pth'):
save_path += '.pth'

torch.save({'train_x': self.train_x,
'train_y': self.train_y},
save_path)

def load(self, load_path):
load_path = str(load_path)
save_dict = torch.load(load_path)
self.train_x = save_dict['train_x']
self.train_y = save_dict['train_y']


class TimeForgettingDataset(Dataset):
"""
Expand Down
9 changes: 9 additions & 0 deletions test/gp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,15 @@ def test_load_save(self):
self.assertEqual(model.covar_module.outputscale,
loaded.covar_module.outputscale)

save_data = tempfile.NamedTemporaryFile(suffix='.pth').name
model.save_dataset(save_data)
self.assertTrue(os.path.isfile(save_data))
# load a new GP with different seed, then load the dataset
x2 = np.linspace(2, 3, 11)
loaded = MaternGP.load(save_file, x2, y)
loaded.load_dataset(save_data)
self.assertTrue(torch.all(torch.eq(model.train_x, loaded.train_x)))

def test_hyper_optimization_0(self):
warnings.simplefilter('ignore', gpytorch.utils.warnings.GPInputWarning)

Expand Down