1212
1313
1414def load_model (model_name , device , quantize , device_map = None ):
15+ torch_dtype = torch .bfloat16 if quantize else "auto"
16+ device_map = device_map if device_map is not None else "auto"
17+
18+ attn_implementation = None
19+ if "gemma-3" in model_name :
20+ # This is due to the gemma-3 issue with SDPA implementation
21+ # https://github.com/google-deepmind/gemma/issues/169
22+ attn_implementation = "eager"
23+ logging .info ("Using slower \" eager\" attention implementation for gemma-3 due to issue with SDPA implementation" )
24+
1525 if model_name == 'test' :
1626 model = AutoModelForCausalLM .from_pretrained ('bigcode/tiny_starcoder_py' ).to (device )
1727 elif model_name == 'test-instruct' :
1828 model = AutoModelForCausalLM .from_pretrained ("rahuldshetty/tiny-starcoder-instruct" )
1929 else :
2030 if device_map is not None :
21- if (quantize ):
22- model = AutoModelForCausalLM .from_pretrained (model_name , torch_dtype = torch .bfloat16 , cache_dir = HF_CACHE , token = HF_ACCESS_TOKEN , trust_remote_code = True , device_map = device_map ).eval ()
23- else :
24- model = AutoModelForCausalLM .from_pretrained (model_name , cache_dir = HF_CACHE , token = HF_ACCESS_TOKEN , trust_remote_code = True , device_map = device_map ).eval ()
25- else :
26- if (quantize ):
27- model = AutoModelForCausalLM .from_pretrained (model_name , torch_dtype = torch .bfloat16 , cache_dir = HF_CACHE , token = HF_ACCESS_TOKEN , trust_remote_code = True ).eval ().to (device )
28- else :
29- model = AutoModelForCausalLM .from_pretrained (model_name , cache_dir = HF_CACHE , token = HF_ACCESS_TOKEN , trust_remote_code = True ).eval ().to (device )
31+ logging .info (f"Loading model { model_name } with device:{ device } , device_map:{ device_map } , torch_dtype:{ torch_dtype } " )
32+ model = AutoModelForCausalLM .from_pretrained (
33+ model_name ,
34+ torch_dtype = torch_dtype ,
35+ cache_dir = HF_CACHE ,
36+ token = HF_ACCESS_TOKEN ,
37+ trust_remote_code = True ,
38+ device_map = device_map ,
39+ attn_implementation = attn_implementation
40+ ).eval ()
3041 return model
3142
3243def load_tokenizer (model_name ):
@@ -35,7 +46,12 @@ def load_tokenizer(model_name):
3546 elif model_name == 'test-instruct' :
3647 tokenizer = AutoTokenizer .from_pretrained ("rahuldshetty/tiny-starcoder-instruct" )
3748 else :
38- tokenizer = AutoTokenizer .from_pretrained (model_name , cache_dir = HF_CACHE , token = HF_ACCESS_TOKEN , trust_remote_code = True )
49+ tokenizer = AutoTokenizer .from_pretrained (
50+ model_name ,
51+ cache_dir = HF_CACHE ,
52+ token = HF_ACCESS_TOKEN ,
53+ trust_remote_code = True
54+ )
3955 return tokenizer
4056
4157def get_output_path (model_name , grammar , dataset , num_samples , mode ):
0 commit comments