diff --git a/excuter/op-mem-cuda/src/deepx/tensorfunc/cuda_math.cuh b/excuter/op-mem-cuda/src/deepx/tensorfunc/cuda_math.cuh index ff7117f..9f2c0b1 100644 --- a/excuter/op-mem-cuda/src/deepx/tensorfunc/cuda_math.cuh +++ b/excuter/op-mem-cuda/src/deepx/tensorfunc/cuda_math.cuh @@ -4,6 +4,7 @@ #include #include #include +#include #include namespace deepx::tensorfunc @@ -37,6 +38,27 @@ namespace deepx::tensorfunc *out = hsqrt(*a); } + template <> + __device__ __forceinline__ void deepx_sqrt<__nv_fp8_e4m3>(const __nv_fp8_e4m3 *a, __nv_fp8_e4m3 *out) + { + __half input_fp16 = __nv_cvt_fp8_to_halfraw(static_cast<__nv_fp8_storage_t>(*a), __NV_E4M3); + __half result_fp16 = hsqrt(input_fp16); // CUDA 内置半精度平方根 + *out = static_cast<__nv_fp8_e4m3>(__nv_cvt_halfraw_to_fp8(result_fp16, __NV_SATFINITE, __NV_E4M3)); + } + + template <> + __device__ __forceinline__ void deepx_sqrt<__nv_fp8_e5m2>(const __nv_fp8_e5m2 *a, __nv_fp8_e5m2 *out) + { + __half input_fp16 = __nv_cvt_fp8_to_halfraw(static_cast<__nv_fp8_storage_t>(*a), __NV_E5M2); + + // 2. 执行平方根 + __half result_fp16 = hsqrt(input_fp16); + + // 3. 转回 FP8 → E5M2 格式 + *out =static_cast<__nv_fp8_e5m2>(__nv_cvt_halfraw_to_fp8(result_fp16, __NV_SATFINITE, __NV_E5M2)); + } + + // pow template __device__ __forceinline__ void deepx_pow(const T *a, const T *b, T *out);