Skip to content

Commit facac4b

Browse files
committed
fix precommit errors
1 parent c06390c commit facac4b

File tree

3 files changed

+45
-30
lines changed

3 files changed

+45
-30
lines changed

mmlearn/modules/ema.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,6 @@ def restore(self, model: torch.nn.Module) -> torch.nn.Module:
130130
model.load_state_dict(d, strict=False)
131131
return model
132132

133-
# def state_dict(self) -> dict[str, Any]:
134-
# """Return the state dict of the model."""
135-
# return self.model.state_dict() # type: ignore[no-any-return]
136-
137133
@staticmethod
138134
def get_annealed_rate(
139135
start: float,

mmlearn/modules/encoders/vision.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import math
44
from functools import partial
5-
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
5+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast
66

77
import timm
88
import torch
@@ -284,7 +284,6 @@ def __init__(
284284
# Weight Initialization
285285
self.init_std = init_std
286286
self.apply(self._init_weights)
287-
self.fix_init_weight()
288287

289288
def fix_init_weight(self) -> None:
290289
"""Fix initialization of weights by rescaling them according to layer depth."""
@@ -493,7 +492,6 @@ def __init__(
493492
self.init_std = init_std
494493
trunc_normal_(self.mask_token, std=self.init_std)
495494
self.apply(self._init_weights)
496-
# self.fix_init_weight()
497495

498496
def fix_init_weight(self) -> None:
499497
"""Fix initialization of weights by rescaling them according to layer depth."""
@@ -567,9 +565,12 @@ def forward(
567565
return self.predictor_proj(x)
568566

569567

570-
@store(
571-
group="modules/encoders",
572-
provider="mmlearn",
568+
@cast(
569+
VisionTransformerPredictor,
570+
store(
571+
group="modules/encoders",
572+
provider="mmlearn",
573+
),
573574
)
574575
def vit_predictor(**kwargs: Any) -> VisionTransformerPredictor:
575576
"""
@@ -585,9 +586,12 @@ def vit_predictor(**kwargs: Any) -> VisionTransformerPredictor:
585586
)
586587

587588

588-
@store(
589-
group="modules/encoders",
590-
provider="mmlearn",
589+
@cast(
590+
VisionTransformer,
591+
store(
592+
group="modules/encoders",
593+
provider="mmlearn",
594+
),
591595
)
592596
def vit_tiny(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
593597
"""
@@ -610,9 +614,12 @@ def vit_tiny(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
610614
)
611615

612616

613-
@store(
614-
group="modules/encoders",
615-
provider="mmlearn",
617+
@cast(
618+
VisionTransformer,
619+
store(
620+
group="modules/encoders",
621+
provider="mmlearn",
622+
),
616623
)
617624
def vit_small(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
618625
"""
@@ -635,9 +642,12 @@ def vit_small(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
635642
)
636643

637644

638-
@store(
639-
group="modules/encoders",
640-
provider="mmlearn",
645+
@cast(
646+
VisionTransformer,
647+
store(
648+
group="modules/encoders",
649+
provider="mmlearn",
650+
),
641651
)
642652
def vit_base(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
643653
"""
@@ -660,9 +670,12 @@ def vit_base(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
660670
)
661671

662672

663-
@store(
664-
group="modules/encoders",
665-
provider="mmlearn",
673+
@cast(
674+
VisionTransformer,
675+
store(
676+
group="modules/encoders",
677+
provider="mmlearn",
678+
),
666679
)
667680
def vit_large(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
668681
"""
@@ -685,9 +698,12 @@ def vit_large(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
685698
)
686699

687700

688-
@store(
689-
group="modules/encoders",
690-
provider="mmlearn",
701+
@cast(
702+
VisionTransformer,
703+
store(
704+
group="modules/encoders",
705+
provider="mmlearn",
706+
),
691707
)
692708
def vit_huge(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
693709
"""
@@ -710,9 +726,12 @@ def vit_huge(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
710726
)
711727

712728

713-
@store(
714-
group="modules/encoders",
715-
provider="mmlearn",
729+
@cast(
730+
VisionTransformer,
731+
store(
732+
group="modules/encoders",
733+
provider="mmlearn",
734+
),
716735
)
717736
def vit_giant(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
718737
"""

mmlearn/tasks/ijepa_pretraining.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def _shared_step(
154154

155155
return loss
156156

157-
def configure_optimizers(self):
157+
def configure_optimizers(self) -> Dict[str, Any]:
158158
"""Configure the optimizer and learning rate scheduler."""
159159
weight_decay_value = 0.05 # Desired weight decay
160160

@@ -194,7 +194,7 @@ def configure_optimizers(self):
194194
},
195195
]
196196

197-
optimizer = torch.optim.AdamW(parameters)
197+
optimizer = torch.optim.AdamW(parameters, lr=0.001)
198198

199199
# Instantiate the learning rate scheduler if provided
200200
lr_scheduler = None

0 commit comments

Comments
 (0)