Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions src/devices/cpu/common_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,22 +44,19 @@ uint16_t f32_to_f16(float val) {
int32_t exponent = ((f32 >> 23) & 0xFF) - 127;// Extract and de-bias the exponent
uint32_t mantissa = f32 & 0x7FFFFF; // Extract the mantissa (fraction part)

if (exponent == 128) {// Special case for Inf and NaN
if (mantissa != 0) {
// NaN
return sign | 0x7C00 | (mantissa >> 13);// Convert the NaN payload
} else {
// Infinity
return sign | 0x7C00;
if (exponent >= 31) {// Special cases for Inf and NaN
// NaN
if (exponent == 128 && mantissa != 0) {
return sign | 0x7E00;
}
} else if (exponent > 15) { // Overflow: Larger than float16 max
return sign | 0x7C00; // Return infinity
} else if (exponent >= -14) {// Normalized float16
// Infinity
return sign | 0x7C00;
} else if (exponent >= -14) {// Normalized case
return sign | ((exponent + 15) << 10) | (mantissa >> 13);
} else if (exponent >= -24) { // Subnormal float16 (leading denormals)
mantissa |= 0x800000; // Add implicit leading 1
int32_t shift = -exponent - 1;// Calculate shift for subnormal numbers
return sign | (mantissa >> (13 + shift));
} else if (exponent >= -24) {
mantissa |= 0x800000;// Add implicit leading 1
mantissa >>= (-14 - exponent);
return sign | (mantissa >> 13);
} else {
// Too small for subnormal: return signed zero
return sign;
Expand Down
6 changes: 5 additions & 1 deletion src/ops/matmul/cpu/matmul_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,11 @@ infiniopStatus_t matmul_cpu(MatmulCpuDescriptor_t desc, void *c, float beta, voi
}
}
if constexpr (std::is_same<Tdata, uint16_t>::value) {
*c_ = f32_to_f16(beta * f16_to_f32(*c_) + alpha * sum);
if (beta == 0) {
*c_ = f32_to_f16(alpha * sum);
} else {
*c_ = f32_to_f16(beta * f16_to_f32(*c_) + alpha * sum);
}
} else {
*c_ = beta * (*c_) + alpha * sum;
}
Expand Down
Loading