|
| 1 | +#include "../../../devices/moore/moore_common.h" |
| 2 | +#include "add_rms_norm_moore.h" |
| 3 | + |
| 4 | +#include "../../../devices/moore/moore_kernel_common.h" |
| 5 | +#include <cub/block/block_reduce.cuh> |
| 6 | + |
| 7 | +#include "../../../reduce/cuda/reduce.cuh" |
| 8 | + |
| 9 | +#include "../cuda/kernel.cuh" |
| 10 | + |
| 11 | +// Kernel function template for add_rms_norm on Moore platform |
| 12 | +template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight> |
| 13 | +INFINIOP_MOORE_KERNEL add_rmsnormKernel( |
| 14 | + Tdata *__restrict__ y, |
| 15 | + Tdata *__restrict__ residual_out, |
| 16 | + ptrdiff_t stride_y_batch, |
| 17 | + ptrdiff_t stride_y_nhead, |
| 18 | + ptrdiff_t stride_residual_out_batch, |
| 19 | + ptrdiff_t stride_residual_out_nhead, |
| 20 | + const Tdata *__restrict__ a, |
| 21 | + ptrdiff_t stride_a_batch, |
| 22 | + ptrdiff_t stride_a_nhead, |
| 23 | + const Tdata *__restrict__ b, |
| 24 | + ptrdiff_t stride_b_batch, |
| 25 | + ptrdiff_t stride_b_nhead, |
| 26 | + const Tweight *__restrict__ w, |
| 27 | + size_t nhead, |
| 28 | + size_t dim, |
| 29 | + float epsilon) { |
| 30 | + add_rmsnormBlock<BLOCK_SIZE, Tcompute>( |
| 31 | + y, residual_out, |
| 32 | + stride_y_batch, stride_y_nhead, |
| 33 | + stride_residual_out_batch, stride_residual_out_nhead, |
| 34 | + a, stride_a_batch, stride_a_nhead, |
| 35 | + b, stride_b_batch, stride_b_nhead, |
| 36 | + w, nhead, dim, epsilon); |
| 37 | +} |
| 38 | + |
| 39 | +namespace op::add_rms_norm::moore { |
| 40 | + |
| 41 | +// Internal opaque structure for Moore device handle |
| 42 | +struct Descriptor::Opaque { |
| 43 | + std::shared_ptr<device::moore::Handle::Internal> internal; |
| 44 | +}; |
| 45 | + |
| 46 | +// Destructor |
| 47 | +Descriptor::~Descriptor() { |
| 48 | + delete _opaque; |
| 49 | +} |
| 50 | + |
| 51 | +// Create descriptor for add_rms_norm operator |
| 52 | +infiniStatus_t Descriptor::create( |
| 53 | + infiniopHandle_t handle, |
| 54 | + Descriptor **desc_ptr, |
| 55 | + infiniopTensorDescriptor_t y_desc, |
| 56 | + infiniopTensorDescriptor_t a_desc, |
| 57 | + infiniopTensorDescriptor_t b_desc, |
| 58 | + infiniopTensorDescriptor_t weight_desc, |
| 59 | + float epsilon, |
| 60 | + infiniopTensorDescriptor_t residual_out_desc) { |
| 61 | + auto result = AddRMSNormInfo::create(y_desc, a_desc, b_desc, weight_desc, epsilon, residual_out_desc); |
| 62 | + CHECK_RESULT(result); |
| 63 | + auto info = result.take(); |
| 64 | + |
| 65 | + *desc_ptr = new Descriptor( |
| 66 | + new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()}, |
| 67 | + std::move(info), |
| 68 | + 0, |
| 69 | + handle->device, handle->device_id); |
| 70 | + return INFINI_STATUS_SUCCESS; |
| 71 | +} |
| 72 | + |
| 73 | +// Launch kernel with different data types |
| 74 | +template <unsigned int BLOCK_SIZE> |
| 75 | +infiniStatus_t launchKernel( |
| 76 | + uint32_t batch_size, size_t nhead, size_t dim, |
| 77 | + void *y, infiniDtype_t atype, ptrdiff_t stride_y_batch, ptrdiff_t stride_y_nhead, |
| 78 | + void *residual_out, ptrdiff_t stride_residual_out_batch, ptrdiff_t stride_residual_out_nhead, |
| 79 | + const void *a, ptrdiff_t stride_a_batch, ptrdiff_t stride_a_nhead, |
| 80 | + const void *b, ptrdiff_t stride_b_batch, ptrdiff_t stride_b_nhead, |
| 81 | + const void *w, infiniDtype_t wtype, |
| 82 | + float epsilon, |
| 83 | + musaStream_t musa_stream) { |
| 84 | + |
| 85 | +#define LAUNCH_KERNEL(Tdata, Tweight, Tcompute) \ |
| 86 | + add_rmsnormKernel<BLOCK_SIZE, Tcompute, Tdata, Tweight><<<batch_size * nhead, BLOCK_SIZE, 0, musa_stream>>>( \ |
| 87 | + reinterpret_cast<Tdata *>(y), \ |
| 88 | + reinterpret_cast<Tdata *>(residual_out), \ |
| 89 | + stride_y_batch, \ |
| 90 | + stride_y_nhead, \ |
| 91 | + stride_residual_out_batch, \ |
| 92 | + stride_residual_out_nhead, \ |
| 93 | + reinterpret_cast<const Tdata *>(a), \ |
| 94 | + stride_a_batch, \ |
| 95 | + stride_a_nhead, \ |
| 96 | + reinterpret_cast<const Tdata *>(b), \ |
| 97 | + stride_b_batch, \ |
| 98 | + stride_b_nhead, \ |
| 99 | + reinterpret_cast<const Tweight *>(w), \ |
| 100 | + nhead, \ |
| 101 | + dim, \ |
| 102 | + epsilon) |
| 103 | + |
| 104 | + // Handle different data type combinations |
| 105 | + if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { |
| 106 | + LAUNCH_KERNEL(half, half, float); |
| 107 | + } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) { |
| 108 | + LAUNCH_KERNEL(half, __mt_bfloat16, float); |
| 109 | + } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) { |
| 110 | + LAUNCH_KERNEL(half, float, float); |
| 111 | + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) { |
| 112 | + LAUNCH_KERNEL(__mt_bfloat16, __mt_bfloat16, float); |
| 113 | + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) { |
| 114 | + LAUNCH_KERNEL(__mt_bfloat16, half, float); |
| 115 | + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) { |
| 116 | + LAUNCH_KERNEL(__mt_bfloat16, float, float); |
| 117 | + } else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) { |
| 118 | + LAUNCH_KERNEL(float, float, float); |
| 119 | + } else { |
| 120 | + return INFINI_STATUS_BAD_TENSOR_DTYPE; |
| 121 | + } |
| 122 | + |
| 123 | +#undef LAUNCH_KERNEL |
| 124 | + |
| 125 | + return INFINI_STATUS_SUCCESS; |
| 126 | +} |
| 127 | + |
| 128 | +// Main calculation function |
| 129 | +infiniStatus_t Descriptor::calculate( |
| 130 | + void *workspace, size_t workspace_size, |
| 131 | + void *y, const void *a, const void *b, const void *weight, |
| 132 | + void *residual_out, void *stream) const { |
| 133 | + |
| 134 | + // Check workspace size |
| 135 | + if (workspace_size < _workspace_size) { |
| 136 | + return INFINI_STATUS_INSUFFICIENT_WORKSPACE; |
| 137 | + } |
| 138 | + |
| 139 | + // Extract tensor strides and dimensions |
| 140 | + auto stride_a_batch = _info.a_strides[0]; |
| 141 | + auto stride_a_nhead = _info.a_strides[1]; |
| 142 | + auto stride_b_batch = _info.b_strides[0]; |
| 143 | + auto stride_b_nhead = _info.b_strides[1]; |
| 144 | + auto stride_y_batch = _info.y_strides[0]; |
| 145 | + auto stride_y_nhead = _info.y_strides[1]; |
| 146 | + auto stride_residual_out_batch = _info.residual_out_strides[0]; |
| 147 | + auto stride_residual_out_nhead = _info.residual_out_strides[1]; |
| 148 | + auto dim = _info.dim(); |
| 149 | + uint32_t batch_size = static_cast<uint32_t>(_info.shape[0]); |
| 150 | + size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1; |
| 151 | + auto musa_stream = reinterpret_cast<musaStream_t>(stream); |
| 152 | + |
| 153 | + // Launch kernel with appropriate block size based on device capability |
| 154 | + if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) { |
| 155 | + CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_1024>( |
| 156 | + batch_size, nhead, dim, |
| 157 | + y, _info.atype, stride_y_batch, stride_y_nhead, |
| 158 | + residual_out, stride_residual_out_batch, stride_residual_out_nhead, |
| 159 | + a, stride_a_batch, stride_a_nhead, |
| 160 | + b, stride_b_batch, stride_b_nhead, |
| 161 | + weight, _info.wtype, _info.epsilon, musa_stream)); |
| 162 | + } else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) { |
| 163 | + CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_512>( |
| 164 | + batch_size, nhead, dim, |
| 165 | + y, _info.atype, stride_y_batch, stride_y_nhead, |
| 166 | + residual_out, stride_residual_out_batch, stride_residual_out_nhead, |
| 167 | + a, stride_a_batch, stride_a_nhead, |
| 168 | + b, stride_b_batch, stride_b_nhead, |
| 169 | + weight, _info.wtype, _info.epsilon, musa_stream)); |
| 170 | + } else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_2048) { |
| 171 | + CHECK_STATUS(launchKernel<MOORE_BLOCK_SIZE_2048>( |
| 172 | + batch_size, nhead, dim, |
| 173 | + y, _info.atype, stride_y_batch, stride_y_nhead, |
| 174 | + residual_out, stride_residual_out_batch, stride_residual_out_nhead, |
| 175 | + a, stride_a_batch, stride_a_nhead, |
| 176 | + b, stride_b_batch, stride_b_nhead, |
| 177 | + weight, _info.wtype, _info.epsilon, musa_stream)); |
| 178 | + } else { |
| 179 | + return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; |
| 180 | + } |
| 181 | + return INFINI_STATUS_SUCCESS; |
| 182 | +} |
| 183 | +} // namespace op::add_rms_norm::moore |
0 commit comments