HipBlasLT-Ext GroupedGemm support for RaggedDot op#514
HipBlasLT-Ext GroupedGemm support for RaggedDot op#514mfrancepillois wants to merge 30 commits intorocm-jaxlib-v0.8.0from
Conversation
HipBlasLT-Ext GroupedGemm support for RaggedDot implemented for matrices with and without batch dimension (note that the batch dimension must be the outer-most dimension of the matrix). Three ragged modes supported: - ragged in Non-Contracting dimension - ragged in Contracting dimension - ragged in Batch dimension Not implemented yet: - epilogues - proto for GroupedGemm config
Add GroupedGemm Algo validity check.
|
The performance of GroupedGemm on MI300, Rocm version 7.1, Jax rocm-jaxlib-v0.8.0, using the following script: import time
import statistics as stats
import jax
import jax.numpy as jnp
from jax import lax
from jax import random
from itertools import product
import os
from datetime import datetime
def random_tokens_per_expert(nb_experts, min_val, max_val, nb_tokens, key):
# Convert inputs to Python ints if they are arrays
nb_experts = int(nb_experts)
min_val = int(min_val)
max_val = int(max_val)
nb_tokens = int(nb_tokens)
# Check feasibility
if not (nb_experts * min_val <= nb_tokens <= nb_experts * max_val):
raise ValueError("Impossible to generate with given constraints")
# Start with minimum tokens
arr = jnp.full(nb_experts, min_val)
remaining = nb_tokens - nb_experts * min_val
# Function to distribute remaining tokens
def body(i, val):
arr, remaining, key = val
key, subkey = jax.random.split(key)
add = jax.random.randint(subkey, (), 0, jnp.minimum(max_val - min_val, remaining))
arr = arr.at[i].set(arr[i] + add)
remaining -= add
return arr, remaining, key
# Distribute remaining tokens
arr, remaining, key = jax.lax.fori_loop(0, nb_experts, body, (arr, remaining, key))
# Shuffle the array
key, subkey = jax.random.split(key)
perm = jax.random.permutation(subkey, nb_experts)
arr = arr[perm]
return arr
def apply_ragged(data, W, row_splits):
return lax.ragged_dot(
lhs=data,
rhs=W,
group_sizes=row_splits,
)
jit_ragged_dot_group_gemm = jax.jit(apply_ragged, compiler_options={
"xla_gpu_experimental_use_ragged_dot_grouped_gemm": True,
"xla_gpu_enable_cublaslt": True,
"xla_gpu_experimental_use_ragged_dot_fusion": False
})
jit_ragged_dot = jax.jit(apply_ragged, compiler_options={
"xla_gpu_experimental_use_ragged_dot_grouped_gemm": False,
"xla_gpu_enable_cublaslt": True,
"xla_gpu_experimental_use_ragged_dot_fusion": False
})
def run_single_perf_test(num_experts, num_tokens, d_model, ragged_dot, ragged_dot_ref, group_sizes, check_output = False):
n = jax.local_device_count()
key = jax.random.PRNGKey(0)
A = jax.random.uniform(key, (num_tokens, d_model), dtype=jnp.float32)
W = jax.random.uniform(key, (num_experts, d_model,d_model*4), dtype=jnp.float32)
if group_sizes == None :
group_sizes = random_tokens_per_expert(num_experts, 16, 4096, num_tokens, key)
assert jnp.sum(group_sizes) <= num_tokens
max_group_size = jnp.max(group_sizes)
if (A.size >= pow(2,32-1)) or (W.size >= pow(2,32-1)):
# skip
print("Skipped: Matrix size too large")
return 0, 0
print(A.shape)
print(W.shape)
print(group_sizes.shape)
# warm-up (compile + steady state)
for _ in range(10):
out = ragged_dot(A, W, group_sizes)
jax.block_until_ready(out) # ensure warm-up is finished
# benchmark
iters = 10
t0 = time.perf_counter()
for _ in range(iters):
out = ragged_dot(A, W, group_sizes)
jax.block_until_ready(out)
t1 = time.perf_counter()
# check output
validity = "unchecked"
valid_rate = -1
if check_output:
out_ref = ragged_dot_ref(A, W, group_sizes)
is_close = jnp.isclose(out, out_ref, rtol=1e-3, atol=1e-3)
count = jnp.sum(is_close)
valid_rate = count/is_close.size
if (valid_rate == 1):
validity = "Valid"
else:
validity = str("Do not match the reference : " + str(count) + " / " + str(is_close.size) + " (" + str(valid_rate) + ")")
ms = (t1 - t0) * 1000.0 / iters
print(f"devices={n} mean_latency={ms:.3f} ms outputs validity = {validity}")
return ms, valid_rate
def run_perf_test(num_experts, num_tokens, d_model, group_sizes = None, check_output = False, output_file = None):
runs = 10
ret = [run_single_perf_test(num_experts, num_tokens, d_model, jit_ragged_dot_group_gemm, jit_ragged_dot, group_sizes, check_output) for _ in range(runs)]
results, validity_rates = zip(*ret)
mean_ms = stats.mean(results)
std_ms = stats.stdev(results) if len(results) > 1 else 0.0
min_ms = min(results)
max_ms = max(results)
validity_rate = sum(validity_rates)/runs
if check_output and validity_rate != 1 :
print(f"\n Results unvalid : {validity_rate}")
print("\nPer-run mean latencies (ms):")
print(", ".join(f"{v:.3f}" for v in results))
print("\nSummary:")
print(
f"num_experts={num_experts} num_tokens={num_tokens} token_per_expert={int(num_tokens/num_experts)} d_model={d_model}, mean={mean_ms:.3f} ms, std={std_ms:.3f} ms, min={min_ms:.3f} ms, max={max_ms:.3f} ms, valid rate={validity_rate} \n\n"
)
if output_file:
if check_output:
output_file.write(f"{num_experts},{num_tokens},{int(num_tokens/num_experts)},{d_model},{mean_ms:.3f},{std_ms:.3f},{min_ms:.3f},{max_ms:.3f},{validity_rate}\n")
else:
output_file.write(f"{num_experts},{num_tokens},{int(num_tokens/num_experts)},{d_model},{mean_ms:.3f},{std_ms:.3f},{min_ms:.3f},{max_ms:.3f}\n")
def main(output_file):
# check output validity
run_perf_test(64, 64*128, 1024, None, True)
output_file.write("num_experts,num_tokens,token_per_expert,d_model,mean_ms,std_ms,min_ms,max_ms,validity\n")
# Evaluate perf
num_expert = [16, 32, 64, 128, 256]
token_per_expert = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
d_model = [512, 1024, 2048, 4096]
tokens = 64 * 128
# Matrix size unsupported by the GroupGemm backend
# A = (16384, 4096)
# w = (64, 4096, 16384)
# (64,)
# Matrix size >= 2^(32-1) (int32 type)
for nb_experts, te, d in product(num_expert, token_per_expert, d_model):
group_sizes = jnp.full(nb_experts, te)
run_perf_test(nb_experts, nb_experts*te, d, group_sizes, False, output_file)
time.sleep(3)
output_file.write("Random group sizes\n")
# random groups
for nb_experts, d in product(num_expert, d_model):
run_perf_test(nb_experts, tokens, d, None, False, output_file)
time.sleep(3)
if __name__ == "__main__":
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
file_path = os.path.join("logs", f"perf_{timestamp}.csv")
os.makedirs("logs", exist_ok=True)
with open(file_path, "w", encoding="utf-8") as f:
main(f) |
| static_cast<hipblaslt_ext::UserArguments *>( | ||
| host_userArgs.get()->opaque()); | ||
| for (int i = 0; i < group_count; i++) { | ||
| std::cout << "m[" << i << "] = " << h_userArgs[i].m << std::endl; |
There was a problem hiding this comment.
If you want to keep this function, I guess we need to change 'std::cout' to VLOG(x)
…r group-gemm). Implement the on-device update of teh user args. Only initialize algorithm if new. Remove redundant information in the GroupedGemmConfig. General code clean-up.
Add condition on target for using group-gemm. Add tests to verify group-gemm use-conditions (flags, ...).
| __global__ void SetUserArgsKernelRaggedInNonContractingDim( | ||
| hipblaslt_ext::UserArguments* dest_args, void* a, void* b, void* c, void* d, | ||
| void* e, const void* group_sizes, size_t byte_width_elem_a, | ||
| size_t byte_width_elem_b, size_t byte_width_elem_c, |
There was a problem hiding this comment.
maybe uint8_t, or uint32_t
There was a problem hiding this comment.
Also we support mixed element type grouped matmuls?
There was a problem hiding this comment.
Maybe even pass it as log2 of said value so you can use shift instead of multiply.
There was a problem hiding this comment.
Or maybe an even a template parameter.
There was a problem hiding this comment.
The code has been updated. We now pass log2 for bitwidths and shift the address accordingly.
| template <typename T> | ||
| __global__ void SetUserArgsKernelRaggedInNonContractingDim( | ||
| hipblaslt_ext::UserArguments* dest_args, void* a, void* b, void* c, void* d, | ||
| void* e, const void* group_sizes, size_t byte_width_elem_a, |
There was a problem hiding this comment.
Use T*? I guess it helps compiler a bit.
(including change from multiply to shift to take into account datatype in pointer calculation):
Use single-block grid-stride loop to cumulate group sizes when number of groups > block_size using shraed memory
| arg.activationType = 0; | ||
|
|
||
| // Copy from shared memory to global memory | ||
| dest_args[idx] = arg; |
There was a problem hiding this comment.
You are not taking any advantage of coalescing here. In code as it is each thread copies its own arg. What should happend is that all threads in block collectivly perform "memcpy" to global memory. I think the largest store per thred is 4 x int32. So thread 0 reads 16 bytes, thread 1 next 16 and so on and store them in global memory consecutively. And I guess that needs to be repeated 12 times. In order to store everything.
There was a problem hiding this comment.
The code has been update to use all the thread to write back the user-arg.
| arg.n = n; | ||
| arg.a = const_cast<void*>(static_cast<const void*>( | ||
| static_cast<const uint8_t*>(a) + | ||
| ((offset_group * stride_a) << log2_byte_width_elem_a))); |
There was a problem hiding this comment.
You might need to cast to intptr_t before shifting. I guess you can overfow uint32_t?
| offset += BLOCK_SIZE * BYTES_PER_THREAD) { | ||
| size_t bytes_to_copy = min(BYTES_PER_THREAD, total_bytes - offset); | ||
| memcpy(&dest_ptr[offset], &src_ptr[offset], bytes_to_copy); | ||
| } |
There was a problem hiding this comment.
You might need the __barrier(__CLK_LOCAL_MEM_FENCE); before next iteration.
There was a problem hiding this comment.
Tests have shown that local synchronisation is not sufficient between iterations, as we need to ensure that all data has been copied from shared memory to global memory before rewriting shared memory in the next iteration of the loop. The code has been updated accordingly.
| uint32_t batch_size = | ||
| min(BLOCK_SIZE, static_cast<uint32_t>(num_gemms - batch_start)); | ||
|
|
||
| __syncthreads(); |
There was a problem hiding this comment.
Is this __syncthreads needed for ExclusiveSum above? If not move this uodate below the next barrier to reduce number of syncs.
There was a problem hiding this comment.
Actuality you might not need to move it at all. It think it should be protected by barrier bellow.
There was a problem hiding this comment.
We need a barrier here anyway, because we reuse the shared memory used by the BlockScan for sharedUserArgs.
There was a problem hiding this comment.
I see. Then maybe we go back to not using dynamic shmem to tmp storeage for block scan . I suspect it is small anyways.
There was a problem hiding this comment.
The code has been updated to used two different shared memory storages for block scan and args, and the synchronization has been removed.
| for (size_t offset = threadIdx.x * BYTES_PER_THREAD; offset < total_bytes; | ||
| offset += BLOCK_SIZE * BYTES_PER_THREAD) { | ||
| size_t bytes_to_copy = min(BYTES_PER_THREAD, total_bytes - offset); | ||
| memcpy(&dest_ptr[offset], &src_ptr[offset], bytes_to_copy); |
There was a problem hiding this comment.
Can you post the assembly for this function to see of this gets nicely unrolled?
There was a problem hiding this comment.
This code was not working correctly. It has been replaced with a "simple" copy (
xla/xla/stream_executor/rocm/rocm_helpers.cu.cc
Lines 170 to 196 in 2dcac72
HipBlasLT-Ext GroupedGemm support for RaggedDot operation implemented for matrices with and without batch dimension (note that the batch dimension must be the outer-most dimension of the matrix).
Three ragged modes supported:
- ragged in Non-Contracting dimension
- ragged in Contracting dimension
- ragged in Batch dimension
Test Plan
Unit tests are included in this PR. The tests verify:
CublasLtGroupedMatmulThunkand execute correctly on the device.Integration test calling the Jax/Lax RaggedDot operation had been successfully carried out (using Jax
rocm-jaxlib-v0.8.0.branch)The support has only been tested on MI300 - Rocm 7.1.
Test Result
Performance evaluation exposed an higher overhead for small inputs compared to the regular implementation (padded dot), however GroupedGemm presents significant speed-up for large inputs.
Here the performance trend on MI 300 Rocm 7.1 Jax rocm-jaxlib-v0.8.0.` (evaluated using commit c5986c3 with dtype=float32)
num_experts = 16
num_experts = 32