Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": false
}
]
}
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ BUILD_DIR:= $(ROOT_DIR)/build
FILES_CUDA := $(CSRC)/ops.cu $(CSRC)/kernels.cu
FILES_CPP := $(CSRC)/common.cpp $(CSRC)/cpu_ops.cpp $(CSRC)/pythonInterface.c

INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include
# INCLUDE := -I $(CUDA_HOME)/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include
INCLUDE := -I /usr/local/cuda-11.3/include -I $(ROOT_DIR)/csrc -I $(CONDA_PREFIX)/include -I $(ROOT_DIR)/include
LIB := -L $(CUDA_HOME)/lib64 -lcudart -lcublas -lcublasLt -lcusparse -L $(CONDA_PREFIX)/lib

# NVIDIA NVCC compilation flags
Expand Down
3 changes: 2 additions & 1 deletion bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,7 +838,8 @@ def quantize_4bit(A: Tensor, absmax: Tensor = None, out: Tensor = None, blocksiz
else:
raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {A.dtype}")
post_call(A.device)

# print(absmax)
# print(out)
datatype = get_4bit_type(quant_type, device=A.device)

if compress_statistics:
Expand Down
7 changes: 7 additions & 0 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,9 +154,16 @@ def __new__(cls, data=None, requires_grad=True, quant_state=None, blocksize=64,

def cuda(self, device):
w = self.data.contiguous().half().cuda(device)
rw=self.data.contiguous()
w_4bit, quant_state = bnb.functional.quantize_4bit(w, blocksize=self.blocksize, compress_statistics=self.compress_statistics, quant_type=self.quant_type)
self.data = w_4bit
self.quant_state = quant_state
absmax=torch.zeros(32,dtype=torch.float)
out = torch.zeros(1024,1, dtype=torch.uint8)
torch.ops.load_library("/home/zhe/bitsandbytes/tests/custom_op/build/libcustom_allreduce_op.so")
torch.ops.my_ops.ref_fp4_quantize(rw,absmax,out,64,2048)
# print(absmax)
# print(out)

return self

Expand Down
1 change: 1 addition & 0 deletions csrc/ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include <cuda_runtime_api.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cublas_v2.h>
#include <cublasLt.h>
#include <cusparse.h>
Expand Down
14 changes: 14 additions & 0 deletions tests/custom_op/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
cmake_minimum_required(VERSION 3.1 FATAL_ERROR)
project(custom_allreduce_op LANGUAGES CXX CUDA)

find_package(Torch REQUIRED)

# Define our library target
add_library(custom_allreduce_op SHARED pyt_all_reduce_op.cpp pyt_all_reduce_kernel.cu)
# Enable C++14
target_compile_features(custom_allreduce_op PRIVATE cxx_std_14)
# Link against LibTorch
target_link_libraries(custom_allreduce_op "${TORCH_LIBRARIES}")

set_property(TARGET torch_cuda PROPERTY INTERFACE_COMPILE_OPTIONS "")
set_property(TARGET torch_cpu PROPERTY INTERFACE_COMPILE_OPTIONS "")
6 changes: 6 additions & 0 deletions tests/custom_op/pyt_all_reduce_cpu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch
torch.ops.load_library("build/libcustom_allreduce_op.so")
A = torch.ones(1024, dtype=torch.int)
b = torch.ops.my_ops.custom_allreduce(A)
print(b)

6 changes: 6 additions & 0 deletions tests/custom_op/pyt_all_reduce_gpu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import torch
torch.ops.load_library("build/libcustom_allreduce_op.so")
A = torch.ones(1024, dtype=torch.float32, device='cuda')
b = torch.ops.my_ops.custom_allreduce(A.half())
print(b[0]) #.to('cpu'))

157 changes: 157 additions & 0 deletions tests/custom_op/pyt_all_reduce_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#include <torch/script.h>

#include "pyt_all_reduce_kernel.hh"
#define BLOCKX_DIM 256

template <typename scalar_t>
void cpu_all_reduce(float* sum, scalar_t* data, int n) {
scalar_t temp_sum = 0;
for (int i = 0; i < n; ++i) {
temp_sum += data[i];
}
*sum = temp_sum;
}

template <typename scalar_t>
__global__ void gpu_all_reduce(float* sum, scalar_t* data, int n) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
scalar_t temp = 0;
for (int i = idx; i < n; i += stride) {
temp += data[i];
}

atomicAdd(sum, temp);
}

torch::Tensor all_reduce_launcher(torch::Tensor input) {
torch::Device device(torch::kCUDA, 0);
torch::Tensor output = torch::zeros(1, torch::kFloat);
if (input.device() == device) {
output = output.to(device);
dim3 blockSize(BLOCKX_DIM);
dim3 gridSize((input.size(0) + BLOCKX_DIM - 1) / BLOCKX_DIM);
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
input.type(), "gpu_all_reduce", ([&] {
gpu_all_reduce<scalar_t><<<gridSize, blockSize, 0, stream>>>(
output.data_ptr<float>(), input.data_ptr<scalar_t>(),
input.size(0));
}));
} else {
cpu_all_reduce<int>(output.data_ptr<float>(), input.data_ptr<int>(),
input.size(0));
}
return output;
}

