Skip to content

Commit d343567

Browse files
committed
Fix gemma-3 issue
1 parent 2ca6652 commit d343567

File tree

2 files changed

+26
-12
lines changed

2 files changed

+26
-12
lines changed

syncode/common.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,32 @@
1212

1313

1414
def load_model(model_name, device, quantize, device_map = None):
15+
torch_dtype = torch.bfloat16 if quantize else "auto"
16+
device_map = device_map if device_map is not None else "auto"
17+
18+
attn_implementation = None
19+
if "gemma-3" in model_name:
20+
# This is due to the gemma-3 issue with SDPA implementation
21+
# https://github.com/google-deepmind/gemma/issues/169
22+
attn_implementation = "eager"
23+
logging.info("Using slower \"eager\" attention implementation for gemma-3 due to issue with SDPA implementation")
24+
1525
if model_name == 'test':
1626
model = AutoModelForCausalLM.from_pretrained('bigcode/tiny_starcoder_py').to(device)
1727
elif model_name == 'test-instruct':
1828
model = AutoModelForCausalLM.from_pretrained("rahuldshetty/tiny-starcoder-instruct")
1929
else:
2030
if device_map is not None:
21-
if (quantize):
22-
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True, device_map = device_map).eval()
23-
else:
24-
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True, device_map = device_map).eval()
25-
else:
26-
if (quantize):
27-
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device)
28-
else:
29-
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True).eval().to(device)
31+
logging.info(f"Loading model {model_name} with device:{device}, device_map:{device_map}, torch_dtype:{torch_dtype}")
32+
model = AutoModelForCausalLM.from_pretrained(
33+
model_name,
34+
torch_dtype=torch_dtype,
35+
cache_dir=HF_CACHE,
36+
token=HF_ACCESS_TOKEN,
37+
trust_remote_code=True,
38+
device_map = device_map,
39+
attn_implementation=attn_implementation
40+
).eval()
3041
return model
3142

3243
def load_tokenizer(model_name):
@@ -35,7 +46,12 @@ def load_tokenizer(model_name):
3546
elif model_name == 'test-instruct':
3647
tokenizer = AutoTokenizer.from_pretrained("rahuldshetty/tiny-starcoder-instruct")
3748
else:
38-
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=HF_CACHE, token=HF_ACCESS_TOKEN, trust_remote_code=True)
49+
tokenizer = AutoTokenizer.from_pretrained(
50+
model_name,
51+
cache_dir=HF_CACHE,
52+
token=HF_ACCESS_TOKEN,
53+
trust_remote_code=True
54+
)
3955
return tokenizer
4056

4157
def get_output_path(model_name, grammar, dataset, num_samples, mode):

syncode/language_model.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,7 @@ def get_tokenized_input(self, prompt: Union[str, list], batch_size: int):
179179
inputs = self.tokenizer(
180180
input_batch,
181181
return_tensors="pt",
182-
pad_to_multiple_of=8,
183182
).to(self.model.device)
184-
185183
return inputs
186184

187185
@torch.inference_mode()

0 commit comments

Comments
 (0)