diff --git a/optimum/commands/export/executorch.py b/optimum/commands/export/executorch.py index 729f6c7f..58d900c0 100644 --- a/optimum/commands/export/executorch.py +++ b/optimum/commands/export/executorch.py @@ -76,7 +76,7 @@ def parse_args_executorch(parser): required_group.add_argument( "--qlinear", type=str, - choices=["8da4w", "4w", "8w", "8da8w", "8da4w,8da8w"], + choices=["8da4w", "4w", "8w", "8da8w", "8da4w,8da8w", "fpa4w"], required=False, help=( "Quantization config for decoder linear layers.\n\n" @@ -85,7 +85,8 @@ def parse_args_executorch(parser): " 8da8w - 8-bit dynamic activation, 8-bit weight\n" " 8da4w,8da8w - 8-bit dynamic activation, 4-bit weight and 8-bit weight\n" " 4w - 4-bit weight only\n" - " 8w - 8-bit weight only" + " 8w - 8-bit weight only\n" + " fpa4w - floating point activation, 4-bit weight (MPS backend)" ), ) required_group.add_argument( @@ -106,7 +107,7 @@ def parse_args_executorch(parser): required_group.add_argument( "--qlinear_encoder", type=str, - choices=["8da4w", "4w", "8w", "8da8w", "8da4w,8da8w"], + choices=["8da4w", "4w", "8w", "8da8w", "8da4w,8da8w", "fpa4w"], required=False, help=( "Quantization config for encoder linear layers.\n\n" @@ -115,7 +116,8 @@ def parse_args_executorch(parser): " 8da8w - 8-bit dynamic activation, 8-bit weight\n" " 8da4w,8da8w - 8-bit dynamic activation, 4-bit weight; fallback on 8-bit dynamic activation, 8-bit weight per-channel where group size doesn't divide block size cleanly \n" " 4w - 4-bit weight only\n" - " 8w - 8-bit weight only" + " 8w - 8-bit weight only\n" + " fpa4w - floating point activation, 4-bit weight (MPS backend)" ), ) required_group.add_argument( @@ -182,9 +184,9 @@ def parse_args_executorch(parser): required_group.add_argument( "--device", type=str, - choices=["cpu", "cuda", "cuda:0", "cuda:1", "cuda:2", "cuda:3"], + choices=["cpu", "cuda", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "mps"], required=False, - help="Device to run the model on. Options: cpu, cuda. Default: cpu.", + help="Device to run the model on. Options: cpu, cuda, mps. Default: cpu.", ) @@ -219,6 +221,14 @@ def run(self): "--qlinear_encoder_packing_format can only be used when --qlinear_encoder is set to '4w'" ) + # Validate fpa4w quantization requires MPS device + qlinear = getattr(self.args, "qlinear", None) + qlinear_encoder = getattr(self.args, "qlinear_encoder", None) + if qlinear == "fpa4w" and device != "mps": + raise ValueError("--qlinear=fpa4w can only be used when --device is set to 'mps'") + if qlinear_encoder == "fpa4w" and device != "mps": + raise ValueError("--qlinear_encoder=fpa4w can only be used when --device is set to 'mps'") + kwargs = {} if self.args.use_custom_sdpa: kwargs["use_custom_sdpa"] = self.args.use_custom_sdpa diff --git a/optimum/exporters/executorch/quantization.py b/optimum/exporters/executorch/quantization.py index 958ce525..6e48d9dc 100644 --- a/optimum/exporters/executorch/quantization.py +++ b/optimum/exporters/executorch/quantization.py @@ -31,6 +31,7 @@ def quantize_model_( if not (qlinear_config or qembedding_config): return + from torchao.experimental.quant_api import UIntxWeightOnlyConfig from torchao.quantization.granularity import PerAxis, PerGroup from torchao.quantization.quant_api import ( Int4WeightOnlyConfig, @@ -42,9 +43,9 @@ def quantize_model_( if qembedding_config: if qlinear_config == "8w": - assert ( - qembedding_group_size == 0 - ), "8-bit embedding quantization only supports per-token at the moment, please use qembedding_group_size = 0." + assert qembedding_group_size == 0, ( + "8-bit embedding quantization only supports per-token at the moment, please use qembedding_group_size = 0." + ) if qembedding_group_size == 0: embedding_weight_granularity = PerAxis(0) else: @@ -101,6 +102,13 @@ def build_linear_config(quant_config_key: str, granularity: str, packing_format: weight_dtype=torch.int8, weight_granularity=PerAxis(0), ) + if quant_config_key == "fpa4w": + # Need to import to load the ops + import torchao.experimental.ops.mps # noqa: F401 + return UIntxWeightOnlyConfig( + group_size=qlinear_group_size, + bitwidth=4, + ) raise ValueError(f"Unsupported linear quantization config '{quant_config_key}'.") qlinear_configs = [cfg.strip() for cfg in qlinear_config.split(",")] @@ -120,9 +128,9 @@ def build_linear_config(quant_config_key: str, granularity: str, packing_format: ) fallback_linear_config_key = None else: - assert ( - qlinear_group_size % 2 == 0 - ), f"Linear quantization group size must be a multiple of 2, got {qlinear_group_size}." + assert qlinear_group_size % 2 == 0, ( + f"Linear quantization group size must be a multiple of 2, got {qlinear_group_size}." + ) linear_weight_granularity = PerGroup(qlinear_group_size) logging.info("Quantizing linear layers.")