float dDequantizeFP4Tree(unsigned char val, float absmax) {
float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
if ((val & 0b0100) == 4) // 0
if ((val & 0b0010) == 2) // 01
if ((val & 0b0001) == 1) // 111
return 0.25000000f * absmax * sign; // 1111
else
return 0.16666667f * absmax * sign; // 1110
else if ((val & 0b0001) == 1) // 110
return 0.50000000f * absmax * sign; // 1101
else
return 0.33333333f * absmax * sign; // 1100
else if ((val & 0b0010) == 2) // 10
if ((val & 0b0001) == 1) // 101
return 1.00000000f * absmax * sign; // 1011
else
return 0.66666667f * absmax * sign; // 1010
else if ((val & 0b0001) == 1) // 100
return 5.208333333e-03f * absmax * sign; // 1001
else
return 0.00000000f * absmax * sign; // 1000
}

unsigned char dQuantizeFP4(float x) {
// FP4 with bias of 3
// first bit is a sign
// subnormals
// 0b000 = 0
// 0b001 = 0.0625
// 0b110 = 2
// 0b111 = 3
// 0b100 = 4
// 0b101 = 6
// 0b010 = 8
// 0b011 = 12

// we do a binary search
// the pivots are divided by 12 (the FP4 absmax)
// since we assum input data is in [-1.0, 1.0]

// !be careful here, its easy to make a mistake
// that is difficult to noice if you add an extra
// zero somewhere!

int sign = x < 0 ? 0b1000 : 0b0000;
x = fabsf(x);
if (x > 0.29166667f)
if (x > 0.583333f)
if (x > 0.8333333f)
return 0b0011 + sign;
else
return 0b0010 + sign;
else if (x > 0.4166667f)
return 0b101 + sign;
else
return 0b100 + sign;
else if (x > 0.0859375f)
if (x > 0.20833333f)
return 0b0111 + sign;
else
return 0b0110 + sign;
else if (x > 0.00260417f)
return 0b0001 + sign;
else
return 0b0000 + sign;
}

void fp4_quantize_launcher(const torch::Tensor& A, torch::Tensor& absmax,
torch::Tensor& out, int64_t blocksize, int64_t n) {
auto blocks = absmax.sizes()[0];
auto src = A.data_ptr<float>();
auto absmax_ptr = absmax.data_ptr<float>();
auto out_ptr = out.data_ptr<unsigned char>();
for (int b = 0; b < blocks; b++) {
float max = -99999999999999.f;
size_t offset = b * blocksize;
for (int i = 0; i < blocksize; i++) {
if (offset + i >= n) break;
max = std::abs(src[offset + i]) > max ? std::abs(src[offset + i]) : max;
}
absmax_ptr[b] = max;
for (int i = 0; i < blocksize / 2; i++) {
unsigned char packed_4bit = 0;
if (offset + i * 2 >= n) break;
packed_4bit |= dQuantizeFP4(src[offset + 2 * i] * (1.f / max)) << 4;
packed_4bit |= dQuantizeFP4(src[offset + 2 * i + 1] * (1.f / max));
out_ptr[offset / 2 + i] = packed_4bit;
}
}
}

void fp4_dequantize_launcher(const torch::Tensor& A, torch::Tensor& absmax,
torch::Tensor& out, int64_t blocksize, int64_t n) {
auto blocks = absmax.sizes()[0];
auto src = A.data_ptr<unsigned char>();
auto absmax_ptr = absmax.data_ptr<float>();
auto out_ptr = out.data_ptr<float>();
for (int b = 0; b < blocks; b++) {
size_t offset = b * blocksize;
auto max = absmax_ptr[b];
for (int i = 0; i < blocksize / 2; i++) {
unsigned char packed_4bit = 0;
if (offset + i * 2 >= n) break;
out_ptr[offset + 2 * i] =
dDequantizeFP4Tree(src[offset / 2 + i] >> 4, max);
out_ptr[offset + 2 * i + 1] =
dDequantizeFP4Tree(src[offset / 2 + i] & 0x0f, max);
}
}
}
12 changes: 12 additions & 0 deletions tests/custom_op/pyt_all_reduce_kernel.hh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

