diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..2b2502c69 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,16 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python: Current File", + "type": "python", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal", + "justMyCode": false + } + ] +} \ No newline at end of file diff --git a/Makefile b/Makefile index 7ccbcb191..8f786442f 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,8 @@ BUILD_DIR:= $(ROOT_DIR)/build FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c -INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include +# INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include +INCLUDE := -I /usr/local/cuda-11.3/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcusparse -L $(CONDA_PREFIX)/lib # NVIDIA NVCC compilation flags diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 837f6bf1c..2802a1183 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -838,7 +838,8 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}") post_call(A.device) - + # print(absmax) + # print(out) datatype = get_4bit_type(quant_type, device=A.device) if compress_statistics: diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index 3d34bb45f..c9526873b 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -154,9 +154,16 @@ def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64, def cuda(self, device): w = self.data.contiguous().half().cuda(device) + rw=self.data.contiguous() w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type) self.data = w_4bit self.quant_state = quant_state + absmax=torch.zeros(32,dtype=torch.float) + out = torch.zeros(1024,1, dtype=torch.uint8) + torch.ops.load_library("/home/zhe/bitsandbytes/tests/custom_op/build/libcustom_allreduce_op.so") + torch.ops.my_ops.ref_fp4_quantize(rw,absmax,out,64,2048) + # print(absmax) + # print(out) return self diff --git a/csrc/ops.cuh b/csrc/ops.cuh index f37b3b3af..0b6b0418d 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -14,6 +14,7 @@ #include #include +#include #include #include #include diff --git a/tests/custom_op/CMakeLists.txt b/tests/custom_op/CMakeLists.txt new file mode 100644 index 000000000..a4db9f82d --- /dev/null +++ b/tests/custom_op/CMakeLists.txt @@ -0,0 +1,14 @@ +cmake_minimum_required(VERSION 3.1 FATAL_ERROR) +project(custom_allreduce_op LANGUAGES CXX CUDA) + +find_package(Torch REQUIRED) + +# Define our library target +add_library(custom_allreduce_op SHARED pyt_all_reduce_op.cpp pyt_all_reduce_kernel.cu) +# Enable C++14 +target_compile_features(custom_allreduce_op PRIVATE cxx_std_14) +# Link against LibTorch +target_link_libraries(custom_allreduce_op "${TORCH_LIBRARIES}") + +set_property(TARGET torch_cuda PROPERTY INTERFACE_COMPILE_OPTIONS "") +set_property(TARGET torch_cpu PROPERTY INTERFACE_COMPILE_OPTIONS "") diff --git a/tests/custom_op/pyt_all_reduce_cpu_test.py b/tests/custom_op/pyt_all_reduce_cpu_test.py new file mode 100644 index 000000000..043f8e317 --- /dev/null +++ b/tests/custom_op/pyt_all_reduce_cpu_test.py @@ -0,0 +1,6 @@ +import torch +torch.ops.load_library("build/libcustom_allreduce_op.so") +A = torch.ones(1024, dtype=torch.int) +b = torch.ops.my_ops.custom_allreduce(A) +print(b) + diff --git a/tests/custom_op/pyt_all_reduce_gpu_test.py b/tests/custom_op/pyt_all_reduce_gpu_test.py new file mode 100644 index 000000000..d61f73175 --- /dev/null +++ b/tests/custom_op/pyt_all_reduce_gpu_test.py @@ -0,0 +1,6 @@ +import torch +torch.ops.load_library("build/libcustom_allreduce_op.so") +A = torch.ones(1024, dtype=torch.float32, device='cuda') +b = torch.ops.my_ops.custom_allreduce(A.half()) +print(b[0]) #.to('cpu')) + diff --git a/tests/custom_op/pyt_all_reduce_kernel.cu b/tests/custom_op/pyt_all_reduce_kernel.cu new file mode 100644 index 000000000..0cd992907 --- /dev/null +++ b/tests/custom_op/pyt_all_reduce_kernel.cu @@ -0,0 +1,157 @@ +#include + +#include "pyt_all_reduce_kernel.hh" +#define BLOCKX_DIM 256 + +template +void cpu_all_reduce(float* sum, scalar_t* data, int n) { + scalar_t temp_sum = 0; + for (int i = 0; i < n; ++i) { + temp_sum += data[i]; + } + *sum = temp_sum; +} + +template +__global__ void gpu_all_reduce(float* sum, scalar_t* data, int n) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + scalar_t temp = 0; + for (int i = idx; i < n; i += stride) { + temp += data[i]; + } + + atomicAdd(sum, temp); +} + +torch::Tensor all_reduce_launcher(torch::Tensor input) { + torch::Device device(torch::kCUDA, 0); + torch::Tensor output = torch::zeros(1, torch::kFloat); + if (input.device() == device) { + output = output.to(device); + dim3 blockSize(BLOCKX_DIM); + dim3 gridSize((input.size(0) + BLOCKX_DIM - 1) / BLOCKX_DIM); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + input.type(), "gpu_all_reduce", ([&] { + gpu_all_reduce<<>>( + output.data_ptr(), input.data_ptr(), + input.size(0)); + })); + } else { + cpu_all_reduce(output.data_ptr(), input.data_ptr(), + input.size(0)); + } + return output; +} + +float dDequantizeFP4Tree(unsigned char val, float absmax) { + float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f; + if ((val & 0b0100) == 4) // 0 + if ((val & 0b0010) == 2) // 01 + if ((val & 0b0001) == 1) // 111 + return 0.25000000f * absmax * sign; // 1111 + else + return 0.16666667f * absmax * sign; // 1110 + else if ((val & 0b0001) == 1) // 110 + return 0.50000000f * absmax * sign; // 1101 + else + return 0.33333333f * absmax * sign; // 1100 + else if ((val & 0b0010) == 2) // 10 + if ((val & 0b0001) == 1) // 101 + return 1.00000000f * absmax * sign; // 1011 + else + return 0.66666667f * absmax * sign; // 1010 + else if ((val & 0b0001) == 1) // 100 + return 5.208333333e-03f * absmax * sign; // 1001 + else + return 0.00000000f * absmax * sign; // 1000 +} + +unsigned char dQuantizeFP4(float x) { + // FP4 with bias of 3 + // first bit is a sign + // subnormals + // 0b000 = 0 + // 0b001 = 0.0625 + // 0b110 = 2 + // 0b111 = 3 + // 0b100 = 4 + // 0b101 = 6 + // 0b010 = 8 + // 0b011 = 12 + + // we do a binary search + // the pivots are divided by 12 (the FP4 absmax) + // since we assum input data is in [-1.0, 1.0] + + // !be careful here, its easy to make a mistake + // that is difficult to noice if you add an extra + // zero somewhere! + + int sign = x < 0 ? 0b1000 : 0b0000; + x = fabsf(x); + if (x > 0.29166667f) + if (x > 0.583333f) + if (x > 0.8333333f) + return 0b0011 + sign; + else + return 0b0010 + sign; + else if (x > 0.4166667f) + return 0b101 + sign; + else + return 0b100 + sign; + else if (x > 0.0859375f) + if (x > 0.20833333f) + return 0b0111 + sign; + else + return 0b0110 + sign; + else if (x > 0.00260417f) + return 0b0001 + sign; + else + return 0b0000 + sign; +} + +void fp4_quantize_launcher(const torch::Tensor& A, torch::Tensor& absmax, + torch::Tensor& out, int64_t blocksize, int64_t n) { + auto blocks = absmax.sizes()[0]; + auto src = A.data_ptr(); + auto absmax_ptr = absmax.data_ptr(); + auto out_ptr = out.data_ptr(); + for (int b = 0; b < blocks; b++) { + float max = -99999999999999.f; + size_t offset = b * blocksize; + for (int i = 0; i < blocksize; i++) { + if (offset + i >= n) break; + max = std::abs(src[offset + i]) > max ? std::abs(src[offset + i]) : max; + } + absmax_ptr[b] = max; + for (int i = 0; i < blocksize / 2; i++) { + unsigned char packed_4bit = 0; + if (offset + i * 2 >= n) break; + packed_4bit |= dQuantizeFP4(src[offset + 2 * i] * (1.f / max)) << 4; + packed_4bit |= dQuantizeFP4(src[offset + 2 * i + 1] * (1.f / max)); + out_ptr[offset / 2 + i] = packed_4bit; + } + } +} + +void fp4_dequantize_launcher(const torch::Tensor& A, torch::Tensor& absmax, + torch::Tensor& out, int64_t blocksize, int64_t n) { + auto blocks = absmax.sizes()[0]; + auto src = A.data_ptr(); + auto absmax_ptr = absmax.data_ptr(); + auto out_ptr = out.data_ptr(); + for (int b = 0; b < blocks; b++) { + size_t offset = b * blocksize; + auto max = absmax_ptr[b]; + for (int i = 0; i < blocksize / 2; i++) { + unsigned char packed_4bit = 0; + if (offset + i * 2 >= n) break; + out_ptr[offset + 2 * i] = + dDequantizeFP4Tree(src[offset / 2 + i] >> 4, max); + out_ptr[offset + 2 * i + 1] = + dDequantizeFP4Tree(src[offset / 2 + i] & 0x0f, max); + } + } +} \ No newline at end of file diff --git a/tests/custom_op/pyt_all_reduce_kernel.hh b/tests/custom_op/pyt_all_reduce_kernel.hh new file mode 100644 index 000000000..dd60102be --- /dev/null +++ b/tests/custom_op/pyt_all_reduce_kernel.hh @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +torch::Tensor all_reduce_launcher(torch::Tensor input); + +void fp4_quantize_launcher(const torch::Tensor& A, torch::Tensor& absmax, + torch::Tensor& out, int64_t blocksize, int64_t n); + +void fp4_dequantize_launcher(const torch::Tensor& A, torch::Tensor& absmax, + torch::Tensor& out, int64_t blocksize, int64_t n); \ No newline at end of file diff --git a/tests/custom_op/pyt_all_reduce_op.cpp b/tests/custom_op/pyt_all_reduce_op.cpp new file mode 100644 index 000000000..d8dfc776c --- /dev/null +++ b/tests/custom_op/pyt_all_reduce_op.cpp @@ -0,0 +1,26 @@ +#include +#include + +#include "pyt_all_reduce_kernel.hh" + +static torch::Tensor custom_allreduce(torch::Tensor input) { + return all_reduce_launcher(input); +} + +static void ref_fp4_quantize(const torch::Tensor& A, torch::Tensor& absmax, + torch::Tensor& out, int64_t blocksize, int64_t n) { + fp4_quantize_launcher(A, absmax, out, blocksize, n); +} + +static void ref_fp4_dequantize(const torch::Tensor& A, torch::Tensor& absmax, + torch::Tensor& out, int64_t blocksize, + int64_t n) { + fp4_dequantize_launcher(A, absmax, out, blocksize, n); +} + +// static auto registry = torch::RegisterOperators("myop::skbmm", &skbmm); +TORCH_LIBRARY(my_ops, m) { + m.def("custom_allreduce", &custom_allreduce); + m.def("ref_fp4_quantize", &ref_fp4_quantize); + m.def("ref_fp4_dequantize", &ref_fp4_dequantize); +} diff --git a/tests/custom_op/run_build.sh b/tests/custom_op/run_build.sh new file mode 100755 index 000000000..35f068b6f --- /dev/null +++ b/tests/custom_op/run_build.sh @@ -0,0 +1,9 @@ +#!/bin/bash + +rm -rf build; +mkdir build; +cd build + +cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch; print(torch.__path__[0])')" .. + +make -j$(nproc) diff --git a/tests/fp4_ref.py b/tests/fp4_ref.py new file mode 100644 index 000000000..d0ab2dcf5 --- /dev/null +++ b/tests/fp4_ref.py @@ -0,0 +1,32 @@ +import torch +from torch import Tensor +# TODO: support double-quant for absmax +torch.ops.load_library("/home/zhe/bitsandbytes/tests/custom_op/build/libcustom_allreduce_op.so") +def ref_quantizeblockwise_fp4(A: Tensor,absmax: Tensor=None,out: Tensor=None,blocksize=64,compress_statics=False,quant_type="fp4"): + n = A.numel() + input_shape = A.shape + + if absmax is None: + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) + + + if out is None: + out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device) + + assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] + + torch.ops.my_ops.ref_fp4_quantize(A,absmax,out,64,n) + + return absmax,out + +# TODO: support absmax-dequant +def ref_dequantizeblockwise_fp4(A: Tensor,out: Tensor,absmax: Tensor=None,blocksize=64,compress_statics=False,quant_type="fp4"): + assert(out != None) + shape = out.shape + dtype = out.dtype + n = out.numel() + torch.ops.my_ops.ref_fp4_dequantize(A,absmax,out,64,n) + return out + \ No newline at end of file diff --git a/tests/test_functional.py b/tests/test_functional.py index d7212b047..0a574d120 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -12,6 +12,8 @@ from bitsandbytes import functional as F from scipy.stats import norm +import fp4_ref + torch.set_printoptions( precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000 ) @@ -2257,10 +2259,16 @@ def test_fp4_quant(dtype): result = sign*exp*frac code[idx] = result - A1 = torch.randn(1024, 1024, device='cuda', dtype=dtype) + A1 = torch.randn(1023, 1023, device='cuda', dtype=dtype) + A11=A1.to("cpu") qa, SA = F.quantize_fp4(A1, blocksize=64) + qaa,SAA=fp4_ref.ref_quantizeblockwise_fp4(A11,blocksize=64) + # np.allclose(qa.cpu(),qaa,1.e-1) + # np.allclose(SA[0].cpu(),SAA,1) A2 = F.dequantize_fp4(qa, SA) - + out=torch.zeros(1023,1023,dtype=dtype) + A22=fp4_ref.ref_dequantizeblockwise_fp4(SAA,out,qaa) + np.allclose(A2.cpu(),A22,0) err = (A1 - A2).abs().float() relerr = (err/(A1.abs().float()+1e-8)).mean() idx = err > 1.0 @@ -2553,3 +2561,4 @@ def test_gemv_eye_4bit(storage_type, dtype, double_quant): #torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080) +test_fp4_quant(torch.float32) \ No newline at end of file diff --git a/tests/test_modules.py b/tests/test_modules.py index 7d2d03498..ab0226484 100644 --- a/tests/test_modules.py +++ b/tests/test_modules.py @@ -317,16 +317,17 @@ def forward(self, x): @pytest.mark.parametrize("threshold", values, ids=names) def test_linear8bitlt_inference(threshold): - l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half() - assert l1.weight.device.type == "cuda" - assert l1.weight.dtype == torch.float16 + # l1 = bnb.nn.Linear8bitLt(32, 64, threshold=threshold).cuda().half() + test=bnb.nn.LinearFP4(32,64).cuda() + # assert l1.weight.device.type == "cuda" + # assert l1.weight.dtype == torch.float16 - l1.eval() - for i in range(100): - b1 = torch.randn(16, 8, 32, device="cuda").half() - o1 = l1(b1) - if i == 1: - assert l1.state.CxB is not None + # l1.eval() + # for i in range(100): + # b1 = torch.randn(16, 8, 32, device="cuda").half() + # o1 = l1(b1) + # if i == 1: + # assert l1.state.CxB is not None def test_linear8bitlt_accumulated_gradient(): @@ -641,3 +642,4 @@ def test_4bit_warnings(): +test_linear8bitlt_inference(0.1) \ No newline at end of file