Skip to content

Commit 7b94bd9

Browse files
[common] Added support of FP4 data type (NVIDIA#1779)
* Added support of FP4 data type Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * Refactoring to BitsNum in progress Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * Fixed compilation errors. All C++ tests passed Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * Fixed a typo Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added FP4 guard to TMA tensor descriptor data type Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed errors in JAX C++ extensions Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Removed dummy NVFP4 C++ test file Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * Make pytorch changes Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * Refactored the code per the review notes. Fixed JAX build error. Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Removed unnecessary static casts Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> * Typo fix Signed-off-by: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com> * Pass correct num bits to create_2D_tensor_map; fixes CI Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> * inline funcs Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> --------- Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
1 parent e963e4a commit 7b94bd9

23 files changed

+391
-169
lines changed

tests/cpp/operator/test_normalization.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ inline auto compute_gamma(InputType gamma, const bool zero_centered_gamma, const
6767
// Remove the use_cudnn check here when it is supported by both backends.
6868
const bool zero_centered_gamma_in_weight_dtype = use_cudnn && cudnn_zero_centered_gamma_in_weight_dtype;
6969

70-
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3>){
70+
if constexpr (std::is_same_v<InputType, fp8e5m2> || std::is_same_v<InputType, fp8e4m3> ||
71+
std::is_same_v<InputType, fp4e2m1>){
7172
compute_t g = static_cast<compute_t>(gamma);
7273
if (zero_centered_gamma) {
7374
g += static_cast<compute_t>(1.f);

tests/cpp/test_common.cu

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) {
4545
return true;
4646
}
4747

48-
size_t typeToSize(DType type) {
48+
size_t typeToNumBits(DType type) {
4949
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
5050
{
5151
return TypeInfo<T>::size;
@@ -62,7 +62,8 @@ const std::string &typeName(DType type) {
6262
{DType::kBFloat16, "bfloat16"},
6363
{DType::kFloat8E4M3, "float8e4m3"},
6464
{DType::kFloat8E5M2, "float8e5m2"},
65-
{DType::kFloat8E8M0, "float8e8m0"}};
65+
{DType::kFloat8E8M0, "float8e8m0"},
66+
{DType::kFloat4E2M1, "float4e2m1"}};
6667
return name_map.at(type);
6768
}
6869

@@ -109,9 +110,16 @@ size_t DIVUP(const size_t &x, const size_t &y){
109110
struct scale_inv_meta {
110111
std::vector<size_t> shape;
111112
DType type;
112-
size_t type_size;
113+
size_t type_size_bits;
114+
size_t bytes() const noexcept {
115+
return (product(shape) * type_size_bits) / 8;
116+
}
113117
};
114118

119+
size_t bytes(const NVTEShape& shape, const DType type) {
120+
return (product(shape) * typeToNumBits(type)) / 8;
121+
}
122+
115123
NVTEShape convertShape(const std::vector<size_t>& s) {
116124
return nvte_make_shape(s.data(), s.size());
117125
}
@@ -122,7 +130,7 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
122130
scale_inv_meta ret;
123131
ret.shape = {1};
124132
ret.type = DType::kFloat32;
125-
ret.type_size = sizeof(float);
133+
ret.type_size_bits = typeToNumBits(DType::kFloat32);
126134
return {ret, ret};
127135
}
128136
if (scaling_mode == NVTE_MXFP8_1D_SCALING) {
@@ -152,8 +160,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
152160
}
153161
ret_rowwise.type = DType::kFloat8E8M0;
154162
ret_colwise.type = DType::kFloat8E8M0;
155-
ret_rowwise.type_size = sizeof(uint8_t);
156-
ret_colwise.type_size = sizeof(uint8_t);
163+
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
164+
ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0);
157165

158166
return {ret_rowwise, ret_colwise};
159167
}
@@ -179,8 +187,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
179187
}
180188
ret_rowwise.type = DType::kFloat32;
181189
ret_colwise.type = DType::kFloat32;
182-
ret_rowwise.type_size = sizeof(float);
183-
ret_colwise.type_size = sizeof(float);
190+
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
191+
ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
184192

185193
return {ret_rowwise, ret_colwise};
186194
}
@@ -205,8 +213,8 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
205213
}
206214
ret_rowwise.type = DType::kFloat32;
207215
ret_colwise.type = DType::kFloat32;
208-
ret_rowwise.type_size = sizeof(float);
209-
ret_colwise.type_size = sizeof(float);
216+
ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat32);
217+
ret_colwise.type_size_bits = typeToNumBits(DType::kFloat32);
210218
return {ret_rowwise, ret_colwise};
211219
}
212220

