Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
278 changes: 210 additions & 68 deletions difflogic/cuda/difflogic_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
#include <vector>

#define BACKWARD_W_BATCH_THREADS 32
#define MAX_SHARED_MEM 48000
#define FANOUT_MEMLIM 48000 / (3 * 8 + 14 * 4) + !!(48000 % (3 * 8 + 14 * 4))

#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
Expand Down Expand Up @@ -95,43 +97,54 @@ __global__ void logic_layer_cuda_forward_kernel(
torch::PackedTensorAccessor64<scalar_t, 2, torch::RestrictPtrTraits> w,
torch::PackedTensorAccessor64<scalar_t, 2, torch::RestrictPtrTraits> y
) {
auto col = blockIdx.x; // neuron dim
auto row = blockIdx.y * blockDim.x + threadIdx.x; // batch dim
// batch_segment = blockIdx.y
// batch_dim_within_segment = threadIdx.x

__shared__ int64_t idx_a;
__shared__ int64_t idx_b;
__shared__ scalar_t w_[16];

// if (threadIdx.x == 0) {
// idx_a = a[col];
// idx_b = b[col];
// } else if (threadIdx.x > 0 && threadIdx.x < 17) {
// w_[threadIdx.x - 1] = w[col][threadIdx.x - 1];
// }
if (threadIdx.x == 0) {
idx_a = a[col];
idx_b = b[col];
for (int wIdx = 0; wIdx < 16; wIdx++) {
w_[wIdx] = w[col][wIdx];
}
}

for ( // batch dim
auto row = blockIdx.x * blockDim.x + threadIdx.x;
row < y.size(1);
row += blockDim.x * gridDim.x
) {
for ( // neuron dim
auto col = blockIdx.y * blockDim.y + threadIdx.y;
col < y.size(0);
col += blockDim.y * gridDim.y
) {

const auto idx_a = a[col];
const auto idx_b = b[col];
const auto a_ = x[idx_a][row];
const auto b_ = x[idx_b][row];

const auto w_ = w[col];

y[col][row] = (
((w_[1] * (a_ * b_)
+ w_[2] * (a_ - a_ * b_))
+ (w_[3] * a_
+ w_[4] * (b_ - a_ * b_)))
+ ((w_[5] * b_
+ w_[6] * (a_ + b_ - static_cast<scalar_t>(2) * a_ * b_))
+ (w_[7] * (a_ + b_ - a_ * b_)
+ w_[8] * (static_cast<scalar_t>(1) - (a_ + b_ - a_ * b_)))))
+ (((w_[9] * (static_cast<scalar_t>(1) - (a_ + b_ - static_cast<scalar_t>(2) * a_ * b_))
+ w_[10] * (static_cast<scalar_t>(1) - b_)) +
(w_[11] * (static_cast<scalar_t>(1) - b_ + a_ * b_)
+ w_[12] * (static_cast<scalar_t>(1) - a_))) +
(w_[13] * (static_cast<scalar_t>(1) - a_ + a_ * b_)
+ w_[14] * (static_cast<scalar_t>(1) - a_ * b_)
+ w_[15])
);
}}
__syncthreads();


if (row < y.size(1) && col < y.size(0)) {
const auto a_ = x[idx_a][row];
const auto b_ = x[idx_b][row];

y[col][row] = (
((w_[1] * (a_ * b_)
+ w_[2] * (a_ - a_ * b_))
+ (w_[3] * a_
+ w_[4] * (b_ - a_ * b_)))
+ ((w_[5] * b_
+ w_[6] * (a_ + b_ - static_cast<scalar_t>(2) * a_ * b_))
+ (w_[7] * (a_ + b_ - a_ * b_)
+ w_[8] * (static_cast<scalar_t>(1) - (a_ + b_ - a_ * b_)))))
+ (((w_[9] * (static_cast<scalar_t>(1) - (a_ + b_ - static_cast<scalar_t>(2) * a_ * b_))
+ w_[10] * (static_cast<scalar_t>(1) - b_)) +
(w_[11] * (static_cast<scalar_t>(1) - b_ + a_ * b_)
+ w_[12] * (static_cast<scalar_t>(1) - a_))) +
(w_[13] * (static_cast<scalar_t>(1) - a_ + a_ * b_)
+ w_[14] * (static_cast<scalar_t>(1) - a_ * b_)
+ w_[15])
);
}
}


