diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml deleted file mode 100644 index bb9154fd..00000000 --- a/.github/workflows/lint.yml +++ /dev/null @@ -1,26 +0,0 @@ -name: Lint - -on: [push, pull_request] - -jobs: - lint: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v2 - - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: '3.10' # or any version your project uses - - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install black==25.1.0 ruff==0.12.2 - - - name: Run Black - run: black --check . - - - name: Run Ruff (no formatting) - run: ruff check . --no-fix diff --git a/.github/workflows/precommit-action.yml b/.github/workflows/precommit-action.yml new file mode 100644 index 00000000..edea9678 --- /dev/null +++ b/.github/workflows/precommit-action.yml @@ -0,0 +1,19 @@ +name: Pre-commit Check + +on: + push: + branches: [main, master] + pull_request: + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: '3.10' + + - name: Run pre-commit + uses: pre-commit/action@v3.0.1 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cbb7284d..74fda14f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,31 +1,19 @@ repos: -- repo: https://github.com/psf/black - rev: "25.1.0" - hooks: - - id: black - - id: black-jupyter # for formatting jupyter-notebook +# Use `pre-commit autoupdate` to update all the hook. -- repo: https://github.com/pycqa/isort - rev: 5.13.2 - hooks: - - id: isort - name: isort (python) - args: ["--profile=black"] +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. https://docs.astral.sh/ruff/integrations/#pre-commit + rev: v0.14.11 + hooks: + # Run the linter. + - id: ruff-check + args: [ --fix ] + # Run the formatter. + - id: ruff-format -- repo: https://github.com/asottile/seed-isort-config - rev: v2.2.0 - hooks: - - id: seed-isort-config - -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 - hooks: - - id: check-yaml - - id: end-of-file-fixer - - id: trailing-whitespace - -- repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.12.2 - hooks: - - id: ruff - args: [--fix] +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v6.0.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index 1f21b04b..48fe5bf2 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -32,18 +32,20 @@ def __init__( data_extractor = data_extractor.labeled self.data_extractor = data_extractor - assert ( - isinstance(beta, float) and beta > 0.0 - ), f"Beta parameter must be a float with value greater than 0.0, for loss class {self.__class__.__name__}." + assert isinstance(beta, float) and beta > 0.0, ( + f"Beta parameter must be a float with value greater than 0.0, for loss class {self.__class__.__name__}." + ) - assert ( - self.data_extractor is not None - ), f"Data extractor must be provided if this loss class ({self.__class__.__name__}) is used." + assert self.data_extractor is not None, ( + f"Data extractor must be provided if this loss class ({self.__class__.__name__}) is used." + ) assert all( os.path.exists(os.path.join(self.data_extractor.processed_dir, file_name)) for file_name in self.data_extractor.processed_file_names - ), "Dataset files not found. Make sure the dataset is processed before using this loss." + ), ( + "Dataset files not found. Make sure the dataset is processed before using this loss." + ) assert ( isinstance(self.data_extractor, _ChEBIDataExtractor) diff --git a/chebai/loss/focal_loss.py b/chebai/loss/focal_loss.py index 0fcc3c61..584229da 100644 --- a/chebai/loss/focal_loss.py +++ b/chebai/loss/focal_loss.py @@ -36,9 +36,9 @@ def __init__( and alpha is not None and isinstance(alpha, (list, torch.Tensor)) ): - assert ( - num_classes is not None - ), "num_classes must be specified for multi-class classification" + assert num_classes is not None, ( + "num_classes must be specified for multi-class classification" + ) if isinstance(alpha, list): self.alpha = torch.Tensor(alpha) else: diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 02b6ec72..07082119 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -96,9 +96,9 @@ def __init__( self.prediction_kind = prediction_kind self.data_limit = data_limit self.label_filter = label_filter - assert (balance_after_filter is not None) or ( - self.label_filter is None - ), "Filter balancing requires a filter" + assert (balance_after_filter is not None) or (self.label_filter is None), ( + "Filter balancing requires a filter" + ) self.balance_after_filter = balance_after_filter self.num_workers = num_workers self.persistent_workers: bool = bool(persistent_workers) @@ -108,13 +108,13 @@ def __init__( self.use_inner_cross_validation = ( inner_k_folds > 1 ) # only use cv if there are at least 2 folds - assert ( - fold_index is None or self.use_inner_cross_validation is not None - ), "fold_index can only be set if cross validation is used" + assert fold_index is None or self.use_inner_cross_validation is not None, ( + "fold_index can only be set if cross validation is used" + ) if fold_index is not None and self.inner_k_folds is not None: - assert ( - fold_index < self.inner_k_folds - ), "fold_index can't be larger than the total number of folds" + assert fold_index < self.inner_k_folds, ( + "fold_index can't be larger than the total number of folds" + ) self.fold_index = fold_index self._base_dir = base_dir self.n_token_limit = n_token_limit @@ -137,9 +137,9 @@ def num_of_labels(self): @property def feature_vector_size(self): - assert ( - self._feature_vector_size is not None - ), "size of feature vector must be set" + assert self._feature_vector_size is not None, ( + "size of feature vector must be set" + ) return self._feature_vector_size @property diff --git a/chebai/preprocessing/datasets/chebi.py b/chebai/preprocessing/datasets/chebi.py index 379a7f62..c6659d05 100644 --- a/chebai/preprocessing/datasets/chebi.py +++ b/chebai/preprocessing/datasets/chebi.py @@ -144,9 +144,9 @@ def __init__( **kwargs, ): if bool(augment_smiles): - assert ( - int(aug_smiles_variations) > 0 - ), "Number of variations must be greater than 0" + assert int(aug_smiles_variations) > 0, ( + "Number of variations must be greater than 0" + ) aug_smiles_variations = int(aug_smiles_variations) if not kwargs.get("splits_file_path", None): diff --git a/chebai/preprocessing/datasets/pubchem.py b/chebai/preprocessing/datasets/pubchem.py index 1491463c..8cc208b9 100644 --- a/chebai/preprocessing/datasets/pubchem.py +++ b/chebai/preprocessing/datasets/pubchem.py @@ -408,9 +408,7 @@ def download(self): print("Selecting most dissimilar values from random subsets...") for i in tqdm.tqdm(range(self.n_random_subsets)): smiles_i = random_smiles[ - i - * len(random_smiles) - // self.n_random_subsets : (i + 1) + i * len(random_smiles) // self.n_random_subsets : (i + 1) * len(random_smiles) // self.n_random_subsets ] diff --git a/chebai/preprocessing/migration/chebi_data_migration.py b/chebai/preprocessing/migration/chebi_data_migration.py index af136557..8815a223 100644 --- a/chebai/preprocessing/migration/chebi_data_migration.py +++ b/chebai/preprocessing/migration/chebi_data_migration.py @@ -308,7 +308,6 @@ def _old_raw_dir(self) -> str: class Main: - def migrate( self, datamodule: Optional[_ChEBIDataExtractor] = None, diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 2b3b1b0e..dcd72ed0 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -215,7 +215,6 @@ def _read_data(self, raw_data: str) -> List[int]: return None def _back_to_smiles(self, smiles_encoded): - token_file = self.reader.token_path token_coding = {} counter = 0 diff --git a/chebai/result/analyse_sem.py b/chebai/result/analyse_sem.py index b33ea01d..2b2e0693 100644 --- a/chebai/result/analyse_sem.py +++ b/chebai/result/analyse_sem.py @@ -642,9 +642,9 @@ def run_all( for file in os.listdir(os.path.join(ckpt_dir, run_name)): if f"epoch={epoch}_" in file or f"epoch={epoch}." in file: ckpt_path = os.path.join(os.path.join(ckpt_dir, run_name, file)) - assert ( - ckpt_path is not None - ), f"Failed to find checkpoint for epoch {epoch} in {os.path.join(ckpt_dir, run_name)}" + assert ckpt_path is not None, ( + f"Failed to find checkpoint for epoch {epoch} in {os.path.join(ckpt_dir, run_name)}" + ) print(f"Starting run {run_name} (epoch {epoch})") for dataset, dataset_key in prediction_datasets: diff --git a/chebai/result/base.py b/chebai/result/base.py index 7acd820c..311e3f1a 100644 --- a/chebai/result/base.py +++ b/chebai/result/base.py @@ -22,9 +22,9 @@ def close(self): pass def __init_subclass__(cls, **kwargs): - assert ( - cls._identifier() not in PROCESSORS - ), f"ResultProcessor {cls.__name__} does not have a unique identifier" + assert cls._identifier() not in PROCESSORS, ( + f"ResultProcessor {cls.__name__} does not have a unique identifier" + ) PROCESSORS[cls._identifier()] = cls def process_prediction(self, proc_id, features, labels, pred, ident): diff --git a/chebai/result/classification.py b/chebai/result/classification.py index ab8b1e2d..725be248 100644 --- a/chebai/result/classification.py +++ b/chebai/result/classification.py @@ -120,7 +120,6 @@ def metrics_classification_multilabel( labels: Tensor, device: torch.device, ): - if device != labels.device: device = labels.device @@ -145,7 +144,6 @@ def metrics_classification_binary( labels: Tensor, device: torch.device, ): - if device != labels.device: device = labels.device diff --git a/chebai/result/generate_class_properties.py b/chebai/result/generate_class_properties.py index 6a043e5a..cea62a16 100644 --- a/chebai/result/generate_class_properties.py +++ b/chebai/result/generate_class_properties.py @@ -242,7 +242,9 @@ def generate( "train", "val", "test", - ], f"Given data partition invalid: {data_partition}, Choose one of the value among `train`, `val`, `test` " + ], ( + f"Given data partition invalid: {data_partition}, Choose one of the value among `train`, `val`, `test` " + ) generator = ClassesPropertiesGenerator() generator.generate_props( data_partition, diff --git a/chebai/result/utils.py b/chebai/result/utils.py index 55549297..61bd3782 100644 --- a/chebai/result/utils.py +++ b/chebai/result/utils.py @@ -655,9 +655,9 @@ def load_data_instance(data_cls_path: str, data_cls_kwargs: dict): assert isinstance(data_cls_kwargs, dict), "data_cls_kwargs must be a dict" data_cls = load_class(data_cls_path) assert isinstance(data_cls, type), f"{data_cls} is not a class." - assert issubclass( - data_cls, XYBaseDataModule - ), f"{data_cls} must inherit from XYBaseDataModule" + assert issubclass(data_cls, XYBaseDataModule), ( + f"{data_cls} must inherit from XYBaseDataModule" + ) return data_cls(**data_cls_kwargs) @@ -682,17 +682,17 @@ def load_model_for_inference( lightning_cls = load_class(model_cls_path) assert isinstance(lightning_cls, type), f"{lightning_cls} is not a class." - assert issubclass( - lightning_cls, ChebaiBaseNet - ), f"{lightning_cls} must inherit from ChebaiBaseNet" + assert issubclass(lightning_cls, ChebaiBaseNet), ( + f"{lightning_cls} must inherit from ChebaiBaseNet" + ) try: model = lightning_cls.load_from_checkpoint(model_ckpt_path, **model_load_kwargs) except Exception as e: raise RuntimeError(f"Error loading model {model_name} \n Error: {e}") from e - assert isinstance( - model, ChebaiBaseNet - ), f"Model: {model}(Model Name: {model_name}) is not a ChebaiBaseNet instance." + assert isinstance(model, ChebaiBaseNet), ( + f"Model: {model}(Model Name: {model_name}) is not a ChebaiBaseNet instance." + ) model.eval() model.freeze() return model diff --git a/chebai/train.py b/chebai/train.py index 883f2263..763b6304 100644 --- a/chebai/train.py +++ b/chebai/train.py @@ -277,9 +277,9 @@ def batchify(x: List, y: List) -> List: ] -def load_data() -> ( - Tuple[List[Molecule], List[torch.Tensor], List[Molecule], List[torch.Tensor]] -): +def load_data() -> Tuple[ + List[Molecule], List[torch.Tensor], List[Molecule], List[torch.Tensor] +]: """ Load and preprocess the data. diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index e93cff85..5c960007 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -159,7 +159,6 @@ def log_dir(self) -> Optional[str]: class LoadDataLaterFitLoop(_FitLoop): - def on_advance_start(self) -> None: """Calls the hook ``on_train_epoch_start`` **before** the dataloaders are setup. This is necessary so that the dataloaders can get information from the model. For example: The on_train_epoch_start diff --git a/configs/data/moleculenet/bace_moleculenet.yml b/configs/data/moleculenet/bace_moleculenet.yml index e5d4bdb7..bd6c04a8 100644 --- a/configs/data/moleculenet/bace_moleculenet.yml +++ b/configs/data/moleculenet/bace_moleculenet.yml @@ -2,4 +2,4 @@ class_path: chebai.preprocessing.datasets.molecule_classification.BaceChem init_args: batch_size: 32 validation_split: 0.05 - test_split: 0.15 \ No newline at end of file + test_split: 0.15 diff --git a/configs/data/moleculenet/hiv_moleculenet.yml b/configs/data/moleculenet/hiv_moleculenet.yml index 70c74434..3bef06b2 100644 --- a/configs/data/moleculenet/hiv_moleculenet.yml +++ b/configs/data/moleculenet/hiv_moleculenet.yml @@ -2,4 +2,4 @@ class_path: chebai.preprocessing.datasets.molecule_classification.HIVChem init_args: batch_size: 32 validation_split: 0.05 - test_split: 0.15 \ No newline at end of file + test_split: 0.15 diff --git a/configs/data/moleculenet/lipo_moleculenet.yml b/configs/data/moleculenet/lipo_moleculenet.yml index c246db5b..79ac3db6 100644 --- a/configs/data/moleculenet/lipo_moleculenet.yml +++ b/configs/data/moleculenet/lipo_moleculenet.yml @@ -2,4 +2,4 @@ class_path: chebai.preprocessing.datasets.molecule_regression.LipoChem init_args: batch_size: 32 validation_split: 0.05 - test_split: 0.15 \ No newline at end of file + test_split: 0.15 diff --git a/configs/data/moleculenet/muv_moleculenet.yml b/configs/data/moleculenet/muv_moleculenet.yml index f4eba3e1..d7498305 100644 --- a/configs/data/moleculenet/muv_moleculenet.yml +++ b/configs/data/moleculenet/muv_moleculenet.yml @@ -2,4 +2,4 @@ class_path: chebai.preprocessing.datasets.molecule_classification.MUVChem init_args: batch_size: 32 validation_split: 0.05 - test_split: 0.15 \ No newline at end of file + test_split: 0.15 diff --git a/configs/data/moleculenet/sider_moleculenet.yml b/configs/data/moleculenet/sider_moleculenet.yml index a1d635c5..1a1d81ee 100644 --- a/configs/data/moleculenet/sider_moleculenet.yml +++ b/configs/data/moleculenet/sider_moleculenet.yml @@ -2,4 +2,4 @@ class_path: chebai.preprocessing.datasets.molecule_classification.SiderChem init_args: batch_size: 10 validation_split: 0.05 - test_split: 0.15 \ No newline at end of file + test_split: 0.15 diff --git a/configs/loss/bce.yml b/configs/loss/bce.yml index 10135513..6cf1c0df 100644 --- a/configs/loss/bce.yml +++ b/configs/loss/bce.yml @@ -1,3 +1,3 @@ class_path: chebai.loss.bce_weighted.BCEWeighted init_args: - beta: 1000 \ No newline at end of file + beta: 1000 diff --git a/configs/loss/bce_new.yml b/configs/loss/bce_new.yml index f8fbe98d..d53533c0 100644 --- a/configs/loss/bce_new.yml +++ b/configs/loss/bce_new.yml @@ -1 +1 @@ -class_path: torch.nn.BCEWithLogitsLoss \ No newline at end of file +class_path: torch.nn.BCEWithLogitsLoss diff --git a/configs/loss/bce_try.yml b/configs/loss/bce_try.yml index ff8f9d4e..6ee636d0 100644 --- a/configs/loss/bce_try.yml +++ b/configs/loss/bce_try.yml @@ -1 +1 @@ -class_path: torch.nn.BCELoss \ No newline at end of file +class_path: torch.nn.BCELoss diff --git a/configs/loss/focal_loss_12.yml b/configs/loss/focal_loss_12.yml index 0351a942..5d8f09df 100644 --- a/configs/loss/focal_loss_12.yml +++ b/configs/loss/focal_loss_12.yml @@ -1,4 +1,4 @@ class_path: chebai.loss.focal_loss.FocalLoss init_args: task_type: multi-label - num_classes: 12 \ No newline at end of file + num_classes: 12 diff --git a/configs/loss/mae.yml b/configs/loss/mae.yml index 75e011be..03378498 100644 --- a/configs/loss/mae.yml +++ b/configs/loss/mae.yml @@ -1 +1 @@ -class_path: torch.nn.L1Loss \ No newline at end of file +class_path: torch.nn.L1Loss diff --git a/configs/loss/mse.yml b/configs/loss/mse.yml index 16fab1c8..92c92b05 100644 --- a/configs/loss/mse.yml +++ b/configs/loss/mse.yml @@ -1 +1 @@ -class_path: torch.nn.MSELoss \ No newline at end of file +class_path: torch.nn.MSELoss diff --git a/configs/metrics/mae.yml b/configs/metrics/mae.yml index 323e5fb4..4985eb61 100644 --- a/configs/metrics/mae.yml +++ b/configs/metrics/mae.yml @@ -2,4 +2,4 @@ class_path: torchmetrics.MetricCollection init_args: metrics: mae: - class_path: torchmetrics.regression.MeanAbsoluteError \ No newline at end of file + class_path: torchmetrics.regression.MeanAbsoluteError diff --git a/configs/metrics/micro-macro-f1-roc-auc-17_test.yml b/configs/metrics/micro-macro-f1-roc-auc-17_test.yml index 0a42fb0e..7c790ee8 100644 --- a/configs/metrics/micro-macro-f1-roc-auc-17_test.yml +++ b/configs/metrics/micro-macro-f1-roc-auc-17_test.yml @@ -12,11 +12,11 @@ init_args: class_path: torchmetrics.classification.MultilabelAUROC init_args: num_labels: 17 - precision: + precision: class_path: torchmetrics.classification.MultilabelPrecision init_args: num_labels: 17 recall: class_path: torchmetrics.classification.MultilabelRecall init_args: - num_labels: 17 \ No newline at end of file + num_labels: 17 diff --git a/configs/metrics/mse-rmse-r2.yml b/configs/metrics/mse-rmse-r2.yml index ad7bb53f..f58fc1de 100644 --- a/configs/metrics/mse-rmse-r2.yml +++ b/configs/metrics/mse-rmse-r2.yml @@ -8,4 +8,4 @@ init_args: init_args: squared: True r2: - class_path: torchmetrics.regression.R2Score \ No newline at end of file + class_path: torchmetrics.regression.R2Score diff --git a/configs/metrics/mse.yml b/configs/metrics/mse.yml index 1914442e..c62efca2 100644 --- a/configs/metrics/mse.yml +++ b/configs/metrics/mse.yml @@ -2,4 +2,4 @@ class_path: torchmetrics.MetricCollection init_args: metrics: mse: - class_path: torchmetrics.regression.MeanSquaredError \ No newline at end of file + class_path: torchmetrics.regression.MeanSquaredError diff --git a/configs/model/electra.yml b/configs/model/electra.yml index 053f5d65..a6c197bd 100644 --- a/configs/model/electra.yml +++ b/configs/model/electra.yml @@ -9,4 +9,4 @@ init_args: num_attention_heads: 8 num_hidden_layers: 6 type_vocab_size: 1 - hidden_size: 256 \ No newline at end of file + hidden_size: 256 diff --git a/configs/training/binary_callbacks.yml b/configs/training/binary_callbacks.yml index 013b8c77..21c85c62 100644 --- a/configs/training/binary_callbacks.yml +++ b/configs/training/binary_callbacks.yml @@ -32,7 +32,7 @@ # patience: 5 # verbose: False # mode: "max" - + # - class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping # init_args: diff --git a/configs/training/wandb_logger.yml b/configs/training/wandb_logger.yml index b7c51418..b0dd8870 100644 --- a/configs/training/wandb_logger.yml +++ b/configs/training/wandb_logger.yml @@ -3,4 +3,4 @@ init_args: save_dir: logs project: 'chebai' entity: 'chebai' - log_model: 'all' \ No newline at end of file + log_model: 'all' diff --git a/tests/unit/dataset_classes/testChebiDataExtractor.py b/tests/unit/dataset_classes/testChebiDataExtractor.py index 8da900da..800384d3 100644 --- a/tests/unit/dataset_classes/testChebiDataExtractor.py +++ b/tests/unit/dataset_classes/testChebiDataExtractor.py @@ -9,7 +9,6 @@ class TestChEBIDataExtractor(unittest.TestCase): - @classmethod @patch.multiple(_ChEBIDataExtractor, __abstractmethods__=frozenset()) @patch.object(_ChEBIDataExtractor, "base_dir", new_callable=PropertyMock) diff --git a/tests/unit/dataset_classes/testChebiOverXPartial.py b/tests/unit/dataset_classes/testChebiOverXPartial.py index 0e263335..b97a9a33 100644 --- a/tests/unit/dataset_classes/testChebiOverXPartial.py +++ b/tests/unit/dataset_classes/testChebiOverXPartial.py @@ -8,7 +8,6 @@ class TestChEBIOverX(unittest.TestCase): - @classmethod @patch.multiple(ChEBIOverXPartial, __abstractmethods__=frozenset()) @patch("os.makedirs", return_value=None) diff --git a/tests/unit/mock_data/tox_mock_data.py b/tests/unit/mock_data/tox_mock_data.py index bee21a2d..fcf5633f 100644 --- a/tests/unit/mock_data/tox_mock_data.py +++ b/tests/unit/mock_data/tox_mock_data.py @@ -194,9 +194,9 @@ def get_processed_grouped_data() -> List[Dict]: processed_data = Tox21MolNetMockData.get_processed_data() groups = ["A", "A", "B", "B", "C", "C", "C", "C"] - assert len(processed_data) == len( - groups - ), "The number of processed data entries does not match the number of groups." + assert len(processed_data) == len(groups), ( + "The number of processed data entries does not match the number of groups." + ) # Combine processed data with their corresponding groups grouped_data = [ @@ -208,7 +208,6 @@ def get_processed_grouped_data() -> List[Dict]: class Tox21ChallengeMockData: - MOL_BINARY_STR = ( b"cyclobutane\n" b" RDKit 2D\n\n" diff --git a/tutorials/demo_process_results.ipynb b/tutorials/demo_process_results.ipynb index 76a181b6..55f358a9 100644 --- a/tutorials/demo_process_results.ipynb +++ b/tutorials/demo_process_results.ipynb @@ -338,9 +338,15 @@ " \"per_epoch=99_val_loss=0.0167_val_micro-f1=0.91.ckpt\",\n", ")\n", "model_path_v200 = \"electra_c100_bce_unweighted.ckpt\"\n", - "model_v148 = Electra.load_from_checkpoint(model_path_v148, pretrained_checkpoint=None).to(\"cpu\")\n", - "model_v200 = Electra.load_from_checkpoint(model_path_v200, pretrained_checkpoint=None).to(\"cpu\")\n", - "model_v227 = Electra.load_from_checkpoint(model_path_v227, pretrained_checkpoint=None).to(\"cpu\")\n", + "model_v148 = Electra.load_from_checkpoint(\n", + " model_path_v148, pretrained_checkpoint=None\n", + ").to(\"cpu\")\n", + "model_v200 = Electra.load_from_checkpoint(\n", + " model_path_v200, pretrained_checkpoint=None\n", + ").to(\"cpu\")\n", + "model_v227 = Electra.load_from_checkpoint(\n", + " model_path_v227, pretrained_checkpoint=None\n", + ").to(\"cpu\")\n", "\n", "data_module_v200 = ChEBIOver100()\n", "data_module_v148 = ChEBIOver100(chebi_version_train=148)\n", diff --git a/tutorials/process_results_old_chebi.ipynb b/tutorials/process_results_old_chebi.ipynb index cb3ec3be..a61e8a9f 100644 --- a/tutorials/process_results_old_chebi.ipynb +++ b/tutorials/process_results_old_chebi.ipynb @@ -61,8 +61,12 @@ "model_path_v200 = os.path.join(\"models\", \"electra_c100_bce_unweighted.ckpt\")\n", "model_path_v148 = os.path.join(\"models\", \"electra_c100_bce_unweighted_v148.ckpt\")\n", "\n", - "model_v200 = Electra.load_from_checkpoint(model_path_v200, pretrained_checkpoint=None).to(DEVICE)\n", - "model_v148 = Electra.load_from_checkpoint(model_path_v148, pretrained_checkpoint=None).to(DEVICE)\n", + "model_v200 = Electra.load_from_checkpoint(\n", + " model_path_v200, pretrained_checkpoint=None\n", + ").to(DEVICE)\n", + "model_v148 = Electra.load_from_checkpoint(\n", + " model_path_v148, pretrained_checkpoint=None\n", + ").to(DEVICE)\n", "\n", "data_module_v200 = ChEBIOver100(chebi_version=200)\n", "data_module_v148 = ChEBIOver100(chebi_version=200, chebi_version_train=148)"