Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions saxml/tools/offline_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def parse_known_args(argv):
'gemma2b',
'gemma7b',
'llama2-70b-weight-linear-only-int8',
'gptfp8',
],
help='Quantization Config.',
)
Expand Down
47 changes: 35 additions & 12 deletions saxml/tools/quant_fn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,14 @@

import apache_beam as beam
from jax import numpy as jnp
from jax import lax
import numpy as np
import quantization_actions
import quantization_configs
import tensorstore_util

from flax.linen import fp8_ops


def _split_with_shard(
var: np.ndarray,
Expand Down Expand Up @@ -83,6 +86,7 @@ def _write_quantized_tensor(
zp: np.ndarray | None,
suffix: str = '',
sharding_indices: Optional[list[int]] = None,
scale_act: np.ndarray | None = None,
) -> None:
if number_bit == 4 and action.use_int4_packed_weights:
# Extra pack needed for 4 bit.
Expand All @@ -99,6 +103,8 @@ def _write_quantized_tensor(
if zp:
zp_name = action.target_name + '_quantized_zp' + suffix
self._writer.write_variable(zp_name, zp)
scale_act_name = action.target_name + '_quantized_act_scale' + suffix
self._writer.write_variable(scale_act_name, scale_act)

def process(self, action: quantization_actions.OptAction):
target_var = self._readers[action.input_dir].read_variable(
Expand All @@ -115,6 +121,7 @@ def process(self, action: quantization_actions.OptAction):
optimization_on_bound = False
p_value = 1.0
per_channel = False
scale_act = None
if action.number_bit == 4 and action.use_optimization:
optimization_on_bound = True
p_value = action.optimization_p_value
Expand All @@ -125,18 +132,33 @@ def process(self, action: quantization_actions.OptAction):
per_channel = True

if self._symmetric:
target_var, scale = quantization_configs.quantize_tensor(
target_var,
quantize_axis,
quantize_factor,
True,
number_of_bits,
use_fp=action.use_fp,
add_scale_eps=action.add_scale_eps,
optimization_on_bound=optimization_on_bound,
p_value=p_value,
per_channel=per_channel,
)
if action.use_fp and number_of_bits == 8:
assert per_channel == False, f'fp8 only supports per-tensor quantization.'
scale_act_name = action.source_name[:-1] + 'einsum.input_scale'
scale_kernel_name = action.source_name[:-1] + 'einsum.kernel_scale'
scale_act = self._readers[action.input_dir].read_variable(
scale_act_name, action.layer_id, action.num_layers
)
scale = self._readers[action.input_dir].read_variable(
scale_kernel_name, action.layer_id, action.num_layers
)
compute_dtype = target_var.dtype
target_var = fp8_ops.quantize(target_var, jnp.float8_e4m3fn, scale, compute_dtype)
# This is needed since fp8 cannot be saved.
target_var = lax.bitcast_convert_type(target_var, new_dtype=jnp.int8)
else:
target_var, scale = quantization_configs.quantize_tensor(
target_var,
quantize_axis,
quantize_factor,
True,
number_of_bits,
use_fp=action.use_fp,
add_scale_eps=action.add_scale_eps,
optimization_on_bound=optimization_on_bound,
p_value=p_value,
per_channel=per_channel,
)
zp = None
else:
target_var, scale, zp = quantization_configs.quantize_tensor(
Expand All @@ -158,6 +180,7 @@ def process(self, action: quantization_actions.OptAction):
scale,
zp,
sharding_indices=action.sharding_indices,
scale_act = scale_act,
)
else:
# no quantization.
Expand Down
12 changes: 12 additions & 0 deletions saxml/tools/quantization_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,18 @@ class QuantizationConfigsGPTJ(QuantizationConfigs):
}


class QuantizationConfigsFP8(QuantizationConfigsGPTJ):
"""Quantization config for FP8 model."""

factor = 1.0
configs = {
'ff_layer.ffn_layer1.linear.w': ([0, 1], factor, 0, -1),
'ff_layer.ffn_layer2.linear.w': ([0, 1], factor, 0, -1),
'self_attention.combined_qkv.w': ([0, 1, 2, 3], factor, 1, -1),
'self_attention.post.w': ([0, 1, 2], factor, 0, -1),
}


class QuantizationConfigsGPTJStacked(QuantizationConfigs):
"""Quantization config for GPTJ model."""

Expand Down
1 change: 1 addition & 0 deletions saxml/tools/quantization_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
'llama2-70b-weight-linear-only-int8': (
quantization_configs.QuantizationConfigsLLaMA70BWeightLinearOnlyInt8()
),
'gptfp8': quantization_configs.QuantizationConfigsFP8(),
}

NAME_TO_CONFIG_STACKED = {
Expand Down