#include <ATen/cuda/CUDAContext.h>
#include <torch/script.h>

torch::Tensor all_reduce_launcher(torch::Tensor input);

void fp4_quantize_launcher(const torch::Tensor& A, torch::Tensor& absmax,
torch::Tensor& out, int64_t blocksize, int64_t n);

void fp4_dequantize_launcher(const torch::Tensor& A, torch::Tensor& absmax,
torch::Tensor& out, int64_t blocksize, int64_t n);
26 changes: 26 additions & 0 deletions tests/custom_op/pyt_all_reduce_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include <torch/script.h>
#include <torch/torch.h>

#include "pyt_all_reduce_kernel.hh"

static torch::Tensor custom_allreduce(torch::Tensor input) {
return all_reduce_launcher(input);
}

static void ref_fp4_quantize(const torch::Tensor& A, torch::Tensor& absmax,
torch::Tensor& out, int64_t blocksize, int64_t n) {
fp4_quantize_launcher(A, absmax, out, blocksize, n);
}

static void ref_fp4_dequantize(const torch::Tensor& A, torch::Tensor& absmax,
torch::Tensor& out, int64_t blocksize,
int64_t n) {
fp4_dequantize_launcher(A, absmax, out, blocksize, n);
}

// static auto registry = torch::RegisterOperators("myop::skbmm", &skbmm);
TORCH_LIBRARY(my_ops, m) {
m.def("custom_allreduce", &custom_allreduce);
m.def("ref_fp4_quantize", &ref_fp4_quantize);
m.def("ref_fp4_dequantize", &ref_fp4_dequantize);
}
9 changes: 9 additions & 0 deletions tests/custom_op/run_build.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
#!/bin/bash

rm -rf build;
mkdir build;
cd build

cmake -DCMAKE_PREFIX_PATH="$(python -c 'import torch; print(torch.__path__[0])')" ..

make -j$(nproc)
32 changes: 32 additions & 0 deletions tests/fp4_ref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
from torch import Tensor
# TODO: support double-quant for absmax
torch.ops.load_library("/home/zhe/bitsandbytes/tests/custom_op/build/libcustom_allreduce_op.so")
def ref_quantizeblockwise_fp4(A: Tensor,absmax: Tensor=None,out: Tensor=None,blocksize=64,compress_statics=False,quant_type="fp4"):
n = A.numel()
input_shape = A.shape

if absmax is None:
blocks = n // blocksize
blocks += 1 if n % blocksize > 0 else 0
absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)


if out is None:
out = torch.zeros(((n+1)//2, 1), dtype=torch.uint8, device=A.device)

assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64]

torch.ops.my_ops.ref_fp4_quantize(A,absmax,out,64,n)

return absmax,out

# TODO: support absmax-dequant
def ref_dequantizeblockwise_fp4(A: Tensor,out: Tensor,absmax: Tensor=None,blocksize=64,compress_statics=False,quant_type="fp4"):
assert(out != None)
shape = out.shape
dtype = out.dtype
n = out.numel()
torch.ops.my_ops.ref_fp4_dequantize(A,absmax,out,64,n)
return out

13 changes: 11 additions & 2 deletions tests/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
from bitsandbytes import functional as F
from scipy.stats import norm

import fp4_ref

torch.set_printoptions(
precision=5, sci_mode=False, linewidth=120, edgeitems=20, threshold=10000
)
Expand Down Expand Up @@ -2257,10 +2259,16 @@ def test_fp4_quant(dtype):
result = sign*exp*frac
code[idx] = result

A1 = torch.randn(1024, 1024, device='cuda', dtype=dtype)
A1 = torch.randn(1023, 1023, device='cuda', dtype=dtype)
A11=A1.to("cpu")
qa, SA = F.quantize_fp4(A1, blocksize=64)
qaa,SAA=fp4_ref.ref_quantizeblockwise_fp4(A11,blocksize=64)
# np.allclose(qa.cpu(),qaa,1.e-1)
# np.allclose(SA[0].cpu(),SAA,1)
A2 = F.dequantize_fp4(qa, SA)

out=torch.zeros(1023,1023,dtype=dtype)
A22=fp4_ref.ref_dequantizeblockwise_fp4(SAA,out,qaa)
np.allclose(A2.cpu(),A22,0)
err = (A1 - A2).abs().float()
relerr = (err/(A1.abs().float()+1e-8)).mean()
idx = err > 1.0
Expand Down Expand Up @@ -2553,3 +2561,4 @@ def test_gemv_eye_4bit(storage_type, dtype, double_quant):
#torch.testing.assert_close(A, C2, rtol=1e-5, atol=0.080)


test_fp4_quant(torch.float32)
Loading