From 96620adf77eb12eb31e9e0266633d1881d2657c4 Mon Sep 17 00:00:00 2001 From: zh476 Date: Fri, 12 Dec 2025 20:18:25 -0500 Subject: [PATCH 1/3] initial commit --- ece6775 | 1 + 1 file changed, 1 insertion(+) create mode 160000 ece6775 diff --git a/ece6775 b/ece6775 new file mode 160000 index 000000000..18e963df7 --- /dev/null +++ b/ece6775 @@ -0,0 +1 @@ +Subproject commit 18e963df74647205763c75cdecde87e3f87c927c From 03206cc3862adbbf42becff09dda8c062677449f Mon Sep 17 00:00:00 2001 From: zh476 Date: Fri, 12 Dec 2025 21:15:36 -0500 Subject: [PATCH 2/3] with readme --- ece6775 | 1 - ece6775/README.md | 43 + .../attention_moe/allo_attention_moe_alt.py | 945 +++++++++++++++ .../attention_moe/allo_attention_moe_base.py | 689 +++++++++++ .../attention_moe/allo_attention_moe_lib.py | 771 ++++++++++++ .../attention_moe/pytorch_attention_moe.py | 1060 +++++++++++++++++ ece6775/llm_config/check_llm_config.py | 124 ++ ece6775/llm_config/llm_config.py | 142 +++ ece6775/moe/allo_moe_alt.py | 446 +++++++ ece6775/moe/allo_moe_base.py | 489 ++++++++ ece6775/moe/allo_moe_lib.py | 437 +++++++ ece6775/moe/pytorch_moe.py | 336 ++++++ 12 files changed, 5482 insertions(+), 1 deletion(-) delete mode 160000 ece6775 create mode 100644 ece6775/README.md create mode 100644 ece6775/attention_moe/allo_attention_moe_alt.py create mode 100644 ece6775/attention_moe/allo_attention_moe_base.py create mode 100644 ece6775/attention_moe/allo_attention_moe_lib.py create mode 100644 ece6775/attention_moe/pytorch_attention_moe.py create mode 100644 ece6775/llm_config/check_llm_config.py create mode 100644 ece6775/llm_config/llm_config.py create mode 100644 ece6775/moe/allo_moe_alt.py create mode 100644 ece6775/moe/allo_moe_base.py create mode 100644 ece6775/moe/allo_moe_lib.py create mode 100644 ece6775/moe/pytorch_moe.py diff --git a/ece6775 b/ece6775 deleted file mode 160000 index 18e963df7..000000000 --- a/ece6775 +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 18e963df74647205763c75cdecde87e3f87c927c diff --git a/ece6775/README.md b/ece6775/README.md new file mode 100644 index 000000000..b2be2a2c0 --- /dev/null +++ b/ece6775/README.md @@ -0,0 +1,43 @@ +# ECE6775 - MoE and Attention+MoE Implementations + +Allo implementations of MoE (Mixture of Experts) and Attention+MoE layers for FPGA acceleration. + +## Structure + +- `moe/`: Standalone MoE layer implementations +- `attention_moe/`: Attention + MoE combined implementations +- `llm_config/`: Shared configuration for different model sizes + +## MoE Implementations + +Each directory contains multiple versions: + +- `*_base.py`: Manual implementation (no library functions) +- `*_lib.py`: Uses Allo library functions (nn.linear2d, nn.GeLU) +- `*_alt.py`: Optimized version with fused GeLU and row-level dataflow +- `pytorch_*.py`: PyTorch reference implementation + +## Usage + +Run any implementation directly: + +```bash +python moe/allo_moe_alt.py +python attention_moe/allo_attention_moe_alt.py +``` + +## Configuration + +Edit `llm_config/llm_config.py` to change model configuration. Default: `switch_base_8_scaled_1_8` + +Available configs: +- `switch_base_8*`: Google Switch-Base-8 variants +- `mixtral_8x7b*`: Mixtral-8x7B variants +- `deepseek*`: DeepSeek MoE variants +- `custom`: Small test configuration + +## Notes + +- Small numerical differences (~1e-4) vs PyTorch are normal due to accumulation order +- Set `MODE` in each file to control build target: `llvm`, `sw_emu`, `hw_emu`, `hw`, `csyn` + diff --git a/ece6775/attention_moe/allo_attention_moe_alt.py b/ece6775/attention_moe/allo_attention_moe_alt.py new file mode 100644 index 000000000..c30b616b6 --- /dev/null +++ b/ece6775/attention_moe/allo_attention_moe_alt.py @@ -0,0 +1,945 @@ +# Attention + MoE in Allo +# Q, K, V -> Attention -> MoE -> Output +# custom implementations (no library functions) for full control +# optimizations: fused GeLU, row-level dataflow, unrolled loops, array partitioning + +import numpy as np +import allo +from allo.ir.types import float32, int32 +from allo import dsl +import sys +import os + +# path setup +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(current_dir) +llm_config_dir = os.path.join(project_root, "llm_config") +if llm_config_dir not in sys.path: + sys.path.insert(0, llm_config_dir) + +MODE = "csyn" # llvm, sw_emu, hw_emu, hw, csyn + +from llm_config import DEFAULT_CONFIG_MODE, get_moe_config, print_config_info + +CONFIG_MODE = DEFAULT_CONFIG_MODE +def softmax_1d[Ty, N, K](X: "Ty[N, K]") -> "Ty[N, K]": + # softmax over last dim + Z: Ty[N, K] + E_exp: Ty[N, K] + M: Ty[N] = -1000000000000.0 + S: Ty[N] = 0.0 + + # find max per row + for n, k in dsl.grid(N, K, name="row_max"): + if X[n, k] > M[n]: + M[n] = X[n, k] + + # exp and sum + for n, k in dsl.grid(N, K, name="exp_sum"): + E_exp[n, k] = dsl.exp(X[n, k] - M[n]) + S[n] += E_exp[n, k] + + # normalize + for n, k in dsl.grid(N, K, name="normalize"): + Z[n, k] = E_exp[n, k] / S[n] + + return Z + +# softmax for attention scores [L, L] +def softmax_2d[Ty, L](X: "Ty[L, L]") -> "Ty[L, L]": + Z: Ty[L, L] + E_exp: Ty[L, L] + M: Ty[L] = -1000000000000.0 + S: Ty[L] = 0.0 + + # find max per row + for i, j in dsl.grid(L, L, name="row_max"): + if X[i, j] > M[i]: + M[i] = X[i, j] + + # exp and sum + for i, j in dsl.grid(L, L, name="exp_sum"): + E_exp[i, j] = dsl.exp(X[i, j] - M[i]) + S[i] += E_exp[i, j] + + # normalize + for i, j in dsl.grid(L, L, name="normalize"): + Z[i, j] = E_exp[i, j] / S[i] + + return Z + +# scaled dot-product attention (multi-head) +def scaled_dot_product_attention[Ty, H, L, D]( + Q: "Ty[L, D]", + K: "Ty[L, D]", + V: "Ty[L, D]" +) -> "Ty[L, D]": + # scaled dot-product attention (multi-head) + + # matches pytorch: split into H heads, compute attention per head, merge + Z: Ty[L, D] = 0.0 + + # scale factor: 1/sqrt(head_dim) + scale: Ty = 1.0 / dsl.sqrt(float(D // H)) + + # process each head + for h in range(H, name="head_loop"): + # split Q, K, V for this head + Q_h: Ty[L, D // H] = 0.0 + K_h: Ty[L, D // H] = 0.0 + V_h: Ty[L, D // H] = 0.0 + + for i, j in dsl.grid(L, D // H, name="split_qkv"): + Q_h[i, j] = Q[i, h * (D // H) + j] + K_h[i, j] = K[i, h * (D // H) + j] + V_h[i, j] = V[i, h * (D // H) + j] + + # QK^T with accumulator to break read-after-write dependency + Y: Ty[L, L] = 0.0 + for i in range(L, name="qkt_i"): + for j in range(L, name="qkt_j"): + acc: Ty = 0.0 + for k in range(D // H, name="qkt_k"): + acc += Q_h[i, k] * K_h[j, k] # K_h[j, k] = K_h^T[k, j] + Y[i, j] = acc + + # scale + Y_scaled: Ty[L, L] = 0.0 + for i, j in dsl.grid(L, L, name="scale"): + Y_scaled[i, j] = Y[i, j] * scale + + # Apply softmax over last dimension + S: Ty[L, L] = 0.0 + E_exp: Ty[L, L] = 0.0 + M: Ty[L] = -1000000000000.0 + Sum: Ty[L] = 0.0 + + # Find max for each row + for i, j in dsl.grid(L, L, name="softmax_max"): + if Y_scaled[i, j] > M[i]: + M[i] = Y_scaled[i, j] + + # Compute exp and sum + for i, j in dsl.grid(L, L, name="softmax_exp"): + E_exp[i, j] = dsl.exp(Y_scaled[i, j] - M[i]) + Sum[i] += E_exp[i, j] + + # Normalize + for i, j in dsl.grid(L, L, name="softmax_norm"): + S[i, j] = E_exp[i, j] / Sum[i] + + # Compute S @ V_h = [L, L] @ [L, D//H] = [L, D//H] + # Loop order (i, k, j) for better memory access and pipelining: + # - Sequential access to V_h[k, j] as j changes (inner loop) + # - Better pipelining of reduction loop k + C_h: Ty[L, D // H] = 0.0 + for i in range(L, name="sv_i"): + for k in range(L, name="sv_k"): + for j in range(D // H, name="sv_j"): + C_h[i, j] += S[i, k] * V_h[k, j] + + # Merge back to Z + for i, j in dsl.grid(L, D // H, name="merge_heads"): + Z[i, h * (D // H) + j] = C_h[i, j] + + return Z + +# top-1 selection (argmax) +def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": + # pick best expert per token + indices: int32[N] + max_val: Ty[N] = -1000000000000.0 + + for n in range(N, name="init"): + indices[n] = 0 + max_val[n] = logits[n, 0] + + for n, e in dsl.grid(N, E, name="argmax"): + if e > 0: # skip e=0 + if logits[n, e] > max_val[n]: + max_val[n] = logits[n, e] + indices[n] = e + + return indices + +# FFN expert with fused GeLU +# optimizations: fuse GeLU into FC1, row-level dataflow, unroll reduction loops +def expert[Ty, N, D_in, D_hidden, D_out]( + x: "Ty[N, D_in]", + fc1_weight: "Ty[D_hidden, D_in]", + fc1_bias: "Ty[D_hidden]", + fc2_weight: "Ty[D_out, D_hidden]", + fc2_bias: "Ty[D_out]" +) -> "Ty[N, D_out]": + """ + A simple feed-forward expert network with fused GeLU and row-level dataflow. + + Optimizations applied: + 1. GeLU fused into FC1: Compute GeLU immediately after each FC1 element + to eliminate intermediate array storage and enable streaming + 2. Row-level processing: Outer loop over rows (N) enables dataflow between + FC1+GeLU and FC2 stages - while FC2 processes row n, FC1 can process row n+1 + 3. Separate reduction loops enable unroll pragmas + + Args: + x: Input tensor of shape [N, D_in] + fc1_weight: First linear layer weights of shape [D_hidden, D_in] + fc1_bias: First linear layer bias of shape [D_hidden] + fc2_weight: Second linear layer weights of shape [D_out, D_hidden] + fc2_bias: Second linear layer bias of shape [D_out] + Returns: + output: Output tensor of shape [N, D_out] + """ + output: Ty[N, D_out] = 0.0 + + # Row-level processing: outer loop over rows (tokens) + # This structure enables dataflow between FC1+GeLU and FC2 + for n in range(N, name="row_loop"): + # ===================================================================== + # Stage 1: FC1 + Fused GeLU (produces one row of hidden activations) + # ===================================================================== + # Intermediate buffer for this row's hidden activations (after GeLU) + hidden_row: Ty[D_hidden] = 0.0 + + # Compute FC1 output for each hidden dimension, fuse GeLU immediately + for h in range(D_hidden, name="fc1_gelu_loop"): + # FC1: dot product x[n,:] @ fc1_weight[h,:] + fc1_bias[h] + acc: Ty = 0.0 + for k in range(D_in, name="fc1_reduce"): + acc += x[n, k] * fc1_weight[h, k] + fc1_val: Ty = acc + fc1_bias[h] + + # Fused GeLU: apply immediately to avoid storing fc1_out array + # GeLU(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³))) + # This matches PyTorch's GELU implementation (tanh approximation) + # sqrt(2/π) ≈ 0.7978845608028654 + x3: Ty = fc1_val * fc1_val * fc1_val + inner: Ty = 0.7978845608028654 * (fc1_val + 0.044715 * x3) + tanh_val: Ty = dsl.tanh(inner) + gelu_val: Ty = 0.5 * fc1_val * (1.0 + tanh_val) + hidden_row[h] = gelu_val + + # ===================================================================== + # Stage 2: FC2 (consumes hidden row, produces output row) + # ===================================================================== + for o in range(D_out, name="fc2_loop"): + # FC2: dot product hidden_row @ fc2_weight[o,:] + fc2_bias[o] + acc2: Ty = 0.0 + for k in range(D_hidden, name="fc2_reduce"): + acc2 += hidden_row[k] * fc2_weight[o, k] + output[n, o] = acc2 + fc2_bias[o] + + return output + +# MoE layer - routes tokens to experts +def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( + x: "Ty[N, D_in]", + # Gate weights + gate_weight: "Ty[E, D_in]", + gate_bias: "Ty[E]", + # Expert weights (E experts, each with 2 linear layers) + expert_fc1_weights: "Ty[E, D_hidden, D_in]", + expert_fc1_biases: "Ty[E, D_hidden]", + expert_fc2_weights: "Ty[E, D_out, D_hidden]", + expert_fc2_biases: "Ty[E, D_out]" +) -> "Ty[N, D_out]": + """ + Mixture of Experts layer for inference. + + Args: + x: Input tensor of shape [N, D_in] (already flattened from [B, L, D_in]) + gate_weight: Gate weight matrix of shape [E, D_in] + gate_bias: Gate bias vector of shape [E] + expert_fc1_weights: Expert FC1 weights of shape [E, D_hidden, D_in] + expert_fc1_biases: Expert FC1 biases of shape [E, D_hidden] + expert_fc2_weights: Expert FC2 weights of shape [E, D_out, D_hidden] + expert_fc2_biases: Expert FC2 biases of shape [E, D_out] + Returns: + output: Output tensor of shape [N, D_out] + """ + # ========================================================================= + # Step 1: Compute gate logits (custom linear, no library) + # gate_logits[n, e] = x[n, :] @ gate_weight[e, :] + gate_bias[e] + # ========================================================================= + gate_logits: Ty[N, E] = 0.0 + for n, e in dsl.grid(N, E, name="gate_linear"): + acc: Ty = 0.0 + for k in range(D_in, name="gate_reduce"): + acc += x[n, k] * gate_weight[e, k] + gate_logits[n, e] = acc + gate_bias[e] + + # ========================================================================= + # Step 2: Select top-1 expert using top1_select function + # ========================================================================= + top1_indices_1d: int32[N] = top1_select[Ty, N, E](gate_logits) + + # ========================================================================= + # Step 3: Get top-k logits and apply softmax + # ========================================================================= + top_k_logits: Ty[N, K] = 0.0 + for n, k in dsl.grid(N, K, name="topk_logits"): + expert_idx = top1_indices_1d[n] if k == 0 else 0 # For k=1, K=1 + top_k_logits[n, k] = gate_logits[n, expert_idx] + + # ========================================================================= + # Step 4: Apply softmax to top-k logits using softmax_1d function + # ========================================================================= + top_k_weights = softmax_1d[Ty, N, K](top_k_logits) # [N, K] + + # ========================================================================= + # Step 5: Create sparse weight matrix from top-k weights + # ========================================================================= + gate_weights: Ty[N, E] = 0.0 + for n in range(N, name="sparse_gate"): + expert_idx = top1_indices_1d[n] + gate_weights[n, expert_idx] = top_k_weights[n, 0] # For k=1, K=1 + + # ========================================================================= + # Step 6: Process each expert: compute outputs for all tokens + # ========================================================================= + expert_outputs: Ty[E, N, D_out] = 0.0 + + for e in range(E, name="expert_loop"): + # Extract expert weights for this expert + expert_fc1_w: Ty[D_hidden, D_in] = 0.0 + expert_fc1_b: Ty[D_hidden] = 0.0 + expert_fc2_w: Ty[D_out, D_hidden] = 0.0 + expert_fc2_b: Ty[D_out] = 0.0 + + for d_hidden, d_in in dsl.grid(D_hidden, D_in, name="extract_fc1_w"): + expert_fc1_w[d_hidden, d_in] = expert_fc1_weights[e, d_hidden, d_in] + + for d_hidden in range(D_hidden, name="extract_fc1_b"): + expert_fc1_b[d_hidden] = expert_fc1_biases[e, d_hidden] + + for d_out, d_hidden in dsl.grid(D_out, D_hidden, name="extract_fc2_w"): + expert_fc2_w[d_out, d_hidden] = expert_fc2_weights[e, d_out, d_hidden] + + for d_out in range(D_out, name="extract_fc2_b"): + expert_fc2_b[d_out] = expert_fc2_biases[e, d_out] + + # Process all tokens through this expert (uses optimized expert function) + expert_out = expert[Ty, N, D_in, D_hidden, D_out]( + x, expert_fc1_w, expert_fc1_b, expert_fc2_w, expert_fc2_b + ) # [N, D_out] + + # Store expert outputs + for n, d_out in dsl.grid(N, D_out, name="store_expert_out"): + expert_outputs[e, n, d_out] = expert_out[n, d_out] + + # ========================================================================= + # Step 7: Combine expert outputs using gate weights + # ========================================================================= + output: Ty[N, D_out] = 0.0 + for n, e, d_out in dsl.grid(N, E, D_out, name="combine_outputs"): + weight: Ty = gate_weights[n, e] + output[n, d_out] += expert_outputs[e, n, d_out] * weight + + return output + +#---------------------------------------------------------------------------------- +# Attention + MoE Layer: Combined layer +# Data flow: Q, K, V -> Attention -> MoE -> Output +#---------------------------------------------------------------------------------- +def attention_moe_layer[Ty, B, L, D, H, E, TopK, D_hidden]( + Query: "Ty[B, L, D]", + Key: "Ty[B, L, D]", + Value: "Ty[B, L, D]", + # Gate weights + gate_weight: "Ty[E, D]", + gate_bias: "Ty[E]", + # Expert weights (E experts, each with 2 linear layers) + expert_fc1_weights: "Ty[E, D_hidden, D]", + expert_fc1_biases: "Ty[E, D_hidden]", + expert_fc2_weights: "Ty[E, D, D_hidden]", + expert_fc2_biases: "Ty[E, D]" +) -> "Ty[B, L, D]": + """ + Combined Attention + MoE layer. + + Data flow: Q, K, V -> Attention -> MoE -> Output + + Args: + Query: Query tensor of shape [B, L, D] + Key: Key tensor of shape [B, L, D] + Value: Value tensor of shape [B, L, D] + gate_weight: Gate weight matrix of shape [E, D] + gate_bias: Gate bias vector of shape [E] + expert_fc1_weights: Expert FC1 weights of shape [E, D_hidden, D] + expert_fc1_biases: Expert FC1 biases of shape [E, D_hidden] + expert_fc2_weights: Expert FC2 weights of shape [E, D, D_hidden] + expert_fc2_biases: Expert FC2 biases of shape [E, D] + Returns: + output: Output tensor of shape [B, L, D] + """ + # Output tensor + output: Ty[B, L, D] = 0.0 + + # Process each batch item + for b in range(B, name="batch_loop"): + # Step 1: Extract Q, K, V for this batch item -> [L, D] + Q_b: Ty[L, D] = 0.0 + K_b: Ty[L, D] = 0.0 + V_b: Ty[L, D] = 0.0 + + for l, d in dsl.grid(L, D, name="extract_qkv"): + Q_b[l, d] = Query[b, l, d] + K_b[l, d] = Key[b, l, d] + V_b[l, d] = Value[b, l, d] + + # Step 2: Apply custom scaled_dot_product_attention + # scaled_dot_product_attention[Ty, H, L, D](Q, K, V) -> [L, D] + attn_out = scaled_dot_product_attention[Ty, H, L, D](Q_b, K_b, V_b) # [L, D] + + # Step 3: Apply MoE layer + # moe_layer expects [N, D_in], so we use N=L for single batch + moe_out = moe_layer[Ty, L, D, D, E, TopK, D_hidden]( + attn_out, + gate_weight, + gate_bias, + expert_fc1_weights, + expert_fc1_biases, + expert_fc2_weights, + expert_fc2_biases + ) # [L, D] + + # Step 4: Store output for this batch item + for l, d in dsl.grid(L, D, name="store_output"): + output[b, l, d] = moe_out[l, d] + + return output + + +#================================================================================== +# Schedule optimization function with HLS pragmas +#================================================================================== +def optimize_attention_moe_with_composition( + batch_size, seq_len, embed_dim, num_heads, num_experts, k, hidden_dim +): + """ + Create optimized schedules for Attention + MoE and compose them together. + + Optimization strategy: + 1. Pipeline innermost loops for throughput + 2. Unroll small loops for parallelism + 3. Partition arrays for parallel access + + Args: + batch_size: Batch size (B) + seq_len: Sequence length (L) + embed_dim: Embedding dimension (D) + num_heads: Number of attention heads (H) + num_experts: Number of experts (E) + k: Top-k value (currently only k=1 is supported) + hidden_dim: Hidden dimension for experts + + Returns: + s_attn_moe: Optimized schedule for attention_moe_layer with all sub-schedules composed + """ + Ty = float32 + B = batch_size + L = seq_len + D = embed_dim + H = num_heads + E = num_experts + K = k + D_hidden = hidden_dim + head_dim = D // H + + print("=" * 60) + print("Creating and optimizing Attention + MoE schedules...") + print("=" * 60) + + # ========================================================================= + # Step 1: Create and optimize schedule for top1_select + # ========================================================================= + print("\n[1] Creating schedule for top1_select...") + s_top1 = allo.customize(top1_select, instantiate=[Ty, L, E]) + # Pipeline the inner loop for finding argmax + # Note: Use "function_name:loop_var" format for dsl.grid loops + s_top1.pipeline("top1_select:e") + print(" - Created top1_select schedule with pipeline optimization") + + # ========================================================================= + # Step 2: Create and optimize schedule for softmax_1d (for MoE gate) + # ========================================================================= + print("\n[2] Creating schedule for softmax_1d...") + s_softmax_1d = allo.customize(softmax_1d, instantiate=[Ty, L, K]) + # Pipeline softmax loops using get_loops() to get loop handles + loops_softmax = s_softmax_1d.get_loops(s_softmax_1d.top_func_name) + s_softmax_1d.pipeline(loops_softmax["row_max"]["k"]) + s_softmax_1d.pipeline(loops_softmax["exp_sum"]["k"]) + s_softmax_1d.pipeline(loops_softmax["normalize"]["k"]) + print(" - Created softmax_1d schedule with pipeline optimization") + + # ========================================================================= + # Step 3: Create and optimize schedule for expert + # Optimizations: + # - Dataflow on row_loop (pipeline FC1+GeLU and FC2 stages) + # - Unroll on fc1_reduce and fc2_reduce loops + # - Array partitioning for parallel access + # ========================================================================= + print("\n[3] Creating schedule for expert (custom, no library functions)...") + s_expert = allo.customize(expert, instantiate=[Ty, L, D, D_hidden, D]) + + # Get loop handles for expert + expert_loops = s_expert.get_loops(s_expert.top_func_name) + print(f" - Available loops: {list(expert_loops.loops.keys())}") + + # Get nested loops inside row_loop + row_loop = expert_loops["row_loop"] + print(f" - row_loop sub-loops: {list(row_loop.loops.keys())}") + + # ------------------------------------------------------------------------- + # Note: Allo flattens nested loops and uses loop variable names as keys + # row_loop sub-loops: ['n', 'h', 'k', 'o'] where: + # n = row index (from row_loop) + # h = hidden dim index (from fc1_gelu_loop) + # k = reduction index (shared by fc1_reduce and fc2_reduce) + # o = output dim index (from fc2_loop) + # ------------------------------------------------------------------------- + + # ------------------------------------------------------------------------- + # Optimization 1: Pipeline the output dimension loops (h and o) + # This pipelines both the FC1+GeLU loop and FC2 loop + # ------------------------------------------------------------------------- + s_expert.pipeline(row_loop["h"]) # Pipeline FC1+GeLU hidden dim loop + s_expert.pipeline(row_loop["o"]) # Pipeline FC2 output dim loop + print(" - Applied pipeline to h (fc1_gelu) and o (fc2) loops") + + # ------------------------------------------------------------------------- + # Optimization 2: Unroll reduction loop for parallel MACs + # The k loop is shared between FC1 and FC2 reductions + # ------------------------------------------------------------------------- + # Determine unroll factor based on dimensions + unroll_factor = min(4, D, D_hidden) + + s_expert.unroll(row_loop["k"], factor=unroll_factor) + print(f" - Applied unroll to k (reduction loop) factor={unroll_factor}") + + # ------------------------------------------------------------------------- + # Note: Removed explicit array partitioning + # Let HLS infer partitioning from pipelining directives for better tool runtime + # Explicit partitioning can explode synthesis time and may not be optimal + # ------------------------------------------------------------------------- + + print(" - Created expert schedule with pipeline and unroll optimizations") + print(" - Note: Array partitioning inferred by HLS from pipelining") + + # ========================================================================= + # Step 4: Create and optimize schedule for moe_layer + # Optimizations: + # - Pipeline on gate_linear and combine_outputs + # - Unroll on gate_reduce for parallel access + # - Compose optimized expert schedule + # ========================================================================= + print("\n[4] Creating schedule for moe_layer...") + s_moe = allo.customize(moe_layer, instantiate=[Ty, L, D, D, E, K, D_hidden]) + + # Get loop handles for moe_layer + moe_loops = s_moe.get_loops(s_moe.top_func_name) + print(f" - Available top-level loops: {list(moe_loops.loops.keys())}") + + # ------------------------------------------------------------------------- + # Optimize gate_linear: pipeline and unroll + # Note: Allo flattens loops - gate_linear sub-loops are likely ['n', 'e', 'k'] + # ------------------------------------------------------------------------- + gate_linear_loop = moe_loops["gate_linear"] + print(f" - gate_linear sub-loops: {list(gate_linear_loop.loops.keys())}") + + # Pipeline the e loop (output dimension) and unroll k (reduction) + s_moe.pipeline(gate_linear_loop["e"]) + print(" - Applied pipeline to gate_linear:e") + + unroll_factor_gate = min(4, D) + s_moe.unroll(gate_linear_loop["k"], factor=unroll_factor_gate) + print(f" - Applied unroll to gate_linear:k (factor={unroll_factor_gate})") + + # ------------------------------------------------------------------------- + # Optimize other loops in moe_layer + # Note: For dsl.grid loops, sub-loops use variable names (n, e, k, d_out, etc.) + # ------------------------------------------------------------------------- + combine_loop = moe_loops["combine_outputs"] + print(f" - combine_outputs sub-loops: {list(combine_loop.loops.keys())}") + s_moe.pipeline(combine_loop["d_out"]) + + topk_loop = moe_loops["topk_logits"] + print(f" - topk_logits sub-loops: {list(topk_loop.loops.keys())}") + s_moe.pipeline(topk_loop["k"]) + + sparse_loop = moe_loops["sparse_gate"] + print(f" - sparse_gate sub-loops: {list(sparse_loop.loops.keys())}") + s_moe.pipeline(sparse_loop["n"]) + print(" - Applied pipeline to combine_outputs:d_out, topk_logits:k, sparse_gate:n") + + # ------------------------------------------------------------------------- + # Compose sub-function schedules + # ------------------------------------------------------------------------- + s_moe.compose(s_top1) + s_moe.compose(s_softmax_1d) + s_moe.compose(s_expert) + print(" - Composed top1_select, softmax_1d, and expert schedules") + print(" - Created moe_layer schedule with pipeline, unroll optimizations") + + # ========================================================================= + # Step 5: Create and optimize schedule for scaled_dot_product_attention + # ========================================================================= + print("\n[5] Creating schedule for custom scaled_dot_product_attention...") + s_attn = allo.customize(scaled_dot_product_attention, instantiate=[Ty, H, L, D]) + + # Get loop handles for attention + attn_loops = s_attn.get_loops(s_attn.top_func_name) + print(f" - Available top-level loops: {list(attn_loops.loops.keys())}") + + # Helper function to print loop hierarchy recursively + def print_loop_hierarchy(loops, indent=0): + for key, handle in loops.loops.items(): + print(" " * indent + f"- {key}") + # Check if handle has children (is a loop wrapper) + if hasattr(handle, "loops"): + print_loop_hierarchy(handle, indent + 2) + + print(" - Full loop hierarchy:") + print_loop_hierarchy(attn_loops, indent=4) + + # Attention has nested loops inside head_loop + # The structure is: head_loop -> h -> [split_qkv, qkt_matmul, scale, softmax_*, sv_matmul, merge_heads] + head_inner = attn_loops["head_loop"] + # print(f" - head_loop sub-loops: {list(head_inner.loops.keys())}") + + # Get loop handles for sv_matmul and qkt_matmul BEFORE applying global pipeline + # Global pipeline using string selectors might modify loop structure/metadata + # sv_loops = head_inner["sv_matmul"] + # qkt_loops = head_inner["qkt_matmul"] + # print(f" - sv_matmul sub-loops: {list(sv_loops.loops.keys())}") + # print(f" - qkt_matmul sub-loops: {list(qkt_loops.loops.keys())}") + + # Pipeline reduction loops (k) for matmuls - this is critical for II=1 + # The reordered loops (i, k, j) allow better pipelining of k reduction + s_attn.pipeline("scaled_dot_product_attention:k") # Pipeline k reduction loops (qkt_k, sv_k) + s_attn.pipeline("scaled_dot_product_attention:j") # Pipeline j loops for output dimension + s_attn.pipeline("scaled_dot_product_attention:i") # Pipeline i loops (sv_i) to overlap row processing + print(" - Applied pipeline to i (row), k (reduction), and j (output) loops") + print(" - Note: Loop reordering (i,k,j) enables better memory access and pipelining") + print(" - Pipelining sv_i allows overlapping processing of different output rows") + + # ------------------------------------------------------------------------- + # Apply unroll optimization to break loop-carried dependency + # Unroll k loop by factor of 4 to allow 4 multiplications in parallel + # ------------------------------------------------------------------------- + # print(" - Applying unroll optimization...") + # s_attn.unroll(sv_loops["k"], factor=4) + # print(" - Applied unroll to sv_matmul:k (factor=4)") + # s_attn.unroll(qkt_loops["k"], factor=4) + # print(" - Applied unroll to qkt_matmul:k (factor=4)") + + # ------------------------------------------------------------------------- + # Apply pipeline optimization to reduction loops + # ------------------------------------------------------------------------- + # print(" - Applying pipeline to reduction loops...") + # s_attn.pipeline(sv_loops["k"]) + # print(" - Applied pipeline to sv_matmul:k (reduction loop)") + # s_attn.pipeline(qkt_loops["k"]) + # print(" - Applied pipeline to qkt_matmul:k (reduction loop)") + + # ------------------------------------------------------------------------- + # Apply reorder optimization for better data locality + # For matrix multiplication C[i,j] += A[i,k] * B[k,j], reorder to i,k,j + # This improves memory access pattern for B (V_h in sv_matmul) + # ------------------------------------------------------------------------- + # print(" - Applying reorder optimization...") + # s_attn.reorder(sv_loops["k"], sv_loops["j"]) + # print(" - Applied reorder to sv_matmul: (i, j, k) -> (i, k, j)") + # s_attn.reorder(qkt_loops["k"], qkt_loops["j"]) + # print(" - Applied reorder to qkt_matmul: (i, j, k) -> (i, k, j)") + + # ------------------------------------------------------------------------- + # Apply buffer_at optimization to reduce memory access + # Creates on-chip buffer for intermediate results + # ------------------------------------------------------------------------- + # print(" - Applying buffer_at optimization...") + # s_attn.buffer_at(s_attn.C_h, axis=sv_loops["i"]) + # print(" - Applied buffer_at to C_h at sv_matmul:i") + # s_attn.buffer_at(s_attn.Y, axis=qkt_loops["i"]) + # print(" - Applied buffer_at to Y at qkt_matmul:i") + + print(" - Created scaled_dot_product_attention schedule with pipeline optimizations") + print(" - Matmul loops reordered to (i,k,j) for better memory access and II=1") + + # ========================================================================= + # Step 6: Create schedule for main attention_moe_layer function + # ========================================================================= + print("\n[6] Creating schedule for attention_moe_layer...") + s_attn_moe = allo.customize( + attention_moe_layer, + instantiate=[Ty, B, L, D, H, E, K, D_hidden] + ) + + # Get loop handles for attention_moe_layer + attn_moe_loops = s_attn_moe.get_loops(s_attn_moe.top_func_name) + print(f" - Available top-level loops: {list(attn_moe_loops.loops.keys())}") + + print(" - Full loop hierarchy for attention_moe_layer:") + print_loop_hierarchy(attn_moe_loops, indent=4) + + # Pipeline top-level loops using get_loops() + # Note: extract_qkv and store_output are inside batch_loop + batch_loop = attn_moe_loops["batch_loop"] + + # s_attn_moe.pipeline(batch_loop["extract_qkv"]["d"]) + # s_attn_moe.pipeline(batch_loop["store_output"]["d"]) + + # ========================================================================= + # Step 7: Compose all schedules together + # ========================================================================= + print("\n[7] Composing all schedules together...") + s_attn_moe.compose(s_attn) + s_attn_moe.compose(s_moe) + print(" - Composed scaled_dot_product_attention schedule") + print(" - Composed moe_layer schedule") + + print("\n" + "=" * 60) + print("Schedule composition complete with optimizations!") + print("=" * 60) + print("MoE Optimizations applied:") + print(" 1. Fused GeLU into FC1 (eliminates intermediate array)") + print(" 2. Row-level structure (FC1+GeLU -> FC2)") + print(" 3. Unrolled reduction loops (parallel MACs)") + print(f" - Expert k loop: unroll factor {min(4, D, D_hidden)}") + print(f" - Gate k loop: unroll factor {min(4, D)}") + print(" 4. Pipeline on h, o, e loops (output dimensions)") + print(" 5. Array partitioning inferred by HLS (not explicit)") + print("-" * 60) + print("Attention Optimizations (unchanged):") + print(" - Pipeline on innermost loops (j, k)") + print("=" * 60) + + return s_attn_moe + + +# test/compare with pytorch +if __name__ == "__main__": + import torch + import torch.nn as torch_nn + from pytorch_attention_moe import AttentionMoE + + # get config + moe_config = get_moe_config(CONFIG_MODE) + + batch_size = moe_config["batch_size"] + seq_len = moe_config["seq_len"] + embed_dim = moe_config["input_dim"] # D, must be divisible by num_heads + num_experts = moe_config["num_experts"] # E + k = moe_config["k"] # Top-k MoE + hidden_dim = moe_config["hidden_dim"] # D_hidden + + # Attention-specific parameter: num_heads + # Choose num_heads such that embed_dim is divisible by num_heads + # Common choices: 2, 4, 8, 12, 16 (depending on embed_dim) + if embed_dim >= 768: + num_heads = 12 # Standard for BERT-base + elif embed_dim >= 512: + num_heads = 8 + elif embed_dim >= 256: + num_heads = 8 + elif embed_dim >= 128: + num_heads = 4 + elif embed_dim >= 64: + num_heads = 4 + else: + num_heads = 2 + + # Ensure embed_dim is divisible by num_heads + while embed_dim % num_heads != 0 and num_heads > 1: + num_heads -= 1 + + seed = 42 + + print("=" * 60) + print("Attention + MoE Allo Implementation Test") + print("=" * 60) + print(f"Configuration Mode: {CONFIG_MODE}") + print(f"Configuration:") + print(f" batch_size={batch_size}, seq_len={seq_len}, embed_dim={embed_dim}") + print(f" num_heads={num_heads}, head_dim={embed_dim // num_heads}") + print(f" num_experts={num_experts}, k={k}, hidden_dim={hidden_dim}") + print(f" Seed: {seed}") + print("=" * 60) + + #---------------------------------------------------------------------------------- + # Run PyTorch implementation to get weights and outputs + #---------------------------------------------------------------------------------- + print("\n[1] Running PyTorch implementation...") + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # Create PyTorch AttentionMoE layer + pytorch_model = AttentionMoE( + seq_len=seq_len, + embed_dim=embed_dim, + num_heads=num_heads, + num_experts=num_experts, + k=k, + expert_hidden_dim=hidden_dim + ) + pytorch_model.eval() + + # Initialize with Xavier uniform + for param in pytorch_model.parameters(): + if param.dim() > 1: + torch_nn.init.xavier_uniform_(param) + else: + torch_nn.init.zeros_(param) + + # Create random inputs (Q, K, V) + torch.manual_seed(seed) + Q_pt = torch.randn(batch_size, seq_len, embed_dim) + K_pt = torch.randn(batch_size, seq_len, embed_dim) + V_pt = torch.randn(batch_size, seq_len, embed_dim) + + # Run PyTorch inference + with torch.no_grad(): + pytorch_output = pytorch_model(Q_pt, K_pt, V_pt, verbose=False) + + print(f"PyTorch output shape: {pytorch_output.shape}") + print(f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]") + + #---------------------------------------------------------------------------------- + # Extract weights and biases from PyTorch model + #---------------------------------------------------------------------------------- + print("\n[2] Extracting weights from PyTorch model...") + + # Gate weights + gate_weight_pt = pytorch_model.moe.gate.gate_linear.weight.data # [num_experts, embed_dim] + gate_bias_pt = pytorch_model.moe.gate.gate_linear.bias + if gate_bias_pt is not None: + gate_bias_pt = gate_bias_pt.data + else: + gate_bias_pt = torch.zeros(num_experts) + + # Expert weights + expert_fc1_weights_pt = torch.stack([exp.fc1.weight.data for exp in pytorch_model.moe.experts]) + expert_fc1_biases_pt = torch.stack([exp.fc1.bias.data for exp in pytorch_model.moe.experts]) + expert_fc2_weights_pt = torch.stack([exp.fc2.weight.data for exp in pytorch_model.moe.experts]) + expert_fc2_biases_pt = torch.stack([exp.fc2.bias.data for exp in pytorch_model.moe.experts]) + + #---------------------------------------------------------------------------------- + # Convert to numpy arrays + #---------------------------------------------------------------------------------- + print("\n[3] Converting weights to numpy arrays...") + Q_np = np.ascontiguousarray(Q_pt.detach().numpy(), dtype=np.float32) + K_np = np.ascontiguousarray(K_pt.detach().numpy(), dtype=np.float32) + V_np = np.ascontiguousarray(V_pt.detach().numpy(), dtype=np.float32) + + gate_weight_np = np.ascontiguousarray(gate_weight_pt.detach().numpy(), dtype=np.float32) + gate_bias_np = np.ascontiguousarray(gate_bias_pt.detach().numpy(), dtype=np.float32) + + expert_fc1_weights_np = np.ascontiguousarray(expert_fc1_weights_pt.detach().numpy(), dtype=np.float32) + expert_fc1_biases_np = np.ascontiguousarray(expert_fc1_biases_pt.detach().numpy(), dtype=np.float32) + expert_fc2_weights_np = np.ascontiguousarray(expert_fc2_weights_pt.detach().numpy(), dtype=np.float32) + expert_fc2_biases_np = np.ascontiguousarray(expert_fc2_biases_pt.detach().numpy(), dtype=np.float32) + + print(f"Q shape: {Q_np.shape}") + print(f"K shape: {K_np.shape}") + print(f"V shape: {V_np.shape}") + print(f"Gate weight shape: {gate_weight_np.shape}") + print(f"Expert FC1 weights shape: {expert_fc1_weights_np.shape}") + print(f"Expert FC2 weights shape: {expert_fc2_weights_np.shape}") + + #---------------------------------------------------------------------------------- + # Run Allo implementation + #---------------------------------------------------------------------------------- + print("\n[4] Running Allo implementation...") + try: + # Create optimized schedule with composition + allo_schedule = optimize_attention_moe_with_composition( + batch_size, seq_len, embed_dim, num_heads, num_experts, k, hidden_dim + ) + + # Generate project name + project_name = f"allo_attention_moe_alt_{CONFIG_MODE}.prj" + print(f"Using project name: {project_name}") + + # Build module + print("\n[5] Building Allo module...") + if MODE == "llvm": + mod = allo_schedule.build(target="llvm") + elif MODE == "sw_emu": + mod = allo_schedule.build(target="vitis_hls", mode="sw_emu", project=project_name) + elif MODE == "hw_emu": + mod = allo_schedule.build(target="vitis_hls", mode="hw_emu", project=project_name) + elif MODE == "hw": + mod = allo_schedule.build(target="vitis_hls", mode="hw", project=project_name) + elif MODE == "csyn": + mod = allo_schedule.build(target="vitis_hls", mode="csyn", project=project_name) + else: + raise ValueError(f"Unsupported mode: {MODE}") + + # Run Allo inference + print("\n[6] Running Allo inference...") + if MODE == "llvm": + allo_output = mod( + Q_np, K_np, V_np, + gate_weight_np, gate_bias_np, + expert_fc1_weights_np, expert_fc1_biases_np, + expert_fc2_weights_np, expert_fc2_biases_np + ) + elif MODE in ["sw_emu", "hw_emu", "hw"]: + allo_output = np.zeros((batch_size, seq_len, embed_dim), dtype=np.float32) + mod(Q_np, K_np, V_np, + gate_weight_np, gate_bias_np, + expert_fc1_weights_np, expert_fc1_biases_np, + expert_fc2_weights_np, expert_fc2_biases_np, + allo_output) + elif MODE == "csyn": + allo_output = np.zeros((batch_size, seq_len, embed_dim), dtype=np.float32) + mod() + else: + raise ValueError(f"Unsupported mode: {MODE}") + + print(f"Allo output shape: {allo_output.shape}") + print(f"Allo output range: [{allo_output.min():.6f}, {allo_output.max():.6f}]") + + #---------------------------------------------------------------------------------- + # Compare the Allo and PyTorch outputs + #---------------------------------------------------------------------------------- + print("\n[7] Comparing outputs...") + pytorch_output_np = pytorch_output.detach().numpy() + + # Compute differences + diff = np.abs(allo_output - pytorch_output_np) + mean_diff = np.mean(diff) + max_diff = np.max(diff) + rel_diff = np.mean(diff / (np.abs(pytorch_output_np) + 1e-8)) + + print(f"Mean absolute difference: {mean_diff:.6e}") + print(f"Max absolute difference: {max_diff:.6e}") + print(f"Mean relative difference: {rel_diff:.6e}") + + # Check if outputs are close + atol = 5e-4 + rtol = 2e-3 + is_close = np.allclose(allo_output, pytorch_output_np, atol=atol, rtol=rtol) + + if is_close: + print(f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})") + else: + print(f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})") + print("First few differences:") + print(diff.flatten()[:10]) + + #---------------------------------------------------------------------------------- + # Print sample outputs for comparison + #---------------------------------------------------------------------------------- + print("\n[8] Sample outputs (first token, first 5 dimensions):") + print(f"PyTorch: {pytorch_output_np[0, 0, :5]}") + print(f"Allo: {allo_output[0, 0, :5]}") + print(f"Diff: {diff[0, 0, :5]}") + + except Exception as e: + print(f"\n✗ ERROR: Failed to run Allo implementation: {e}") + import traceback + traceback.print_exc() + + print("=" * 60) diff --git a/ece6775/attention_moe/allo_attention_moe_base.py b/ece6775/attention_moe/allo_attention_moe_base.py new file mode 100644 index 000000000..a4d851576 --- /dev/null +++ b/ece6775/attention_moe/allo_attention_moe_base.py @@ -0,0 +1,689 @@ +# Attention + MoE in Allo (no library version) +# Q, K, V -> Attention -> MoE -> Output +# all manual implementation, no library functions + +import numpy as np +import allo +from allo.ir.types import float32, int32 +from allo import dsl +import sys +import os + +# path setup +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(current_dir) +llm_config_dir = os.path.join(project_root, "llm_config") +if llm_config_dir not in sys.path: + sys.path.insert(0, llm_config_dir) + +MODE = "sw_emu" # llvm, sw_emu, hw_emu, hw, csyn + +from llm_config import DEFAULT_CONFIG_MODE, get_moe_config, print_config_info + +CONFIG_MODE = DEFAULT_CONFIG_MODE +def softmax_1d[Ty, N, K](X: "Ty[N, K]") -> "Ty[N, K]": + """ + Softmax over last dimension for [N, K] shape. + + Args: + X: Input tensor of shape [N, K] + Returns: + Z: Output tensor of shape [N, K] with softmax applied over dimension K + """ + Z: Ty[N, K] + E_exp: Ty[N, K] + M: Ty[N] = -1000000000000.0 + S: Ty[N] = 0.0 + + # Find max for each row (over dimension K) + for n, k in dsl.grid(N, K, name="row_max"): + if X[n, k] > M[n]: + M[n] = X[n, k] + + # Compute exp and sum + for n, k in dsl.grid(N, K, name="exp_sum"): + E_exp[n, k] = dsl.exp(X[n, k] - M[n]) + S[n] += E_exp[n, k] + + # Normalize + for n, k in dsl.grid(N, K, name="update"): + Z[n, k] = E_exp[n, k] / S[n] + + return Z + +#---------------------------------------------------------------------------------- +# Scaled Dot-Product Attention: Custom implementation matching PyTorch +#---------------------------------------------------------------------------------- +def scaled_dot_product_attention[Ty, H, L, D]( + Q: "Ty[L, D]", + K: "Ty[L, D]", + V: "Ty[L, D]" +) -> "Ty[L, D]": + """ + Scaled Dot-Product Attention (Multi-Head Attention). + + This implementation matches the PyTorch version: + - Input: Q, K, V of shape [L, D] + - Split into H heads, each with dimension D // H + - For each head: softmax(QK^T / sqrt(head_dim)) @ V + - Merge heads back to [L, D] + + Args: + Q: Query tensor of shape [L, D] + K: Key tensor of shape [L, D] + V: Value tensor of shape [L, D] + + Returns: + Z: Output tensor of shape [L, D] + """ + Z: Ty[L, D] = 0.0 + + # Compute scale factor: 1 / sqrt(head_dim) = 1 / sqrt(D // H) + scale: Ty = 1.0 / dsl.sqrt(float(D // H)) + + # Process each head + for h in range(H, name="head_loop"): + # Split Q, K, V for this head + Q_h: Ty[L, D // H] = 0.0 + K_h: Ty[L, D // H] = 0.0 + V_h: Ty[L, D // H] = 0.0 + + for i, j in dsl.grid(L, D // H, name="split_qkv"): + Q_h[i, j] = Q[i, h * (D // H) + j] + K_h[i, j] = K[i, h * (D // H) + j] + V_h[i, j] = V[i, h * (D // H) + j] + + # Compute QK^T = [L, D//H] @ [D//H, L] = [L, L] + Y: Ty[L, L] = 0.0 + for i, j, k in dsl.grid(L, L, D // H, name="qkt_matmul"): + Y[i, j] += Q_h[i, k] * K_h[j, k] # K_h[j, k] is K_h^T[k, j] + + # Scale by 1/sqrt(head_dim) + Y_scaled: Ty[L, L] = 0.0 + for i, j in dsl.grid(L, L, name="scale"): + Y_scaled[i, j] = Y[i, j] * scale + + # Apply softmax over last dimension (inline implementation) + S: Ty[L, L] = 0.0 + E_exp: Ty[L, L] = 0.0 + M: Ty[L] = -1000000000000.0 + Sum: Ty[L] = 0.0 + + # Find max for each row + for i, j in dsl.grid(L, L, name="softmax_max"): + if Y_scaled[i, j] > M[i]: + M[i] = Y_scaled[i, j] + + # Compute exp and sum + for i, j in dsl.grid(L, L, name="softmax_exp"): + E_exp[i, j] = dsl.exp(Y_scaled[i, j] - M[i]) + Sum[i] += E_exp[i, j] + + # Normalize + for i, j in dsl.grid(L, L, name="softmax_norm"): + S[i, j] = E_exp[i, j] / Sum[i] + + # Compute S @ V_h = [L, L] @ [L, D//H] = [L, D//H] + C_h: Ty[L, D // H] = 0.0 + for i, j, k in dsl.grid(L, D // H, L, name="sv_matmul"): + C_h[i, j] += S[i, k] * V_h[k, j] + + # Merge back to Z + for i, j in dsl.grid(L, D // H, name="merge_heads"): + Z[i, h * (D // H) + j] = C_h[i, j] + + return Z + +#---------------------------------------------------------------------------------- +# Top1 selection: Top-k selection for k=1 (argmax) +#---------------------------------------------------------------------------------- +def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": + """ + Select top-1 expert (argmax) for each token. + + Args: + logits: Input logits of shape [N, E] + Returns: + indices: Top-1 expert indices of shape [N] + """ + indices: int32[N] + max_val: Ty[N] = -1000000000000.0 + + # Initialize indices and max_val with first expert + for n in range(N, name="init"): + indices[n] = 0 + max_val[n] = logits[n, 0] + + # Find argmax for each token (search from index 1 onwards) + for n, e in dsl.grid(N, E, name="argmax"): + if e > 0: # Skip e=0 (already initialized) + if logits[n, e] > max_val[n]: + max_val[n] = logits[n, e] + indices[n] = e + + return indices + +#---------------------------------------------------------------------------------- +# MoE Layer: Main MoE layer - ALL MANUAL IMPLEMENTATION (NO LIBRARY FUNCTIONS) +# Based on moe_allo.py implementation +#---------------------------------------------------------------------------------- +def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( + x: "Ty[N, D_in]", + # Gate weights + gate_weight: "Ty[E, D_in]", + gate_bias: "Ty[E]", + # Expert weights (E experts, each with 2 linear layers) + expert_fc1_weights: "Ty[E, D_hidden, D_in]", + expert_fc1_biases: "Ty[E, D_hidden]", + expert_fc2_weights: "Ty[E, D_out, D_hidden]", + expert_fc2_biases: "Ty[E, D_out]" +) -> "Ty[N, D_out]": + """ + Mixture of Experts layer for inference. + ALL MANUAL IMPLEMENTATION - NO LIBRARY FUNCTIONS. + + Args: + x: Input tensor of shape [N, D_in] (already flattened from [B, L, D_in]) + gate_weight: Gate weight matrix of shape [E, D_in] + gate_bias: Gate bias vector of shape [E] + expert_fc1_weights: Expert FC1 weights of shape [E, D_hidden, D_in] + expert_fc1_biases: Expert FC1 biases of shape [E, D_hidden] + expert_fc2_weights: Expert FC2 weights of shape [E, D_out, D_hidden] + expert_fc2_biases: Expert FC2 biases of shape [E, D_out] + Returns: + output: Output tensor of shape [N, D_out] + """ + # ========================================================================= + # Step 1: Compute gate logits using MANUAL linear layer + # gate_logits = x @ gate_weight^T + gate_bias + # ========================================================================= + gate_logits: Ty[N, E] = 0.0 + for i, j in dsl.grid(N, E, name="gate_linear"): + # Initialize with bias + gate_logits[i, j] = gate_bias[j] + # Matrix multiplication: x[i] @ gate_weight[j]^T + for k in range(D_in): + gate_logits[i, j] += x[i, k] * gate_weight[j, k] + + # ========================================================================= + # Step 2: Select top-1 expert (argmax) for each token + # ========================================================================= + top1_indices: int32[N] + max_logit: Ty[N] = -1000000000000.0 + + for n in range(N, name="top1_init"): + top1_indices[n] = 0 + max_logit[n] = gate_logits[n, 0] + + for n, e in dsl.grid(N, E, name="top1_argmax"): + if e > 0: + if gate_logits[n, e] > max_logit[n]: + max_logit[n] = gate_logits[n, e] + top1_indices[n] = e + + # ========================================================================= + # Step 3: Get top-k logits and apply softmax + # For k=1, softmax just returns 1.0, but we compute for consistency + # ========================================================================= + top_k_logits: Ty[N, K] = 0.0 + for n, k in dsl.grid(N, K, name="topk_logits"): + expert_idx: int32 = top1_indices[n] if k == 0 else 0 + top_k_logits[n, k] = gate_logits[n, expert_idx] + + # Apply softmax to top-k logits (inline implementation) + top_k_weights: Ty[N, K] = 0.0 + softmax_max: Ty[N] = -1000000000000.0 + softmax_sum: Ty[N] = 0.0 + softmax_exp: Ty[N, K] = 0.0 + + for n, k in dsl.grid(N, K, name="softmax_max"): + if top_k_logits[n, k] > softmax_max[n]: + softmax_max[n] = top_k_logits[n, k] + + for n, k in dsl.grid(N, K, name="softmax_exp"): + softmax_exp[n, k] = dsl.exp(top_k_logits[n, k] - softmax_max[n]) + softmax_sum[n] += softmax_exp[n, k] + + for n, k in dsl.grid(N, K, name="softmax_norm"): + top_k_weights[n, k] = softmax_exp[n, k] / softmax_sum[n] + + # ========================================================================= + # Step 4: Create sparse weight matrix from top-k weights + # ========================================================================= + gate_weights: Ty[N, E] = 0.0 + for n in range(N, name="gate_weights"): + expert_idx: int32 = top1_indices[n] + gate_weights[n, expert_idx] = top_k_weights[n, 0] # For k=1 + + # ========================================================================= + # Step 5: Process each expert - ALL MANUAL (NO LIBRARY FUNCTIONS) + # ========================================================================= + expert_outputs: Ty[E, N, D_out] = 0.0 + + for e in range(E, name="expert_loop"): + # Extract expert weights for this expert + expert_fc1_w: Ty[D_hidden, D_in] = 0.0 + expert_fc1_b: Ty[D_hidden] = 0.0 + expert_fc2_w: Ty[D_out, D_hidden] = 0.0 + expert_fc2_b: Ty[D_out] = 0.0 + + for d_hidden, d_in in dsl.grid(D_hidden, D_in, name="extract_fc1_w"): + expert_fc1_w[d_hidden, d_in] = expert_fc1_weights[e, d_hidden, d_in] + + for d_hidden in range(D_hidden, name="extract_fc1_b"): + expert_fc1_b[d_hidden] = expert_fc1_biases[e, d_hidden] + + for d_out, d_hidden in dsl.grid(D_out, D_hidden, name="extract_fc2_w"): + expert_fc2_w[d_out, d_hidden] = expert_fc2_weights[e, d_out, d_hidden] + + for d_out in range(D_out, name="extract_fc2_b"): + expert_fc2_b[d_out] = expert_fc2_biases[e, d_out] + + # --------------------------------------------------------------------- + # Expert forward pass - MANUAL IMPLEMENTATION + # --------------------------------------------------------------------- + + # FC1: fc1_out = x @ fc1_weight^T + fc1_bias + fc1_out: Ty[N, D_hidden] = 0.0 + for i, j in dsl.grid(N, D_hidden, name="fc1_linear"): + fc1_out[i, j] = expert_fc1_b[j] + for k in range(D_in): + fc1_out[i, j] += x[i, k] * expert_fc1_w[j, k] + + # GELU activation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + gelu_out: Ty[N, D_hidden] = 0.0 + for i, j in dsl.grid(N, D_hidden, name="gelu"): + x_val: Ty = fc1_out[i, j] + x3: Ty = x_val * x_val * x_val + # sqrt(2/pi) ≈ 0.7978845608028654 + inner: Ty = 0.7978845608028654 * (x_val + 0.044715 * x3) + tanh_term: Ty = dsl.tanh(inner) + gelu_out[i, j] = 0.5 * x_val * (1.0 + tanh_term) + + # FC2: fc2_out = gelu_out @ fc2_weight^T + fc2_bias + expert_out: Ty[N, D_out] = 0.0 + for i, j in dsl.grid(N, D_out, name="fc2_linear"): + expert_out[i, j] = expert_fc2_b[j] + for k in range(D_hidden): + expert_out[i, j] += gelu_out[i, k] * expert_fc2_w[j, k] + + # Store expert outputs + for n, d_out in dsl.grid(N, D_out, name="store_expert_out"): + expert_outputs[e, n, d_out] = expert_out[n, d_out] + + # ========================================================================= + # Step 6: Combine expert outputs using gate weights + # ========================================================================= + output: Ty[N, D_out] = 0.0 + for n, e, d_out in dsl.grid(N, E, D_out, name="combine_outputs"): + weight: Ty = gate_weights[n, e] + output[n, d_out] += expert_outputs[e, n, d_out] * weight + + return output + +#---------------------------------------------------------------------------------- +# Attention + MoE Layer: Combined layer +# Data flow: Q, K, V -> Attention -> MoE -> Output +#---------------------------------------------------------------------------------- +def attention_moe_layer[Ty, B, L, D, H, E, TopK, D_hidden]( + Query: "Ty[B, L, D]", + Key: "Ty[B, L, D]", + Value: "Ty[B, L, D]", + # Gate weights + gate_weight: "Ty[E, D]", + gate_bias: "Ty[E]", + # Expert weights (E experts, each with 2 linear layers) + expert_fc1_weights: "Ty[E, D_hidden, D]", + expert_fc1_biases: "Ty[E, D_hidden]", + expert_fc2_weights: "Ty[E, D, D_hidden]", + expert_fc2_biases: "Ty[E, D]" +) -> "Ty[B, L, D]": + """ + Combined Attention + MoE layer. + + Data flow: Q, K, V -> Attention -> MoE -> Output + + Args: + Query: Query tensor of shape [B, L, D] + Key: Key tensor of shape [B, L, D] + Value: Value tensor of shape [B, L, D] + gate_weight: Gate weight matrix of shape [E, D] + gate_bias: Gate bias vector of shape [E] + expert_fc1_weights: Expert FC1 weights of shape [E, D_hidden, D] + expert_fc1_biases: Expert FC1 biases of shape [E, D_hidden] + expert_fc2_weights: Expert FC2 weights of shape [E, D, D_hidden] + expert_fc2_biases: Expert FC2 biases of shape [E, D] + Returns: + output: Output tensor of shape [B, L, D] + """ + # Output tensor + output: Ty[B, L, D] = 0.0 + + # Process each batch item + for b in range(B, name="batch_loop"): + # Step 1: Extract Q, K, V for this batch item -> [L, D] + Q_b: Ty[L, D] = 0.0 + K_b: Ty[L, D] = 0.0 + V_b: Ty[L, D] = 0.0 + + for l, d in dsl.grid(L, D, name="extract_qkv"): + Q_b[l, d] = Query[b, l, d] + K_b[l, d] = Key[b, l, d] + V_b[l, d] = Value[b, l, d] + + # Step 2: Apply custom scaled_dot_product_attention + attn_out = scaled_dot_product_attention[Ty, H, L, D](Q_b, K_b, V_b) # [L, D] + + # Step 3: Apply MoE layer (all manual implementation) + moe_out = moe_layer[Ty, L, D, D, E, TopK, D_hidden]( + attn_out, + gate_weight, + gate_bias, + expert_fc1_weights, + expert_fc1_biases, + expert_fc2_weights, + expert_fc2_biases + ) # [L, D] + + # Step 4: Store output for this batch item + for l, d in dsl.grid(L, D, name="store_output"): + output[b, l, d] = moe_out[l, d] + + return output + + +#================================================================================== +# Schedule optimization function - NO LIBRARY FUNCTIONS +#================================================================================== +def optimize_attention_moe_with_composition( + batch_size, seq_len, embed_dim, num_heads, num_experts, k, hidden_dim +): + """ + Create optimized schedules for Attention + MoE and compose them together. + NO LIBRARY FUNCTIONS - all manual implementation. + + Args: + batch_size: Batch size (B) + seq_len: Sequence length (L) + embed_dim: Embedding dimension (D) + num_heads: Number of attention heads (H) + num_experts: Number of experts (E) + k: Top-k value (currently only k=1 is supported) + hidden_dim: Hidden dimension for experts + + Returns: + s_attn_moe: Optimized schedule for attention_moe_layer with all sub-schedules composed + """ + Ty = float32 + B = batch_size + L = seq_len + D = embed_dim + H = num_heads + E = num_experts + K = k + D_hidden = hidden_dim + + print("=" * 60) + print("Creating and optimizing Attention + MoE schedules (NO LIBRARY)...") + print("=" * 60) + + # Step 1: Create schedule for custom attention + print("\n[1] Creating schedule for custom scaled_dot_product_attention...") + s_attn = allo.customize(scaled_dot_product_attention, instantiate=[Ty, H, L, D]) + print(" - Created scaled_dot_product_attention schedule") + + # Step 2: Create schedule for moe_layer (all manual) + print("\n[2] Creating schedule for moe_layer (manual implementation)...") + s_moe = allo.customize(moe_layer, instantiate=[Ty, L, D, D, E, K, D_hidden]) + print(" - Created moe_layer schedule (no library functions)") + + # Step 3: Create schedule for main attention_moe_layer function + print("\n[3] Creating schedule for attention_moe_layer...") + s_attn_moe = allo.customize( + attention_moe_layer, + instantiate=[Ty, B, L, D, H, E, K, D_hidden] + ) + + # Step 4: Compose all schedules together + print("\n[4] Composing all schedules together...") + s_attn_moe.compose(s_attn) + s_attn_moe.compose(s_moe) + print(" - Composed scaled_dot_product_attention schedule") + print(" - Composed moe_layer schedule") + + print("\n" + "=" * 60) + print("Schedule composition complete (NO LIBRARY FUNCTIONS)!") + print("=" * 60) + + return s_attn_moe + + +#================================================================================== +# Test function to compare Allo and PyTorch implementations +#================================================================================== + +if __name__ == "__main__": + import torch + import torch.nn as torch_nn + from pytorch_attention_moe import AttentionMoE + + # ============================================================================ + # Configuration parameters - use shared config from llm_config.py + # ============================================================================ + moe_config = get_moe_config(CONFIG_MODE) + + batch_size = moe_config["batch_size"] + seq_len = moe_config["seq_len"] + embed_dim = moe_config["input_dim"] # D, must be divisible by num_heads + num_experts = moe_config["num_experts"] # E + k = moe_config["k"] # Top-k MoE + hidden_dim = moe_config["hidden_dim"] # D_hidden + + # Attention-specific parameter: num_heads + if embed_dim >= 768: + num_heads = 12 + elif embed_dim >= 512: + num_heads = 8 + elif embed_dim >= 256: + num_heads = 8 + elif embed_dim >= 128: + num_heads = 4 + elif embed_dim >= 64: + num_heads = 4 + else: + num_heads = 2 + + # Ensure embed_dim is divisible by num_heads + while embed_dim % num_heads != 0 and num_heads > 1: + num_heads -= 1 + + seed = 42 + + print("=" * 60) + print("Attention + MoE Allo Implementation Test (NO LIBRARY)") + print("=" * 60) + print(f"Configuration Mode: {CONFIG_MODE}") + print(f"Configuration:") + print(f" batch_size={batch_size}, seq_len={seq_len}, embed_dim={embed_dim}") + print(f" num_heads={num_heads}, head_dim={embed_dim // num_heads}") + print(f" num_experts={num_experts}, k={k}, hidden_dim={hidden_dim}") + print(f" Seed: {seed}") + print("=" * 60) + + #---------------------------------------------------------------------------------- + # Run PyTorch implementation to get weights and outputs + #---------------------------------------------------------------------------------- + print("\n[1] Running PyTorch implementation...") + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # Create PyTorch AttentionMoE layer + pytorch_model = AttentionMoE( + seq_len=seq_len, + embed_dim=embed_dim, + num_heads=num_heads, + num_experts=num_experts, + k=k, + expert_hidden_dim=hidden_dim + ) + pytorch_model.eval() + + # Initialize with Xavier uniform + for param in pytorch_model.parameters(): + if param.dim() > 1: + torch_nn.init.xavier_uniform_(param) + else: + torch_nn.init.zeros_(param) + + # Create random inputs (Q, K, V) + torch.manual_seed(seed) + Q_pt = torch.randn(batch_size, seq_len, embed_dim) + K_pt = torch.randn(batch_size, seq_len, embed_dim) + V_pt = torch.randn(batch_size, seq_len, embed_dim) + + # Run PyTorch inference + with torch.no_grad(): + pytorch_output = pytorch_model(Q_pt, K_pt, V_pt, verbose=False) + + print(f"PyTorch output shape: {pytorch_output.shape}") + print(f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]") + + #---------------------------------------------------------------------------------- + # Extract weights and biases from PyTorch model + #---------------------------------------------------------------------------------- + print("\n[2] Extracting weights from PyTorch model...") + + # Gate weights + gate_weight_pt = pytorch_model.moe.gate.gate_linear.weight.data + gate_bias_pt = pytorch_model.moe.gate.gate_linear.bias + if gate_bias_pt is not None: + gate_bias_pt = gate_bias_pt.data + else: + gate_bias_pt = torch.zeros(num_experts) + + # Expert weights + expert_fc1_weights_pt = torch.stack([exp.fc1.weight.data for exp in pytorch_model.moe.experts]) + expert_fc1_biases_pt = torch.stack([exp.fc1.bias.data for exp in pytorch_model.moe.experts]) + expert_fc2_weights_pt = torch.stack([exp.fc2.weight.data for exp in pytorch_model.moe.experts]) + expert_fc2_biases_pt = torch.stack([exp.fc2.bias.data for exp in pytorch_model.moe.experts]) + + #---------------------------------------------------------------------------------- + # Convert to numpy arrays + #---------------------------------------------------------------------------------- + print("\n[3] Converting weights to numpy arrays...") + Q_np = np.ascontiguousarray(Q_pt.detach().numpy(), dtype=np.float32) + K_np = np.ascontiguousarray(K_pt.detach().numpy(), dtype=np.float32) + V_np = np.ascontiguousarray(V_pt.detach().numpy(), dtype=np.float32) + + gate_weight_np = np.ascontiguousarray(gate_weight_pt.detach().numpy(), dtype=np.float32) + gate_bias_np = np.ascontiguousarray(gate_bias_pt.detach().numpy(), dtype=np.float32) + + expert_fc1_weights_np = np.ascontiguousarray(expert_fc1_weights_pt.detach().numpy(), dtype=np.float32) + expert_fc1_biases_np = np.ascontiguousarray(expert_fc1_biases_pt.detach().numpy(), dtype=np.float32) + expert_fc2_weights_np = np.ascontiguousarray(expert_fc2_weights_pt.detach().numpy(), dtype=np.float32) + expert_fc2_biases_np = np.ascontiguousarray(expert_fc2_biases_pt.detach().numpy(), dtype=np.float32) + + print(f"Q shape: {Q_np.shape}") + print(f"K shape: {K_np.shape}") + print(f"V shape: {V_np.shape}") + print(f"Gate weight shape: {gate_weight_np.shape}") + print(f"Expert FC1 weights shape: {expert_fc1_weights_np.shape}") + print(f"Expert FC2 weights shape: {expert_fc2_weights_np.shape}") + + #---------------------------------------------------------------------------------- + # Run Allo implementation + #---------------------------------------------------------------------------------- + print("\n[4] Running Allo implementation...") + try: + # Create optimized schedule with composition + allo_schedule = optimize_attention_moe_with_composition( + batch_size, seq_len, embed_dim, num_heads, num_experts, k, hidden_dim + ) + + # Generate project name + project_name = f"allo_attention_moe_base_{CONFIG_MODE}.prj" + print(f"Using project name: {project_name}") + + # Build module + print("\n[5] Building Allo module...") + if MODE == "llvm": + mod = allo_schedule.build(target="llvm") + elif MODE == "sw_emu": + mod = allo_schedule.build(target="vitis_hls", mode="sw_emu", project=project_name) + elif MODE == "hw_emu": + mod = allo_schedule.build(target="vitis_hls", mode="hw_emu", project=project_name) + elif MODE == "hw": + mod = allo_schedule.build(target="vitis_hls", mode="hw", project=project_name) + elif MODE == "csyn": + mod = allo_schedule.build(target="vitis_hls", mode="csyn", project=project_name) + else: + raise ValueError(f"Unsupported mode: {MODE}") + + # Run Allo inference + print("\n[6] Running Allo inference...") + if MODE == "llvm": + allo_output = mod( + Q_np, K_np, V_np, + gate_weight_np, gate_bias_np, + expert_fc1_weights_np, expert_fc1_biases_np, + expert_fc2_weights_np, expert_fc2_biases_np + ) + elif MODE in ["sw_emu", "hw_emu", "hw"]: + allo_output = np.zeros((batch_size, seq_len, embed_dim), dtype=np.float32) + mod(Q_np, K_np, V_np, + gate_weight_np, gate_bias_np, + expert_fc1_weights_np, expert_fc1_biases_np, + expert_fc2_weights_np, expert_fc2_biases_np, + allo_output) + elif MODE == "csyn": + allo_output = np.zeros((batch_size, seq_len, embed_dim), dtype=np.float32) + mod() + else: + raise ValueError(f"Unsupported mode: {MODE}") + + print(f"Allo output shape: {allo_output.shape}") + print(f"Allo output range: [{allo_output.min():.6f}, {allo_output.max():.6f}]") + + #---------------------------------------------------------------------------------- + # Compare the Allo and PyTorch outputs + #---------------------------------------------------------------------------------- + print("\n[7] Comparing outputs...") + pytorch_output_np = pytorch_output.detach().numpy() + + # Compute differences + diff = np.abs(allo_output - pytorch_output_np) + mean_diff = np.mean(diff) + max_diff = np.max(diff) + rel_diff = np.mean(diff / (np.abs(pytorch_output_np) + 1e-8)) + + print(f"Mean absolute difference: {mean_diff:.6e}") + print(f"Max absolute difference: {max_diff:.6e}") + print(f"Mean relative difference: {rel_diff:.6e}") + + # Check if outputs are close + atol = 5e-4 + rtol = 2e-3 + is_close = np.allclose(allo_output, pytorch_output_np, atol=atol, rtol=rtol) + + if is_close: + print(f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})") + else: + print(f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})") + print("First few differences:") + print(diff.flatten()[:10]) + + #---------------------------------------------------------------------------------- + # Print sample outputs for comparison + #---------------------------------------------------------------------------------- + print("\n[8] Sample outputs (first token, first 5 dimensions):") + print(f"PyTorch: {pytorch_output_np[0, 0, :5]}") + print(f"Allo: {allo_output[0, 0, :5]}") + print(f"Diff: {diff[0, 0, :5]}") + + except Exception as e: + print(f"\n✗ ERROR: Failed to run Allo implementation: {e}") + import traceback + traceback.print_exc() + + print("=" * 60) diff --git a/ece6775/attention_moe/allo_attention_moe_lib.py b/ece6775/attention_moe/allo_attention_moe_lib.py new file mode 100644 index 000000000..d1833b89f --- /dev/null +++ b/ece6775/attention_moe/allo_attention_moe_lib.py @@ -0,0 +1,771 @@ +# Attention + MoE in Allo +# Q, K, V -> Attention -> MoE -> Output +# uses library functions (nn.linear2d, nn.GeLU) for expert + +import numpy as np +import allo +import allo.library.nn as allo_nn +from allo.library.nn import linear2d, GeLU # Direct import for type inference +from allo.ir.types import float32, int32 +from allo import dsl +import sys +import os + +# path setup +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(current_dir) +llm_config_dir = os.path.join(project_root, "llm_config") +if llm_config_dir not in sys.path: + sys.path.insert(0, llm_config_dir) + +MODE = "sw_emu" # llvm, sw_emu, hw_emu, hw, csyn + +from llm_config import DEFAULT_CONFIG_MODE, get_moe_config, print_config_info + +CONFIG_MODE = DEFAULT_CONFIG_MODE +def softmax_1d[Ty, N, K](X: "Ty[N, K]") -> "Ty[N, K]": + """ + Softmax over last dimension for [N, K] shape. + + Args: + X: Input tensor of shape [N, K] + Returns: + Z: Output tensor of shape [N, K] with softmax applied over dimension K + """ + Z: Ty[N, K] + E_exp: Ty[N, K] + M: Ty[N] = -1000000000000.0 + S: Ty[N] = 0.0 + + # Find max for each row (over dimension K) + for n, k in dsl.grid(N, K, name="row_max"): + if X[n, k] > M[n]: + M[n] = X[n, k] + + # Compute exp and sum + for n, k in dsl.grid(N, K, name="exp_sum"): + E_exp[n, k] = dsl.exp(X[n, k] - M[n]) + S[n] += E_exp[n, k] + + # Normalize + for n, k in dsl.grid(N, K, name="update"): + # Add small epsilon to prevent division by zero + Z[n, k] = E_exp[n, k] / (S[n]) + + return Z + +#---------------------------------------------------------------------------------- +# Softmax 2D: Softmax over last dimension for [L, L] shape (for attention) +#---------------------------------------------------------------------------------- +def softmax_2d[Ty, L](X: "Ty[L, L]") -> "Ty[L, L]": + """ + Softmax over last dimension for [L, L] shape. + Used for attention scores. + + Args: + X: Input tensor of shape [L, L] + Returns: + Z: Output tensor of shape [L, L] with softmax applied over last dimension + """ + Z: Ty[L, L] + E_exp: Ty[L, L] + M: Ty[L] = -1000000000000.0 + S: Ty[L] = 0.0 + + # Find max for each row + for i, j in dsl.grid(L, L, name="row_max"): + if X[i, j] > M[i]: + M[i] = X[i, j] + + # Compute exp and sum + for i, j in dsl.grid(L, L, name="exp_sum"): + E_exp[i, j] = dsl.exp(X[i, j] - M[i]) + S[i] += E_exp[i, j] + + # Normalize + for i, j in dsl.grid(L, L, name="update"): + Z[i, j] = E_exp[i, j] / S[i] + + return Z + +#---------------------------------------------------------------------------------- +# Scaled Dot-Product Attention: Custom implementation matching PyTorch +#---------------------------------------------------------------------------------- +def scaled_dot_product_attention[Ty, H, L, D]( + Q: "Ty[L, D]", + K: "Ty[L, D]", + V: "Ty[L, D]" +) -> "Ty[L, D]": + """ + Scaled Dot-Product Attention (Multi-Head Attention). + + This implementation matches the PyTorch version: + - Input: Q, K, V of shape [L, D] + - Split into H heads, each with dimension D // H + - For each head: softmax(QK^T / sqrt(head_dim)) @ V + - Merge heads back to [L, D] + - Uses linear2d for matrix multiplication (replacing systolic for HLS compatibility) + + Args: + Q: Query tensor of shape [L, D] + K: Key tensor of shape [L, D] + V: Value tensor of shape [L, D] + + Returns: + Z: Output tensor of shape [L, D] + """ + Z: Ty[L, D] = 0.0 + + # Compute scale factor: 1 / sqrt(head_dim) = 1 / sqrt(D // H) + # For D=8, H=2: head_dim=4, scale = 1/sqrt(4) = 0.5 + scale: Ty = 1.0 / dsl.sqrt(float(D // H)) + + # Process each head + for h in range(H, name="head_loop"): + # Split Q, K, V for this head + Q_h: Ty[L, D // H] = 0.0 + K_h: Ty[L, D // H] = 0.0 # Not transposed, will transpose for linear2d + V_h: Ty[L, D // H] = 0.0 + + for i, j in dsl.grid(L, D // H, name="split_qkv"): + Q_h[i, j] = Q[i, h * (D // H) + j] + K_h[i, j] = K[i, h * (D // H) + j] # Store K normally (not transposed) + V_h[i, j] = V[i, h * (D // H) + j] + + # Compute QK^T = [L, D//H] @ [D//H, L] = [L, L] + # Using linear2d: Q_h @ K_h^T = linear2d(Q_h, K_h_transposed, 0) + # linear2d computes X @ W^T + b, so we need K_h^T as W + # K_h is [L, D//H], so K_h^T is [D//H, L], but linear2d expects W[N, K] = [L, D//H] + # So we need to transpose K_h to [D//H, L], then use it as W[L, D//H] (which is K_h^T) + # Actually: Q_h[L, D//H] @ K_h^T[D//H, L] = Q_h @ (K_h^T) + # linear2d(X[M, K], W[N, K], b) computes X @ W^T, so: + # - X = Q_h [L, D//H] -> M=L, K=D//H + # - W should be [L, D//H] to get W^T = [D//H, L] + # - So W = K_h^T, but we have K_h [L, D//H], so we need to transpose it + K_h_T: Ty[D // H, L] = 0.0 + for i, j in dsl.grid(L, D // H, name="transpose_K"): + K_h_T[j, i] = K_h[i, j] + + # Now use linear2d: Q_h[L, D//H] @ K_h_T^T[D//H, L] = Q_h @ K_h_T^T + # But linear2d expects W[N, K], so we need W[L, D//H] such that W^T = [D//H, L] + # Actually, we can use K_h directly as W[L, D//H], then W^T = [D//H, L] = K_h^T + # Wait, let me reconsider: linear2d(X[M, K], W[N, K], b) = X @ W^T + # For QK^T = Q_h[L, D//H] @ K_h^T[D//H, L]: + # - X = Q_h [L, D//H] -> M=L, K=D//H + # - We want result [L, L], so N=L + # - W should be [L, D//H] such that W^T = [D//H, L] = K_h^T + # - So W = K_h (which is [L, D//H]), then W^T = [D//H, L] = K_h^T ✓ + zero_bias: Ty[L] = 0.0 + Y = linear2d[Ty, Ty, Ty, L, L, D // H](Q_h, K_h, zero_bias) # Q_h @ K_h^T + + # Scale by 1/sqrt(head_dim) + Y_scaled: Ty[L, L] = 0.0 + for i, j in dsl.grid(L, L, name="scale"): + Y_scaled[i, j] = Y[i, j] * scale + + # Apply softmax over last dimension + S: Ty[L, L] = 0.0 + E_exp: Ty[L, L] = 0.0 + M: Ty[L] = -1000000000000.0 + Sum: Ty[L] = 0.0 + + # Find max for each row + for i, j in dsl.grid(L, L, name="softmax_max"): + if Y_scaled[i, j] > M[i]: + M[i] = Y_scaled[i, j] + + # Compute exp and sum + for i, j in dsl.grid(L, L, name="softmax_exp"): + E_exp[i, j] = dsl.exp(Y_scaled[i, j] - M[i]) + Sum[i] += E_exp[i, j] + + # Normalize + for i, j in dsl.grid(L, L, name="softmax_norm"): + S[i, j] = E_exp[i, j] / Sum[i] + + # Compute S @ V_h = [L, L] @ [L, D//H] = [L, D//H] + # Using linear2d: S @ V_h = linear2d(S, V_h_transposed, 0) + # linear2d(X[M, K], W[N, K], b) = X @ W^T + # For S[L, L] @ V_h[L, D//H]: + # - X = S [L, L] -> M=L, K=L + # - We want result [L, D//H], so N=D//H + # - W should be [D//H, L] such that W^T = [L, D//H] = V_h + # - So W = V_h^T (which is [D//H, L]), then W^T = [L, D//H] = V_h ✓ + V_h_T: Ty[D // H, L] = 0.0 + for i, j in dsl.grid(L, D // H, name="transpose_V"): + V_h_T[j, i] = V_h[i, j] + + zero_bias_v: Ty[D // H] = 0.0 + C_h = linear2d[Ty, Ty, Ty, L, D // H, L](S, V_h_T, zero_bias_v) # S @ V_h + + # Merge back to Z + for i, j in dsl.grid(L, D // H, name="merge_heads"): + Z[i, h * (D // H) + j] = C_h[i, j] + + return Z + +#---------------------------------------------------------------------------------- +# Top1 selection: Top-k selection for k=1 (argmax) +#---------------------------------------------------------------------------------- +def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": + """ + Select top-1 expert (argmax) for each token. + + Args: + logits: Input logits of shape [N, E] + Returns: + indices: Top-1 expert indices of shape [N] + """ + indices: int32[N] + max_val: Ty[N] = -1000000000000.0 + + # Initialize indices and max_val with first expert + for n in range(N, name="init"): + indices[n] = 0 + max_val[n] = logits[n, 0] + + # Find argmax for each token (search from index 1 onwards) + for n, e in dsl.grid(N, E, name="argmax"): + if e > 0: # Skip e=0 (already initialized) + if logits[n, e] > max_val[n]: + max_val[n] = logits[n, e] + indices[n] = e + + return indices + +#---------------------------------------------------------------------------------- +# Expert: A simple feed-forward expert network +#---------------------------------------------------------------------------------- +def expert[Ty, N, D_in, D_hidden, D_out]( + x: "Ty[N, D_in]", + fc1_weight: "Ty[D_hidden, D_in]", + fc1_bias: "Ty[D_hidden]", + fc2_weight: "Ty[D_out, D_hidden]", + fc2_bias: "Ty[D_out]" +) -> "Ty[N, D_out]": + """ + A simple feed-forward expert network. + + This implementation uses linear2d and GeLU from Allo library + for optimized computation. + + Args: + x: Input tensor of shape [N, D_in] + fc1_weight: First linear layer weights of shape [D_hidden, D_in] + fc1_bias: First linear layer bias of shape [D_hidden] + fc2_weight: Second linear layer weights of shape [D_out, D_hidden] + fc2_bias: Second linear layer bias of shape [D_out] + Returns: + output: Output tensor of shape [N, D_out] + """ + # Step 1: First linear layer using linear2d from library + fc1_out = linear2d[Ty, Ty, Ty, N, D_hidden, D_in](x, fc1_weight, fc1_bias) + + # Step 2: GELU activation using GeLU from library + gelu_out = GeLU[Ty, N, D_hidden](fc1_out) + + # Step 3: Second linear layer using linear2d from library + fc2_out = linear2d[Ty, Ty, Ty, N, D_out, D_hidden](gelu_out, fc2_weight, fc2_bias) + + return fc2_out + +#---------------------------------------------------------------------------------- +# MoE Layer: Main MoE layer that routes tokens to experts +#---------------------------------------------------------------------------------- +def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( + x: "Ty[N, D_in]", + # Gate weights + gate_weight: "Ty[E, D_in]", + gate_bias: "Ty[E]", + # Expert weights (E experts, each with 2 linear layers) + expert_fc1_weights: "Ty[E, D_hidden, D_in]", + expert_fc1_biases: "Ty[E, D_hidden]", + expert_fc2_weights: "Ty[E, D_out, D_hidden]", + expert_fc2_biases: "Ty[E, D_out]" +) -> "Ty[N, D_out]": + """ + Mixture of Experts layer for inference. + + Args: + x: Input tensor of shape [N, D_in] (already flattened from [B, L, D_in]) + gate_weight: Gate weight matrix of shape [E, D_in] + gate_bias: Gate bias vector of shape [E] + expert_fc1_weights: Expert FC1 weights of shape [E, D_hidden, D_in] + expert_fc1_biases: Expert FC1 biases of shape [E, D_hidden] + expert_fc2_weights: Expert FC2 weights of shape [E, D_out, D_hidden] + expert_fc2_biases: Expert FC2 biases of shape [E, D_out] + Returns: + output: Output tensor of shape [N, D_out] + """ + # Step 1: Compute gate logits using linear2d + gate_logits = linear2d[Ty, Ty, Ty, N, E, D_in](x, gate_weight, gate_bias) # [N, E] + + # Step 2: Select top-1 expert using top1_select function + top1_indices_1d: int32[N] = top1_select[Ty, N, E](gate_logits) + + # Step 3: Get top-k logits and apply softmax + top_k_logits: Ty[N, K] = 0.0 + for n, k in dsl.grid(N, K, name="topk_logits"): + expert_idx = top1_indices_1d[n] if k == 0 else 0 # For k=1, K=1 + top_k_logits[n, k] = gate_logits[n, expert_idx] + + # Step 4: Apply softmax to top-k logits using softmax_1d function + top_k_weights = softmax_1d[Ty, N, K](top_k_logits) # [N, K] + + # Step 5: Create sparse weight matrix from top-k weights + top_k_indices: int32[N, K] = 0 + gate_weights: Ty[N, E] = 0.0 + + for n in range(N, name="gate_weights"): + # Store top-k indices (for k=1) + top_k_indices[n, 0] = top1_indices_1d[n] + # Set gate weights from softmax output + expert_idx = top1_indices_1d[n] + gate_weights[n, expert_idx] = top_k_weights[n, 0] # For k=1, K=1 + + # Step 6: Process each expert: compute outputs for all tokens + expert_outputs: Ty[E, N, D_out] = 0.0 + + for e in range(E, name="expert_loop"): + # Extract expert weights for this expert + expert_fc1_w: Ty[D_hidden, D_in] = 0.0 + expert_fc1_b: Ty[D_hidden] = 0.0 + expert_fc2_w: Ty[D_out, D_hidden] = 0.0 + expert_fc2_b: Ty[D_out] = 0.0 + + for d_hidden, d_in in dsl.grid(D_hidden, D_in, name="extract_fc1_w"): + expert_fc1_w[d_hidden, d_in] = expert_fc1_weights[e, d_hidden, d_in] + + for d_hidden in range(D_hidden, name="extract_fc1_b"): + expert_fc1_b[d_hidden] = expert_fc1_biases[e, d_hidden] + + for d_out, d_hidden in dsl.grid(D_out, D_hidden, name="extract_fc2_w"): + expert_fc2_w[d_out, d_hidden] = expert_fc2_weights[e, d_out, d_hidden] + + for d_out in range(D_out, name="extract_fc2_b"): + expert_fc2_b[d_out] = expert_fc2_biases[e, d_out] + + # Process all tokens through this expert using the expert function + expert_out = expert[Ty, N, D_in, D_hidden, D_out]( + x, expert_fc1_w, expert_fc1_b, expert_fc2_w, expert_fc2_b + ) # [N, D_out] + + # Store expert outputs + for n, d_out in dsl.grid(N, D_out, name="store_expert_out"): + expert_outputs[e, n, d_out] = expert_out[n, d_out] + + # Step 7: Combine expert outputs using gate weights + output: Ty[N, D_out] = 0.0 + for n, e, d_out in dsl.grid(N, E, D_out, name="combine_outputs"): + weight: Ty = gate_weights[n, e] + output[n, d_out] += expert_outputs[e, n, d_out] * weight + + return output + +#---------------------------------------------------------------------------------- +# Attention + MoE Layer: Combined layer +# Data flow: Q, K, V -> Attention -> MoE -> Output +#---------------------------------------------------------------------------------- +def attention_moe_layer[Ty, B, L, D, H, E, TopK, D_hidden]( + Query: "Ty[B, L, D]", + Key: "Ty[B, L, D]", + Value: "Ty[B, L, D]", + # Gate weights + gate_weight: "Ty[E, D]", + gate_bias: "Ty[E]", + # Expert weights (E experts, each with 2 linear layers) + expert_fc1_weights: "Ty[E, D_hidden, D]", + expert_fc1_biases: "Ty[E, D_hidden]", + expert_fc2_weights: "Ty[E, D, D_hidden]", + expert_fc2_biases: "Ty[E, D]" +) -> "Ty[B, L, D]": + """ + Combined Attention + MoE layer. + + Data flow: Q, K, V -> Attention -> MoE -> Output + + Args: + Query: Query tensor of shape [B, L, D] + Key: Key tensor of shape [B, L, D] + Value: Value tensor of shape [B, L, D] + gate_weight: Gate weight matrix of shape [E, D] + gate_bias: Gate bias vector of shape [E] + expert_fc1_weights: Expert FC1 weights of shape [E, D_hidden, D] + expert_fc1_biases: Expert FC1 biases of shape [E, D_hidden] + expert_fc2_weights: Expert FC2 weights of shape [E, D, D_hidden] + expert_fc2_biases: Expert FC2 biases of shape [E, D] + Returns: + output: Output tensor of shape [B, L, D] + """ + # Output tensor + output: Ty[B, L, D] = 0.0 + + # Process each batch item + for b in range(B, name="batch_loop"): + # Step 1: Extract Q, K, V for this batch item -> [L, D] + Q_b: Ty[L, D] = 0.0 + K_b: Ty[L, D] = 0.0 + V_b: Ty[L, D] = 0.0 + + for l, d in dsl.grid(L, D, name="extract_qkv"): + Q_b[l, d] = Query[b, l, d] + K_b[l, d] = Key[b, l, d] + V_b[l, d] = Value[b, l, d] + + # Step 2: Apply custom scaled_dot_product_attention + # scaled_dot_product_attention[Ty, H, L, D](Q, K, V) -> [L, D] + attn_out = scaled_dot_product_attention[Ty, H, L, D](Q_b, K_b, V_b) # [L, D] + + # Step 3: Apply MoE layer + # moe_layer expects [N, D_in], so we use N=L for single batch + moe_out = moe_layer[Ty, L, D, D, E, TopK, D_hidden]( + attn_out, + gate_weight, + gate_bias, + expert_fc1_weights, + expert_fc1_biases, + expert_fc2_weights, + expert_fc2_biases + ) # [L, D] + + # Step 4: Store output for this batch item + for l, d in dsl.grid(L, D, name="store_output"): + output[b, l, d] = moe_out[l, d] + + return output + + +#================================================================================== +# Schedule optimization function +#================================================================================== +def optimize_attention_moe_with_composition( + batch_size, seq_len, embed_dim, num_heads, num_experts, k, hidden_dim +): + """ + Create optimized schedules for Attention + MoE and compose them together. + + Args: + batch_size: Batch size (B) + seq_len: Sequence length (L) + embed_dim: Embedding dimension (D) + num_heads: Number of attention heads (H) + num_experts: Number of experts (E) + k: Top-k value (currently only k=1 is supported) + hidden_dim: Hidden dimension for experts + + Returns: + s_attn_moe: Optimized schedule for attention_moe_layer with all sub-schedules composed + """ + Ty = float32 + B = batch_size + L = seq_len + D = embed_dim + H = num_heads + E = num_experts + K = k + D_hidden = hidden_dim + + print("=" * 60) + print("Creating and optimizing Attention + MoE schedules...") + print("=" * 60) + + # Step 1: Create schedule for top1_select + print("\n[1] Creating schedule for top1_select...") + s_top1 = allo.customize(top1_select, instantiate=[Ty, L, E]) + print(" - Created top1_select schedule for [L, E]") + + # Step 2: Create schedule for softmax_1d (for MoE gate) + print("\n[2] Creating schedule for softmax_1d...") + s_softmax_1d = allo.customize(softmax_1d, instantiate=[Ty, L, K]) + print(" - Created softmax_1d schedule for [L, K]") + + # Step 3: Create schedule for expert + print("\n[3] Creating schedule for expert...") + + # Create schedules for library functions that expert uses + print(" - Creating schedules for library functions (linear2d, GeLU)...") + s_linear_fc1 = allo.customize(allo_nn.linear2d, instantiate=[Ty, Ty, Ty, L, D_hidden, D]) + s_gelu = allo.customize(allo_nn.GeLU, instantiate=[Ty, L, D_hidden]) + s_linear_fc2 = allo.customize(allo_nn.linear2d, instantiate=[Ty, Ty, Ty, L, D, D_hidden]) + print(" - Library function schedules created") + + # Create expert schedule + print(" - Creating expert schedule...") + s_expert = allo.customize(expert, instantiate=[Ty, L, D, D_hidden, D]) + + # Compose library function schedules + s_expert.compose(s_linear_fc1, id="expert_fc1") + s_expert.compose(s_gelu, id="expert_gelu") + s_expert.compose(s_linear_fc2, id="expert_fc2") + print(" - Created expert schedule") + print(" - Composed nn.linear2d and nn.GeLU schedules for expert") + + # Step 4: Create schedule for moe_layer + print("\n[4] Creating schedule for moe_layer...") + s_moe = allo.customize(moe_layer, instantiate=[Ty, L, D, D, E, K, D_hidden]) + + # Compose gate linear + s_gate_linear = allo.customize(allo_nn.linear2d, instantiate=[Ty, Ty, Ty, L, E, D]) + s_moe.compose(s_gate_linear, id="gate") + s_moe.compose(s_top1) + s_moe.compose(s_softmax_1d) + s_moe.compose(s_expert) + print(" - Created moe_layer schedule") + + # Step 5: Create schedule for custom attention + print("\n[5] Creating schedule for custom scaled_dot_product_attention...") + s_attn = allo.customize(scaled_dot_product_attention, instantiate=[Ty, H, L, D]) + print(" - Created scaled_dot_product_attention schedule") + + # Step 6: Create schedule for main attention_moe_layer function + print("\n[6] Creating schedule for attention_moe_layer...") + s_attn_moe = allo.customize( + attention_moe_layer, + instantiate=[Ty, B, L, D, H, E, K, D_hidden] + ) + + # Step 7: Compose all schedules together + print("\n[7] Composing all schedules together...") + s_attn_moe.compose(s_attn) + s_attn_moe.compose(s_moe) + print(" - Composed scaled_dot_product_attention schedule") + print(" - Composed moe_layer schedule") + + print("\n" + "=" * 60) + print("Schedule composition complete!") + print("=" * 60) + + return s_attn_moe + + +#================================================================================== +# Test function to compare Allo and PyTorch implementations +#================================================================================== + +if __name__ == "__main__": + import torch + import torch.nn as torch_nn + from pytorch_attention_moe import AttentionMoE + # ============================================================================ + # Configuration parameters - use shared config from llm_config.py + # ============================================================================ + # Get MoE configuration from shared config + moe_config = get_moe_config(CONFIG_MODE) + + batch_size = moe_config["batch_size"] + seq_len = moe_config["seq_len"] + embed_dim = moe_config["input_dim"] # D, must be divisible by num_heads + num_experts = moe_config["num_experts"] # E + k = moe_config["k"] # Top-k MoE + hidden_dim = moe_config["hidden_dim"] # D_hidden + + # Attention-specific parameter: num_heads + # Choose num_heads such that embed_dim is divisible by num_heads + # Common choices: 2, 4, 8, 12, 16 (depending on embed_dim) + if embed_dim >= 768: + num_heads = 12 # Standard for BERT-base + elif embed_dim >= 512: + num_heads = 8 + elif embed_dim >= 256: + num_heads = 8 + elif embed_dim >= 128: + num_heads = 4 + elif embed_dim >= 64: + num_heads = 4 + else: + num_heads = 2 + + # Ensure embed_dim is divisible by num_heads + while embed_dim % num_heads != 0 and num_heads > 1: + num_heads -= 1 + + seed = 42 + + print("=" * 60) + print("Attention + MoE Allo Implementation Test") + print("=" * 60) + print(f"Configuration Mode: {CONFIG_MODE}") + print(f"Configuration:") + print(f" batch_size={batch_size}, seq_len={seq_len}, embed_dim={embed_dim}") + print(f" num_heads={num_heads}, head_dim={embed_dim // num_heads}") + print(f" num_experts={num_experts}, k={k}, hidden_dim={hidden_dim}") + print(f" Seed: {seed}") + print("=" * 60) + + #---------------------------------------------------------------------------------- + # Run PyTorch implementation to get weights and outputs + #---------------------------------------------------------------------------------- + print("\n[1] Running PyTorch implementation...") + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # Create PyTorch AttentionMoE layer + pytorch_model = AttentionMoE( + seq_len=seq_len, + embed_dim=embed_dim, + num_heads=num_heads, + num_experts=num_experts, + k=k, + expert_hidden_dim=hidden_dim + ) + pytorch_model.eval() + + # Initialize with Xavier uniform + for param in pytorch_model.parameters(): + if param.dim() > 1: + torch_nn.init.xavier_uniform_(param) + else: + torch_nn.init.zeros_(param) + + # Create random inputs (Q, K, V) + torch.manual_seed(seed) + Q_pt = torch.randn(batch_size, seq_len, embed_dim) + K_pt = torch.randn(batch_size, seq_len, embed_dim) + V_pt = torch.randn(batch_size, seq_len, embed_dim) + + # Run PyTorch inference + with torch.no_grad(): + pytorch_output = pytorch_model(Q_pt, K_pt, V_pt, verbose=False) + + print(f"PyTorch output shape: {pytorch_output.shape}") + print(f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]") + + #---------------------------------------------------------------------------------- + # Extract weights and biases from PyTorch model + #---------------------------------------------------------------------------------- + print("\n[2] Extracting weights from PyTorch model...") + + # Gate weights + gate_weight_pt = pytorch_model.moe.gate.gate_linear.weight.data # [num_experts, embed_dim] + gate_bias_pt = pytorch_model.moe.gate.gate_linear.bias + if gate_bias_pt is not None: + gate_bias_pt = gate_bias_pt.data + else: + gate_bias_pt = torch.zeros(num_experts) + + # Expert weights + expert_fc1_weights_pt = torch.stack([exp.fc1.weight.data for exp in pytorch_model.moe.experts]) + expert_fc1_biases_pt = torch.stack([exp.fc1.bias.data for exp in pytorch_model.moe.experts]) + expert_fc2_weights_pt = torch.stack([exp.fc2.weight.data for exp in pytorch_model.moe.experts]) + expert_fc2_biases_pt = torch.stack([exp.fc2.bias.data for exp in pytorch_model.moe.experts]) + + #---------------------------------------------------------------------------------- + # Convert to numpy arrays + #---------------------------------------------------------------------------------- + print("\n[3] Converting weights to numpy arrays...") + Q_np = np.ascontiguousarray(Q_pt.detach().numpy(), dtype=np.float32) + K_np = np.ascontiguousarray(K_pt.detach().numpy(), dtype=np.float32) + V_np = np.ascontiguousarray(V_pt.detach().numpy(), dtype=np.float32) + + gate_weight_np = np.ascontiguousarray(gate_weight_pt.detach().numpy(), dtype=np.float32) + gate_bias_np = np.ascontiguousarray(gate_bias_pt.detach().numpy(), dtype=np.float32) + + expert_fc1_weights_np = np.ascontiguousarray(expert_fc1_weights_pt.detach().numpy(), dtype=np.float32) + expert_fc1_biases_np = np.ascontiguousarray(expert_fc1_biases_pt.detach().numpy(), dtype=np.float32) + expert_fc2_weights_np = np.ascontiguousarray(expert_fc2_weights_pt.detach().numpy(), dtype=np.float32) + expert_fc2_biases_np = np.ascontiguousarray(expert_fc2_biases_pt.detach().numpy(), dtype=np.float32) + + print(f"Q shape: {Q_np.shape}") + print(f"K shape: {K_np.shape}") + print(f"V shape: {V_np.shape}") + print(f"Gate weight shape: {gate_weight_np.shape}") + print(f"Expert FC1 weights shape: {expert_fc1_weights_np.shape}") + print(f"Expert FC2 weights shape: {expert_fc2_weights_np.shape}") + + #---------------------------------------------------------------------------------- + # Run Allo implementation + #---------------------------------------------------------------------------------- + print("\n[4] Running Allo implementation...") + try: + # Create optimized schedule with composition + allo_schedule = optimize_attention_moe_with_composition( + batch_size, seq_len, embed_dim, num_heads, num_experts, k, hidden_dim + ) + + # Generate project name + project_name = f"allo_attention_moe_lib_{CONFIG_MODE}.prj" + print(f"Using project name: {project_name}") + + # Build module + print("\n[5] Building Allo module...") + if MODE == "llvm": + mod = allo_schedule.build(target="llvm") + elif MODE == "sw_emu": + mod = allo_schedule.build(target="vitis_hls", mode="sw_emu", project=project_name) + elif MODE == "hw_emu": + mod = allo_schedule.build(target="vitis_hls", mode="hw_emu", project=project_name) + elif MODE == "hw": + mod = allo_schedule.build(target="vitis_hls", mode="hw", project=project_name) + elif MODE == "csyn": + mod = allo_schedule.build(target="vitis_hls", mode="csyn", project=project_name) + else: + raise ValueError(f"Unsupported mode: {MODE}") + + # Run Allo inference + print("\n[6] Running Allo inference...") + if MODE == "llvm": + allo_output = mod( + Q_np, K_np, V_np, + gate_weight_np, gate_bias_np, + expert_fc1_weights_np, expert_fc1_biases_np, + expert_fc2_weights_np, expert_fc2_biases_np + ) + elif MODE in ["sw_emu", "hw_emu", "hw"]: + allo_output = np.zeros((batch_size, seq_len, embed_dim), dtype=np.float32) + mod(Q_np, K_np, V_np, + gate_weight_np, gate_bias_np, + expert_fc1_weights_np, expert_fc1_biases_np, + expert_fc2_weights_np, expert_fc2_biases_np, + allo_output) + elif MODE == "csyn": + allo_output = np.zeros((batch_size, seq_len, embed_dim), dtype=np.float32) + mod() + else: + raise ValueError(f"Unsupported mode: {MODE}") + + print(f"Allo output shape: {allo_output.shape}") + print(f"Allo output range: [{allo_output.min():.6f}, {allo_output.max():.6f}]") + + #---------------------------------------------------------------------------------- + # Compare the Allo and PyTorch outputs + #---------------------------------------------------------------------------------- + print("\n[7] Comparing outputs...") + pytorch_output_np = pytorch_output.detach().numpy() + + # Compute differences + diff = np.abs(allo_output - pytorch_output_np) + mean_diff = np.mean(diff) + max_diff = np.max(diff) + rel_diff = np.mean(diff / (np.abs(pytorch_output_np) + 1e-8)) + + print(f"Mean absolute difference: {mean_diff:.6e}") + print(f"Max absolute difference: {max_diff:.6e}") + print(f"Mean relative difference: {rel_diff:.6e}") + + # Check if outputs are close + atol = 5e-4 + rtol = 2e-3 + is_close = np.allclose(allo_output, pytorch_output_np, atol=atol, rtol=rtol) + + if is_close: + print(f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})") + else: + print(f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})") + print("First few differences:") + print(diff.flatten()[:10]) + + #---------------------------------------------------------------------------------- + # Print sample outputs for comparison + #---------------------------------------------------------------------------------- + print("\n[8] Sample outputs (first token, first 5 dimensions):") + print(f"PyTorch: {pytorch_output_np[0, 0, :5]}") + print(f"Allo: {allo_output[0, 0, :5]}") + print(f"Diff: {diff[0, 0, :5]}") + + except Exception as e: + print(f"\n✗ ERROR: Failed to run Allo implementation: {e}") + import traceback + traceback.print_exc() + + print("=" * 60) diff --git a/ece6775/attention_moe/pytorch_attention_moe.py b/ece6775/attention_moe/pytorch_attention_moe.py new file mode 100644 index 000000000..f942dbcc6 --- /dev/null +++ b/ece6775/attention_moe/pytorch_attention_moe.py @@ -0,0 +1,1060 @@ +""" +Attention + Mixture of Experts (MoE) Implementation for Inference Only + +This script implements: +1. Scaled Dot-Product Attention (Multi-Head Attention) +2. MoE layer +3. Combined Attention -> MoE pipeline + +Uses random inputs and weights for demonstration. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple, Optional +import math + + +class ScaledDotProductAttention(nn.Module): + """ + Scaled Dot-Product Attention (Multi-Head Attention). + + This implementation matches the Allo library version: + - Input: Q, K, V of shape [L, D] + - Split into H heads, each with dimension D // H + - For each head: softmax(QK^T / sqrt(D // H)) @ V + - Merge heads back to [L, D] + + Args: + num_heads: Number of attention heads (H) + head_dim: Dimension per head (D // H) + """ + + def __init__(self, num_heads: int, embed_dim: int): + super().__init__() + self.num_heads = num_heads + self.embed_dim = embed_dim + self.head_dim = embed_dim // num_heads + assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" + + # Scaling factor: 1 / sqrt(head_dim) + self.scale = 1.0 / math.sqrt(self.head_dim) + + def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, verbose: bool = False) -> torch.Tensor: + """ + Forward pass through scaled dot-product attention. + + Args: + Q: Query tensor of shape [L, D] + K: Key tensor of shape [L, D] + V: Value tensor of shape [L, D] + verbose: Whether to print debug information + + Returns: + Output tensor of shape [L, D] + """ + L, D = Q.shape + H = self.num_heads + head_dim = D // H + + # Initialize output + Z = torch.zeros(L, D, device=Q.device, dtype=Q.dtype) + + # Process each head + for h in range(H): + # Split Q, K, V for this head + start_idx = h * head_dim + end_idx = (h + 1) * head_dim + + Q_h = Q[:, start_idx:end_idx] # [L, head_dim] + K_h = K[:, start_idx:end_idx] # [L, head_dim] + V_h = V[:, start_idx:end_idx] # [L, head_dim] + + # QK^T = [L, head_dim] @ [head_dim, L] = [L, L] + # Note: K_h.T is the transpose + Y = torch.matmul(Q_h, K_h.T) # [L, L] + + # Scale by 1/sqrt(head_dim) + Y = Y * self.scale + + # Softmax over last dimension + S = F.softmax(Y, dim=-1) # [L, L] + + # YV = [L, L] @ [L, head_dim] = [L, head_dim] + C_h = torch.matmul(S, V_h) # [L, head_dim] + + # Merge back to Z + Z[:, start_idx:end_idx] = C_h + + if verbose: + print(f"\n[Attention] L={L}, D={D}, H={H}, head_dim={head_dim}") + print(f"[Attention] scale={self.scale:.6f}") + print(f"[Attention] Output shape: {Z.shape}") + + return Z + + +class TopKGate(nn.Module): + """ + Gate module to select top k experts for routing. + + Args: + input_dim: Input feature dimension + num_experts: Total number of experts + k: Number of experts to select per token + """ + + def __init__(self, input_dim: int, num_experts: int, k: int = 1): + super().__init__() + self.k = k + self.num_experts = num_experts + # Linear layer to compute expert logits + self.gate_linear = nn.Linear(input_dim, num_experts, bias=False) + + def forward(self, x: torch.Tensor, verbose: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass through the gate. + + Args: + x: Input tensor of shape [batch_size * seq_len, input_dim] + verbose: Whether to print routing information + + Returns: + full_weights: Sparse weight matrix [batch_size * seq_len, num_experts] + top_k_indices: Indices of selected experts [batch_size * seq_len, k] + """ + # Compute logits for all experts + logits = self.gate_linear(x) # [N, num_experts] + + # Select top-k experts + top_k_logits, top_k_indices = torch.topk(logits, self.k, dim=-1) + + # Apply softmax to top-k logits for normalized weights + top_k_weights = F.softmax(top_k_logits, dim=-1) # [N, k] + + # Create sparse weight matrix (zeros for non-selected experts) + full_weights = torch.zeros_like(logits) + full_weights.scatter_(1, top_k_indices, top_k_weights) + + # Print routing information if verbose + if verbose: + num_tokens = x.shape[0] + + # Count tokens per expert + expert_counts = {} + for expert_idx in range(self.num_experts): + count = (top_k_indices == expert_idx).sum().item() + if count > 0: + expert_counts[expert_idx] = count + + # Verify that each token selects exactly k experts + tokens_per_expert_count = {} + for i in range(num_tokens): + num_experts_selected = (top_k_weights[i] > 1e-6).sum().item() + tokens_per_expert_count[num_experts_selected] = tokens_per_expert_count.get(num_experts_selected, 0) + 1 + + print(f"\n[Gate Routing] Total tokens: {num_tokens}, k={self.k}") + print(f"[Gate Routing] Expert distribution:") + total_selections = num_tokens * self.k # Total expert-token pairs + for expert_idx, count in sorted(expert_counts.items()): + percentage = (count / total_selections) * 100 + print(f" Expert {expert_idx}: {count} selections ({percentage:.1f}% of {total_selections} total selections)") + + # Verify top-k: each token selects exactly k experts + if tokens_per_expert_count.get(self.k, 0) == num_tokens: + print(f"[Gate Routing] ✓ All {num_tokens} tokens select exactly {self.k} expert(s)") + else: + print(f"[Gate Routing] ⚠ Expected all tokens to select {self.k} expert(s), but got: {tokens_per_expert_count}") + + return full_weights, top_k_indices + + +class Expert(nn.Module): + """ + A simple feed-forward expert network. + + Args: + input_dim: Input feature dimension + hidden_dim: Hidden layer dimension + output_dim: Output feature dimension + """ + + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): + super().__init__() + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, output_dim) + # Use GELU activation (modern choice) + self.activation = nn.GELU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the expert.""" + x = self.fc1(x) + x = self.activation(x) + x = self.fc2(x) + return x + + +class MoELayer(nn.Module): + """ + Mixture of Experts layer for inference. + + Args: + input_dim: Input feature dimension + output_dim: Output feature dimension + num_experts: Total number of experts + k: Number of experts to activate per token + expert_hidden_dim: Hidden dimension for experts (default: 4 * input_dim) + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + num_experts: int, + k: int = 1, + expert_hidden_dim: Optional[int] = None + ): + super().__init__() + self.num_experts = num_experts + self.k = k + self.output_dim = output_dim + + if expert_hidden_dim is None: + expert_hidden_dim = input_dim * 4 # Common practice + + # Initialize gate and experts + self.gate = TopKGate(input_dim, num_experts, k) + self.experts = nn.ModuleList([ + Expert(input_dim, expert_hidden_dim, output_dim) + for _ in range(num_experts) + ]) + + def forward(self, x: torch.Tensor, verbose: bool = False) -> torch.Tensor: + """ + Forward pass through MoE layer. + + Args: + x: Input tensor of shape [batch_size, seq_len, input_dim] + verbose: Whether to print verification information for top-1 MoE + + Returns: + Output tensor of shape [batch_size, seq_len, output_dim] + """ + + # Store original shape + original_shape = x.shape + batch_size, seq_len = original_shape[0], original_shape[1] + + # Flatten batch and sequence dimensions + x_flat = x.view(-1, original_shape[-1]) # [N, input_dim], N = batch * seq_len + num_tokens = x_flat.shape[0] + + # Get gating weights and expert indices + gate_weights, top_k_indices = self.gate(x_flat, verbose=verbose) # [N, num_experts], [N, k] + + # Initialize output tensor + output = torch.zeros( + x_flat.shape[0], + self.output_dim, + device=x.device, + dtype=x.dtype + ) + + # Track expert usage statistics + expert_usage_stats = {} + + # Process each expert + for expert_idx in range(self.num_experts): + # Find tokens assigned to this expert + # Create mask: tokens that have this expert in their top-k + expert_mask = (top_k_indices == expert_idx).any(dim=-1) # [N] + + num_tokens_for_expert = expert_mask.sum().item() + + if not expert_mask.any(): + expert_usage_stats[expert_idx] = 0 + continue # No tokens for this expert + + # Track usage + expert_usage_stats[expert_idx] = num_tokens_for_expert + + # Get inputs for this expert + expert_inputs = x_flat[expert_mask] # [num_tokens, input_dim] + + # Process through expert + expert_outputs = self.experts[expert_idx](expert_inputs) # [num_tokens, output_dim] + + # Get corresponding weights for these tokens + # For each token, find the weight for this expert (could be in any of k positions) + token_indices = torch.where(expert_mask)[0] # Indices in flattened tensor + expert_weights = gate_weights[token_indices, expert_idx].unsqueeze(1) # [num_tokens, 1] + + # Weight and accumulate outputs + weighted_outputs = expert_outputs * expert_weights + output[token_indices] += weighted_outputs + + # Verify top-k behavior + if verbose: + active_experts = [idx for idx, count in expert_usage_stats.items() if count > 0] + total_usage = sum(expert_usage_stats.values()) + + # Verify top-k: each token selects exactly k experts + if top_k_indices.shape[1] != self.k: + print(f"[MoE Verification] ✗ ERROR: Expected k={self.k}, but top_k_indices has shape {top_k_indices.shape}") + + # Verify all tokens are processed (for k>1, total_usage may be > num_tokens) + expected_total_usage = num_tokens * self.k + if total_usage != expected_total_usage: + print(f"[MoE Verification] ⚠ Expected {expected_total_usage} expert-token pairs (={num_tokens} tokens × {self.k} experts), but got {total_usage}") + + # Check if weights are normalized (should sum to 1.0 per token) + sample_weights = gate_weights.sum(dim=1) + weights_normalized = torch.allclose(sample_weights, torch.ones_like(sample_weights), atol=1e-5) + + # Final summary for top-k + print(f"\n[MoE Verification] === Top-{self.k} MoE Verification ===") + print(f" ✓ Each token routes to exactly {self.k} expert(s) (k={self.k})") + print(f" ✓ All {num_tokens} tokens processed") + print(f" ✓ Total expert-token pairs: {total_usage} (expected: {expected_total_usage})") + print(f" ✓ Active experts: {len(active_experts)}/{self.num_experts}") + print(f" ✓ Expert distribution: {dict(sorted(expert_usage_stats.items()))}") + if weights_normalized: + print(f" ✓ Gate weights normalized (sum to 1.0 per token)") + else: + print(f" ✗ Gate weights NOT normalized correctly") + print(f" === Top-{self.k} MoE is working correctly! ===") + + # Reshape back to original shape + output = output.view(batch_size, seq_len, self.output_dim) + + return output + + +class AttentionMoE(nn.Module): + """ + Combined Attention + MoE layer. + + Data flow: Input -> Attention -> MoE -> Output + + This follows the Transformer architecture pattern where: + 1. Attention computes self-attention on the input + 2. MoE replaces the standard FFN (Feed-Forward Network) + + Args: + seq_len: Sequence length (L) + embed_dim: Embedding dimension (D) + num_heads: Number of attention heads (H) + num_experts: Number of MoE experts (E) + k: Top-k experts to activate per token + expert_hidden_dim: Hidden dimension for experts + """ + + def __init__( + self, + seq_len: int, + embed_dim: int, + num_heads: int, + num_experts: int, + k: int = 1, + expert_hidden_dim: Optional[int] = None + ): + super().__init__() + self.seq_len = seq_len + self.embed_dim = embed_dim + self.num_heads = num_heads + + # Attention layer + self.attention = ScaledDotProductAttention(num_heads, embed_dim) + + # MoE layer (input_dim = output_dim = embed_dim) + self.moe = MoELayer( + input_dim=embed_dim, + output_dim=embed_dim, + num_experts=num_experts, + k=k, + expert_hidden_dim=expert_hidden_dim + ) + + def forward( + self, + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + verbose: bool = False + ) -> torch.Tensor: + """ + Forward pass: Attention -> MoE + + Args: + Q: Query tensor of shape [B, L, D] + K: Key tensor of shape [B, L, D] + V: Value tensor of shape [B, L, D] + verbose: Whether to print debug information + + Returns: + Output tensor of shape [B, L, D] + """ + batch_size = Q.shape[0] + L = Q.shape[1] + D = Q.shape[2] + + # Process each batch item through attention + # Attention expects [L, D], so we process batch by batch + attn_outputs = [] + for b in range(batch_size): + attn_out = self.attention(Q[b], K[b], V[b], verbose=verbose and b == 0) + attn_outputs.append(attn_out) + + # Stack back to [B, L, D] + attn_output = torch.stack(attn_outputs, dim=0) + + if verbose: + print(f"\n[AttentionMoE] Attention output shape: {attn_output.shape}") + + # Pass through MoE + # MoE expects [B, L, D] and returns [B, L, D] + moe_output = self.moe(attn_output, verbose=verbose) + + if verbose: + print(f"[AttentionMoE] MoE output shape: {moe_output.shape}") + + return moe_output + + +def run_attention_moe_inference( + batch_size: int = 4, + seq_len: int = 4, + embed_dim: int = 64, + num_heads: int = 4, + num_experts: int = 2, + k: int = 1, + expert_hidden_dim: Optional[int] = None, + seed: int = 24, + verbose: bool = True +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, AttentionMoE]: + """ + Run Attention + MoE inference with fixed seed for reproducible inputs and weights. + + Args: + batch_size: Batch size (B) + seq_len: Sequence length (L) + embed_dim: Embedding dimension (D) + num_heads: Number of attention heads (H) + num_experts: Number of experts (E) + k: Top-k experts to activate per token + expert_hidden_dim: Hidden dimension for experts + seed: Random seed for reproducibility + verbose: Whether to print verification information + + Returns: + output: Output tensor from AttentionMoE [B, L, D] + Q, K, V: Input tensors used for inference + model: AttentionMoE model instance + """ + # Set random seed for reproducibility + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # Create model + model = AttentionMoE( + seq_len=seq_len, + embed_dim=embed_dim, + num_heads=num_heads, + num_experts=num_experts, + k=k, + expert_hidden_dim=expert_hidden_dim + ) + model.eval() + + # Initialize with random weights (Xavier uniform for better stability) + for param in model.parameters(): + if param.dim() > 1: + nn.init.xavier_uniform_(param) + else: + nn.init.zeros_(param) + + # Create random inputs (Q, K, V) + torch.manual_seed(seed) + Q = torch.randn(batch_size, seq_len, embed_dim) + K = torch.randn(batch_size, seq_len, embed_dim) + V = torch.randn(batch_size, seq_len, embed_dim) + + # Run inference with verbose output for verification + with torch.no_grad(): + output = model(Q, K, V, verbose=verbose) + + return output, Q, K, V, model + + +def verify_gate_layer( + num_tokens: int = 128, + input_dim: int = 96, + num_experts: int = 2, + k: int = 1, + seed: int = 42 +): + """ + Verify that the Gate Layer is effectively routing tokens to different experts. + + This function tests: + 1. Different inputs produce different routing decisions + 2. Gate weights are properly normalized (sum to 1.0) + 3. Expert distribution is reasonable (not all tokens to one expert) + 4. Routing is deterministic (same input -> same routing) + """ + print("=" * 70) + print("Gate Layer Verification Test") + print("=" * 70) + print(f"Configuration: num_tokens={num_tokens}, input_dim={input_dim}") + print(f" num_experts={num_experts}, k={k}") + print("=" * 70) + + torch.manual_seed(seed) + + # Create gate + gate = TopKGate(input_dim, num_experts, k) + nn.init.xavier_uniform_(gate.gate_linear.weight) + gate.eval() + + # Test 1: Create random inputs and check routing + print("\n[Test 1] Random Input Routing") + print("-" * 50) + x = torch.randn(num_tokens, input_dim) + + with torch.no_grad(): + full_weights, top_k_indices = gate(x, verbose=False) + + # Count tokens per expert + expert_counts = {} + for e in range(num_experts): + count = (top_k_indices == e).sum().item() + expert_counts[e] = count + percentage = (count / (num_tokens * k)) * 100 + print(f" Expert {e}: {count} tokens ({percentage:.1f}%)") + + # Check if routing is balanced + min_count = min(expert_counts.values()) + max_count = max(expert_counts.values()) + imbalance_ratio = max_count / max(min_count, 1) + + if imbalance_ratio < 10: + print(f" ✓ Routing is reasonably balanced (imbalance ratio: {imbalance_ratio:.2f})") + else: + print(f" ⚠ Routing is imbalanced (imbalance ratio: {imbalance_ratio:.2f})") + + # Test 2: Verify weights are normalized + print("\n[Test 2] Weight Normalization") + print("-" * 50) + weight_sums = full_weights.sum(dim=1) + all_normalized = torch.allclose(weight_sums, torch.ones_like(weight_sums), atol=1e-5) + + if all_normalized: + print(f" ✓ All gate weights sum to 1.0 (properly normalized)") + else: + print(f" ✗ Gate weights NOT properly normalized!") + print(f" Weight sums range: [{weight_sums.min():.6f}, {weight_sums.max():.6f}]") + + # Test 3: Different inputs should (mostly) produce different routings + print("\n[Test 3] Input-Dependent Routing") + print("-" * 50) + + # Create two very different inputs + x1 = torch.randn(10, input_dim) * 10 # Large magnitude + x2 = torch.randn(10, input_dim) * 0.1 # Small magnitude + + with torch.no_grad(): + _, indices1 = gate(x1, verbose=False) + _, indices2 = gate(x2, verbose=False) + + # Check if routings are different + same_routing = (indices1 == indices2).all().item() + if not same_routing: + print(f" ✓ Different inputs produce different routing decisions") + else: + print(f" ⚠ Different inputs produced same routing (may indicate issue)") + + # Test 4: Determinism - same input should always produce same routing + print("\n[Test 4] Routing Determinism") + print("-" * 50) + + x_test = torch.randn(20, input_dim) + with torch.no_grad(): + _, indices_run1 = gate(x_test, verbose=False) + _, indices_run2 = gate(x_test, verbose=False) + + is_deterministic = (indices_run1 == indices_run2).all().item() + if is_deterministic: + print(f" ✓ Routing is deterministic (same input -> same expert)") + else: + print(f" ✗ Routing is NOT deterministic!") + + # Test 5: Show actual routing decisions for a few tokens + print("\n[Test 5] Sample Routing Decisions") + print("-" * 50) + + # Create a small set of tokens to visualize + x_sample = torch.randn(8, input_dim) + with torch.no_grad(): + logits = gate.gate_linear(x_sample) + weights, indices = gate(x_sample, verbose=False) + + print(f" Token | Expert Logits (raw) | Selected Expert | Weight") + print(f" " + "-" * 60) + for i in range(min(8, x_sample.shape[0])): + logit_str = ", ".join([f"{l:.3f}" for l in logits[i].tolist()]) + selected = indices[i, 0].item() + weight = weights[i, selected].item() + print(f" {i:5d} | [{logit_str:25s}] | Expert {selected} | {weight:.4f}") + + # Test 6: Verify gate learns meaningful patterns + print("\n[Test 6] Pattern-Based Routing") + print("-" * 50) + + # Create inputs with clear patterns + # Pattern A: positive values in first half, negative in second half + # Pattern B: opposite + pattern_a = torch.cat([torch.ones(input_dim // 2), -torch.ones(input_dim // 2)]).unsqueeze(0) + pattern_b = torch.cat([-torch.ones(input_dim // 2), torch.ones(input_dim // 2)]).unsqueeze(0) + + # Repeat patterns + x_patterns = torch.cat([pattern_a.repeat(5, 1), pattern_b.repeat(5, 1)], dim=0) + + with torch.no_grad(): + _, pattern_indices = gate(x_patterns, verbose=False) + + pattern_a_experts = pattern_indices[:5, 0].tolist() + pattern_b_experts = pattern_indices[5:, 0].tolist() + + print(f" Pattern A (positive-negative) routed to experts: {pattern_a_experts}") + print(f" Pattern B (negative-positive) routed to experts: {pattern_b_experts}") + + # Check if same patterns go to same expert + a_consistent = len(set(pattern_a_experts)) == 1 + b_consistent = len(set(pattern_b_experts)) == 1 + patterns_different = set(pattern_a_experts) != set(pattern_b_experts) + + if a_consistent and b_consistent: + print(f" ✓ Same patterns consistently route to same expert") + else: + print(f" ⚠ Same patterns route to different experts (expected with random weights)") + + if patterns_different: + print(f" ✓ Different patterns route to different experts") + else: + print(f" ⚠ Different patterns route to same expert") + + print("\n" + "=" * 70) + print("Gate Layer Verification Complete!") + print("=" * 70) + + return gate, expert_counts + + +def benchmark_attention_moe( + batch_size: int = 1, + seq_len: int = 64, + embed_dim: int = 256, + num_heads: int = 8, + num_experts: int = 4, + k: int = 1, + expert_hidden_dim: Optional[int] = None, + num_warmup: int = 10, + num_runs: int = 100, + device: str = "cpu", + seed: int = 42 +): + """ + Benchmark AttentionMoE inference time. + + Best practices for accurate timing: + 1. Warmup runs: Avoid cold start overhead (JIT compilation, cache warming) + 2. Multiple runs: Get stable average and standard deviation + 3. torch.no_grad(): Disable gradient tracking for inference + 4. torch.cuda.synchronize(): Ensure GPU operations complete before timing + + Args: + batch_size: Batch size (B) + seq_len: Sequence length (L) + embed_dim: Embedding dimension (D) + num_heads: Number of attention heads (H) + num_experts: Number of experts (E) + k: Top-k experts per token + expert_hidden_dim: Hidden dim for experts (default: 4 * embed_dim) + num_warmup: Number of warmup iterations + num_runs: Number of timed iterations + device: "cpu" or "cuda" + seed: Random seed + + Returns: + dict: Timing statistics (mean, std, min, max in milliseconds) + """ + import time + import numpy as np + + print("=" * 70) + print("AttentionMoE Benchmark") + print("=" * 70) + print(f"Configuration:") + print(f" batch_size={batch_size}, seq_len={seq_len}, embed_dim={embed_dim}") + print(f" num_heads={num_heads}, head_dim={embed_dim // num_heads}") + print(f" num_experts={num_experts}, k={k}") + print(f" expert_hidden_dim={expert_hidden_dim or embed_dim * 4}") + print(f" device={device}") + print(f" num_warmup={num_warmup}, num_runs={num_runs}") + print("=" * 70) + + # Set seed and create model + torch.manual_seed(seed) + + model = AttentionMoE( + seq_len=seq_len, + embed_dim=embed_dim, + num_heads=num_heads, + num_experts=num_experts, + k=k, + expert_hidden_dim=expert_hidden_dim + ) + model.eval() + + # Move to device + if device == "cuda" and torch.cuda.is_available(): + model = model.cuda() + print(f" Using GPU: {torch.cuda.get_device_name(0)}") + else: + device = "cpu" + print(f" Using CPU") + + # Initialize weights + for param in model.parameters(): + if param.dim() > 1: + nn.init.xavier_uniform_(param) + else: + nn.init.zeros_(param) + + # Create input tensors + torch.manual_seed(seed) + if device == "cuda": + Q = torch.randn(batch_size, seq_len, embed_dim, device="cuda") + K = torch.randn(batch_size, seq_len, embed_dim, device="cuda") + V = torch.randn(batch_size, seq_len, embed_dim, device="cuda") + else: + Q = torch.randn(batch_size, seq_len, embed_dim) + K = torch.randn(batch_size, seq_len, embed_dim) + V = torch.randn(batch_size, seq_len, embed_dim) + + # Warmup runs (important for JIT, cache warming, etc.) + print(f"\n[1] Warmup ({num_warmup} iterations)...") + with torch.no_grad(): + for _ in range(num_warmup): + _ = model(Q, K, V, verbose=False) + if device == "cuda": + torch.cuda.synchronize() # Wait for GPU to finish + + # Timed runs + print(f"[2] Timing ({num_runs} iterations)...") + times = [] + + with torch.no_grad(): + for i in range(num_runs): + if device == "cuda": + torch.cuda.synchronize() # Ensure previous work is done + + start_time = time.perf_counter() # High-resolution timer + + _ = model(Q, K, V, verbose=False) + + if device == "cuda": + torch.cuda.synchronize() # Wait for GPU to finish + + end_time = time.perf_counter() + times.append((end_time - start_time) * 1000) # Convert to ms + + # Compute statistics + times = np.array(times) + stats = { + "mean_ms": np.mean(times), + "std_ms": np.std(times), + "min_ms": np.min(times), + "max_ms": np.max(times), + "median_ms": np.median(times), + "p95_ms": np.percentile(times, 95), + "p99_ms": np.percentile(times, 99), + } + + # Print results + print("\n" + "=" * 70) + print("Benchmark Results:") + print("=" * 70) + print(f" Mean: {stats['mean_ms']:.4f} ms") + print(f" Std: {stats['std_ms']:.4f} ms") + print(f" Min: {stats['min_ms']:.4f} ms") + print(f" Max: {stats['max_ms']:.4f} ms") + print(f" Median: {stats['median_ms']:.4f} ms") + print(f" P95: {stats['p95_ms']:.4f} ms") + print(f" P99: {stats['p99_ms']:.4f} ms") + print("-" * 70) + print(f" Throughput: {1000 / stats['mean_ms']:.2f} inferences/sec") + print(f" Tokens/sec: {batch_size * seq_len * 1000 / stats['mean_ms']:.2f}") + print("=" * 70) + + return stats, model + + +def benchmark_with_allo_config( + config_mode: str = "switch_base_8_scaled_1_8", + num_warmup: int = 10, + num_runs: int = 100, + device: str = "cpu", + seed: int = 42 +): + """ + Benchmark using the same configuration as Allo tests. + This ensures fair comparison between PyTorch and Allo implementations. + + Uses llm_config.py for configuration to match allo_attention_moe_alt.py + + Args: + config_mode: Configuration mode from llm_config (e.g., "switch_base_8_scaled_1_8") + num_warmup: Number of warmup iterations + num_runs: Number of timed iterations + device: "cpu" or "cuda" + seed: Random seed (must match Allo test seed for same inputs) + + Returns: + dict: Timing statistics and model + """ + import sys + import os + import time + import numpy as np + + # Add llm_config to path + current_dir = os.path.dirname(os.path.abspath(__file__)) + project_root = os.path.dirname(current_dir) + llm_config_dir = os.path.join(project_root, "llm_config") + if llm_config_dir not in sys.path: + sys.path.insert(0, llm_config_dir) + + from llm_config import get_moe_config + + # Get configuration (same as allo_attention_moe_alt.py) + moe_config = get_moe_config(config_mode) + + batch_size = moe_config["batch_size"] + seq_len = moe_config["seq_len"] + embed_dim = moe_config["input_dim"] + num_experts = moe_config["num_experts"] + k = moe_config["k"] + hidden_dim = moe_config["hidden_dim"] + + # Determine num_heads (same logic as allo_attention_moe_alt.py) + if embed_dim >= 768: + num_heads = 12 + elif embed_dim >= 512: + num_heads = 8 + elif embed_dim >= 256: + num_heads = 8 + elif embed_dim >= 128: + num_heads = 4 + elif embed_dim >= 64: + num_heads = 4 + else: + num_heads = 2 + + while embed_dim % num_heads != 0 and num_heads > 1: + num_heads -= 1 + + print("=" * 70) + print("PyTorch AttentionMoE Benchmark (Allo-compatible config)") + print("=" * 70) + print(f"Config Mode: {config_mode}") + print(f"Configuration:") + print(f" batch_size={batch_size}, seq_len={seq_len}, embed_dim={embed_dim}") + print(f" num_heads={num_heads}, head_dim={embed_dim // num_heads}") + print(f" num_experts={num_experts}, k={k}, hidden_dim={hidden_dim}") + print(f" device={device}, seed={seed}") + print(f" num_warmup={num_warmup}, num_runs={num_runs}") + print("=" * 70) + + # Create model with same initialization as Allo test + torch.manual_seed(seed) + if torch.cuda.is_available() and device == "cuda": + torch.cuda.manual_seed_all(seed) + + model = AttentionMoE( + seq_len=seq_len, + embed_dim=embed_dim, + num_heads=num_heads, + num_experts=num_experts, + k=k, + expert_hidden_dim=hidden_dim + ) + model.eval() + + # Initialize with Xavier uniform (same as Allo test) + for param in model.parameters(): + if param.dim() > 1: + nn.init.xavier_uniform_(param) + else: + nn.init.zeros_(param) + + # Move to device + if device == "cuda" and torch.cuda.is_available(): + model = model.cuda() + print(f" Using GPU: {torch.cuda.get_device_name(0)}") + else: + device = "cpu" + print(f" Using CPU") + + # Create input tensors with same seed (same as Allo test) + torch.manual_seed(seed) + if device == "cuda": + Q = torch.randn(batch_size, seq_len, embed_dim, device="cuda") + K = torch.randn(batch_size, seq_len, embed_dim, device="cuda") + V = torch.randn(batch_size, seq_len, embed_dim, device="cuda") + else: + Q = torch.randn(batch_size, seq_len, embed_dim) + K = torch.randn(batch_size, seq_len, embed_dim) + V = torch.randn(batch_size, seq_len, embed_dim) + + print(f"\nInput shapes: Q={Q.shape}, K={K.shape}, V={V.shape}") + + # Verify output first + print("\n[1] Verifying output...") + with torch.no_grad(): + output = model(Q, K, V, verbose=False) + print(f" Output shape: {output.shape}") + print(f" Output range: [{output.min().item():.6f}, {output.max().item():.6f}]") + + # Warmup + print(f"\n[2] Warmup ({num_warmup} iterations)...") + with torch.no_grad(): + for _ in range(num_warmup): + _ = model(Q, K, V, verbose=False) + if device == "cuda": + torch.cuda.synchronize() + + # Timed runs + print(f"[3] Timing ({num_runs} iterations)...") + times = [] + + with torch.no_grad(): + for i in range(num_runs): + if device == "cuda": + torch.cuda.synchronize() + + start_time = time.perf_counter() + _ = model(Q, K, V, verbose=False) + + if device == "cuda": + torch.cuda.synchronize() + + end_time = time.perf_counter() + times.append((end_time - start_time) * 1000) + + # Compute statistics + times = np.array(times) + stats = { + "mean_ms": np.mean(times), + "std_ms": np.std(times), + "min_ms": np.min(times), + "max_ms": np.max(times), + "median_ms": np.median(times), + "p95_ms": np.percentile(times, 95), + "p99_ms": np.percentile(times, 99), + } + + # Print results + print("\n" + "=" * 70) + print("Benchmark Results:") + print("=" * 70) + print(f" Mean: {stats['mean_ms']:.4f} ms") + print(f" Std: {stats['std_ms']:.4f} ms") + print(f" Min: {stats['min_ms']:.4f} ms") + print(f" Max: {stats['max_ms']:.4f} ms") + print(f" Median: {stats['median_ms']:.4f} ms") + print(f" P95: {stats['p95_ms']:.4f} ms") + print(f" P99: {stats['p99_ms']:.4f} ms") + print("-" * 70) + print(f" Throughput: {1000 / stats['mean_ms']:.2f} inferences/sec") + print(f" Tokens/sec: {batch_size * seq_len * 1000 / stats['mean_ms']:.2f}") + print("=" * 70) + + return stats, model, Q, K, V + + +if __name__ == "__main__": + # First, run gate layer verification + print("\n" + "#" * 70) + print("# PART 1: Gate Layer Verification") + print("#" * 70) + + gate, expert_counts = verify_gate_layer( + num_tokens=128, + input_dim=96, + num_experts=2, + k=1, + seed=42 + ) + + # Then run the full Attention + MoE test + print("\n" + "#" * 70) + print("# PART 2: Full Attention + MoE Inference Test") + print("#" * 70) + + # Configuration parameters + batch_size = 1 + seq_len = 4 + embed_dim = 8 # D, must be divisible by num_heads + num_heads = 2 # H + num_experts = 2 # E + k = 1 # Top-k MoE: each token uses exactly k experts + expert_hidden_dim = 16 # Hidden dimension for experts + seed = 24 + + print("=" * 60) + print(f"Attention + MoE Inference Test") + print("=" * 60) + print(f"Configuration:") + print(f" batch_size={batch_size}, seq_len={seq_len}, embed_dim={embed_dim}") + print(f" num_heads={num_heads}, head_dim={embed_dim // num_heads}") + print(f" num_experts={num_experts}, k={k}, expert_hidden_dim={expert_hidden_dim}") + print("=" * 60) + + # Run Attention + MoE inference + output, Q, K, V, model = run_attention_moe_inference( + batch_size=batch_size, + seq_len=seq_len, + embed_dim=embed_dim, + num_heads=num_heads, + num_experts=num_experts, + k=k, + expert_hidden_dim=expert_hidden_dim, + seed=seed, + verbose=True + ) + + print(f"\n" + "=" * 60) + print(f"Results:") + print(f" Input Q shape: {Q.shape}") + print(f" Input K shape: {K.shape}") + print(f" Input V shape: {V.shape}") + print(f" Output shape: {output.shape}") + print(f" Output range: [{output.min().item():.6f}, {output.max().item():.6f}]") + print("=" * 60) + + # PART 3: Benchmark with Allo-compatible config + print("\n" + "#" * 70) + print("# PART 3: Performance Benchmark (Allo-compatible config)") + print("#" * 70) + + # Use the same config as allo_attention_moe_alt.py for fair comparison + stats, _, Q_bench, K_bench, V_bench = benchmark_with_allo_config( + config_mode="switch_base_8_scaled_1_8", # Same as DEFAULT_CONFIG_MODE in llm_config + num_warmup=10, + num_runs=100, + device="cuda" if torch.cuda.is_available() else "cpu", + seed=42 # Same seed as Allo test + ) diff --git a/ece6775/llm_config/check_llm_config.py b/ece6775/llm_config/check_llm_config.py new file mode 100644 index 000000000..e9b7a78bd --- /dev/null +++ b/ece6775/llm_config/check_llm_config.py @@ -0,0 +1,124 @@ +# check MoE model config from Hugging Face +# usage: python check_llm_config.py + +import sys +from transformers import AutoConfig + +def print_moe_config(model_name): + # print MoE model config + try: + print(f"Loading model configuration: {model_name}") + print("=" * 80) + + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + + # Print basic configuration + print("\n[Basic Architecture Parameters]") + print(f"Model type: {config.model_type if hasattr(config, 'model_type') else 'N/A'}") + print(f"Vocabulary size: {config.vocab_size if hasattr(config, 'vocab_size') else 'N/A'}") + print(f"Max position embeddings: {config.max_position_embeddings if hasattr(config, 'max_position_embeddings') else 'N/A'}") + + # Print hidden layer dimensions + print("\n[Hidden Layer Dimensions]") + if hasattr(config, 'hidden_size'): + print(f"Hidden size (hidden_size): {config.hidden_size}") + if hasattr(config, 'd_model'): + print(f"Model dimension (d_model): {config.d_model}") + if hasattr(config, 'n_embd'): + print(f"Embedding dimension (n_embd): {config.n_embd}") + + # Print MoE related parameters + print("\n[MoE Expert Parameters]") + if hasattr(config, 'num_local_experts'): + print(f"Number of local experts (num_local_experts): {config.num_local_experts}") + if hasattr(config, 'num_experts'): + print(f"Number of experts (num_experts): {config.num_experts}") + if hasattr(config, 'num_experts_per_tok'): + print(f"Number of experts per token (num_experts_per_tok): {config.num_experts_per_tok}") + if hasattr(config, 'num_experts_to_select'): + print(f"Number of experts to select (num_experts_to_select): {config.num_experts_to_select}") + + # Print feed-forward network parameters + print("\n[Feed-Forward Network Parameters]") + if hasattr(config, 'intermediate_size'): + print(f"Intermediate size (intermediate_size): {config.intermediate_size}") + if hasattr(config, 'ffn_dim'): + print(f"FFN dimension (ffn_dim): {config.ffn_dim}") + if hasattr(config, 'n_inner'): + print(f"Inner dimension (n_inner): {config.n_inner}") + + # Print attention parameters + print("\n[Attention Parameters]") + if hasattr(config, 'num_attention_heads'): + print(f"Number of attention heads (num_attention_heads): {config.num_attention_heads}") + if hasattr(config, 'num_heads'): + print(f"Number of heads (num_heads): {config.num_heads}") + if hasattr(config, 'n_head'): + print(f"Number of heads (n_head): {config.n_head}") + + # Print layer count + print("\n[Layer Parameters]") + if hasattr(config, 'num_hidden_layers'): + print(f"Number of hidden layers (num_hidden_layers): {config.num_hidden_layers}") + if hasattr(config, 'num_layers'): + print(f"Number of layers (num_layers): {config.num_layers}") + if hasattr(config, 'n_layer'): + print(f"Number of layers (n_layer): {config.n_layer}") + + # Print all configuration (for debugging) + print("\n[Full Configuration Information]") + print("-" * 80) + for key, value in config.to_dict().items(): + if isinstance(value, (int, float, str, bool, type(None))): + print(f"{key}: {value}") + elif isinstance(value, (list, tuple)) and len(value) < 10: + print(f"{key}: {value}") + + print("\n" + "=" * 80) + print("Configuration loaded successfully!") + + # Provide suggested test parameters + print("\n[Suggested Test Parameters (for Allo Implementation)]") + hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', getattr(config, 'n_embd', None))) + num_experts = getattr(config, 'num_local_experts', getattr(config, 'num_experts', None)) + intermediate_size = getattr(config, 'intermediate_size', getattr(config, 'ffn_dim', getattr(config, 'n_inner', None))) + num_experts_per_tok = getattr(config, 'num_experts_per_tok', getattr(config, 'num_experts_to_select', 1)) + + if hidden_size: + print(f"input_dim = {hidden_size} # Input dimension") + if intermediate_size: + print(f"hidden_dim = {intermediate_size} # Expert hidden dimension") + if hidden_size: + print(f"output_dim = {hidden_size} # Output dimension") + if num_experts: + print(f"num_experts = {num_experts} # Number of experts") + if num_experts_per_tok: + print(f"k = {num_experts_per_tok} # Top-k (number of experts activated per token)") + print(f"batch_size = 1 # Batch size (adjustable as needed)") + print(f"seq_len = 128 # Sequence length (adjustable as needed, actual model may support longer)") + + except Exception as e: + print(f"Error: {e}") + import traceback + traceback.print_exc() + print("\nHint: Make sure transformers library is installed: pip install transformers") + print("If the model requires trust_remote_code=True, make sure to trust the model") + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print(__doc__) + print("\n[Common MoE Model Examples]") + print("1. DeepSeek MoE:") + print(" python check_moe_config.py deepseek-ai/deepseek-moe-16b-base") + print("\n2. Mixtral:") + print(" python check_moe_config.py mistralai/Mixtral-8x7B-v0.1") + print("\n3. Switch Transformer:") + print(" python check_moe_config.py google/switch-base-8") + print("\n4. GShard:") + print(" python check_moe_config.py google/gshard-gpt") + sys.exit(1) + + model_name = sys.argv[1] + print_moe_config(model_name) + diff --git a/ece6775/llm_config/llm_config.py b/ece6775/llm_config/llm_config.py new file mode 100644 index 000000000..f126c0ed3 --- /dev/null +++ b/ece6775/llm_config/llm_config.py @@ -0,0 +1,142 @@ +# LLM model config - shared for MoE and Attention+MoE + +# config modes: switch_base_8*, mixtral_8x7b*, deepseek*, custom +DEFAULT_CONFIG_MODE = "switch_base_8_scaled_1_8" + + +def get_moe_config(config_mode=None): + # get MoE config (works for standalone MoE and Attention+MoE) + if config_mode is None: + config_mode = DEFAULT_CONFIG_MODE + + config = {} + + if config_mode == "switch_base_8": + # Google Switch-Base-8 (original) + config = { + "batch_size": 1, + "seq_len": 128, + "input_dim": 768, # hidden_size + "output_dim": 768, # hidden_size + "num_experts": 8, + "k": 1, # Top-1 MoE + "hidden_dim": 2048, + } + elif config_mode == "switch_base_8_scaled_2_3": + # scaled to 2/3 + config = { + "batch_size": 1, + "seq_len": 128, + "input_dim": 512, # 768 * 2/3 + "output_dim": 512, + "num_experts": 2, + "k": 1, # Top-1 MoE + "hidden_dim": 1024, + } + elif config_mode == "switch_base_8_scaled_1_2": + # scaled to 1/2 + config = { + "batch_size": 1, + "seq_len": 128, + "input_dim": 384, # 768 * 1/2 + "output_dim": 384, + "num_experts": 2, + "k": 1, # Top-1 MoE + "hidden_dim": 1024, + } + elif config_mode == "switch_base_8_scaled_1_4": + # scaled to 1/4 + config = { + "batch_size": 1, + "seq_len": 128, + "input_dim": 192, # 768 * 1/4 + "output_dim": 192, + "num_experts": 2, + "k": 1, # Top-1 MoE + "hidden_dim": 512, + } + elif config_mode == "switch_base_8_scaled_1_8": + # scaled to 1/8 + config = { + "batch_size": 1, + "seq_len": 128, + "input_dim": 96, # 768 * 1/8 + "output_dim": 96, + "num_experts": 2, # changed to 2 for testing (original: 8) + "k": 1, + "hidden_dim": 256, + } + elif config_mode == "mixtral_8x7b": + # Mixtral-8x7B (original) + config = { + "batch_size": 1, + "seq_len": 128, + "input_dim": 4096, # hidden_size + "output_dim": 4096, # hidden_size + "num_experts": 8, + "k": 2, # Top-2 MoE + "hidden_dim": 14336, + } + elif config_mode == "mixtral_8x7b_scaled_1_8": + # scaled to 1/8 (Top-2 MoE) + config = { + "batch_size": 1, + "seq_len": 128, + "input_dim": 512, # 4096 / 8 + "output_dim": 512, + "num_experts": 8, + "k": 2, # Top-2 MoE (real Mixtral uses Top-2) + "hidden_dim": 2048, + } + elif config_mode == "deepseek": + # DeepSeek MoE-16B (needs significant resources) + config = { + "batch_size": 1, + "seq_len": 128, # Can be adjusted, model supports up to 4096 + "input_dim": 2048, # hidden_size + "output_dim": 2048, # hidden_size + "num_experts": 64, + "k": 6, + "hidden_dim": 10944, + } + elif config_mode == "deepseek_scaled": + # scaled version + config = { + "batch_size": 1, + "seq_len": 128, + "input_dim": 512, # 2048 / 4 + "output_dim": 512, + "num_experts": 8, # 64 / 8 + "k": 1, # 6 -> 1 (current implementation supports k=1) + "hidden_dim": 2048, + } + else: # custom + # small test config + config = { + "batch_size": 4, + "seq_len": 4, + "input_dim": 64, + "output_dim": 64, + "num_experts": 2, + "k": 1, # Top-1 MoE + "hidden_dim": 256, + } + + return config + + +def print_config_info(config_mode, config): + # print config info + print("=" * 60) + print(f"LLM Configuration (MoE)") + print("=" * 60) + print(f"Model: {config_mode}") + print(f"Batch size: {config['batch_size']}") + print(f"Sequence length: {config['seq_len']}") + print(f"Input dimension: {config['input_dim']}") + print(f"Output dimension: {config['output_dim']}") + print(f"Number of experts: {config['num_experts']}") + print(f"Top-k: {config['k']}") + print(f"Hidden dimension: {config['hidden_dim']}") + print("=" * 60) + diff --git a/ece6775/moe/allo_moe_alt.py b/ece6775/moe/allo_moe_alt.py new file mode 100644 index 000000000..1906c9ac2 --- /dev/null +++ b/ece6775/moe/allo_moe_alt.py @@ -0,0 +1,446 @@ +# MoE with fused GeLU optimization for HLS +# fuses GeLU into FC1 to save memory and enable pipelining +# small diff vs pytorch (~1e-4) from accumulation order + +import numpy as np +import allo +from allo.ir.types import float32, int32 +from allo import dsl + +MODE = "csyn" # llvm, sw_emu, hw_emu, hw, csyn + +import sys +import os +# path setup for config +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(current_dir) +llm_config_dir = os.path.join(project_root, "llm_config") +if llm_config_dir not in sys.path: + sys.path.insert(0, llm_config_dir) + +from llm_config import DEFAULT_CONFIG_MODE, get_moe_config, print_config_info + +CONFIG_MODE = DEFAULT_CONFIG_MODE +def softmax_1d[Ty, N, K](X: "Ty[N, K]") -> "Ty[N, K]": + # stable softmax over last dim + Z: Ty[N, K] + E_exp: Ty[N, K] + M: Ty[N] = -1e12 + S: Ty[N] = 0.0 + + # find max per row + for n, k in dsl.grid(N, K, name="row_max"): + if X[n, k] > M[n]: + M[n] = X[n, k] + + # exp and sum + for n, k in dsl.grid(N, K, name="exp_sum"): + E_exp[n, k] = dsl.exp(X[n, k] - M[n]) + S[n] += E_exp[n, k] + + # normalize + for n, k in dsl.grid(N, K, name="normalize"): + Z[n, k] = E_exp[n, k] / S[n] + + return Z + +# top-1 selection (argmax) +def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": + # pick best expert per token + indices: int32[N] + max_val: Ty[N] = -1e12 + + for n in range(N, name="init"): + indices[n] = 0 + max_val[n] = logits[n, 0] + + for n, e in dsl.grid(N, E, name="argmax"): + if e > 0 and logits[n, e] > max_val[n]: + max_val[n] = logits[n, e] + indices[n] = e + + return indices +# FFN expert with fused GeLU, row-by-row for pipelining +def expert[Ty, N, D_in, D_hidden, D_out]( + x: "Ty[N, D_in]", + fc1_weight: "Ty[D_hidden, D_in]", + fc1_bias: "Ty[D_hidden]", + fc2_weight: "Ty[D_out, D_hidden]", + fc2_bias: "Ty[D_out]" +) -> "Ty[N, D_out]": + output: Ty[N, D_out] = 0.0 + + for n in range(N, name="row_loop"): + hidden_row: Ty[D_hidden] = 0.0 + + # FC1 + GeLU together + for h in range(D_hidden, name="fc1_gelu_loop"): + acc: Ty = 0.0 + for k in range(D_in, name="fc1_reduce"): + acc += x[n, k] * fc1_weight[h, k] + fc1_val: Ty = acc + fc1_bias[h] + + # GeLU: tanh approximation + x3: Ty = fc1_val * fc1_val * fc1_val + inner: Ty = 0.7978845608028654 * (fc1_val + 0.044715 * x3) # sqrt(2/pi) ≈ 0.7978... + hidden_row[h] = 0.5 * fc1_val * (1.0 + dsl.tanh(inner)) + + # FC2 + for o in range(D_out, name="fc2_loop"): + acc2: Ty = 0.0 + for k in range(D_hidden, name="fc2_reduce"): + acc2 += hidden_row[k] * fc2_weight[o, k] + output[n, o] = acc2 + fc2_bias[o] + + return output +# Top-1 MoE: pick one expert per token, weight with softmax +def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( + x: "Ty[N, D_in]", + gate_weight: "Ty[E, D_in]", + gate_bias: "Ty[E]", + expert_fc1_weights: "Ty[E, D_hidden, D_in]", + expert_fc1_biases: "Ty[E, D_hidden]", + expert_fc2_weights: "Ty[E, D_out, D_hidden]", + expert_fc2_biases: "Ty[E, D_out]" +) -> "Ty[N, D_out]": + # compute gate scores + gate_logits: Ty[N, E] = 0.0 + for n, e in dsl.grid(N, E, name="gate_linear"): + acc: Ty = 0.0 + for k in range(D_in, name="gate_reduce"): + acc += x[n, k] * gate_weight[e, k] + gate_logits[n, e] = acc + gate_bias[e] + + # pick best expert + top1_idx: int32[N] = top1_select[Ty, N, E](gate_logits) + + # get top-k logits (for k=1) + top_k_logits: Ty[N, K] = 0.0 + for n, k in dsl.grid(N, K, name="topk_logits"): + expert_idx = top1_idx[n] if k == 0 else 0 + top_k_logits[n, k] = gate_logits[n, expert_idx] + + top_k_weights = softmax_1d[Ty, N, K](top_k_logits) + + # sparse weights + gate_w: Ty[N, E] = 0.0 + for n in range(N, name="sparse_gate"): + gate_w[n, top1_idx[n]] = top_k_weights[n, 0] + + # run all experts (could optimize to skip unused ones later) + expert_out: Ty[E, N, D_out] = 0.0 + + for e in range(E, name="expert_loop"): + # extract weights for this expert + fc1_w: Ty[D_hidden, D_in] = 0.0 + fc1_b: Ty[D_hidden] = 0.0 + fc2_w: Ty[D_out, D_hidden] = 0.0 + fc2_b: Ty[D_out] = 0.0 + + for i, j in dsl.grid(D_hidden, D_in, name="extract_fc1_w"): + fc1_w[i, j] = expert_fc1_weights[e, i, j] + for i in range(D_hidden, name="extract_fc1_b"): + fc1_b[i] = expert_fc1_biases[e, i] + for i, j in dsl.grid(D_out, D_hidden, name="extract_fc2_w"): + fc2_w[i, j] = expert_fc2_weights[e, i, j] + for i in range(D_out, name="extract_fc2_b"): + fc2_b[i] = expert_fc2_biases[e, i] + + out = expert[Ty, N, D_in, D_hidden, D_out](x, fc1_w, fc1_b, fc2_w, fc2_b) + + for n, d in dsl.grid(N, D_out, name="store_expert_out"): + expert_out[e, n, d] = out[n, d] + + # combine with weights + output: Ty[N, D_out] = 0.0 + for n, e, d_out in dsl.grid(N, E, D_out, name="combine_outputs"): + output[n, d_out] += expert_out[e, n, d_out] * gate_w[n, e] + + return output + + +# build schedule with pipelining/unrolling optimizations +def optimize_moe_schedule(num_tokens, input_dim, output_dim, num_experts, k, hidden_dim): + Ty = float32 + N = num_tokens + D_in = input_dim + D_out = output_dim + E = num_experts + K = k + D_hidden = hidden_dim + + print("Building MoE schedule...") + + # top1_select + s_top1 = allo.customize(top1_select, instantiate=[Ty, N, E]) + s_top1.pipeline("top1_select:e") + + # softmax + s_softmax = allo.customize(softmax_1d, instantiate=[Ty, N, K]) + sm_loops = s_softmax.get_loops(s_softmax.top_func_name) + s_softmax.pipeline(sm_loops["row_max"]["k"]) + s_softmax.pipeline(sm_loops["exp_sum"]["k"]) + s_softmax.pipeline(sm_loops["normalize"]["k"]) + + # expert + s_expert = allo.customize(expert, instantiate=[Ty, N, D_in, D_hidden, D_out]) + exp_loops = s_expert.get_loops(s_expert.top_func_name) + row_loop = exp_loops["row_loop"] + s_expert.pipeline(row_loop["h"]) + s_expert.pipeline(row_loop["o"]) + + # unroll reduction loop for parallel MACs + unroll_factor = min(4, D_in, D_hidden) + s_expert.unroll(row_loop["k"], factor=unroll_factor) + print(f" - Applied unroll to k (reduction loop) factor={unroll_factor}") + + # array partitioning + if D_hidden <= 32: + s_expert.partition(s_expert.hidden_row, dim=0) + print(f" - Applied complete partition to hidden_row (D_hidden={D_hidden})") + else: + s_expert.partition(s_expert.hidden_row, dim=0, factor=4) + print(f" - Applied cyclic partition to hidden_row (factor=4)") + + print(" - Created expert schedule with pipeline, unroll, and partition optimizations") + + # moe_layer schedule + print("\n[4] Creating schedule for moe_layer...") + s_moe = allo.customize(moe_layer, instantiate=[Ty, N, D_in, D_out, E, K, D_hidden]) + + moe_loops = s_moe.get_loops(s_moe.top_func_name) + print(f" - Available top-level loops: {list(moe_loops.loops.keys())}") + + # optimize gate_linear + gate_linear_loop = moe_loops["gate_linear"] + print(f" - gate_linear sub-loops: {list(gate_linear_loop.loops.keys())}") + + s_moe.pipeline(gate_linear_loop["e"]) + print(" - Applied pipeline to gate_linear:e") + + unroll_factor_gate = min(4, D_in) + s_moe.unroll(gate_linear_loop["k"], factor=unroll_factor_gate) + print(f" - Applied unroll to gate_linear:k (factor={unroll_factor_gate})") + + # other loops + combine_loop = moe_loops["combine_outputs"] + s_moe.pipeline(combine_loop["d_out"]) + + topk_loop = moe_loops["topk_logits"] + s_moe.pipeline(topk_loop["k"]) + + sparse_loop = moe_loops["sparse_gate"] + s_moe.pipeline(sparse_loop["n"]) + + print(" - Applied pipeline to combine_outputs:d_out, topk_logits:k, sparse_gate:n") + + # compose sub-schedules + s_moe.compose(s_top1) + s_moe.compose(s_softmax) + s_moe.compose(s_expert) + print(" - Composed top1_select, softmax_1d, and expert schedules") + + print("\n" + "=" * 60) + print("Schedule done!") + print("=" * 60) + print("Optimizations:") + print(" 1. Fused GeLU into FC1") + print(" 2. Row-level structure") + print(" 3. Unrolled reduction loops") + print(f" - Expert k: factor {unroll_factor}") + print(f" - Gate k: factor {unroll_factor_gate}") + print(" 4. Array partitioning for hidden_row") + print(" 5. Pipeline on h, o, e loops") + print("=" * 60) + + return s_moe + + +# test/compare with pytorch +if __name__ == "__main__": + import torch + import torch.nn as nn + from pytorch_moe import MoELayer + + moe_config = get_moe_config(CONFIG_MODE) + + batch_size = moe_config["batch_size"] + seq_len = moe_config["seq_len"] + input_dim = moe_config["input_dim"] + output_dim = moe_config["output_dim"] + num_experts = moe_config["num_experts"] + k = moe_config["k"] + hidden_dim = moe_config["hidden_dim"] + seed = 42 + + print("=" * 60) + print("MoE Allo (Vayun Optimized) vs PyTorch Comparison Test") + print("=" * 60) + print(f"Configuration Mode: {CONFIG_MODE}") + print(f"Configuration:") + print(f" batch_size={batch_size}, seq_len={seq_len}") + print(f" input_dim={input_dim}, output_dim={output_dim}") + print(f" num_experts={num_experts}, k={k}, hidden_dim={hidden_dim}") + print(f" Seed: {seed}") + print("=" * 60) + + # pytorch baseline + print("\n[1] Running PyTorch implementation...") + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # Create PyTorch MoE layer + pytorch_moe = MoELayer(input_dim, output_dim, num_experts, k, hidden_dim) + pytorch_moe.eval() + + # Initialize with Xavier uniform + for param in pytorch_moe.parameters(): + if param.dim() > 1: + nn.init.xavier_uniform_(param) + else: + nn.init.zeros_(param) + + # Create random input + torch.manual_seed(seed) + pytorch_input = torch.randn(batch_size, seq_len, input_dim) + + # Run PyTorch inference + with torch.no_grad(): + pytorch_output = pytorch_moe(pytorch_input, verbose=False) + + print(f"PyTorch output shape: {pytorch_output.shape}") + print(f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]") + + # extract weights + print("\n[2] Extracting weights from PyTorch model...") + + # Gate weights + gate_weight_pt = pytorch_moe.gate.gate_linear.weight.data # [num_experts, input_dim] + gate_bias_pt = pytorch_moe.gate.gate_linear.bias + if gate_bias_pt is not None: + gate_bias_pt = gate_bias_pt.data + else: + gate_bias_pt = torch.zeros(num_experts) + + # Expert weights + expert_fc1_weights_pt = torch.stack([exp.fc1.weight.data for exp in pytorch_moe.experts]) + expert_fc1_biases_pt = torch.stack([exp.fc1.bias.data for exp in pytorch_moe.experts]) + expert_fc2_weights_pt = torch.stack([exp.fc2.weight.data for exp in pytorch_moe.experts]) + expert_fc2_biases_pt = torch.stack([exp.fc2.bias.data for exp in pytorch_moe.experts]) + + # convert to numpy + print("\n[3] Converting weights to numpy arrays...") + x_np = np.ascontiguousarray(pytorch_input.detach().numpy(), dtype=np.float32) + + gate_weight_np = np.ascontiguousarray(gate_weight_pt.detach().numpy(), dtype=np.float32) + gate_bias_np = np.ascontiguousarray(gate_bias_pt.detach().numpy(), dtype=np.float32) + + expert_fc1_weights_np = np.ascontiguousarray(expert_fc1_weights_pt.detach().numpy(), dtype=np.float32) + expert_fc1_biases_np = np.ascontiguousarray(expert_fc1_biases_pt.detach().numpy(), dtype=np.float32) + expert_fc2_weights_np = np.ascontiguousarray(expert_fc2_weights_pt.detach().numpy(), dtype=np.float32) + expert_fc2_biases_np = np.ascontiguousarray(expert_fc2_biases_pt.detach().numpy(), dtype=np.float32) + + print(f"Input shape: {x_np.shape}") + print(f"Gate weight shape: {gate_weight_np.shape}") + print(f"Expert FC1 weights shape: {expert_fc1_weights_np.shape}") + print(f"Expert FC2 weights shape: {expert_fc2_weights_np.shape}") + + # flatten input (N = B * L) + num_tokens = batch_size * seq_len + x_flat_np = x_np.reshape(num_tokens, input_dim) + print(f"Flattened input shape: {x_flat_np.shape}") + + # run allo + print("\n[4] Running Allo implementation...") + try: + # Create optimized schedule + allo_schedule = optimize_moe_schedule( + num_tokens, input_dim, output_dim, num_experts, k, hidden_dim + ) + + # Generate project name + project_name = f"allo_moe_alt_{CONFIG_MODE}.prj" + print(f"Using project name: {project_name}") + + # Build module + print("\n[5] Building Allo module...") + if MODE == "llvm": + mod = allo_schedule.build(target="llvm") + elif MODE == "sw_emu": + mod = allo_schedule.build(target="vitis_hls", mode="sw_emu", project=project_name) + elif MODE == "hw_emu": + mod = allo_schedule.build(target="vitis_hls", mode="hw_emu", project=project_name) + elif MODE == "hw": + mod = allo_schedule.build(target="vitis_hls", mode="hw", project=project_name) + elif MODE == "csyn": + mod = allo_schedule.build(target="vitis_hls", mode="csyn", project=project_name) + else: + raise ValueError(f"Unsupported mode: {MODE}") + + # Run Allo inference + print("\n[6] Running Allo inference...") + if MODE == "llvm": + allo_output_flat = mod( + x_flat_np, + gate_weight_np, gate_bias_np, + expert_fc1_weights_np, expert_fc1_biases_np, + expert_fc2_weights_np, expert_fc2_biases_np + ) + elif MODE in ["sw_emu", "hw_emu", "hw"]: + allo_output_flat = np.zeros((num_tokens, output_dim), dtype=np.float32) + mod(x_flat_np, + gate_weight_np, gate_bias_np, + expert_fc1_weights_np, expert_fc1_biases_np, + expert_fc2_weights_np, expert_fc2_biases_np, + allo_output_flat) + elif MODE == "csyn": + allo_output_flat = np.zeros((num_tokens, output_dim), dtype=np.float32) + mod() + else: + raise ValueError(f"Unsupported mode: {MODE}") + + # Reshape output back to [B, L, D_out] + allo_output = allo_output_flat.reshape(batch_size, seq_len, output_dim) + + print(f"Allo output shape: {allo_output.shape}") + print(f"Allo output range: [{allo_output.min():.6f}, {allo_output.max():.6f}]") + + # compare + print("\n[7] Comparing outputs...") + pytorch_output_np = pytorch_output.detach().numpy() + + # Compute differences + diff = np.abs(allo_output - pytorch_output_np) + mean_diff = np.mean(diff) + max_diff = np.max(diff) + rel_diff = np.mean(diff / (np.abs(pytorch_output_np) + 1e-8)) + + print(f"Mean absolute difference: {mean_diff:.6e}") + print(f"Max absolute difference: {max_diff:.6e}") + print(f"Mean relative difference: {rel_diff:.6e}") + + # Check if outputs are close + atol = 5e-4 + rtol = 2e-3 + is_close = np.allclose(allo_output, pytorch_output_np, atol=atol, rtol=rtol) + + if is_close: + print(f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})") + else: + print(f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})") + print("First few differences:") + print(diff.flatten()[:10]) + + # sample outputs + print("\n[8] Sample outputs (first token, first 5 dimensions):") + print(f"PyTorch: {pytorch_output_np[0, 0, :5]}") + print(f"Allo: {allo_output[0, 0, :5]}") + print(f"Diff: {diff[0, 0, :5]}") + + except Exception as e: + print(f"\n✗ ERROR: Failed to run Allo implementation: {e}") + import traceback + traceback.print_exc() + + print("=" * 60) \ No newline at end of file diff --git a/ece6775/moe/allo_moe_base.py b/ece6775/moe/allo_moe_base.py new file mode 100644 index 000000000..7688f723b --- /dev/null +++ b/ece6775/moe/allo_moe_base.py @@ -0,0 +1,489 @@ +# MoE in Allo - manual implementation (no nn lib) +# matches pytorch but with small numerical differences (1e-4 to 1e-3) from: +# - different accumulation order +# - GELU approximation differences +# - fp precision +# these are normal and don't affect performance + +import allo +from allo.ir.types import float32, int32 +from allo import dsl +import sys +import os + +# path setup +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(current_dir) +llm_config_dir = os.path.join(project_root, "llm_config") +if llm_config_dir not in sys.path: + sys.path.insert(0, llm_config_dir) + +MODE = "sw_emu" # llvm, sw_emu, hw_emu, hw, csyn + +from llm_config import DEFAULT_CONFIG_MODE, get_moe_config, print_config_info + +CONFIG_MODE = DEFAULT_CONFIG_MODE +def softmax_1d[Ty, N, E](X: "Ty[N, E]") -> "Ty[N, E]": + # softmax over last dim + Z: Ty[N, E] + E_exp: Ty[N, E] + M: Ty[N] = -1000000000000.0 # TODO: use -inf if available + S: Ty[N] = 0.0 + + # find max per row + for i in range(N): + for j in range(E): + if X[i, j] > M[i]: + M[i] = X[i, j] + + # exp and sum + for i in range(N): + for j in range(E): + E_exp[i, j] = dsl.exp(X[i, j] - M[i]) + S[i] += E_exp[i, j] + + # normalize + for i in range(N): + for j in range(E): + Z[i, j] = E_exp[i, j] / S[i] + + return Z + + +# top-1 selection (argmax) +def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": + # pick best expert for each token + indices: int32[N] + max_val: Ty[N] = -1000000000000.0 + + for i in range(N): + indices[i] = 0 + max_val[i] = logits[i, 0] + # check rest + for j in range(1, E): + if logits[i, j] > max_val[i]: + max_val[i] = logits[i, j] + indices[i] = j + + return indices + + +# TopKGate: Gate module to select top-k experts +def topk_gate[Ty, N, D, E, K]( + x: "Ty[N, D]", + gate_weight: "Ty[E, D]", + gate_bias: "Ty[E]" +) -> ("Ty[N, E]", "int32[N, K]"): + """ + Gate module to select top-k experts for routing. + + Args: + x: Input tensor of shape [N, D] where N = batch * seq_len + gate_weight: Gate weight matrix of shape [E, D] + gate_bias: Gate bias vector of shape [E] + Returns: + full_weights: Sparse weight matrix of shape [N, E] + top_k_indices: Indices of selected experts of shape [N, K] + """ + # Compute logits using linear layer (manual implementation) + # Computes: x @ gate_weight^T + gate_bias = [N, D] @ [D, E] + [E] = [N, E] + # Note: In PyTorch version, bias=False, so gate_bias should be zeros + logits: Ty[N, E] = 0.0 + for i in range(N): + for j in range(E): + # Initialize with bias + logits[i, j] = gate_bias[j] + # Matrix multiplication: x[i] @ gate_weight[j]^T + for k in range(D): + logits[i, j] += x[i, k] * gate_weight[j, k] + + # For k=1, use top1_select + top1_indices_1d: int32[N] = top1_select[Ty, N, E](logits) + + # Expand to [N, K] format (for k=1, K=1) + top_k_indices: int32[N, K] = 0 + for i in range(N): + top_k_indices[i, 0] = top1_indices_1d[i] + + # Get top-k logits + top_k_logits: Ty[N, K] = 0.0 + for i in range(N): + for k in range(K): + expert_idx = top_k_indices[i, k] + top_k_logits[i, k] = logits[i, expert_idx] + + # Apply softmax to top-k logits + top_k_weights = softmax_1d[Ty, N, K](top_k_logits) # [N, K] + + # Create sparse weight matrix + full_weights: Ty[N, E] = 0.0 + for i in range(N): + for k in range(K): + expert_idx = top_k_indices[i, k] + full_weights[i, expert_idx] = top_k_weights[i, k] + + return full_weights, top_k_indices + + +# simple FFN expert +def expert[Ty, N, D_in, D_hidden, D_out]( + x: "Ty[N, D_in]", + fc1_weight: "Ty[D_hidden, D_in]", + fc1_bias: "Ty[D_hidden]", + fc2_weight: "Ty[D_out, D_hidden]", + fc2_bias: "Ty[D_out]" +) -> "Ty[N, D_out]": + # FC1 + fc1_out: Ty[N, D_hidden] = 0.0 + for i in range(N): + for j in range(D_hidden): + fc1_out[i, j] = fc1_bias[j] + for k in range(D_in): + fc1_out[i, j] += x[i, k] * fc1_weight[j, k] + + # GELU + gelu_out: Ty[N, D_hidden] = 0.0 + for i in range(N): + for j in range(D_hidden): + x_val: Ty = fc1_out[i, j] + x3: Ty = x_val * x_val * x_val + inner: Ty = 0.797885 * (x_val + 0.044715 * x3) + tanh_term: Ty = dsl.tanh(inner) + gelu_out[i, j] = 0.5 * x_val * (1.0 + tanh_term) + + # FC2 + fc2_out: Ty[N, D_out] = 0.0 + for i in range(N): + for j in range(D_out): + fc2_out[i, j] = fc2_bias[j] + for k in range(D_hidden): + fc2_out[i, j] += gelu_out[i, k] * fc2_weight[j, k] + + return fc2_out + + +# main MoE layer +def moe_layer[Ty, B, L, D_in, D_out, E, K, D_hidden]( + x: "Ty[B, L, D_in]", + gate_weight: "Ty[E, D_in]", + gate_bias: "Ty[E]", + expert_fc1_weights: "Ty[E, D_hidden, D_in]", + expert_fc1_biases: "Ty[E, D_hidden]", + expert_fc2_weights: "Ty[E, D_out, D_hidden]", + expert_fc2_biases: "Ty[E, D_out]" +) -> "Ty[B, L, D_out]": + # process all tokens through all experts, then weight by gate + # matches pytorch behavior + gate_logits: Ty[B * L, E] = 0.0 + + # Compute gate logits for all tokens + for b in range(B): + for l in range(L): + token_idx = b * L + l + # Compute logits for this token: x[b, l] @ gate_weight^T + gate_bias + for j in range(E): + # Initialize with bias + gate_logits[token_idx, j] = gate_bias[j] + # Matrix multiplication: x[b, l] @ gate_weight[j]^T + for k in range(D_in): + gate_logits[token_idx, j] += x[b, l, k] * gate_weight[j, k] + + # Select top-k experts and compute weights + gate_weights: Ty[B * L, E] = 0.0 + top_k_indices: int32[B * L, K] = 0 + + for i in range(B * L): + # Find top-1 expert (argmax) for this token + top_expert_idx: int32 = 0 + max_logit: Ty = gate_logits[i, 0] + + for j in range(1, E): + if gate_logits[i, j] > max_logit: + max_logit = gate_logits[i, j] + top_expert_idx = j + + # Store top-k indices (for k=1) + top_k_indices[i, 0] = top_expert_idx + + # Get top-k logit and apply softmax + # For k=1, softmax just returns 1.0, but we compute it for consistency + # softmax(top_k_logit) = exp(top_k_logit) / exp(top_k_logit) = 1.0 + gate_weights[i, top_expert_idx] = 1.0 + + # Flatten input for expert processing + x_flat: Ty[B * L, D_in] = 0.0 + for b in range(B): + for l in range(L): + for d in range(D_in): + x_flat[b * L + l, d] = x[b, l, d] + + # Process each expert: compute outputs for all tokens + # Then use gate weights to select correct outputs + expert_outputs: Ty[E, B * L, D_out] = 0.0 + + for expert_idx in range(E): + # Extract expert weights for this expert + expert_fc1_w: Ty[D_hidden, D_in] = 0.0 + expert_fc1_b: Ty[D_hidden] = 0.0 + expert_fc2_w: Ty[D_out, D_hidden] = 0.0 + expert_fc2_b: Ty[D_out] = 0.0 + + for h in range(D_hidden): + for d in range(D_in): + expert_fc1_w[h, d] = expert_fc1_weights[expert_idx, h, d] + expert_fc1_b[h] = expert_fc1_biases[expert_idx, h] + + for o in range(D_out): + for h in range(D_hidden): + expert_fc2_w[o, h] = expert_fc2_weights[expert_idx, o, h] + expert_fc2_b[o] = expert_fc2_biases[expert_idx, o] + + # Process all tokens through this expert (inline expert function) + # First linear layer: fc1_out = x_flat @ expert_fc1_w^T + expert_fc1_b + fc1_out: Ty[B * L, D_hidden] = 0.0 + for i in range(B * L): + for j in range(D_hidden): + # Initialize with bias + fc1_out[i, j] = expert_fc1_b[j] + # Matrix multiplication: x_flat[i] @ expert_fc1_w[j]^T + for k in range(D_in): + fc1_out[i, j] += x_flat[i, k] * expert_fc1_w[j, k] + + # GELU activation: 0.5 * x * (1 + tanh(0.797885 * (x + 0.044715 * x^3))) + gelu_out: Ty[B * L, D_hidden] = 0.0 + for i in range(B * L): + for j in range(D_hidden): + x_val: Ty = fc1_out[i, j] + # Compute x^3 + x3: Ty = x_val * x_val * x_val + # Inner term: sqrt(2/pi) * (x + 0.044715 * x^3) + # Use exact value: sqrt(2/pi) ≈ 0.7978845608028654 + inner: Ty = 0.7978845608028654 * (x_val + 0.044715 * x3) + # Tanh + tanh_term: Ty = dsl.tanh(inner) + # GELU: 0.5 * x * (1 + tanh_term) + gelu_out[i, j] = 0.5 * x_val * (1.0 + tanh_term) + + # Second linear layer: fc2_out = gelu_out @ expert_fc2_w^T + expert_fc2_b + expert_out: Ty[B * L, D_out] = 0.0 + for i in range(B * L): + for j in range(D_out): + # Initialize with bias + expert_out[i, j] = expert_fc2_b[j] + # Matrix multiplication: gelu_out[i] @ expert_fc2_w[j]^T + for k in range(D_hidden): + expert_out[i, j] += gelu_out[i, k] * expert_fc2_w[j, k] + + # Store expert outputs + for i in range(B * L): + for o in range(D_out): + expert_outputs[expert_idx, i, o] = expert_out[i, o] + + # Combine expert outputs using gate weights + # For each token, sum over experts weighted by gate_weights + output_flat: Ty[B * L, D_out] = 0.0 + for i in range(B * L): + for expert_idx in range(E): + weight: Ty = gate_weights[i, expert_idx] + for o in range(D_out): + output_flat[i, o] += expert_outputs[expert_idx, i, o] * weight + + # Reshape back to original shape + output: Ty[B, L, D_out] = 0.0 + for b in range(B): + for l in range(L): + for o in range(D_out): + output[b, l, o] = output_flat[b * L + l, o] + + return output + + +# test/compare with pytorch +if __name__ == "__main__": + import numpy as np + import torch + import torch.nn as nn + from pytorch_moe import MoELayer, run_moe_inference + + config = get_moe_config(CONFIG_MODE) + batch_size = config["batch_size"] + seq_len = config["seq_len"] + input_dim = config["input_dim"] + output_dim = config["output_dim"] + num_experts = config["num_experts"] + k = config["k"] + hidden_dim = config["hidden_dim"] + + seed = 42 + + # Print configuration info using shared function + print("=" * 60) + print("MoE Allo vs PyTorch Comparison Test") + print("=" * 60) + print_config_info(CONFIG_MODE, config) + print(f"Seed: {seed}") + print("=" * 60) + + # pytorch baseline + print("\n[1] Running PyTorch implementation...") + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # Create PyTorch MoE layer + pytorch_moe = MoELayer(input_dim, output_dim, num_experts, k, hidden_dim) + pytorch_moe.eval() + + # Initialize with Xavier uniform (matching PyTorch) + for param in pytorch_moe.parameters(): + if param.dim() > 1: + nn.init.xavier_uniform_(param) + else: + nn.init.zeros_(param) + + # Create random input + torch.manual_seed(seed) + pytorch_input = torch.randn(batch_size, seq_len, input_dim) + + # Run PyTorch inference + with torch.no_grad(): + pytorch_output = pytorch_moe(pytorch_input, verbose=False) + + print(f"PyTorch output shape: {pytorch_output.shape}") + print(f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]") + + # extract weights + print("\n[2] Extracting weights from PyTorch model...") + + # Gate weights: [input_dim, num_experts] -> [num_experts, input_dim] for Allo + gate_weight_pt = pytorch_moe.gate.gate_linear.weight.data # [num_experts, input_dim] + gate_bias_pt = pytorch_moe.gate.gate_linear.bias + if gate_bias_pt is not None: + gate_bias_pt = gate_bias_pt.data + else: + gate_bias_pt = torch.zeros(num_experts) + + # Expert weights + expert_fc1_weights_pt = torch.stack([expert.fc1.weight.data for expert in pytorch_moe.experts]) # [num_experts, hidden_dim, input_dim] + expert_fc1_biases_pt = torch.stack([expert.fc1.bias.data for expert in pytorch_moe.experts]) # [num_experts, hidden_dim] + expert_fc2_weights_pt = torch.stack([expert.fc2.weight.data for expert in pytorch_moe.experts]) # [num_experts, output_dim, hidden_dim] + expert_fc2_biases_pt = torch.stack([expert.fc2.bias.data for expert in pytorch_moe.experts]) # [num_experts, output_dim] + + # convert to numpy + print("\n[3] Converting weights to numpy arrays...") + x_np = np.ascontiguousarray(pytorch_input.detach().numpy(), dtype=np.float32) + + # Gate weights: PyTorch has [num_experts, input_dim], Allo expects [num_experts, input_dim] (same) + gate_weight_np = np.ascontiguousarray(gate_weight_pt.detach().numpy(), dtype=np.float32) + gate_bias_np = np.ascontiguousarray(gate_bias_pt.detach().numpy(), dtype=np.float32) + + # Expert weights: PyTorch has [num_experts, hidden_dim, input_dim], Allo expects [num_experts, hidden_dim, input_dim] (same) + expert_fc1_weights_np = np.ascontiguousarray(expert_fc1_weights_pt.detach().numpy(), dtype=np.float32) + expert_fc1_biases_np = np.ascontiguousarray(expert_fc1_biases_pt.detach().numpy(), dtype=np.float32) + expert_fc2_weights_np = np.ascontiguousarray(expert_fc2_weights_pt.detach().numpy(), dtype=np.float32) + expert_fc2_biases_np = np.ascontiguousarray(expert_fc2_biases_pt.detach().numpy(), dtype=np.float32) + + print(f"Input shape: {x_np.shape}") + print(f"Gate weight shape: {gate_weight_np.shape}") + print(f"Expert FC1 weights shape: {expert_fc1_weights_np.shape}") + print(f"Expert FC2 weights shape: {expert_fc2_weights_np.shape}") + + # run allo + print("\n[4] Running Allo implementation...") + try: + # Customize Allo module + allo_mod = allo.customize( + moe_layer, + instantiate=[float32, batch_size, seq_len, input_dim, output_dim, num_experts, k, hidden_dim] + ) + + # Generate project name based on CONFIG_MODE to avoid conflicts + # This ensures different configurations use different build folders + project_name = f"allo_moe_base_{CONFIG_MODE}.prj" + print(f"Using project name: {project_name}") + + # Build module + print("Building Allo module...") + if MODE == "llvm": + mod = allo_mod.build(target="llvm") # Use LLVM for CPU testing + elif MODE == "sw_emu": + mod = allo_mod.build(target="vitis_hls", mode="sw_emu", project=project_name) + elif MODE == "hw_emu": + mod = allo_mod.build(target="vitis_hls", mode="hw_emu", project=project_name) + elif MODE == "hw": + mod = allo_mod.build(target="vitis_hls", mode="hw", project=project_name) + elif MODE == "csyn": + mod = allo_mod.build(target="vitis_hls", mode="csyn", project=project_name) + else: + raise ValueError(f"Unsupported mode: {MODE}") + + # Run Allo inference + # Note: In LLVM backend, functions with return values return the result directly + # We don't need to pass output buffer as argument + print("Running Allo inference...") + if MODE == "llvm": + allo_output = mod( + x_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np + ) + elif MODE == "sw_emu": + allo_output = np.zeros((batch_size, seq_len, output_dim), dtype=np.float32) + mod(x_np, gate_weight_np, gate_bias_np, expert_fc1_weights_np, expert_fc1_biases_np, expert_fc2_weights_np, expert_fc2_biases_np, allo_output) + elif MODE == "hw_emu": + allo_output = np.zeros((batch_size, seq_len, output_dim), dtype=np.float32) + mod(x_np, gate_weight_np, gate_bias_np, expert_fc1_weights_np, expert_fc1_biases_np, expert_fc2_weights_np, expert_fc2_biases_np, allo_output) + elif MODE == "hw": + allo_output = np.zeros((batch_size, seq_len, output_dim), dtype=np.float32) + mod(x_np, gate_weight_np, gate_bias_np, expert_fc1_weights_np, expert_fc1_biases_np, expert_fc2_weights_np, expert_fc2_biases_np, allo_output) + elif MODE == "csyn": + allo_output = np.zeros((batch_size, seq_len, output_dim), dtype=np.float32) + mod() + else: + raise ValueError(f"Unsupported mode: {MODE}") + + print(f"Allo output shape: {allo_output.shape}") + print(f"Allo output range: [{allo_output.min():.6f}, {allo_output.max():.6f}]") + + # compare + print("\n[5] Comparing outputs...") + pytorch_output_np = pytorch_output.detach().numpy() + + # Compute differences + diff = np.abs(allo_output - pytorch_output_np) + mean_diff = np.mean(diff) + max_diff = np.max(diff) + rel_diff = np.mean(diff / (np.abs(pytorch_output_np) + 1e-8)) + + print(f"Mean absolute difference: {mean_diff:.6e}") + print(f"Max absolute difference: {max_diff:.6e}") + print(f"Mean relative difference: {rel_diff:.6e}") + + # check if close (1e-4 to 1e-3 diff is normal from accumulation order, GELU approx, fp precision) + atol = 5e-4 + rtol = 2e-3 + is_close = np.allclose(allo_output, pytorch_output_np, atol=atol, rtol=rtol) + + if is_close: + print(f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})") + else: + print(f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})") + print("First few differences:") + print(diff.flatten()[:10]) + + # sample outputs + print("\n[6] Sample outputs (first token, first 5 dimensions):") + print(f"PyTorch: {pytorch_output_np[0, 0, :5]}") + print(f"Allo: {allo_output[0, 0, :5]}") + print(f"Diff: {diff[0, 0, :5]}") + + except Exception as e: + print(f"\n✗ ERROR: Failed to run Allo implementation: {e}") + import traceback + traceback.print_exc() + + print("=" * 60) \ No newline at end of file diff --git a/ece6775/moe/allo_moe_lib.py b/ece6775/moe/allo_moe_lib.py new file mode 100644 index 000000000..05d6bb16d --- /dev/null +++ b/ece6775/moe/allo_moe_lib.py @@ -0,0 +1,437 @@ +# MoE using Allo library functions (nn.linear2d, nn.GeLU) +# follows test_bert pattern - all functions in one file with generic types + +import numpy as np +import allo +import allo.library.nn as nn +from allo.library.nn import linear2d, GeLU # Direct import for type inference +from allo.ir.types import float32, int32 +from allo import dsl +import sys +import os + +# path setup +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(current_dir) +llm_config_dir = os.path.join(project_root, "llm_config") +if llm_config_dir not in sys.path: + sys.path.insert(0, llm_config_dir) + +MODE = "sw_emu" # llvm, sw_emu, hw_emu, hw, csyn + +from llm_config import DEFAULT_CONFIG_MODE, get_moe_config, print_config_info + +CONFIG_MODE = DEFAULT_CONFIG_MODE +def softmax_1d[Ty, N, K](X: "Ty[N, K]") -> "Ty[N, K]": + # softmax over last dim + Z: Ty[N, K] + E_exp: Ty[N, K] + M: Ty[N] = -1000000000000.0 + S: Ty[N] = 0.0 + + # find max per row + for n, k in dsl.grid(N, K, name="row_max"): + if X[n, k] > M[n]: + M[n] = X[n, k] + + # exp and sum + for n, k in dsl.grid(N, K, name="exp_sum"): + E_exp[n, k] = dsl.exp(X[n, k] - M[n]) + S[n] += E_exp[n, k] + + # normalize + for n, k in dsl.grid(N, K, name="update"): + Z[n, k] = E_exp[n, k] / (S[n]) + + return Z + +# top-1 selection +def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": + # pick best expert per token + indices: int32[N] + max_val: Ty[N] = -1000000000000.0 + + for n in range(N, name="init"): + indices[n] = 0 + max_val[n] = logits[n, 0] + + for n, e in dsl.grid(N, E, name="argmax"): + if e > 0: # skip e=0 + if logits[n, e] > max_val[n]: + max_val[n] = logits[n, e] + indices[n] = e + + return indices + +# FFN expert using library functions +def expert[Ty, N, D_in, D_hidden, D_out]( + x: "Ty[N, D_in]", + fc1_weight: "Ty[D_hidden, D_in]", + fc1_bias: "Ty[D_hidden]", + fc2_weight: "Ty[D_out, D_hidden]", + fc2_bias: "Ty[D_out]" +) -> "Ty[N, D_out]": + # using linear2d and GeLU from library (imported at top) + fc1_out = linear2d[Ty, Ty, Ty, N, D_hidden, D_in](x, fc1_weight, fc1_bias) + gelu_out = GeLU[Ty, N, D_hidden](fc1_out) + fc2_out = linear2d[Ty, Ty, Ty, N, D_out, D_hidden](gelu_out, fc2_weight, fc2_bias) + return fc2_out + +# main MoE layer +def moe_layer[Ty, B, L, D_in, D_out, E, K, D_hidden]( + x: "Ty[B, L, D_in]", + gate_weight: "Ty[E, D_in]", + gate_bias: "Ty[E]", + expert_fc1_weights: "Ty[E, D_hidden, D_in]", + expert_fc1_biases: "Ty[E, D_hidden]", + expert_fc2_weights: "Ty[E, D_out, D_hidden]", + expert_fc2_biases: "Ty[E, D_out]" +) -> "Ty[B, L, D_out]": + # Flatten batch and sequence dimensions: N = B * L + N = B * L + x_flat: Ty[N, D_in] = 0.0 + for b, l, d_in in dsl.grid(B, L, D_in, name="flatten"): + x_flat[b * L + l, d_in] = x[b, l, d_in] + + # Step 1: Compute gate logits using linear2d (inlined topk_gate logic) + # Note: Use linear2d (direct import) to avoid conflict with torch.nn + gate_logits = linear2d[Ty, Ty, Ty, N, E, D_in](x_flat, gate_weight, gate_bias) # [N, E] + + # Step 2: Select top-1 expert using top1_select function + top1_indices_1d: int32[N] = top1_select[Ty, N, E](gate_logits) + + # Step 3: Get top-k logits and apply softmax + top_k_logits: Ty[N, K] = 0.0 + for n, k in dsl.grid(N, K, name="topk_logits"): + expert_idx = top1_indices_1d[n] if k == 0 else 0 # For k=1, K=1 + top_k_logits[n, k] = gate_logits[n, expert_idx] + + # Step 4: Apply softmax to top-k logits using softmax_1d function + top_k_weights = softmax_1d[Ty, N, K](top_k_logits) # [N, K] + + # Step 5: Create sparse weight matrix from top-k weights + top_k_indices: int32[N, K] = 0 + gate_weights: Ty[N, E] = 0.0 + + for n in range(N, name="gate_weights"): + # Store top-k indices (for k=1) + top_k_indices[n, 0] = top1_indices_1d[n] + # Set gate weights from softmax output + expert_idx = top1_indices_1d[n] + gate_weights[n, expert_idx] = top_k_weights[n, 0] # For k=1, K=1 + + # Step 2: Process each expert: compute outputs for all tokens + expert_outputs: Ty[E, N, D_out] = 0.0 + + for e in range(E, name="expert_loop"): + # Extract expert weights for this expert + expert_fc1_w: Ty[D_hidden, D_in] = 0.0 + expert_fc1_b: Ty[D_hidden] = 0.0 + expert_fc2_w: Ty[D_out, D_hidden] = 0.0 + expert_fc2_b: Ty[D_out] = 0.0 + + for d_hidden, d_in in dsl.grid(D_hidden, D_in, name="extract_fc1_w"): + expert_fc1_w[d_hidden, d_in] = expert_fc1_weights[e, d_hidden, d_in] + + for d_hidden in range(D_hidden, name="extract_fc1_b"): + expert_fc1_b[d_hidden] = expert_fc1_biases[e, d_hidden] + + for d_out, d_hidden in dsl.grid(D_out, D_hidden, name="extract_fc2_w"): + expert_fc2_w[d_out, d_hidden] = expert_fc2_weights[e, d_out, d_hidden] + + for d_out in range(D_out, name="extract_fc2_b"): + expert_fc2_b[d_out] = expert_fc2_biases[e, d_out] + + # Process all tokens through this expert using the expert function + expert_out = expert[Ty, N, D_in, D_hidden, D_out]( + x_flat, expert_fc1_w, expert_fc1_b, expert_fc2_w, expert_fc2_b + ) # [N, D_out] + + # Store expert outputs + for n, d_out in dsl.grid(N, D_out, name="store_expert_out"): + expert_outputs[e, n, d_out] = expert_out[n, d_out] + + # Step 3: Combine expert outputs using gate weights + output_flat: Ty[N, D_out] = 0.0 + for n, e, d_out in dsl.grid(N, E, D_out, name="combine_outputs"): + weight: Ty = gate_weights[n, e] + output_flat[n, d_out] += expert_outputs[e, n, d_out] * weight + + # Step 4: Reshape back to original shape + output: Ty[B, L, D_out] = 0.0 + for b, l, d_out in dsl.grid(B, L, D_out, name="reshape"): + output[b, l, d_out] = output_flat[b * L + l, d_out] + + return output + + +#================================================================================== +# Schedule optimization function +#================================================================================== +def optimize_moe_with_composition( + batch_size, seq_len, input_dim, output_dim, num_experts, k, hidden_dim +): + # create schedules for sub-functions and compose them + # TODO: add optimizations later (pipeline, partition, etc) + Ty = float32 + N = batch_size * seq_len + D_in = input_dim + E = num_experts + K = k + D_hidden = hidden_dim + D_out = output_dim + + print("=" * 60) + print("Creating and optimizing MoE schedules...") + print("=" * 60) + + # top1_select schedule + print("\n[1] Creating schedule for top1_select...") + s_top1 = allo.customize(top1_select, instantiate=[Ty, N, E]) + print(" - Created top1_select schedule for [N, E]") + + # softmax schedule + print("\n[2] Creating schedule for softmax_1d...") + s_softmax = allo.customize(softmax_1d, instantiate=[Ty, N, K]) + print(" - Created softmax_1d schedule for [N, K]") + + # expert schedule - need to create library function schedules first (type inference workaround) + print("\n[3] Creating schedule for expert...") + import allo.library.nn as allo_nn + + print(" - Creating schedules for library functions (linear2d, GeLU)...") + s_linear_fc1 = allo.customize(allo_nn.linear2d, instantiate=[Ty, Ty, Ty, N, D_hidden, D_in]) + s_gelu = allo.customize(allo_nn.GeLU, instantiate=[Ty, N, D_hidden]) + s_linear_fc2 = allo.customize(allo_nn.linear2d, instantiate=[Ty, Ty, Ty, N, D_out, D_hidden]) + print(" - Library function schedules created") + + print(" - Creating expert schedule...") + s_expert = allo.customize(expert, instantiate=[Ty, N, D_in, D_hidden, D_out]) + + # compose library functions into expert + s_expert.compose(s_linear_fc1, id="expert_fc1") + s_expert.compose(s_gelu, id="expert_gelu") + s_expert.compose(s_linear_fc2, id="expert_fc2") + print(" - Created expert schedule") + print(" - Composed nn.linear2d and nn.GeLU schedules for expert") + + # moe_layer schedule + print("\n[4] Creating schedule for moe_layer...") + s_moe = allo.customize( + moe_layer, + instantiate=[Ty, batch_size, seq_len, input_dim, output_dim, num_experts, k, hidden_dim] + ) + + # compose everything + print("\n[5] Composing all schedules together...") + + # gate linear2d + s_gate_linear = allo.customize(allo_nn.linear2d, instantiate=[Ty, Ty, Ty, N, E, D_in]) + s_moe.compose(s_gate_linear, id="gate") + print(" - Composed linear2d (gate logits) schedule") + + # top1 and softmax + s_moe.compose(s_top1) + s_moe.compose(s_softmax) + print(" - Composed top1_select and softmax_1d schedules") + + # expert (already has library functions composed) + s_moe.compose(s_expert) + print(" - Composed expert schedule (includes nn.linear2d and nn.GeLU)") + + print("\n" + "=" * 60) + print("Schedule composition complete!") + print("=" * 60) + + return s_moe + + +# test/compare with pytorch +if __name__ == "__main__": + import torch + import torch.nn as nn + import sys + import os + # Import pytorch_moe from the same directory + from pytorch_moe import MoELayer + + # ============================================================================ + # Configuration parameters - using shared config module + # ============================================================================ + # Get configuration from shared module + config = get_moe_config(CONFIG_MODE) + batch_size = config["batch_size"] + seq_len = config["seq_len"] + input_dim = config["input_dim"] + output_dim = config["output_dim"] + num_experts = config["num_experts"] + k = config["k"] + hidden_dim = config["hidden_dim"] + + seed = 42 + + # Print configuration info using shared function + print_config_info(CONFIG_MODE, config) + print(f"Seed: {seed}") + + #---------------------------------------------------------------------------------- + # Run PyTorch implementation to get weights and outputs + #---------------------------------------------------------------------------------- + print("\n[1] Running PyTorch implementation...") + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # Create PyTorch MoE layer from pytorch_moe.py + pytorch_moe = MoELayer(input_dim, output_dim, num_experts, k, hidden_dim) + pytorch_moe.eval() + + # Initialize with Xavier uniform (matching PyTorch) + for param in pytorch_moe.parameters(): + if param.dim() > 1: + nn.init.xavier_uniform_(param) + else: + nn.init.zeros_(param) + + # Create random input + torch.manual_seed(seed) + pytorch_input = torch.randn(batch_size, seq_len, input_dim) + + # Run PyTorch inference + with torch.no_grad(): + pytorch_output = pytorch_moe(pytorch_input, verbose=False) + + print(f"PyTorch output shape: {pytorch_output.shape}") + print(f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]") + + #---------------------------------------------------------------------------------- + # Extract weights and biases from PyTorch model + #---------------------------------------------------------------------------------- + print("\n[2] Extracting weights from PyTorch model...") + + # Gate weights: [input_dim, num_experts] -> [num_experts, input_dim] for Allo + gate_weight_pt = pytorch_moe.gate.gate_linear.weight.data # [num_experts, input_dim] + gate_bias_pt = pytorch_moe.gate.gate_linear.bias + if gate_bias_pt is not None: + gate_bias_pt = gate_bias_pt.data + else: + gate_bias_pt = torch.zeros(num_experts) + + # Expert weights + expert_fc1_weights_pt = torch.stack([expert.fc1.weight.data for expert in pytorch_moe.experts]) # [num_experts, hidden_dim, input_dim] + expert_fc1_biases_pt = torch.stack([expert.fc1.bias.data for expert in pytorch_moe.experts]) # [num_experts, hidden_dim] + expert_fc2_weights_pt = torch.stack([expert.fc2.weight.data for expert in pytorch_moe.experts]) # [num_experts, output_dim, hidden_dim] + expert_fc2_biases_pt = torch.stack([expert.fc2.bias.data for expert in pytorch_moe.experts]) # [num_experts, output_dim] + + #---------------------------------------------------------------------------------- + # Convert the weights and biases to numpy arrays (ensure C-contiguous and correct shape) + #---------------------------------------------------------------------------------- + print("\n[3] Converting weights to numpy arrays...") + x_np = np.ascontiguousarray(pytorch_input.detach().numpy(), dtype=np.float32) + + # Gate weights: PyTorch has [num_experts, input_dim], Allo expects [num_experts, input_dim] (same) + gate_weight_np = np.ascontiguousarray(gate_weight_pt.detach().numpy(), dtype=np.float32) + gate_bias_np = np.ascontiguousarray(gate_bias_pt.detach().numpy(), dtype=np.float32) + + # Expert weights: PyTorch has [num_experts, hidden_dim, input_dim], Allo expects [num_experts, hidden_dim, input_dim] (same) + expert_fc1_weights_np = np.ascontiguousarray(expert_fc1_weights_pt.detach().numpy(), dtype=np.float32) + expert_fc1_biases_np = np.ascontiguousarray(expert_fc1_biases_pt.detach().numpy(), dtype=np.float32) + expert_fc2_weights_np = np.ascontiguousarray(expert_fc2_weights_pt.detach().numpy(), dtype=np.float32) + expert_fc2_biases_np = np.ascontiguousarray(expert_fc2_biases_pt.detach().numpy(), dtype=np.float32) + + print(f"Input shape: {x_np.shape}") + print(f"Gate weight shape: {gate_weight_np.shape}") + print(f"Expert FC1 weights shape: {expert_fc1_weights_np.shape}") + print(f"Expert FC2 weights shape: {expert_fc2_weights_np.shape}") + + #---------------------------------------------------------------------------------- + # Run Allo implementation + #---------------------------------------------------------------------------------- + print("\n[4] Running Allo implementation...") + try: + # Create optimized schedule with composition + allo_schedule = optimize_moe_with_composition( + batch_size, seq_len, input_dim, output_dim, num_experts, k, hidden_dim + ) + + # Generate project name based on CONFIG_MODE to avoid conflicts + # This ensures different configurations use different build folders + project_name = f"allo_moe_lib_{CONFIG_MODE}.prj" + print(f"Using project name: {project_name}") + + # Build module + print("\n[5] Building Allo module...") + if MODE == "llvm": + mod = allo_schedule.build(target="llvm") + elif MODE == "sw_emu": + mod = allo_schedule.build(target="vitis_hls", mode="sw_emu", project=project_name) + elif MODE == "hw_emu": + mod = allo_schedule.build(target="vitis_hls", mode="hw_emu", project=project_name) + elif MODE == "hw": + mod = allo_schedule.build(target="vitis_hls", mode="hw", project=project_name) + elif MODE == "csyn": + mod = allo_schedule.build(target="vitis_hls", mode="csyn", project=project_name) + else: + raise ValueError(f"Unsupported mode: {MODE}") + + # Run Allo inference + print("\n[6] Running Allo inference...") + if MODE == "llvm": + allo_output = mod( + x_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np + ) + elif MODE == "sw_emu" or MODE == "hw_emu" or MODE == "hw": + allo_output = np.zeros((batch_size, seq_len, output_dim), dtype=np.float32) + mod(x_np, gate_weight_np, gate_bias_np, expert_fc1_weights_np, expert_fc1_biases_np, expert_fc2_weights_np, expert_fc2_biases_np, allo_output) + elif MODE == "csyn": + allo_output = np.zeros((batch_size, seq_len, output_dim), dtype=np.float32) + mod() + else: + raise ValueError(f"Unsupported mode: {MODE}") + + print(f"Allo output shape: {allo_output.shape}") + print(f"Allo output range: [{allo_output.min():.6f}, {allo_output.max():.6f}]") + + # compare + print("\n[7] Comparing outputs...") + pytorch_output_np = pytorch_output.detach().numpy() + + diff = np.abs(allo_output - pytorch_output_np) + mean_diff = np.mean(diff) + max_diff = np.max(diff) + rel_diff = np.mean(diff / (np.abs(pytorch_output_np) + 1e-8)) + + print(f"Mean absolute difference: {mean_diff:.6e}") + print(f"Max absolute difference: {max_diff:.6e}") + print(f"Mean relative difference: {rel_diff:.6e}") + + # check if close (1e-4 to 1e-3 diff is normal) + atol = 5e-4 + rtol = 2e-3 + is_close = np.allclose(allo_output, pytorch_output_np, atol=atol, rtol=rtol) + + if is_close: + print(f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})") + else: + print(f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})") + print("First few differences:") + print(diff.flatten()[:10]) + + # sample outputs + print("\n[8] Sample outputs (first token, first 5 dimensions):") + print(f"PyTorch: {pytorch_output_np[0, 0, :5]}") + print(f"Allo: {allo_output[0, 0, :5]}") + print(f"Diff: {diff[0, 0, :5]}") + + except Exception as e: + print(f"\n✗ ERROR: Failed to run Allo implementation: {e}") + import traceback + traceback.print_exc() + + print("=" * 60) diff --git a/ece6775/moe/pytorch_moe.py b/ece6775/moe/pytorch_moe.py new file mode 100644 index 000000000..f34ef6427 --- /dev/null +++ b/ece6775/moe/pytorch_moe.py @@ -0,0 +1,336 @@ +""" +Modern Mixture of Experts (MoE) Implementation for Inference Only + +This script implements a simplified MoE layer optimized for inference. +Uses random inputs and weights for demonstration. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple, Optional + + +class TopKGate(nn.Module): + """ + Gate module to select top k experts for routing. + + Args: + input_dim: Input feature dimension + num_experts: Total number of experts + k: Number of experts to select per token + """ + + def __init__(self, input_dim: int, num_experts: int, k: int = 1): + super().__init__() + self.k = k + self.num_experts = num_experts + # Linear layer to compute expert logits + self.gate_linear = nn.Linear(input_dim, num_experts, bias=False) + + def forward(self, x: torch.Tensor, verbose: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass through the gate. + + Args: + x: Input tensor of shape [batch_size * seq_len, input_dim] + verbose: Whether to print routing information + + Returns: + full_weights: Sparse weight matrix [batch_size * seq_len, num_experts] + top_k_indices: Indices of selected experts [batch_size * seq_len, k] + """ + # Compute logits for all experts + logits = self.gate_linear(x) # [N, num_experts] + + # Select top-k experts + top_k_logits, top_k_indices = torch.topk(logits, self.k, dim=-1) + + # Apply softmax to top-k logits for normalized weights + top_k_weights = F.softmax(top_k_logits, dim=-1) # [N, k] + + # Create sparse weight matrix (zeros for non-selected experts) + full_weights = torch.zeros_like(logits) + full_weights.scatter_(1, top_k_indices, top_k_weights) + + # Print routing information if verbose + if verbose: + num_tokens = x.shape[0] + + # Count tokens per expert + expert_counts = {} + for expert_idx in range(self.num_experts): + count = (top_k_indices == expert_idx).sum().item() + if count > 0: + expert_counts[expert_idx] = count + + # Verify that each token selects exactly k experts + tokens_per_expert_count = {} + for i in range(num_tokens): + num_experts_selected = (top_k_weights[i] > 1e-6).sum().item() + tokens_per_expert_count[num_experts_selected] = tokens_per_expert_count.get(num_experts_selected, 0) + 1 + + print(f"\n[Gate Routing] Total tokens: {num_tokens}, k={self.k}") + print(f"[Gate Routing] Expert distribution:") + total_selections = num_tokens * self.k # Total expert-token pairs + for expert_idx, count in sorted(expert_counts.items()): + percentage = (count / total_selections) * 100 + print(f" Expert {expert_idx}: {count} selections ({percentage:.1f}% of {total_selections} total selections)") + + # Verify top-k: each token selects exactly k experts + if tokens_per_expert_count.get(self.k, 0) == num_tokens: + print(f"[Gate Routing] ✓ All {num_tokens} tokens select exactly {self.k} expert(s)") + else: + print(f"[Gate Routing] ⚠ Expected all tokens to select {self.k} expert(s), but got: {tokens_per_expert_count}") + + return full_weights, top_k_indices + + +class Expert(nn.Module): + """ + A simple feed-forward expert network. + + Args: + input_dim: Input feature dimension + hidden_dim: Hidden layer dimension + output_dim: Output feature dimension + """ + + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): + super().__init__() + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, output_dim) + # Use GELU activation (modern choice) + self.activation = nn.GELU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the expert.""" + x = self.fc1(x) + x = self.activation(x) + x = self.fc2(x) + return x + + +class MoELayer(nn.Module): + """ + Mixture of Experts layer for inference. + + Args: + input_dim: Input feature dimension + output_dim: Output feature dimension + num_experts: Total number of experts + k: Number of experts to activate per token + expert_hidden_dim: Hidden dimension for experts (default: 4 * input_dim) + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + num_experts: int, + k: int = 1, + expert_hidden_dim: Optional[int] = None + ): + super().__init__() + self.num_experts = num_experts + self.k = k + self.output_dim = output_dim + + if expert_hidden_dim is None: + expert_hidden_dim = input_dim * 4 # Common practice + + # Initialize gate and experts + self.gate = TopKGate(input_dim, num_experts, k) + self.experts = nn.ModuleList([ + Expert(input_dim, expert_hidden_dim, output_dim) + for _ in range(num_experts) + ]) + + def forward(self, x: torch.Tensor, verbose: bool = False) -> torch.Tensor: + """ + Forward pass through MoE layer. + + Args: + x: Input tensor of shape [batch_size, seq_len, input_dim] + verbose: Whether to print verification information for top-1 MoE + + Returns: + Output tensor of shape [batch_size, seq_len, output_dim] + """ + + # Store original shape + original_shape = x.shape + batch_size, seq_len = original_shape[0], original_shape[1] + + # Flatten batch and sequence dimensions + x_flat = x.view(-1, original_shape[-1]) # [N, input_dim], N = batch * seq_len + num_tokens = x_flat.shape[0] + + # Get gating weights and expert indices + gate_weights, top_k_indices = self.gate(x_flat, verbose=verbose) # [N, num_experts], [N, k] + + # Initialize output tensor + output = torch.zeros( + x_flat.shape[0], + self.output_dim, + device=x.device, + dtype=x.dtype + ) + + # Track expert usage statistics + expert_usage_stats = {} + + # Process each expert + for expert_idx in range(self.num_experts): + # Find tokens assigned to this expert + # Create mask: tokens that have this expert in their top-k + expert_mask = (top_k_indices == expert_idx).any(dim=-1) # [N] + + num_tokens_for_expert = expert_mask.sum().item() + + if not expert_mask.any(): + expert_usage_stats[expert_idx] = 0 + continue # No tokens for this expert + + # Track usage + expert_usage_stats[expert_idx] = num_tokens_for_expert + + # Get inputs for this expert + expert_inputs = x_flat[expert_mask] # [num_tokens, input_dim] + + # Process through expert + expert_outputs = self.experts[expert_idx](expert_inputs) # [num_tokens, output_dim] + + # Get corresponding weights for these tokens + # For each token, find the weight for this expert (could be in any of k positions) + token_indices = torch.where(expert_mask)[0] # Indices in flattened tensor + expert_weights = gate_weights[token_indices, expert_idx].unsqueeze(1) # [num_tokens, 1] + + # Weight and accumulate outputs + weighted_outputs = expert_outputs * expert_weights + output[token_indices] += weighted_outputs + + # Verify top-k behavior + if verbose: + active_experts = [idx for idx, count in expert_usage_stats.items() if count > 0] + total_usage = sum(expert_usage_stats.values()) + + # Verify top-k: each token selects exactly k experts + if top_k_indices.shape[1] != self.k: + print(f"[MoE Verification] ✗ ERROR: Expected k={self.k}, but top_k_indices has shape {top_k_indices.shape}") + + # Verify all tokens are processed (for k>1, total_usage may be > num_tokens) + expected_total_usage = num_tokens * self.k + if total_usage != expected_total_usage: + print(f"[MoE Verification] ⚠ Expected {expected_total_usage} expert-token pairs (={num_tokens} tokens × {self.k} experts), but got {total_usage}") + + # Check if weights are normalized (should sum to 1.0 per token) + sample_weights = gate_weights.sum(dim=1) + weights_normalized = torch.allclose(sample_weights, torch.ones_like(sample_weights), atol=1e-5) + + # Final summary for top-k + print(f"\n[MoE Verification] === Top-{self.k} MoE Verification ===") + print(f" ✓ Each token routes to exactly {self.k} expert(s) (k={self.k})") + print(f" ✓ All {num_tokens} tokens processed") + print(f" ✓ Total expert-token pairs: {total_usage} (expected: {expected_total_usage})") + print(f" ✓ Active experts: {len(active_experts)}/{self.num_experts}") + print(f" ✓ Expert distribution: {dict(sorted(expert_usage_stats.items()))}") + if weights_normalized: + print(f" ✓ Gate weights normalized (sum to 1.0 per token)") + else: + print(f" ✗ Gate weights NOT normalized correctly") + print(f" === Top-{self.k} MoE is working correctly! ===") + + # Reshape back to original shape + output = output.view(batch_size, seq_len, self.output_dim) + + return output + + +def run_moe_inference( + batch_size: int = 4, + seq_len: int = 4, + input_dim: int = 64, + output_dim: int = 64, + num_experts: int = 2, + k: int = 1, + seed: int = 24, + verbose: bool = True +) -> Tuple[torch.Tensor, torch.Tensor, MoELayer]: + """ + Run MoE layer inference with fixed seed for reproducible inputs and weights. + + Args: + batch_size: Batch size + seq_len: Sequence length + input_dim: Input feature dimension + output_dim: Output feature dimension + num_experts: Number of experts + k: Top-k experts to activate per token + seed: Random seed for reproducibility + verbose: Whether to print verification information + + Returns: + output: Output tensor from MoE layer [batch_size, seq_len, output_dim] + input_tensor: Input tensor used for inference + moe_layer: MoE layer instance with initialized weights + """ + # Set random seed for reproducibility + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # Create MoE layer + moe_layer = MoELayer(input_dim, output_dim, num_experts, k) + moe_layer.eval() + + # Initialize with random weights (Xavier uniform for better stability) + for param in moe_layer.parameters(): + if param.dim() > 1: + nn.init.xavier_uniform_(param) + else: + nn.init.zeros_(param) + + # Create random input + torch.manual_seed(seed) + input_tensor = torch.randn(batch_size, seq_len, input_dim) + + # Run inference with verbose output for verification + with torch.no_grad(): + output = moe_layer(input_tensor, verbose=verbose) + + return output, input_tensor, moe_layer + + +if __name__ == "__main__": + # Configuration parameters + batch_size = 4 + seq_len = 1 + input_dim = 1 + output_dim = 1 + num_experts = 2 + k = 1 # Top-k MoE: each token uses exactly k experts + seed = 24 + + print("=" * 60) + print(f"Top-{k} MoE Inference Test") + print("=" * 60) + print(f"Configuration: batch={batch_size}, seq_len={seq_len}, input_dim={input_dim}") + print(f"Experts: {num_experts}, k={k}") + print("=" * 60) + + # Run MoE inference + output, input_tensor, moe_layer = run_moe_inference( + batch_size=batch_size, + seq_len=seq_len, + input_dim=input_dim, + output_dim=output_dim, + num_experts=num_experts, + k=k, + seed=seed, + verbose=True + ) + + print(f"\nOutput shape: {output.shape}") + print("=" * 60) + From b3f85ccf7293df5f5c08173164f5cd4a34ebb824 Mon Sep 17 00:00:00 2001 From: zh476 Date: Fri, 12 Dec 2025 21:44:10 -0500 Subject: [PATCH 3/3] formatting --- ece6775/README.md | 4 + .../attention_moe/allo_attention_moe_alt.py | 448 ++++++++++------- .../attention_moe/allo_attention_moe_base.py | 370 ++++++++------ .../attention_moe/allo_attention_moe_lib.py | 421 +++++++++------- .../attention_moe/pytorch_attention_moe.py | 471 ++++++++++-------- ece6775/llm_config/check_llm_config.py | 131 +++-- ece6775/llm_config/llm_config.py | 9 +- ece6775/moe/allo_moe_alt.py | 276 ++++++---- ece6775/moe/allo_moe_base.py | 267 ++++++---- ece6775/moe/allo_moe_lib.py | 284 +++++++---- ece6775/moe/pytorch_moe.py | 192 ++++--- 11 files changed, 1693 insertions(+), 1180 deletions(-) diff --git a/ece6775/README.md b/ece6775/README.md index b2be2a2c0..fbc6ea18f 100644 --- a/ece6775/README.md +++ b/ece6775/README.md @@ -1,3 +1,6 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + # ECE6775 - MoE and Attention+MoE Implementations Allo implementations of MoE (Mixture of Experts) and Attention+MoE layers for FPGA acceleration. @@ -41,3 +44,4 @@ Available configs: - Small numerical differences (~1e-4) vs PyTorch are normal due to accumulation order - Set `MODE` in each file to control build target: `llvm`, `sw_emu`, `hw_emu`, `hw`, `csyn` +Github PR link: https://github.com/cornell-zhang/allo/pull/489 \ No newline at end of file diff --git a/ece6775/attention_moe/allo_attention_moe_alt.py b/ece6775/attention_moe/allo_attention_moe_alt.py index c30b616b6..a76aed77d 100644 --- a/ece6775/attention_moe/allo_attention_moe_alt.py +++ b/ece6775/attention_moe/allo_attention_moe_alt.py @@ -1,3 +1,5 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 # Attention + MoE in Allo # Q, K, V -> Attention -> MoE -> Output # custom implementations (no library functions) for full control @@ -22,78 +24,80 @@ from llm_config import DEFAULT_CONFIG_MODE, get_moe_config, print_config_info CONFIG_MODE = DEFAULT_CONFIG_MODE + + def softmax_1d[Ty, N, K](X: "Ty[N, K]") -> "Ty[N, K]": # softmax over last dim Z: Ty[N, K] E_exp: Ty[N, K] M: Ty[N] = -1000000000000.0 S: Ty[N] = 0.0 - + # find max per row for n, k in dsl.grid(N, K, name="row_max"): if X[n, k] > M[n]: M[n] = X[n, k] - + # exp and sum for n, k in dsl.grid(N, K, name="exp_sum"): E_exp[n, k] = dsl.exp(X[n, k] - M[n]) S[n] += E_exp[n, k] - + # normalize for n, k in dsl.grid(N, K, name="normalize"): Z[n, k] = E_exp[n, k] / S[n] - + return Z + # softmax for attention scores [L, L] def softmax_2d[Ty, L](X: "Ty[L, L]") -> "Ty[L, L]": Z: Ty[L, L] E_exp: Ty[L, L] M: Ty[L] = -1000000000000.0 S: Ty[L] = 0.0 - + # find max per row for i, j in dsl.grid(L, L, name="row_max"): if X[i, j] > M[i]: M[i] = X[i, j] - + # exp and sum for i, j in dsl.grid(L, L, name="exp_sum"): E_exp[i, j] = dsl.exp(X[i, j] - M[i]) S[i] += E_exp[i, j] - + # normalize for i, j in dsl.grid(L, L, name="normalize"): Z[i, j] = E_exp[i, j] / S[i] - + return Z + # scaled dot-product attention (multi-head) -def scaled_dot_product_attention[Ty, H, L, D]( - Q: "Ty[L, D]", - K: "Ty[L, D]", - V: "Ty[L, D]" -) -> "Ty[L, D]": +def scaled_dot_product_attention[ + Ty, H, L, D +](Q: "Ty[L, D]", K: "Ty[L, D]", V: "Ty[L, D]") -> "Ty[L, D]": # scaled dot-product attention (multi-head) - + # matches pytorch: split into H heads, compute attention per head, merge Z: Ty[L, D] = 0.0 - + # scale factor: 1/sqrt(head_dim) scale: Ty = 1.0 / dsl.sqrt(float(D // H)) - + # process each head for h in range(H, name="head_loop"): # split Q, K, V for this head Q_h: Ty[L, D // H] = 0.0 K_h: Ty[L, D // H] = 0.0 V_h: Ty[L, D // H] = 0.0 - + for i, j in dsl.grid(L, D // H, name="split_qkv"): Q_h[i, j] = Q[i, h * (D // H) + j] K_h[i, j] = K[i, h * (D // H) + j] V_h[i, j] = V[i, h * (D // H) + j] - + # QK^T with accumulator to break read-after-write dependency Y: Ty[L, L] = 0.0 for i in range(L, name="qkt_i"): @@ -102,32 +106,32 @@ def scaled_dot_product_attention[Ty, H, L, D]( for k in range(D // H, name="qkt_k"): acc += Q_h[i, k] * K_h[j, k] # K_h[j, k] = K_h^T[k, j] Y[i, j] = acc - + # scale Y_scaled: Ty[L, L] = 0.0 for i, j in dsl.grid(L, L, name="scale"): Y_scaled[i, j] = Y[i, j] * scale - + # Apply softmax over last dimension S: Ty[L, L] = 0.0 E_exp: Ty[L, L] = 0.0 M: Ty[L] = -1000000000000.0 Sum: Ty[L] = 0.0 - + # Find max for each row for i, j in dsl.grid(L, L, name="softmax_max"): if Y_scaled[i, j] > M[i]: M[i] = Y_scaled[i, j] - + # Compute exp and sum for i, j in dsl.grid(L, L, name="softmax_exp"): E_exp[i, j] = dsl.exp(Y_scaled[i, j] - M[i]) Sum[i] += E_exp[i, j] - + # Normalize for i, j in dsl.grid(L, L, name="softmax_norm"): S[i, j] = E_exp[i, j] / Sum[i] - + # Compute S @ V_h = [L, L] @ [L, D//H] = [L, D//H] # Loop order (i, k, j) for better memory access and pipelining: # - Sequential access to V_h[k, j] as j changes (inner loop) @@ -137,50 +141,54 @@ def scaled_dot_product_attention[Ty, H, L, D]( for k in range(L, name="sv_k"): for j in range(D // H, name="sv_j"): C_h[i, j] += S[i, k] * V_h[k, j] - + # Merge back to Z for i, j in dsl.grid(L, D // H, name="merge_heads"): Z[i, h * (D // H) + j] = C_h[i, j] - + return Z + # top-1 selection (argmax) def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": # pick best expert per token indices: int32[N] max_val: Ty[N] = -1000000000000.0 - + for n in range(N, name="init"): indices[n] = 0 max_val[n] = logits[n, 0] - + for n, e in dsl.grid(N, E, name="argmax"): if e > 0: # skip e=0 if logits[n, e] > max_val[n]: max_val[n] = logits[n, e] indices[n] = e - + return indices + # FFN expert with fused GeLU # optimizations: fuse GeLU into FC1, row-level dataflow, unroll reduction loops -def expert[Ty, N, D_in, D_hidden, D_out]( +def expert[ + Ty, N, D_in, D_hidden, D_out +]( x: "Ty[N, D_in]", fc1_weight: "Ty[D_hidden, D_in]", fc1_bias: "Ty[D_hidden]", fc2_weight: "Ty[D_out, D_hidden]", - fc2_bias: "Ty[D_out]" + fc2_bias: "Ty[D_out]", ) -> "Ty[N, D_out]": """ A simple feed-forward expert network with fused GeLU and row-level dataflow. - + Optimizations applied: 1. GeLU fused into FC1: Compute GeLU immediately after each FC1 element to eliminate intermediate array storage and enable streaming 2. Row-level processing: Outer loop over rows (N) enables dataflow between FC1+GeLU and FC2 stages - while FC2 processes row n, FC1 can process row n+1 3. Separate reduction loops enable unroll pragmas - + Args: x: Input tensor of shape [N, D_in] fc1_weight: First linear layer weights of shape [D_hidden, D_in] @@ -191,7 +199,7 @@ def expert[Ty, N, D_in, D_hidden, D_out]( output: Output tensor of shape [N, D_out] """ output: Ty[N, D_out] = 0.0 - + # Row-level processing: outer loop over rows (tokens) # This structure enables dataflow between FC1+GeLU and FC2 for n in range(N, name="row_loop"): @@ -200,7 +208,7 @@ def expert[Ty, N, D_in, D_hidden, D_out]( # ===================================================================== # Intermediate buffer for this row's hidden activations (after GeLU) hidden_row: Ty[D_hidden] = 0.0 - + # Compute FC1 output for each hidden dimension, fuse GeLU immediately for h in range(D_hidden, name="fc1_gelu_loop"): # FC1: dot product x[n,:] @ fc1_weight[h,:] + fc1_bias[h] @@ -208,7 +216,7 @@ def expert[Ty, N, D_in, D_hidden, D_out]( for k in range(D_in, name="fc1_reduce"): acc += x[n, k] * fc1_weight[h, k] fc1_val: Ty = acc + fc1_bias[h] - + # Fused GeLU: apply immediately to avoid storing fc1_out array # GeLU(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³))) # This matches PyTorch's GELU implementation (tanh approximation) @@ -218,7 +226,7 @@ def expert[Ty, N, D_in, D_hidden, D_out]( tanh_val: Ty = dsl.tanh(inner) gelu_val: Ty = 0.5 * fc1_val * (1.0 + tanh_val) hidden_row[h] = gelu_val - + # ===================================================================== # Stage 2: FC2 (consumes hidden row, produces output row) # ===================================================================== @@ -228,11 +236,14 @@ def expert[Ty, N, D_in, D_hidden, D_out]( for k in range(D_hidden, name="fc2_reduce"): acc2 += hidden_row[k] * fc2_weight[o, k] output[n, o] = acc2 + fc2_bias[o] - + return output + # MoE layer - routes tokens to experts -def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( +def moe_layer[ + Ty, N, D_in, D_out, E, K, D_hidden +]( x: "Ty[N, D_in]", # Gate weights gate_weight: "Ty[E, D_in]", @@ -241,11 +252,11 @@ def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( expert_fc1_weights: "Ty[E, D_hidden, D_in]", expert_fc1_biases: "Ty[E, D_hidden]", expert_fc2_weights: "Ty[E, D_out, D_hidden]", - expert_fc2_biases: "Ty[E, D_out]" + expert_fc2_biases: "Ty[E, D_out]", ) -> "Ty[N, D_out]": """ Mixture of Experts layer for inference. - + Args: x: Input tensor of shape [N, D_in] (already flattened from [B, L, D_in]) gate_weight: Gate weight matrix of shape [E, D_in] @@ -267,12 +278,12 @@ def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( for k in range(D_in, name="gate_reduce"): acc += x[n, k] * gate_weight[e, k] gate_logits[n, e] = acc + gate_bias[e] - + # ========================================================================= # Step 2: Select top-1 expert using top1_select function # ========================================================================= top1_indices_1d: int32[N] = top1_select[Ty, N, E](gate_logits) - + # ========================================================================= # Step 3: Get top-k logits and apply softmax # ========================================================================= @@ -280,12 +291,12 @@ def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( for n, k in dsl.grid(N, K, name="topk_logits"): expert_idx = top1_indices_1d[n] if k == 0 else 0 # For k=1, K=1 top_k_logits[n, k] = gate_logits[n, expert_idx] - + # ========================================================================= # Step 4: Apply softmax to top-k logits using softmax_1d function # ========================================================================= top_k_weights = softmax_1d[Ty, N, K](top_k_logits) # [N, K] - + # ========================================================================= # Step 5: Create sparse weight matrix from top-k weights # ========================================================================= @@ -293,40 +304,40 @@ def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( for n in range(N, name="sparse_gate"): expert_idx = top1_indices_1d[n] gate_weights[n, expert_idx] = top_k_weights[n, 0] # For k=1, K=1 - + # ========================================================================= # Step 6: Process each expert: compute outputs for all tokens # ========================================================================= expert_outputs: Ty[E, N, D_out] = 0.0 - + for e in range(E, name="expert_loop"): # Extract expert weights for this expert expert_fc1_w: Ty[D_hidden, D_in] = 0.0 expert_fc1_b: Ty[D_hidden] = 0.0 expert_fc2_w: Ty[D_out, D_hidden] = 0.0 expert_fc2_b: Ty[D_out] = 0.0 - + for d_hidden, d_in in dsl.grid(D_hidden, D_in, name="extract_fc1_w"): expert_fc1_w[d_hidden, d_in] = expert_fc1_weights[e, d_hidden, d_in] - + for d_hidden in range(D_hidden, name="extract_fc1_b"): expert_fc1_b[d_hidden] = expert_fc1_biases[e, d_hidden] - + for d_out, d_hidden in dsl.grid(D_out, D_hidden, name="extract_fc2_w"): expert_fc2_w[d_out, d_hidden] = expert_fc2_weights[e, d_out, d_hidden] - + for d_out in range(D_out, name="extract_fc2_b"): expert_fc2_b[d_out] = expert_fc2_biases[e, d_out] - + # Process all tokens through this expert (uses optimized expert function) expert_out = expert[Ty, N, D_in, D_hidden, D_out]( x, expert_fc1_w, expert_fc1_b, expert_fc2_w, expert_fc2_b ) # [N, D_out] - + # Store expert outputs for n, d_out in dsl.grid(N, D_out, name="store_expert_out"): expert_outputs[e, n, d_out] = expert_out[n, d_out] - + # ========================================================================= # Step 7: Combine expert outputs using gate weights # ========================================================================= @@ -334,14 +345,17 @@ def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( for n, e, d_out in dsl.grid(N, E, D_out, name="combine_outputs"): weight: Ty = gate_weights[n, e] output[n, d_out] += expert_outputs[e, n, d_out] * weight - + return output -#---------------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------------- # Attention + MoE Layer: Combined layer # Data flow: Q, K, V -> Attention -> MoE -> Output -#---------------------------------------------------------------------------------- -def attention_moe_layer[Ty, B, L, D, H, E, TopK, D_hidden]( +# ---------------------------------------------------------------------------------- +def attention_moe_layer[ + Ty, B, L, D, H, E, TopK, D_hidden +]( Query: "Ty[B, L, D]", Key: "Ty[B, L, D]", Value: "Ty[B, L, D]", @@ -352,13 +366,13 @@ def attention_moe_layer[Ty, B, L, D, H, E, TopK, D_hidden]( expert_fc1_weights: "Ty[E, D_hidden, D]", expert_fc1_biases: "Ty[E, D_hidden]", expert_fc2_weights: "Ty[E, D, D_hidden]", - expert_fc2_biases: "Ty[E, D]" + expert_fc2_biases: "Ty[E, D]", ) -> "Ty[B, L, D]": """ Combined Attention + MoE layer. - + Data flow: Q, K, V -> Attention -> MoE -> Output - + Args: Query: Query tensor of shape [B, L, D] Key: Key tensor of shape [B, L, D] @@ -374,23 +388,23 @@ def attention_moe_layer[Ty, B, L, D, H, E, TopK, D_hidden]( """ # Output tensor output: Ty[B, L, D] = 0.0 - + # Process each batch item for b in range(B, name="batch_loop"): # Step 1: Extract Q, K, V for this batch item -> [L, D] Q_b: Ty[L, D] = 0.0 K_b: Ty[L, D] = 0.0 V_b: Ty[L, D] = 0.0 - + for l, d in dsl.grid(L, D, name="extract_qkv"): Q_b[l, d] = Query[b, l, d] K_b[l, d] = Key[b, l, d] V_b[l, d] = Value[b, l, d] - + # Step 2: Apply custom scaled_dot_product_attention # scaled_dot_product_attention[Ty, H, L, D](Q, K, V) -> [L, D] attn_out = scaled_dot_product_attention[Ty, H, L, D](Q_b, K_b, V_b) # [L, D] - + # Step 3: Apply MoE layer # moe_layer expects [N, D_in], so we use N=L for single batch moe_out = moe_layer[Ty, L, D, D, E, TopK, D_hidden]( @@ -400,30 +414,30 @@ def attention_moe_layer[Ty, B, L, D, H, E, TopK, D_hidden]( expert_fc1_weights, expert_fc1_biases, expert_fc2_weights, - expert_fc2_biases + expert_fc2_biases, ) # [L, D] - + # Step 4: Store output for this batch item for l, d in dsl.grid(L, D, name="store_output"): output[b, l, d] = moe_out[l, d] - + return output -#================================================================================== +# ================================================================================== # Schedule optimization function with HLS pragmas -#================================================================================== +# ================================================================================== def optimize_attention_moe_with_composition( batch_size, seq_len, embed_dim, num_heads, num_experts, k, hidden_dim ): """ Create optimized schedules for Attention + MoE and compose them together. - + Optimization strategy: 1. Pipeline innermost loops for throughput 2. Unroll small loops for parallelism 3. Partition arrays for parallel access - + Args: batch_size: Batch size (B) seq_len: Sequence length (L) @@ -432,7 +446,7 @@ def optimize_attention_moe_with_composition( num_experts: Number of experts (E) k: Top-k value (currently only k=1 is supported) hidden_dim: Hidden dimension for experts - + Returns: s_attn_moe: Optimized schedule for attention_moe_layer with all sub-schedules composed """ @@ -445,11 +459,11 @@ def optimize_attention_moe_with_composition( K = k D_hidden = hidden_dim head_dim = D // H - + print("=" * 60) print("Creating and optimizing Attention + MoE schedules...") print("=" * 60) - + # ========================================================================= # Step 1: Create and optimize schedule for top1_select # ========================================================================= @@ -459,7 +473,7 @@ def optimize_attention_moe_with_composition( # Note: Use "function_name:loop_var" format for dsl.grid loops s_top1.pipeline("top1_select:e") print(" - Created top1_select schedule with pipeline optimization") - + # ========================================================================= # Step 2: Create and optimize schedule for softmax_1d (for MoE gate) # ========================================================================= @@ -471,7 +485,7 @@ def optimize_attention_moe_with_composition( s_softmax_1d.pipeline(loops_softmax["exp_sum"]["k"]) s_softmax_1d.pipeline(loops_softmax["normalize"]["k"]) print(" - Created softmax_1d schedule with pipeline optimization") - + # ========================================================================= # Step 3: Create and optimize schedule for expert # Optimizations: @@ -481,24 +495,24 @@ def optimize_attention_moe_with_composition( # ========================================================================= print("\n[3] Creating schedule for expert (custom, no library functions)...") s_expert = allo.customize(expert, instantiate=[Ty, L, D, D_hidden, D]) - + # Get loop handles for expert expert_loops = s_expert.get_loops(s_expert.top_func_name) print(f" - Available loops: {list(expert_loops.loops.keys())}") - + # Get nested loops inside row_loop row_loop = expert_loops["row_loop"] print(f" - row_loop sub-loops: {list(row_loop.loops.keys())}") - + # ------------------------------------------------------------------------- # Note: Allo flattens nested loops and uses loop variable names as keys # row_loop sub-loops: ['n', 'h', 'k', 'o'] where: # n = row index (from row_loop) - # h = hidden dim index (from fc1_gelu_loop) + # h = hidden dim index (from fc1_gelu_loop) # k = reduction index (shared by fc1_reduce and fc2_reduce) # o = output dim index (from fc2_loop) # ------------------------------------------------------------------------- - + # ------------------------------------------------------------------------- # Optimization 1: Pipeline the output dimension loops (h and o) # This pipelines both the FC1+GeLU loop and FC2 loop @@ -506,26 +520,26 @@ def optimize_attention_moe_with_composition( s_expert.pipeline(row_loop["h"]) # Pipeline FC1+GeLU hidden dim loop s_expert.pipeline(row_loop["o"]) # Pipeline FC2 output dim loop print(" - Applied pipeline to h (fc1_gelu) and o (fc2) loops") - + # ------------------------------------------------------------------------- # Optimization 2: Unroll reduction loop for parallel MACs # The k loop is shared between FC1 and FC2 reductions # ------------------------------------------------------------------------- # Determine unroll factor based on dimensions unroll_factor = min(4, D, D_hidden) - + s_expert.unroll(row_loop["k"], factor=unroll_factor) print(f" - Applied unroll to k (reduction loop) factor={unroll_factor}") - + # ------------------------------------------------------------------------- # Note: Removed explicit array partitioning # Let HLS infer partitioning from pipelining directives for better tool runtime # Explicit partitioning can explode synthesis time and may not be optimal # ------------------------------------------------------------------------- - + print(" - Created expert schedule with pipeline and unroll optimizations") print(" - Note: Array partitioning inferred by HLS from pipelining") - + # ========================================================================= # Step 4: Create and optimize schedule for moe_layer # Optimizations: @@ -535,26 +549,26 @@ def optimize_attention_moe_with_composition( # ========================================================================= print("\n[4] Creating schedule for moe_layer...") s_moe = allo.customize(moe_layer, instantiate=[Ty, L, D, D, E, K, D_hidden]) - + # Get loop handles for moe_layer moe_loops = s_moe.get_loops(s_moe.top_func_name) print(f" - Available top-level loops: {list(moe_loops.loops.keys())}") - + # ------------------------------------------------------------------------- # Optimize gate_linear: pipeline and unroll # Note: Allo flattens loops - gate_linear sub-loops are likely ['n', 'e', 'k'] # ------------------------------------------------------------------------- gate_linear_loop = moe_loops["gate_linear"] print(f" - gate_linear sub-loops: {list(gate_linear_loop.loops.keys())}") - + # Pipeline the e loop (output dimension) and unroll k (reduction) s_moe.pipeline(gate_linear_loop["e"]) print(" - Applied pipeline to gate_linear:e") - + unroll_factor_gate = min(4, D) s_moe.unroll(gate_linear_loop["k"], factor=unroll_factor_gate) print(f" - Applied unroll to gate_linear:k (factor={unroll_factor_gate})") - + # ------------------------------------------------------------------------- # Optimize other loops in moe_layer # Note: For dsl.grid loops, sub-loops use variable names (n, e, k, d_out, etc.) @@ -562,16 +576,16 @@ def optimize_attention_moe_with_composition( combine_loop = moe_loops["combine_outputs"] print(f" - combine_outputs sub-loops: {list(combine_loop.loops.keys())}") s_moe.pipeline(combine_loop["d_out"]) - + topk_loop = moe_loops["topk_logits"] print(f" - topk_logits sub-loops: {list(topk_loop.loops.keys())}") s_moe.pipeline(topk_loop["k"]) - + sparse_loop = moe_loops["sparse_gate"] print(f" - sparse_gate sub-loops: {list(sparse_loop.loops.keys())}") s_moe.pipeline(sparse_loop["n"]) print(" - Applied pipeline to combine_outputs:d_out, topk_logits:k, sparse_gate:n") - + # ------------------------------------------------------------------------- # Compose sub-function schedules # ------------------------------------------------------------------------- @@ -580,17 +594,17 @@ def optimize_attention_moe_with_composition( s_moe.compose(s_expert) print(" - Composed top1_select, softmax_1d, and expert schedules") print(" - Created moe_layer schedule with pipeline, unroll optimizations") - + # ========================================================================= # Step 5: Create and optimize schedule for scaled_dot_product_attention # ========================================================================= print("\n[5] Creating schedule for custom scaled_dot_product_attention...") s_attn = allo.customize(scaled_dot_product_attention, instantiate=[Ty, H, L, D]) - + # Get loop handles for attention attn_loops = s_attn.get_loops(s_attn.top_func_name) print(f" - Available top-level loops: {list(attn_loops.loops.keys())}") - + # Helper function to print loop hierarchy recursively def print_loop_hierarchy(loops, indent=0): for key, handle in loops.loops.items(): @@ -601,7 +615,7 @@ def print_loop_hierarchy(loops, indent=0): print(" - Full loop hierarchy:") print_loop_hierarchy(attn_loops, indent=4) - + # Attention has nested loops inside head_loop # The structure is: head_loop -> h -> [split_qkv, qkt_matmul, scale, softmax_*, sv_matmul, merge_heads] head_inner = attn_loops["head_loop"] @@ -613,14 +627,22 @@ def print_loop_hierarchy(loops, indent=0): # qkt_loops = head_inner["qkt_matmul"] # print(f" - sv_matmul sub-loops: {list(sv_loops.loops.keys())}") # print(f" - qkt_matmul sub-loops: {list(qkt_loops.loops.keys())}") - + # Pipeline reduction loops (k) for matmuls - this is critical for II=1 # The reordered loops (i, k, j) allow better pipelining of k reduction - s_attn.pipeline("scaled_dot_product_attention:k") # Pipeline k reduction loops (qkt_k, sv_k) - s_attn.pipeline("scaled_dot_product_attention:j") # Pipeline j loops for output dimension - s_attn.pipeline("scaled_dot_product_attention:i") # Pipeline i loops (sv_i) to overlap row processing + s_attn.pipeline( + "scaled_dot_product_attention:k" + ) # Pipeline k reduction loops (qkt_k, sv_k) + s_attn.pipeline( + "scaled_dot_product_attention:j" + ) # Pipeline j loops for output dimension + s_attn.pipeline( + "scaled_dot_product_attention:i" + ) # Pipeline i loops (sv_i) to overlap row processing print(" - Applied pipeline to i (row), k (reduction), and j (output) loops") - print(" - Note: Loop reordering (i,k,j) enables better memory access and pipelining") + print( + " - Note: Loop reordering (i,k,j) enables better memory access and pipelining" + ) print(" - Pipelining sv_i allows overlapping processing of different output rows") # ------------------------------------------------------------------------- @@ -632,7 +654,7 @@ def print_loop_hierarchy(loops, indent=0): # print(" - Applied unroll to sv_matmul:k (factor=4)") # s_attn.unroll(qkt_loops["k"], factor=4) # print(" - Applied unroll to qkt_matmul:k (factor=4)") - + # ------------------------------------------------------------------------- # Apply pipeline optimization to reduction loops # ------------------------------------------------------------------------- @@ -641,7 +663,7 @@ def print_loop_hierarchy(loops, indent=0): # print(" - Applied pipeline to sv_matmul:k (reduction loop)") # s_attn.pipeline(qkt_loops["k"]) # print(" - Applied pipeline to qkt_matmul:k (reduction loop)") - + # ------------------------------------------------------------------------- # Apply reorder optimization for better data locality # For matrix multiplication C[i,j] += A[i,k] * B[k,j], reorder to i,k,j @@ -652,7 +674,7 @@ def print_loop_hierarchy(loops, indent=0): # print(" - Applied reorder to sv_matmul: (i, j, k) -> (i, k, j)") # s_attn.reorder(qkt_loops["k"], qkt_loops["j"]) # print(" - Applied reorder to qkt_matmul: (i, j, k) -> (i, k, j)") - + # ------------------------------------------------------------------------- # Apply buffer_at optimization to reduce memory access # Creates on-chip buffer for intermediate results @@ -662,33 +684,34 @@ def print_loop_hierarchy(loops, indent=0): # print(" - Applied buffer_at to C_h at sv_matmul:i") # s_attn.buffer_at(s_attn.Y, axis=qkt_loops["i"]) # print(" - Applied buffer_at to Y at qkt_matmul:i") - - print(" - Created scaled_dot_product_attention schedule with pipeline optimizations") + + print( + " - Created scaled_dot_product_attention schedule with pipeline optimizations" + ) print(" - Matmul loops reordered to (i,k,j) for better memory access and II=1") - + # ========================================================================= # Step 6: Create schedule for main attention_moe_layer function # ========================================================================= print("\n[6] Creating schedule for attention_moe_layer...") s_attn_moe = allo.customize( - attention_moe_layer, - instantiate=[Ty, B, L, D, H, E, K, D_hidden] + attention_moe_layer, instantiate=[Ty, B, L, D, H, E, K, D_hidden] ) - + # Get loop handles for attention_moe_layer attn_moe_loops = s_attn_moe.get_loops(s_attn_moe.top_func_name) print(f" - Available top-level loops: {list(attn_moe_loops.loops.keys())}") print(" - Full loop hierarchy for attention_moe_layer:") print_loop_hierarchy(attn_moe_loops, indent=4) - + # Pipeline top-level loops using get_loops() # Note: extract_qkv and store_output are inside batch_loop batch_loop = attn_moe_loops["batch_loop"] - + # s_attn_moe.pipeline(batch_loop["extract_qkv"]["d"]) # s_attn_moe.pipeline(batch_loop["store_output"]["d"]) - + # ========================================================================= # Step 7: Compose all schedules together # ========================================================================= @@ -697,7 +720,7 @@ def print_loop_hierarchy(loops, indent=0): s_attn_moe.compose(s_moe) print(" - Composed scaled_dot_product_attention schedule") print(" - Composed moe_layer schedule") - + print("\n" + "=" * 60) print("Schedule composition complete with optimizations!") print("=" * 60) @@ -713,7 +736,7 @@ def print_loop_hierarchy(loops, indent=0): print("Attention Optimizations (unchanged):") print(" - Pipeline on innermost loops (j, k)") print("=" * 60) - + return s_attn_moe @@ -722,17 +745,17 @@ def print_loop_hierarchy(loops, indent=0): import torch import torch.nn as torch_nn from pytorch_attention_moe import AttentionMoE - + # get config moe_config = get_moe_config(CONFIG_MODE) - + batch_size = moe_config["batch_size"] seq_len = moe_config["seq_len"] embed_dim = moe_config["input_dim"] # D, must be divisible by num_heads num_experts = moe_config["num_experts"] # E k = moe_config["k"] # Top-k MoE hidden_dim = moe_config["hidden_dim"] # D_hidden - + # Attention-specific parameter: num_heads # Choose num_heads such that embed_dim is divisible by num_heads # Common choices: 2, 4, 8, 12, 16 (depending on embed_dim) @@ -748,13 +771,13 @@ def print_loop_hierarchy(loops, indent=0): num_heads = 4 else: num_heads = 2 - + # Ensure embed_dim is divisible by num_heads while embed_dim % num_heads != 0 and num_heads > 1: num_heads -= 1 - + seed = 42 - + print("=" * 60) print("Attention + MoE Allo Implementation Test") print("=" * 60) @@ -765,15 +788,15 @@ def print_loop_hierarchy(loops, indent=0): print(f" num_experts={num_experts}, k={k}, hidden_dim={hidden_dim}") print(f" Seed: {seed}") print("=" * 60) - - #---------------------------------------------------------------------------------- + + # ---------------------------------------------------------------------------------- # Run PyTorch implementation to get weights and outputs - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[1] Running PyTorch implementation...") torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) - + # Create PyTorch AttentionMoE layer pytorch_model = AttentionMoE( seq_len=seq_len, @@ -781,165 +804,212 @@ def print_loop_hierarchy(loops, indent=0): num_heads=num_heads, num_experts=num_experts, k=k, - expert_hidden_dim=hidden_dim + expert_hidden_dim=hidden_dim, ) pytorch_model.eval() - + # Initialize with Xavier uniform for param in pytorch_model.parameters(): if param.dim() > 1: torch_nn.init.xavier_uniform_(param) else: torch_nn.init.zeros_(param) - + # Create random inputs (Q, K, V) torch.manual_seed(seed) Q_pt = torch.randn(batch_size, seq_len, embed_dim) K_pt = torch.randn(batch_size, seq_len, embed_dim) V_pt = torch.randn(batch_size, seq_len, embed_dim) - + # Run PyTorch inference with torch.no_grad(): pytorch_output = pytorch_model(Q_pt, K_pt, V_pt, verbose=False) - + print(f"PyTorch output shape: {pytorch_output.shape}") - print(f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]") - - #---------------------------------------------------------------------------------- + print( + f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]" + ) + + # ---------------------------------------------------------------------------------- # Extract weights and biases from PyTorch model - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[2] Extracting weights from PyTorch model...") - + # Gate weights - gate_weight_pt = pytorch_model.moe.gate.gate_linear.weight.data # [num_experts, embed_dim] + gate_weight_pt = ( + pytorch_model.moe.gate.gate_linear.weight.data + ) # [num_experts, embed_dim] gate_bias_pt = pytorch_model.moe.gate.gate_linear.bias if gate_bias_pt is not None: gate_bias_pt = gate_bias_pt.data else: gate_bias_pt = torch.zeros(num_experts) - + # Expert weights - expert_fc1_weights_pt = torch.stack([exp.fc1.weight.data for exp in pytorch_model.moe.experts]) - expert_fc1_biases_pt = torch.stack([exp.fc1.bias.data for exp in pytorch_model.moe.experts]) - expert_fc2_weights_pt = torch.stack([exp.fc2.weight.data for exp in pytorch_model.moe.experts]) - expert_fc2_biases_pt = torch.stack([exp.fc2.bias.data for exp in pytorch_model.moe.experts]) - - #---------------------------------------------------------------------------------- + expert_fc1_weights_pt = torch.stack( + [exp.fc1.weight.data for exp in pytorch_model.moe.experts] + ) + expert_fc1_biases_pt = torch.stack( + [exp.fc1.bias.data for exp in pytorch_model.moe.experts] + ) + expert_fc2_weights_pt = torch.stack( + [exp.fc2.weight.data for exp in pytorch_model.moe.experts] + ) + expert_fc2_biases_pt = torch.stack( + [exp.fc2.bias.data for exp in pytorch_model.moe.experts] + ) + + # ---------------------------------------------------------------------------------- # Convert to numpy arrays - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[3] Converting weights to numpy arrays...") Q_np = np.ascontiguousarray(Q_pt.detach().numpy(), dtype=np.float32) K_np = np.ascontiguousarray(K_pt.detach().numpy(), dtype=np.float32) V_np = np.ascontiguousarray(V_pt.detach().numpy(), dtype=np.float32) - - gate_weight_np = np.ascontiguousarray(gate_weight_pt.detach().numpy(), dtype=np.float32) + + gate_weight_np = np.ascontiguousarray( + gate_weight_pt.detach().numpy(), dtype=np.float32 + ) gate_bias_np = np.ascontiguousarray(gate_bias_pt.detach().numpy(), dtype=np.float32) - - expert_fc1_weights_np = np.ascontiguousarray(expert_fc1_weights_pt.detach().numpy(), dtype=np.float32) - expert_fc1_biases_np = np.ascontiguousarray(expert_fc1_biases_pt.detach().numpy(), dtype=np.float32) - expert_fc2_weights_np = np.ascontiguousarray(expert_fc2_weights_pt.detach().numpy(), dtype=np.float32) - expert_fc2_biases_np = np.ascontiguousarray(expert_fc2_biases_pt.detach().numpy(), dtype=np.float32) - + + expert_fc1_weights_np = np.ascontiguousarray( + expert_fc1_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc1_biases_np = np.ascontiguousarray( + expert_fc1_biases_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_weights_np = np.ascontiguousarray( + expert_fc2_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_biases_np = np.ascontiguousarray( + expert_fc2_biases_pt.detach().numpy(), dtype=np.float32 + ) + print(f"Q shape: {Q_np.shape}") print(f"K shape: {K_np.shape}") print(f"V shape: {V_np.shape}") print(f"Gate weight shape: {gate_weight_np.shape}") print(f"Expert FC1 weights shape: {expert_fc1_weights_np.shape}") print(f"Expert FC2 weights shape: {expert_fc2_weights_np.shape}") - - #---------------------------------------------------------------------------------- + + # ---------------------------------------------------------------------------------- # Run Allo implementation - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[4] Running Allo implementation...") try: # Create optimized schedule with composition allo_schedule = optimize_attention_moe_with_composition( batch_size, seq_len, embed_dim, num_heads, num_experts, k, hidden_dim ) - + # Generate project name project_name = f"allo_attention_moe_alt_{CONFIG_MODE}.prj" print(f"Using project name: {project_name}") - + # Build module print("\n[5] Building Allo module...") if MODE == "llvm": mod = allo_schedule.build(target="llvm") elif MODE == "sw_emu": - mod = allo_schedule.build(target="vitis_hls", mode="sw_emu", project=project_name) + mod = allo_schedule.build( + target="vitis_hls", mode="sw_emu", project=project_name + ) elif MODE == "hw_emu": - mod = allo_schedule.build(target="vitis_hls", mode="hw_emu", project=project_name) + mod = allo_schedule.build( + target="vitis_hls", mode="hw_emu", project=project_name + ) elif MODE == "hw": - mod = allo_schedule.build(target="vitis_hls", mode="hw", project=project_name) + mod = allo_schedule.build( + target="vitis_hls", mode="hw", project=project_name + ) elif MODE == "csyn": - mod = allo_schedule.build(target="vitis_hls", mode="csyn", project=project_name) + mod = allo_schedule.build( + target="vitis_hls", mode="csyn", project=project_name + ) else: raise ValueError(f"Unsupported mode: {MODE}") - + # Run Allo inference print("\n[6] Running Allo inference...") if MODE == "llvm": allo_output = mod( - Q_np, K_np, V_np, - gate_weight_np, gate_bias_np, - expert_fc1_weights_np, expert_fc1_biases_np, - expert_fc2_weights_np, expert_fc2_biases_np + Q_np, + K_np, + V_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, ) elif MODE in ["sw_emu", "hw_emu", "hw"]: allo_output = np.zeros((batch_size, seq_len, embed_dim), dtype=np.float32) - mod(Q_np, K_np, V_np, - gate_weight_np, gate_bias_np, - expert_fc1_weights_np, expert_fc1_biases_np, - expert_fc2_weights_np, expert_fc2_biases_np, - allo_output) + mod( + Q_np, + K_np, + V_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + allo_output, + ) elif MODE == "csyn": allo_output = np.zeros((batch_size, seq_len, embed_dim), dtype=np.float32) mod() else: raise ValueError(f"Unsupported mode: {MODE}") - + print(f"Allo output shape: {allo_output.shape}") print(f"Allo output range: [{allo_output.min():.6f}, {allo_output.max():.6f}]") - - #---------------------------------------------------------------------------------- + + # ---------------------------------------------------------------------------------- # Compare the Allo and PyTorch outputs - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[7] Comparing outputs...") pytorch_output_np = pytorch_output.detach().numpy() - + # Compute differences diff = np.abs(allo_output - pytorch_output_np) mean_diff = np.mean(diff) max_diff = np.max(diff) rel_diff = np.mean(diff / (np.abs(pytorch_output_np) + 1e-8)) - + print(f"Mean absolute difference: {mean_diff:.6e}") print(f"Max absolute difference: {max_diff:.6e}") print(f"Mean relative difference: {rel_diff:.6e}") - + # Check if outputs are close atol = 5e-4 rtol = 2e-3 is_close = np.allclose(allo_output, pytorch_output_np, atol=atol, rtol=rtol) - + if is_close: - print(f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})") + print( + f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})" + ) else: - print(f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})") + print( + f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})" + ) print("First few differences:") print(diff.flatten()[:10]) - - #---------------------------------------------------------------------------------- + + # ---------------------------------------------------------------------------------- # Print sample outputs for comparison - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[8] Sample outputs (first token, first 5 dimensions):") print(f"PyTorch: {pytorch_output_np[0, 0, :5]}") print(f"Allo: {allo_output[0, 0, :5]}") print(f"Diff: {diff[0, 0, :5]}") - + except Exception as e: print(f"\n✗ ERROR: Failed to run Allo implementation: {e}") import traceback + traceback.print_exc() - + print("=" * 60) diff --git a/ece6775/attention_moe/allo_attention_moe_base.py b/ece6775/attention_moe/allo_attention_moe_base.py index a4d851576..2ad20c6ab 100644 --- a/ece6775/attention_moe/allo_attention_moe_base.py +++ b/ece6775/attention_moe/allo_attention_moe_base.py @@ -1,3 +1,5 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 # Attention + MoE in Allo (no library version) # Q, K, V -> Attention -> MoE -> Output # all manual implementation, no library functions @@ -21,10 +23,12 @@ from llm_config import DEFAULT_CONFIG_MODE, get_moe_config, print_config_info CONFIG_MODE = DEFAULT_CONFIG_MODE + + def softmax_1d[Ty, N, K](X: "Ty[N, K]") -> "Ty[N, K]": """ Softmax over last dimension for [N, K] shape. - + Args: X: Input tensor of shape [N, K] Returns: @@ -34,113 +38,113 @@ def softmax_1d[Ty, N, K](X: "Ty[N, K]") -> "Ty[N, K]": E_exp: Ty[N, K] M: Ty[N] = -1000000000000.0 S: Ty[N] = 0.0 - + # Find max for each row (over dimension K) for n, k in dsl.grid(N, K, name="row_max"): if X[n, k] > M[n]: M[n] = X[n, k] - + # Compute exp and sum for n, k in dsl.grid(N, K, name="exp_sum"): E_exp[n, k] = dsl.exp(X[n, k] - M[n]) S[n] += E_exp[n, k] - + # Normalize for n, k in dsl.grid(N, K, name="update"): Z[n, k] = E_exp[n, k] / S[n] - + return Z -#---------------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------------- # Scaled Dot-Product Attention: Custom implementation matching PyTorch -#---------------------------------------------------------------------------------- -def scaled_dot_product_attention[Ty, H, L, D]( - Q: "Ty[L, D]", - K: "Ty[L, D]", - V: "Ty[L, D]" -) -> "Ty[L, D]": +# ---------------------------------------------------------------------------------- +def scaled_dot_product_attention[ + Ty, H, L, D +](Q: "Ty[L, D]", K: "Ty[L, D]", V: "Ty[L, D]") -> "Ty[L, D]": """ Scaled Dot-Product Attention (Multi-Head Attention). - + This implementation matches the PyTorch version: - Input: Q, K, V of shape [L, D] - Split into H heads, each with dimension D // H - For each head: softmax(QK^T / sqrt(head_dim)) @ V - Merge heads back to [L, D] - + Args: Q: Query tensor of shape [L, D] K: Key tensor of shape [L, D] V: Value tensor of shape [L, D] - + Returns: Z: Output tensor of shape [L, D] """ Z: Ty[L, D] = 0.0 - + # Compute scale factor: 1 / sqrt(head_dim) = 1 / sqrt(D // H) scale: Ty = 1.0 / dsl.sqrt(float(D // H)) - + # Process each head for h in range(H, name="head_loop"): # Split Q, K, V for this head Q_h: Ty[L, D // H] = 0.0 K_h: Ty[L, D // H] = 0.0 V_h: Ty[L, D // H] = 0.0 - + for i, j in dsl.grid(L, D // H, name="split_qkv"): Q_h[i, j] = Q[i, h * (D // H) + j] K_h[i, j] = K[i, h * (D // H) + j] V_h[i, j] = V[i, h * (D // H) + j] - + # Compute QK^T = [L, D//H] @ [D//H, L] = [L, L] Y: Ty[L, L] = 0.0 for i, j, k in dsl.grid(L, L, D // H, name="qkt_matmul"): Y[i, j] += Q_h[i, k] * K_h[j, k] # K_h[j, k] is K_h^T[k, j] - + # Scale by 1/sqrt(head_dim) Y_scaled: Ty[L, L] = 0.0 for i, j in dsl.grid(L, L, name="scale"): Y_scaled[i, j] = Y[i, j] * scale - + # Apply softmax over last dimension (inline implementation) S: Ty[L, L] = 0.0 E_exp: Ty[L, L] = 0.0 M: Ty[L] = -1000000000000.0 Sum: Ty[L] = 0.0 - + # Find max for each row for i, j in dsl.grid(L, L, name="softmax_max"): if Y_scaled[i, j] > M[i]: M[i] = Y_scaled[i, j] - + # Compute exp and sum for i, j in dsl.grid(L, L, name="softmax_exp"): E_exp[i, j] = dsl.exp(Y_scaled[i, j] - M[i]) Sum[i] += E_exp[i, j] - + # Normalize for i, j in dsl.grid(L, L, name="softmax_norm"): S[i, j] = E_exp[i, j] / Sum[i] - + # Compute S @ V_h = [L, L] @ [L, D//H] = [L, D//H] C_h: Ty[L, D // H] = 0.0 for i, j, k in dsl.grid(L, D // H, L, name="sv_matmul"): C_h[i, j] += S[i, k] * V_h[k, j] - + # Merge back to Z for i, j in dsl.grid(L, D // H, name="merge_heads"): Z[i, h * (D // H) + j] = C_h[i, j] - + return Z -#---------------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------------- # Top1 selection: Top-k selection for k=1 (argmax) -#---------------------------------------------------------------------------------- +# ---------------------------------------------------------------------------------- def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": """ Select top-1 expert (argmax) for each token. - + Args: logits: Input logits of shape [N, E] Returns: @@ -148,26 +152,29 @@ def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": """ indices: int32[N] max_val: Ty[N] = -1000000000000.0 - + # Initialize indices and max_val with first expert for n in range(N, name="init"): indices[n] = 0 max_val[n] = logits[n, 0] - + # Find argmax for each token (search from index 1 onwards) for n, e in dsl.grid(N, E, name="argmax"): if e > 0: # Skip e=0 (already initialized) if logits[n, e] > max_val[n]: max_val[n] = logits[n, e] indices[n] = e - + return indices -#---------------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------------- # MoE Layer: Main MoE layer - ALL MANUAL IMPLEMENTATION (NO LIBRARY FUNCTIONS) # Based on moe_allo.py implementation -#---------------------------------------------------------------------------------- -def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( +# ---------------------------------------------------------------------------------- +def moe_layer[ + Ty, N, D_in, D_out, E, K, D_hidden +]( x: "Ty[N, D_in]", # Gate weights gate_weight: "Ty[E, D_in]", @@ -176,12 +183,12 @@ def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( expert_fc1_weights: "Ty[E, D_hidden, D_in]", expert_fc1_biases: "Ty[E, D_hidden]", expert_fc2_weights: "Ty[E, D_out, D_hidden]", - expert_fc2_biases: "Ty[E, D_out]" + expert_fc2_biases: "Ty[E, D_out]", ) -> "Ty[N, D_out]": """ Mixture of Experts layer for inference. ALL MANUAL IMPLEMENTATION - NO LIBRARY FUNCTIONS. - + Args: x: Input tensor of shape [N, D_in] (already flattened from [B, L, D_in]) gate_weight: Gate weight matrix of shape [E, D_in] @@ -204,23 +211,23 @@ def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( # Matrix multiplication: x[i] @ gate_weight[j]^T for k in range(D_in): gate_logits[i, j] += x[i, k] * gate_weight[j, k] - + # ========================================================================= # Step 2: Select top-1 expert (argmax) for each token # ========================================================================= top1_indices: int32[N] max_logit: Ty[N] = -1000000000000.0 - + for n in range(N, name="top1_init"): top1_indices[n] = 0 max_logit[n] = gate_logits[n, 0] - + for n, e in dsl.grid(N, E, name="top1_argmax"): if e > 0: if gate_logits[n, e] > max_logit[n]: max_logit[n] = gate_logits[n, e] top1_indices[n] = e - + # ========================================================================= # Step 3: Get top-k logits and apply softmax # For k=1, softmax just returns 1.0, but we compute for consistency @@ -229,24 +236,24 @@ def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( for n, k in dsl.grid(N, K, name="topk_logits"): expert_idx: int32 = top1_indices[n] if k == 0 else 0 top_k_logits[n, k] = gate_logits[n, expert_idx] - + # Apply softmax to top-k logits (inline implementation) top_k_weights: Ty[N, K] = 0.0 softmax_max: Ty[N] = -1000000000000.0 softmax_sum: Ty[N] = 0.0 softmax_exp: Ty[N, K] = 0.0 - + for n, k in dsl.grid(N, K, name="softmax_max"): if top_k_logits[n, k] > softmax_max[n]: softmax_max[n] = top_k_logits[n, k] - + for n, k in dsl.grid(N, K, name="softmax_exp"): softmax_exp[n, k] = dsl.exp(top_k_logits[n, k] - softmax_max[n]) softmax_sum[n] += softmax_exp[n, k] - + for n, k in dsl.grid(N, K, name="softmax_norm"): top_k_weights[n, k] = softmax_exp[n, k] / softmax_sum[n] - + # ========================================================================= # Step 4: Create sparse weight matrix from top-k weights # ========================================================================= @@ -254,42 +261,42 @@ def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( for n in range(N, name="gate_weights"): expert_idx: int32 = top1_indices[n] gate_weights[n, expert_idx] = top_k_weights[n, 0] # For k=1 - + # ========================================================================= # Step 5: Process each expert - ALL MANUAL (NO LIBRARY FUNCTIONS) # ========================================================================= expert_outputs: Ty[E, N, D_out] = 0.0 - + for e in range(E, name="expert_loop"): # Extract expert weights for this expert expert_fc1_w: Ty[D_hidden, D_in] = 0.0 expert_fc1_b: Ty[D_hidden] = 0.0 expert_fc2_w: Ty[D_out, D_hidden] = 0.0 expert_fc2_b: Ty[D_out] = 0.0 - + for d_hidden, d_in in dsl.grid(D_hidden, D_in, name="extract_fc1_w"): expert_fc1_w[d_hidden, d_in] = expert_fc1_weights[e, d_hidden, d_in] - + for d_hidden in range(D_hidden, name="extract_fc1_b"): expert_fc1_b[d_hidden] = expert_fc1_biases[e, d_hidden] - + for d_out, d_hidden in dsl.grid(D_out, D_hidden, name="extract_fc2_w"): expert_fc2_w[d_out, d_hidden] = expert_fc2_weights[e, d_out, d_hidden] - + for d_out in range(D_out, name="extract_fc2_b"): expert_fc2_b[d_out] = expert_fc2_biases[e, d_out] - + # --------------------------------------------------------------------- # Expert forward pass - MANUAL IMPLEMENTATION # --------------------------------------------------------------------- - + # FC1: fc1_out = x @ fc1_weight^T + fc1_bias fc1_out: Ty[N, D_hidden] = 0.0 for i, j in dsl.grid(N, D_hidden, name="fc1_linear"): fc1_out[i, j] = expert_fc1_b[j] for k in range(D_in): fc1_out[i, j] += x[i, k] * expert_fc1_w[j, k] - + # GELU activation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) gelu_out: Ty[N, D_hidden] = 0.0 for i, j in dsl.grid(N, D_hidden, name="gelu"): @@ -299,18 +306,18 @@ def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( inner: Ty = 0.7978845608028654 * (x_val + 0.044715 * x3) tanh_term: Ty = dsl.tanh(inner) gelu_out[i, j] = 0.5 * x_val * (1.0 + tanh_term) - + # FC2: fc2_out = gelu_out @ fc2_weight^T + fc2_bias expert_out: Ty[N, D_out] = 0.0 for i, j in dsl.grid(N, D_out, name="fc2_linear"): expert_out[i, j] = expert_fc2_b[j] for k in range(D_hidden): expert_out[i, j] += gelu_out[i, k] * expert_fc2_w[j, k] - + # Store expert outputs for n, d_out in dsl.grid(N, D_out, name="store_expert_out"): expert_outputs[e, n, d_out] = expert_out[n, d_out] - + # ========================================================================= # Step 6: Combine expert outputs using gate weights # ========================================================================= @@ -318,14 +325,17 @@ def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( for n, e, d_out in dsl.grid(N, E, D_out, name="combine_outputs"): weight: Ty = gate_weights[n, e] output[n, d_out] += expert_outputs[e, n, d_out] * weight - + return output -#---------------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------------- # Attention + MoE Layer: Combined layer # Data flow: Q, K, V -> Attention -> MoE -> Output -#---------------------------------------------------------------------------------- -def attention_moe_layer[Ty, B, L, D, H, E, TopK, D_hidden]( +# ---------------------------------------------------------------------------------- +def attention_moe_layer[ + Ty, B, L, D, H, E, TopK, D_hidden +]( Query: "Ty[B, L, D]", Key: "Ty[B, L, D]", Value: "Ty[B, L, D]", @@ -336,13 +346,13 @@ def attention_moe_layer[Ty, B, L, D, H, E, TopK, D_hidden]( expert_fc1_weights: "Ty[E, D_hidden, D]", expert_fc1_biases: "Ty[E, D_hidden]", expert_fc2_weights: "Ty[E, D, D_hidden]", - expert_fc2_biases: "Ty[E, D]" + expert_fc2_biases: "Ty[E, D]", ) -> "Ty[B, L, D]": """ Combined Attention + MoE layer. - + Data flow: Q, K, V -> Attention -> MoE -> Output - + Args: Query: Query tensor of shape [B, L, D] Key: Key tensor of shape [B, L, D] @@ -358,22 +368,22 @@ def attention_moe_layer[Ty, B, L, D, H, E, TopK, D_hidden]( """ # Output tensor output: Ty[B, L, D] = 0.0 - + # Process each batch item for b in range(B, name="batch_loop"): # Step 1: Extract Q, K, V for this batch item -> [L, D] Q_b: Ty[L, D] = 0.0 K_b: Ty[L, D] = 0.0 V_b: Ty[L, D] = 0.0 - + for l, d in dsl.grid(L, D, name="extract_qkv"): Q_b[l, d] = Query[b, l, d] K_b[l, d] = Key[b, l, d] V_b[l, d] = Value[b, l, d] - + # Step 2: Apply custom scaled_dot_product_attention attn_out = scaled_dot_product_attention[Ty, H, L, D](Q_b, K_b, V_b) # [L, D] - + # Step 3: Apply MoE layer (all manual implementation) moe_out = moe_layer[Ty, L, D, D, E, TopK, D_hidden]( attn_out, @@ -382,26 +392,26 @@ def attention_moe_layer[Ty, B, L, D, H, E, TopK, D_hidden]( expert_fc1_weights, expert_fc1_biases, expert_fc2_weights, - expert_fc2_biases + expert_fc2_biases, ) # [L, D] - + # Step 4: Store output for this batch item for l, d in dsl.grid(L, D, name="store_output"): output[b, l, d] = moe_out[l, d] - + return output -#================================================================================== +# ================================================================================== # Schedule optimization function - NO LIBRARY FUNCTIONS -#================================================================================== +# ================================================================================== def optimize_attention_moe_with_composition( batch_size, seq_len, embed_dim, num_heads, num_experts, k, hidden_dim ): """ Create optimized schedules for Attention + MoE and compose them together. NO LIBRARY FUNCTIONS - all manual implementation. - + Args: batch_size: Batch size (B) seq_len: Sequence length (L) @@ -410,7 +420,7 @@ def optimize_attention_moe_with_composition( num_experts: Number of experts (E) k: Top-k value (currently only k=1 is supported) hidden_dim: Hidden dimension for experts - + Returns: s_attn_moe: Optimized schedule for attention_moe_layer with all sub-schedules composed """ @@ -422,63 +432,62 @@ def optimize_attention_moe_with_composition( E = num_experts K = k D_hidden = hidden_dim - + print("=" * 60) print("Creating and optimizing Attention + MoE schedules (NO LIBRARY)...") print("=" * 60) - + # Step 1: Create schedule for custom attention print("\n[1] Creating schedule for custom scaled_dot_product_attention...") s_attn = allo.customize(scaled_dot_product_attention, instantiate=[Ty, H, L, D]) print(" - Created scaled_dot_product_attention schedule") - + # Step 2: Create schedule for moe_layer (all manual) print("\n[2] Creating schedule for moe_layer (manual implementation)...") s_moe = allo.customize(moe_layer, instantiate=[Ty, L, D, D, E, K, D_hidden]) print(" - Created moe_layer schedule (no library functions)") - + # Step 3: Create schedule for main attention_moe_layer function print("\n[3] Creating schedule for attention_moe_layer...") s_attn_moe = allo.customize( - attention_moe_layer, - instantiate=[Ty, B, L, D, H, E, K, D_hidden] + attention_moe_layer, instantiate=[Ty, B, L, D, H, E, K, D_hidden] ) - + # Step 4: Compose all schedules together print("\n[4] Composing all schedules together...") s_attn_moe.compose(s_attn) s_attn_moe.compose(s_moe) print(" - Composed scaled_dot_product_attention schedule") print(" - Composed moe_layer schedule") - + print("\n" + "=" * 60) print("Schedule composition complete (NO LIBRARY FUNCTIONS)!") print("=" * 60) - + return s_attn_moe -#================================================================================== +# ================================================================================== # Test function to compare Allo and PyTorch implementations -#================================================================================== +# ================================================================================== if __name__ == "__main__": import torch import torch.nn as torch_nn from pytorch_attention_moe import AttentionMoE - + # ============================================================================ # Configuration parameters - use shared config from llm_config.py # ============================================================================ moe_config = get_moe_config(CONFIG_MODE) - + batch_size = moe_config["batch_size"] seq_len = moe_config["seq_len"] embed_dim = moe_config["input_dim"] # D, must be divisible by num_heads num_experts = moe_config["num_experts"] # E k = moe_config["k"] # Top-k MoE hidden_dim = moe_config["hidden_dim"] # D_hidden - + # Attention-specific parameter: num_heads if embed_dim >= 768: num_heads = 12 @@ -492,13 +501,13 @@ def optimize_attention_moe_with_composition( num_heads = 4 else: num_heads = 2 - + # Ensure embed_dim is divisible by num_heads while embed_dim % num_heads != 0 and num_heads > 1: num_heads -= 1 - + seed = 42 - + print("=" * 60) print("Attention + MoE Allo Implementation Test (NO LIBRARY)") print("=" * 60) @@ -509,15 +518,15 @@ def optimize_attention_moe_with_composition( print(f" num_experts={num_experts}, k={k}, hidden_dim={hidden_dim}") print(f" Seed: {seed}") print("=" * 60) - - #---------------------------------------------------------------------------------- + + # ---------------------------------------------------------------------------------- # Run PyTorch implementation to get weights and outputs - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[1] Running PyTorch implementation...") torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) - + # Create PyTorch AttentionMoE layer pytorch_model = AttentionMoE( seq_len=seq_len, @@ -525,35 +534,37 @@ def optimize_attention_moe_with_composition( num_heads=num_heads, num_experts=num_experts, k=k, - expert_hidden_dim=hidden_dim + expert_hidden_dim=hidden_dim, ) pytorch_model.eval() - + # Initialize with Xavier uniform for param in pytorch_model.parameters(): if param.dim() > 1: torch_nn.init.xavier_uniform_(param) else: torch_nn.init.zeros_(param) - + # Create random inputs (Q, K, V) torch.manual_seed(seed) Q_pt = torch.randn(batch_size, seq_len, embed_dim) K_pt = torch.randn(batch_size, seq_len, embed_dim) V_pt = torch.randn(batch_size, seq_len, embed_dim) - + # Run PyTorch inference with torch.no_grad(): pytorch_output = pytorch_model(Q_pt, K_pt, V_pt, verbose=False) - + print(f"PyTorch output shape: {pytorch_output.shape}") - print(f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]") - - #---------------------------------------------------------------------------------- + print( + f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]" + ) + + # ---------------------------------------------------------------------------------- # Extract weights and biases from PyTorch model - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[2] Extracting weights from PyTorch model...") - + # Gate weights gate_weight_pt = pytorch_model.moe.gate.gate_linear.weight.data gate_bias_pt = pytorch_model.moe.gate.gate_linear.bias @@ -561,129 +572,172 @@ def optimize_attention_moe_with_composition( gate_bias_pt = gate_bias_pt.data else: gate_bias_pt = torch.zeros(num_experts) - + # Expert weights - expert_fc1_weights_pt = torch.stack([exp.fc1.weight.data for exp in pytorch_model.moe.experts]) - expert_fc1_biases_pt = torch.stack([exp.fc1.bias.data for exp in pytorch_model.moe.experts]) - expert_fc2_weights_pt = torch.stack([exp.fc2.weight.data for exp in pytorch_model.moe.experts]) - expert_fc2_biases_pt = torch.stack([exp.fc2.bias.data for exp in pytorch_model.moe.experts]) - - #---------------------------------------------------------------------------------- + expert_fc1_weights_pt = torch.stack( + [exp.fc1.weight.data for exp in pytorch_model.moe.experts] + ) + expert_fc1_biases_pt = torch.stack( + [exp.fc1.bias.data for exp in pytorch_model.moe.experts] + ) + expert_fc2_weights_pt = torch.stack( + [exp.fc2.weight.data for exp in pytorch_model.moe.experts] + ) + expert_fc2_biases_pt = torch.stack( + [exp.fc2.bias.data for exp in pytorch_model.moe.experts] + ) + + # ---------------------------------------------------------------------------------- # Convert to numpy arrays - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[3] Converting weights to numpy arrays...") Q_np = np.ascontiguousarray(Q_pt.detach().numpy(), dtype=np.float32) K_np = np.ascontiguousarray(K_pt.detach().numpy(), dtype=np.float32) V_np = np.ascontiguousarray(V_pt.detach().numpy(), dtype=np.float32) - - gate_weight_np = np.ascontiguousarray(gate_weight_pt.detach().numpy(), dtype=np.float32) + + gate_weight_np = np.ascontiguousarray( + gate_weight_pt.detach().numpy(), dtype=np.float32 + ) gate_bias_np = np.ascontiguousarray(gate_bias_pt.detach().numpy(), dtype=np.float32) - - expert_fc1_weights_np = np.ascontiguousarray(expert_fc1_weights_pt.detach().numpy(), dtype=np.float32) - expert_fc1_biases_np = np.ascontiguousarray(expert_fc1_biases_pt.detach().numpy(), dtype=np.float32) - expert_fc2_weights_np = np.ascontiguousarray(expert_fc2_weights_pt.detach().numpy(), dtype=np.float32) - expert_fc2_biases_np = np.ascontiguousarray(expert_fc2_biases_pt.detach().numpy(), dtype=np.float32) - + + expert_fc1_weights_np = np.ascontiguousarray( + expert_fc1_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc1_biases_np = np.ascontiguousarray( + expert_fc1_biases_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_weights_np = np.ascontiguousarray( + expert_fc2_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_biases_np = np.ascontiguousarray( + expert_fc2_biases_pt.detach().numpy(), dtype=np.float32 + ) + print(f"Q shape: {Q_np.shape}") print(f"K shape: {K_np.shape}") print(f"V shape: {V_np.shape}") print(f"Gate weight shape: {gate_weight_np.shape}") print(f"Expert FC1 weights shape: {expert_fc1_weights_np.shape}") print(f"Expert FC2 weights shape: {expert_fc2_weights_np.shape}") - - #---------------------------------------------------------------------------------- + + # ---------------------------------------------------------------------------------- # Run Allo implementation - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[4] Running Allo implementation...") try: # Create optimized schedule with composition allo_schedule = optimize_attention_moe_with_composition( batch_size, seq_len, embed_dim, num_heads, num_experts, k, hidden_dim ) - + # Generate project name project_name = f"allo_attention_moe_base_{CONFIG_MODE}.prj" print(f"Using project name: {project_name}") - + # Build module print("\n[5] Building Allo module...") if MODE == "llvm": mod = allo_schedule.build(target="llvm") elif MODE == "sw_emu": - mod = allo_schedule.build(target="vitis_hls", mode="sw_emu", project=project_name) + mod = allo_schedule.build( + target="vitis_hls", mode="sw_emu", project=project_name + ) elif MODE == "hw_emu": - mod = allo_schedule.build(target="vitis_hls", mode="hw_emu", project=project_name) + mod = allo_schedule.build( + target="vitis_hls", mode="hw_emu", project=project_name + ) elif MODE == "hw": - mod = allo_schedule.build(target="vitis_hls", mode="hw", project=project_name) + mod = allo_schedule.build( + target="vitis_hls", mode="hw", project=project_name + ) elif MODE == "csyn": - mod = allo_schedule.build(target="vitis_hls", mode="csyn", project=project_name) + mod = allo_schedule.build( + target="vitis_hls", mode="csyn", project=project_name + ) else: raise ValueError(f"Unsupported mode: {MODE}") - + # Run Allo inference print("\n[6] Running Allo inference...") if MODE == "llvm": allo_output = mod( - Q_np, K_np, V_np, - gate_weight_np, gate_bias_np, - expert_fc1_weights_np, expert_fc1_biases_np, - expert_fc2_weights_np, expert_fc2_biases_np + Q_np, + K_np, + V_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, ) elif MODE in ["sw_emu", "hw_emu", "hw"]: allo_output = np.zeros((batch_size, seq_len, embed_dim), dtype=np.float32) - mod(Q_np, K_np, V_np, - gate_weight_np, gate_bias_np, - expert_fc1_weights_np, expert_fc1_biases_np, - expert_fc2_weights_np, expert_fc2_biases_np, - allo_output) + mod( + Q_np, + K_np, + V_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + allo_output, + ) elif MODE == "csyn": allo_output = np.zeros((batch_size, seq_len, embed_dim), dtype=np.float32) mod() else: raise ValueError(f"Unsupported mode: {MODE}") - + print(f"Allo output shape: {allo_output.shape}") print(f"Allo output range: [{allo_output.min():.6f}, {allo_output.max():.6f}]") - - #---------------------------------------------------------------------------------- + + # ---------------------------------------------------------------------------------- # Compare the Allo and PyTorch outputs - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[7] Comparing outputs...") pytorch_output_np = pytorch_output.detach().numpy() - + # Compute differences diff = np.abs(allo_output - pytorch_output_np) mean_diff = np.mean(diff) max_diff = np.max(diff) rel_diff = np.mean(diff / (np.abs(pytorch_output_np) + 1e-8)) - + print(f"Mean absolute difference: {mean_diff:.6e}") print(f"Max absolute difference: {max_diff:.6e}") print(f"Mean relative difference: {rel_diff:.6e}") - + # Check if outputs are close atol = 5e-4 rtol = 2e-3 is_close = np.allclose(allo_output, pytorch_output_np, atol=atol, rtol=rtol) - + if is_close: - print(f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})") + print( + f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})" + ) else: - print(f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})") + print( + f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})" + ) print("First few differences:") print(diff.flatten()[:10]) - - #---------------------------------------------------------------------------------- + + # ---------------------------------------------------------------------------------- # Print sample outputs for comparison - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[8] Sample outputs (first token, first 5 dimensions):") print(f"PyTorch: {pytorch_output_np[0, 0, :5]}") print(f"Allo: {allo_output[0, 0, :5]}") print(f"Diff: {diff[0, 0, :5]}") - + except Exception as e: print(f"\n✗ ERROR: Failed to run Allo implementation: {e}") import traceback + traceback.print_exc() - + print("=" * 60) diff --git a/ece6775/attention_moe/allo_attention_moe_lib.py b/ece6775/attention_moe/allo_attention_moe_lib.py index d1833b89f..0e71b21f6 100644 --- a/ece6775/attention_moe/allo_attention_moe_lib.py +++ b/ece6775/attention_moe/allo_attention_moe_lib.py @@ -1,3 +1,5 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 # Attention + MoE in Allo # Q, K, V -> Attention -> MoE -> Output # uses library functions (nn.linear2d, nn.GeLU) for expert @@ -23,10 +25,12 @@ from llm_config import DEFAULT_CONFIG_MODE, get_moe_config, print_config_info CONFIG_MODE = DEFAULT_CONFIG_MODE + + def softmax_1d[Ty, N, K](X: "Ty[N, K]") -> "Ty[N, K]": """ Softmax over last dimension for [N, K] shape. - + Args: X: Input tensor of shape [N, K] Returns: @@ -36,32 +40,33 @@ def softmax_1d[Ty, N, K](X: "Ty[N, K]") -> "Ty[N, K]": E_exp: Ty[N, K] M: Ty[N] = -1000000000000.0 S: Ty[N] = 0.0 - + # Find max for each row (over dimension K) for n, k in dsl.grid(N, K, name="row_max"): if X[n, k] > M[n]: M[n] = X[n, k] - + # Compute exp and sum for n, k in dsl.grid(N, K, name="exp_sum"): E_exp[n, k] = dsl.exp(X[n, k] - M[n]) S[n] += E_exp[n, k] - + # Normalize for n, k in dsl.grid(N, K, name="update"): # Add small epsilon to prevent division by zero Z[n, k] = E_exp[n, k] / (S[n]) - + return Z -#---------------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------------- # Softmax 2D: Softmax over last dimension for [L, L] shape (for attention) -#---------------------------------------------------------------------------------- +# ---------------------------------------------------------------------------------- def softmax_2d[Ty, L](X: "Ty[L, L]") -> "Ty[L, L]": """ Softmax over last dimension for [L, L] shape. Used for attention scores. - + Args: X: Input tensor of shape [L, L] Returns: @@ -71,67 +76,66 @@ def softmax_2d[Ty, L](X: "Ty[L, L]") -> "Ty[L, L]": E_exp: Ty[L, L] M: Ty[L] = -1000000000000.0 S: Ty[L] = 0.0 - + # Find max for each row for i, j in dsl.grid(L, L, name="row_max"): if X[i, j] > M[i]: M[i] = X[i, j] - + # Compute exp and sum for i, j in dsl.grid(L, L, name="exp_sum"): E_exp[i, j] = dsl.exp(X[i, j] - M[i]) S[i] += E_exp[i, j] - + # Normalize for i, j in dsl.grid(L, L, name="update"): Z[i, j] = E_exp[i, j] / S[i] - + return Z -#---------------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------------- # Scaled Dot-Product Attention: Custom implementation matching PyTorch -#---------------------------------------------------------------------------------- -def scaled_dot_product_attention[Ty, H, L, D]( - Q: "Ty[L, D]", - K: "Ty[L, D]", - V: "Ty[L, D]" -) -> "Ty[L, D]": +# ---------------------------------------------------------------------------------- +def scaled_dot_product_attention[ + Ty, H, L, D +](Q: "Ty[L, D]", K: "Ty[L, D]", V: "Ty[L, D]") -> "Ty[L, D]": """ Scaled Dot-Product Attention (Multi-Head Attention). - + This implementation matches the PyTorch version: - Input: Q, K, V of shape [L, D] - Split into H heads, each with dimension D // H - For each head: softmax(QK^T / sqrt(head_dim)) @ V - Merge heads back to [L, D] - Uses linear2d for matrix multiplication (replacing systolic for HLS compatibility) - + Args: Q: Query tensor of shape [L, D] K: Key tensor of shape [L, D] V: Value tensor of shape [L, D] - + Returns: Z: Output tensor of shape [L, D] """ Z: Ty[L, D] = 0.0 - + # Compute scale factor: 1 / sqrt(head_dim) = 1 / sqrt(D // H) # For D=8, H=2: head_dim=4, scale = 1/sqrt(4) = 0.5 scale: Ty = 1.0 / dsl.sqrt(float(D // H)) - + # Process each head for h in range(H, name="head_loop"): # Split Q, K, V for this head Q_h: Ty[L, D // H] = 0.0 K_h: Ty[L, D // H] = 0.0 # Not transposed, will transpose for linear2d V_h: Ty[L, D // H] = 0.0 - + for i, j in dsl.grid(L, D // H, name="split_qkv"): Q_h[i, j] = Q[i, h * (D // H) + j] K_h[i, j] = K[i, h * (D // H) + j] # Store K normally (not transposed) V_h[i, j] = V[i, h * (D // H) + j] - + # Compute QK^T = [L, D//H] @ [D//H, L] = [L, L] # Using linear2d: Q_h @ K_h^T = linear2d(Q_h, K_h_transposed, 0) # linear2d computes X @ W^T + b, so we need K_h^T as W @@ -145,7 +149,7 @@ def scaled_dot_product_attention[Ty, H, L, D]( K_h_T: Ty[D // H, L] = 0.0 for i, j in dsl.grid(L, D // H, name="transpose_K"): K_h_T[j, i] = K_h[i, j] - + # Now use linear2d: Q_h[L, D//H] @ K_h_T^T[D//H, L] = Q_h @ K_h_T^T # But linear2d expects W[N, K], so we need W[L, D//H] such that W^T = [D//H, L] # Actually, we can use K_h directly as W[L, D//H], then W^T = [D//H, L] = K_h^T @@ -157,32 +161,32 @@ def scaled_dot_product_attention[Ty, H, L, D]( # - So W = K_h (which is [L, D//H]), then W^T = [D//H, L] = K_h^T ✓ zero_bias: Ty[L] = 0.0 Y = linear2d[Ty, Ty, Ty, L, L, D // H](Q_h, K_h, zero_bias) # Q_h @ K_h^T - + # Scale by 1/sqrt(head_dim) Y_scaled: Ty[L, L] = 0.0 for i, j in dsl.grid(L, L, name="scale"): Y_scaled[i, j] = Y[i, j] * scale - + # Apply softmax over last dimension S: Ty[L, L] = 0.0 E_exp: Ty[L, L] = 0.0 M: Ty[L] = -1000000000000.0 Sum: Ty[L] = 0.0 - + # Find max for each row for i, j in dsl.grid(L, L, name="softmax_max"): if Y_scaled[i, j] > M[i]: M[i] = Y_scaled[i, j] - + # Compute exp and sum for i, j in dsl.grid(L, L, name="softmax_exp"): E_exp[i, j] = dsl.exp(Y_scaled[i, j] - M[i]) Sum[i] += E_exp[i, j] - + # Normalize for i, j in dsl.grid(L, L, name="softmax_norm"): S[i, j] = E_exp[i, j] / Sum[i] - + # Compute S @ V_h = [L, L] @ [L, D//H] = [L, D//H] # Using linear2d: S @ V_h = linear2d(S, V_h_transposed, 0) # linear2d(X[M, K], W[N, K], b) = X @ W^T @@ -194,23 +198,24 @@ def scaled_dot_product_attention[Ty, H, L, D]( V_h_T: Ty[D // H, L] = 0.0 for i, j in dsl.grid(L, D // H, name="transpose_V"): V_h_T[j, i] = V_h[i, j] - + zero_bias_v: Ty[D // H] = 0.0 C_h = linear2d[Ty, Ty, Ty, L, D // H, L](S, V_h_T, zero_bias_v) # S @ V_h - + # Merge back to Z for i, j in dsl.grid(L, D // H, name="merge_heads"): Z[i, h * (D // H) + j] = C_h[i, j] - + return Z -#---------------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------------- # Top1 selection: Top-k selection for k=1 (argmax) -#---------------------------------------------------------------------------------- +# ---------------------------------------------------------------------------------- def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": """ Select top-1 expert (argmax) for each token. - + Args: logits: Input logits of shape [N, E] Returns: @@ -218,37 +223,40 @@ def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": """ indices: int32[N] max_val: Ty[N] = -1000000000000.0 - + # Initialize indices and max_val with first expert for n in range(N, name="init"): indices[n] = 0 max_val[n] = logits[n, 0] - + # Find argmax for each token (search from index 1 onwards) for n, e in dsl.grid(N, E, name="argmax"): if e > 0: # Skip e=0 (already initialized) if logits[n, e] > max_val[n]: max_val[n] = logits[n, e] indices[n] = e - + return indices -#---------------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------------- # Expert: A simple feed-forward expert network -#---------------------------------------------------------------------------------- -def expert[Ty, N, D_in, D_hidden, D_out]( +# ---------------------------------------------------------------------------------- +def expert[ + Ty, N, D_in, D_hidden, D_out +]( x: "Ty[N, D_in]", fc1_weight: "Ty[D_hidden, D_in]", fc1_bias: "Ty[D_hidden]", fc2_weight: "Ty[D_out, D_hidden]", - fc2_bias: "Ty[D_out]" + fc2_bias: "Ty[D_out]", ) -> "Ty[N, D_out]": """ A simple feed-forward expert network. - + This implementation uses linear2d and GeLU from Allo library for optimized computation. - + Args: x: Input tensor of shape [N, D_in] fc1_weight: First linear layer weights of shape [D_hidden, D_in] @@ -260,19 +268,22 @@ def expert[Ty, N, D_in, D_hidden, D_out]( """ # Step 1: First linear layer using linear2d from library fc1_out = linear2d[Ty, Ty, Ty, N, D_hidden, D_in](x, fc1_weight, fc1_bias) - + # Step 2: GELU activation using GeLU from library gelu_out = GeLU[Ty, N, D_hidden](fc1_out) - + # Step 3: Second linear layer using linear2d from library fc2_out = linear2d[Ty, Ty, Ty, N, D_out, D_hidden](gelu_out, fc2_weight, fc2_bias) - + return fc2_out -#---------------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------------- # MoE Layer: Main MoE layer that routes tokens to experts -#---------------------------------------------------------------------------------- -def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( +# ---------------------------------------------------------------------------------- +def moe_layer[ + Ty, N, D_in, D_out, E, K, D_hidden +]( x: "Ty[N, D_in]", # Gate weights gate_weight: "Ty[E, D_in]", @@ -281,11 +292,11 @@ def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( expert_fc1_weights: "Ty[E, D_hidden, D_in]", expert_fc1_biases: "Ty[E, D_hidden]", expert_fc2_weights: "Ty[E, D_out, D_hidden]", - expert_fc2_biases: "Ty[E, D_out]" + expert_fc2_biases: "Ty[E, D_out]", ) -> "Ty[N, D_out]": """ Mixture of Experts layer for inference. - + Args: x: Input tensor of shape [N, D_in] (already flattened from [B, L, D_in]) gate_weight: Gate weight matrix of shape [E, D_in] @@ -299,74 +310,77 @@ def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( """ # Step 1: Compute gate logits using linear2d gate_logits = linear2d[Ty, Ty, Ty, N, E, D_in](x, gate_weight, gate_bias) # [N, E] - + # Step 2: Select top-1 expert using top1_select function top1_indices_1d: int32[N] = top1_select[Ty, N, E](gate_logits) - + # Step 3: Get top-k logits and apply softmax top_k_logits: Ty[N, K] = 0.0 for n, k in dsl.grid(N, K, name="topk_logits"): expert_idx = top1_indices_1d[n] if k == 0 else 0 # For k=1, K=1 top_k_logits[n, k] = gate_logits[n, expert_idx] - + # Step 4: Apply softmax to top-k logits using softmax_1d function top_k_weights = softmax_1d[Ty, N, K](top_k_logits) # [N, K] - + # Step 5: Create sparse weight matrix from top-k weights top_k_indices: int32[N, K] = 0 gate_weights: Ty[N, E] = 0.0 - + for n in range(N, name="gate_weights"): # Store top-k indices (for k=1) top_k_indices[n, 0] = top1_indices_1d[n] # Set gate weights from softmax output expert_idx = top1_indices_1d[n] gate_weights[n, expert_idx] = top_k_weights[n, 0] # For k=1, K=1 - + # Step 6: Process each expert: compute outputs for all tokens expert_outputs: Ty[E, N, D_out] = 0.0 - + for e in range(E, name="expert_loop"): # Extract expert weights for this expert expert_fc1_w: Ty[D_hidden, D_in] = 0.0 expert_fc1_b: Ty[D_hidden] = 0.0 expert_fc2_w: Ty[D_out, D_hidden] = 0.0 expert_fc2_b: Ty[D_out] = 0.0 - + for d_hidden, d_in in dsl.grid(D_hidden, D_in, name="extract_fc1_w"): expert_fc1_w[d_hidden, d_in] = expert_fc1_weights[e, d_hidden, d_in] - + for d_hidden in range(D_hidden, name="extract_fc1_b"): expert_fc1_b[d_hidden] = expert_fc1_biases[e, d_hidden] - + for d_out, d_hidden in dsl.grid(D_out, D_hidden, name="extract_fc2_w"): expert_fc2_w[d_out, d_hidden] = expert_fc2_weights[e, d_out, d_hidden] - + for d_out in range(D_out, name="extract_fc2_b"): expert_fc2_b[d_out] = expert_fc2_biases[e, d_out] - + # Process all tokens through this expert using the expert function expert_out = expert[Ty, N, D_in, D_hidden, D_out]( x, expert_fc1_w, expert_fc1_b, expert_fc2_w, expert_fc2_b ) # [N, D_out] - + # Store expert outputs for n, d_out in dsl.grid(N, D_out, name="store_expert_out"): expert_outputs[e, n, d_out] = expert_out[n, d_out] - + # Step 7: Combine expert outputs using gate weights output: Ty[N, D_out] = 0.0 for n, e, d_out in dsl.grid(N, E, D_out, name="combine_outputs"): weight: Ty = gate_weights[n, e] output[n, d_out] += expert_outputs[e, n, d_out] * weight - + return output -#---------------------------------------------------------------------------------- + +# ---------------------------------------------------------------------------------- # Attention + MoE Layer: Combined layer # Data flow: Q, K, V -> Attention -> MoE -> Output -#---------------------------------------------------------------------------------- -def attention_moe_layer[Ty, B, L, D, H, E, TopK, D_hidden]( +# ---------------------------------------------------------------------------------- +def attention_moe_layer[ + Ty, B, L, D, H, E, TopK, D_hidden +]( Query: "Ty[B, L, D]", Key: "Ty[B, L, D]", Value: "Ty[B, L, D]", @@ -377,13 +391,13 @@ def attention_moe_layer[Ty, B, L, D, H, E, TopK, D_hidden]( expert_fc1_weights: "Ty[E, D_hidden, D]", expert_fc1_biases: "Ty[E, D_hidden]", expert_fc2_weights: "Ty[E, D, D_hidden]", - expert_fc2_biases: "Ty[E, D]" + expert_fc2_biases: "Ty[E, D]", ) -> "Ty[B, L, D]": """ Combined Attention + MoE layer. - + Data flow: Q, K, V -> Attention -> MoE -> Output - + Args: Query: Query tensor of shape [B, L, D] Key: Key tensor of shape [B, L, D] @@ -399,23 +413,23 @@ def attention_moe_layer[Ty, B, L, D, H, E, TopK, D_hidden]( """ # Output tensor output: Ty[B, L, D] = 0.0 - + # Process each batch item for b in range(B, name="batch_loop"): # Step 1: Extract Q, K, V for this batch item -> [L, D] Q_b: Ty[L, D] = 0.0 K_b: Ty[L, D] = 0.0 V_b: Ty[L, D] = 0.0 - + for l, d in dsl.grid(L, D, name="extract_qkv"): Q_b[l, d] = Query[b, l, d] K_b[l, d] = Key[b, l, d] V_b[l, d] = Value[b, l, d] - + # Step 2: Apply custom scaled_dot_product_attention # scaled_dot_product_attention[Ty, H, L, D](Q, K, V) -> [L, D] attn_out = scaled_dot_product_attention[Ty, H, L, D](Q_b, K_b, V_b) # [L, D] - + # Step 3: Apply MoE layer # moe_layer expects [N, D_in], so we use N=L for single batch moe_out = moe_layer[Ty, L, D, D, E, TopK, D_hidden]( @@ -425,25 +439,25 @@ def attention_moe_layer[Ty, B, L, D, H, E, TopK, D_hidden]( expert_fc1_weights, expert_fc1_biases, expert_fc2_weights, - expert_fc2_biases + expert_fc2_biases, ) # [L, D] - + # Step 4: Store output for this batch item for l, d in dsl.grid(L, D, name="store_output"): output[b, l, d] = moe_out[l, d] - + return output -#================================================================================== +# ================================================================================== # Schedule optimization function -#================================================================================== +# ================================================================================== def optimize_attention_moe_with_composition( batch_size, seq_len, embed_dim, num_heads, num_experts, k, hidden_dim ): """ Create optimized schedules for Attention + MoE and compose them together. - + Args: batch_size: Batch size (B) seq_len: Sequence length (L) @@ -452,7 +466,7 @@ def optimize_attention_moe_with_composition( num_experts: Number of experts (E) k: Top-k value (currently only k=1 is supported) hidden_dim: Hidden dimension for experts - + Returns: s_attn_moe: Optimized schedule for attention_moe_layer with all sub-schedules composed """ @@ -464,46 +478,50 @@ def optimize_attention_moe_with_composition( E = num_experts K = k D_hidden = hidden_dim - + print("=" * 60) print("Creating and optimizing Attention + MoE schedules...") print("=" * 60) - + # Step 1: Create schedule for top1_select print("\n[1] Creating schedule for top1_select...") s_top1 = allo.customize(top1_select, instantiate=[Ty, L, E]) print(" - Created top1_select schedule for [L, E]") - + # Step 2: Create schedule for softmax_1d (for MoE gate) print("\n[2] Creating schedule for softmax_1d...") s_softmax_1d = allo.customize(softmax_1d, instantiate=[Ty, L, K]) print(" - Created softmax_1d schedule for [L, K]") - + # Step 3: Create schedule for expert print("\n[3] Creating schedule for expert...") - + # Create schedules for library functions that expert uses print(" - Creating schedules for library functions (linear2d, GeLU)...") - s_linear_fc1 = allo.customize(allo_nn.linear2d, instantiate=[Ty, Ty, Ty, L, D_hidden, D]) + s_linear_fc1 = allo.customize( + allo_nn.linear2d, instantiate=[Ty, Ty, Ty, L, D_hidden, D] + ) s_gelu = allo.customize(allo_nn.GeLU, instantiate=[Ty, L, D_hidden]) - s_linear_fc2 = allo.customize(allo_nn.linear2d, instantiate=[Ty, Ty, Ty, L, D, D_hidden]) + s_linear_fc2 = allo.customize( + allo_nn.linear2d, instantiate=[Ty, Ty, Ty, L, D, D_hidden] + ) print(" - Library function schedules created") - + # Create expert schedule print(" - Creating expert schedule...") s_expert = allo.customize(expert, instantiate=[Ty, L, D, D_hidden, D]) - + # Compose library function schedules s_expert.compose(s_linear_fc1, id="expert_fc1") s_expert.compose(s_gelu, id="expert_gelu") s_expert.compose(s_linear_fc2, id="expert_fc2") print(" - Created expert schedule") print(" - Composed nn.linear2d and nn.GeLU schedules for expert") - + # Step 4: Create schedule for moe_layer print("\n[4] Creating schedule for moe_layer...") s_moe = allo.customize(moe_layer, instantiate=[Ty, L, D, D, E, K, D_hidden]) - + # Compose gate linear s_gate_linear = allo.customize(allo_nn.linear2d, instantiate=[Ty, Ty, Ty, L, E, D]) s_moe.compose(s_gate_linear, id="gate") @@ -511,54 +529,54 @@ def optimize_attention_moe_with_composition( s_moe.compose(s_softmax_1d) s_moe.compose(s_expert) print(" - Created moe_layer schedule") - + # Step 5: Create schedule for custom attention print("\n[5] Creating schedule for custom scaled_dot_product_attention...") s_attn = allo.customize(scaled_dot_product_attention, instantiate=[Ty, H, L, D]) print(" - Created scaled_dot_product_attention schedule") - + # Step 6: Create schedule for main attention_moe_layer function print("\n[6] Creating schedule for attention_moe_layer...") s_attn_moe = allo.customize( - attention_moe_layer, - instantiate=[Ty, B, L, D, H, E, K, D_hidden] + attention_moe_layer, instantiate=[Ty, B, L, D, H, E, K, D_hidden] ) - + # Step 7: Compose all schedules together print("\n[7] Composing all schedules together...") s_attn_moe.compose(s_attn) s_attn_moe.compose(s_moe) print(" - Composed scaled_dot_product_attention schedule") print(" - Composed moe_layer schedule") - + print("\n" + "=" * 60) print("Schedule composition complete!") print("=" * 60) - + return s_attn_moe -#================================================================================== +# ================================================================================== # Test function to compare Allo and PyTorch implementations -#================================================================================== +# ================================================================================== if __name__ == "__main__": import torch import torch.nn as torch_nn from pytorch_attention_moe import AttentionMoE + # ============================================================================ # Configuration parameters - use shared config from llm_config.py # ============================================================================ # Get MoE configuration from shared config moe_config = get_moe_config(CONFIG_MODE) - + batch_size = moe_config["batch_size"] seq_len = moe_config["seq_len"] embed_dim = moe_config["input_dim"] # D, must be divisible by num_heads num_experts = moe_config["num_experts"] # E k = moe_config["k"] # Top-k MoE hidden_dim = moe_config["hidden_dim"] # D_hidden - + # Attention-specific parameter: num_heads # Choose num_heads such that embed_dim is divisible by num_heads # Common choices: 2, 4, 8, 12, 16 (depending on embed_dim) @@ -574,13 +592,13 @@ def optimize_attention_moe_with_composition( num_heads = 4 else: num_heads = 2 - + # Ensure embed_dim is divisible by num_heads while embed_dim % num_heads != 0 and num_heads > 1: num_heads -= 1 - + seed = 42 - + print("=" * 60) print("Attention + MoE Allo Implementation Test") print("=" * 60) @@ -591,15 +609,15 @@ def optimize_attention_moe_with_composition( print(f" num_experts={num_experts}, k={k}, hidden_dim={hidden_dim}") print(f" Seed: {seed}") print("=" * 60) - - #---------------------------------------------------------------------------------- + + # ---------------------------------------------------------------------------------- # Run PyTorch implementation to get weights and outputs - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[1] Running PyTorch implementation...") torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) - + # Create PyTorch AttentionMoE layer pytorch_model = AttentionMoE( seq_len=seq_len, @@ -607,165 +625,212 @@ def optimize_attention_moe_with_composition( num_heads=num_heads, num_experts=num_experts, k=k, - expert_hidden_dim=hidden_dim + expert_hidden_dim=hidden_dim, ) pytorch_model.eval() - + # Initialize with Xavier uniform for param in pytorch_model.parameters(): if param.dim() > 1: torch_nn.init.xavier_uniform_(param) else: torch_nn.init.zeros_(param) - + # Create random inputs (Q, K, V) torch.manual_seed(seed) Q_pt = torch.randn(batch_size, seq_len, embed_dim) K_pt = torch.randn(batch_size, seq_len, embed_dim) V_pt = torch.randn(batch_size, seq_len, embed_dim) - + # Run PyTorch inference with torch.no_grad(): pytorch_output = pytorch_model(Q_pt, K_pt, V_pt, verbose=False) - + print(f"PyTorch output shape: {pytorch_output.shape}") - print(f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]") - - #---------------------------------------------------------------------------------- + print( + f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]" + ) + + # ---------------------------------------------------------------------------------- # Extract weights and biases from PyTorch model - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[2] Extracting weights from PyTorch model...") - + # Gate weights - gate_weight_pt = pytorch_model.moe.gate.gate_linear.weight.data # [num_experts, embed_dim] + gate_weight_pt = ( + pytorch_model.moe.gate.gate_linear.weight.data + ) # [num_experts, embed_dim] gate_bias_pt = pytorch_model.moe.gate.gate_linear.bias if gate_bias_pt is not None: gate_bias_pt = gate_bias_pt.data else: gate_bias_pt = torch.zeros(num_experts) - + # Expert weights - expert_fc1_weights_pt = torch.stack([exp.fc1.weight.data for exp in pytorch_model.moe.experts]) - expert_fc1_biases_pt = torch.stack([exp.fc1.bias.data for exp in pytorch_model.moe.experts]) - expert_fc2_weights_pt = torch.stack([exp.fc2.weight.data for exp in pytorch_model.moe.experts]) - expert_fc2_biases_pt = torch.stack([exp.fc2.bias.data for exp in pytorch_model.moe.experts]) - - #---------------------------------------------------------------------------------- + expert_fc1_weights_pt = torch.stack( + [exp.fc1.weight.data for exp in pytorch_model.moe.experts] + ) + expert_fc1_biases_pt = torch.stack( + [exp.fc1.bias.data for exp in pytorch_model.moe.experts] + ) + expert_fc2_weights_pt = torch.stack( + [exp.fc2.weight.data for exp in pytorch_model.moe.experts] + ) + expert_fc2_biases_pt = torch.stack( + [exp.fc2.bias.data for exp in pytorch_model.moe.experts] + ) + + # ---------------------------------------------------------------------------------- # Convert to numpy arrays - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[3] Converting weights to numpy arrays...") Q_np = np.ascontiguousarray(Q_pt.detach().numpy(), dtype=np.float32) K_np = np.ascontiguousarray(K_pt.detach().numpy(), dtype=np.float32) V_np = np.ascontiguousarray(V_pt.detach().numpy(), dtype=np.float32) - - gate_weight_np = np.ascontiguousarray(gate_weight_pt.detach().numpy(), dtype=np.float32) + + gate_weight_np = np.ascontiguousarray( + gate_weight_pt.detach().numpy(), dtype=np.float32 + ) gate_bias_np = np.ascontiguousarray(gate_bias_pt.detach().numpy(), dtype=np.float32) - - expert_fc1_weights_np = np.ascontiguousarray(expert_fc1_weights_pt.detach().numpy(), dtype=np.float32) - expert_fc1_biases_np = np.ascontiguousarray(expert_fc1_biases_pt.detach().numpy(), dtype=np.float32) - expert_fc2_weights_np = np.ascontiguousarray(expert_fc2_weights_pt.detach().numpy(), dtype=np.float32) - expert_fc2_biases_np = np.ascontiguousarray(expert_fc2_biases_pt.detach().numpy(), dtype=np.float32) - + + expert_fc1_weights_np = np.ascontiguousarray( + expert_fc1_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc1_biases_np = np.ascontiguousarray( + expert_fc1_biases_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_weights_np = np.ascontiguousarray( + expert_fc2_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_biases_np = np.ascontiguousarray( + expert_fc2_biases_pt.detach().numpy(), dtype=np.float32 + ) + print(f"Q shape: {Q_np.shape}") print(f"K shape: {K_np.shape}") print(f"V shape: {V_np.shape}") print(f"Gate weight shape: {gate_weight_np.shape}") print(f"Expert FC1 weights shape: {expert_fc1_weights_np.shape}") print(f"Expert FC2 weights shape: {expert_fc2_weights_np.shape}") - - #---------------------------------------------------------------------------------- + + # ---------------------------------------------------------------------------------- # Run Allo implementation - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[4] Running Allo implementation...") try: # Create optimized schedule with composition allo_schedule = optimize_attention_moe_with_composition( batch_size, seq_len, embed_dim, num_heads, num_experts, k, hidden_dim ) - + # Generate project name project_name = f"allo_attention_moe_lib_{CONFIG_MODE}.prj" print(f"Using project name: {project_name}") - + # Build module print("\n[5] Building Allo module...") if MODE == "llvm": mod = allo_schedule.build(target="llvm") elif MODE == "sw_emu": - mod = allo_schedule.build(target="vitis_hls", mode="sw_emu", project=project_name) + mod = allo_schedule.build( + target="vitis_hls", mode="sw_emu", project=project_name + ) elif MODE == "hw_emu": - mod = allo_schedule.build(target="vitis_hls", mode="hw_emu", project=project_name) + mod = allo_schedule.build( + target="vitis_hls", mode="hw_emu", project=project_name + ) elif MODE == "hw": - mod = allo_schedule.build(target="vitis_hls", mode="hw", project=project_name) + mod = allo_schedule.build( + target="vitis_hls", mode="hw", project=project_name + ) elif MODE == "csyn": - mod = allo_schedule.build(target="vitis_hls", mode="csyn", project=project_name) + mod = allo_schedule.build( + target="vitis_hls", mode="csyn", project=project_name + ) else: raise ValueError(f"Unsupported mode: {MODE}") - + # Run Allo inference print("\n[6] Running Allo inference...") if MODE == "llvm": allo_output = mod( - Q_np, K_np, V_np, - gate_weight_np, gate_bias_np, - expert_fc1_weights_np, expert_fc1_biases_np, - expert_fc2_weights_np, expert_fc2_biases_np + Q_np, + K_np, + V_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, ) elif MODE in ["sw_emu", "hw_emu", "hw"]: allo_output = np.zeros((batch_size, seq_len, embed_dim), dtype=np.float32) - mod(Q_np, K_np, V_np, - gate_weight_np, gate_bias_np, - expert_fc1_weights_np, expert_fc1_biases_np, - expert_fc2_weights_np, expert_fc2_biases_np, - allo_output) + mod( + Q_np, + K_np, + V_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + allo_output, + ) elif MODE == "csyn": allo_output = np.zeros((batch_size, seq_len, embed_dim), dtype=np.float32) mod() else: raise ValueError(f"Unsupported mode: {MODE}") - + print(f"Allo output shape: {allo_output.shape}") print(f"Allo output range: [{allo_output.min():.6f}, {allo_output.max():.6f}]") - - #---------------------------------------------------------------------------------- + + # ---------------------------------------------------------------------------------- # Compare the Allo and PyTorch outputs - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[7] Comparing outputs...") pytorch_output_np = pytorch_output.detach().numpy() - + # Compute differences diff = np.abs(allo_output - pytorch_output_np) mean_diff = np.mean(diff) max_diff = np.max(diff) rel_diff = np.mean(diff / (np.abs(pytorch_output_np) + 1e-8)) - + print(f"Mean absolute difference: {mean_diff:.6e}") print(f"Max absolute difference: {max_diff:.6e}") print(f"Mean relative difference: {rel_diff:.6e}") - + # Check if outputs are close atol = 5e-4 rtol = 2e-3 is_close = np.allclose(allo_output, pytorch_output_np, atol=atol, rtol=rtol) - + if is_close: - print(f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})") + print( + f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})" + ) else: - print(f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})") + print( + f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})" + ) print("First few differences:") print(diff.flatten()[:10]) - - #---------------------------------------------------------------------------------- + + # ---------------------------------------------------------------------------------- # Print sample outputs for comparison - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[8] Sample outputs (first token, first 5 dimensions):") print(f"PyTorch: {pytorch_output_np[0, 0, :5]}") print(f"Allo: {allo_output[0, 0, :5]}") print(f"Diff: {diff[0, 0, :5]}") - + except Exception as e: print(f"\n✗ ERROR: Failed to run Allo implementation: {e}") import traceback + traceback.print_exc() - + print("=" * 60) diff --git a/ece6775/attention_moe/pytorch_attention_moe.py b/ece6775/attention_moe/pytorch_attention_moe.py index f942dbcc6..068430a6a 100644 --- a/ece6775/attention_moe/pytorch_attention_moe.py +++ b/ece6775/attention_moe/pytorch_attention_moe.py @@ -1,3 +1,5 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 """ Attention + Mixture of Experts (MoE) Implementation for Inference Only @@ -19,174 +21,186 @@ class ScaledDotProductAttention(nn.Module): """ Scaled Dot-Product Attention (Multi-Head Attention). - + This implementation matches the Allo library version: - Input: Q, K, V of shape [L, D] - Split into H heads, each with dimension D // H - For each head: softmax(QK^T / sqrt(D // H)) @ V - Merge heads back to [L, D] - + Args: num_heads: Number of attention heads (H) head_dim: Dimension per head (D // H) """ - + def __init__(self, num_heads: int, embed_dim: int): super().__init__() self.num_heads = num_heads self.embed_dim = embed_dim self.head_dim = embed_dim // num_heads assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" - + # Scaling factor: 1 / sqrt(head_dim) self.scale = 1.0 / math.sqrt(self.head_dim) - - def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, verbose: bool = False) -> torch.Tensor: + + def forward( + self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, verbose: bool = False + ) -> torch.Tensor: """ Forward pass through scaled dot-product attention. - + Args: Q: Query tensor of shape [L, D] K: Key tensor of shape [L, D] V: Value tensor of shape [L, D] verbose: Whether to print debug information - + Returns: Output tensor of shape [L, D] """ L, D = Q.shape H = self.num_heads head_dim = D // H - + # Initialize output Z = torch.zeros(L, D, device=Q.device, dtype=Q.dtype) - + # Process each head for h in range(H): # Split Q, K, V for this head start_idx = h * head_dim end_idx = (h + 1) * head_dim - + Q_h = Q[:, start_idx:end_idx] # [L, head_dim] K_h = K[:, start_idx:end_idx] # [L, head_dim] V_h = V[:, start_idx:end_idx] # [L, head_dim] - + # QK^T = [L, head_dim] @ [head_dim, L] = [L, L] # Note: K_h.T is the transpose Y = torch.matmul(Q_h, K_h.T) # [L, L] - + # Scale by 1/sqrt(head_dim) Y = Y * self.scale - + # Softmax over last dimension S = F.softmax(Y, dim=-1) # [L, L] - + # YV = [L, L] @ [L, head_dim] = [L, head_dim] C_h = torch.matmul(S, V_h) # [L, head_dim] - + # Merge back to Z Z[:, start_idx:end_idx] = C_h - + if verbose: print(f"\n[Attention] L={L}, D={D}, H={H}, head_dim={head_dim}") print(f"[Attention] scale={self.scale:.6f}") print(f"[Attention] Output shape: {Z.shape}") - + return Z class TopKGate(nn.Module): """ Gate module to select top k experts for routing. - + Args: input_dim: Input feature dimension num_experts: Total number of experts k: Number of experts to select per token """ - + def __init__(self, input_dim: int, num_experts: int, k: int = 1): super().__init__() self.k = k self.num_experts = num_experts # Linear layer to compute expert logits self.gate_linear = nn.Linear(input_dim, num_experts, bias=False) - - def forward(self, x: torch.Tensor, verbose: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + + def forward( + self, x: torch.Tensor, verbose: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass through the gate. - + Args: x: Input tensor of shape [batch_size * seq_len, input_dim] verbose: Whether to print routing information - + Returns: full_weights: Sparse weight matrix [batch_size * seq_len, num_experts] top_k_indices: Indices of selected experts [batch_size * seq_len, k] """ # Compute logits for all experts logits = self.gate_linear(x) # [N, num_experts] - + # Select top-k experts top_k_logits, top_k_indices = torch.topk(logits, self.k, dim=-1) - + # Apply softmax to top-k logits for normalized weights top_k_weights = F.softmax(top_k_logits, dim=-1) # [N, k] - + # Create sparse weight matrix (zeros for non-selected experts) full_weights = torch.zeros_like(logits) full_weights.scatter_(1, top_k_indices, top_k_weights) - + # Print routing information if verbose if verbose: num_tokens = x.shape[0] - + # Count tokens per expert expert_counts = {} for expert_idx in range(self.num_experts): count = (top_k_indices == expert_idx).sum().item() if count > 0: expert_counts[expert_idx] = count - + # Verify that each token selects exactly k experts tokens_per_expert_count = {} for i in range(num_tokens): num_experts_selected = (top_k_weights[i] > 1e-6).sum().item() - tokens_per_expert_count[num_experts_selected] = tokens_per_expert_count.get(num_experts_selected, 0) + 1 - + tokens_per_expert_count[num_experts_selected] = ( + tokens_per_expert_count.get(num_experts_selected, 0) + 1 + ) + print(f"\n[Gate Routing] Total tokens: {num_tokens}, k={self.k}") print(f"[Gate Routing] Expert distribution:") total_selections = num_tokens * self.k # Total expert-token pairs for expert_idx, count in sorted(expert_counts.items()): percentage = (count / total_selections) * 100 - print(f" Expert {expert_idx}: {count} selections ({percentage:.1f}% of {total_selections} total selections)") - + print( + f" Expert {expert_idx}: {count} selections ({percentage:.1f}% of {total_selections} total selections)" + ) + # Verify top-k: each token selects exactly k experts if tokens_per_expert_count.get(self.k, 0) == num_tokens: - print(f"[Gate Routing] ✓ All {num_tokens} tokens select exactly {self.k} expert(s)") + print( + f"[Gate Routing] ✓ All {num_tokens} tokens select exactly {self.k} expert(s)" + ) else: - print(f"[Gate Routing] ⚠ Expected all tokens to select {self.k} expert(s), but got: {tokens_per_expert_count}") - + print( + f"[Gate Routing] ⚠ Expected all tokens to select {self.k} expert(s), but got: {tokens_per_expert_count}" + ) + return full_weights, top_k_indices class Expert(nn.Module): """ A simple feed-forward expert network. - + Args: input_dim: Input feature dimension hidden_dim: Hidden layer dimension output_dim: Output feature dimension """ - + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): super().__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, output_dim) # Use GELU activation (modern choice) self.activation = nn.GELU() - + def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the expert.""" x = self.fc1(x) @@ -198,7 +212,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MoELayer(nn.Module): """ Mixture of Experts layer for inference. - + Args: input_dim: Input feature dimension output_dim: Output feature dimension @@ -206,141 +220,158 @@ class MoELayer(nn.Module): k: Number of experts to activate per token expert_hidden_dim: Hidden dimension for experts (default: 4 * input_dim) """ - + def __init__( self, input_dim: int, output_dim: int, num_experts: int, k: int = 1, - expert_hidden_dim: Optional[int] = None + expert_hidden_dim: Optional[int] = None, ): super().__init__() self.num_experts = num_experts self.k = k self.output_dim = output_dim - + if expert_hidden_dim is None: expert_hidden_dim = input_dim * 4 # Common practice - + # Initialize gate and experts self.gate = TopKGate(input_dim, num_experts, k) - self.experts = nn.ModuleList([ - Expert(input_dim, expert_hidden_dim, output_dim) - for _ in range(num_experts) - ]) - + self.experts = nn.ModuleList( + [ + Expert(input_dim, expert_hidden_dim, output_dim) + for _ in range(num_experts) + ] + ) + def forward(self, x: torch.Tensor, verbose: bool = False) -> torch.Tensor: """ Forward pass through MoE layer. - + Args: x: Input tensor of shape [batch_size, seq_len, input_dim] verbose: Whether to print verification information for top-1 MoE - + Returns: Output tensor of shape [batch_size, seq_len, output_dim] """ - + # Store original shape original_shape = x.shape batch_size, seq_len = original_shape[0], original_shape[1] - + # Flatten batch and sequence dimensions x_flat = x.view(-1, original_shape[-1]) # [N, input_dim], N = batch * seq_len num_tokens = x_flat.shape[0] - + # Get gating weights and expert indices - gate_weights, top_k_indices = self.gate(x_flat, verbose=verbose) # [N, num_experts], [N, k] - + gate_weights, top_k_indices = self.gate( + x_flat, verbose=verbose + ) # [N, num_experts], [N, k] + # Initialize output tensor output = torch.zeros( - x_flat.shape[0], - self.output_dim, - device=x.device, - dtype=x.dtype + x_flat.shape[0], self.output_dim, device=x.device, dtype=x.dtype ) - + # Track expert usage statistics expert_usage_stats = {} - + # Process each expert for expert_idx in range(self.num_experts): # Find tokens assigned to this expert # Create mask: tokens that have this expert in their top-k expert_mask = (top_k_indices == expert_idx).any(dim=-1) # [N] - + num_tokens_for_expert = expert_mask.sum().item() - + if not expert_mask.any(): expert_usage_stats[expert_idx] = 0 continue # No tokens for this expert - + # Track usage expert_usage_stats[expert_idx] = num_tokens_for_expert - + # Get inputs for this expert expert_inputs = x_flat[expert_mask] # [num_tokens, input_dim] - + # Process through expert - expert_outputs = self.experts[expert_idx](expert_inputs) # [num_tokens, output_dim] - + expert_outputs = self.experts[expert_idx]( + expert_inputs + ) # [num_tokens, output_dim] + # Get corresponding weights for these tokens # For each token, find the weight for this expert (could be in any of k positions) token_indices = torch.where(expert_mask)[0] # Indices in flattened tensor - expert_weights = gate_weights[token_indices, expert_idx].unsqueeze(1) # [num_tokens, 1] - + expert_weights = gate_weights[token_indices, expert_idx].unsqueeze( + 1 + ) # [num_tokens, 1] + # Weight and accumulate outputs weighted_outputs = expert_outputs * expert_weights output[token_indices] += weighted_outputs - + # Verify top-k behavior if verbose: - active_experts = [idx for idx, count in expert_usage_stats.items() if count > 0] + active_experts = [ + idx for idx, count in expert_usage_stats.items() if count > 0 + ] total_usage = sum(expert_usage_stats.values()) - + # Verify top-k: each token selects exactly k experts if top_k_indices.shape[1] != self.k: - print(f"[MoE Verification] ✗ ERROR: Expected k={self.k}, but top_k_indices has shape {top_k_indices.shape}") - + print( + f"[MoE Verification] ✗ ERROR: Expected k={self.k}, but top_k_indices has shape {top_k_indices.shape}" + ) + # Verify all tokens are processed (for k>1, total_usage may be > num_tokens) expected_total_usage = num_tokens * self.k if total_usage != expected_total_usage: - print(f"[MoE Verification] ⚠ Expected {expected_total_usage} expert-token pairs (={num_tokens} tokens × {self.k} experts), but got {total_usage}") - + print( + f"[MoE Verification] ⚠ Expected {expected_total_usage} expert-token pairs (={num_tokens} tokens × {self.k} experts), but got {total_usage}" + ) + # Check if weights are normalized (should sum to 1.0 per token) sample_weights = gate_weights.sum(dim=1) - weights_normalized = torch.allclose(sample_weights, torch.ones_like(sample_weights), atol=1e-5) - + weights_normalized = torch.allclose( + sample_weights, torch.ones_like(sample_weights), atol=1e-5 + ) + # Final summary for top-k print(f"\n[MoE Verification] === Top-{self.k} MoE Verification ===") print(f" ✓ Each token routes to exactly {self.k} expert(s) (k={self.k})") print(f" ✓ All {num_tokens} tokens processed") - print(f" ✓ Total expert-token pairs: {total_usage} (expected: {expected_total_usage})") + print( + f" ✓ Total expert-token pairs: {total_usage} (expected: {expected_total_usage})" + ) print(f" ✓ Active experts: {len(active_experts)}/{self.num_experts}") - print(f" ✓ Expert distribution: {dict(sorted(expert_usage_stats.items()))}") + print( + f" ✓ Expert distribution: {dict(sorted(expert_usage_stats.items()))}" + ) if weights_normalized: print(f" ✓ Gate weights normalized (sum to 1.0 per token)") else: print(f" ✗ Gate weights NOT normalized correctly") print(f" === Top-{self.k} MoE is working correctly! ===") - + # Reshape back to original shape output = output.view(batch_size, seq_len, self.output_dim) - + return output class AttentionMoE(nn.Module): """ Combined Attention + MoE layer. - + Data flow: Input -> Attention -> MoE -> Output - + This follows the Transformer architecture pattern where: 1. Attention computes self-attention on the input 2. MoE replaces the standard FFN (Feed-Forward Network) - + Args: seq_len: Sequence length (L) embed_dim: Embedding dimension (D) @@ -349,7 +380,7 @@ class AttentionMoE(nn.Module): k: Top-k experts to activate per token expert_hidden_dim: Hidden dimension for experts """ - + def __init__( self, seq_len: int, @@ -357,68 +388,64 @@ def __init__( num_heads: int, num_experts: int, k: int = 1, - expert_hidden_dim: Optional[int] = None + expert_hidden_dim: Optional[int] = None, ): super().__init__() self.seq_len = seq_len self.embed_dim = embed_dim self.num_heads = num_heads - + # Attention layer self.attention = ScaledDotProductAttention(num_heads, embed_dim) - + # MoE layer (input_dim = output_dim = embed_dim) self.moe = MoELayer( input_dim=embed_dim, output_dim=embed_dim, num_experts=num_experts, k=k, - expert_hidden_dim=expert_hidden_dim + expert_hidden_dim=expert_hidden_dim, ) - + def forward( - self, - Q: torch.Tensor, - K: torch.Tensor, - V: torch.Tensor, - verbose: bool = False + self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, verbose: bool = False ) -> torch.Tensor: """ Forward pass: Attention -> MoE - + Args: Q: Query tensor of shape [B, L, D] K: Key tensor of shape [B, L, D] V: Value tensor of shape [B, L, D] verbose: Whether to print debug information - + Returns: Output tensor of shape [B, L, D] """ batch_size = Q.shape[0] L = Q.shape[1] D = Q.shape[2] - + # Process each batch item through attention # Attention expects [L, D], so we process batch by batch attn_outputs = [] for b in range(batch_size): attn_out = self.attention(Q[b], K[b], V[b], verbose=verbose and b == 0) attn_outputs.append(attn_out) - + # Stack back to [B, L, D] attn_output = torch.stack(attn_outputs, dim=0) - + if verbose: print(f"\n[AttentionMoE] Attention output shape: {attn_output.shape}") - + # Pass through MoE # MoE expects [B, L, D] and returns [B, L, D] moe_output = self.moe(attn_output, verbose=verbose) - + if verbose: print(f"[AttentionMoE] MoE output shape: {moe_output.shape}") - + return moe_output @@ -431,11 +458,11 @@ def run_attention_moe_inference( k: int = 1, expert_hidden_dim: Optional[int] = None, seed: int = 24, - verbose: bool = True + verbose: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, AttentionMoE]: """ Run Attention + MoE inference with fixed seed for reproducible inputs and weights. - + Args: batch_size: Batch size (B) seq_len: Sequence length (L) @@ -446,7 +473,7 @@ def run_attention_moe_inference( expert_hidden_dim: Hidden dimension for experts seed: Random seed for reproducibility verbose: Whether to print verification information - + Returns: output: Output tensor from AttentionMoE [B, L, D] Q, K, V: Input tensors used for inference @@ -456,7 +483,7 @@ def run_attention_moe_inference( torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) - + # Create model model = AttentionMoE( seq_len=seq_len, @@ -464,27 +491,27 @@ def run_attention_moe_inference( num_heads=num_heads, num_experts=num_experts, k=k, - expert_hidden_dim=expert_hidden_dim + expert_hidden_dim=expert_hidden_dim, ) model.eval() - + # Initialize with random weights (Xavier uniform for better stability) for param in model.parameters(): if param.dim() > 1: nn.init.xavier_uniform_(param) else: nn.init.zeros_(param) - + # Create random inputs (Q, K, V) torch.manual_seed(seed) Q = torch.randn(batch_size, seq_len, embed_dim) K = torch.randn(batch_size, seq_len, embed_dim) V = torch.randn(batch_size, seq_len, embed_dim) - + # Run inference with verbose output for verification with torch.no_grad(): output = model(Q, K, V, verbose=verbose) - + return output, Q, K, V, model @@ -493,11 +520,11 @@ def verify_gate_layer( input_dim: int = 96, num_experts: int = 2, k: int = 1, - seed: int = 42 + seed: int = 42, ): """ Verify that the Gate Layer is effectively routing tokens to different experts. - + This function tests: 1. Different inputs produce different routing decisions 2. Gate weights are properly normalized (sum to 1.0) @@ -510,22 +537,22 @@ def verify_gate_layer( print(f"Configuration: num_tokens={num_tokens}, input_dim={input_dim}") print(f" num_experts={num_experts}, k={k}") print("=" * 70) - + torch.manual_seed(seed) - + # Create gate gate = TopKGate(input_dim, num_experts, k) nn.init.xavier_uniform_(gate.gate_linear.weight) gate.eval() - + # Test 1: Create random inputs and check routing print("\n[Test 1] Random Input Routing") print("-" * 50) x = torch.randn(num_tokens, input_dim) - + with torch.no_grad(): full_weights, top_k_indices = gate(x, verbose=False) - + # Count tokens per expert expert_counts = {} for e in range(num_experts): @@ -533,73 +560,79 @@ def verify_gate_layer( expert_counts[e] = count percentage = (count / (num_tokens * k)) * 100 print(f" Expert {e}: {count} tokens ({percentage:.1f}%)") - + # Check if routing is balanced min_count = min(expert_counts.values()) max_count = max(expert_counts.values()) imbalance_ratio = max_count / max(min_count, 1) - + if imbalance_ratio < 10: - print(f" ✓ Routing is reasonably balanced (imbalance ratio: {imbalance_ratio:.2f})") + print( + f" ✓ Routing is reasonably balanced (imbalance ratio: {imbalance_ratio:.2f})" + ) else: print(f" ⚠ Routing is imbalanced (imbalance ratio: {imbalance_ratio:.2f})") - + # Test 2: Verify weights are normalized print("\n[Test 2] Weight Normalization") print("-" * 50) weight_sums = full_weights.sum(dim=1) - all_normalized = torch.allclose(weight_sums, torch.ones_like(weight_sums), atol=1e-5) - + all_normalized = torch.allclose( + weight_sums, torch.ones_like(weight_sums), atol=1e-5 + ) + if all_normalized: print(f" ✓ All gate weights sum to 1.0 (properly normalized)") else: print(f" ✗ Gate weights NOT properly normalized!") - print(f" Weight sums range: [{weight_sums.min():.6f}, {weight_sums.max():.6f}]") - + print( + f" Weight sums range: [{weight_sums.min():.6f}, {weight_sums.max():.6f}]" + ) + # Test 3: Different inputs should (mostly) produce different routings print("\n[Test 3] Input-Dependent Routing") print("-" * 50) - + # Create two very different inputs x1 = torch.randn(10, input_dim) * 10 # Large magnitude x2 = torch.randn(10, input_dim) * 0.1 # Small magnitude - + with torch.no_grad(): _, indices1 = gate(x1, verbose=False) _, indices2 = gate(x2, verbose=False) - + # Check if routings are different same_routing = (indices1 == indices2).all().item() if not same_routing: print(f" ✓ Different inputs produce different routing decisions") else: print(f" ⚠ Different inputs produced same routing (may indicate issue)") - + # Test 4: Determinism - same input should always produce same routing print("\n[Test 4] Routing Determinism") print("-" * 50) - + x_test = torch.randn(20, input_dim) with torch.no_grad(): _, indices_run1 = gate(x_test, verbose=False) _, indices_run2 = gate(x_test, verbose=False) - + is_deterministic = (indices_run1 == indices_run2).all().item() if is_deterministic: print(f" ✓ Routing is deterministic (same input -> same expert)") else: print(f" ✗ Routing is NOT deterministic!") - + # Test 5: Show actual routing decisions for a few tokens print("\n[Test 5] Sample Routing Decisions") print("-" * 50) - + # Create a small set of tokens to visualize x_sample = torch.randn(8, input_dim) with torch.no_grad(): logits = gate.gate_linear(x_sample) weights, indices = gate(x_sample, verbose=False) - + print(f" Token | Expert Logits (raw) | Selected Expert | Weight") print(f" " + "-" * 60) for i in range(min(8, x_sample.shape[0])): @@ -607,48 +640,54 @@ def verify_gate_layer( selected = indices[i, 0].item() weight = weights[i, selected].item() print(f" {i:5d} | [{logit_str:25s}] | Expert {selected} | {weight:.4f}") - + # Test 6: Verify gate learns meaningful patterns print("\n[Test 6] Pattern-Based Routing") print("-" * 50) - + # Create inputs with clear patterns # Pattern A: positive values in first half, negative in second half # Pattern B: opposite - pattern_a = torch.cat([torch.ones(input_dim // 2), -torch.ones(input_dim // 2)]).unsqueeze(0) - pattern_b = torch.cat([-torch.ones(input_dim // 2), torch.ones(input_dim // 2)]).unsqueeze(0) - + pattern_a = torch.cat( + [torch.ones(input_dim // 2), -torch.ones(input_dim // 2)] + ).unsqueeze(0) + pattern_b = torch.cat( + [-torch.ones(input_dim // 2), torch.ones(input_dim // 2)] + ).unsqueeze(0) + # Repeat patterns x_patterns = torch.cat([pattern_a.repeat(5, 1), pattern_b.repeat(5, 1)], dim=0) - + with torch.no_grad(): _, pattern_indices = gate(x_patterns, verbose=False) - + pattern_a_experts = pattern_indices[:5, 0].tolist() pattern_b_experts = pattern_indices[5:, 0].tolist() - + print(f" Pattern A (positive-negative) routed to experts: {pattern_a_experts}") print(f" Pattern B (negative-positive) routed to experts: {pattern_b_experts}") - + # Check if same patterns go to same expert a_consistent = len(set(pattern_a_experts)) == 1 b_consistent = len(set(pattern_b_experts)) == 1 patterns_different = set(pattern_a_experts) != set(pattern_b_experts) - + if a_consistent and b_consistent: print(f" ✓ Same patterns consistently route to same expert") else: - print(f" ⚠ Same patterns route to different experts (expected with random weights)") - + print( + f" ⚠ Same patterns route to different experts (expected with random weights)" + ) + if patterns_different: print(f" ✓ Different patterns route to different experts") else: print(f" ⚠ Different patterns route to same expert") - + print("\n" + "=" * 70) print("Gate Layer Verification Complete!") print("=" * 70) - + return gate, expert_counts @@ -663,17 +702,17 @@ def benchmark_attention_moe( num_warmup: int = 10, num_runs: int = 100, device: str = "cpu", - seed: int = 42 + seed: int = 42, ): """ Benchmark AttentionMoE inference time. - + Best practices for accurate timing: 1. Warmup runs: Avoid cold start overhead (JIT compilation, cache warming) 2. Multiple runs: Get stable average and standard deviation 3. torch.no_grad(): Disable gradient tracking for inference 4. torch.cuda.synchronize(): Ensure GPU operations complete before timing - + Args: batch_size: Batch size (B) seq_len: Sequence length (L) @@ -686,13 +725,13 @@ def benchmark_attention_moe( num_runs: Number of timed iterations device: "cpu" or "cuda" seed: Random seed - + Returns: dict: Timing statistics (mean, std, min, max in milliseconds) """ import time import numpy as np - + print("=" * 70) print("AttentionMoE Benchmark") print("=" * 70) @@ -704,20 +743,20 @@ def benchmark_attention_moe( print(f" device={device}") print(f" num_warmup={num_warmup}, num_runs={num_runs}") print("=" * 70) - + # Set seed and create model torch.manual_seed(seed) - + model = AttentionMoE( seq_len=seq_len, embed_dim=embed_dim, num_heads=num_heads, num_experts=num_experts, k=k, - expert_hidden_dim=expert_hidden_dim + expert_hidden_dim=expert_hidden_dim, ) model.eval() - + # Move to device if device == "cuda" and torch.cuda.is_available(): model = model.cuda() @@ -725,14 +764,14 @@ def benchmark_attention_moe( else: device = "cpu" print(f" Using CPU") - + # Initialize weights for param in model.parameters(): if param.dim() > 1: nn.init.xavier_uniform_(param) else: nn.init.zeros_(param) - + # Create input tensors torch.manual_seed(seed) if device == "cuda": @@ -743,7 +782,7 @@ def benchmark_attention_moe( Q = torch.randn(batch_size, seq_len, embed_dim) K = torch.randn(batch_size, seq_len, embed_dim) V = torch.randn(batch_size, seq_len, embed_dim) - + # Warmup runs (important for JIT, cache warming, etc.) print(f"\n[1] Warmup ({num_warmup} iterations)...") with torch.no_grad(): @@ -751,26 +790,26 @@ def benchmark_attention_moe( _ = model(Q, K, V, verbose=False) if device == "cuda": torch.cuda.synchronize() # Wait for GPU to finish - + # Timed runs print(f"[2] Timing ({num_runs} iterations)...") times = [] - + with torch.no_grad(): for i in range(num_runs): if device == "cuda": torch.cuda.synchronize() # Ensure previous work is done - + start_time = time.perf_counter() # High-resolution timer - + _ = model(Q, K, V, verbose=False) - + if device == "cuda": torch.cuda.synchronize() # Wait for GPU to finish - + end_time = time.perf_counter() times.append((end_time - start_time) * 1000) # Convert to ms - + # Compute statistics times = np.array(times) stats = { @@ -782,7 +821,7 @@ def benchmark_attention_moe( "p95_ms": np.percentile(times, 95), "p99_ms": np.percentile(times, 99), } - + # Print results print("\n" + "=" * 70) print("Benchmark Results:") @@ -798,7 +837,7 @@ def benchmark_attention_moe( print(f" Throughput: {1000 / stats['mean_ms']:.2f} inferences/sec") print(f" Tokens/sec: {batch_size * seq_len * 1000 / stats['mean_ms']:.2f}") print("=" * 70) - + return stats, model @@ -807,21 +846,21 @@ def benchmark_with_allo_config( num_warmup: int = 10, num_runs: int = 100, device: str = "cpu", - seed: int = 42 + seed: int = 42, ): """ Benchmark using the same configuration as Allo tests. This ensures fair comparison between PyTorch and Allo implementations. - + Uses llm_config.py for configuration to match allo_attention_moe_alt.py - + Args: config_mode: Configuration mode from llm_config (e.g., "switch_base_8_scaled_1_8") num_warmup: Number of warmup iterations num_runs: Number of timed iterations device: "cpu" or "cuda" seed: Random seed (must match Allo test seed for same inputs) - + Returns: dict: Timing statistics and model """ @@ -829,26 +868,26 @@ def benchmark_with_allo_config( import os import time import numpy as np - + # Add llm_config to path current_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.dirname(current_dir) llm_config_dir = os.path.join(project_root, "llm_config") if llm_config_dir not in sys.path: sys.path.insert(0, llm_config_dir) - + from llm_config import get_moe_config - + # Get configuration (same as allo_attention_moe_alt.py) moe_config = get_moe_config(config_mode) - + batch_size = moe_config["batch_size"] seq_len = moe_config["seq_len"] embed_dim = moe_config["input_dim"] num_experts = moe_config["num_experts"] k = moe_config["k"] hidden_dim = moe_config["hidden_dim"] - + # Determine num_heads (same logic as allo_attention_moe_alt.py) if embed_dim >= 768: num_heads = 12 @@ -862,10 +901,10 @@ def benchmark_with_allo_config( num_heads = 4 else: num_heads = 2 - + while embed_dim % num_heads != 0 and num_heads > 1: num_heads -= 1 - + print("=" * 70) print("PyTorch AttentionMoE Benchmark (Allo-compatible config)") print("=" * 70) @@ -877,29 +916,29 @@ def benchmark_with_allo_config( print(f" device={device}, seed={seed}") print(f" num_warmup={num_warmup}, num_runs={num_runs}") print("=" * 70) - + # Create model with same initialization as Allo test torch.manual_seed(seed) if torch.cuda.is_available() and device == "cuda": torch.cuda.manual_seed_all(seed) - + model = AttentionMoE( seq_len=seq_len, embed_dim=embed_dim, num_heads=num_heads, num_experts=num_experts, k=k, - expert_hidden_dim=hidden_dim + expert_hidden_dim=hidden_dim, ) model.eval() - + # Initialize with Xavier uniform (same as Allo test) for param in model.parameters(): if param.dim() > 1: nn.init.xavier_uniform_(param) else: nn.init.zeros_(param) - + # Move to device if device == "cuda" and torch.cuda.is_available(): model = model.cuda() @@ -907,7 +946,7 @@ def benchmark_with_allo_config( else: device = "cpu" print(f" Using CPU") - + # Create input tensors with same seed (same as Allo test) torch.manual_seed(seed) if device == "cuda": @@ -918,16 +957,16 @@ def benchmark_with_allo_config( Q = torch.randn(batch_size, seq_len, embed_dim) K = torch.randn(batch_size, seq_len, embed_dim) V = torch.randn(batch_size, seq_len, embed_dim) - + print(f"\nInput shapes: Q={Q.shape}, K={K.shape}, V={V.shape}") - + # Verify output first print("\n[1] Verifying output...") with torch.no_grad(): output = model(Q, K, V, verbose=False) print(f" Output shape: {output.shape}") print(f" Output range: [{output.min().item():.6f}, {output.max().item():.6f}]") - + # Warmup print(f"\n[2] Warmup ({num_warmup} iterations)...") with torch.no_grad(): @@ -935,25 +974,25 @@ def benchmark_with_allo_config( _ = model(Q, K, V, verbose=False) if device == "cuda": torch.cuda.synchronize() - + # Timed runs print(f"[3] Timing ({num_runs} iterations)...") times = [] - + with torch.no_grad(): for i in range(num_runs): if device == "cuda": torch.cuda.synchronize() - + start_time = time.perf_counter() _ = model(Q, K, V, verbose=False) - + if device == "cuda": torch.cuda.synchronize() - + end_time = time.perf_counter() times.append((end_time - start_time) * 1000) - + # Compute statistics times = np.array(times) stats = { @@ -965,7 +1004,7 @@ def benchmark_with_allo_config( "p95_ms": np.percentile(times, 95), "p99_ms": np.percentile(times, 99), } - + # Print results print("\n" + "=" * 70) print("Benchmark Results:") @@ -981,7 +1020,7 @@ def benchmark_with_allo_config( print(f" Throughput: {1000 / stats['mean_ms']:.2f} inferences/sec") print(f" Tokens/sec: {batch_size * seq_len * 1000 / stats['mean_ms']:.2f}") print("=" * 70) - + return stats, model, Q, K, V @@ -990,20 +1029,16 @@ def benchmark_with_allo_config( print("\n" + "#" * 70) print("# PART 1: Gate Layer Verification") print("#" * 70) - + gate, expert_counts = verify_gate_layer( - num_tokens=128, - input_dim=96, - num_experts=2, - k=1, - seed=42 + num_tokens=128, input_dim=96, num_experts=2, k=1, seed=42 ) - + # Then run the full Attention + MoE test print("\n" + "#" * 70) print("# PART 2: Full Attention + MoE Inference Test") print("#" * 70) - + # Configuration parameters batch_size = 1 seq_len = 4 @@ -1013,7 +1048,7 @@ def benchmark_with_allo_config( k = 1 # Top-k MoE: each token uses exactly k experts expert_hidden_dim = 16 # Hidden dimension for experts seed = 24 - + print("=" * 60) print(f"Attention + MoE Inference Test") print("=" * 60) @@ -1022,7 +1057,7 @@ def benchmark_with_allo_config( print(f" num_heads={num_heads}, head_dim={embed_dim // num_heads}") print(f" num_experts={num_experts}, k={k}, expert_hidden_dim={expert_hidden_dim}") print("=" * 60) - + # Run Attention + MoE inference output, Q, K, V, model = run_attention_moe_inference( batch_size=batch_size, @@ -1033,9 +1068,9 @@ def benchmark_with_allo_config( k=k, expert_hidden_dim=expert_hidden_dim, seed=seed, - verbose=True + verbose=True, ) - + print(f"\n" + "=" * 60) print(f"Results:") print(f" Input Q shape: {Q.shape}") @@ -1044,17 +1079,17 @@ def benchmark_with_allo_config( print(f" Output shape: {output.shape}") print(f" Output range: [{output.min().item():.6f}, {output.max().item():.6f}]") print("=" * 60) - + # PART 3: Benchmark with Allo-compatible config print("\n" + "#" * 70) print("# PART 3: Performance Benchmark (Allo-compatible config)") print("#" * 70) - + # Use the same config as allo_attention_moe_alt.py for fair comparison stats, _, Q_bench, K_bench, V_bench = benchmark_with_allo_config( config_mode="switch_base_8_scaled_1_8", # Same as DEFAULT_CONFIG_MODE in llm_config num_warmup=10, num_runs=100, device="cuda" if torch.cuda.is_available() else "cpu", - seed=42 # Same seed as Allo test + seed=42, # Same seed as Allo test ) diff --git a/ece6775/llm_config/check_llm_config.py b/ece6775/llm_config/check_llm_config.py index e9b7a78bd..449d572f2 100644 --- a/ece6775/llm_config/check_llm_config.py +++ b/ece6775/llm_config/check_llm_config.py @@ -1,70 +1,89 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 # check MoE model config from Hugging Face # usage: python check_llm_config.py import sys from transformers import AutoConfig + def print_moe_config(model_name): # print MoE model config try: print(f"Loading model configuration: {model_name}") print("=" * 80) - + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) - + # Print basic configuration print("\n[Basic Architecture Parameters]") - print(f"Model type: {config.model_type if hasattr(config, 'model_type') else 'N/A'}") - print(f"Vocabulary size: {config.vocab_size if hasattr(config, 'vocab_size') else 'N/A'}") - print(f"Max position embeddings: {config.max_position_embeddings if hasattr(config, 'max_position_embeddings') else 'N/A'}") - + print( + f"Model type: {config.model_type if hasattr(config, 'model_type') else 'N/A'}" + ) + print( + f"Vocabulary size: {config.vocab_size if hasattr(config, 'vocab_size') else 'N/A'}" + ) + print( + f"Max position embeddings: {config.max_position_embeddings if hasattr(config, 'max_position_embeddings') else 'N/A'}" + ) + # Print hidden layer dimensions print("\n[Hidden Layer Dimensions]") - if hasattr(config, 'hidden_size'): + if hasattr(config, "hidden_size"): print(f"Hidden size (hidden_size): {config.hidden_size}") - if hasattr(config, 'd_model'): + if hasattr(config, "d_model"): print(f"Model dimension (d_model): {config.d_model}") - if hasattr(config, 'n_embd'): + if hasattr(config, "n_embd"): print(f"Embedding dimension (n_embd): {config.n_embd}") - + # Print MoE related parameters print("\n[MoE Expert Parameters]") - if hasattr(config, 'num_local_experts'): - print(f"Number of local experts (num_local_experts): {config.num_local_experts}") - if hasattr(config, 'num_experts'): + if hasattr(config, "num_local_experts"): + print( + f"Number of local experts (num_local_experts): {config.num_local_experts}" + ) + if hasattr(config, "num_experts"): print(f"Number of experts (num_experts): {config.num_experts}") - if hasattr(config, 'num_experts_per_tok'): - print(f"Number of experts per token (num_experts_per_tok): {config.num_experts_per_tok}") - if hasattr(config, 'num_experts_to_select'): - print(f"Number of experts to select (num_experts_to_select): {config.num_experts_to_select}") - + if hasattr(config, "num_experts_per_tok"): + print( + f"Number of experts per token (num_experts_per_tok): {config.num_experts_per_tok}" + ) + if hasattr(config, "num_experts_to_select"): + print( + f"Number of experts to select (num_experts_to_select): {config.num_experts_to_select}" + ) + # Print feed-forward network parameters print("\n[Feed-Forward Network Parameters]") - if hasattr(config, 'intermediate_size'): + if hasattr(config, "intermediate_size"): print(f"Intermediate size (intermediate_size): {config.intermediate_size}") - if hasattr(config, 'ffn_dim'): + if hasattr(config, "ffn_dim"): print(f"FFN dimension (ffn_dim): {config.ffn_dim}") - if hasattr(config, 'n_inner'): + if hasattr(config, "n_inner"): print(f"Inner dimension (n_inner): {config.n_inner}") - + # Print attention parameters print("\n[Attention Parameters]") - if hasattr(config, 'num_attention_heads'): - print(f"Number of attention heads (num_attention_heads): {config.num_attention_heads}") - if hasattr(config, 'num_heads'): + if hasattr(config, "num_attention_heads"): + print( + f"Number of attention heads (num_attention_heads): {config.num_attention_heads}" + ) + if hasattr(config, "num_heads"): print(f"Number of heads (num_heads): {config.num_heads}") - if hasattr(config, 'n_head'): + if hasattr(config, "n_head"): print(f"Number of heads (n_head): {config.n_head}") - + # Print layer count print("\n[Layer Parameters]") - if hasattr(config, 'num_hidden_layers'): - print(f"Number of hidden layers (num_hidden_layers): {config.num_hidden_layers}") - if hasattr(config, 'num_layers'): + if hasattr(config, "num_hidden_layers"): + print( + f"Number of hidden layers (num_hidden_layers): {config.num_hidden_layers}" + ) + if hasattr(config, "num_layers"): print(f"Number of layers (num_layers): {config.num_layers}") - if hasattr(config, 'n_layer'): + if hasattr(config, "n_layer"): print(f"Number of layers (n_layer): {config.n_layer}") - + # Print all configuration (for debugging) print("\n[Full Configuration Information]") print("-" * 80) @@ -73,17 +92,29 @@ def print_moe_config(model_name): print(f"{key}: {value}") elif isinstance(value, (list, tuple)) and len(value) < 10: print(f"{key}: {value}") - + print("\n" + "=" * 80) print("Configuration loaded successfully!") - + # Provide suggested test parameters print("\n[Suggested Test Parameters (for Allo Implementation)]") - hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', getattr(config, 'n_embd', None))) - num_experts = getattr(config, 'num_local_experts', getattr(config, 'num_experts', None)) - intermediate_size = getattr(config, 'intermediate_size', getattr(config, 'ffn_dim', getattr(config, 'n_inner', None))) - num_experts_per_tok = getattr(config, 'num_experts_per_tok', getattr(config, 'num_experts_to_select', 1)) - + hidden_size = getattr( + config, + "hidden_size", + getattr(config, "d_model", getattr(config, "n_embd", None)), + ) + num_experts = getattr( + config, "num_local_experts", getattr(config, "num_experts", None) + ) + intermediate_size = getattr( + config, + "intermediate_size", + getattr(config, "ffn_dim", getattr(config, "n_inner", None)), + ) + num_experts_per_tok = getattr( + config, "num_experts_per_tok", getattr(config, "num_experts_to_select", 1) + ) + if hidden_size: print(f"input_dim = {hidden_size} # Input dimension") if intermediate_size: @@ -93,16 +124,25 @@ def print_moe_config(model_name): if num_experts: print(f"num_experts = {num_experts} # Number of experts") if num_experts_per_tok: - print(f"k = {num_experts_per_tok} # Top-k (number of experts activated per token)") + print( + f"k = {num_experts_per_tok} # Top-k (number of experts activated per token)" + ) print(f"batch_size = 1 # Batch size (adjustable as needed)") - print(f"seq_len = 128 # Sequence length (adjustable as needed, actual model may support longer)") - + print( + f"seq_len = 128 # Sequence length (adjustable as needed, actual model may support longer)" + ) + except Exception as e: print(f"Error: {e}") import traceback + traceback.print_exc() - print("\nHint: Make sure transformers library is installed: pip install transformers") - print("If the model requires trust_remote_code=True, make sure to trust the model") + print( + "\nHint: Make sure transformers library is installed: pip install transformers" + ) + print( + "If the model requires trust_remote_code=True, make sure to trust the model" + ) if __name__ == "__main__": @@ -118,7 +158,6 @@ def print_moe_config(model_name): print("\n4. GShard:") print(" python check_moe_config.py google/gshard-gpt") sys.exit(1) - + model_name = sys.argv[1] print_moe_config(model_name) - diff --git a/ece6775/llm_config/llm_config.py b/ece6775/llm_config/llm_config.py index f126c0ed3..24bcfd1a4 100644 --- a/ece6775/llm_config/llm_config.py +++ b/ece6775/llm_config/llm_config.py @@ -1,3 +1,5 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 # LLM model config - shared for MoE and Attention+MoE # config modes: switch_base_8*, mixtral_8x7b*, deepseek*, custom @@ -8,9 +10,9 @@ def get_moe_config(config_mode=None): # get MoE config (works for standalone MoE and Attention+MoE) if config_mode is None: config_mode = DEFAULT_CONFIG_MODE - + config = {} - + if config_mode == "switch_base_8": # Google Switch-Base-8 (original) config = { @@ -121,7 +123,7 @@ def get_moe_config(config_mode=None): "k": 1, # Top-1 MoE "hidden_dim": 256, } - + return config @@ -139,4 +141,3 @@ def print_config_info(config_mode, config): print(f"Top-k: {config['k']}") print(f"Hidden dimension: {config['hidden_dim']}") print("=" * 60) - diff --git a/ece6775/moe/allo_moe_alt.py b/ece6775/moe/allo_moe_alt.py index 1906c9ac2..2fc90f26d 100644 --- a/ece6775/moe/allo_moe_alt.py +++ b/ece6775/moe/allo_moe_alt.py @@ -1,3 +1,6 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + # MoE with fused GeLU optimization for HLS # fuses GeLU into FC1 to save memory and enable pipelining # small diff vs pytorch (~1e-4) from accumulation order @@ -11,6 +14,7 @@ import sys import os + # path setup for config current_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.dirname(current_dir) @@ -21,87 +25,100 @@ from llm_config import DEFAULT_CONFIG_MODE, get_moe_config, print_config_info CONFIG_MODE = DEFAULT_CONFIG_MODE + + def softmax_1d[Ty, N, K](X: "Ty[N, K]") -> "Ty[N, K]": # stable softmax over last dim Z: Ty[N, K] E_exp: Ty[N, K] M: Ty[N] = -1e12 S: Ty[N] = 0.0 - + # find max per row for n, k in dsl.grid(N, K, name="row_max"): if X[n, k] > M[n]: M[n] = X[n, k] - + # exp and sum for n, k in dsl.grid(N, K, name="exp_sum"): E_exp[n, k] = dsl.exp(X[n, k] - M[n]) S[n] += E_exp[n, k] - + # normalize for n, k in dsl.grid(N, K, name="normalize"): Z[n, k] = E_exp[n, k] / S[n] - + return Z + # top-1 selection (argmax) def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": # pick best expert per token indices: int32[N] max_val: Ty[N] = -1e12 - + for n in range(N, name="init"): indices[n] = 0 max_val[n] = logits[n, 0] - + for n, e in dsl.grid(N, E, name="argmax"): if e > 0 and logits[n, e] > max_val[n]: max_val[n] = logits[n, e] indices[n] = e - + return indices + + # FFN expert with fused GeLU, row-by-row for pipelining -def expert[Ty, N, D_in, D_hidden, D_out]( +def expert[ + Ty, N, D_in, D_hidden, D_out +]( x: "Ty[N, D_in]", fc1_weight: "Ty[D_hidden, D_in]", fc1_bias: "Ty[D_hidden]", fc2_weight: "Ty[D_out, D_hidden]", - fc2_bias: "Ty[D_out]" + fc2_bias: "Ty[D_out]", ) -> "Ty[N, D_out]": output: Ty[N, D_out] = 0.0 - + for n in range(N, name="row_loop"): hidden_row: Ty[D_hidden] = 0.0 - + # FC1 + GeLU together for h in range(D_hidden, name="fc1_gelu_loop"): acc: Ty = 0.0 for k in range(D_in, name="fc1_reduce"): acc += x[n, k] * fc1_weight[h, k] fc1_val: Ty = acc + fc1_bias[h] - + # GeLU: tanh approximation x3: Ty = fc1_val * fc1_val * fc1_val - inner: Ty = 0.7978845608028654 * (fc1_val + 0.044715 * x3) # sqrt(2/pi) ≈ 0.7978... + inner: Ty = 0.7978845608028654 * ( + fc1_val + 0.044715 * x3 + ) # sqrt(2/pi) ≈ 0.7978... hidden_row[h] = 0.5 * fc1_val * (1.0 + dsl.tanh(inner)) - + # FC2 for o in range(D_out, name="fc2_loop"): acc2: Ty = 0.0 for k in range(D_hidden, name="fc2_reduce"): acc2 += hidden_row[k] * fc2_weight[o, k] output[n, o] = acc2 + fc2_bias[o] - + return output + + # Top-1 MoE: pick one expert per token, weight with softmax -def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( +def moe_layer[ + Ty, N, D_in, D_out, E, K, D_hidden +]( x: "Ty[N, D_in]", gate_weight: "Ty[E, D_in]", gate_bias: "Ty[E]", expert_fc1_weights: "Ty[E, D_hidden, D_in]", expert_fc1_biases: "Ty[E, D_hidden]", expert_fc2_weights: "Ty[E, D_out, D_hidden]", - expert_fc2_biases: "Ty[E, D_out]" + expert_fc2_biases: "Ty[E, D_out]", ) -> "Ty[N, D_out]": # compute gate scores gate_logits: Ty[N, E] = 0.0 @@ -110,33 +127,33 @@ def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( for k in range(D_in, name="gate_reduce"): acc += x[n, k] * gate_weight[e, k] gate_logits[n, e] = acc + gate_bias[e] - + # pick best expert top1_idx: int32[N] = top1_select[Ty, N, E](gate_logits) - + # get top-k logits (for k=1) top_k_logits: Ty[N, K] = 0.0 for n, k in dsl.grid(N, K, name="topk_logits"): expert_idx = top1_idx[n] if k == 0 else 0 top_k_logits[n, k] = gate_logits[n, expert_idx] - + top_k_weights = softmax_1d[Ty, N, K](top_k_logits) - + # sparse weights gate_w: Ty[N, E] = 0.0 for n in range(N, name="sparse_gate"): gate_w[n, top1_idx[n]] = top_k_weights[n, 0] - + # run all experts (could optimize to skip unused ones later) expert_out: Ty[E, N, D_out] = 0.0 - + for e in range(E, name="expert_loop"): # extract weights for this expert fc1_w: Ty[D_hidden, D_in] = 0.0 fc1_b: Ty[D_hidden] = 0.0 fc2_w: Ty[D_out, D_hidden] = 0.0 fc2_b: Ty[D_out] = 0.0 - + for i, j in dsl.grid(D_hidden, D_in, name="extract_fc1_w"): fc1_w[i, j] = expert_fc1_weights[e, i, j] for i in range(D_hidden, name="extract_fc1_b"): @@ -145,22 +162,24 @@ def moe_layer[Ty, N, D_in, D_out, E, K, D_hidden]( fc2_w[i, j] = expert_fc2_weights[e, i, j] for i in range(D_out, name="extract_fc2_b"): fc2_b[i] = expert_fc2_biases[e, i] - + out = expert[Ty, N, D_in, D_hidden, D_out](x, fc1_w, fc1_b, fc2_w, fc2_b) - + for n, d in dsl.grid(N, D_out, name="store_expert_out"): expert_out[e, n, d] = out[n, d] - + # combine with weights output: Ty[N, D_out] = 0.0 for n, e, d_out in dsl.grid(N, E, D_out, name="combine_outputs"): output[n, d_out] += expert_out[e, n, d_out] * gate_w[n, e] - + return output # build schedule with pipelining/unrolling optimizations -def optimize_moe_schedule(num_tokens, input_dim, output_dim, num_experts, k, hidden_dim): +def optimize_moe_schedule( + num_tokens, input_dim, output_dim, num_experts, k, hidden_dim +): Ty = float32 N = num_tokens D_in = input_dim @@ -168,32 +187,32 @@ def optimize_moe_schedule(num_tokens, input_dim, output_dim, num_experts, k, hid E = num_experts K = k D_hidden = hidden_dim - + print("Building MoE schedule...") - + # top1_select s_top1 = allo.customize(top1_select, instantiate=[Ty, N, E]) s_top1.pipeline("top1_select:e") - + # softmax s_softmax = allo.customize(softmax_1d, instantiate=[Ty, N, K]) sm_loops = s_softmax.get_loops(s_softmax.top_func_name) s_softmax.pipeline(sm_loops["row_max"]["k"]) s_softmax.pipeline(sm_loops["exp_sum"]["k"]) s_softmax.pipeline(sm_loops["normalize"]["k"]) - + # expert s_expert = allo.customize(expert, instantiate=[Ty, N, D_in, D_hidden, D_out]) exp_loops = s_expert.get_loops(s_expert.top_func_name) row_loop = exp_loops["row_loop"] s_expert.pipeline(row_loop["h"]) s_expert.pipeline(row_loop["o"]) - + # unroll reduction loop for parallel MACs unroll_factor = min(4, D_in, D_hidden) s_expert.unroll(row_loop["k"], factor=unroll_factor) print(f" - Applied unroll to k (reduction loop) factor={unroll_factor}") - + # array partitioning if D_hidden <= 32: s_expert.partition(s_expert.hidden_row, dim=0) @@ -201,45 +220,47 @@ def optimize_moe_schedule(num_tokens, input_dim, output_dim, num_experts, k, hid else: s_expert.partition(s_expert.hidden_row, dim=0, factor=4) print(f" - Applied cyclic partition to hidden_row (factor=4)") - - print(" - Created expert schedule with pipeline, unroll, and partition optimizations") - + + print( + " - Created expert schedule with pipeline, unroll, and partition optimizations" + ) + # moe_layer schedule print("\n[4] Creating schedule for moe_layer...") s_moe = allo.customize(moe_layer, instantiate=[Ty, N, D_in, D_out, E, K, D_hidden]) - + moe_loops = s_moe.get_loops(s_moe.top_func_name) print(f" - Available top-level loops: {list(moe_loops.loops.keys())}") - + # optimize gate_linear gate_linear_loop = moe_loops["gate_linear"] print(f" - gate_linear sub-loops: {list(gate_linear_loop.loops.keys())}") - + s_moe.pipeline(gate_linear_loop["e"]) print(" - Applied pipeline to gate_linear:e") - + unroll_factor_gate = min(4, D_in) s_moe.unroll(gate_linear_loop["k"], factor=unroll_factor_gate) print(f" - Applied unroll to gate_linear:k (factor={unroll_factor_gate})") - + # other loops combine_loop = moe_loops["combine_outputs"] s_moe.pipeline(combine_loop["d_out"]) - + topk_loop = moe_loops["topk_logits"] s_moe.pipeline(topk_loop["k"]) - + sparse_loop = moe_loops["sparse_gate"] s_moe.pipeline(sparse_loop["n"]) - + print(" - Applied pipeline to combine_outputs:d_out, topk_logits:k, sparse_gate:n") - + # compose sub-schedules s_moe.compose(s_top1) s_moe.compose(s_softmax) s_moe.compose(s_expert) print(" - Composed top1_select, softmax_1d, and expert schedules") - + print("\n" + "=" * 60) print("Schedule done!") print("=" * 60) @@ -252,7 +273,7 @@ def optimize_moe_schedule(num_tokens, input_dim, output_dim, num_experts, k, hid print(" 4. Array partitioning for hidden_row") print(" 5. Pipeline on h, o, e loops") print("=" * 60) - + return s_moe @@ -261,9 +282,9 @@ def optimize_moe_schedule(num_tokens, input_dim, output_dim, num_experts, k, hid import torch import torch.nn as nn from pytorch_moe import MoELayer - + moe_config = get_moe_config(CONFIG_MODE) - + batch_size = moe_config["batch_size"] seq_len = moe_config["seq_len"] input_dim = moe_config["input_dim"] @@ -272,7 +293,7 @@ def optimize_moe_schedule(num_tokens, input_dim, output_dim, num_experts, k, hid k = moe_config["k"] hidden_dim = moe_config["hidden_dim"] seed = 42 - + print("=" * 60) print("MoE Allo (Vayun Optimized) vs PyTorch Comparison Test") print("=" * 60) @@ -283,74 +304,96 @@ def optimize_moe_schedule(num_tokens, input_dim, output_dim, num_experts, k, hid print(f" num_experts={num_experts}, k={k}, hidden_dim={hidden_dim}") print(f" Seed: {seed}") print("=" * 60) - + # pytorch baseline print("\n[1] Running PyTorch implementation...") torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) - + # Create PyTorch MoE layer pytorch_moe = MoELayer(input_dim, output_dim, num_experts, k, hidden_dim) pytorch_moe.eval() - + # Initialize with Xavier uniform for param in pytorch_moe.parameters(): if param.dim() > 1: nn.init.xavier_uniform_(param) else: nn.init.zeros_(param) - + # Create random input torch.manual_seed(seed) pytorch_input = torch.randn(batch_size, seq_len, input_dim) - + # Run PyTorch inference with torch.no_grad(): pytorch_output = pytorch_moe(pytorch_input, verbose=False) - + print(f"PyTorch output shape: {pytorch_output.shape}") - print(f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]") - + print( + f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]" + ) + # extract weights print("\n[2] Extracting weights from PyTorch model...") - + # Gate weights - gate_weight_pt = pytorch_moe.gate.gate_linear.weight.data # [num_experts, input_dim] + gate_weight_pt = ( + pytorch_moe.gate.gate_linear.weight.data + ) # [num_experts, input_dim] gate_bias_pt = pytorch_moe.gate.gate_linear.bias if gate_bias_pt is not None: gate_bias_pt = gate_bias_pt.data else: gate_bias_pt = torch.zeros(num_experts) - + # Expert weights - expert_fc1_weights_pt = torch.stack([exp.fc1.weight.data for exp in pytorch_moe.experts]) - expert_fc1_biases_pt = torch.stack([exp.fc1.bias.data for exp in pytorch_moe.experts]) - expert_fc2_weights_pt = torch.stack([exp.fc2.weight.data for exp in pytorch_moe.experts]) - expert_fc2_biases_pt = torch.stack([exp.fc2.bias.data for exp in pytorch_moe.experts]) - + expert_fc1_weights_pt = torch.stack( + [exp.fc1.weight.data for exp in pytorch_moe.experts] + ) + expert_fc1_biases_pt = torch.stack( + [exp.fc1.bias.data for exp in pytorch_moe.experts] + ) + expert_fc2_weights_pt = torch.stack( + [exp.fc2.weight.data for exp in pytorch_moe.experts] + ) + expert_fc2_biases_pt = torch.stack( + [exp.fc2.bias.data for exp in pytorch_moe.experts] + ) + # convert to numpy print("\n[3] Converting weights to numpy arrays...") x_np = np.ascontiguousarray(pytorch_input.detach().numpy(), dtype=np.float32) - - gate_weight_np = np.ascontiguousarray(gate_weight_pt.detach().numpy(), dtype=np.float32) + + gate_weight_np = np.ascontiguousarray( + gate_weight_pt.detach().numpy(), dtype=np.float32 + ) gate_bias_np = np.ascontiguousarray(gate_bias_pt.detach().numpy(), dtype=np.float32) - - expert_fc1_weights_np = np.ascontiguousarray(expert_fc1_weights_pt.detach().numpy(), dtype=np.float32) - expert_fc1_biases_np = np.ascontiguousarray(expert_fc1_biases_pt.detach().numpy(), dtype=np.float32) - expert_fc2_weights_np = np.ascontiguousarray(expert_fc2_weights_pt.detach().numpy(), dtype=np.float32) - expert_fc2_biases_np = np.ascontiguousarray(expert_fc2_biases_pt.detach().numpy(), dtype=np.float32) - + + expert_fc1_weights_np = np.ascontiguousarray( + expert_fc1_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc1_biases_np = np.ascontiguousarray( + expert_fc1_biases_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_weights_np = np.ascontiguousarray( + expert_fc2_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_biases_np = np.ascontiguousarray( + expert_fc2_biases_pt.detach().numpy(), dtype=np.float32 + ) + print(f"Input shape: {x_np.shape}") print(f"Gate weight shape: {gate_weight_np.shape}") print(f"Expert FC1 weights shape: {expert_fc1_weights_np.shape}") print(f"Expert FC2 weights shape: {expert_fc2_weights_np.shape}") - + # flatten input (N = B * L) num_tokens = batch_size * seq_len x_flat_np = x_np.reshape(num_tokens, input_dim) print(f"Flattened input shape: {x_flat_np.shape}") - + # run allo print("\n[4] Running Allo implementation...") try: @@ -358,89 +401,110 @@ def optimize_moe_schedule(num_tokens, input_dim, output_dim, num_experts, k, hid allo_schedule = optimize_moe_schedule( num_tokens, input_dim, output_dim, num_experts, k, hidden_dim ) - + # Generate project name project_name = f"allo_moe_alt_{CONFIG_MODE}.prj" print(f"Using project name: {project_name}") - + # Build module print("\n[5] Building Allo module...") if MODE == "llvm": mod = allo_schedule.build(target="llvm") elif MODE == "sw_emu": - mod = allo_schedule.build(target="vitis_hls", mode="sw_emu", project=project_name) + mod = allo_schedule.build( + target="vitis_hls", mode="sw_emu", project=project_name + ) elif MODE == "hw_emu": - mod = allo_schedule.build(target="vitis_hls", mode="hw_emu", project=project_name) + mod = allo_schedule.build( + target="vitis_hls", mode="hw_emu", project=project_name + ) elif MODE == "hw": - mod = allo_schedule.build(target="vitis_hls", mode="hw", project=project_name) + mod = allo_schedule.build( + target="vitis_hls", mode="hw", project=project_name + ) elif MODE == "csyn": - mod = allo_schedule.build(target="vitis_hls", mode="csyn", project=project_name) + mod = allo_schedule.build( + target="vitis_hls", mode="csyn", project=project_name + ) else: raise ValueError(f"Unsupported mode: {MODE}") - + # Run Allo inference print("\n[6] Running Allo inference...") if MODE == "llvm": allo_output_flat = mod( x_flat_np, - gate_weight_np, gate_bias_np, - expert_fc1_weights_np, expert_fc1_biases_np, - expert_fc2_weights_np, expert_fc2_biases_np + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, ) elif MODE in ["sw_emu", "hw_emu", "hw"]: allo_output_flat = np.zeros((num_tokens, output_dim), dtype=np.float32) - mod(x_flat_np, - gate_weight_np, gate_bias_np, - expert_fc1_weights_np, expert_fc1_biases_np, - expert_fc2_weights_np, expert_fc2_biases_np, - allo_output_flat) + mod( + x_flat_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + allo_output_flat, + ) elif MODE == "csyn": allo_output_flat = np.zeros((num_tokens, output_dim), dtype=np.float32) mod() else: raise ValueError(f"Unsupported mode: {MODE}") - + # Reshape output back to [B, L, D_out] allo_output = allo_output_flat.reshape(batch_size, seq_len, output_dim) - + print(f"Allo output shape: {allo_output.shape}") print(f"Allo output range: [{allo_output.min():.6f}, {allo_output.max():.6f}]") - + # compare print("\n[7] Comparing outputs...") pytorch_output_np = pytorch_output.detach().numpy() - + # Compute differences diff = np.abs(allo_output - pytorch_output_np) mean_diff = np.mean(diff) max_diff = np.max(diff) rel_diff = np.mean(diff / (np.abs(pytorch_output_np) + 1e-8)) - + print(f"Mean absolute difference: {mean_diff:.6e}") print(f"Max absolute difference: {max_diff:.6e}") print(f"Mean relative difference: {rel_diff:.6e}") - + # Check if outputs are close atol = 5e-4 rtol = 2e-3 is_close = np.allclose(allo_output, pytorch_output_np, atol=atol, rtol=rtol) - + if is_close: - print(f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})") + print( + f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})" + ) else: - print(f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})") + print( + f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})" + ) print("First few differences:") print(diff.flatten()[:10]) - + # sample outputs print("\n[8] Sample outputs (first token, first 5 dimensions):") print(f"PyTorch: {pytorch_output_np[0, 0, :5]}") print(f"Allo: {allo_output[0, 0, :5]}") print(f"Diff: {diff[0, 0, :5]}") - + except Exception as e: print(f"\n✗ ERROR: Failed to run Allo implementation: {e}") import traceback + traceback.print_exc() - - print("=" * 60) \ No newline at end of file + + print("=" * 60) diff --git a/ece6775/moe/allo_moe_base.py b/ece6775/moe/allo_moe_base.py index 7688f723b..eba936f99 100644 --- a/ece6775/moe/allo_moe_base.py +++ b/ece6775/moe/allo_moe_base.py @@ -1,7 +1,10 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + # MoE in Allo - manual implementation (no nn lib) # matches pytorch but with small numerical differences (1e-4 to 1e-3) from: # - different accumulation order -# - GELU approximation differences +# - GELU approximation differences # - fp precision # these are normal and don't affect performance @@ -23,30 +26,32 @@ from llm_config import DEFAULT_CONFIG_MODE, get_moe_config, print_config_info CONFIG_MODE = DEFAULT_CONFIG_MODE + + def softmax_1d[Ty, N, E](X: "Ty[N, E]") -> "Ty[N, E]": # softmax over last dim Z: Ty[N, E] E_exp: Ty[N, E] M: Ty[N] = -1000000000000.0 # TODO: use -inf if available S: Ty[N] = 0.0 - + # find max per row for i in range(N): for j in range(E): if X[i, j] > M[i]: M[i] = X[i, j] - + # exp and sum for i in range(N): for j in range(E): E_exp[i, j] = dsl.exp(X[i, j] - M[i]) S[i] += E_exp[i, j] - + # normalize for i in range(N): for j in range(E): Z[i, j] = E_exp[i, j] / S[i] - + return Z @@ -55,7 +60,7 @@ def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": # pick best expert for each token indices: int32[N] max_val: Ty[N] = -1000000000000.0 - + for i in range(N): indices[i] = 0 max_val[i] = logits[i, 0] @@ -64,19 +69,20 @@ def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": if logits[i, j] > max_val[i]: max_val[i] = logits[i, j] indices[i] = j - + return indices # TopKGate: Gate module to select top-k experts -def topk_gate[Ty, N, D, E, K]( - x: "Ty[N, D]", - gate_weight: "Ty[E, D]", - gate_bias: "Ty[E]" -) -> ("Ty[N, E]", "int32[N, K]"): +def topk_gate[ + Ty, N, D, E, K +](x: "Ty[N, D]", gate_weight: "Ty[E, D]", gate_bias: "Ty[E]") -> ( + "Ty[N, E]", + "int32[N, K]", +): """ Gate module to select top-k experts for routing. - + Args: x: Input tensor of shape [N, D] where N = batch * seq_len gate_weight: Gate weight matrix of shape [E, D] @@ -96,42 +102,44 @@ def topk_gate[Ty, N, D, E, K]( # Matrix multiplication: x[i] @ gate_weight[j]^T for k in range(D): logits[i, j] += x[i, k] * gate_weight[j, k] - + # For k=1, use top1_select top1_indices_1d: int32[N] = top1_select[Ty, N, E](logits) - + # Expand to [N, K] format (for k=1, K=1) top_k_indices: int32[N, K] = 0 for i in range(N): top_k_indices[i, 0] = top1_indices_1d[i] - + # Get top-k logits top_k_logits: Ty[N, K] = 0.0 for i in range(N): for k in range(K): expert_idx = top_k_indices[i, k] top_k_logits[i, k] = logits[i, expert_idx] - + # Apply softmax to top-k logits top_k_weights = softmax_1d[Ty, N, K](top_k_logits) # [N, K] - + # Create sparse weight matrix full_weights: Ty[N, E] = 0.0 for i in range(N): for k in range(K): expert_idx = top_k_indices[i, k] full_weights[i, expert_idx] = top_k_weights[i, k] - + return full_weights, top_k_indices # simple FFN expert -def expert[Ty, N, D_in, D_hidden, D_out]( +def expert[ + Ty, N, D_in, D_hidden, D_out +]( x: "Ty[N, D_in]", fc1_weight: "Ty[D_hidden, D_in]", fc1_bias: "Ty[D_hidden]", fc2_weight: "Ty[D_out, D_hidden]", - fc2_bias: "Ty[D_out]" + fc2_bias: "Ty[D_out]", ) -> "Ty[N, D_out]": # FC1 fc1_out: Ty[N, D_hidden] = 0.0 @@ -140,7 +148,7 @@ def expert[Ty, N, D_in, D_hidden, D_out]( fc1_out[i, j] = fc1_bias[j] for k in range(D_in): fc1_out[i, j] += x[i, k] * fc1_weight[j, k] - + # GELU gelu_out: Ty[N, D_hidden] = 0.0 for i in range(N): @@ -150,7 +158,7 @@ def expert[Ty, N, D_in, D_hidden, D_out]( inner: Ty = 0.797885 * (x_val + 0.044715 * x3) tanh_term: Ty = dsl.tanh(inner) gelu_out[i, j] = 0.5 * x_val * (1.0 + tanh_term) - + # FC2 fc2_out: Ty[N, D_out] = 0.0 for i in range(N): @@ -158,24 +166,26 @@ def expert[Ty, N, D_in, D_hidden, D_out]( fc2_out[i, j] = fc2_bias[j] for k in range(D_hidden): fc2_out[i, j] += gelu_out[i, k] * fc2_weight[j, k] - + return fc2_out # main MoE layer -def moe_layer[Ty, B, L, D_in, D_out, E, K, D_hidden]( +def moe_layer[ + Ty, B, L, D_in, D_out, E, K, D_hidden +]( x: "Ty[B, L, D_in]", gate_weight: "Ty[E, D_in]", gate_bias: "Ty[E]", expert_fc1_weights: "Ty[E, D_hidden, D_in]", expert_fc1_biases: "Ty[E, D_hidden]", expert_fc2_weights: "Ty[E, D_out, D_hidden]", - expert_fc2_biases: "Ty[E, D_out]" + expert_fc2_biases: "Ty[E, D_out]", ) -> "Ty[B, L, D_out]": # process all tokens through all experts, then weight by gate # matches pytorch behavior gate_logits: Ty[B * L, E] = 0.0 - + # Compute gate logits for all tokens for b in range(B): for l in range(L): @@ -187,57 +197,57 @@ def moe_layer[Ty, B, L, D_in, D_out, E, K, D_hidden]( # Matrix multiplication: x[b, l] @ gate_weight[j]^T for k in range(D_in): gate_logits[token_idx, j] += x[b, l, k] * gate_weight[j, k] - + # Select top-k experts and compute weights gate_weights: Ty[B * L, E] = 0.0 top_k_indices: int32[B * L, K] = 0 - + for i in range(B * L): # Find top-1 expert (argmax) for this token top_expert_idx: int32 = 0 max_logit: Ty = gate_logits[i, 0] - + for j in range(1, E): if gate_logits[i, j] > max_logit: max_logit = gate_logits[i, j] top_expert_idx = j - + # Store top-k indices (for k=1) top_k_indices[i, 0] = top_expert_idx - + # Get top-k logit and apply softmax # For k=1, softmax just returns 1.0, but we compute it for consistency # softmax(top_k_logit) = exp(top_k_logit) / exp(top_k_logit) = 1.0 gate_weights[i, top_expert_idx] = 1.0 - + # Flatten input for expert processing x_flat: Ty[B * L, D_in] = 0.0 for b in range(B): for l in range(L): for d in range(D_in): x_flat[b * L + l, d] = x[b, l, d] - + # Process each expert: compute outputs for all tokens # Then use gate weights to select correct outputs expert_outputs: Ty[E, B * L, D_out] = 0.0 - + for expert_idx in range(E): # Extract expert weights for this expert expert_fc1_w: Ty[D_hidden, D_in] = 0.0 expert_fc1_b: Ty[D_hidden] = 0.0 expert_fc2_w: Ty[D_out, D_hidden] = 0.0 expert_fc2_b: Ty[D_out] = 0.0 - + for h in range(D_hidden): for d in range(D_in): expert_fc1_w[h, d] = expert_fc1_weights[expert_idx, h, d] expert_fc1_b[h] = expert_fc1_biases[expert_idx, h] - + for o in range(D_out): for h in range(D_hidden): expert_fc2_w[o, h] = expert_fc2_weights[expert_idx, o, h] expert_fc2_b[o] = expert_fc2_biases[expert_idx, o] - + # Process all tokens through this expert (inline expert function) # First linear layer: fc1_out = x_flat @ expert_fc1_w^T + expert_fc1_b fc1_out: Ty[B * L, D_hidden] = 0.0 @@ -248,7 +258,7 @@ def moe_layer[Ty, B, L, D_in, D_out, E, K, D_hidden]( # Matrix multiplication: x_flat[i] @ expert_fc1_w[j]^T for k in range(D_in): fc1_out[i, j] += x_flat[i, k] * expert_fc1_w[j, k] - + # GELU activation: 0.5 * x * (1 + tanh(0.797885 * (x + 0.044715 * x^3))) gelu_out: Ty[B * L, D_hidden] = 0.0 for i in range(B * L): @@ -263,7 +273,7 @@ def moe_layer[Ty, B, L, D_in, D_out, E, K, D_hidden]( tanh_term: Ty = dsl.tanh(inner) # GELU: 0.5 * x * (1 + tanh_term) gelu_out[i, j] = 0.5 * x_val * (1.0 + tanh_term) - + # Second linear layer: fc2_out = gelu_out @ expert_fc2_w^T + expert_fc2_b expert_out: Ty[B * L, D_out] = 0.0 for i in range(B * L): @@ -273,12 +283,12 @@ def moe_layer[Ty, B, L, D_in, D_out, E, K, D_hidden]( # Matrix multiplication: gelu_out[i] @ expert_fc2_w[j]^T for k in range(D_hidden): expert_out[i, j] += gelu_out[i, k] * expert_fc2_w[j, k] - + # Store expert outputs for i in range(B * L): for o in range(D_out): expert_outputs[expert_idx, i, o] = expert_out[i, o] - + # Combine expert outputs using gate weights # For each token, sum over experts weighted by gate_weights output_flat: Ty[B * L, D_out] = 0.0 @@ -287,14 +297,14 @@ def moe_layer[Ty, B, L, D_in, D_out, E, K, D_hidden]( weight: Ty = gate_weights[i, expert_idx] for o in range(D_out): output_flat[i, o] += expert_outputs[expert_idx, i, o] * weight - + # Reshape back to original shape output: Ty[B, L, D_out] = 0.0 for b in range(B): for l in range(L): for o in range(D_out): output[b, l, o] = output_flat[b * L + l, o] - + return output @@ -304,7 +314,7 @@ def moe_layer[Ty, B, L, D_in, D_out, E, K, D_hidden]( import torch import torch.nn as nn from pytorch_moe import MoELayer, run_moe_inference - + config = get_moe_config(CONFIG_MODE) batch_size = config["batch_size"] seq_len = config["seq_len"] @@ -313,9 +323,9 @@ def moe_layer[Ty, B, L, D_in, D_out, E, K, D_hidden]( num_experts = config["num_experts"] k = config["k"] hidden_dim = config["hidden_dim"] - + seed = 42 - + # Print configuration info using shared function print("=" * 60) print("MoE Allo vs PyTorch Comparison Test") @@ -323,100 +333,135 @@ def moe_layer[Ty, B, L, D_in, D_out, E, K, D_hidden]( print_config_info(CONFIG_MODE, config) print(f"Seed: {seed}") print("=" * 60) - + # pytorch baseline print("\n[1] Running PyTorch implementation...") torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) - + # Create PyTorch MoE layer pytorch_moe = MoELayer(input_dim, output_dim, num_experts, k, hidden_dim) pytorch_moe.eval() - + # Initialize with Xavier uniform (matching PyTorch) for param in pytorch_moe.parameters(): if param.dim() > 1: nn.init.xavier_uniform_(param) else: nn.init.zeros_(param) - + # Create random input torch.manual_seed(seed) pytorch_input = torch.randn(batch_size, seq_len, input_dim) - + # Run PyTorch inference with torch.no_grad(): pytorch_output = pytorch_moe(pytorch_input, verbose=False) - + print(f"PyTorch output shape: {pytorch_output.shape}") - print(f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]") - + print( + f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]" + ) + # extract weights print("\n[2] Extracting weights from PyTorch model...") - + # Gate weights: [input_dim, num_experts] -> [num_experts, input_dim] for Allo - gate_weight_pt = pytorch_moe.gate.gate_linear.weight.data # [num_experts, input_dim] + gate_weight_pt = ( + pytorch_moe.gate.gate_linear.weight.data + ) # [num_experts, input_dim] gate_bias_pt = pytorch_moe.gate.gate_linear.bias if gate_bias_pt is not None: gate_bias_pt = gate_bias_pt.data else: gate_bias_pt = torch.zeros(num_experts) - + # Expert weights - expert_fc1_weights_pt = torch.stack([expert.fc1.weight.data for expert in pytorch_moe.experts]) # [num_experts, hidden_dim, input_dim] - expert_fc1_biases_pt = torch.stack([expert.fc1.bias.data for expert in pytorch_moe.experts]) # [num_experts, hidden_dim] - expert_fc2_weights_pt = torch.stack([expert.fc2.weight.data for expert in pytorch_moe.experts]) # [num_experts, output_dim, hidden_dim] - expert_fc2_biases_pt = torch.stack([expert.fc2.bias.data for expert in pytorch_moe.experts]) # [num_experts, output_dim] - + expert_fc1_weights_pt = torch.stack( + [expert.fc1.weight.data for expert in pytorch_moe.experts] + ) # [num_experts, hidden_dim, input_dim] + expert_fc1_biases_pt = torch.stack( + [expert.fc1.bias.data for expert in pytorch_moe.experts] + ) # [num_experts, hidden_dim] + expert_fc2_weights_pt = torch.stack( + [expert.fc2.weight.data for expert in pytorch_moe.experts] + ) # [num_experts, output_dim, hidden_dim] + expert_fc2_biases_pt = torch.stack( + [expert.fc2.bias.data for expert in pytorch_moe.experts] + ) # [num_experts, output_dim] + # convert to numpy print("\n[3] Converting weights to numpy arrays...") x_np = np.ascontiguousarray(pytorch_input.detach().numpy(), dtype=np.float32) - + # Gate weights: PyTorch has [num_experts, input_dim], Allo expects [num_experts, input_dim] (same) - gate_weight_np = np.ascontiguousarray(gate_weight_pt.detach().numpy(), dtype=np.float32) + gate_weight_np = np.ascontiguousarray( + gate_weight_pt.detach().numpy(), dtype=np.float32 + ) gate_bias_np = np.ascontiguousarray(gate_bias_pt.detach().numpy(), dtype=np.float32) - + # Expert weights: PyTorch has [num_experts, hidden_dim, input_dim], Allo expects [num_experts, hidden_dim, input_dim] (same) - expert_fc1_weights_np = np.ascontiguousarray(expert_fc1_weights_pt.detach().numpy(), dtype=np.float32) - expert_fc1_biases_np = np.ascontiguousarray(expert_fc1_biases_pt.detach().numpy(), dtype=np.float32) - expert_fc2_weights_np = np.ascontiguousarray(expert_fc2_weights_pt.detach().numpy(), dtype=np.float32) - expert_fc2_biases_np = np.ascontiguousarray(expert_fc2_biases_pt.detach().numpy(), dtype=np.float32) - + expert_fc1_weights_np = np.ascontiguousarray( + expert_fc1_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc1_biases_np = np.ascontiguousarray( + expert_fc1_biases_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_weights_np = np.ascontiguousarray( + expert_fc2_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_biases_np = np.ascontiguousarray( + expert_fc2_biases_pt.detach().numpy(), dtype=np.float32 + ) + print(f"Input shape: {x_np.shape}") print(f"Gate weight shape: {gate_weight_np.shape}") print(f"Expert FC1 weights shape: {expert_fc1_weights_np.shape}") print(f"Expert FC2 weights shape: {expert_fc2_weights_np.shape}") - + # run allo print("\n[4] Running Allo implementation...") try: # Customize Allo module allo_mod = allo.customize( - moe_layer, - instantiate=[float32, batch_size, seq_len, input_dim, output_dim, num_experts, k, hidden_dim] + moe_layer, + instantiate=[ + float32, + batch_size, + seq_len, + input_dim, + output_dim, + num_experts, + k, + hidden_dim, + ], ) - + # Generate project name based on CONFIG_MODE to avoid conflicts # This ensures different configurations use different build folders project_name = f"allo_moe_base_{CONFIG_MODE}.prj" print(f"Using project name: {project_name}") - + # Build module print("Building Allo module...") if MODE == "llvm": mod = allo_mod.build(target="llvm") # Use LLVM for CPU testing elif MODE == "sw_emu": - mod = allo_mod.build(target="vitis_hls", mode="sw_emu", project=project_name) + mod = allo_mod.build( + target="vitis_hls", mode="sw_emu", project=project_name + ) elif MODE == "hw_emu": - mod = allo_mod.build(target="vitis_hls", mode="hw_emu", project=project_name) + mod = allo_mod.build( + target="vitis_hls", mode="hw_emu", project=project_name + ) elif MODE == "hw": mod = allo_mod.build(target="vitis_hls", mode="hw", project=project_name) elif MODE == "csyn": mod = allo_mod.build(target="vitis_hls", mode="csyn", project=project_name) else: raise ValueError(f"Unsupported mode: {MODE}") - + # Run Allo inference # Note: In LLVM backend, functions with return values return the result directly # We don't need to pass output buffer as argument @@ -429,61 +474,93 @@ def moe_layer[Ty, B, L, D_in, D_out, E, K, D_hidden]( expert_fc1_weights_np, expert_fc1_biases_np, expert_fc2_weights_np, - expert_fc2_biases_np + expert_fc2_biases_np, ) elif MODE == "sw_emu": allo_output = np.zeros((batch_size, seq_len, output_dim), dtype=np.float32) - mod(x_np, gate_weight_np, gate_bias_np, expert_fc1_weights_np, expert_fc1_biases_np, expert_fc2_weights_np, expert_fc2_biases_np, allo_output) + mod( + x_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + allo_output, + ) elif MODE == "hw_emu": allo_output = np.zeros((batch_size, seq_len, output_dim), dtype=np.float32) - mod(x_np, gate_weight_np, gate_bias_np, expert_fc1_weights_np, expert_fc1_biases_np, expert_fc2_weights_np, expert_fc2_biases_np, allo_output) + mod( + x_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + allo_output, + ) elif MODE == "hw": allo_output = np.zeros((batch_size, seq_len, output_dim), dtype=np.float32) - mod(x_np, gate_weight_np, gate_bias_np, expert_fc1_weights_np, expert_fc1_biases_np, expert_fc2_weights_np, expert_fc2_biases_np, allo_output) + mod( + x_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + allo_output, + ) elif MODE == "csyn": allo_output = np.zeros((batch_size, seq_len, output_dim), dtype=np.float32) mod() else: raise ValueError(f"Unsupported mode: {MODE}") - + print(f"Allo output shape: {allo_output.shape}") print(f"Allo output range: [{allo_output.min():.6f}, {allo_output.max():.6f}]") - + # compare print("\n[5] Comparing outputs...") pytorch_output_np = pytorch_output.detach().numpy() - + # Compute differences diff = np.abs(allo_output - pytorch_output_np) mean_diff = np.mean(diff) max_diff = np.max(diff) rel_diff = np.mean(diff / (np.abs(pytorch_output_np) + 1e-8)) - + print(f"Mean absolute difference: {mean_diff:.6e}") print(f"Max absolute difference: {max_diff:.6e}") print(f"Mean relative difference: {rel_diff:.6e}") - + # check if close (1e-4 to 1e-3 diff is normal from accumulation order, GELU approx, fp precision) atol = 5e-4 rtol = 2e-3 is_close = np.allclose(allo_output, pytorch_output_np, atol=atol, rtol=rtol) - + if is_close: - print(f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})") + print( + f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})" + ) else: - print(f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})") + print( + f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})" + ) print("First few differences:") print(diff.flatten()[:10]) - + # sample outputs print("\n[6] Sample outputs (first token, first 5 dimensions):") print(f"PyTorch: {pytorch_output_np[0, 0, :5]}") print(f"Allo: {allo_output[0, 0, :5]}") print(f"Diff: {diff[0, 0, :5]}") - + except Exception as e: print(f"\n✗ ERROR: Failed to run Allo implementation: {e}") import traceback + traceback.print_exc() - - print("=" * 60) \ No newline at end of file + + print("=" * 60) diff --git a/ece6775/moe/allo_moe_lib.py b/ece6775/moe/allo_moe_lib.py index 05d6bb16d..e433c2498 100644 --- a/ece6775/moe/allo_moe_lib.py +++ b/ece6775/moe/allo_moe_lib.py @@ -1,3 +1,6 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + # MoE using Allo library functions (nn.linear2d, nn.GeLU) # follows test_bert pattern - all functions in one file with generic types @@ -22,54 +25,60 @@ from llm_config import DEFAULT_CONFIG_MODE, get_moe_config, print_config_info CONFIG_MODE = DEFAULT_CONFIG_MODE + + def softmax_1d[Ty, N, K](X: "Ty[N, K]") -> "Ty[N, K]": # softmax over last dim Z: Ty[N, K] E_exp: Ty[N, K] M: Ty[N] = -1000000000000.0 S: Ty[N] = 0.0 - + # find max per row for n, k in dsl.grid(N, K, name="row_max"): if X[n, k] > M[n]: M[n] = X[n, k] - + # exp and sum for n, k in dsl.grid(N, K, name="exp_sum"): E_exp[n, k] = dsl.exp(X[n, k] - M[n]) S[n] += E_exp[n, k] - + # normalize for n, k in dsl.grid(N, K, name="update"): Z[n, k] = E_exp[n, k] / (S[n]) - + return Z + # top-1 selection def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": # pick best expert per token indices: int32[N] max_val: Ty[N] = -1000000000000.0 - + for n in range(N, name="init"): indices[n] = 0 max_val[n] = logits[n, 0] - + for n, e in dsl.grid(N, E, name="argmax"): if e > 0: # skip e=0 if logits[n, e] > max_val[n]: max_val[n] = logits[n, e] indices[n] = e - + return indices + # FFN expert using library functions -def expert[Ty, N, D_in, D_hidden, D_out]( +def expert[ + Ty, N, D_in, D_hidden, D_out +]( x: "Ty[N, D_in]", fc1_weight: "Ty[D_hidden, D_in]", fc1_bias: "Ty[D_hidden]", fc2_weight: "Ty[D_out, D_hidden]", - fc2_bias: "Ty[D_out]" + fc2_bias: "Ty[D_out]", ) -> "Ty[N, D_out]": # using linear2d and GeLU from library (imported at top) fc1_out = linear2d[Ty, Ty, Ty, N, D_hidden, D_in](x, fc1_weight, fc1_bias) @@ -77,97 +86,102 @@ def expert[Ty, N, D_in, D_hidden, D_out]( fc2_out = linear2d[Ty, Ty, Ty, N, D_out, D_hidden](gelu_out, fc2_weight, fc2_bias) return fc2_out + # main MoE layer -def moe_layer[Ty, B, L, D_in, D_out, E, K, D_hidden]( +def moe_layer[ + Ty, B, L, D_in, D_out, E, K, D_hidden +]( x: "Ty[B, L, D_in]", gate_weight: "Ty[E, D_in]", gate_bias: "Ty[E]", expert_fc1_weights: "Ty[E, D_hidden, D_in]", expert_fc1_biases: "Ty[E, D_hidden]", expert_fc2_weights: "Ty[E, D_out, D_hidden]", - expert_fc2_biases: "Ty[E, D_out]" + expert_fc2_biases: "Ty[E, D_out]", ) -> "Ty[B, L, D_out]": # Flatten batch and sequence dimensions: N = B * L N = B * L x_flat: Ty[N, D_in] = 0.0 for b, l, d_in in dsl.grid(B, L, D_in, name="flatten"): x_flat[b * L + l, d_in] = x[b, l, d_in] - + # Step 1: Compute gate logits using linear2d (inlined topk_gate logic) # Note: Use linear2d (direct import) to avoid conflict with torch.nn - gate_logits = linear2d[Ty, Ty, Ty, N, E, D_in](x_flat, gate_weight, gate_bias) # [N, E] - + gate_logits = linear2d[Ty, Ty, Ty, N, E, D_in]( + x_flat, gate_weight, gate_bias + ) # [N, E] + # Step 2: Select top-1 expert using top1_select function top1_indices_1d: int32[N] = top1_select[Ty, N, E](gate_logits) - + # Step 3: Get top-k logits and apply softmax top_k_logits: Ty[N, K] = 0.0 for n, k in dsl.grid(N, K, name="topk_logits"): expert_idx = top1_indices_1d[n] if k == 0 else 0 # For k=1, K=1 top_k_logits[n, k] = gate_logits[n, expert_idx] - + # Step 4: Apply softmax to top-k logits using softmax_1d function top_k_weights = softmax_1d[Ty, N, K](top_k_logits) # [N, K] - + # Step 5: Create sparse weight matrix from top-k weights top_k_indices: int32[N, K] = 0 gate_weights: Ty[N, E] = 0.0 - + for n in range(N, name="gate_weights"): # Store top-k indices (for k=1) top_k_indices[n, 0] = top1_indices_1d[n] # Set gate weights from softmax output expert_idx = top1_indices_1d[n] gate_weights[n, expert_idx] = top_k_weights[n, 0] # For k=1, K=1 - + # Step 2: Process each expert: compute outputs for all tokens expert_outputs: Ty[E, N, D_out] = 0.0 - + for e in range(E, name="expert_loop"): # Extract expert weights for this expert expert_fc1_w: Ty[D_hidden, D_in] = 0.0 expert_fc1_b: Ty[D_hidden] = 0.0 expert_fc2_w: Ty[D_out, D_hidden] = 0.0 expert_fc2_b: Ty[D_out] = 0.0 - + for d_hidden, d_in in dsl.grid(D_hidden, D_in, name="extract_fc1_w"): expert_fc1_w[d_hidden, d_in] = expert_fc1_weights[e, d_hidden, d_in] - + for d_hidden in range(D_hidden, name="extract_fc1_b"): expert_fc1_b[d_hidden] = expert_fc1_biases[e, d_hidden] - + for d_out, d_hidden in dsl.grid(D_out, D_hidden, name="extract_fc2_w"): expert_fc2_w[d_out, d_hidden] = expert_fc2_weights[e, d_out, d_hidden] - + for d_out in range(D_out, name="extract_fc2_b"): expert_fc2_b[d_out] = expert_fc2_biases[e, d_out] - + # Process all tokens through this expert using the expert function expert_out = expert[Ty, N, D_in, D_hidden, D_out]( x_flat, expert_fc1_w, expert_fc1_b, expert_fc2_w, expert_fc2_b ) # [N, D_out] - + # Store expert outputs for n, d_out in dsl.grid(N, D_out, name="store_expert_out"): expert_outputs[e, n, d_out] = expert_out[n, d_out] - + # Step 3: Combine expert outputs using gate weights output_flat: Ty[N, D_out] = 0.0 for n, e, d_out in dsl.grid(N, E, D_out, name="combine_outputs"): weight: Ty = gate_weights[n, e] output_flat[n, d_out] += expert_outputs[e, n, d_out] * weight - + # Step 4: Reshape back to original shape output: Ty[B, L, D_out] = 0.0 for b, l, d_out in dsl.grid(B, L, D_out, name="reshape"): output[b, l, d_out] = output_flat[b * L + l, d_out] - + return output -#================================================================================== +# ================================================================================== # Schedule optimization function -#================================================================================== +# ================================================================================== def optimize_moe_with_composition( batch_size, seq_len, input_dim, output_dim, num_experts, k, hidden_dim ): @@ -180,69 +194,84 @@ def optimize_moe_with_composition( K = k D_hidden = hidden_dim D_out = output_dim - + print("=" * 60) print("Creating and optimizing MoE schedules...") print("=" * 60) - + # top1_select schedule print("\n[1] Creating schedule for top1_select...") s_top1 = allo.customize(top1_select, instantiate=[Ty, N, E]) print(" - Created top1_select schedule for [N, E]") - + # softmax schedule print("\n[2] Creating schedule for softmax_1d...") s_softmax = allo.customize(softmax_1d, instantiate=[Ty, N, K]) print(" - Created softmax_1d schedule for [N, K]") - + # expert schedule - need to create library function schedules first (type inference workaround) print("\n[3] Creating schedule for expert...") import allo.library.nn as allo_nn - + print(" - Creating schedules for library functions (linear2d, GeLU)...") - s_linear_fc1 = allo.customize(allo_nn.linear2d, instantiate=[Ty, Ty, Ty, N, D_hidden, D_in]) + s_linear_fc1 = allo.customize( + allo_nn.linear2d, instantiate=[Ty, Ty, Ty, N, D_hidden, D_in] + ) s_gelu = allo.customize(allo_nn.GeLU, instantiate=[Ty, N, D_hidden]) - s_linear_fc2 = allo.customize(allo_nn.linear2d, instantiate=[Ty, Ty, Ty, N, D_out, D_hidden]) + s_linear_fc2 = allo.customize( + allo_nn.linear2d, instantiate=[Ty, Ty, Ty, N, D_out, D_hidden] + ) print(" - Library function schedules created") - + print(" - Creating expert schedule...") s_expert = allo.customize(expert, instantiate=[Ty, N, D_in, D_hidden, D_out]) - + # compose library functions into expert s_expert.compose(s_linear_fc1, id="expert_fc1") s_expert.compose(s_gelu, id="expert_gelu") s_expert.compose(s_linear_fc2, id="expert_fc2") print(" - Created expert schedule") print(" - Composed nn.linear2d and nn.GeLU schedules for expert") - + # moe_layer schedule print("\n[4] Creating schedule for moe_layer...") s_moe = allo.customize( moe_layer, - instantiate=[Ty, batch_size, seq_len, input_dim, output_dim, num_experts, k, hidden_dim] + instantiate=[ + Ty, + batch_size, + seq_len, + input_dim, + output_dim, + num_experts, + k, + hidden_dim, + ], ) - + # compose everything print("\n[5] Composing all schedules together...") - + # gate linear2d - s_gate_linear = allo.customize(allo_nn.linear2d, instantiate=[Ty, Ty, Ty, N, E, D_in]) + s_gate_linear = allo.customize( + allo_nn.linear2d, instantiate=[Ty, Ty, Ty, N, E, D_in] + ) s_moe.compose(s_gate_linear, id="gate") print(" - Composed linear2d (gate logits) schedule") - + # top1 and softmax s_moe.compose(s_top1) s_moe.compose(s_softmax) print(" - Composed top1_select and softmax_1d schedules") - + # expert (already has library functions composed) s_moe.compose(s_expert) print(" - Composed expert schedule (includes nn.linear2d and nn.GeLU)") - + print("\n" + "=" * 60) print("Schedule composition complete!") print("=" * 60) - + return s_moe @@ -252,9 +281,10 @@ def optimize_moe_with_composition( import torch.nn as nn import sys import os + # Import pytorch_moe from the same directory from pytorch_moe import MoELayer - + # ============================================================================ # Configuration parameters - using shared config module # ============================================================================ @@ -267,113 +297,143 @@ def optimize_moe_with_composition( num_experts = config["num_experts"] k = config["k"] hidden_dim = config["hidden_dim"] - + seed = 42 - + # Print configuration info using shared function print_config_info(CONFIG_MODE, config) print(f"Seed: {seed}") - - #---------------------------------------------------------------------------------- + + # ---------------------------------------------------------------------------------- # Run PyTorch implementation to get weights and outputs - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[1] Running PyTorch implementation...") torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) - + # Create PyTorch MoE layer from pytorch_moe.py pytorch_moe = MoELayer(input_dim, output_dim, num_experts, k, hidden_dim) pytorch_moe.eval() - + # Initialize with Xavier uniform (matching PyTorch) for param in pytorch_moe.parameters(): if param.dim() > 1: nn.init.xavier_uniform_(param) else: nn.init.zeros_(param) - + # Create random input torch.manual_seed(seed) pytorch_input = torch.randn(batch_size, seq_len, input_dim) - + # Run PyTorch inference with torch.no_grad(): pytorch_output = pytorch_moe(pytorch_input, verbose=False) - + print(f"PyTorch output shape: {pytorch_output.shape}") - print(f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]") - - #---------------------------------------------------------------------------------- + print( + f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]" + ) + + # ---------------------------------------------------------------------------------- # Extract weights and biases from PyTorch model - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[2] Extracting weights from PyTorch model...") - + # Gate weights: [input_dim, num_experts] -> [num_experts, input_dim] for Allo - gate_weight_pt = pytorch_moe.gate.gate_linear.weight.data # [num_experts, input_dim] + gate_weight_pt = ( + pytorch_moe.gate.gate_linear.weight.data + ) # [num_experts, input_dim] gate_bias_pt = pytorch_moe.gate.gate_linear.bias if gate_bias_pt is not None: gate_bias_pt = gate_bias_pt.data else: gate_bias_pt = torch.zeros(num_experts) - + # Expert weights - expert_fc1_weights_pt = torch.stack([expert.fc1.weight.data for expert in pytorch_moe.experts]) # [num_experts, hidden_dim, input_dim] - expert_fc1_biases_pt = torch.stack([expert.fc1.bias.data for expert in pytorch_moe.experts]) # [num_experts, hidden_dim] - expert_fc2_weights_pt = torch.stack([expert.fc2.weight.data for expert in pytorch_moe.experts]) # [num_experts, output_dim, hidden_dim] - expert_fc2_biases_pt = torch.stack([expert.fc2.bias.data for expert in pytorch_moe.experts]) # [num_experts, output_dim] - - #---------------------------------------------------------------------------------- + expert_fc1_weights_pt = torch.stack( + [expert.fc1.weight.data for expert in pytorch_moe.experts] + ) # [num_experts, hidden_dim, input_dim] + expert_fc1_biases_pt = torch.stack( + [expert.fc1.bias.data for expert in pytorch_moe.experts] + ) # [num_experts, hidden_dim] + expert_fc2_weights_pt = torch.stack( + [expert.fc2.weight.data for expert in pytorch_moe.experts] + ) # [num_experts, output_dim, hidden_dim] + expert_fc2_biases_pt = torch.stack( + [expert.fc2.bias.data for expert in pytorch_moe.experts] + ) # [num_experts, output_dim] + + # ---------------------------------------------------------------------------------- # Convert the weights and biases to numpy arrays (ensure C-contiguous and correct shape) - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[3] Converting weights to numpy arrays...") x_np = np.ascontiguousarray(pytorch_input.detach().numpy(), dtype=np.float32) - + # Gate weights: PyTorch has [num_experts, input_dim], Allo expects [num_experts, input_dim] (same) - gate_weight_np = np.ascontiguousarray(gate_weight_pt.detach().numpy(), dtype=np.float32) + gate_weight_np = np.ascontiguousarray( + gate_weight_pt.detach().numpy(), dtype=np.float32 + ) gate_bias_np = np.ascontiguousarray(gate_bias_pt.detach().numpy(), dtype=np.float32) - + # Expert weights: PyTorch has [num_experts, hidden_dim, input_dim], Allo expects [num_experts, hidden_dim, input_dim] (same) - expert_fc1_weights_np = np.ascontiguousarray(expert_fc1_weights_pt.detach().numpy(), dtype=np.float32) - expert_fc1_biases_np = np.ascontiguousarray(expert_fc1_biases_pt.detach().numpy(), dtype=np.float32) - expert_fc2_weights_np = np.ascontiguousarray(expert_fc2_weights_pt.detach().numpy(), dtype=np.float32) - expert_fc2_biases_np = np.ascontiguousarray(expert_fc2_biases_pt.detach().numpy(), dtype=np.float32) - + expert_fc1_weights_np = np.ascontiguousarray( + expert_fc1_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc1_biases_np = np.ascontiguousarray( + expert_fc1_biases_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_weights_np = np.ascontiguousarray( + expert_fc2_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_biases_np = np.ascontiguousarray( + expert_fc2_biases_pt.detach().numpy(), dtype=np.float32 + ) + print(f"Input shape: {x_np.shape}") print(f"Gate weight shape: {gate_weight_np.shape}") print(f"Expert FC1 weights shape: {expert_fc1_weights_np.shape}") print(f"Expert FC2 weights shape: {expert_fc2_weights_np.shape}") - - #---------------------------------------------------------------------------------- + + # ---------------------------------------------------------------------------------- # Run Allo implementation - #---------------------------------------------------------------------------------- + # ---------------------------------------------------------------------------------- print("\n[4] Running Allo implementation...") try: # Create optimized schedule with composition allo_schedule = optimize_moe_with_composition( batch_size, seq_len, input_dim, output_dim, num_experts, k, hidden_dim ) - + # Generate project name based on CONFIG_MODE to avoid conflicts # This ensures different configurations use different build folders project_name = f"allo_moe_lib_{CONFIG_MODE}.prj" print(f"Using project name: {project_name}") - + # Build module print("\n[5] Building Allo module...") if MODE == "llvm": mod = allo_schedule.build(target="llvm") elif MODE == "sw_emu": - mod = allo_schedule.build(target="vitis_hls", mode="sw_emu", project=project_name) + mod = allo_schedule.build( + target="vitis_hls", mode="sw_emu", project=project_name + ) elif MODE == "hw_emu": - mod = allo_schedule.build(target="vitis_hls", mode="hw_emu", project=project_name) + mod = allo_schedule.build( + target="vitis_hls", mode="hw_emu", project=project_name + ) elif MODE == "hw": - mod = allo_schedule.build(target="vitis_hls", mode="hw", project=project_name) + mod = allo_schedule.build( + target="vitis_hls", mode="hw", project=project_name + ) elif MODE == "csyn": - mod = allo_schedule.build(target="vitis_hls", mode="csyn", project=project_name) + mod = allo_schedule.build( + target="vitis_hls", mode="csyn", project=project_name + ) else: raise ValueError(f"Unsupported mode: {MODE}") - + # Run Allo inference print("\n[6] Running Allo inference...") if MODE == "llvm": @@ -384,54 +444,68 @@ def optimize_moe_with_composition( expert_fc1_weights_np, expert_fc1_biases_np, expert_fc2_weights_np, - expert_fc2_biases_np + expert_fc2_biases_np, ) elif MODE == "sw_emu" or MODE == "hw_emu" or MODE == "hw": allo_output = np.zeros((batch_size, seq_len, output_dim), dtype=np.float32) - mod(x_np, gate_weight_np, gate_bias_np, expert_fc1_weights_np, expert_fc1_biases_np, expert_fc2_weights_np, expert_fc2_biases_np, allo_output) + mod( + x_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + allo_output, + ) elif MODE == "csyn": allo_output = np.zeros((batch_size, seq_len, output_dim), dtype=np.float32) mod() else: raise ValueError(f"Unsupported mode: {MODE}") - + print(f"Allo output shape: {allo_output.shape}") print(f"Allo output range: [{allo_output.min():.6f}, {allo_output.max():.6f}]") - + # compare print("\n[7] Comparing outputs...") pytorch_output_np = pytorch_output.detach().numpy() - + diff = np.abs(allo_output - pytorch_output_np) mean_diff = np.mean(diff) max_diff = np.max(diff) rel_diff = np.mean(diff / (np.abs(pytorch_output_np) + 1e-8)) - + print(f"Mean absolute difference: {mean_diff:.6e}") print(f"Max absolute difference: {max_diff:.6e}") print(f"Mean relative difference: {rel_diff:.6e}") - + # check if close (1e-4 to 1e-3 diff is normal) atol = 5e-4 rtol = 2e-3 is_close = np.allclose(allo_output, pytorch_output_np, atol=atol, rtol=rtol) - + if is_close: - print(f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})") + print( + f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})" + ) else: - print(f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})") + print( + f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})" + ) print("First few differences:") print(diff.flatten()[:10]) - + # sample outputs print("\n[8] Sample outputs (first token, first 5 dimensions):") print(f"PyTorch: {pytorch_output_np[0, 0, :5]}") print(f"Allo: {allo_output[0, 0, :5]}") print(f"Diff: {diff[0, 0, :5]}") - + except Exception as e: print(f"\n✗ ERROR: Failed to run Allo implementation: {e}") import traceback + traceback.print_exc() - + print("=" * 60) diff --git a/ece6775/moe/pytorch_moe.py b/ece6775/moe/pytorch_moe.py index f34ef6427..fea2f445a 100644 --- a/ece6775/moe/pytorch_moe.py +++ b/ece6775/moe/pytorch_moe.py @@ -1,3 +1,5 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 """ Modern Mixture of Experts (MoE) Implementation for Inference Only @@ -14,95 +16,105 @@ class TopKGate(nn.Module): """ Gate module to select top k experts for routing. - + Args: input_dim: Input feature dimension num_experts: Total number of experts k: Number of experts to select per token """ - + def __init__(self, input_dim: int, num_experts: int, k: int = 1): super().__init__() self.k = k self.num_experts = num_experts # Linear layer to compute expert logits self.gate_linear = nn.Linear(input_dim, num_experts, bias=False) - - def forward(self, x: torch.Tensor, verbose: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + + def forward( + self, x: torch.Tensor, verbose: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass through the gate. - + Args: x: Input tensor of shape [batch_size * seq_len, input_dim] verbose: Whether to print routing information - + Returns: full_weights: Sparse weight matrix [batch_size * seq_len, num_experts] top_k_indices: Indices of selected experts [batch_size * seq_len, k] """ # Compute logits for all experts logits = self.gate_linear(x) # [N, num_experts] - + # Select top-k experts top_k_logits, top_k_indices = torch.topk(logits, self.k, dim=-1) - + # Apply softmax to top-k logits for normalized weights top_k_weights = F.softmax(top_k_logits, dim=-1) # [N, k] - + # Create sparse weight matrix (zeros for non-selected experts) full_weights = torch.zeros_like(logits) full_weights.scatter_(1, top_k_indices, top_k_weights) - + # Print routing information if verbose if verbose: num_tokens = x.shape[0] - + # Count tokens per expert expert_counts = {} for expert_idx in range(self.num_experts): count = (top_k_indices == expert_idx).sum().item() if count > 0: expert_counts[expert_idx] = count - + # Verify that each token selects exactly k experts tokens_per_expert_count = {} for i in range(num_tokens): num_experts_selected = (top_k_weights[i] > 1e-6).sum().item() - tokens_per_expert_count[num_experts_selected] = tokens_per_expert_count.get(num_experts_selected, 0) + 1 - + tokens_per_expert_count[num_experts_selected] = ( + tokens_per_expert_count.get(num_experts_selected, 0) + 1 + ) + print(f"\n[Gate Routing] Total tokens: {num_tokens}, k={self.k}") print(f"[Gate Routing] Expert distribution:") total_selections = num_tokens * self.k # Total expert-token pairs for expert_idx, count in sorted(expert_counts.items()): percentage = (count / total_selections) * 100 - print(f" Expert {expert_idx}: {count} selections ({percentage:.1f}% of {total_selections} total selections)") - + print( + f" Expert {expert_idx}: {count} selections ({percentage:.1f}% of {total_selections} total selections)" + ) + # Verify top-k: each token selects exactly k experts if tokens_per_expert_count.get(self.k, 0) == num_tokens: - print(f"[Gate Routing] ✓ All {num_tokens} tokens select exactly {self.k} expert(s)") + print( + f"[Gate Routing] ✓ All {num_tokens} tokens select exactly {self.k} expert(s)" + ) else: - print(f"[Gate Routing] ⚠ Expected all tokens to select {self.k} expert(s), but got: {tokens_per_expert_count}") - + print( + f"[Gate Routing] ⚠ Expected all tokens to select {self.k} expert(s), but got: {tokens_per_expert_count}" + ) + return full_weights, top_k_indices class Expert(nn.Module): """ A simple feed-forward expert network. - + Args: input_dim: Input feature dimension hidden_dim: Hidden layer dimension output_dim: Output feature dimension """ - + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): super().__init__() self.fc1 = nn.Linear(input_dim, hidden_dim) self.fc2 = nn.Linear(hidden_dim, output_dim) # Use GELU activation (modern choice) self.activation = nn.GELU() - + def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass through the expert.""" x = self.fc1(x) @@ -114,7 +126,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MoELayer(nn.Module): """ Mixture of Experts layer for inference. - + Args: input_dim: Input feature dimension output_dim: Output feature dimension @@ -122,128 +134,145 @@ class MoELayer(nn.Module): k: Number of experts to activate per token expert_hidden_dim: Hidden dimension for experts (default: 4 * input_dim) """ - + def __init__( self, input_dim: int, output_dim: int, num_experts: int, k: int = 1, - expert_hidden_dim: Optional[int] = None + expert_hidden_dim: Optional[int] = None, ): super().__init__() self.num_experts = num_experts self.k = k self.output_dim = output_dim - + if expert_hidden_dim is None: expert_hidden_dim = input_dim * 4 # Common practice - + # Initialize gate and experts self.gate = TopKGate(input_dim, num_experts, k) - self.experts = nn.ModuleList([ - Expert(input_dim, expert_hidden_dim, output_dim) - for _ in range(num_experts) - ]) - + self.experts = nn.ModuleList( + [ + Expert(input_dim, expert_hidden_dim, output_dim) + for _ in range(num_experts) + ] + ) + def forward(self, x: torch.Tensor, verbose: bool = False) -> torch.Tensor: """ Forward pass through MoE layer. - + Args: x: Input tensor of shape [batch_size, seq_len, input_dim] verbose: Whether to print verification information for top-1 MoE - + Returns: Output tensor of shape [batch_size, seq_len, output_dim] """ - + # Store original shape original_shape = x.shape batch_size, seq_len = original_shape[0], original_shape[1] - + # Flatten batch and sequence dimensions x_flat = x.view(-1, original_shape[-1]) # [N, input_dim], N = batch * seq_len num_tokens = x_flat.shape[0] - + # Get gating weights and expert indices - gate_weights, top_k_indices = self.gate(x_flat, verbose=verbose) # [N, num_experts], [N, k] - + gate_weights, top_k_indices = self.gate( + x_flat, verbose=verbose + ) # [N, num_experts], [N, k] + # Initialize output tensor output = torch.zeros( - x_flat.shape[0], - self.output_dim, - device=x.device, - dtype=x.dtype + x_flat.shape[0], self.output_dim, device=x.device, dtype=x.dtype ) - + # Track expert usage statistics expert_usage_stats = {} - + # Process each expert for expert_idx in range(self.num_experts): # Find tokens assigned to this expert # Create mask: tokens that have this expert in their top-k expert_mask = (top_k_indices == expert_idx).any(dim=-1) # [N] - + num_tokens_for_expert = expert_mask.sum().item() - + if not expert_mask.any(): expert_usage_stats[expert_idx] = 0 continue # No tokens for this expert - + # Track usage expert_usage_stats[expert_idx] = num_tokens_for_expert - + # Get inputs for this expert expert_inputs = x_flat[expert_mask] # [num_tokens, input_dim] - + # Process through expert - expert_outputs = self.experts[expert_idx](expert_inputs) # [num_tokens, output_dim] - + expert_outputs = self.experts[expert_idx]( + expert_inputs + ) # [num_tokens, output_dim] + # Get corresponding weights for these tokens # For each token, find the weight for this expert (could be in any of k positions) token_indices = torch.where(expert_mask)[0] # Indices in flattened tensor - expert_weights = gate_weights[token_indices, expert_idx].unsqueeze(1) # [num_tokens, 1] - + expert_weights = gate_weights[token_indices, expert_idx].unsqueeze( + 1 + ) # [num_tokens, 1] + # Weight and accumulate outputs weighted_outputs = expert_outputs * expert_weights output[token_indices] += weighted_outputs - + # Verify top-k behavior if verbose: - active_experts = [idx for idx, count in expert_usage_stats.items() if count > 0] + active_experts = [ + idx for idx, count in expert_usage_stats.items() if count > 0 + ] total_usage = sum(expert_usage_stats.values()) - + # Verify top-k: each token selects exactly k experts if top_k_indices.shape[1] != self.k: - print(f"[MoE Verification] ✗ ERROR: Expected k={self.k}, but top_k_indices has shape {top_k_indices.shape}") - + print( + f"[MoE Verification] ✗ ERROR: Expected k={self.k}, but top_k_indices has shape {top_k_indices.shape}" + ) + # Verify all tokens are processed (for k>1, total_usage may be > num_tokens) expected_total_usage = num_tokens * self.k if total_usage != expected_total_usage: - print(f"[MoE Verification] ⚠ Expected {expected_total_usage} expert-token pairs (={num_tokens} tokens × {self.k} experts), but got {total_usage}") - + print( + f"[MoE Verification] ⚠ Expected {expected_total_usage} expert-token pairs (={num_tokens} tokens × {self.k} experts), but got {total_usage}" + ) + # Check if weights are normalized (should sum to 1.0 per token) sample_weights = gate_weights.sum(dim=1) - weights_normalized = torch.allclose(sample_weights, torch.ones_like(sample_weights), atol=1e-5) - + weights_normalized = torch.allclose( + sample_weights, torch.ones_like(sample_weights), atol=1e-5 + ) + # Final summary for top-k print(f"\n[MoE Verification] === Top-{self.k} MoE Verification ===") print(f" ✓ Each token routes to exactly {self.k} expert(s) (k={self.k})") print(f" ✓ All {num_tokens} tokens processed") - print(f" ✓ Total expert-token pairs: {total_usage} (expected: {expected_total_usage})") + print( + f" ✓ Total expert-token pairs: {total_usage} (expected: {expected_total_usage})" + ) print(f" ✓ Active experts: {len(active_experts)}/{self.num_experts}") - print(f" ✓ Expert distribution: {dict(sorted(expert_usage_stats.items()))}") + print( + f" ✓ Expert distribution: {dict(sorted(expert_usage_stats.items()))}" + ) if weights_normalized: print(f" ✓ Gate weights normalized (sum to 1.0 per token)") else: print(f" ✗ Gate weights NOT normalized correctly") print(f" === Top-{self.k} MoE is working correctly! ===") - + # Reshape back to original shape output = output.view(batch_size, seq_len, self.output_dim) - + return output @@ -255,11 +284,11 @@ def run_moe_inference( num_experts: int = 2, k: int = 1, seed: int = 24, - verbose: bool = True + verbose: bool = True, ) -> Tuple[torch.Tensor, torch.Tensor, MoELayer]: """ Run MoE layer inference with fixed seed for reproducible inputs and weights. - + Args: batch_size: Batch size seq_len: Sequence length @@ -269,7 +298,7 @@ def run_moe_inference( k: Top-k experts to activate per token seed: Random seed for reproducibility verbose: Whether to print verification information - + Returns: output: Output tensor from MoE layer [batch_size, seq_len, output_dim] input_tensor: Input tensor used for inference @@ -279,26 +308,26 @@ def run_moe_inference( torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) - + # Create MoE layer moe_layer = MoELayer(input_dim, output_dim, num_experts, k) moe_layer.eval() - + # Initialize with random weights (Xavier uniform for better stability) for param in moe_layer.parameters(): if param.dim() > 1: nn.init.xavier_uniform_(param) else: nn.init.zeros_(param) - + # Create random input torch.manual_seed(seed) input_tensor = torch.randn(batch_size, seq_len, input_dim) - + # Run inference with verbose output for verification with torch.no_grad(): output = moe_layer(input_tensor, verbose=verbose) - + return output, input_tensor, moe_layer @@ -311,14 +340,16 @@ def run_moe_inference( num_experts = 2 k = 1 # Top-k MoE: each token uses exactly k experts seed = 24 - + print("=" * 60) print(f"Top-{k} MoE Inference Test") print("=" * 60) - print(f"Configuration: batch={batch_size}, seq_len={seq_len}, input_dim={input_dim}") + print( + f"Configuration: batch={batch_size}, seq_len={seq_len}, input_dim={input_dim}" + ) print(f"Experts: {num_experts}, k={k}") print("=" * 60) - + # Run MoE inference output, input_tensor, moe_layer = run_moe_inference( batch_size=batch_size, @@ -328,9 +359,8 @@ def run_moe_inference( num_experts=num_experts, k=k, seed=seed, - verbose=True + verbose=True, ) - + print(f"\nOutput shape: {output.shape}") print("=" * 60) -