From d0dd9a3558a9dce28747e23eb33d61bbfe09ecb0 Mon Sep 17 00:00:00 2001 From: kilinchange Date: Fri, 15 Nov 2024 17:28:48 +0800 Subject: [PATCH] fix(matmul): fix cpu matmul --- src/devices/cpu/common_cpu.cc | 25 +++++++++++-------------- src/ops/matmul/cpu/matmul_cpu.cc | 6 +++++- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/devices/cpu/common_cpu.cc b/src/devices/cpu/common_cpu.cc index b5b5f0fd..3e097446 100644 --- a/src/devices/cpu/common_cpu.cc +++ b/src/devices/cpu/common_cpu.cc @@ -44,22 +44,19 @@ uint16_t f32_to_f16(float val) { int32_t exponent = ((f32 >> 23) & 0xFF) - 127;// Extract and de-bias the exponent uint32_t mantissa = f32 & 0x7FFFFF; // Extract the mantissa (fraction part) - if (exponent == 128) {// Special case for Inf and NaN - if (mantissa != 0) { - // NaN - return sign | 0x7C00 | (mantissa >> 13);// Convert the NaN payload - } else { - // Infinity - return sign | 0x7C00; + if (exponent >= 31) {// Special cases for Inf and NaN + // NaN + if (exponent == 128 && mantissa != 0) { + return sign | 0x7E00; } - } else if (exponent > 15) { // Overflow: Larger than float16 max - return sign | 0x7C00; // Return infinity - } else if (exponent >= -14) {// Normalized float16 + // Infinity + return sign | 0x7C00; + } else if (exponent >= -14) {// Normalized case return sign | ((exponent + 15) << 10) | (mantissa >> 13); - } else if (exponent >= -24) { // Subnormal float16 (leading denormals) - mantissa |= 0x800000; // Add implicit leading 1 - int32_t shift = -exponent - 1;// Calculate shift for subnormal numbers - return sign | (mantissa >> (13 + shift)); + } else if (exponent >= -24) { + mantissa |= 0x800000;// Add implicit leading 1 + mantissa >>= (-14 - exponent); + return sign | (mantissa >> 13); } else { // Too small for subnormal: return signed zero return sign; diff --git a/src/ops/matmul/cpu/matmul_cpu.cc b/src/ops/matmul/cpu/matmul_cpu.cc index b6148852..2dcc9d2e 100644 --- a/src/ops/matmul/cpu/matmul_cpu.cc +++ b/src/ops/matmul/cpu/matmul_cpu.cc @@ -64,7 +64,11 @@ infiniopStatus_t matmul_cpu(MatmulCpuDescriptor_t desc, void *c, float beta, voi } } if constexpr (std::is_same::value) { - *c_ = f32_to_f16(beta * f16_to_f32(*c_) + alpha * sum); + if (beta == 0) { + *c_ = f32_to_f16(alpha * sum); + } else { + *c_ = f32_to_f16(beta * f16_to_f32(*c_) + alpha * sum); + } } else { *c_ = beta * (*c_) + alpha * sum; }