Skip to content

Comments

HipBlasLT-Ext GroupedGemm support for RaggedDot op#514

Draft
mfrancepillois wants to merge 30 commits intorocm-jaxlib-v0.8.0from
ci_support_groupedGemm
Draft

HipBlasLT-Ext GroupedGemm support for RaggedDot op#514
mfrancepillois wants to merge 30 commits intorocm-jaxlib-v0.8.0from
ci_support_groupedGemm

Conversation

@mfrancepillois
Copy link

HipBlasLT-Ext GroupedGemm support for RaggedDot operation implemented for matrices with and without batch dimension (note that the batch dimension must be the outer-most dimension of the matrix).
Three ragged modes supported:
- ragged in Non-Contracting dimension
- ragged in Contracting dimension
- ragged in Batch dimension

Test Plan

Unit tests are included in this PR. The tests verify:

  • the rewriting of RaggedDot into custom-call
  • the translation of the custom-call into a CublasLtGroupedMatmulThunk and execute correctly on the device.

Integration test calling the Jax/Lax RaggedDot operation had been successfully carried out (using Jax rocm-jaxlib-v0.8.0. branch)

The support has only been tested on MI300 - Rocm 7.1.

Test Result

Performance evaluation exposed an higher overhead for small inputs compared to the regular implementation (padded dot), however GroupedGemm presents significant speed-up for large inputs.

