Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions optimum/exporters/executorch/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,24 @@
import torch


# Applied to IntxWeightOnlyConfig and Int8DynamicActivationIntxWeightConfig.
# Not applied to Int4WeightOnlyConfig (different API) or UIntxWeightOnlyConfig (no such param).
DEFAULT_QPARAMS_ALGORITHM = "hqq_scale_only"


def quantize_model_(
eager_model: torch.nn.Module,
qlinear_config: Optional[str] = None,
qlinear_group_size: Optional[int] = 32,
qlinear_packing_format: Optional[str] = None,
qembedding_config: Optional[str] = None,
qembedding_group_size: Optional[int] = 0,
qparams_algorithm: Optional[str] = None,
) -> torch.nn.Module:
# qparams_algorithm is applied to IntxWeightOnlyConfig and Int8DynamicActivationIntxWeightConfig.
# Not applied to Int4WeightOnlyConfig (different API) or UIntxWeightOnlyConfig (no such param).
if qparams_algorithm is None:
qparams_algorithm = DEFAULT_QPARAMS_ALGORITHM
if not (qlinear_config or qembedding_config):
return

Expand Down Expand Up @@ -54,10 +64,12 @@ def quantize_model_(
"4w": IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=embedding_weight_granularity,
intx_choose_qparams_algorithm=qparams_algorithm,
),
"8w": IntxWeightOnlyConfig(
weight_dtype=torch.int8,
granularity=embedding_weight_granularity,
intx_choose_qparams_algorithm=qparams_algorithm,
),
}[qembedding_config]

Expand All @@ -75,6 +87,7 @@ def build_linear_config(quant_config_key: str, granularity: str, packing_format:
return Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=granularity,
intx_choose_qparams_algorithm=qparams_algorithm,
)
if quant_config_key == "4w":
# Determine if we need to use Int4WeightOnlyConfig with int4_packing_format
Expand All @@ -88,16 +101,19 @@ def build_linear_config(quant_config_key: str, granularity: str, packing_format:
return IntxWeightOnlyConfig(
weight_dtype=torch.int4,
granularity=granularity,
intx_choose_qparams_algorithm=qparams_algorithm,
)
if quant_config_key == "8w":
return IntxWeightOnlyConfig(
weight_dtype=torch.int8,
granularity=granularity,
intx_choose_qparams_algorithm=qparams_algorithm,
)
if quant_config_key == "8da8w":
return Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int8,
weight_granularity=PerAxis(0),
intx_choose_qparams_algorithm=qparams_algorithm,
)
if quant_config_key == "fpa4w":
# Need to import to load the ops
Expand Down
Loading