Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 0 additions & 26 deletions .github/workflows/lint.yml

This file was deleted.

19 changes: 19 additions & 0 deletions .github/workflows/precommit-action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
name: Pre-commit Check

on:
push:
branches: [main, master]
pull_request:

jobs:
pre-commit:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- uses: actions/setup-python@v5
with:
python-version: '3.10'

- name: Run pre-commit
uses: pre-commit/action@v3.0.1
44 changes: 16 additions & 28 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,31 +1,19 @@
repos:
- repo: https://github.com/psf/black
rev: "25.1.0"
hooks:
- id: black
- id: black-jupyter # for formatting jupyter-notebook
# Use `pre-commit autoupdate` to update all the hook.

- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
args: ["--profile=black"]
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version. https://docs.astral.sh/ruff/integrations/#pre-commit
rev: v0.14.11
hooks:
# Run the linter.
- id: ruff-check
args: [ --fix ]
# Run the formatter.
- id: ruff-format

- repo: https://github.com/asottile/seed-isort-config
rev: v2.2.0
hooks:
- id: seed-isort-config

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.2
hooks:
- id: ruff
args: [--fix]
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v6.0.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
- id: trailing-whitespace
16 changes: 9 additions & 7 deletions chebai/loss/bce_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,20 @@ def __init__(
data_extractor = data_extractor.labeled
self.data_extractor = data_extractor

assert (
isinstance(beta, float) and beta > 0.0
), f"Beta parameter must be a float with value greater than 0.0, for loss class {self.__class__.__name__}."
assert isinstance(beta, float) and beta > 0.0, (
f"Beta parameter must be a float with value greater than 0.0, for loss class {self.__class__.__name__}."
)

assert (
self.data_extractor is not None
), f"Data extractor must be provided if this loss class ({self.__class__.__name__}) is used."
assert self.data_extractor is not None, (
f"Data extractor must be provided if this loss class ({self.__class__.__name__}) is used."
)

assert all(
os.path.exists(os.path.join(self.data_extractor.processed_dir, file_name))
for file_name in self.data_extractor.processed_file_names
), "Dataset files not found. Make sure the dataset is processed before using this loss."
), (
"Dataset files not found. Make sure the dataset is processed before using this loss."
)

