22
33import math
44from functools import partial
5- from typing import Any , Callable , Dict , List , Optional , Tuple , Union
5+ from typing import Any , Callable , Dict , List , Optional , Tuple , Union , cast
66
77import timm
88import torch
@@ -284,7 +284,6 @@ def __init__(
284284 # Weight Initialization
285285 self .init_std = init_std
286286 self .apply (self ._init_weights )
287- self .fix_init_weight ()
288287
289288 def fix_init_weight (self ) -> None :
290289 """Fix initialization of weights by rescaling them according to layer depth."""
@@ -493,7 +492,6 @@ def __init__(
493492 self .init_std = init_std
494493 trunc_normal_ (self .mask_token , std = self .init_std )
495494 self .apply (self ._init_weights )
496- # self.fix_init_weight()
497495
498496 def fix_init_weight (self ) -> None :
499497 """Fix initialization of weights by rescaling them according to layer depth."""
@@ -567,9 +565,12 @@ def forward(
567565 return self .predictor_proj (x )
568566
569567
570- @store (
571- group = "modules/encoders" ,
572- provider = "mmlearn" ,
568+ @cast (
569+ VisionTransformerPredictor ,
570+ store (
571+ group = "modules/encoders" ,
572+ provider = "mmlearn" ,
573+ ),
573574)
574575def vit_predictor (** kwargs : Any ) -> VisionTransformerPredictor :
575576 """
@@ -585,9 +586,12 @@ def vit_predictor(**kwargs: Any) -> VisionTransformerPredictor:
585586 )
586587
587588
588- @store (
589- group = "modules/encoders" ,
590- provider = "mmlearn" ,
589+ @cast (
590+ VisionTransformer ,
591+ store (
592+ group = "modules/encoders" ,
593+ provider = "mmlearn" ,
594+ ),
591595)
592596def vit_tiny (patch_size : int = 16 , ** kwargs : Any ) -> VisionTransformer :
593597 """
@@ -610,9 +614,12 @@ def vit_tiny(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
610614 )
611615
612616
613- @store (
614- group = "modules/encoders" ,
615- provider = "mmlearn" ,
617+ @cast (
618+ VisionTransformer ,
619+ store (
620+ group = "modules/encoders" ,
621+ provider = "mmlearn" ,
622+ ),
616623)
617624def vit_small (patch_size : int = 16 , ** kwargs : Any ) -> VisionTransformer :
618625 """
@@ -635,9 +642,12 @@ def vit_small(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
635642 )
636643
637644
638- @store (
639- group = "modules/encoders" ,
640- provider = "mmlearn" ,
645+ @cast (
646+ VisionTransformer ,
647+ store (
648+ group = "modules/encoders" ,
649+ provider = "mmlearn" ,
650+ ),
641651)
642652def vit_base (patch_size : int = 16 , ** kwargs : Any ) -> VisionTransformer :
643653 """
@@ -660,9 +670,12 @@ def vit_base(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
660670 )
661671
662672
663- @store (
664- group = "modules/encoders" ,
665- provider = "mmlearn" ,
673+ @cast (
674+ VisionTransformer ,
675+ store (
676+ group = "modules/encoders" ,
677+ provider = "mmlearn" ,
678+ ),
666679)
667680def vit_large (patch_size : int = 16 , ** kwargs : Any ) -> VisionTransformer :
668681 """
@@ -685,9 +698,12 @@ def vit_large(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
685698 )
686699
687700
688- @store (
689- group = "modules/encoders" ,
690- provider = "mmlearn" ,
701+ @cast (
702+ VisionTransformer ,
703+ store (
704+ group = "modules/encoders" ,
705+ provider = "mmlearn" ,
706+ ),
691707)
692708def vit_huge (patch_size : int = 16 , ** kwargs : Any ) -> VisionTransformer :
693709 """
@@ -710,9 +726,12 @@ def vit_huge(patch_size: int = 16, **kwargs: Any) -> VisionTransformer:
710726 )
711727
712728
713- @store (
714- group = "modules/encoders" ,
715- provider = "mmlearn" ,
729+ @cast (
730+ VisionTransformer ,
731+ store (
732+ group = "modules/encoders" ,
733+ provider = "mmlearn" ,
734+ ),
716735)
717736def vit_giant (patch_size : int = 16 , ** kwargs : Any ) -> VisionTransformer :
718737 """
0 commit comments