diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 28d306f..679bd3f 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -919,15 +919,27 @@ def export( example_cache_position, ) - self.exported_sampler = self._export_sampler( - torch.randn((1, 1, self.config.vocab_size), dtype=self.model.dtype, device=self.model.device) - ) + # Skip sampler export for MPS + bfloat16 due to Metal shader compilation error + # (assigning float to bfloat in generated shader code) + is_mps_bfloat16 = str(self.model.device).startswith("mps") and self.model.dtype == torch.bfloat16 + if is_mps_bfloat16: + logging.warning( + "Skipping sampler export for MPS + bfloat16 due to Metal shader compilation issues. " + "The runner will use CPU-based sampling instead." + ) + self.exported_sampler = None + else: + self.exported_sampler = self._export_sampler( + torch.randn((1, 1, self.config.vocab_size), dtype=self.model.dtype, device=self.model.device) + ) - return { + result = { "encoder": self.exported_encoder, # Not called "text_encoder" because the encoder could be non-text too, e.g. Whisper. "text_decoder": self.exported_decoder, - "sampler": self.exported_sampler, } + if self.exported_sampler is not None: + result["sampler"] = self.exported_sampler + return result def generate(self, prompt_token_ids, max_new_tokens): with torch.no_grad():