diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 02b6ec72..41c6cada 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -1173,9 +1173,7 @@ def _retrieve_splits_from_csv(self) -> None: splits_df = pd.read_csv(self.splits_file_path) filename = self.processed_file_names_dict["data"] - data = self.load_processed_data_from_file( - os.path.join(self.processed_dir, filename) - ) + data = self.load_processed_data_from_file(filename) df_data = pd.DataFrame(data) if self.apply_id_filter: @@ -1254,8 +1252,23 @@ def load_processed_data( # If filename is provided return self.load_processed_data_from_file(filename) - def load_processed_data_from_file(self, filename): - return torch.load(os.path.join(filename), weights_only=False) + def load_processed_data_from_file(self, filename: str) -> list[dict[str, Any]]: + """Load processed data from a file. + + The full path is not required; only the filename is needed, as it will be joined with the processed directory. + + Args: + filename (str): The name of the file to load the processed data from. + + Returns: + List[Dict[str, Any]]: The loaded processed data. + + Example: + data = self.load_processed_data_from_file('data.pt') + """ + return torch.load( + os.path.join(self.processed_dir, filename), weights_only=False + ) # ------------------------------ Phase: Raw Properties ----------------------------------- @property diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 379a7f62..b663b4d9 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -516,9 +516,7 @@ def _get_data_splits(self) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ try: filename = self.processed_file_names_dict["data"] - data_chebi_version = self.load_processed_data_from_file( - os.path.join(self.processed_dir, filename) - ) + data_chebi_version = self.load_processed_data_from_file(filename) except FileNotFoundError: raise FileNotFoundError( "File data.pt doesn't exists. "