From c109e0d48958afa002ddd24ea4d2795f00ddbd28 Mon Sep 17 00:00:00 2001 From: John <1229775764@qq.com> Date: Thu, 10 Jul 2025 22:13:40 +0800 Subject: [PATCH] sqrt support fp8 --- .../src/deepx/tensorfunc/cuda_math.cuh | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/excuter/op-mem-cuda/src/deepx/tensorfunc/cuda_math.cuh b/excuter/op-mem-cuda/src/deepx/tensorfunc/cuda_math.cuh index ff7117f3..9f2c0b1e 100644 --- a/excuter/op-mem-cuda/src/deepx/tensorfunc/cuda_math.cuh +++ b/excuter/op-mem-cuda/src/deepx/tensorfunc/cuda_math.cuh @@ -4,6 +4,7 @@ #include #include #include +#include #include namespace deepx::tensorfunc @@ -37,6 +38,27 @@ namespace deepx::tensorfunc *out = hsqrt(*a); } + template <> + __device__ __forceinline__ void deepx_sqrt<__nv_fp8_e4m3>(const __nv_fp8_e4m3 *a, __nv_fp8_e4m3 *out) + { + __half input_fp16 = __nv_cvt_fp8_to_halfraw(static_cast<__nv_fp8_storage_t>(*a), __NV_E4M3); + __half result_fp16 = hsqrt(input_fp16); // CUDA 内置半精度平方根 + *out = static_cast<__nv_fp8_e4m3>(__nv_cvt_halfraw_to_fp8(result_fp16, __NV_SATFINITE, __NV_E4M3)); + } + + template <> + __device__ __forceinline__ void deepx_sqrt<__nv_fp8_e5m2>(const __nv_fp8_e5m2 *a, __nv_fp8_e5m2 *out) + { + __half input_fp16 = __nv_cvt_fp8_to_halfraw(static_cast<__nv_fp8_storage_t>(*a), __NV_E5M2); + + // 2. 执行平方根 + __half result_fp16 = hsqrt(input_fp16); + + // 3. 转回 FP8 → E5M2 格式 + *out =static_cast<__nv_fp8_e5m2>(__nv_cvt_halfraw_to_fp8(result_fp16, __NV_SATFINITE, __NV_E5M2)); + } + + // pow template __device__ __forceinline__ void deepx_pow(const T *a, const T *b, T *out);