Skip to content

Commit 1e8bb49

Browse files
committed
issue/884 - add_rms_norm on metax and moore
1 parent 3b5afff commit 1e8bb49

File tree

5 files changed

+392
-4
lines changed

5 files changed

+392
-4
lines changed
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __ADD_RMS_NORM_METAX_CUH__
2+
#define __ADD_RMS_NORM_METAX_CUH__
3+
4+
#include "../add_rms_norm.h"
5+
6+
DESCRIPTOR(metax)
7+
8+
#endif
Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
#include "../../../devices/metax/metax_common.h"
2+
#include "add_rms_norm_metax.cuh"
3+
4+
#include "../../../devices/metax/metax_kernel_common.h"
5+
#include <cub/block/block_reduce.cuh>
6+
7+
#include "../../../reduce/cuda/reduce.cuh"
8+
9+
#include "../cuda/kernel.cuh"
10+
11+
// Kernel function template for add_rms_norm on Metax platform
12+
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
13+
INFINIOP_METAX_KERNEL add_rmsnormKernel(
14+
Tdata *__restrict__ y,
15+
Tdata *__restrict__ residual_out,
16+
ptrdiff_t stride_y_batch,
17+
ptrdiff_t stride_y_nhead,
18+
ptrdiff_t stride_residual_out_batch,
19+
ptrdiff_t stride_residual_out_nhead,
20+
const Tdata *__restrict__ a,
21+
ptrdiff_t stride_a_batch,
22+
ptrdiff_t stride_a_nhead,
23+
const Tdata *__restrict__ b,
24+
ptrdiff_t stride_b_batch,
25+
ptrdiff_t stride_b_nhead,
26+
const Tweight *__restrict__ w,
27+
size_t nhead,
28+
size_t dim,
29+
float epsilon) {
30+
add_rmsnormBlock<BLOCK_SIZE, Tcompute>(
31+
y, residual_out,
32+
stride_y_batch, stride_y_nhead,
33+
stride_residual_out_batch, stride_residual_out_nhead,
34+
a, stride_a_batch, stride_a_nhead,
35+
b, stride_b_batch, stride_b_nhead,
36+
w, nhead, dim, epsilon);
37+
}
38+
39+
namespace op::add_rms_norm::metax {
40+
41+
// Internal opaque structure for Metax device handle
42+
struct Descriptor::Opaque {
43+
std::shared_ptr<device::metax::Handle::Internal> internal;
44+
};
45+
46+
// Destructor
47+
Descriptor::~Descriptor() {
48+
delete _opaque;
49+
}
50+
51+
// Create descriptor for add_rms_norm operator
52+
infiniStatus_t Descriptor::create(
53+
infiniopHandle_t handle,
54+
Descriptor **desc_ptr,
55+
infiniopTensorDescriptor_t y_desc,
56+
infiniopTensorDescriptor_t a_desc,
57+
infiniopTensorDescriptor_t b_desc,
58+
infiniopTensorDescriptor_t weight_desc,
59+
float epsilon,
60+
infiniopTensorDescriptor_t residual_out_desc) {
61+
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
62+
CHECK_RESULT(result);
63+
auto info = result.take();
64+
65+
*desc_ptr = new Descriptor(
66+
new Opaque{reinterpret_cast<device::metax::Handle *>(handle)->internal()},
67+
std::move(info),
68+
0,
69+
handle->device, handle->device_id);
70+
return INFINI_STATUS_SUCCESS;
71+
}
72+
73+
// Launch kernel with different data types
74+
template <unsigned int BLOCK_SIZE>
75+
infiniStatus_t launchKernel(
76+
uint32_t batch_size, size_t nhead, size_t dim,
77+
void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead,
78+
void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead,
79+
const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead,
80+
const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead,
81+
const void *w, infiniDtype_t wtype,
82+
float epsilon,
83+
hcStream_t stream) {
84+
85+
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
86+
add_rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, stream>>>( \
87+
reinterpret_cast<Tdata *>(y), \
88+
reinterpret_cast<Tdata *>(residual_out), \
89+
stride_y_batch, \
90+
stride_y_nhead, \
91+
stride_residual_out_batch, \
92+
stride_residual_out_nhead, \
93+
reinterpret_cast<const Tdata *>(a), \
94+
stride_a_batch, \
95+
stride_a_nhead, \
96+
reinterpret_cast<const Tdata *>(b), \
97+
stride_b_batch, \
98+
stride_b_nhead, \
99+
reinterpret_cast<const Tweight *>(w), \
100+
nhead, \
101+
dim, \
102+
epsilon)
103+
104+
// Handle different data type combinations following Metax pattern
105+
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
106+
LAUNCH_KERNEL(half, half, float);
107+
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
108+
LAUNCH_KERNEL(__hpcc_bfloat16, __hpcc_bfloat16, float);
109+
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
110+
LAUNCH_KERNEL(__hpcc_bfloat16, float, float);
111+
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
112+
LAUNCH_KERNEL(half, float, float);
113+
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
114+
LAUNCH_KERNEL(half, __hpcc_bfloat16, float);
115+
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
116+
LAUNCH_KERNEL(__hpcc_bfloat16, half, float);
117+
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
118+
LAUNCH_KERNEL(float, float, float);
119+
} else {
120+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
121+
}
122+
123+
#undef LAUNCH_KERNEL
124+
125+
return INFINI_STATUS_SUCCESS;
126+
}
127+
128+
// Main calculation function
129+
infiniStatus_t Descriptor::calculate(
130+
void *workspace, size_t workspace_size,
131+
void *y, const void *a, const void *b, const void *weight,
132+
void *residual_out, void *stream_) const {
133+
134+
// Check workspace size
135+
if (workspace_size < _workspace_size) {
136+
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
137+
}
138+
139+
// Extract tensor strides and dimensions
140+
auto stride_a_batch = _info.a_strides[0];
141+
auto stride_a_nhead = _info.a_strides[1];
142+
auto stride_b_batch = _info.b_strides[0];
143+
auto stride_b_nhead = _info.b_strides[1];
144+
auto stride_y_batch = _info.y_strides[0];
145+
auto stride_y_nhead = _info.y_strides[1];
146+
auto stride_residual_out_batch = _info.residual_out_strides[0];
147+
auto stride_residual_out_nhead = _info.residual_out_strides[1];
148+
auto dim = _info.dim();
149+
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
150+
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
151+
auto stream = reinterpret_cast<hcStream_t>(stream_);
152+
153+
// Launch kernel with appropriate block size based on device capability
154+
if (_opaque->internal->maxThreadsPerBlock() == METAX_BLOCK_SIZE_1024) {
155+
CHECK_STATUS(launchKernel<METAX_BLOCK_SIZE_1024>(
156+
batch_size, nhead, dim,
157+
y, _info.atype, stride_y_batch, stride_y_nhead,
158+
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
159+
a, stride_a_batch, stride_a_nhead,
160+
b, stride_b_batch, stride_b_nhead,
161+
weight, _info.wtype, _info.epsilon, stream));
162+
} else {
163+
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
164+
}
165+
return INFINI_STATUS_SUCCESS;
166+
}
167+
} // namespace op::add_rms_norm::metax
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __ADD_RMS_NORM_MOORE_H__
2+
#define __ADD_RMS_NORM_MOORE_H__
3+
4+
#include "../add_rms_norm.h"
5+
6+
DESCRIPTOR(moore)
7+
8+
#endif
Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
#include "../../../devices/moore/moore_common.h"
2+
#include "add_rms_norm_moore.h"
3+
4+
#include "../../../devices/moore/moore_kernel_common.h"
5+
#include <cub/block/block_reduce.cuh>
6+
7+
#include "../../../reduce/cuda/reduce.cuh"
8+
9+
#include "../cuda/kernel.cuh"
10+
11+
// Kernel function template for add_rms_norm on Moore platform
12+
template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
13+
INFINIOP_MOORE_KERNEL add_rmsnormKernel(
14+
Tdata *__restrict__ y,
15+
Tdata *__restrict__ residual_out,
16+
ptrdiff_t stride_y_batch,
17+
ptrdiff_t stride_y_nhead,
18+
ptrdiff_t stride_residual_out_batch,
19+
ptrdiff_t stride_residual_out_nhead,
20+
const Tdata *__restrict__ a,
21+
ptrdiff_t stride_a_batch,
22+
ptrdiff_t stride_a_nhead,
23+
const Tdata *__restrict__ b,
24+
ptrdiff_t stride_b_batch,
25+
ptrdiff_t stride_b_nhead,
26+
const Tweight *__restrict__ w,
27+
size_t nhead,
28+
size_t dim,
29+
float epsilon) {
30+
add_rmsnormBlock<BLOCK_SIZE, Tcompute>(
31+
y, residual_out,
32+
stride_y_batch, stride_y_nhead,
33+
stride_residual_out_batch, stride_residual_out_nhead,
34+
a, stride_a_batch, stride_a_nhead,
35+
b, stride_b_batch, stride_b_nhead,
36+
w, nhead, dim, epsilon);
37+
}
38+
39+
namespace op::add_rms_norm::moore {
40+
41+
// Internal opaque structure for Moore device handle
42+
struct Descriptor::Opaque {
43+
std::shared_ptr<device::moore::Handle::Internal> internal;
44+
};
45+
46+
// Destructor
47+
Descriptor::~Descriptor() {
48+
delete _opaque;
49+
}
50+
51+
// Create descriptor for add_rms_norm operator
52+
infiniStatus_t Descriptor::create(
53+
infiniopHandle_t handle,
54+
Descriptor **desc_ptr,
55+
infiniopTensorDescriptor_t y_desc,
56+
infiniopTensorDescriptor_t a_desc,
57+
infiniopTensorDescriptor_t b_desc,
58+
infiniopTensorDescriptor_t weight_desc,
59+
float epsilon,
60+
infiniopTensorDescriptor_t residual_out_desc) {
61+
auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc);
62+
CHECK_RESULT(result);
63+
auto info = result.take();
64+
65+
*desc_ptr = new Descriptor(
66+
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
67+
std::move(info),
68+
0,
69+
handle->device, handle->device_id);
70+
return INFINI_STATUS_SUCCESS;
71+
}
72+
73+
// Launch kernel with different data types
74+
template <unsigned int BLOCK_SIZE>
75+
infiniStatus_t launchKernel(
76+
uint32_t batch_size, size_t nhead, size_t dim,
77+
void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead,
78+
void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead,
79+
const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead,
80+
const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead,
81+
const void *w, infiniDtype_t wtype,
82+
float epsilon,
83+
musaStream_t musa_stream) {
84+
85+
#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \
86+
add_rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, musa_stream>>>( \
87+
reinterpret_cast<Tdata *>(y), \
88+
reinterpret_cast<Tdata *>(residual_out), \
89+
stride_y_batch, \
90+
stride_y_nhead, \
91+
stride_residual_out_batch, \
92+
stride_residual_out_nhead, \
93+
reinterpret_cast<const Tdata *>(a), \
94+
stride_a_batch, \
95+
stride_a_nhead, \
96+
reinterpret_cast<const Tdata *>(b), \
97+
stride_b_batch, \
98+
stride_b_nhead, \
99+
reinterpret_cast<const Tweight *>(w), \
100+
nhead, \
101+
dim, \
102+
epsilon)
103+
104+
// Handle different data type combinations
105+
if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) {
106+
LAUNCH_KERNEL(half, half, float);
107+
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) {
108+
LAUNCH_KERNEL(half, __mt_bfloat16, float);
109+
} else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) {
110+
LAUNCH_KERNEL(half, float, float);
111+
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) {
112+
LAUNCH_KERNEL(__mt_bfloat16, __mt_bfloat16, float);
113+
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) {
114+
LAUNCH_KERNEL(__mt_bfloat16, half, float);
115+
} else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) {
116+
LAUNCH_KERNEL(__mt_bfloat16, float, float);
117+
} else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) {
118+
LAUNCH_KERNEL(float, float, float);
119+
} else {
120+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
121+
}
122+
123+
#undef LAUNCH_KERNEL
124+
125+
return INFINI_STATUS_SUCCESS;
126+
}
127+
128+
// Main calculation function
129+
infiniStatus_t Descriptor::calculate(
130+
void *workspace, size_t workspace_size,
131+
void *y, const void *a, const void *b, const void *weight,
132+
void *residual_out, void *stream) const {
133+
134+
// Check workspace size
135+
if (workspace_size < _workspace_size) {
136+
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
137+
}
138+
139+
// Extract tensor strides and dimensions
140+
auto stride_a_batch = _info.a_strides[0];
141+
auto stride_a_nhead = _info.a_strides[1];
142+
auto stride_b_batch = _info.b_strides[0];
143+
auto stride_b_nhead = _info.b_strides[1];
144+
auto stride_y_batch = _info.y_strides[0];
145+
auto stride_y_nhead = _info.y_strides[1];
146+
auto stride_residual_out_batch = _info.residual_out_strides[0];
147+
auto stride_residual_out_nhead = _info.residual_out_strides[1];
148+
auto dim = _info.dim();
149+
uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]);
150+
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
151+
auto musa_stream = reinterpret_cast<musaStream_t>(stream);
152+
153+
// Launch kernel with appropriate block size based on device capability
154+
if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) {
155+
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_1024>(
156+
batch_size, nhead, dim,
157+
y, _info.atype, stride_y_batch, stride_y_nhead,
158+
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
159+
a, stride_a_batch, stride_a_nhead,
160+
b, stride_b_batch, stride_b_nhead,
161+
weight, _info.wtype, _info.epsilon, musa_stream));
162+
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) {
163+
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_512>(
164+
batch_size, nhead, dim,
165+
y, _info.atype, stride_y_batch, stride_y_nhead,
166+
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
167+
a, stride_a_batch, stride_a_nhead,
168+
b, stride_b_batch, stride_b_nhead,
169+
weight, _info.wtype, _info.epsilon, musa_stream));
170+
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) {
171+
CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_2048>(
172+
batch_size, nhead, dim,
173+
y, _info.atype, stride_y_batch, stride_y_nhead,
174+
residual_out, stride_residual_out_batch, stride_residual_out_nhead,
175+
a, stride_a_batch, stride_a_nhead,
176+
b, stride_b_batch, stride_b_nhead,
177+
weight, _info.wtype, _info.epsilon, musa_stream));
178+
} else {
179+
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
180+
}
181+
return INFINI_STATUS_SUCCESS;
182+
}
183+
} // namespace op::add_rms_norm::moore

0 commit comments

Comments
 (0)