Expand Down Expand Up @@ -177,7 +190,6 @@ logic_layer_cuda_backward_w_kernel(
}
}


template <typename scalar_t>
__global__ void
logic_layer_cuda_backward_x_kernel(
Expand Down Expand Up @@ -255,6 +267,99 @@ logic_layer_cuda_backward_x_kernel(
}}
}

template <typename scalar_t>
__global__ void
logic_layer_cuda_backward_x_kernel_optimized(
torch::PackedTensorAccessor64<scalar_t, 2, torch::RestrictPtrTraits> x,
torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> a,
torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> b,
torch::PackedTensorAccessor64<scalar_t, 2, torch::RestrictPtrTraits> w,
torch::PackedTensorAccessor64<scalar_t, 2, torch::RestrictPtrTraits> grad_y,
torch::PackedTensorAccessor64<scalar_t, 2, torch::RestrictPtrTraits> grad_x,
torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> given_x_indices_of_y_start,
torch::PackedTensorAccessor64<int64_t, 1, torch::RestrictPtrTraits> given_x_indices_of_y,
int max_fanout
) {
auto col = blockIdx.x; // neuron dim
auto row = blockIdx.y * blockDim.x + threadIdx.x; // batch dim
// batch_segment = blockIdx.y
// batch_dim_within_segment = threadIdx.x

scalar_t grad_x_ = 0;
extern __shared__ char block_memory[];
int64_t* idx_a_ = (int64_t*) &block_memory[0];
int64_t* idx_b_ = (int64_t*) &block_memory[8*max_fanout];
scalar_t* w_ = (scalar_t*) &block_memory[2*8*max_fanout];
scalar_t** grad_y_idx_y = (scalar_t**) &block_memory[2*8*max_fanout + 14 * sizeof(scalar_t) * max_fanout];

const auto start = given_x_indices_of_y_start[col];
const auto end = given_x_indices_of_y_start[col + 1];

if (threadIdx.x == 0) {
for (int cur = start; cur < end; ++cur) {
const auto idx_y = given_x_indices_of_y[cur];
int local_idx = cur - start;
idx_a_[local_idx] = a[idx_y];
idx_b_[local_idx] = b[idx_y];
grad_y_idx_y[local_idx] = &grad_y[idx_y][0];
for (int weight_idx = 0; weight_idx < 14; weight_idx++) {
w_[local_idx*14 + weight_idx] = w[idx_y][weight_idx + 1];
}
}
}

__syncthreads();


if (row < grad_x.size(1) && col < grad_x.size(0)) {
for (int cur = start; cur < end; ++cur) {
int local_idx = cur - start;
const auto idx_a = idx_a_[local_idx];
const auto idx_b = idx_b_[local_idx];
const auto grad_y_ = grad_y_idx_y[local_idx][row];
const auto idx_is_a = idx_a == col;

// compute grad_x
if (idx_is_a) {
const auto b_ = x[idx_b][row];
const auto dy_dx = (
(w_[local_idx*14 + 1 - 1] * b_
+ w_[local_idx*14 + 2 - 1] * (static_cast<scalar_t>(1) - b_)
+ w_[local_idx*14 + 3 - 1]) +
(w_[local_idx*14 + 4 - 1] * -b_
+ w_[local_idx*14 + 6 - 1] * (static_cast<scalar_t>(1) - static_cast<scalar_t>(2) * b_)
+ w_[local_idx*14 + 7 - 1] * (static_cast<scalar_t>(1) - b_)))
+ ((w_[local_idx*14 + 8 - 1] * (b_ - static_cast<scalar_t>(1))
+ w_[local_idx*14 + 9 - 1] * (static_cast<scalar_t>(2) * b_ - static_cast<scalar_t>(1))
+ w_[local_idx*14 + 11 - 1] * b_)
+ (-w_[local_idx*14 + 12 - 1]
+ w_[local_idx*14 + 13 - 1] * (b_ - static_cast<scalar_t>(1))
+ w_[local_idx*14 + 14 - 1] * -b_)
);
grad_x_ += dy_dx * grad_y_;
} else {
const auto a_ = x[idx_a][row];
const auto dy_dx = (
(w_[local_idx*14 + 1 - 1] * a_
+ w_[local_idx*14 + 2 - 1] * -a_
+ w_[local_idx*14 + 4 - 1] * (static_cast<scalar_t>(1) - a_))
+ (w_[local_idx*14 + 5 - 1]
+ w_[local_idx*14 + 6 - 1] * (static_cast<scalar_t>(1) - static_cast<scalar_t>(2) * a_)
+ w_[local_idx*14 + 7 - 1] * (static_cast<scalar_t>(1) - a_)))
+ ((w_[local_idx*14 + 8 - 1] * (a_ - static_cast<scalar_t>(1))
+ w_[local_idx*14 + 9 - 1] * (static_cast<scalar_t>(2) * a_ - static_cast<scalar_t>(1))
- w_[local_idx*14 + 10 - 1])
+ (w_[local_idx*14 + 11 - 1] * (a_ - static_cast<scalar_t>(1))
+ w_[local_idx*14 + 13 - 1] * a_
+ w_[local_idx*14 + 14 - 1] * -a_)
);
grad_x_ += dy_dx * grad_y_;
}
}
grad_x[col][row] = grad_x_;
}
}


