From 0094e6c1f8cefc04b1490c4cd7060f0e4696fb6f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 9 Jan 2026 16:06:18 +0100 Subject: [PATCH 1/3] File not found error for loss --- chebai/preprocessing/datasets/base.py | 32 +++++++++++++------------- chebai/preprocessing/datasets/chebi.py | 10 ++++---- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 02b6ec72..66e9be0c 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -96,9 +96,9 @@ def __init__( self.prediction_kind = prediction_kind self.data_limit = data_limit self.label_filter = label_filter - assert (balance_after_filter is not None) or ( - self.label_filter is None - ), "Filter balancing requires a filter" + assert (balance_after_filter is not None) or (self.label_filter is None), ( + "Filter balancing requires a filter" + ) self.balance_after_filter = balance_after_filter self.num_workers = num_workers self.persistent_workers: bool = bool(persistent_workers) @@ -108,13 +108,13 @@ def __init__( self.use_inner_cross_validation = ( inner_k_folds > 1 ) # only use cv if there are at least 2 folds - assert ( - fold_index is None or self.use_inner_cross_validation is not None - ), "fold_index can only be set if cross validation is used" + assert fold_index is None or self.use_inner_cross_validation is not None, ( + "fold_index can only be set if cross validation is used" + ) if fold_index is not None and self.inner_k_folds is not None: - assert ( - fold_index < self.inner_k_folds - ), "fold_index can't be larger than the total number of folds" + assert fold_index < self.inner_k_folds, ( + "fold_index can't be larger than the total number of folds" + ) self.fold_index = fold_index self._base_dir = base_dir self.n_token_limit = n_token_limit @@ -137,9 +137,9 @@ def num_of_labels(self): @property def feature_vector_size(self): - assert ( - self._feature_vector_size is not None - ), "size of feature vector must be set" + assert self._feature_vector_size is not None, ( + "size of feature vector must be set" + ) return self._feature_vector_size @property @@ -1173,9 +1173,7 @@ def _retrieve_splits_from_csv(self) -> None: splits_df = pd.read_csv(self.splits_file_path) filename = self.processed_file_names_dict["data"] - data = self.load_processed_data_from_file( - os.path.join(self.processed_dir, filename) - ) + data = self.load_processed_data_from_file(filename) df_data = pd.DataFrame(data) if self.apply_id_filter: @@ -1255,7 +1253,9 @@ def load_processed_data( return self.load_processed_data_from_file(filename) def load_processed_data_from_file(self, filename): - return torch.load(os.path.join(filename), weights_only=False) + return torch.load( + os.path.join(self.processed_dir, filename), weights_only=False + ) # ------------------------------ Phase: Raw Properties ----------------------------------- @property diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 379a7f62..e82a83a8 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -144,9 +144,9 @@ def __init__( **kwargs, ): if bool(augment_smiles): - assert ( - int(aug_smiles_variations) > 0 - ), "Number of variations must be greater than 0" + assert int(aug_smiles_variations) > 0, ( + "Number of variations must be greater than 0" + ) aug_smiles_variations = int(aug_smiles_variations) if not kwargs.get("splits_file_path", None): @@ -516,9 +516,7 @@ def _get_data_splits(self) -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: """ try: filename = self.processed_file_names_dict["data"] - data_chebi_version = self.load_processed_data_from_file( - os.path.join(self.processed_dir, filename) - ) + data_chebi_version = self.load_processed_data_from_file(filename) except FileNotFoundError: raise FileNotFoundError( "File data.pt doesn't exists. " From 89cb0055f5e47a16b9ca204ec152fbf85db222a7 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Fri, 9 Jan 2026 16:21:26 +0100 Subject: [PATCH 2/3] pre-commit format --- chebai/preprocessing/datasets/base.py | 24 ++++++++++++------------ chebai/preprocessing/datasets/chebi.py | 6 +++--- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 66e9be0c..4c269fa3 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -96,9 +96,9 @@ def __init__( self.prediction_kind = prediction_kind self.data_limit = data_limit self.label_filter = label_filter - assert (balance_after_filter is not None) or (self.label_filter is None), ( - "Filter balancing requires a filter" - ) + assert (balance_after_filter is not None) or ( + self.label_filter is None + ), "Filter balancing requires a filter" self.balance_after_filter = balance_after_filter self.num_workers = num_workers self.persistent_workers: bool = bool(persistent_workers) @@ -108,13 +108,13 @@ def __init__( self.use_inner_cross_validation = ( inner_k_folds > 1 ) # only use cv if there are at least 2 folds - assert fold_index is None or self.use_inner_cross_validation is not None, ( - "fold_index can only be set if cross validation is used" - ) + assert ( + fold_index is None or self.use_inner_cross_validation is not None + ), "fold_index can only be set if cross validation is used" if fold_index is not None and self.inner_k_folds is not None: - assert fold_index < self.inner_k_folds, ( - "fold_index can't be larger than the total number of folds" - ) + assert ( + fold_index < self.inner_k_folds + ), "fold_index can't be larger than the total number of folds" self.fold_index = fold_index self._base_dir = base_dir self.n_token_limit = n_token_limit @@ -137,9 +137,9 @@ def num_of_labels(self): @property def feature_vector_size(self): - assert self._feature_vector_size is not None, ( - "size of feature vector must be set" - ) + assert ( + self._feature_vector_size is not None + ), "size of feature vector must be set" return self._feature_vector_size @property diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index e82a83a8..b663b4d9 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -144,9 +144,9 @@ def __init__( **kwargs, ): if bool(augment_smiles): - assert int(aug_smiles_variations) > 0, ( - "Number of variations must be greater than 0" - ) + assert ( + int(aug_smiles_variations) > 0 + ), "Number of variations must be greater than 0" aug_smiles_variations = int(aug_smiles_variations) if not kwargs.get("splits_file_path", None): From a5ea56a9b3eacfc683c15eff0d94f46a3844db27 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 15 Jan 2026 15:54:15 +0100 Subject: [PATCH 3/3] docstring --- chebai/preprocessing/datasets/base.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 4c269fa3..41c6cada 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -1252,7 +1252,20 @@ def load_processed_data( # If filename is provided return self.load_processed_data_from_file(filename) - def load_processed_data_from_file(self, filename): + def load_processed_data_from_file(self, filename: str) -> list[dict[str, Any]]: + """Load processed data from a file. + + The full path is not required; only the filename is needed, as it will be joined with the processed directory. + + Args: + filename (str): The name of the file to load the processed data from. + + Returns: + List[Dict[str, Any]]: The loaded processed data. + + Example: + data = self.load_processed_data_from_file('data.pt') + """ return torch.load( os.path.join(self.processed_dir, filename), weights_only=False )