From 316e0f2f5a350de7645dc103024dbf553f3760b3 Mon Sep 17 00:00:00 2001 From: Mengwei Liu Date: Tue, 27 Jan 2026 23:25:12 -0800 Subject: [PATCH] Do not export sampler for device==mps && dtype==bfloat16 As titled. Seeing this error: ``` E 00:00:13.059103 executorch:et_metal.mm:246] ETMetalShaderLibrary: Failed to compile shader library: program_source:3813:29: error: assigning to 'bfloat' from incompatible type 'float' tmp_acc_2 = tmp0; ^~~~ E 00:00:13.059124 executorch:et_metal.mm:263] ETMetalShaderLibrary: Library not compiled E 00:00:13.059126 executorch:et_metal.mm:301] ETMetalShaderLibrary::getKernelFunction: Failed to get pipeline state for 'generated_kernel' E 00:00:13.059127 executorch:shim_mps.mm:105] aoti_torch_mps_get_kernel_function: Failed to get kernel function 'generated_kernel' E 00:00:13.059129 executorch:shim_mps.mm:517] aoti_torch_mps_run_command_block: null function handle ``` When running metal delegated argmax model. This PR disable the sampler export code path, let runner fallback to C++ CPU sampler. --- optimum/exporters/executorch/integrations.py | 22 +++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/optimum/exporters/executorch/integrations.py b/optimum/exporters/executorch/integrations.py index 28d306f..679bd3f 100644 --- a/optimum/exporters/executorch/integrations.py +++ b/optimum/exporters/executorch/integrations.py @@ -919,15 +919,27 @@ def export( example_cache_position, ) - self.exported_sampler = self._export_sampler( - torch.randn((1, 1, self.config.vocab_size), dtype=self.model.dtype, device=self.model.device) - ) + # Skip sampler export for MPS + bfloat16 due to Metal shader compilation error + # (assigning float to bfloat in generated shader code) + is_mps_bfloat16 = str(self.model.device).startswith("mps") and self.model.dtype == torch.bfloat16 + if is_mps_bfloat16: + logging.warning( + "Skipping sampler export for MPS + bfloat16 due to Metal shader compilation issues. " + "The runner will use CPU-based sampling instead." + ) + self.exported_sampler = None + else: + self.exported_sampler = self._export_sampler( + torch.randn((1, 1, self.config.vocab_size), dtype=self.model.dtype, device=self.model.device) + ) - return { + result = { "encoder": self.exported_encoder, # Not called "text_encoder" because the encoder could be non-text too, e.g. Whisper. "text_decoder": self.exported_decoder, - "sampler": self.exported_sampler, } + if self.exported_sampler is not None: + result["sampler"] = self.exported_sampler + return result def generate(self, prompt_token_ids, max_new_tokens): with torch.no_grad():