Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions optimum/commands/export/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def parse_args_executorch(parser):
required_group.add_argument(
"--qlinear",
type=str,
choices=["8da4w", "4w", "8w", "8da8w", "8da4w,8da8w"],
choices=["8da4w", "4w", "8w", "8da8w", "8da4w,8da8w", "fpa4w"],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does fpa4w work on backends other than metal?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it only works with Metal

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a check then, if a user pass --qlinear fpa4w and --device mps at the same time

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

required=False,
help=(
"Quantization config for decoder linear layers.\n\n"
Expand All @@ -85,7 +85,8 @@ def parse_args_executorch(parser):
" 8da8w - 8-bit dynamic activation, 8-bit weight\n"
" 8da4w,8da8w - 8-bit dynamic activation, 4-bit weight and 8-bit weight\n"
" 4w - 4-bit weight only\n"
" 8w - 8-bit weight only"
" 8w - 8-bit weight only\n"
" fpa4w - floating point activation, 4-bit weight (MPS backend)"
),
)
required_group.add_argument(
Expand All @@ -106,7 +107,7 @@ def parse_args_executorch(parser):
required_group.add_argument(
"--qlinear_encoder",
type=str,
choices=["8da4w", "4w", "8w", "8da8w", "8da4w,8da8w"],
choices=["8da4w", "4w", "8w", "8da8w", "8da4w,8da8w", "fpa4w"],
required=False,
help=(
"Quantization config for encoder linear layers.\n\n"
Expand All @@ -115,7 +116,8 @@ def parse_args_executorch(parser):
" 8da8w - 8-bit dynamic activation, 8-bit weight\n"
" 8da4w,8da8w - 8-bit dynamic activation, 4-bit weight; fallback on 8-bit dynamic activation, 8-bit weight per-channel where group size doesn't divide block size cleanly \n"
" 4w - 4-bit weight only\n"
" 8w - 8-bit weight only"
" 8w - 8-bit weight only\n"
" fpa4w - floating point activation, 4-bit weight (MPS backend)"
),
)
required_group.add_argument(
Expand Down Expand Up @@ -182,9 +184,9 @@ def parse_args_executorch(parser):
required_group.add_argument(
"--device",
type=str,
choices=["cpu", "cuda", "cuda:0", "cuda:1", "cuda:2", "cuda:3"],
choices=["cpu", "cuda", "cuda:0", "cuda:1", "cuda:2", "cuda:3", "mps"],
required=False,
help="Device to run the model on. Options: cpu, cuda. Default: cpu.",
help="Device to run the model on. Options: cpu, cuda, mps. Default: cpu.",
)


Expand Down Expand Up @@ -219,6 +221,14 @@ def run(self):
"--qlinear_encoder_packing_format can only be used when --qlinear_encoder is set to '4w'"
)

# Validate fpa4w quantization requires MPS device
qlinear = getattr(self.args, "qlinear", None)
qlinear_encoder = getattr(self.args, "qlinear_encoder", None)
if qlinear == "fpa4w" and device != "mps":
raise ValueError("--qlinear=fpa4w can only be used when --device is set to 'mps'")
if qlinear_encoder == "fpa4w" and device != "mps":
raise ValueError("--qlinear_encoder=fpa4w can only be used when --device is set to 'mps'")

kwargs = {}
if self.args.use_custom_sdpa:
kwargs["use_custom_sdpa"] = self.args.use_custom_sdpa
Expand Down
20 changes: 14 additions & 6 deletions optimum/exporters/executorch/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def quantize_model_(
if not (qlinear_config or qembedding_config):
return

from torchao.experimental.quant_api import UIntxWeightOnlyConfig
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import (
Int4WeightOnlyConfig,
Expand All @@ -42,9 +43,9 @@ def quantize_model_(

if qembedding_config:
if qlinear_config == "8w":
assert (
qembedding_group_size == 0
), "8-bit embedding quantization only supports per-token at the moment, please use qembedding_group_size = 0."
assert qembedding_group_size == 0, (
"8-bit embedding quantization only supports per-token at the moment, please use qembedding_group_size = 0."
)
if qembedding_group_size == 0:
embedding_weight_granularity = PerAxis(0)
else:
Expand Down Expand Up @@ -101,6 +102,13 @@ def build_linear_config(quant_config_key: str, granularity: str, packing_format:
weight_dtype=torch.int8,
weight_granularity=PerAxis(0),
)
if quant_config_key == "fpa4w":
# Need to import to load the ops
import torchao.experimental.ops.mps # noqa: F401
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit should we import this in torchao.experimental.quant_api so that from torchao.experimental.quant_api import UIntxWeightOnlyConfig can satisfy the import requirement?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

import torchao.experimental.ops.mps will raise an error if the op library isn't found. The metal ops are not built in torchao by default. For that reason, I thought it would be more clear to have an explicit import that loads the ops, rather than as a side effect of importing the config.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's also why I import torchao.experimental.ops.mps only if quant_config_key == "fpa4w"

return UIntxWeightOnlyConfig(
group_size=qlinear_group_size,
bitwidth=4,
)
raise ValueError(f"Unsupported linear quantization config '{quant_config_key}'.")

qlinear_configs = [cfg.strip() for cfg in qlinear_config.split(",")]
Expand All @@ -120,9 +128,9 @@ def build_linear_config(quant_config_key: str, granularity: str, packing_format:
)
fallback_linear_config_key = None
else:
assert (
qlinear_group_size % 2 == 0
), f"Linear quantization group size must be a multiple of 2, got {qlinear_group_size}."
assert qlinear_group_size % 2 == 0, (
f"Linear quantization group size must be a multiple of 2, got {qlinear_group_size}."
)
linear_weight_granularity = PerGroup(qlinear_group_size)

logging.info("Quantizing linear layers.")
Expand Down
Loading