@@ -222,8 +230,7 @@ Tensor::Tensor(const std::string& name,
222230
gen_.seed(seed);
223231
rowwise_ = rowwise;
224232
columnwise_ = columnwise;
225-
size_t s = typeToSize(type);
226-
size_t total_size = product(shape) * s;
233+
size_t total_size = bytes(shape, type);
227234
void *dptr_rowwise = nullptr;
228235
void *dptr_columnwise = nullptr;
229236
cpu_data_rowwise_ = nullptr;
@@ -305,8 +312,8 @@ Tensor::Tensor(const std::string& name,
305312
} else {
306313
auto [rowwise_scale_meta, colwise_scale_meta] =
307314
get_scales(normalized_shape, tensor_.scaling_mode());
308-
auto rowwise_scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
309-
auto columnwise_scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size;
315+
auto rowwise_scale_size = rowwise_scale_meta.bytes();
316+
auto columnwise_scale_size = colwise_scale_meta.bytes();
310317
auto scale_shape = rowwise_scale_meta.shape;
311318
auto columnwise_scale_shape = colwise_scale_meta.shape;
312319
if (rowwise) {
@@ -331,7 +338,7 @@ Tensor::Tensor(const std::string& name,
331338

332339
void Tensor::to_cpu() const {
333340
const NVTEShape s = tensor_.shape();
334-
const size_t size = product(s) * typeToSize(tensor_.dtype());
341+
const size_t size = bytes(s, tensor_.dtype());
335342
if (rowwise_) {
336343
cudaMemcpy(cpu_data_rowwise_.get(),
337344
tensor_.get_rowwise_data().data_ptr,
@@ -360,14 +367,14 @@ void Tensor::to_cpu() const {
360367
auto [rowwise_scale_meta, colwise_scale_meta] =
361368
get_scales(s, tensor_.scaling_mode());
362369
if (rowwise_) {
363-
auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
370+
auto scale_size = rowwise_scale_meta.bytes();
364371
cudaMemcpy(rowwise_scale_inv_cpu_data_.get(),
365372
tensor_.get_rowwise_scale_inv().data_ptr,
366373
scale_size,
367374
cudaMemcpyDeviceToHost);
368375
}
369376
if (columnwise_) {
370-
auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size;
377+
auto scale_size = colwise_scale_meta.bytes();
371378
cudaMemcpy(columnwise_scale_inv_cpu_data_.get(),
372379
tensor_.get_columnwise_scale_inv().data_ptr,
373380
scale_size,
@@ -378,34 +385,32 @@ void Tensor::to_cpu() const {
378385

379386
void Tensor::from_cpu() const {
380387
const NVTEShape s = tensor_.shape();
381-
const size_t size = product(s) * typeToSize(tensor_.dtype());
388+
const size_t size = bytes(s, tensor_.dtype());
382389
if (rowwise_) {
383-
cudaMemcpy(tensor_.get_rowwise_data().data_ptr,
384-
cpu_data_rowwise_.get(), size, cudaMemcpyHostToDevice);
390+
cudaMemcpy(tensor_.get_rowwise_data().data_ptr, cpu_data_rowwise_.get(), size,
391+
cudaMemcpyHostToDevice);
385392
}
386393
if (columnwise_) {
387-
cudaMemcpy(tensor_.get_columnwise_data().data_ptr,
388-
cpu_data_columnwise_.get(), size, cudaMemcpyHostToDevice);
394+
cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size,
395+
cudaMemcpyHostToDevice);
389396
}
390397
if (isFp8Type(dtype())) {
391398
if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) {
392399
if (tensor_.amax() != nullptr){
393-
cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float),
394-
cudaMemcpyHostToDevice);
400+
cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
395401
}
396-
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float),
397-
cudaMemcpyHostToDevice);
402+
cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice);
398403
}
399404
auto [rowwise_scale_meta, colwise_scale_meta] =
400405
get_scales(s, tensor_.scaling_mode());
401406
if (rowwise_) {
402-
auto scale_size = product(rowwise_scale_meta.shape) * rowwise_scale_meta.type_size;
407+
auto scale_size = rowwise_scale_meta.bytes();
403408
cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr,
404409
rowwise_scale_inv_cpu_data_.get(), scale_size,
405410
cudaMemcpyHostToDevice);
406411
}
407412
if (columnwise_) {
408-
auto scale_size = product(colwise_scale_meta.shape) * colwise_scale_meta.type_size;
413+
auto scale_size = colwise_scale_meta.bytes();
409414
cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr,
410415
columnwise_scale_inv_cpu_data_.get(), scale_size,
411416
cudaMemcpyHostToDevice);

tests/cpp/test_common.h

Lines changed: 61 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,15 @@
1010
#include <vector>
1111
#include <array>
1212
#include <random>
13+
#include <cudaTypedefs.h>
14+
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
1315

1416
#include <cuda_bf16.h>
1517
#include <cuda_fp16.h>
1618
#include <cuda_fp8.h>
19+
#if FP4_TYPE_SUPPORTED
20+
#include <cuda_fp4.h>
21+
#endif
1722
#include <cuda_runtime_api.h>
1823

1924
#include <transformer_engine/transformer_engine.h>
@@ -55,19 +60,32 @@ using bf16 = nv_bfloat16;
5560
using fp8e4m3 = __nv_fp8_e4m3;
5661
using fp8e5m2 = __nv_fp8_e5m2;
5762
using fp8e8m0 = uint8_t;
63+
#if FP4_TYPE_SUPPORTED
64+
using fp4e2m1 = __nv_fp4_e2m1;
65+
#endif
5866

5967
template <typename T>
60-
struct TypeInfo{
61-
using types = std::tuple<byte,
62-
int16,
63-
int32,
64-
int64,
65-
fp32,
66-
fp16,
67-
bf16,
68-
fp8e4m3,
69-
fp8e5m2,
70-
fp8e8m0>;
68+
struct BitsNumber;
69+
70+
#if FP4_TYPE_SUPPORTED
71+
template <>
72+
struct BitsNumber<fp4e2m1> {
73+
static constexpr size_t num_bits = 4;
74+
};
75+
#endif
76+
77+
template <typename T>
78+
struct BitsNumber {
79+
static constexpr size_t num_bits = 8 * sizeof(T);
80+
};
81+
82+
template <typename T>
83+
struct TypeInfo {
84+
#if FP4_TYPE_SUPPORTED
85+
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0, fp4e2m1>;
86+
#else
87+
using types = std::tuple<byte, int16, int32, int64, fp32, fp16, bf16, fp8e4m3, fp8e5m2, fp8e8m0>;
88+
#endif
7189

7290
template <typename U, DType current>
7391
struct Helper {
@@ -94,7 +112,7 @@ struct TypeInfo{
94112
}
95113

96114
constexpr static DType dtype = getType<T>();
97-
constexpr static size_t size = sizeof(T);
115+
constexpr static size_t size = BitsNumber<T>::num_bits;;
98116
};
99117

100118
class Tensor {
@@ -416,9 +434,10 @@ inline float dsilu(const float x) { return x * dsigmoid(x) + sigmoid(x); }
416434
inline float srelu(const float x) { return x > 0 ? x * x : 0; }
417435
inline float dsrelu(const float x) { return fmaxf(0, 2 * x); }
418436

419-
size_t typeToSize(DType type);
437+
size_t typeToNumBits(DType type);
420438
size_t product(const NVTEShape &shape);
421439
size_t product(const std::vector<size_t> &shape);
440+
size_t bytes(const NVTEShape& shape, const DType type);
422441

423442
size_t first_dimension(const std::vector<size_t> &shape);
424443
size_t last_dimension(const std::vector<size_t> &shape);
@@ -464,6 +483,16 @@ constexpr int32_t blackwellComputeCapability = 100;
464483

465484
} // namespace test
466485

486+
#if FP4_TYPE_SUPPORTED
487+
#define SWITCH_FP4_TYPE_HANDLE(type, ...) \
488+
case DType::kFloat4E2M1: { \
489+
using type = fp4e2m1; \
490+
{ __VA_ARGS__ } \
491+
} break;
492+
#else
493+
#define SWITCH_FP4_TYPE_HANDLE(type, ...) // do nothing
494+
#endif
495+
467496
#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
468497
switch (dtype) { \
469498
using namespace transformer_engine; \
@@ -515,8 +544,16 @@ constexpr int32_t blackwellComputeCapability = 100;
515544
{__VA_ARGS__} \
516545
} \
517546
break; \
547+
case DType::kFloat8E8M0: \
548+
{ \
549+
using type = fp8e8m0; \
550+
{__VA_ARGS__} \
551+
} \
552+
break; \
553+
SWITCH_FP4_TYPE_HANDLE(type, __VA_ARGS__) \
518554
default: \
519-
NVTE_ERROR("Invalid type."); \
555+
printf("dtype: %d\n", static_cast<int>(dtype)); \
556+
NVTE_ERROR("Invalid type MARKED TEST."); \
520557
}
521558

522559
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(dtype, type, ...) \
@@ -535,7 +572,15 @@ constexpr int32_t blackwellComputeCapability = 100;
535572
} \
536573
break; \
537574
default: \
538-
NVTE_ERROR("Invalid type."); \
575+
NVTE_ERROR("Invalid type MARKED TEST 2."); \
576+
}
577+
578+
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP4_ONLY(dtype, type, ...) \
579+
switch (dtype) { \
580+
using namespace transformer_engine; \
581+
SWITCH_FP4_HANDLE(type, __VA_ARGS__) \
582+
default: \
583+
NVTE_ERROR("Invalid type MARKED TEST 3."); \
539584
}
540585

541586
#define TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(dtype, type, ...) \
@@ -560,5 +605,5 @@ constexpr int32_t blackwellComputeCapability = 100;
560605
} \
561606
break; \
562607
default: \
563-
NVTE_ERROR("Invalid type."); \
608+
NVTE_ERROR("Invalid type MARKED TEST 4."); \
564609
}

0 commit comments

Comments
 (0)