diff --git a/optimum/exporters/executorch/quantization.py b/optimum/exporters/executorch/quantization.py index 7e32244..53c3d4e 100644 --- a/optimum/exporters/executorch/quantization.py +++ b/optimum/exporters/executorch/quantization.py @@ -18,6 +18,11 @@ import torch +# Applied to IntxWeightOnlyConfig and Int8DynamicActivationIntxWeightConfig. +# Not applied to Int4WeightOnlyConfig (different API) or UIntxWeightOnlyConfig (no such param). +DEFAULT_QPARAMS_ALGORITHM = "hqq_scale_only" + + def quantize_model_( eager_model: torch.nn.Module, qlinear_config: Optional[str] = None, @@ -25,7 +30,12 @@ def quantize_model_( qlinear_packing_format: Optional[str] = None, qembedding_config: Optional[str] = None, qembedding_group_size: Optional[int] = 0, + qparams_algorithm: Optional[str] = None, ) -> torch.nn.Module: + # qparams_algorithm is applied to IntxWeightOnlyConfig and Int8DynamicActivationIntxWeightConfig. + # Not applied to Int4WeightOnlyConfig (different API) or UIntxWeightOnlyConfig (no such param). + if qparams_algorithm is None: + qparams_algorithm = DEFAULT_QPARAMS_ALGORITHM if not (qlinear_config or qembedding_config): return @@ -54,10 +64,12 @@ def quantize_model_( "4w": IntxWeightOnlyConfig( weight_dtype=torch.int4, granularity=embedding_weight_granularity, + intx_choose_qparams_algorithm=qparams_algorithm, ), "8w": IntxWeightOnlyConfig( weight_dtype=torch.int8, granularity=embedding_weight_granularity, + intx_choose_qparams_algorithm=qparams_algorithm, ), }[qembedding_config] @@ -75,6 +87,7 @@ def build_linear_config(quant_config_key: str, granularity: str, packing_format: return Int8DynamicActivationIntxWeightConfig( weight_dtype=torch.int4, weight_granularity=granularity, + intx_choose_qparams_algorithm=qparams_algorithm, ) if quant_config_key == "4w": # Determine if we need to use Int4WeightOnlyConfig with int4_packing_format @@ -88,16 +101,19 @@ def build_linear_config(quant_config_key: str, granularity: str, packing_format: return IntxWeightOnlyConfig( weight_dtype=torch.int4, granularity=granularity, + intx_choose_qparams_algorithm=qparams_algorithm, ) if quant_config_key == "8w": return IntxWeightOnlyConfig( weight_dtype=torch.int8, granularity=granularity, + intx_choose_qparams_algorithm=qparams_algorithm, ) if quant_config_key == "8da8w": return Int8DynamicActivationIntxWeightConfig( weight_dtype=torch.int8, weight_granularity=PerAxis(0), + intx_choose_qparams_algorithm=qparams_algorithm, ) if quant_config_key == "fpa4w": # Need to import to load the ops