diff --git a/edge/model/inference/inference.py b/edge/model/inference/inference.py index 1031e55..044f8ab 100644 --- a/edge/model/inference/inference.py +++ b/edge/model/inference/inference.py @@ -186,10 +186,11 @@ def append_data(self, x, y, **kwargs): def save(self, save_path): """ - Saves the GP in PyTorch format. - PyTorch does NOT save samples or class structure. Such a model cannot be loaded by a simple "file.open" method. + Saves the GP in PyTorch format, and optionally the Dataset object. + PyTorch does NOT save samples or class structure. Such a model cannot + be loaded by a simple "file.open" method. See the GP.load method for more information. - :param save_path: str or Path: the path of the file where to save the model + :param save_path: str or Path: where to save the GP model """ save_path = str(save_path) if not save_path.endswith('.pth'): @@ -205,6 +206,16 @@ def save(self, save_path): torch.save(save_dict, save_path) + def save_dataset(self, save_path): + """ + Saves a dataset, using the method implemented in the dataset class. + :param save_data: str or Path: where to save the Dataset + """ + save_path = str(save_path) + if not save_path.endswith('.pth'): + save_path += '.pth' + self.dataset.save(save_path) + # Careful: composing decorators with @staticmethod can be tricky. The @staticmethod decorator should be the last # one, because it does NOT return a method but an observer object @staticmethod @@ -234,8 +245,18 @@ def load(load_path, train_x, train_y): **construction_parameters ) model.load_state_dict(save_dict['state_dict']) + return model + def load_dataset(self, load_path): + """ + Loads and sets `train_x` and `train_y`. + :param load_path: str or Path: the path to the data file + """ + load_path = str(load_path) + self.dataset.load(load_path) + self._set_gp_data_to_dataset() + class Dataset: """ @@ -268,6 +289,21 @@ def append(self, append_x, append_y, **kwargs): self.train_x = torch.cat((self.train_x, atleast_2d(append_x)), dim=0) self.train_y = torch.cat((self.train_y, append_y), dim=0) + def save(self, save_path): + save_path = str(save_path) + if not save_path.endswith('.pth'): + save_path += '.pth' + + torch.save({'train_x': self.train_x, + 'train_y': self.train_y}, + save_path) + + def load(self, load_path): + load_path = str(load_path) + save_dict = torch.load(load_path) + self.train_x = save_dict['train_x'] + self.train_y = save_dict['train_y'] + class TimeForgettingDataset(Dataset): """ diff --git a/test/gp_test.py b/test/gp_test.py index 7cc4945..007c617 100644 --- a/test/gp_test.py +++ b/test/gp_test.py @@ -127,6 +127,15 @@ def test_load_save(self): self.assertEqual(model.covar_module.outputscale, loaded.covar_module.outputscale) + save_data = tempfile.NamedTemporaryFile(suffix='.pth').name + model.save_dataset(save_data) + self.assertTrue(os.path.isfile(save_data)) + # load a new GP with different seed, then load the dataset + x2 = np.linspace(2, 3, 11) + loaded = MaternGP.load(save_file, x2, y) + loaded.load_dataset(save_data) + self.assertTrue(torch.all(torch.eq(model.train_x, loaded.train_x))) + def test_hyper_optimization_0(self): warnings.simplefilter('ignore', gpytorch.utils.warnings.GPInputWarning)