assert (
isinstance(self.data_extractor, _ChEBIDataExtractor)
Expand Down
6 changes: 3 additions & 3 deletions chebai/loss/focal_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def __init__(
and alpha is not None
and isinstance(alpha, (list, torch.Tensor))
):
assert (
num_classes is not None
), "num_classes must be specified for multi-class classification"
assert num_classes is not None, (
"num_classes must be specified for multi-class classification"
)
if isinstance(alpha, list):
self.alpha = torch.Tensor(alpha)
else:
Expand Down
24 changes: 12 additions & 12 deletions chebai/preprocessing/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ def __init__(
self.prediction_kind = prediction_kind
self.data_limit = data_limit
self.label_filter = label_filter
assert (balance_after_filter is not None) or (
self.label_filter is None
), "Filter balancing requires a filter"
assert (balance_after_filter is not None) or (self.label_filter is None), (
"Filter balancing requires a filter"
)
self.balance_after_filter = balance_after_filter
self.num_workers = num_workers
self.persistent_workers: bool = bool(persistent_workers)
Expand All @@ -108,13 +108,13 @@ def __init__(
self.use_inner_cross_validation = (
inner_k_folds > 1
) # only use cv if there are at least 2 folds
assert (
fold_index is None or self.use_inner_cross_validation is not None
), "fold_index can only be set if cross validation is used"
assert fold_index is None or self.use_inner_cross_validation is not None, (
"fold_index can only be set if cross validation is used"
)
if fold_index is not None and self.inner_k_folds is not None:
assert (
fold_index < self.inner_k_folds
), "fold_index can't be larger than the total number of folds"
assert fold_index < self.inner_k_folds, (
"fold_index can't be larger than the total number of folds"
)
self.fold_index = fold_index
self._base_dir = base_dir
self.n_token_limit = n_token_limit
Expand All @@ -137,9 +137,9 @@ def num_of_labels(self):

@property
def feature_vector_size(self):
assert (
self._feature_vector_size is not None
), "size of feature vector must be set"
assert self._feature_vector_size is not None, (
"size of feature vector must be set"
)
return self._feature_vector_size

@property
Expand Down
6 changes: 3 additions & 3 deletions chebai/preprocessing/datasets/chebi.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ def __init__(
**kwargs,
):
if bool(augment_smiles):
assert (
int(aug_smiles_variations) > 0
), "Number of variations must be greater than 0"
assert int(aug_smiles_variations) > 0, (
"Number of variations must be greater than 0"
)
aug_smiles_variations = int(aug_smiles_variations)

if not kwargs.get("splits_file_path", None):
Expand Down
4 changes: 1 addition & 3 deletions chebai/preprocessing/datasets/pubchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,9 +408,7 @@ def download(self):
print("Selecting most dissimilar values from random subsets...")
for i in tqdm.tqdm(range(self.n_random_subsets)):
smiles_i = random_smiles[
i
* len(random_smiles)
// self.n_random_subsets : (i + 1)
i * len(random_smiles) // self.n_random_subsets : (i + 1)
* len(random_smiles)
// self.n_random_subsets
]
Expand Down
1 change: 0 additions & 1 deletion chebai/preprocessing/migration/chebi_data_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,6 @@ def _old_raw_dir(self) -> str:


class Main:

def migrate(
self,
datamodule: Optional[_ChEBIDataExtractor] = None,
Expand Down
1 change: 0 additions & 1 deletion chebai/preprocessing/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ def _read_data(self, raw_data: str) -> List[int]:
return None

def _back_to_smiles(self, smiles_encoded):

token_file = self.reader.token_path
token_coding = {}
counter = 0
Expand Down
6 changes: 3 additions & 3 deletions chebai/result/analyse_sem.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,9 +642,9 @@ def run_all(
for file in os.listdir(os.path.join(ckpt_dir, run_name)):
if f"epoch={epoch}_" in file or f"epoch={epoch}." in file:
ckpt_path = os.path.join(os.path.join(ckpt_dir, run_name, file))
assert (
ckpt_path is not None
), f"Failed to find checkpoint for epoch {epoch} in {os.path.join(ckpt_dir, run_name)}"
assert ckpt_path is not None, (
f"Failed to find checkpoint for epoch {epoch} in {os.path.join(ckpt_dir, run_name)}"
)
print(f"Starting run {run_name} (epoch {epoch})")

for dataset, dataset_key in prediction_datasets:
Expand Down
6 changes: 3 additions & 3 deletions chebai/result/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@ def close(self):
pass

def __init_subclass__(cls, **kwargs):
assert (
cls._identifier() not in PROCESSORS
), f"ResultProcessor {cls.__name__} does not have a unique identifier"
assert cls._identifier() not in PROCESSORS, (
f"ResultProcessor {cls.__name__} does not have a unique identifier"
)
PROCESSORS[cls._identifier()] = cls

def process_prediction(self, proc_id, features, labels, pred, ident):
Expand Down
2 changes: 0 additions & 2 deletions chebai/result/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def metrics_classification_multilabel(
labels: Tensor,
device: torch.device,
):

if device != labels.device:
device = labels.device

Expand All @@ -145,7 +144,6 @@ def metrics_classification_binary(
labels: Tensor,
device: torch.device,
):

if device != labels.device:
device = labels.device

Expand Down
4 changes: 3 additions & 1 deletion chebai/result/generate_class_properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,9 @@ def generate(
"train",
"val",
"test",
], f"Given data partition invalid: {data_partition}, Choose one of the value among `train`, `val`, `test` "
], (
f"Given data partition invalid: {data_partition}, Choose one of the value among `train`, `val`, `test` "
)
generator = ClassesPropertiesGenerator()
generator.generate_props(
data_partition,
Expand Down
18 changes: 9 additions & 9 deletions chebai/result/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,9 +655,9 @@ def load_data_instance(data_cls_path: str, data_cls_kwargs: dict):
assert isinstance(data_cls_kwargs, dict), "data_cls_kwargs must be a dict"
data_cls = load_class(data_cls_path)
assert isinstance(data_cls, type), f"{data_cls} is not a class."
assert issubclass(
data_cls, XYBaseDataModule
), f"{data_cls} must inherit from XYBaseDataModule"
assert issubclass(data_cls, XYBaseDataModule), (
f"{data_cls} must inherit from XYBaseDataModule"
)
return data_cls(**data_cls_kwargs)


Expand All @@ -682,17 +682,17 @@ def load_model_for_inference(
lightning_cls = load_class(model_cls_path)

assert isinstance(lightning_cls, type), f"{lightning_cls} is not a class."
assert issubclass(
lightning_cls, ChebaiBaseNet
), f"{lightning_cls} must inherit from ChebaiBaseNet"
assert issubclass(lightning_cls, ChebaiBaseNet), (
f"{lightning_cls} must inherit from ChebaiBaseNet"
)
try:
model = lightning_cls.load_from_checkpoint(model_ckpt_path, **model_load_kwargs)
except Exception as e:
raise RuntimeError(f"Error loading model {model_name} \n Error: {e}") from e

assert isinstance(
model, ChebaiBaseNet
), f"Model: {model}(Model Name: {model_name}) is not a ChebaiBaseNet instance."
assert isinstance(model, ChebaiBaseNet), (
f"Model: {model}(Model Name: {model_name}) is not a ChebaiBaseNet instance."
)
model.eval()
model.freeze()
return model
Expand Down
6 changes: 3 additions & 3 deletions chebai/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,9 +277,9 @@ def batchify(x: List, y: List) -> List:
]


def load_data() -> (
Tuple[List[Molecule], List[torch.Tensor], List[Molecule], List[torch.Tensor]]
):
def load_data() -> Tuple[
List[Molecule], List[torch.Tensor], List[Molecule], List[torch.Tensor]
]:
"""
Load and preprocess the data.

Expand Down
1 change: 0 additions & 1 deletion chebai/trainer/CustomTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def log_dir(self) -> Optional[str]:


class LoadDataLaterFitLoop(_FitLoop):

def on_advance_start(self) -> None:
"""Calls the hook ``on_train_epoch_start`` **before** the dataloaders are setup. This is necessary
so that the dataloaders can get information from the model. For example: The on_train_epoch_start
Expand Down
2 changes: 1 addition & 1 deletion configs/data/moleculenet/bace_moleculenet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ class_path: chebai.preprocessing.datasets.molecule_classification.BaceChem
init_args:
batch_size: 32
validation_split: 0.05
test_split: 0.15
test_split: 0.15
2 changes: 1 addition & 1 deletion configs/data/moleculenet/hiv_moleculenet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ class_path: chebai.preprocessing.datasets.molecule_classification.HIVChem
init_args:
batch_size: 32
validation_split: 0.05
test_split: 0.15
test_split: 0.15
2 changes: 1 addition & 1 deletion configs/data/moleculenet/lipo_moleculenet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ class_path: chebai.preprocessing.datasets.molecule_regression.LipoChem
init_args:
batch_size: 32
validation_split: 0.05
test_split: 0.15
test_split: 0.15
2 changes: 1 addition & 1 deletion configs/data/moleculenet/muv_moleculenet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ class_path: chebai.preprocessing.datasets.molecule_classification.MUVChem
init_args:
batch_size: 32
validation_split: 0.05
test_split: 0.15
test_split: 0.15
2 changes: 1 addition & 1 deletion configs/data/moleculenet/sider_moleculenet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@ class_path: chebai.preprocessing.datasets.molecule_classification.SiderChem
init_args:
batch_size: 10
validation_split: 0.05
test_split: 0.15
test_split: 0.15
2 changes: 1 addition & 1 deletion configs/loss/bce.yml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
class_path: chebai.loss.bce_weighted.BCEWeighted
init_args:
beta: 1000
beta: 1000
2 changes: 1 addition & 1 deletion configs/loss/bce_new.yml
Original file line number Diff line number Diff line change
@@ -1 +1 @@
class_path: torch.nn.BCEWithLogitsLoss
class_path: torch.nn.BCEWithLogitsLoss
2 changes: 1 addition & 1 deletion configs/loss/bce_try.yml
Original file line number Diff line number Diff line change
@@ -1 +1 @@
class_path: torch.nn.BCELoss
class_path: torch.nn.BCELoss
2 changes: 1 addition & 1 deletion configs/loss/focal_loss_12.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
class_path: chebai.loss.focal_loss.FocalLoss
init_args:
task_type: multi-label
num_classes: 12
num_classes: 12
2 changes: 1 addition & 1 deletion configs/loss/mae.yml
Original file line number Diff line number Diff line change
@@ -1 +1 @@
class_path: torch.nn.L1Loss
class_path: torch.nn.L1Loss
2 changes: 1 addition & 1 deletion configs/loss/mse.yml
Original file line number Diff line number Diff line change
@@ -1 +1 @@
class_path: torch.nn.MSELoss
class_path: torch.nn.MSELoss
Loading