diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index 036b3cac4cb3..9902242d32a1 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -15,6 +15,9 @@ from vllm.model_executor.layers.fused_moe.modular_kernel import ( FusedMoEPrepareAndFinalize, ) +from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( + build_flashinfer_fp8_cutlass_moe_prepare_finalize, +) from vllm.platforms import current_platform from vllm.utils.import_utils import has_deep_ep, has_pplx @@ -77,12 +80,17 @@ def maybe_make_prepare_finalize( prepare_finalize: FusedMoEPrepareAndFinalize | None = None - # TODO(rob): update this as part of the MoE refactor. - assert not moe.use_flashinfer_cutlass_kernels, ( - "Must be created in modelopt.py or fp8.py" - ) + if moe.use_flashinfer_cutlass_kernels: + assert quant_config is not None + use_deepseek_fp8_block_scale = ( + quant_config is not None and quant_config.is_block_quantized + ) + prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( + moe=moe, + use_deepseek_fp8_block_scale=use_deepseek_fp8_block_scale, + ) - if moe.use_pplx_kernels: + elif moe.use_pplx_kernels: assert quant_config is not None hidden_dim_bytes, hidden_scale_bytes = pplx_hidden_dim_scale_bytes( diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 5353830db04b..f864634c6617 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -241,9 +241,7 @@ def flashinfer_cutlass_moe_fp4( apply_router_weight_on_input: bool = False, ) -> torch.Tensor: fused_experts = mk.FusedMoEModularKernel( - create_flashinfer_prepare_finalize( - use_dp=False, use_nvfp4=True, enable_alltoallv=False - ), + create_flashinfer_prepare_finalize(use_dp=False), FlashInferExperts( out_dtype=hidden_states.dtype, quant_config=quant_config, diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index f9104b6bf7f5..99efab468256 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -48,7 +48,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, - build_flashinfer_fp8_cutlass_moe_prepare_finalize, get_flashinfer_moe_backend, register_moe_scaling_factors, rotate_flashinfer_fp8_moe_weights, @@ -150,7 +149,7 @@ def get_fp8_moe_backend( if block_quant and current_platform.is_device_capability_family(100): raise ValueError( "FlashInfer FP8 MoE throughput backend does not " - "support block quantization on SM100. Please use " + "support block quantization. Please use " "VLLM_FLASHINFER_MOE_BACKEND=latency " "instead." ) @@ -1103,13 +1102,6 @@ def maybe_make_prepare_finalize( or self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM ): return None - elif self.fp8_backend == Fp8MoeBackend.FLASHINFER_CUTLASS: - prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( - self.moe, - use_deepseek_fp8_block_scale=self.block_quant, - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize return super().maybe_make_prepare_finalize(routing_tables) def select_gemm_impl( diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 115edb2b3a34..3327f856ce56 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -46,7 +46,6 @@ from vllm.model_executor.layers.quantization.utils.flashinfer_utils import ( FlashinferMoeBackend, apply_flashinfer_per_tensor_scale_fp8, - build_flashinfer_fp8_cutlass_moe_prepare_finalize, flashinfer_cutlass_moe_fp8, get_flashinfer_moe_backend, is_flashinfer_supporting_global_sf, @@ -751,17 +750,6 @@ def maybe_make_prepare_finalize( # TRT LLM not supported with all2all yet. if self.flashinfer_moe_backend == FlashinferMoeBackend.TENSORRT_LLM: return None - elif self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS: - # TP case: avoid convert to ModularKernelMethod - to be refactored. - if self.moe.dp_size == 1: - return None - - prepare_finalize = build_flashinfer_fp8_cutlass_moe_prepare_finalize( - self.moe, - use_deepseek_fp8_block_scale=False, - ) - logger.debug_once("%s", prepare_finalize.__class__.__name__) - return prepare_finalize return super().maybe_make_prepare_finalize(routing_tables) def select_gemm_impl( @@ -1456,9 +1444,6 @@ def maybe_make_prepare_finalize( self.allow_flashinfer and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS ): - # TP case: avoid convert to ModularKernelMethod - to be refactored. - if self.moe.dp_size == 1: - return None # For now, fp4 moe only works with the flashinfer dispatcher. prepare_finalize = build_flashinfer_fp4_cutlass_moe_prepare_finalize( self.moe