torch::Tensor logic_layer_cuda_forward(
torch::Tensor x,
Expand All @@ -273,11 +378,10 @@ torch::Tensor logic_layer_cuda_forward(

auto y = torch::empty({out_size, batch_size}, torch::dtype(x.dtype()).device(x.device()));

dim3 threads_per_block(32, 32);
dim3 threads_per_block(min(static_cast<int64_t>(1024), batch_size), 1);

const dim3 blocks_per_grid(
min(static_cast<int64_t>(65535), ceil_div(batch_size, static_cast<int64_t>(threads_per_block.x))),
min(static_cast<int64_t>(65535), ceil_div(out_size, static_cast<int64_t>(threads_per_block.y)))
out_size, ceil_div(batch_size, static_cast<int64_t>(1024))
);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.type(), "logic_layer_cuda_forward", ([&] {
Expand Down Expand Up @@ -382,31 +486,67 @@ torch::Tensor logic_layer_cuda_backward_x(
CHECK_INPUT(given_x_indices_of_y);

auto grad_x = torch::empty_like(x);
int max_fanout = ceil_div(grad_y.size(0) * 2, x.size(0));

if (max_fanout < FANOUT_MEMLIM) {
const auto batch_size = x.size(1);
const auto out_size = w.size(0);
const auto in_size = x.size(0);

// int shared_mem_size = (3 * sizeof(int64_t) + 14 * sizeof(w.type())) * max_fanout;
int shared_mem_size = (3 * 8 + 14 * 4) * max_fanout;

dim3 threads_per_block(min(static_cast<int64_t>(1024), batch_size), 1);

const dim3 blocks_per_grid(
in_size, ceil_div(batch_size, static_cast<int64_t>(1024))
);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.type(), "logic_layer_cuda_backward_x", ([&] {
logic_layer_cuda_backward_x_kernel_optimized<scalar_t><<<blocks_per_grid, threads_per_block, shared_mem_size>>>(
x.packed_accessor64<scalar_t, 2, torch::RestrictPtrTraits>(),
a.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
b.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
w.packed_accessor64<scalar_t, 2, torch::RestrictPtrTraits>(),
grad_y.packed_accessor64<scalar_t, 2, torch::RestrictPtrTraits>(),
grad_x.packed_accessor64<scalar_t, 2, torch::RestrictPtrTraits>(),
given_x_indices_of_y_start.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
given_x_indices_of_y.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
max_fanout
);
}));

gpuErrchk(cudaPeekAtLastError());
gpuErrchk(cudaDeviceSynchronize());

return grad_x;
} else {
dim3 threads_per_block(32, 32);

const dim3 blocks_per_grid(
min(static_cast<int64_t>(65535), ceil_div(x.size(1), static_cast<int64_t>(threads_per_block.x))),
min(static_cast<int64_t>(65535), ceil_div(x.size(0), static_cast<int64_t>(threads_per_block.y)))
);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.type(), "logic_layer_cuda_backward_x", ([&] {
logic_layer_cuda_backward_x_kernel<scalar_t><<<blocks_per_grid, threads_per_block>>>(
x.packed_accessor64<scalar_t, 2, torch::RestrictPtrTraits>(),
a.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
b.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
w.packed_accessor64<scalar_t, 2, torch::RestrictPtrTraits>(),
grad_y.packed_accessor64<scalar_t, 2, torch::RestrictPtrTraits>(),
grad_x.packed_accessor64<scalar_t, 2, torch::RestrictPtrTraits>(),
given_x_indices_of_y_start.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
given_x_indices_of_y.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>()
);
}));

gpuErrchk(cudaPeekAtLastError());
gpuErrchk(cudaDeviceSynchronize());

return grad_x;
}

dim3 threads_per_block(32, 32);

const dim3 blocks_per_grid(
min(static_cast<int64_t>(65535), ceil_div(x.size(1), static_cast<int64_t>(threads_per_block.x))),
min(static_cast<int64_t>(65535), ceil_div(x.size(0), static_cast<int64_t>(threads_per_block.y)))
);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.type(), "logic_layer_cuda_backward_x", ([&] {
logic_layer_cuda_backward_x_kernel<scalar_t><<<blocks_per_grid, threads_per_block>>>(
x.packed_accessor64<scalar_t, 2, torch::RestrictPtrTraits>(),
a.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
b.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
w.packed_accessor64<scalar_t, 2, torch::RestrictPtrTraits>(),
grad_y.packed_accessor64<scalar_t, 2, torch::RestrictPtrTraits>(),
grad_x.packed_accessor64<scalar_t, 2, torch::RestrictPtrTraits>(),
given_x_indices_of_y_start.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>(),
given_x_indices_of_y.packed_accessor64<int64_t, 1, torch::RestrictPtrTraits>()
);
}));

gpuErrchk(cudaPeekAtLastError());
gpuErrchk(cudaDeviceSynchronize());

return grad_x;
}


Expand Down Expand Up @@ -557,7 +697,7 @@ __global__ void tensor_packbits_cuda_kernel(
) {
for ( // batch in b
auto col = blockIdx.x * blockDim.x + threadIdx.x;
col < t.size(1);
col < b.size(1);
col += blockDim.x * gridDim.x
) {

Expand All @@ -569,8 +709,11 @@ __global__ void tensor_packbits_cuda_kernel(
constexpr int bit_count = std::numeric_limits<unsigned_scalar_t>::digits;
val.signed_scalar = b[row][col];
for (unsigned int i = 0; i < bit_count; ++i) {
const unsigned_scalar_t bit_mask = static_cast<unsigned_scalar_t>(t[row][bit_count * col + i]) << i;
val.unsigned_scalar = val.unsigned_scalar | bit_mask;
const auto t_col = bit_count * col + i;
if (t_col < t.size(1)) {
const unsigned_scalar_t bit_mask = static_cast<unsigned_scalar_t>(t[row][t_col]) << i;
val.unsigned_scalar = val.unsigned_scalar | bit_mask;
}
}
b[row][col] = val.signed_scalar;
}
Expand Down Expand Up @@ -698,4 +841,3 @@ torch::Tensor groupbitsum(


/**********************************************************************************************************************/

15 changes: 14 additions & 1 deletion difflogic/packbitstensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import difflogic_cuda
import torch
import numpy as np


class PackBitsTensor:
Expand All @@ -25,4 +26,16 @@ def flatten(self, start_dim=0, end_dim=-1, **kwargs):
Returns the PackBitsTensor object itself.
Arguments are ignored.
"""
return self
return self

def _get_member_repr(self, member):
if len(member) <= 4:
result = [(np.binary_repr(integer, width=self.bit_count))[::-1] for integer in member]
return ' '.join(result)
first_three = [(np.binary_repr(integer, width=self.bit_count))[::-1] for integer in member[:3]]
sep = "..."
final = np.binary_repr(member[-1], width=self.bit_count)[::-1]
return f"{' '.join(first_three)} {sep} {final}"

def __repr__(self):
return '\n'.join([self._get_member_repr(item) for item in self.t])