Skip to content

Commit 73968c3

Browse files
committed
issue/843: success per_channel_quant_int8
1 parent 2525802 commit 73968c3

File tree

12 files changed

+1205
-2
lines changed

12 files changed

+1205
-2
lines changed
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#ifndef __INFINIOP_PER_CHANNEL_QUANT_INT8_API_H__
2+
#define __INFINIOP_PER_CHANNEL_QUANT_INT8_API_H__
3+
4+
#include "../../operator_descriptor.h"
5+
6+
typedef InfiniopDescriptor *infiniopPerChannelQuantI8Descriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreatePerChannelQuantI8Descriptor(infiniopHandle_t handle,
9+
infiniopPerChannelQuantI8Descriptor_t *desc_ptr,
10+
infiniopTensorDescriptor_t x_packed_desc,
11+
infiniopTensorDescriptor_t x_scale_desc,
12+
infiniopTensorDescriptor_t x_zero_desc,
13+
infiniopTensorDescriptor_t x_desc);
14+
15+
__C __export infiniStatus_t infiniopGetPerChannelQuantI8WorkspaceSize(infiniopPerChannelQuantI8Descriptor_t desc, size_t *size);
16+
17+
__C __export infiniStatus_t infiniopPerChannelQuantI8(infiniopPerChannelQuantI8Descriptor_t desc,
18+
void *workspace,
19+
size_t workspace_size,
20+
void *x_packed,
21+
void *x_scale,
22+
void *x_zero,
23+
const void *x,
24+
void *stream);
25+
26+
__C __export infiniStatus_t infiniopDestroyPerChannelQuantI8Descriptor(infiniopPerChannelQuantI8Descriptor_t desc);
27+
28+
#endif
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
#ifndef __PERCHANNEL_QUANTINT8_KERNEL_CUH__
2+
#define __PERCHANNEL_QUANTINT8_KERNEL_CUH__
3+
4+
#include <cub/block/block_reduce.cuh>
5+
__device__ inline int round_half_away_from_zero(float x) {
6+
float ax = fabsf(x);
7+
float r = floorf(ax + 0.5f);
8+
return (x >= 0.0f) ? (int)r : -(int)r;
9+
}
10+
11+
template <typename Tdata, unsigned int BLOCK_SIZE>
12+
__device__ void blockPerChannelQuantI8Kernel(
13+
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x,
14+
int M, int K) {
15+
int row = blockIdx.x;
16+
int tid = row * K;
17+
18+
// ---- 1. reduce max ----
19+
float local_max = op::common_cuda::reduce_op::max<BLOCK_SIZE, Tdata>(
20+
x + tid, K);
21+
22+
__shared__ float global_max_f;
23+
if (threadIdx.x == 0) {
24+
global_max_f = local_max;
25+
}
26+
__syncthreads();
27+
28+
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
29+
__shared__ typename BlockReduce::TempStorage temp_storage;
30+
31+
// ---- 2. reduce min ----
32+
float thread_min = __FLT_MAX__;
33+
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {
34+
thread_min = fminf(thread_min, (float)x[tid + ind]);
35+
}
36+
float local_min = BlockReduce(temp_storage).Reduce(thread_min, cub::Min());
37+
38+
__shared__ float global_min_f;
39+
if (threadIdx.x == 0) {
40+
global_min_f = local_min;
41+
}
42+
__syncthreads();
43+
44+
// ---- 3. 使用 float(匹配 python)计算 scale/zero ----
45+
float global_max = global_max_f;
46+
float global_min = global_min_f;
47+
48+
float scale = (global_max - global_min) / 255.0f;
49+
if (scale < 1e-8f) {
50+
scale = 1e-8f;
51+
}
52+
53+
float inv_scale = 1.0f / scale;
54+
float zero = -global_min * inv_scale - 128.0f;
55+
56+
// 写回 scale, zero
57+
x_scale[row] = (Tdata)scale;
58+
x_zero[row] = (Tdata)zero;
59+
60+
// ---- 4. 使用 float + half-away-from-zero(与 Python 完全一致)----
61+
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {
62+
63+
float v = (float)x[tid + ind];
64+
float qf = v * inv_scale + zero;
65+
66+
int q = round_half_away_from_zero(qf);
67+
68+
if (q > 127) {
69+
q = 127;
70+
}
71+
if (q < -128) {
72+
q = -128;
73+
}
74+
75+
x_packed[tid + ind] = (int8_t)q;
76+
}
77+
}
78+
79+
template <typename Tdata, unsigned int BLOCK_SIZE>
80+
__device__ void blockPerChannelQuantI8SymKernel(
81+
int8_t *x_packed, float *x_scale, const Tdata *x,
82+
int M, int K) {
83+
int row = blockIdx.x;
84+
int tid = row * K;
85+
86+
typedef cub::BlockReduce<float, BLOCK_SIZE> BlockReduce;
87+
__shared__ typename BlockReduce::TempStorage temp_storage;
88+
89+
// ---- 2. reduce min ----
90+
float thread_max = -__FLT_MAX__;
91+
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {
92+
thread_max = fmaxf(thread_max, fabs((float)x[tid + ind]));
93+
}
94+
float local_max = BlockReduce(temp_storage).Reduce(thread_max, cub::Max());
95+
96+
__shared__ float global_max_f;
97+
if (threadIdx.x == 0) {
98+
global_max_f = local_max;
99+
}
100+
__syncthreads();
101+
102+
// ---- 3. 使用 float(匹配 python)计算 scale/zero ----
103+
float global_max = global_max_f;
104+
105+
float scale = global_max / 127.0f;
106+
if (scale < 1e-8f) {
107+
scale = 1e-8f;
108+
}
109+
110+
float inv_scale = 1.0f / scale;
111+
112+
// 写回 scale, zero
113+
x_scale[row] = (Tdata)scale;
114+
115+
// ---- 4. 使用 float + half-away-from-zero(与 Python 完全一致)----
116+
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE) {
117+
118+
float v = (float)x[tid + ind];
119+
float qf = v * inv_scale;
120+
121+
int q = round_half_away_from_zero(qf);
122+
123+
if (q > 127) {
124+
q = 127;
125+
}
126+
if (q < -127) {
127+
q = -127;
128+
}
129+
130+
x_packed[tid + ind] = (int8_t)q;
131+
}
132+
}
133+
134+
template <typename T>
135+
struct MaxOp {
136+
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
137+
return max(a, b);
138+
}
139+
};
140+
template <typename T>
141+
struct MinOp {
142+
__device__ __forceinline__ T operator()(const T &a, const T &b) const {
143+
return min(a, b);
144+
}
145+
};
146+
template <template <typename> class ReductionOp, typename T,
147+
int thread_group_width>
148+
__inline__ __device__ T WarpAllReduce(T val) {
149+
for (int mask = thread_group_width / 2; mask > 0; mask /= 2) {
150+
val = ReductionOp<T>()(val, __shfl_xor_sync(0xffffffff, val, mask));
151+
}
152+
return val;
153+
}
154+
155+
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
156+
__device__ void warpPerChannelQuantI8Kernel(
157+
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x,
158+
int M, int K) {
159+
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
160+
int tid = otherIdx * K;
161+
162+
if (otherIdx < M) {
163+
164+
__shared__ float max_total[BLOCK_SIZE_y];
165+
__shared__ float min_total[BLOCK_SIZE_y];
166+
167+
float max_data = -__FLT_MAX__;
168+
float min_data = __FLT_MAX__;
169+
170+
// ---- reduce max/min ----
171+
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
172+
float v = (float)x[tid + ind];
173+
max_data = fmaxf(max_data, v);
174+
min_data = fminf(min_data, v);
175+
}
176+
177+
max_data = WarpAllReduce<MaxOp, float, BLOCK_SIZE_x>(max_data);
178+
min_data = WarpAllReduce<MinOp, float, BLOCK_SIZE_x>(min_data);
179+
180+
if (threadIdx.x == 0) {
181+
max_total[threadIdx.y] = max_data;
182+
min_total[threadIdx.y] = min_data;
183+
}
184+
__syncthreads();
185+
186+
// ---- float scale/zero(与 Python float32 匹配)----
187+
float max_f = max_total[threadIdx.y];
188+
float min_f = min_total[threadIdx.y];
189+
190+
float scale = (max_f - min_f) / 255.0f;
191+
if (scale < 1e-8f) {
192+
scale = 1e-8f;
193+
}
194+
195+
float inv_scale = 1.0f / scale;
196+
float zero = -min_f * inv_scale - 128.0f;
197+
198+
x_scale[otherIdx] = scale;
199+
x_zero[otherIdx] = zero;
200+
201+
// ---- float + half-away-from-zero 量化 ----
202+
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
203+
float v = (float)x[tid + ind];
204+
float qf = v * inv_scale + zero;
205+
206+
int q = round_half_away_from_zero(qf);
207+
208+
if (q > 127) {
209+
q = 127;
210+
}
211+
if (q < -128) {
212+
q = -128;
213+
}
214+
215+
x_packed[tid + ind] = (int8_t)q;
216+
}
217+
}
218+
}
219+
220+
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
221+
__device__ void warpPerChannelQuantI8SymKernel(
222+
int8_t *x_packed, float *x_scale, const Tdata *x,
223+
int M, int K) {
224+
int otherIdx = blockIdx.x * blockDim.y + threadIdx.y;
225+
int tid = otherIdx * K;
226+
227+
if (otherIdx < M) {
228+
229+
__shared__ float max_total[BLOCK_SIZE_y];
230+
231+
float max_data = -__FLT_MAX__;
232+
233+
// ---- reduce max/min ----
234+
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
235+
float v = fabs((float)x[tid + ind]);
236+
max_data = fmaxf(max_data, v);
237+
}
238+
239+
max_data = WarpAllReduce<MaxOp, float, BLOCK_SIZE_x>(max_data);
240+
241+
if (threadIdx.x == 0) {
242+
max_total[threadIdx.y] = max_data;
243+
}
244+
__syncthreads();
245+
246+
// ---- float scale/zero(与 Python float32 匹配)----
247+
float max_f = max_total[threadIdx.y];
248+
249+
float scale = max_f / 127.0f;
250+
if (scale < 1e-8f) {
251+
scale = 1e-8f;
252+
}
253+
254+
float inv_scale = 1.0f / scale;
255+
256+
x_scale[otherIdx] = scale;
257+
258+
// ---- float + half-away-from-zero 量化 ----
259+
for (int ind = threadIdx.x; ind < K; ind += BLOCK_SIZE_x) {
260+
float v = (float)x[tid + ind];
261+
float qf = v * inv_scale;
262+
263+
int q = round_half_away_from_zero(qf);
264+
265+
if (q > 127) {
266+
q = 127;
267+
}
268+
if (q < -127) {
269+
q = -127;
270+
}
271+
272+
x_packed[tid + ind] = (int8_t)q;
273+
}
274+
}
275+
}
276+
277+
#endif // __PERCHANNEL_QUANTINT8_KERNEL_CUH__
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#ifndef __PER_CHANNEL_QUANT_INT8_INFO_H__
2+
#define __PER_CHANNEL_QUANT_INT8_INFO_H__
3+
4+
#include "../../../../utils.h"
5+
#include "../../../operator.h"
6+
#include "../../../tensor.h"
7+
8+
namespace op::per_channel_quant_int8 {
9+
10+
class PerChannelQuantI8Info {
11+
private:
12+
PerChannelQuantI8Info() = default;
13+
14+
public:
15+
infiniDtype_t dtype, packed_type;
16+
size_t M, K;
17+
18+
static utils::Result<PerChannelQuantI8Info> createPerChannelQuantI8Info(
19+
infiniopTensorDescriptor_t x_packed_desc,
20+
infiniopTensorDescriptor_t x_scale_desc,
21+
infiniopTensorDescriptor_t x_zero_desc,
22+
infiniopTensorDescriptor_t x_desc) {
23+
24+
CHECK_OR_RETURN(
25+
x_packed_desc != nullptr && x_scale_desc != nullptr && x_desc != nullptr,
26+
INFINI_STATUS_NULL_POINTER);
27+
28+
const infiniDtype_t dtype = x_desc->dtype();
29+
const infiniDtype_t packed_type = x_packed_desc->dtype();
30+
31+
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32);
32+
CHECK_DTYPE(packed_type, INFINI_DTYPE_I8);
33+
34+
CHECK_OR_RETURN(x_desc->ndim() == 2
35+
&& x_packed_desc->ndim() == 2
36+
&& x_scale_desc->ndim() == 2,
37+
INFINI_STATUS_BAD_TENSOR_SHAPE);
38+
39+
size_t M = x_desc->dim(0);
40+
size_t K = x_desc->dim(1);
41+
42+
CHECK_OR_RETURN(M == x_packed_desc->dim(0)
43+
|| K == x_packed_desc->dim(1)
44+
|| M == x_scale_desc->dim(0)
45+
|| 1 == x_scale_desc->dim(1),
46+
INFINI_STATUS_BAD_TENSOR_SHAPE);
47+
48+
return utils::Result<PerChannelQuantI8Info>(PerChannelQuantI8Info{
49+
dtype,
50+
packed_type,
51+
M,
52+
K,
53+
});
54+
}
55+
};
56+
57+
} // namespace op::per_channel_quant_int8
58+
59+
#endif // __PER_CHANNEL_QUANT_INT8_INFO_H__

0 commit comments

Comments
 (0)