From cf77a55327cfac04d9705336684721a23e58325f Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Tue, 4 Jun 2024 22:56:35 -0500 Subject: [PATCH 1/2] Add fp8 checkpoint converter. --- saxml/tools/quant_fn.py | 47 +++++++++++++++++++++-------- saxml/tools/quantization_configs.py | 12 ++++++++ 2 files changed, 47 insertions(+), 12 deletions(-) diff --git a/saxml/tools/quant_fn.py b/saxml/tools/quant_fn.py index 54d77d42..8c9c03f1 100644 --- a/saxml/tools/quant_fn.py +++ b/saxml/tools/quant_fn.py @@ -17,11 +17,14 @@ import apache_beam as beam from jax import numpy as jnp +from jax import lax import numpy as np import quantization_actions import quantization_configs import tensorstore_util +from flax.linen import fp8_ops + def _split_with_shard( var: np.ndarray, @@ -83,6 +86,7 @@ def _write_quantized_tensor( zp: np.ndarray | None, suffix: str = '', sharding_indices: Optional[list[int]] = None, + scale_act: np.ndarray | None = None, ) -> None: if number_bit == 4 and action.use_int4_packed_weights: # Extra pack needed for 4 bit. @@ -99,6 +103,8 @@ def _write_quantized_tensor( if zp: zp_name = action.target_name + '_quantized_zp' + suffix self._writer.write_variable(zp_name, zp) + scale_act_name = action.target_name + '_quantized_act_scale' + suffix + self._writer.write_variable(scale_act_name, scale_act) def process(self, action: quantization_actions.OptAction): target_var = self._readers[action.input_dir].read_variable( @@ -115,6 +121,7 @@ def process(self, action: quantization_actions.OptAction): optimization_on_bound = False p_value = 1.0 per_channel = False + scale_act = None if action.number_bit == 4 and action.use_optimization: optimization_on_bound = True p_value = action.optimization_p_value @@ -125,18 +132,33 @@ def process(self, action: quantization_actions.OptAction): per_channel = True if self._symmetric: - target_var, scale = quantization_configs.quantize_tensor( - target_var, - quantize_axis, - quantize_factor, - True, - number_of_bits, - use_fp=action.use_fp, - add_scale_eps=action.add_scale_eps, - optimization_on_bound=optimization_on_bound, - p_value=p_value, - per_channel=per_channel, - ) + if action.use_fp and number_of_bits == 8: + assert per_channel == False, f'fp8 only supports per-tensor quantization.' + scale_act_name = action.source_name[:-1] + 'einsum.input_scale' + scale_kernel_name = action.source_name[:-1] + 'einsum.kernel_scale' + scale_act = self._readers[action.input_dir].read_variable( + scale_act_name, action.layer_id, action.num_layers + ) + scale = self._readers[action.input_dir].read_variable( + scale_kernel_name, action.layer_id, action.num_layers + ) + compute_dtype = target_var.dtype + target_var = fp8_ops.quantize(target_var, jnp.float8_e4m3fn, scale, compute_dtype) + # This is needed since fp8 cannot be saved. + target_var = lax.bitcast_convert_type(target_var, new_dtype=jnp.int8) + else: + target_var, scale = quantization_configs.quantize_tensor( + target_var, + quantize_axis, + quantize_factor, + True, + number_of_bits, + use_fp=action.use_fp, + add_scale_eps=action.add_scale_eps, + optimization_on_bound=optimization_on_bound, + p_value=p_value, + per_channel=per_channel, + ) zp = None else: target_var, scale, zp = quantization_configs.quantize_tensor( @@ -158,6 +180,7 @@ def process(self, action: quantization_actions.OptAction): scale, zp, sharding_indices=action.sharding_indices, + scale_act = scale_act, ) else: # no quantization. diff --git a/saxml/tools/quantization_configs.py b/saxml/tools/quantization_configs.py index 55d28a99..a1b4a72a 100644 --- a/saxml/tools/quantization_configs.py +++ b/saxml/tools/quantization_configs.py @@ -66,6 +66,18 @@ class QuantizationConfigsGPTJ(QuantizationConfigs): } +class QuantizationConfigsFP8(QuantizationConfigsGPTJ): + """Quantization config for FP8 model.""" + + factor = 1.0 + configs = { + 'ff_layer.ffn_layer1.linear.w': ([0, 1], factor, 0, -1), + 'ff_layer.ffn_layer2.linear.w': ([0, 1], factor, 0, -1), + 'self_attention.combined_qkv.w': ([0, 1, 2, 3], factor, 1, -1), + 'self_attention.post.w': ([0, 1, 2], factor, 0, -1), + } + + class QuantizationConfigsGPTJStacked(QuantizationConfigs): """Quantization config for GPTJ model.""" From 59e3f1e58dd882e47907b92189498c631818af75 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Tue, 4 Jun 2024 23:27:41 -0500 Subject: [PATCH 2/2] Add config name map. --- saxml/tools/offline_quantize.py | 1 + saxml/tools/quantization_provider.py | 1 + 2 files changed, 2 insertions(+) diff --git a/saxml/tools/offline_quantize.py b/saxml/tools/offline_quantize.py index 4097b0de..d7f93c06 100644 --- a/saxml/tools/offline_quantize.py +++ b/saxml/tools/offline_quantize.py @@ -59,6 +59,7 @@ def parse_known_args(argv): 'gemma2b', 'gemma7b', 'llama2-70b-weight-linear-only-int8', + 'gptfp8', ], help='Quantization Config.', ) diff --git a/saxml/tools/quantization_provider.py b/saxml/tools/quantization_provider.py index 73768459..825bc061 100644 --- a/saxml/tools/quantization_provider.py +++ b/saxml/tools/quantization_provider.py @@ -23,6 +23,7 @@ 'llama2-70b-weight-linear-only-int8': ( quantization_configs.QuantizationConfigsLLaMA70BWeightLinearOnlyInt8() ), + 'gptfp8': quantization_configs.QuantizationConfigsFP8(), } NAME_TO_CONFIG_STACKED = {