Here the performance trend on MI 300 Rocm 7.1 Jax rocm-jaxlib-v0.8.0.` (evaluated using commit c5986c3 with dtype=float32)

num_experts = 16

num_tokens token_per_expert d_model GroupedGemm runtime Baseline runtime
256 16 512 1.561 0.228
256 16 1024 1.663 0.446
256 16 2048 1.919 1.252
256 16 4096 3.225 4.528
512 32 512 2.253 0.268
512 32 1024 2.263 0.871
512 32 2048 2.054 2.295
512 32 4096 3.953 8.623
1024 64 512 2.225 0.43
1024 64 1024 2.3 1.488
1024 64 2048 2.43 4.578
1024 64 4096 5.217 17.287
2048 128 512 2.274 0.994
2048 128 1024 2.503 2.225
2048 128 2048 3.3 8.203
2048 128 4096 8.125 33.118
4096 256 512 1.691 1.43
4096 256 1024 2.174 4.487
4096 256 2048 4.947 17.342
4096 256 4096 14.305 64.274
8256 516 512 1.855 2.193
8256 516 1024 2.767 8.227
8256 516 2048 8.273 32.756
8256 516 4096 27.387 130.875
16384 1024 512 2.157 4.334
16384 1024 1024 4.046 16.611
16384 1024 2048 14.63 64.586
16384 1024 4096 51.259 255.86
32768 2048 512 2.786 8.331
32768 2048 1024 6.378 32.267
32768 2048 2048 27.502 128.509
32768 2048 4096 100.317 512.569
65536 4096 512 4.127 16.769
65536 4096 1024 11.853 64.811
65536 4096 2048 53.116 257.95
65536 4096 4096 198.419 1021.526

num_experts = 32

num_tokens token_per_expert d_model GroupedGemm runtime Baseline runtime
512 16 512 1.606 0.437
512 16 1024 1.733 1.221
512 16 2048 2.126 4.354
1024 32 512 1.619 0.691
1024 32 1024 1.782 2.22
1024 32 2048 2.408 8.499
2048 64 512 1.693 1.272
2048 64 1024 2.061 4.234
2048 64 2048 3.235 16.436
4096 128 512 1.736 2.204
4096 128 1024 2.24 8.591
4096 128 2048 4.829 32.981
8192 256 512 1.899 4.193
8192 256 1024 2.892 16.308
8192 256 2048 8.047 64.677
16512 516 512 2.246 8.353
16512 516 1024 4.148 32.731
16512 516 2048 14.975 129.56
32768 1024 512 2.873 16.575
32768 1024 1024 6.701 64.009
32768 1024 2048 27.477 256.705
65536 2048 512 4.336 33.04
65536 2048 1024 11.95 128.634
65536 2048 2048 53.259 514.378
131072 4096 512 7.45 66.187
131072 4096 1024 22.699 256.132
131072 4096 2048 104.371 1022.845

HipBlasLT-Ext GroupedGemm support for RaggedDot implemented for matrices with and without batch dimension (note that the batch dimension must be the outer-most dimension of the matrix).
Three ragged modes supported:
- ragged in Non-Contracting dimension
- ragged in Contracting dimension
- ragged in Batch dimension
Not implemented yet:
- epilogues
- proto for GroupedGemm config
Add GroupedGemm Algo validity check.
@mfrancepillois
Copy link
Author

The performance of GroupedGemm on MI300, Rocm version 7.1, Jax rocm-jaxlib-v0.8.0, using the following script:

import time
import statistics as stats
import jax
import jax.numpy as jnp
from jax import lax
from jax import random
from itertools import product
import os
from datetime import datetime

def random_tokens_per_expert(nb_experts, min_val, max_val, nb_tokens, key):
    # Convert inputs to Python ints if they are arrays
    nb_experts = int(nb_experts)
    min_val = int(min_val)
    max_val = int(max_val)
    nb_tokens = int(nb_tokens)

    # Check feasibility
    if not (nb_experts * min_val <= nb_tokens <= nb_experts * max_val):
        raise ValueError("Impossible to generate with given constraints")

    # Start with minimum tokens
    arr = jnp.full(nb_experts, min_val)
    remaining = nb_tokens - nb_experts * min_val

    # Function to distribute remaining tokens
    def body(i, val):
        arr, remaining, key = val
        key, subkey = jax.random.split(key)
        add = jax.random.randint(subkey, (), 0, jnp.minimum(max_val - min_val, remaining))
        arr = arr.at[i].set(arr[i] + add)
        remaining -= add
        return arr, remaining, key

    # Distribute remaining tokens
    arr, remaining, key = jax.lax.fori_loop(0, nb_experts, body, (arr, remaining, key))

    # Shuffle the array
    key, subkey = jax.random.split(key)
    perm = jax.random.permutation(subkey, nb_experts)
    arr = arr[perm]

    return arr

def apply_ragged(data, W, row_splits):
    return lax.ragged_dot(
        lhs=data,
        rhs=W,
        group_sizes=row_splits,
    )

jit_ragged_dot_group_gemm = jax.jit(apply_ragged,  compiler_options={
        "xla_gpu_experimental_use_ragged_dot_grouped_gemm": True,
        "xla_gpu_enable_cublaslt": True,
        "xla_gpu_experimental_use_ragged_dot_fusion": False
    })

jit_ragged_dot = jax.jit(apply_ragged,  compiler_options={
        "xla_gpu_experimental_use_ragged_dot_grouped_gemm": False,
        "xla_gpu_enable_cublaslt": True,
        "xla_gpu_experimental_use_ragged_dot_fusion": False
    })

def run_single_perf_test(num_experts, num_tokens, d_model, ragged_dot, ragged_dot_ref, group_sizes, check_output = False):
    n = jax.local_device_count()
    key = jax.random.PRNGKey(0)
    A = jax.random.uniform(key, (num_tokens, d_model), dtype=jnp.float32)
    W = jax.random.uniform(key, (num_experts, d_model,d_model*4), dtype=jnp.float32)

    if group_sizes == None :
        group_sizes = random_tokens_per_expert(num_experts, 16, 4096, num_tokens, key)
    assert jnp.sum(group_sizes) <= num_tokens

    max_group_size = jnp.max(group_sizes)

    if (A.size >= pow(2,32-1)) or (W.size >= pow(2,32-1)):
        # skip
        print("Skipped: Matrix size too large")
        return 0, 0

    print(A.shape)
    print(W.shape)
    print(group_sizes.shape)

    # warm-up (compile + steady state)
    for _ in range(10):
        out = ragged_dot(A, W, group_sizes)
    jax.block_until_ready(out)  # ensure warm-up is finished

    # benchmark
    iters = 10
    t0 = time.perf_counter()
    for _ in range(iters):
        out = ragged_dot(A, W, group_sizes)
        jax.block_until_ready(out)
    t1 = time.perf_counter()

    # check output
    validity = "unchecked"
    valid_rate = -1
    if check_output:
        out_ref = ragged_dot_ref(A, W, group_sizes)
        is_close = jnp.isclose(out, out_ref, rtol=1e-3, atol=1e-3)
        count = jnp.sum(is_close)
        valid_rate = count/is_close.size
        if (valid_rate == 1):
            validity = "Valid"
        else:
            validity = str("Do not match the reference : " + str(count) + " / " + str(is_close.size) + " (" + str(valid_rate) + ")")

    ms = (t1 - t0) * 1000.0 / iters
    print(f"devices={n}  mean_latency={ms:.3f} ms outputs validity = {validity}")
    
    return ms, valid_rate


def run_perf_test(num_experts, num_tokens, d_model, group_sizes = None, check_output = False, output_file = None):
    runs = 10
    ret = [run_single_perf_test(num_experts, num_tokens, d_model, jit_ragged_dot_group_gemm, jit_ragged_dot, group_sizes, check_output) for _ in range(runs)]
    
    results, validity_rates = zip(*ret)

    mean_ms = stats.mean(results)
    std_ms = stats.stdev(results) if len(results) > 1 else 0.0
    min_ms = min(results)
    max_ms = max(results)

    validity_rate = sum(validity_rates)/runs

    if check_output and validity_rate != 1 :
        print(f"\n Results unvalid : {validity_rate}")

    print("\nPer-run mean latencies (ms):")
    print(", ".join(f"{v:.3f}" for v in results))
    print("\nSummary:")
    print(
        f"num_experts={num_experts} num_tokens={num_tokens} token_per_expert={int(num_tokens/num_experts)} d_model={d_model}, mean={mean_ms:.3f} ms, std={std_ms:.3f} ms, min={min_ms:.3f} ms, max={max_ms:.3f} ms, valid rate={validity_rate} \n\n"
    )
    if output_file:
        if check_output:
            output_file.write(f"{num_experts},{num_tokens},{int(num_tokens/num_experts)},{d_model},{mean_ms:.3f},{std_ms:.3f},{min_ms:.3f},{max_ms:.3f},{validity_rate}\n")
        else:
            output_file.write(f"{num_experts},{num_tokens},{int(num_tokens/num_experts)},{d_model},{mean_ms:.3f},{std_ms:.3f},{min_ms:.3f},{max_ms:.3f}\n")

def main(output_file):
    # check output validity
    run_perf_test(64, 64*128, 1024, None, True)

    output_file.write("num_experts,num_tokens,token_per_expert,d_model,mean_ms,std_ms,min_ms,max_ms,validity\n")

    # Evaluate perf
    num_expert = [16, 32, 64, 128, 256]
    token_per_expert = [16, 32, 64, 128, 256, 512, 1024, 2048, 4096]
    d_model = [512, 1024, 2048, 4096]
    tokens = 64 * 128

    # Matrix size unsupported by the GroupGemm backend
    # A = (16384, 4096)
    # w = (64, 4096, 16384)
    # (64,)
    # Matrix size >= 2^(32-1) (int32 type)

    for nb_experts, te, d in product(num_expert, token_per_expert, d_model):
        group_sizes = jnp.full(nb_experts, te)
        run_perf_test(nb_experts, nb_experts*te, d, group_sizes, False, output_file)
        time.sleep(3)

    output_file.write("Random group sizes\n")

    # random groups
    for nb_experts, d in product(num_expert, d_model):
        run_perf_test(nb_experts, tokens, d, None, False, output_file)
        time.sleep(3)


if __name__ == "__main__":
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    file_path = os.path.join("logs", f"perf_{timestamp}.csv")

    os.makedirs("logs", exist_ok=True)

    with open(file_path, "w", encoding="utf-8") as f:
        main(f)

static_cast<hipblaslt_ext::UserArguments *>(
host_userArgs.get()->opaque());
for (int i = 0; i < group_count; i++) {
std::cout << "m[" << i << "] = " << h_userArgs[i].m << std::endl;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to keep this function, I guess we need to change 'std::cout' to VLOG(x)

…r group-gemm).

Implement the on-device update of teh user args.
Only initialize algorithm if new.
Remove redundant information in the GroupedGemmConfig.
General code clean-up.
Add condition on target for using group-gemm.
Add tests to verify group-gemm use-conditions (flags, ...).
__global__ void SetUserArgsKernelRaggedInNonContractingDim(
hipblaslt_ext::UserArguments* dest_args, void* a, void* b, void* c, void* d,
void* e, const void* group_sizes, size_t byte_width_elem_a,
size_t byte_width_elem_b, size_t byte_width_elem_c,
Copy link

@draganmladjenovic draganmladjenovic Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe uint8_t, or uint32_t

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also we support mixed element type grouped matmuls?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe even pass it as log2 of said value so you can use shift instead of multiply.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or maybe an even a template parameter.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code has been updated. We now pass log2 for bitwidths and shift the address accordingly.

template <typename T>
__global__ void SetUserArgsKernelRaggedInNonContractingDim(
hipblaslt_ext::UserArguments* dest_args, void* a, void* b, void* c, void* d,
void* e, const void* group_sizes, size_t byte_width_elem_a,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use T*? I guess it helps compiler a bit.

(including change from multiply to shift to take into account datatype in pointer calculation):
Use single-block grid-stride loop to cumulate group sizes when number of groups > block_size using shraed memory
arg.activationType = 0;

// Copy from shared memory to global memory
dest_args[idx] = arg;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are not taking any advantage of coalescing here. In code as it is each thread copies its own arg. What should happend is that all threads in block collectivly perform "memcpy" to global memory. I think the largest store per thred is 4 x int32. So thread 0 reads 16 bytes, thread 1 next 16 and so on and store them in global memory consecutively. And I guess that needs to be repeated 12 times. In order to store everything.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code has been update to use all the thread to write back the user-arg.

arg.n = n;
arg.a = const_cast<void*>(static_cast<const void*>(
static_cast<const uint8_t*>(a) +
((offset_group * stride_a) << log2_byte_width_elem_a)));

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might need to cast to intptr_t before shifting. I guess you can overfow uint32_t?

offset += BLOCK_SIZE * BYTES_PER_THREAD) {
size_t bytes_to_copy = min(BYTES_PER_THREAD, total_bytes - offset);
memcpy(&dest_ptr[offset], &src_ptr[offset], bytes_to_copy);
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You might need the __barrier(__CLK_LOCAL_MEM_FENCE); before next iteration.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests have shown that local synchronisation is not sufficient between iterations, as we need to ensure that all data has been copied from shared memory to global memory before rewriting shared memory in the next iteration of the loop. The code has been updated accordingly.

uint32_t batch_size =
min(BLOCK_SIZE, static_cast<uint32_t>(num_gemms - batch_start));

__syncthreads();

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this __syncthreads needed for ExclusiveSum above? If not move this uodate below the next barrier to reduce number of syncs.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actuality you might not need to move it at all. It think it should be protected by barrier bellow.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a barrier here anyway, because we reuse the shared memory used by the BlockScan for sharedUserArgs.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. Then maybe we go back to not using dynamic shmem to tmp storeage for block scan . I suspect it is small anyways.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code has been updated to used two different shared memory storages for block scan and args, and the synchronization has been removed.

for (size_t offset = threadIdx.x * BYTES_PER_THREAD; offset < total_bytes;
offset += BLOCK_SIZE * BYTES_PER_THREAD) {
size_t bytes_to_copy = min(BYTES_PER_THREAD, total_bytes - offset);
memcpy(&dest_ptr[offset], &src_ptr[offset], bytes_to_copy);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you post the assembly for this function to see of this gets nicely unrolled?

Copy link
Author

@mfrancepillois mfrancepillois Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code was not working correctly. It has been replaced with a "simple" copy (

__device__ __forceinline__ void copy_shared_to_global(void* shared_src,
void* global_dest,
size_t total_bytes) {
size_t count_uint4 = total_bytes / sizeof(uint4);
// Vectorized copy using uint4 (16 bytes per iteration)
if (count_uint4 > 0) {
uint4* src_ptr = reinterpret_cast<uint4*>(shared_src);
uint4* dest_ptr = reinterpret_cast<uint4*>(global_dest);
for (size_t i = threadIdx.x; i < count_uint4; i += blockDim.x) {
dest_ptr[i] = src_ptr[i];
}
}
// Handle remaining bytes (if total_bytes is not a multiple of 16)
size_t remaining_bytes = total_bytes % sizeof(uint4);
if (remaining_bytes > 0) {
uint8_t* src_ptr = reinterpret_cast<uint8_t*>(shared_src);
uint8_t* dest_ptr = reinterpret_cast<uint8_t*>(global_dest);
size_t offset = count_uint4 * sizeof(uint4);
for (size_t i = threadIdx.x; i < remaining_bytes; i += blockDim.x) {
dest_ptr[offset + i] = src_ptr[offset + i];
}
}
}
).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants