From 863182b6892229c42e321550f205bf61991a5a48 Mon Sep 17 00:00:00 2001 From: Mahdi Zaferanchi Date: Sat, 20 Jan 2024 19:39:33 +0330 Subject: [PATCH 1/2] make training kernels faster by using shared memory --- difflogic/cuda/difflogic_kernel.cu | 269 ++++++++++++++++++++++------- 1 file changed, 204 insertions(+), 65 deletions(-) diff --git a/difflogic/cuda/difflogic_kernel.cu b/difflogic/cuda/difflogic_kernel.cu index 52a7c90..b3017fb 100644 --- a/difflogic/cuda/difflogic_kernel.cu +++ b/difflogic/cuda/difflogic_kernel.cu @@ -9,6 +9,8 @@ #include #define BACKWARD_W_BATCH_THREADS 32 +#define MAX_SHARED_MEM 48000 +#define FANOUT_MEMLIM 48000 / (3 * 8 + 14 * 4) + !!(48000 % (3 * 8 + 14 * 4)) #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") @@ -95,43 +97,54 @@ __global__ void logic_layer_cuda_forward_kernel( torch::PackedTensorAccessor64 w, torch::PackedTensorAccessor64 y ) { + auto col = blockIdx.x; // neuron dim + auto row = blockIdx.y * blockDim.x + threadIdx.x; // batch dim + // batch_segment = blockIdx.y + // batch_dim_within_segment = threadIdx.x + + __shared__ int64_t idx_a; + __shared__ int64_t idx_b; + __shared__ scalar_t w_[16]; + + // if (threadIdx.x == 0) { + // idx_a = a[col]; + // idx_b = b[col]; + // } else if (threadIdx.x > 0 && threadIdx.x < 17) { + // w_[threadIdx.x - 1] = w[col][threadIdx.x - 1]; + // } + if (threadIdx.x == 0) { + idx_a = a[col]; + idx_b = b[col]; + for (int wIdx = 0; wIdx < 16; wIdx++) { + w_[wIdx] = w[col][wIdx]; + } + } - for ( // batch dim - auto row = blockIdx.x * blockDim.x + threadIdx.x; - row < y.size(1); - row += blockDim.x * gridDim.x - ) { - for ( // neuron dim - auto col = blockIdx.y * blockDim.y + threadIdx.y; - col < y.size(0); - col += blockDim.y * gridDim.y - ) { - - const auto idx_a = a[col]; - const auto idx_b = b[col]; - const auto a_ = x[idx_a][row]; - const auto b_ = x[idx_b][row]; - - const auto w_ = w[col]; - - y[col][row] = ( - ((w_[1] * (a_ * b_) - + w_[2] * (a_ - a_ * b_)) - + (w_[3] * a_ - + w_[4] * (b_ - a_ * b_))) - + ((w_[5] * b_ - + w_[6] * (a_ + b_ - static_cast(2) * a_ * b_)) - + (w_[7] * (a_ + b_ - a_ * b_) - + w_[8] * (static_cast(1) - (a_ + b_ - a_ * b_))))) - + (((w_[9] * (static_cast(1) - (a_ + b_ - static_cast(2) * a_ * b_)) - + w_[10] * (static_cast(1) - b_)) + - (w_[11] * (static_cast(1) - b_ + a_ * b_) - + w_[12] * (static_cast(1) - a_))) + - (w_[13] * (static_cast(1) - a_ + a_ * b_) - + w_[14] * (static_cast(1) - a_ * b_) - + w_[15]) - ); - }} + __syncthreads(); + + + if (row < y.size(1) && col < y.size(0)) { + const auto a_ = x[idx_a][row]; + const auto b_ = x[idx_b][row]; + + y[col][row] = ( + ((w_[1] * (a_ * b_) + + w_[2] * (a_ - a_ * b_)) + + (w_[3] * a_ + + w_[4] * (b_ - a_ * b_))) + + ((w_[5] * b_ + + w_[6] * (a_ + b_ - static_cast(2) * a_ * b_)) + + (w_[7] * (a_ + b_ - a_ * b_) + + w_[8] * (static_cast(1) - (a_ + b_ - a_ * b_))))) + + (((w_[9] * (static_cast(1) - (a_ + b_ - static_cast(2) * a_ * b_)) + + w_[10] * (static_cast(1) - b_)) + + (w_[11] * (static_cast(1) - b_ + a_ * b_) + + w_[12] * (static_cast(1) - a_))) + + (w_[13] * (static_cast(1) - a_ + a_ * b_) + + w_[14] * (static_cast(1) - a_ * b_) + + w_[15]) + ); + } } @@ -177,7 +190,6 @@ logic_layer_cuda_backward_w_kernel( } } - template __global__ void logic_layer_cuda_backward_x_kernel( @@ -255,6 +267,99 @@ logic_layer_cuda_backward_x_kernel( }} } +template +__global__ void +logic_layer_cuda_backward_x_kernel_optimized( + torch::PackedTensorAccessor64 x, + torch::PackedTensorAccessor64 a, + torch::PackedTensorAccessor64 b, + torch::PackedTensorAccessor64 w, + torch::PackedTensorAccessor64 grad_y, + torch::PackedTensorAccessor64 grad_x, + torch::PackedTensorAccessor64 given_x_indices_of_y_start, + torch::PackedTensorAccessor64 given_x_indices_of_y, + int max_fanout +) { + auto col = blockIdx.x; // neuron dim + auto row = blockIdx.y * blockDim.x + threadIdx.x; // batch dim + // batch_segment = blockIdx.y + // batch_dim_within_segment = threadIdx.x + + scalar_t grad_x_ = 0; + extern __shared__ char block_memory[]; + int64_t* idx_a_ = (int64_t*) &block_memory[0]; + int64_t* idx_b_ = (int64_t*) &block_memory[8*max_fanout]; + scalar_t* w_ = (scalar_t*) &block_memory[2*8*max_fanout]; + scalar_t** grad_y_idx_y = (scalar_t**) &block_memory[2*8*max_fanout + 14 * sizeof(scalar_t) * max_fanout]; + + const auto start = given_x_indices_of_y_start[col]; + const auto end = given_x_indices_of_y_start[col + 1]; + + if (threadIdx.x == 0) { + for (int cur = start; cur < end; ++cur) { + const auto idx_y = given_x_indices_of_y[cur]; + int local_idx = cur - start; + idx_a_[local_idx] = a[idx_y]; + idx_b_[local_idx] = b[idx_y]; + grad_y_idx_y[local_idx] = &grad_y[idx_y][0]; + for (int weight_idx = 0; weight_idx < 14; weight_idx++) { + w_[local_idx*14 + weight_idx] = w[idx_y][weight_idx + 1]; + } + } + } + + __syncthreads(); + + + if (row < grad_x.size(1) && col < grad_x.size(0)) { + for (int cur = start; cur < end; ++cur) { + int local_idx = cur - start; + const auto idx_a = idx_a_[local_idx]; + const auto idx_b = idx_b_[local_idx]; + const auto grad_y_ = grad_y_idx_y[local_idx][row]; + const auto idx_is_a = idx_a == col; + + // compute grad_x + if (idx_is_a) { + const auto b_ = x[idx_b][row]; + const auto dy_dx = ( + (w_[local_idx*14 + 1 - 1] * b_ + + w_[local_idx*14 + 2 - 1] * (static_cast(1) - b_) + + w_[local_idx*14 + 3 - 1]) + + (w_[local_idx*14 + 4 - 1] * -b_ + + w_[local_idx*14 + 6 - 1] * (static_cast(1) - static_cast(2) * b_) + + w_[local_idx*14 + 7 - 1] * (static_cast(1) - b_))) + + ((w_[local_idx*14 + 8 - 1] * (b_ - static_cast(1)) + + w_[local_idx*14 + 9 - 1] * (static_cast(2) * b_ - static_cast(1)) + + w_[local_idx*14 + 11 - 1] * b_) + + (-w_[local_idx*14 + 12 - 1] + + w_[local_idx*14 + 13 - 1] * (b_ - static_cast(1)) + + w_[local_idx*14 + 14 - 1] * -b_) + ); + grad_x_ += dy_dx * grad_y_; + } else { + const auto a_ = x[idx_a][row]; + const auto dy_dx = ( + (w_[local_idx*14 + 1 - 1] * a_ + + w_[local_idx*14 + 2 - 1] * -a_ + + w_[local_idx*14 + 4 - 1] * (static_cast(1) - a_)) + + (w_[local_idx*14 + 5 - 1] + + w_[local_idx*14 + 6 - 1] * (static_cast(1) - static_cast(2) * a_) + + w_[local_idx*14 + 7 - 1] * (static_cast(1) - a_))) + + ((w_[local_idx*14 + 8 - 1] * (a_ - static_cast(1)) + + w_[local_idx*14 + 9 - 1] * (static_cast(2) * a_ - static_cast(1)) + - w_[local_idx*14 + 10 - 1]) + + (w_[local_idx*14 + 11 - 1] * (a_ - static_cast(1)) + + w_[local_idx*14 + 13 - 1] * a_ + + w_[local_idx*14 + 14 - 1] * -a_) + ); + grad_x_ += dy_dx * grad_y_; + } + } + grad_x[col][row] = grad_x_; + } +} + torch::Tensor logic_layer_cuda_forward( torch::Tensor x, @@ -273,11 +378,10 @@ torch::Tensor logic_layer_cuda_forward( auto y = torch::empty({out_size, batch_size}, torch::dtype(x.dtype()).device(x.device())); - dim3 threads_per_block(32, 32); + dim3 threads_per_block(min(static_cast(1024), batch_size), 1); const dim3 blocks_per_grid( - min(static_cast(65535), ceil_div(batch_size, static_cast(threads_per_block.x))), - min(static_cast(65535), ceil_div(out_size, static_cast(threads_per_block.y))) + out_size, ceil_div(batch_size, static_cast(1024)) ); AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.type(), "logic_layer_cuda_forward", ([&] { @@ -382,31 +486,67 @@ torch::Tensor logic_layer_cuda_backward_x( CHECK_INPUT(given_x_indices_of_y); auto grad_x = torch::empty_like(x); + int max_fanout = ceil_div(grad_y.size(0) * 2, x.size(0)); + + if (max_fanout < FANOUT_MEMLIM) { + const auto batch_size = x.size(1); + const auto out_size = w.size(0); + const auto in_size = x.size(0); + + // int shared_mem_size = (3 * sizeof(int64_t) + 14 * sizeof(w.type())) * max_fanout; + int shared_mem_size = (3 * 8 + 14 * 4) * max_fanout; + + dim3 threads_per_block(min(static_cast(1024), batch_size), 1); + + const dim3 blocks_per_grid( + in_size, ceil_div(batch_size, static_cast(1024)) + ); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.type(), "logic_layer_cuda_backward_x", ([&] { + logic_layer_cuda_backward_x_kernel_optimized<<>>( + x.packed_accessor64(), + a.packed_accessor64(), + b.packed_accessor64(), + w.packed_accessor64(), + grad_y.packed_accessor64(), + grad_x.packed_accessor64(), + given_x_indices_of_y_start.packed_accessor64(), + given_x_indices_of_y.packed_accessor64(), + max_fanout + ); + })); + + gpuErrchk(cudaPeekAtLastError()); + gpuErrchk(cudaDeviceSynchronize()); + + return grad_x; + } else { + dim3 threads_per_block(32, 32); + + const dim3 blocks_per_grid( + min(static_cast(65535), ceil_div(x.size(1), static_cast(threads_per_block.x))), + min(static_cast(65535), ceil_div(x.size(0), static_cast(threads_per_block.y))) + ); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.type(), "logic_layer_cuda_backward_x", ([&] { + logic_layer_cuda_backward_x_kernel<<>>( + x.packed_accessor64(), + a.packed_accessor64(), + b.packed_accessor64(), + w.packed_accessor64(), + grad_y.packed_accessor64(), + grad_x.packed_accessor64(), + given_x_indices_of_y_start.packed_accessor64(), + given_x_indices_of_y.packed_accessor64() + ); + })); + + gpuErrchk(cudaPeekAtLastError()); + gpuErrchk(cudaDeviceSynchronize()); + + return grad_x; + } - dim3 threads_per_block(32, 32); - - const dim3 blocks_per_grid( - min(static_cast(65535), ceil_div(x.size(1), static_cast(threads_per_block.x))), - min(static_cast(65535), ceil_div(x.size(0), static_cast(threads_per_block.y))) - ); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.type(), "logic_layer_cuda_backward_x", ([&] { - logic_layer_cuda_backward_x_kernel<<>>( - x.packed_accessor64(), - a.packed_accessor64(), - b.packed_accessor64(), - w.packed_accessor64(), - grad_y.packed_accessor64(), - grad_x.packed_accessor64(), - given_x_indices_of_y_start.packed_accessor64(), - given_x_indices_of_y.packed_accessor64() - ); - })); - - gpuErrchk(cudaPeekAtLastError()); - gpuErrchk(cudaDeviceSynchronize()); - - return grad_x; } @@ -698,4 +838,3 @@ torch::Tensor groupbitsum( /**********************************************************************************************************************/ - From 47c6954665a7a67e012396f6a15587e1039eb41e Mon Sep 17 00:00:00 2001 From: Mahdi Zaferanchi Date: Mon, 11 Mar 2024 18:39:20 +0330 Subject: [PATCH 2/2] fix tensor_packbits_cuda_kernel for when batch size is not a multiple of bit count add __repr__ to PackBitsTensor --- difflogic/cuda/difflogic_kernel.cu | 9 ++++++--- difflogic/packbitstensor.py | 15 ++++++++++++++- 2 files changed, 20 insertions(+), 4 deletions(-) diff --git a/difflogic/cuda/difflogic_kernel.cu b/difflogic/cuda/difflogic_kernel.cu index 52a7c90..14b9b42 100644 --- a/difflogic/cuda/difflogic_kernel.cu +++ b/difflogic/cuda/difflogic_kernel.cu @@ -557,7 +557,7 @@ __global__ void tensor_packbits_cuda_kernel( ) { for ( // batch in b auto col = blockIdx.x * blockDim.x + threadIdx.x; - col < t.size(1); + col < b.size(1); col += blockDim.x * gridDim.x ) { @@ -569,8 +569,11 @@ __global__ void tensor_packbits_cuda_kernel( constexpr int bit_count = std::numeric_limits::digits; val.signed_scalar = b[row][col]; for (unsigned int i = 0; i < bit_count; ++i) { - const unsigned_scalar_t bit_mask = static_cast(t[row][bit_count * col + i]) << i; - val.unsigned_scalar = val.unsigned_scalar | bit_mask; + const auto t_col = bit_count * col + i; + if (t_col < t.size(1)) { + const unsigned_scalar_t bit_mask = static_cast(t[row][t_col]) << i; + val.unsigned_scalar = val.unsigned_scalar | bit_mask; + } } b[row][col] = val.signed_scalar; } diff --git a/difflogic/packbitstensor.py b/difflogic/packbitstensor.py index ae4443d..bc1afd3 100644 --- a/difflogic/packbitstensor.py +++ b/difflogic/packbitstensor.py @@ -1,5 +1,6 @@ import difflogic_cuda import torch +import numpy as np class PackBitsTensor: @@ -25,4 +26,16 @@ def flatten(self, start_dim=0, end_dim=-1, **kwargs): Returns the PackBitsTensor object itself. Arguments are ignored. """ - return self \ No newline at end of file + return self + + def _get_member_repr(self, member): + if len(member) <= 4: + result = [(np.binary_repr(integer, width=self.bit_count))[::-1] for integer in member] + return ' '.join(result) + first_three = [(np.binary_repr(integer, width=self.bit_count))[::-1] for integer in member[:3]] + sep = "..." + final = np.binary_repr(member[-1], width=self.bit_count)[::-1] + return f"{' '.join(first_three)} {sep} {final}" + + def __repr__(self): + return '\n'.join([self._get_member_repr(item) for item in self.t]) \ No newline at end of file