diff --git a/ece6775/README.md b/ece6775/README.md new file mode 100644 index 000000000..fbc6ea18f --- /dev/null +++ b/ece6775/README.md @@ -0,0 +1,47 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# ECE6775 - MoE and Attention+MoE Implementations + +Allo implementations of MoE (Mixture of Experts) and Attention+MoE layers for FPGA acceleration. + +## Structure + +- `moe/`: Standalone MoE layer implementations +- `attention_moe/`: Attention + MoE combined implementations +- `llm_config/`: Shared configuration for different model sizes + +## MoE Implementations + +Each directory contains multiple versions: + +- `*_base.py`: Manual implementation (no library functions) +- `*_lib.py`: Uses Allo library functions (nn.linear2d, nn.GeLU) +- `*_alt.py`: Optimized version with fused GeLU and row-level dataflow +- `pytorch_*.py`: PyTorch reference implementation + +## Usage + +Run any implementation directly: + +```bash +python moe/allo_moe_alt.py +python attention_moe/allo_attention_moe_alt.py +``` + +## Configuration + +Edit `llm_config/llm_config.py` to change model configuration. Default: `switch_base_8_scaled_1_8` + +Available configs: +- `switch_base_8*`: Google Switch-Base-8 variants +- `mixtral_8x7b*`: Mixtral-8x7B variants +- `deepseek*`: DeepSeek MoE variants +- `custom`: Small test configuration + +## Notes + +- Small numerical differences (~1e-4) vs PyTorch are normal due to accumulation order +- Set `MODE` in each file to control build target: `llvm`, `sw_emu`, `hw_emu`, `hw`, `csyn` + +Github PR link: https://github.com/cornell-zhang/allo/pull/489 \ No newline at end of file diff --git a/ece6775/attention_moe/allo_attention_moe_alt.py b/ece6775/attention_moe/allo_attention_moe_alt.py new file mode 100644 index 000000000..a76aed77d --- /dev/null +++ b/ece6775/attention_moe/allo_attention_moe_alt.py @@ -0,0 +1,1015 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# Attention + MoE in Allo +# Q, K, V -> Attention -> MoE -> Output +# custom implementations (no library functions) for full control +# optimizations: fused GeLU, row-level dataflow, unrolled loops, array partitioning + +import numpy as np +import allo +from allo.ir.types import float32, int32 +from allo import dsl +import sys +import os + +# path setup +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(current_dir) +llm_config_dir = os.path.join(project_root, "llm_config") +if llm_config_dir not in sys.path: + sys.path.insert(0, llm_config_dir) + +MODE = "csyn" # llvm, sw_emu, hw_emu, hw, csyn + +from llm_config import DEFAULT_CONFIG_MODE, get_moe_config, print_config_info + +CONFIG_MODE = DEFAULT_CONFIG_MODE + + +def softmax_1d[Ty, N, K](X: "Ty[N, K]") -> "Ty[N, K]": + # softmax over last dim + Z: Ty[N, K] + E_exp: Ty[N, K] + M: Ty[N] = -1000000000000.0 + S: Ty[N] = 0.0 + + # find max per row + for n, k in dsl.grid(N, K, name="row_max"): + if X[n, k] > M[n]: + M[n] = X[n, k] + + # exp and sum + for n, k in dsl.grid(N, K, name="exp_sum"): + E_exp[n, k] = dsl.exp(X[n, k] - M[n]) + S[n] += E_exp[n, k] + + # normalize + for n, k in dsl.grid(N, K, name="normalize"): + Z[n, k] = E_exp[n, k] / S[n] + + return Z + + +# softmax for attention scores [L, L] +def softmax_2d[Ty, L](X: "Ty[L, L]") -> "Ty[L, L]": + Z: Ty[L, L] + E_exp: Ty[L, L] + M: Ty[L] = -1000000000000.0 + S: Ty[L] = 0.0 + + # find max per row + for i, j in dsl.grid(L, L, name="row_max"): + if X[i, j] > M[i]: + M[i] = X[i, j] + + # exp and sum + for i, j in dsl.grid(L, L, name="exp_sum"): + E_exp[i, j] = dsl.exp(X[i, j] - M[i]) + S[i] += E_exp[i, j] + + # normalize + for i, j in dsl.grid(L, L, name="normalize"): + Z[i, j] = E_exp[i, j] / S[i] + + return Z + + +# scaled dot-product attention (multi-head) +def scaled_dot_product_attention[ + Ty, H, L, D +](Q: "Ty[L, D]", K: "Ty[L, D]", V: "Ty[L, D]") -> "Ty[L, D]": + # scaled dot-product attention (multi-head) + + # matches pytorch: split into H heads, compute attention per head, merge + Z: Ty[L, D] = 0.0 + + # scale factor: 1/sqrt(head_dim) + scale: Ty = 1.0 / dsl.sqrt(float(D // H)) + + # process each head + for h in range(H, name="head_loop"): + # split Q, K, V for this head + Q_h: Ty[L, D // H] = 0.0 + K_h: Ty[L, D // H] = 0.0 + V_h: Ty[L, D // H] = 0.0 + + for i, j in dsl.grid(L, D // H, name="split_qkv"): + Q_h[i, j] = Q[i, h * (D // H) + j] + K_h[i, j] = K[i, h * (D // H) + j] + V_h[i, j] = V[i, h * (D // H) + j] + + # QK^T with accumulator to break read-after-write dependency + Y: Ty[L, L] = 0.0 + for i in range(L, name="qkt_i"): + for j in range(L, name="qkt_j"): + acc: Ty = 0.0 + for k in range(D // H, name="qkt_k"): + acc += Q_h[i, k] * K_h[j, k] # K_h[j, k] = K_h^T[k, j] + Y[i, j] = acc + + # scale + Y_scaled: Ty[L, L] = 0.0 + for i, j in dsl.grid(L, L, name="scale"): + Y_scaled[i, j] = Y[i, j] * scale + + # Apply softmax over last dimension + S: Ty[L, L] = 0.0 + E_exp: Ty[L, L] = 0.0 + M: Ty[L] = -1000000000000.0 + Sum: Ty[L] = 0.0 + + # Find max for each row + for i, j in dsl.grid(L, L, name="softmax_max"): + if Y_scaled[i, j] > M[i]: + M[i] = Y_scaled[i, j] + + # Compute exp and sum + for i, j in dsl.grid(L, L, name="softmax_exp"): + E_exp[i, j] = dsl.exp(Y_scaled[i, j] - M[i]) + Sum[i] += E_exp[i, j] + + # Normalize + for i, j in dsl.grid(L, L, name="softmax_norm"): + S[i, j] = E_exp[i, j] / Sum[i] + + # Compute S @ V_h = [L, L] @ [L, D//H] = [L, D//H] + # Loop order (i, k, j) for better memory access and pipelining: + # - Sequential access to V_h[k, j] as j changes (inner loop) + # - Better pipelining of reduction loop k + C_h: Ty[L, D // H] = 0.0 + for i in range(L, name="sv_i"): + for k in range(L, name="sv_k"): + for j in range(D // H, name="sv_j"): + C_h[i, j] += S[i, k] * V_h[k, j] + + # Merge back to Z + for i, j in dsl.grid(L, D // H, name="merge_heads"): + Z[i, h * (D // H) + j] = C_h[i, j] + + return Z + + +# top-1 selection (argmax) +def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": + # pick best expert per token + indices: int32[N] + max_val: Ty[N] = -1000000000000.0 + + for n in range(N, name="init"): + indices[n] = 0 + max_val[n] = logits[n, 0] + + for n, e in dsl.grid(N, E, name="argmax"): + if e > 0: # skip e=0 + if logits[n, e] > max_val[n]: + max_val[n] = logits[n, e] + indices[n] = e + + return indices + + +# FFN expert with fused GeLU +# optimizations: fuse GeLU into FC1, row-level dataflow, unroll reduction loops +def expert[ + Ty, N, D_in, D_hidden, D_out +]( + x: "Ty[N, D_in]", + fc1_weight: "Ty[D_hidden, D_in]", + fc1_bias: "Ty[D_hidden]", + fc2_weight: "Ty[D_out, D_hidden]", + fc2_bias: "Ty[D_out]", +) -> "Ty[N, D_out]": + """ + A simple feed-forward expert network with fused GeLU and row-level dataflow. + + Optimizations applied: + 1. GeLU fused into FC1: Compute GeLU immediately after each FC1 element + to eliminate intermediate array storage and enable streaming + 2. Row-level processing: Outer loop over rows (N) enables dataflow between + FC1+GeLU and FC2 stages - while FC2 processes row n, FC1 can process row n+1 + 3. Separate reduction loops enable unroll pragmas + + Args: + x: Input tensor of shape [N, D_in] + fc1_weight: First linear layer weights of shape [D_hidden, D_in] + fc1_bias: First linear layer bias of shape [D_hidden] + fc2_weight: Second linear layer weights of shape [D_out, D_hidden] + fc2_bias: Second linear layer bias of shape [D_out] + Returns: + output: Output tensor of shape [N, D_out] + """ + output: Ty[N, D_out] = 0.0 + + # Row-level processing: outer loop over rows (tokens) + # This structure enables dataflow between FC1+GeLU and FC2 + for n in range(N, name="row_loop"): + # ===================================================================== + # Stage 1: FC1 + Fused GeLU (produces one row of hidden activations) + # ===================================================================== + # Intermediate buffer for this row's hidden activations (after GeLU) + hidden_row: Ty[D_hidden] = 0.0 + + # Compute FC1 output for each hidden dimension, fuse GeLU immediately + for h in range(D_hidden, name="fc1_gelu_loop"): + # FC1: dot product x[n,:] @ fc1_weight[h,:] + fc1_bias[h] + acc: Ty = 0.0 + for k in range(D_in, name="fc1_reduce"): + acc += x[n, k] * fc1_weight[h, k] + fc1_val: Ty = acc + fc1_bias[h] + + # Fused GeLU: apply immediately to avoid storing fc1_out array + # GeLU(x) = 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x³))) + # This matches PyTorch's GELU implementation (tanh approximation) + # sqrt(2/π) ≈ 0.7978845608028654 + x3: Ty = fc1_val * fc1_val * fc1_val + inner: Ty = 0.7978845608028654 * (fc1_val + 0.044715 * x3) + tanh_val: Ty = dsl.tanh(inner) + gelu_val: Ty = 0.5 * fc1_val * (1.0 + tanh_val) + hidden_row[h] = gelu_val + + # ===================================================================== + # Stage 2: FC2 (consumes hidden row, produces output row) + # ===================================================================== + for o in range(D_out, name="fc2_loop"): + # FC2: dot product hidden_row @ fc2_weight[o,:] + fc2_bias[o] + acc2: Ty = 0.0 + for k in range(D_hidden, name="fc2_reduce"): + acc2 += hidden_row[k] * fc2_weight[o, k] + output[n, o] = acc2 + fc2_bias[o] + + return output + + +# MoE layer - routes tokens to experts +def moe_layer[ + Ty, N, D_in, D_out, E, K, D_hidden +]( + x: "Ty[N, D_in]", + # Gate weights + gate_weight: "Ty[E, D_in]", + gate_bias: "Ty[E]", + # Expert weights (E experts, each with 2 linear layers) + expert_fc1_weights: "Ty[E, D_hidden, D_in]", + expert_fc1_biases: "Ty[E, D_hidden]", + expert_fc2_weights: "Ty[E, D_out, D_hidden]", + expert_fc2_biases: "Ty[E, D_out]", +) -> "Ty[N, D_out]": + """ + Mixture of Experts layer for inference. + + Args: + x: Input tensor of shape [N, D_in] (already flattened from [B, L, D_in]) + gate_weight: Gate weight matrix of shape [E, D_in] + gate_bias: Gate bias vector of shape [E] + expert_fc1_weights: Expert FC1 weights of shape [E, D_hidden, D_in] + expert_fc1_biases: Expert FC1 biases of shape [E, D_hidden] + expert_fc2_weights: Expert FC2 weights of shape [E, D_out, D_hidden] + expert_fc2_biases: Expert FC2 biases of shape [E, D_out] + Returns: + output: Output tensor of shape [N, D_out] + """ + # ========================================================================= + # Step 1: Compute gate logits (custom linear, no library) + # gate_logits[n, e] = x[n, :] @ gate_weight[e, :] + gate_bias[e] + # ========================================================================= + gate_logits: Ty[N, E] = 0.0 + for n, e in dsl.grid(N, E, name="gate_linear"): + acc: Ty = 0.0 + for k in range(D_in, name="gate_reduce"): + acc += x[n, k] * gate_weight[e, k] + gate_logits[n, e] = acc + gate_bias[e] + + # ========================================================================= + # Step 2: Select top-1 expert using top1_select function + # ========================================================================= + top1_indices_1d: int32[N] = top1_select[Ty, N, E](gate_logits) + + # ========================================================================= + # Step 3: Get top-k logits and apply softmax + # ========================================================================= + top_k_logits: Ty[N, K] = 0.0 + for n, k in dsl.grid(N, K, name="topk_logits"): + expert_idx = top1_indices_1d[n] if k == 0 else 0 # For k=1, K=1 + top_k_logits[n, k] = gate_logits[n, expert_idx] + + # ========================================================================= + # Step 4: Apply softmax to top-k logits using softmax_1d function + # ========================================================================= + top_k_weights = softmax_1d[Ty, N, K](top_k_logits) # [N, K] + + # ========================================================================= + # Step 5: Create sparse weight matrix from top-k weights + # ========================================================================= + gate_weights: Ty[N, E] = 0.0 + for n in range(N, name="sparse_gate"): + expert_idx = top1_indices_1d[n] + gate_weights[n, expert_idx] = top_k_weights[n, 0] # For k=1, K=1 + + # ========================================================================= + # Step 6: Process each expert: compute outputs for all tokens + # ========================================================================= + expert_outputs: Ty[E, N, D_out] = 0.0 + + for e in range(E, name="expert_loop"): + # Extract expert weights for this expert + expert_fc1_w: Ty[D_hidden, D_in] = 0.0 + expert_fc1_b: Ty[D_hidden] = 0.0 + expert_fc2_w: Ty[D_out, D_hidden] = 0.0 + expert_fc2_b: Ty[D_out] = 0.0 + + for d_hidden, d_in in dsl.grid(D_hidden, D_in, name="extract_fc1_w"): + expert_fc1_w[d_hidden, d_in] = expert_fc1_weights[e, d_hidden, d_in] + + for d_hidden in range(D_hidden, name="extract_fc1_b"): + expert_fc1_b[d_hidden] = expert_fc1_biases[e, d_hidden] + + for d_out, d_hidden in dsl.grid(D_out, D_hidden, name="extract_fc2_w"): + expert_fc2_w[d_out, d_hidden] = expert_fc2_weights[e, d_out, d_hidden] + + for d_out in range(D_out, name="extract_fc2_b"): + expert_fc2_b[d_out] = expert_fc2_biases[e, d_out] + + # Process all tokens through this expert (uses optimized expert function) + expert_out = expert[Ty, N, D_in, D_hidden, D_out]( + x, expert_fc1_w, expert_fc1_b, expert_fc2_w, expert_fc2_b + ) # [N, D_out] + + # Store expert outputs + for n, d_out in dsl.grid(N, D_out, name="store_expert_out"): + expert_outputs[e, n, d_out] = expert_out[n, d_out] + + # ========================================================================= + # Step 7: Combine expert outputs using gate weights + # ========================================================================= + output: Ty[N, D_out] = 0.0 + for n, e, d_out in dsl.grid(N, E, D_out, name="combine_outputs"): + weight: Ty = gate_weights[n, e] + output[n, d_out] += expert_outputs[e, n, d_out] * weight + + return output + + +# ---------------------------------------------------------------------------------- +# Attention + MoE Layer: Combined layer +# Data flow: Q, K, V -> Attention -> MoE -> Output +# ---------------------------------------------------------------------------------- +def attention_moe_layer[ + Ty, B, L, D, H, E, TopK, D_hidden +]( + Query: "Ty[B, L, D]", + Key: "Ty[B, L, D]", + Value: "Ty[B, L, D]", + # Gate weights + gate_weight: "Ty[E, D]", + gate_bias: "Ty[E]", + # Expert weights (E experts, each with 2 linear layers) + expert_fc1_weights: "Ty[E, D_hidden, D]", + expert_fc1_biases: "Ty[E, D_hidden]", + expert_fc2_weights: "Ty[E, D, D_hidden]", + expert_fc2_biases: "Ty[E, D]", +) -> "Ty[B, L, D]": + """ + Combined Attention + MoE layer. + + Data flow: Q, K, V -> Attention -> MoE -> Output + + Args: + Query: Query tensor of shape [B, L, D] + Key: Key tensor of shape [B, L, D] + Value: Value tensor of shape [B, L, D] + gate_weight: Gate weight matrix of shape [E, D] + gate_bias: Gate bias vector of shape [E] + expert_fc1_weights: Expert FC1 weights of shape [E, D_hidden, D] + expert_fc1_biases: Expert FC1 biases of shape [E, D_hidden] + expert_fc2_weights: Expert FC2 weights of shape [E, D, D_hidden] + expert_fc2_biases: Expert FC2 biases of shape [E, D] + Returns: + output: Output tensor of shape [B, L, D] + """ + # Output tensor + output: Ty[B, L, D] = 0.0 + + # Process each batch item + for b in range(B, name="batch_loop"): + # Step 1: Extract Q, K, V for this batch item -> [L, D] + Q_b: Ty[L, D] = 0.0 + K_b: Ty[L, D] = 0.0 + V_b: Ty[L, D] = 0.0 + + for l, d in dsl.grid(L, D, name="extract_qkv"): + Q_b[l, d] = Query[b, l, d] + K_b[l, d] = Key[b, l, d] + V_b[l, d] = Value[b, l, d] + + # Step 2: Apply custom scaled_dot_product_attention + # scaled_dot_product_attention[Ty, H, L, D](Q, K, V) -> [L, D] + attn_out = scaled_dot_product_attention[Ty, H, L, D](Q_b, K_b, V_b) # [L, D] + + # Step 3: Apply MoE layer + # moe_layer expects [N, D_in], so we use N=L for single batch + moe_out = moe_layer[Ty, L, D, D, E, TopK, D_hidden]( + attn_out, + gate_weight, + gate_bias, + expert_fc1_weights, + expert_fc1_biases, + expert_fc2_weights, + expert_fc2_biases, + ) # [L, D] + + # Step 4: Store output for this batch item + for l, d in dsl.grid(L, D, name="store_output"): + output[b, l, d] = moe_out[l, d] + + return output + + +# ================================================================================== +# Schedule optimization function with HLS pragmas +# ================================================================================== +def optimize_attention_moe_with_composition( + batch_size, seq_len, embed_dim, num_heads, num_experts, k, hidden_dim +): + """ + Create optimized schedules for Attention + MoE and compose them together. + + Optimization strategy: + 1. Pipeline innermost loops for throughput + 2. Unroll small loops for parallelism + 3. Partition arrays for parallel access + + Args: + batch_size: Batch size (B) + seq_len: Sequence length (L) + embed_dim: Embedding dimension (D) + num_heads: Number of attention heads (H) + num_experts: Number of experts (E) + k: Top-k value (currently only k=1 is supported) + hidden_dim: Hidden dimension for experts + + Returns: + s_attn_moe: Optimized schedule for attention_moe_layer with all sub-schedules composed + """ + Ty = float32 + B = batch_size + L = seq_len + D = embed_dim + H = num_heads + E = num_experts + K = k + D_hidden = hidden_dim + head_dim = D // H + + print("=" * 60) + print("Creating and optimizing Attention + MoE schedules...") + print("=" * 60) + + # ========================================================================= + # Step 1: Create and optimize schedule for top1_select + # ========================================================================= + print("\n[1] Creating schedule for top1_select...") + s_top1 = allo.customize(top1_select, instantiate=[Ty, L, E]) + # Pipeline the inner loop for finding argmax + # Note: Use "function_name:loop_var" format for dsl.grid loops + s_top1.pipeline("top1_select:e") + print(" - Created top1_select schedule with pipeline optimization") + + # ========================================================================= + # Step 2: Create and optimize schedule for softmax_1d (for MoE gate) + # ========================================================================= + print("\n[2] Creating schedule for softmax_1d...") + s_softmax_1d = allo.customize(softmax_1d, instantiate=[Ty, L, K]) + # Pipeline softmax loops using get_loops() to get loop handles + loops_softmax = s_softmax_1d.get_loops(s_softmax_1d.top_func_name) + s_softmax_1d.pipeline(loops_softmax["row_max"]["k"]) + s_softmax_1d.pipeline(loops_softmax["exp_sum"]["k"]) + s_softmax_1d.pipeline(loops_softmax["normalize"]["k"]) + print(" - Created softmax_1d schedule with pipeline optimization") + + # ========================================================================= + # Step 3: Create and optimize schedule for expert + # Optimizations: + # - Dataflow on row_loop (pipeline FC1+GeLU and FC2 stages) + # - Unroll on fc1_reduce and fc2_reduce loops + # - Array partitioning for parallel access + # ========================================================================= + print("\n[3] Creating schedule for expert (custom, no library functions)...") + s_expert = allo.customize(expert, instantiate=[Ty, L, D, D_hidden, D]) + + # Get loop handles for expert + expert_loops = s_expert.get_loops(s_expert.top_func_name) + print(f" - Available loops: {list(expert_loops.loops.keys())}") + + # Get nested loops inside row_loop + row_loop = expert_loops["row_loop"] + print(f" - row_loop sub-loops: {list(row_loop.loops.keys())}") + + # ------------------------------------------------------------------------- + # Note: Allo flattens nested loops and uses loop variable names as keys + # row_loop sub-loops: ['n', 'h', 'k', 'o'] where: + # n = row index (from row_loop) + # h = hidden dim index (from fc1_gelu_loop) + # k = reduction index (shared by fc1_reduce and fc2_reduce) + # o = output dim index (from fc2_loop) + # ------------------------------------------------------------------------- + + # ------------------------------------------------------------------------- + # Optimization 1: Pipeline the output dimension loops (h and o) + # This pipelines both the FC1+GeLU loop and FC2 loop + # ------------------------------------------------------------------------- + s_expert.pipeline(row_loop["h"]) # Pipeline FC1+GeLU hidden dim loop + s_expert.pipeline(row_loop["o"]) # Pipeline FC2 output dim loop + print(" - Applied pipeline to h (fc1_gelu) and o (fc2) loops") + + # ------------------------------------------------------------------------- + # Optimization 2: Unroll reduction loop for parallel MACs + # The k loop is shared between FC1 and FC2 reductions + # ------------------------------------------------------------------------- + # Determine unroll factor based on dimensions + unroll_factor = min(4, D, D_hidden) + + s_expert.unroll(row_loop["k"], factor=unroll_factor) + print(f" - Applied unroll to k (reduction loop) factor={unroll_factor}") + + # ------------------------------------------------------------------------- + # Note: Removed explicit array partitioning + # Let HLS infer partitioning from pipelining directives for better tool runtime + # Explicit partitioning can explode synthesis time and may not be optimal + # ------------------------------------------------------------------------- + + print(" - Created expert schedule with pipeline and unroll optimizations") + print(" - Note: Array partitioning inferred by HLS from pipelining") + + # ========================================================================= + # Step 4: Create and optimize schedule for moe_layer + # Optimizations: + # - Pipeline on gate_linear and combine_outputs + # - Unroll on gate_reduce for parallel access + # - Compose optimized expert schedule + # ========================================================================= + print("\n[4] Creating schedule for moe_layer...") + s_moe = allo.customize(moe_layer, instantiate=[Ty, L, D, D, E, K, D_hidden]) + + # Get loop handles for moe_layer + moe_loops = s_moe.get_loops(s_moe.top_func_name) + print(f" - Available top-level loops: {list(moe_loops.loops.keys())}") + + # ------------------------------------------------------------------------- + # Optimize gate_linear: pipeline and unroll + # Note: Allo flattens loops - gate_linear sub-loops are likely ['n', 'e', 'k'] + # ------------------------------------------------------------------------- + gate_linear_loop = moe_loops["gate_linear"] + print(f" - gate_linear sub-loops: {list(gate_linear_loop.loops.keys())}") + + # Pipeline the e loop (output dimension) and unroll k (reduction) + s_moe.pipeline(gate_linear_loop["e"]) + print(" - Applied pipeline to gate_linear:e") + + unroll_factor_gate = min(4, D) + s_moe.unroll(gate_linear_loop["k"], factor=unroll_factor_gate) + print(f" - Applied unroll to gate_linear:k (factor={unroll_factor_gate})") + + # ------------------------------------------------------------------------- + # Optimize other loops in moe_layer + # Note: For dsl.grid loops, sub-loops use variable names (n, e, k, d_out, etc.) + # ------------------------------------------------------------------------- + combine_loop = moe_loops["combine_outputs"] + print(f" - combine_outputs sub-loops: {list(combine_loop.loops.keys())}") + s_moe.pipeline(combine_loop["d_out"]) + + topk_loop = moe_loops["topk_logits"] + print(f" - topk_logits sub-loops: {list(topk_loop.loops.keys())}") + s_moe.pipeline(topk_loop["k"]) + + sparse_loop = moe_loops["sparse_gate"] + print(f" - sparse_gate sub-loops: {list(sparse_loop.loops.keys())}") + s_moe.pipeline(sparse_loop["n"]) + print(" - Applied pipeline to combine_outputs:d_out, topk_logits:k, sparse_gate:n") + + # ------------------------------------------------------------------------- + # Compose sub-function schedules + # ------------------------------------------------------------------------- + s_moe.compose(s_top1) + s_moe.compose(s_softmax_1d) + s_moe.compose(s_expert) + print(" - Composed top1_select, softmax_1d, and expert schedules") + print(" - Created moe_layer schedule with pipeline, unroll optimizations") + + # ========================================================================= + # Step 5: Create and optimize schedule for scaled_dot_product_attention + # ========================================================================= + print("\n[5] Creating schedule for custom scaled_dot_product_attention...") + s_attn = allo.customize(scaled_dot_product_attention, instantiate=[Ty, H, L, D]) + + # Get loop handles for attention + attn_loops = s_attn.get_loops(s_attn.top_func_name) + print(f" - Available top-level loops: {list(attn_loops.loops.keys())}") + + # Helper function to print loop hierarchy recursively + def print_loop_hierarchy(loops, indent=0): + for key, handle in loops.loops.items(): + print(" " * indent + f"- {key}") + # Check if handle has children (is a loop wrapper) + if hasattr(handle, "loops"): + print_loop_hierarchy(handle, indent + 2) + + print(" - Full loop hierarchy:") + print_loop_hierarchy(attn_loops, indent=4) + + # Attention has nested loops inside head_loop + # The structure is: head_loop -> h -> [split_qkv, qkt_matmul, scale, softmax_*, sv_matmul, merge_heads] + head_inner = attn_loops["head_loop"] + # print(f" - head_loop sub-loops: {list(head_inner.loops.keys())}") + + # Get loop handles for sv_matmul and qkt_matmul BEFORE applying global pipeline + # Global pipeline using string selectors might modify loop structure/metadata + # sv_loops = head_inner["sv_matmul"] + # qkt_loops = head_inner["qkt_matmul"] + # print(f" - sv_matmul sub-loops: {list(sv_loops.loops.keys())}") + # print(f" - qkt_matmul sub-loops: {list(qkt_loops.loops.keys())}") + + # Pipeline reduction loops (k) for matmuls - this is critical for II=1 + # The reordered loops (i, k, j) allow better pipelining of k reduction + s_attn.pipeline( + "scaled_dot_product_attention:k" + ) # Pipeline k reduction loops (qkt_k, sv_k) + s_attn.pipeline( + "scaled_dot_product_attention:j" + ) # Pipeline j loops for output dimension + s_attn.pipeline( + "scaled_dot_product_attention:i" + ) # Pipeline i loops (sv_i) to overlap row processing + print(" - Applied pipeline to i (row), k (reduction), and j (output) loops") + print( + " - Note: Loop reordering (i,k,j) enables better memory access and pipelining" + ) + print(" - Pipelining sv_i allows overlapping processing of different output rows") + + # ------------------------------------------------------------------------- + # Apply unroll optimization to break loop-carried dependency + # Unroll k loop by factor of 4 to allow 4 multiplications in parallel + # ------------------------------------------------------------------------- + # print(" - Applying unroll optimization...") + # s_attn.unroll(sv_loops["k"], factor=4) + # print(" - Applied unroll to sv_matmul:k (factor=4)") + # s_attn.unroll(qkt_loops["k"], factor=4) + # print(" - Applied unroll to qkt_matmul:k (factor=4)") + + # ------------------------------------------------------------------------- + # Apply pipeline optimization to reduction loops + # ------------------------------------------------------------------------- + # print(" - Applying pipeline to reduction loops...") + # s_attn.pipeline(sv_loops["k"]) + # print(" - Applied pipeline to sv_matmul:k (reduction loop)") + # s_attn.pipeline(qkt_loops["k"]) + # print(" - Applied pipeline to qkt_matmul:k (reduction loop)") + + # ------------------------------------------------------------------------- + # Apply reorder optimization for better data locality + # For matrix multiplication C[i,j] += A[i,k] * B[k,j], reorder to i,k,j + # This improves memory access pattern for B (V_h in sv_matmul) + # ------------------------------------------------------------------------- + # print(" - Applying reorder optimization...") + # s_attn.reorder(sv_loops["k"], sv_loops["j"]) + # print(" - Applied reorder to sv_matmul: (i, j, k) -> (i, k, j)") + # s_attn.reorder(qkt_loops["k"], qkt_loops["j"]) + # print(" - Applied reorder to qkt_matmul: (i, j, k) -> (i, k, j)") + + # ------------------------------------------------------------------------- + # Apply buffer_at optimization to reduce memory access + # Creates on-chip buffer for intermediate results + # ------------------------------------------------------------------------- + # print(" - Applying buffer_at optimization...") + # s_attn.buffer_at(s_attn.C_h, axis=sv_loops["i"]) + # print(" - Applied buffer_at to C_h at sv_matmul:i") + # s_attn.buffer_at(s_attn.Y, axis=qkt_loops["i"]) + # print(" - Applied buffer_at to Y at qkt_matmul:i") + + print( + " - Created scaled_dot_product_attention schedule with pipeline optimizations" + ) + print(" - Matmul loops reordered to (i,k,j) for better memory access and II=1") + + # ========================================================================= + # Step 6: Create schedule for main attention_moe_layer function + # ========================================================================= + print("\n[6] Creating schedule for attention_moe_layer...") + s_attn_moe = allo.customize( + attention_moe_layer, instantiate=[Ty, B, L, D, H, E, K, D_hidden] + ) + + # Get loop handles for attention_moe_layer + attn_moe_loops = s_attn_moe.get_loops(s_attn_moe.top_func_name) + print(f" - Available top-level loops: {list(attn_moe_loops.loops.keys())}") + + print(" - Full loop hierarchy for attention_moe_layer:") + print_loop_hierarchy(attn_moe_loops, indent=4) + + # Pipeline top-level loops using get_loops() + # Note: extract_qkv and store_output are inside batch_loop + batch_loop = attn_moe_loops["batch_loop"] + + # s_attn_moe.pipeline(batch_loop["extract_qkv"]["d"]) + # s_attn_moe.pipeline(batch_loop["store_output"]["d"]) + + # ========================================================================= + # Step 7: Compose all schedules together + # ========================================================================= + print("\n[7] Composing all schedules together...") + s_attn_moe.compose(s_attn) + s_attn_moe.compose(s_moe) + print(" - Composed scaled_dot_product_attention schedule") + print(" - Composed moe_layer schedule") + + print("\n" + "=" * 60) + print("Schedule composition complete with optimizations!") + print("=" * 60) + print("MoE Optimizations applied:") + print(" 1. Fused GeLU into FC1 (eliminates intermediate array)") + print(" 2. Row-level structure (FC1+GeLU -> FC2)") + print(" 3. Unrolled reduction loops (parallel MACs)") + print(f" - Expert k loop: unroll factor {min(4, D, D_hidden)}") + print(f" - Gate k loop: unroll factor {min(4, D)}") + print(" 4. Pipeline on h, o, e loops (output dimensions)") + print(" 5. Array partitioning inferred by HLS (not explicit)") + print("-" * 60) + print("Attention Optimizations (unchanged):") + print(" - Pipeline on innermost loops (j, k)") + print("=" * 60) + + return s_attn_moe + + +# test/compare with pytorch +if __name__ == "__main__": + import torch + import torch.nn as torch_nn + from pytorch_attention_moe import AttentionMoE + + # get config + moe_config = get_moe_config(CONFIG_MODE) + + batch_size = moe_config["batch_size"] + seq_len = moe_config["seq_len"] + embed_dim = moe_config["input_dim"] # D, must be divisible by num_heads + num_experts = moe_config["num_experts"] # E + k = moe_config["k"] # Top-k MoE + hidden_dim = moe_config["hidden_dim"] # D_hidden + + # Attention-specific parameter: num_heads + # Choose num_heads such that embed_dim is divisible by num_heads + # Common choices: 2, 4, 8, 12, 16 (depending on embed_dim) + if embed_dim >= 768: + num_heads = 12 # Standard for BERT-base + elif embed_dim >= 512: + num_heads = 8 + elif embed_dim >= 256: + num_heads = 8 + elif embed_dim >= 128: + num_heads = 4 + elif embed_dim >= 64: + num_heads = 4 + else: + num_heads = 2 + + # Ensure embed_dim is divisible by num_heads + while embed_dim % num_heads != 0 and num_heads > 1: + num_heads -= 1 + + seed = 42 + + print("=" * 60) + print("Attention + MoE Allo Implementation Test") + print("=" * 60) + print(f"Configuration Mode: {CONFIG_MODE}") + print(f"Configuration:") + print(f" batch_size={batch_size}, seq_len={seq_len}, embed_dim={embed_dim}") + print(f" num_heads={num_heads}, head_dim={embed_dim // num_heads}") + print(f" num_experts={num_experts}, k={k}, hidden_dim={hidden_dim}") + print(f" Seed: {seed}") + print("=" * 60) + + # ---------------------------------------------------------------------------------- + # Run PyTorch implementation to get weights and outputs + # ---------------------------------------------------------------------------------- + print("\n[1] Running PyTorch implementation...") + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # Create PyTorch AttentionMoE layer + pytorch_model = AttentionMoE( + seq_len=seq_len, + embed_dim=embed_dim, + num_heads=num_heads, + num_experts=num_experts, + k=k, + expert_hidden_dim=hidden_dim, + ) + pytorch_model.eval() + + # Initialize with Xavier uniform + for param in pytorch_model.parameters(): + if param.dim() > 1: + torch_nn.init.xavier_uniform_(param) + else: + torch_nn.init.zeros_(param) + + # Create random inputs (Q, K, V) + torch.manual_seed(seed) + Q_pt = torch.randn(batch_size, seq_len, embed_dim) + K_pt = torch.randn(batch_size, seq_len, embed_dim) + V_pt = torch.randn(batch_size, seq_len, embed_dim) + + # Run PyTorch inference + with torch.no_grad(): + pytorch_output = pytorch_model(Q_pt, K_pt, V_pt, verbose=False) + + print(f"PyTorch output shape: {pytorch_output.shape}") + print( + f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]" + ) + + # ---------------------------------------------------------------------------------- + # Extract weights and biases from PyTorch model + # ---------------------------------------------------------------------------------- + print("\n[2] Extracting weights from PyTorch model...") + + # Gate weights + gate_weight_pt = ( + pytorch_model.moe.gate.gate_linear.weight.data + ) # [num_experts, embed_dim] + gate_bias_pt = pytorch_model.moe.gate.gate_linear.bias + if gate_bias_pt is not None: + gate_bias_pt = gate_bias_pt.data + else: + gate_bias_pt = torch.zeros(num_experts) + + # Expert weights + expert_fc1_weights_pt = torch.stack( + [exp.fc1.weight.data for exp in pytorch_model.moe.experts] + ) + expert_fc1_biases_pt = torch.stack( + [exp.fc1.bias.data for exp in pytorch_model.moe.experts] + ) + expert_fc2_weights_pt = torch.stack( + [exp.fc2.weight.data for exp in pytorch_model.moe.experts] + ) + expert_fc2_biases_pt = torch.stack( + [exp.fc2.bias.data for exp in pytorch_model.moe.experts] + ) + + # ---------------------------------------------------------------------------------- + # Convert to numpy arrays + # ---------------------------------------------------------------------------------- + print("\n[3] Converting weights to numpy arrays...") + Q_np = np.ascontiguousarray(Q_pt.detach().numpy(), dtype=np.float32) + K_np = np.ascontiguousarray(K_pt.detach().numpy(), dtype=np.float32) + V_np = np.ascontiguousarray(V_pt.detach().numpy(), dtype=np.float32) + + gate_weight_np = np.ascontiguousarray( + gate_weight_pt.detach().numpy(), dtype=np.float32 + ) + gate_bias_np = np.ascontiguousarray(gate_bias_pt.detach().numpy(), dtype=np.float32) + + expert_fc1_weights_np = np.ascontiguousarray( + expert_fc1_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc1_biases_np = np.ascontiguousarray( + expert_fc1_biases_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_weights_np = np.ascontiguousarray( + expert_fc2_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_biases_np = np.ascontiguousarray( + expert_fc2_biases_pt.detach().numpy(), dtype=np.float32 + ) + + print(f"Q shape: {Q_np.shape}") + print(f"K shape: {K_np.shape}") + print(f"V shape: {V_np.shape}") + print(f"Gate weight shape: {gate_weight_np.shape}") + print(f"Expert FC1 weights shape: {expert_fc1_weights_np.shape}") + print(f"Expert FC2 weights shape: {expert_fc2_weights_np.shape}") + + # ---------------------------------------------------------------------------------- + # Run Allo implementation + # ---------------------------------------------------------------------------------- + print("\n[4] Running Allo implementation...") + try: + # Create optimized schedule with composition + allo_schedule = optimize_attention_moe_with_composition( + batch_size, seq_len, embed_dim, num_heads, num_experts, k, hidden_dim + ) + + # Generate project name + project_name = f"allo_attention_moe_alt_{CONFIG_MODE}.prj" + print(f"Using project name: {project_name}") + + # Build module + print("\n[5] Building Allo module...") + if MODE == "llvm": + mod = allo_schedule.build(target="llvm") + elif MODE == "sw_emu": + mod = allo_schedule.build( + target="vitis_hls", mode="sw_emu", project=project_name + ) + elif MODE == "hw_emu": + mod = allo_schedule.build( + target="vitis_hls", mode="hw_emu", project=project_name + ) + elif MODE == "hw": + mod = allo_schedule.build( + target="vitis_hls", mode="hw", project=project_name + ) + elif MODE == "csyn": + mod = allo_schedule.build( + target="vitis_hls", mode="csyn", project=project_name + ) + else: + raise ValueError(f"Unsupported mode: {MODE}") + + # Run Allo inference + print("\n[6] Running Allo inference...") + if MODE == "llvm": + allo_output = mod( + Q_np, + K_np, + V_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + ) + elif MODE in ["sw_emu", "hw_emu", "hw"]: + allo_output = np.zeros((batch_size, seq_len, embed_dim), dtype=np.float32) + mod( + Q_np, + K_np, + V_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + allo_output, + ) + elif MODE == "csyn": + allo_output = np.zeros((batch_size, seq_len, embed_dim), dtype=np.float32) + mod() + else: + raise ValueError(f"Unsupported mode: {MODE}") + + print(f"Allo output shape: {allo_output.shape}") + print(f"Allo output range: [{allo_output.min():.6f}, {allo_output.max():.6f}]") + + # ---------------------------------------------------------------------------------- + # Compare the Allo and PyTorch outputs + # ---------------------------------------------------------------------------------- + print("\n[7] Comparing outputs...") + pytorch_output_np = pytorch_output.detach().numpy() + + # Compute differences + diff = np.abs(allo_output - pytorch_output_np) + mean_diff = np.mean(diff) + max_diff = np.max(diff) + rel_diff = np.mean(diff / (np.abs(pytorch_output_np) + 1e-8)) + + print(f"Mean absolute difference: {mean_diff:.6e}") + print(f"Max absolute difference: {max_diff:.6e}") + print(f"Mean relative difference: {rel_diff:.6e}") + + # Check if outputs are close + atol = 5e-4 + rtol = 2e-3 + is_close = np.allclose(allo_output, pytorch_output_np, atol=atol, rtol=rtol) + + if is_close: + print( + f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})" + ) + else: + print( + f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})" + ) + print("First few differences:") + print(diff.flatten()[:10]) + + # ---------------------------------------------------------------------------------- + # Print sample outputs for comparison + # ---------------------------------------------------------------------------------- + print("\n[8] Sample outputs (first token, first 5 dimensions):") + print(f"PyTorch: {pytorch_output_np[0, 0, :5]}") + print(f"Allo: {allo_output[0, 0, :5]}") + print(f"Diff: {diff[0, 0, :5]}") + + except Exception as e: + print(f"\n✗ ERROR: Failed to run Allo implementation: {e}") + import traceback + + traceback.print_exc() + + print("=" * 60) diff --git a/ece6775/attention_moe/allo_attention_moe_base.py b/ece6775/attention_moe/allo_attention_moe_base.py new file mode 100644 index 000000000..2ad20c6ab --- /dev/null +++ b/ece6775/attention_moe/allo_attention_moe_base.py @@ -0,0 +1,743 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# Attention + MoE in Allo (no library version) +# Q, K, V -> Attention -> MoE -> Output +# all manual implementation, no library functions + +import numpy as np +import allo +from allo.ir.types import float32, int32 +from allo import dsl +import sys +import os + +# path setup +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(current_dir) +llm_config_dir = os.path.join(project_root, "llm_config") +if llm_config_dir not in sys.path: + sys.path.insert(0, llm_config_dir) + +MODE = "sw_emu" # llvm, sw_emu, hw_emu, hw, csyn + +from llm_config import DEFAULT_CONFIG_MODE, get_moe_config, print_config_info + +CONFIG_MODE = DEFAULT_CONFIG_MODE + + +def softmax_1d[Ty, N, K](X: "Ty[N, K]") -> "Ty[N, K]": + """ + Softmax over last dimension for [N, K] shape. + + Args: + X: Input tensor of shape [N, K] + Returns: + Z: Output tensor of shape [N, K] with softmax applied over dimension K + """ + Z: Ty[N, K] + E_exp: Ty[N, K] + M: Ty[N] = -1000000000000.0 + S: Ty[N] = 0.0 + + # Find max for each row (over dimension K) + for n, k in dsl.grid(N, K, name="row_max"): + if X[n, k] > M[n]: + M[n] = X[n, k] + + # Compute exp and sum + for n, k in dsl.grid(N, K, name="exp_sum"): + E_exp[n, k] = dsl.exp(X[n, k] - M[n]) + S[n] += E_exp[n, k] + + # Normalize + for n, k in dsl.grid(N, K, name="update"): + Z[n, k] = E_exp[n, k] / S[n] + + return Z + + +# ---------------------------------------------------------------------------------- +# Scaled Dot-Product Attention: Custom implementation matching PyTorch +# ---------------------------------------------------------------------------------- +def scaled_dot_product_attention[ + Ty, H, L, D +](Q: "Ty[L, D]", K: "Ty[L, D]", V: "Ty[L, D]") -> "Ty[L, D]": + """ + Scaled Dot-Product Attention (Multi-Head Attention). + + This implementation matches the PyTorch version: + - Input: Q, K, V of shape [L, D] + - Split into H heads, each with dimension D // H + - For each head: softmax(QK^T / sqrt(head_dim)) @ V + - Merge heads back to [L, D] + + Args: + Q: Query tensor of shape [L, D] + K: Key tensor of shape [L, D] + V: Value tensor of shape [L, D] + + Returns: + Z: Output tensor of shape [L, D] + """ + Z: Ty[L, D] = 0.0 + + # Compute scale factor: 1 / sqrt(head_dim) = 1 / sqrt(D // H) + scale: Ty = 1.0 / dsl.sqrt(float(D // H)) + + # Process each head + for h in range(H, name="head_loop"): + # Split Q, K, V for this head + Q_h: Ty[L, D // H] = 0.0 + K_h: Ty[L, D // H] = 0.0 + V_h: Ty[L, D // H] = 0.0 + + for i, j in dsl.grid(L, D // H, name="split_qkv"): + Q_h[i, j] = Q[i, h * (D // H) + j] + K_h[i, j] = K[i, h * (D // H) + j] + V_h[i, j] = V[i, h * (D // H) + j] + + # Compute QK^T = [L, D//H] @ [D//H, L] = [L, L] + Y: Ty[L, L] = 0.0 + for i, j, k in dsl.grid(L, L, D // H, name="qkt_matmul"): + Y[i, j] += Q_h[i, k] * K_h[j, k] # K_h[j, k] is K_h^T[k, j] + + # Scale by 1/sqrt(head_dim) + Y_scaled: Ty[L, L] = 0.0 + for i, j in dsl.grid(L, L, name="scale"): + Y_scaled[i, j] = Y[i, j] * scale + + # Apply softmax over last dimension (inline implementation) + S: Ty[L, L] = 0.0 + E_exp: Ty[L, L] = 0.0 + M: Ty[L] = -1000000000000.0 + Sum: Ty[L] = 0.0 + + # Find max for each row + for i, j in dsl.grid(L, L, name="softmax_max"): + if Y_scaled[i, j] > M[i]: + M[i] = Y_scaled[i, j] + + # Compute exp and sum + for i, j in dsl.grid(L, L, name="softmax_exp"): + E_exp[i, j] = dsl.exp(Y_scaled[i, j] - M[i]) + Sum[i] += E_exp[i, j] + + # Normalize + for i, j in dsl.grid(L, L, name="softmax_norm"): + S[i, j] = E_exp[i, j] / Sum[i] + + # Compute S @ V_h = [L, L] @ [L, D//H] = [L, D//H] + C_h: Ty[L, D // H] = 0.0 + for i, j, k in dsl.grid(L, D // H, L, name="sv_matmul"): + C_h[i, j] += S[i, k] * V_h[k, j] + + # Merge back to Z + for i, j in dsl.grid(L, D // H, name="merge_heads"): + Z[i, h * (D // H) + j] = C_h[i, j] + + return Z + + +# ---------------------------------------------------------------------------------- +# Top1 selection: Top-k selection for k=1 (argmax) +# ---------------------------------------------------------------------------------- +def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": + """ + Select top-1 expert (argmax) for each token. + + Args: + logits: Input logits of shape [N, E] + Returns: + indices: Top-1 expert indices of shape [N] + """ + indices: int32[N] + max_val: Ty[N] = -1000000000000.0 + + # Initialize indices and max_val with first expert + for n in range(N, name="init"): + indices[n] = 0 + max_val[n] = logits[n, 0] + + # Find argmax for each token (search from index 1 onwards) + for n, e in dsl.grid(N, E, name="argmax"): + if e > 0: # Skip e=0 (already initialized) + if logits[n, e] > max_val[n]: + max_val[n] = logits[n, e] + indices[n] = e + + return indices + + +# ---------------------------------------------------------------------------------- +# MoE Layer: Main MoE layer - ALL MANUAL IMPLEMENTATION (NO LIBRARY FUNCTIONS) +# Based on moe_allo.py implementation +# ---------------------------------------------------------------------------------- +def moe_layer[ + Ty, N, D_in, D_out, E, K, D_hidden +]( + x: "Ty[N, D_in]", + # Gate weights + gate_weight: "Ty[E, D_in]", + gate_bias: "Ty[E]", + # Expert weights (E experts, each with 2 linear layers) + expert_fc1_weights: "Ty[E, D_hidden, D_in]", + expert_fc1_biases: "Ty[E, D_hidden]", + expert_fc2_weights: "Ty[E, D_out, D_hidden]", + expert_fc2_biases: "Ty[E, D_out]", +) -> "Ty[N, D_out]": + """ + Mixture of Experts layer for inference. + ALL MANUAL IMPLEMENTATION - NO LIBRARY FUNCTIONS. + + Args: + x: Input tensor of shape [N, D_in] (already flattened from [B, L, D_in]) + gate_weight: Gate weight matrix of shape [E, D_in] + gate_bias: Gate bias vector of shape [E] + expert_fc1_weights: Expert FC1 weights of shape [E, D_hidden, D_in] + expert_fc1_biases: Expert FC1 biases of shape [E, D_hidden] + expert_fc2_weights: Expert FC2 weights of shape [E, D_out, D_hidden] + expert_fc2_biases: Expert FC2 biases of shape [E, D_out] + Returns: + output: Output tensor of shape [N, D_out] + """ + # ========================================================================= + # Step 1: Compute gate logits using MANUAL linear layer + # gate_logits = x @ gate_weight^T + gate_bias + # ========================================================================= + gate_logits: Ty[N, E] = 0.0 + for i, j in dsl.grid(N, E, name="gate_linear"): + # Initialize with bias + gate_logits[i, j] = gate_bias[j] + # Matrix multiplication: x[i] @ gate_weight[j]^T + for k in range(D_in): + gate_logits[i, j] += x[i, k] * gate_weight[j, k] + + # ========================================================================= + # Step 2: Select top-1 expert (argmax) for each token + # ========================================================================= + top1_indices: int32[N] + max_logit: Ty[N] = -1000000000000.0 + + for n in range(N, name="top1_init"): + top1_indices[n] = 0 + max_logit[n] = gate_logits[n, 0] + + for n, e in dsl.grid(N, E, name="top1_argmax"): + if e > 0: + if gate_logits[n, e] > max_logit[n]: + max_logit[n] = gate_logits[n, e] + top1_indices[n] = e + + # ========================================================================= + # Step 3: Get top-k logits and apply softmax + # For k=1, softmax just returns 1.0, but we compute for consistency + # ========================================================================= + top_k_logits: Ty[N, K] = 0.0 + for n, k in dsl.grid(N, K, name="topk_logits"): + expert_idx: int32 = top1_indices[n] if k == 0 else 0 + top_k_logits[n, k] = gate_logits[n, expert_idx] + + # Apply softmax to top-k logits (inline implementation) + top_k_weights: Ty[N, K] = 0.0 + softmax_max: Ty[N] = -1000000000000.0 + softmax_sum: Ty[N] = 0.0 + softmax_exp: Ty[N, K] = 0.0 + + for n, k in dsl.grid(N, K, name="softmax_max"): + if top_k_logits[n, k] > softmax_max[n]: + softmax_max[n] = top_k_logits[n, k] + + for n, k in dsl.grid(N, K, name="softmax_exp"): + softmax_exp[n, k] = dsl.exp(top_k_logits[n, k] - softmax_max[n]) + softmax_sum[n] += softmax_exp[n, k] + + for n, k in dsl.grid(N, K, name="softmax_norm"): + top_k_weights[n, k] = softmax_exp[n, k] / softmax_sum[n] + + # ========================================================================= + # Step 4: Create sparse weight matrix from top-k weights + # ========================================================================= + gate_weights: Ty[N, E] = 0.0 + for n in range(N, name="gate_weights"): + expert_idx: int32 = top1_indices[n] + gate_weights[n, expert_idx] = top_k_weights[n, 0] # For k=1 + + # ========================================================================= + # Step 5: Process each expert - ALL MANUAL (NO LIBRARY FUNCTIONS) + # ========================================================================= + expert_outputs: Ty[E, N, D_out] = 0.0 + + for e in range(E, name="expert_loop"): + # Extract expert weights for this expert + expert_fc1_w: Ty[D_hidden, D_in] = 0.0 + expert_fc1_b: Ty[D_hidden] = 0.0 + expert_fc2_w: Ty[D_out, D_hidden] = 0.0 + expert_fc2_b: Ty[D_out] = 0.0 + + for d_hidden, d_in in dsl.grid(D_hidden, D_in, name="extract_fc1_w"): + expert_fc1_w[d_hidden, d_in] = expert_fc1_weights[e, d_hidden, d_in] + + for d_hidden in range(D_hidden, name="extract_fc1_b"): + expert_fc1_b[d_hidden] = expert_fc1_biases[e, d_hidden] + + for d_out, d_hidden in dsl.grid(D_out, D_hidden, name="extract_fc2_w"): + expert_fc2_w[d_out, d_hidden] = expert_fc2_weights[e, d_out, d_hidden] + + for d_out in range(D_out, name="extract_fc2_b"): + expert_fc2_b[d_out] = expert_fc2_biases[e, d_out] + + # --------------------------------------------------------------------- + # Expert forward pass - MANUAL IMPLEMENTATION + # --------------------------------------------------------------------- + + # FC1: fc1_out = x @ fc1_weight^T + fc1_bias + fc1_out: Ty[N, D_hidden] = 0.0 + for i, j in dsl.grid(N, D_hidden, name="fc1_linear"): + fc1_out[i, j] = expert_fc1_b[j] + for k in range(D_in): + fc1_out[i, j] += x[i, k] * expert_fc1_w[j, k] + + # GELU activation: 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + gelu_out: Ty[N, D_hidden] = 0.0 + for i, j in dsl.grid(N, D_hidden, name="gelu"): + x_val: Ty = fc1_out[i, j] + x3: Ty = x_val * x_val * x_val + # sqrt(2/pi) ≈ 0.7978845608028654 + inner: Ty = 0.7978845608028654 * (x_val + 0.044715 * x3) + tanh_term: Ty = dsl.tanh(inner) + gelu_out[i, j] = 0.5 * x_val * (1.0 + tanh_term) + + # FC2: fc2_out = gelu_out @ fc2_weight^T + fc2_bias + expert_out: Ty[N, D_out] = 0.0 + for i, j in dsl.grid(N, D_out, name="fc2_linear"): + expert_out[i, j] = expert_fc2_b[j] + for k in range(D_hidden): + expert_out[i, j] += gelu_out[i, k] * expert_fc2_w[j, k] + + # Store expert outputs + for n, d_out in dsl.grid(N, D_out, name="store_expert_out"): + expert_outputs[e, n, d_out] = expert_out[n, d_out] + + # ========================================================================= + # Step 6: Combine expert outputs using gate weights + # ========================================================================= + output: Ty[N, D_out] = 0.0 + for n, e, d_out in dsl.grid(N, E, D_out, name="combine_outputs"): + weight: Ty = gate_weights[n, e] + output[n, d_out] += expert_outputs[e, n, d_out] * weight + + return output + + +# ---------------------------------------------------------------------------------- +# Attention + MoE Layer: Combined layer +# Data flow: Q, K, V -> Attention -> MoE -> Output +# ---------------------------------------------------------------------------------- +def attention_moe_layer[ + Ty, B, L, D, H, E, TopK, D_hidden +]( + Query: "Ty[B, L, D]", + Key: "Ty[B, L, D]", + Value: "Ty[B, L, D]", + # Gate weights + gate_weight: "Ty[E, D]", + gate_bias: "Ty[E]", + # Expert weights (E experts, each with 2 linear layers) + expert_fc1_weights: "Ty[E, D_hidden, D]", + expert_fc1_biases: "Ty[E, D_hidden]", + expert_fc2_weights: "Ty[E, D, D_hidden]", + expert_fc2_biases: "Ty[E, D]", +) -> "Ty[B, L, D]": + """ + Combined Attention + MoE layer. + + Data flow: Q, K, V -> Attention -> MoE -> Output + + Args: + Query: Query tensor of shape [B, L, D] + Key: Key tensor of shape [B, L, D] + Value: Value tensor of shape [B, L, D] + gate_weight: Gate weight matrix of shape [E, D] + gate_bias: Gate bias vector of shape [E] + expert_fc1_weights: Expert FC1 weights of shape [E, D_hidden, D] + expert_fc1_biases: Expert FC1 biases of shape [E, D_hidden] + expert_fc2_weights: Expert FC2 weights of shape [E, D, D_hidden] + expert_fc2_biases: Expert FC2 biases of shape [E, D] + Returns: + output: Output tensor of shape [B, L, D] + """ + # Output tensor + output: Ty[B, L, D] = 0.0 + + # Process each batch item + for b in range(B, name="batch_loop"): + # Step 1: Extract Q, K, V for this batch item -> [L, D] + Q_b: Ty[L, D] = 0.0 + K_b: Ty[L, D] = 0.0 + V_b: Ty[L, D] = 0.0 + + for l, d in dsl.grid(L, D, name="extract_qkv"): + Q_b[l, d] = Query[b, l, d] + K_b[l, d] = Key[b, l, d] + V_b[l, d] = Value[b, l, d] + + # Step 2: Apply custom scaled_dot_product_attention + attn_out = scaled_dot_product_attention[Ty, H, L, D](Q_b, K_b, V_b) # [L, D] + + # Step 3: Apply MoE layer (all manual implementation) + moe_out = moe_layer[Ty, L, D, D, E, TopK, D_hidden]( + attn_out, + gate_weight, + gate_bias, + expert_fc1_weights, + expert_fc1_biases, + expert_fc2_weights, + expert_fc2_biases, + ) # [L, D] + + # Step 4: Store output for this batch item + for l, d in dsl.grid(L, D, name="store_output"): + output[b, l, d] = moe_out[l, d] + + return output + + +# ================================================================================== +# Schedule optimization function - NO LIBRARY FUNCTIONS +# ================================================================================== +def optimize_attention_moe_with_composition( + batch_size, seq_len, embed_dim, num_heads, num_experts, k, hidden_dim +): + """ + Create optimized schedules for Attention + MoE and compose them together. + NO LIBRARY FUNCTIONS - all manual implementation. + + Args: + batch_size: Batch size (B) + seq_len: Sequence length (L) + embed_dim: Embedding dimension (D) + num_heads: Number of attention heads (H) + num_experts: Number of experts (E) + k: Top-k value (currently only k=1 is supported) + hidden_dim: Hidden dimension for experts + + Returns: + s_attn_moe: Optimized schedule for attention_moe_layer with all sub-schedules composed + """ + Ty = float32 + B = batch_size + L = seq_len + D = embed_dim + H = num_heads + E = num_experts + K = k + D_hidden = hidden_dim + + print("=" * 60) + print("Creating and optimizing Attention + MoE schedules (NO LIBRARY)...") + print("=" * 60) + + # Step 1: Create schedule for custom attention + print("\n[1] Creating schedule for custom scaled_dot_product_attention...") + s_attn = allo.customize(scaled_dot_product_attention, instantiate=[Ty, H, L, D]) + print(" - Created scaled_dot_product_attention schedule") + + # Step 2: Create schedule for moe_layer (all manual) + print("\n[2] Creating schedule for moe_layer (manual implementation)...") + s_moe = allo.customize(moe_layer, instantiate=[Ty, L, D, D, E, K, D_hidden]) + print(" - Created moe_layer schedule (no library functions)") + + # Step 3: Create schedule for main attention_moe_layer function + print("\n[3] Creating schedule for attention_moe_layer...") + s_attn_moe = allo.customize( + attention_moe_layer, instantiate=[Ty, B, L, D, H, E, K, D_hidden] + ) + + # Step 4: Compose all schedules together + print("\n[4] Composing all schedules together...") + s_attn_moe.compose(s_attn) + s_attn_moe.compose(s_moe) + print(" - Composed scaled_dot_product_attention schedule") + print(" - Composed moe_layer schedule") + + print("\n" + "=" * 60) + print("Schedule composition complete (NO LIBRARY FUNCTIONS)!") + print("=" * 60) + + return s_attn_moe + + +# ================================================================================== +# Test function to compare Allo and PyTorch implementations +# ================================================================================== + +if __name__ == "__main__": + import torch + import torch.nn as torch_nn + from pytorch_attention_moe import AttentionMoE + + # ============================================================================ + # Configuration parameters - use shared config from llm_config.py + # ============================================================================ + moe_config = get_moe_config(CONFIG_MODE) + + batch_size = moe_config["batch_size"] + seq_len = moe_config["seq_len"] + embed_dim = moe_config["input_dim"] # D, must be divisible by num_heads + num_experts = moe_config["num_experts"] # E + k = moe_config["k"] # Top-k MoE + hidden_dim = moe_config["hidden_dim"] # D_hidden + + # Attention-specific parameter: num_heads + if embed_dim >= 768: + num_heads = 12 + elif embed_dim >= 512: + num_heads = 8 + elif embed_dim >= 256: + num_heads = 8 + elif embed_dim >= 128: + num_heads = 4 + elif embed_dim >= 64: + num_heads = 4 + else: + num_heads = 2 + + # Ensure embed_dim is divisible by num_heads + while embed_dim % num_heads != 0 and num_heads > 1: + num_heads -= 1 + + seed = 42 + + print("=" * 60) + print("Attention + MoE Allo Implementation Test (NO LIBRARY)") + print("=" * 60) + print(f"Configuration Mode: {CONFIG_MODE}") + print(f"Configuration:") + print(f" batch_size={batch_size}, seq_len={seq_len}, embed_dim={embed_dim}") + print(f" num_heads={num_heads}, head_dim={embed_dim // num_heads}") + print(f" num_experts={num_experts}, k={k}, hidden_dim={hidden_dim}") + print(f" Seed: {seed}") + print("=" * 60) + + # ---------------------------------------------------------------------------------- + # Run PyTorch implementation to get weights and outputs + # ---------------------------------------------------------------------------------- + print("\n[1] Running PyTorch implementation...") + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # Create PyTorch AttentionMoE layer + pytorch_model = AttentionMoE( + seq_len=seq_len, + embed_dim=embed_dim, + num_heads=num_heads, + num_experts=num_experts, + k=k, + expert_hidden_dim=hidden_dim, + ) + pytorch_model.eval() + + # Initialize with Xavier uniform + for param in pytorch_model.parameters(): + if param.dim() > 1: + torch_nn.init.xavier_uniform_(param) + else: + torch_nn.init.zeros_(param) + + # Create random inputs (Q, K, V) + torch.manual_seed(seed) + Q_pt = torch.randn(batch_size, seq_len, embed_dim) + K_pt = torch.randn(batch_size, seq_len, embed_dim) + V_pt = torch.randn(batch_size, seq_len, embed_dim) + + # Run PyTorch inference + with torch.no_grad(): + pytorch_output = pytorch_model(Q_pt, K_pt, V_pt, verbose=False) + + print(f"PyTorch output shape: {pytorch_output.shape}") + print( + f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]" + ) + + # ---------------------------------------------------------------------------------- + # Extract weights and biases from PyTorch model + # ---------------------------------------------------------------------------------- + print("\n[2] Extracting weights from PyTorch model...") + + # Gate weights + gate_weight_pt = pytorch_model.moe.gate.gate_linear.weight.data + gate_bias_pt = pytorch_model.moe.gate.gate_linear.bias + if gate_bias_pt is not None: + gate_bias_pt = gate_bias_pt.data + else: + gate_bias_pt = torch.zeros(num_experts) + + # Expert weights + expert_fc1_weights_pt = torch.stack( + [exp.fc1.weight.data for exp in pytorch_model.moe.experts] + ) + expert_fc1_biases_pt = torch.stack( + [exp.fc1.bias.data for exp in pytorch_model.moe.experts] + ) + expert_fc2_weights_pt = torch.stack( + [exp.fc2.weight.data for exp in pytorch_model.moe.experts] + ) + expert_fc2_biases_pt = torch.stack( + [exp.fc2.bias.data for exp in pytorch_model.moe.experts] + ) + + # ---------------------------------------------------------------------------------- + # Convert to numpy arrays + # ---------------------------------------------------------------------------------- + print("\n[3] Converting weights to numpy arrays...") + Q_np = np.ascontiguousarray(Q_pt.detach().numpy(), dtype=np.float32) + K_np = np.ascontiguousarray(K_pt.detach().numpy(), dtype=np.float32) + V_np = np.ascontiguousarray(V_pt.detach().numpy(), dtype=np.float32) + + gate_weight_np = np.ascontiguousarray( + gate_weight_pt.detach().numpy(), dtype=np.float32 + ) + gate_bias_np = np.ascontiguousarray(gate_bias_pt.detach().numpy(), dtype=np.float32) + + expert_fc1_weights_np = np.ascontiguousarray( + expert_fc1_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc1_biases_np = np.ascontiguousarray( + expert_fc1_biases_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_weights_np = np.ascontiguousarray( + expert_fc2_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_biases_np = np.ascontiguousarray( + expert_fc2_biases_pt.detach().numpy(), dtype=np.float32 + ) + + print(f"Q shape: {Q_np.shape}") + print(f"K shape: {K_np.shape}") + print(f"V shape: {V_np.shape}") + print(f"Gate weight shape: {gate_weight_np.shape}") + print(f"Expert FC1 weights shape: {expert_fc1_weights_np.shape}") + print(f"Expert FC2 weights shape: {expert_fc2_weights_np.shape}") + + # ---------------------------------------------------------------------------------- + # Run Allo implementation + # ---------------------------------------------------------------------------------- + print("\n[4] Running Allo implementation...") + try: + # Create optimized schedule with composition + allo_schedule = optimize_attention_moe_with_composition( + batch_size, seq_len, embed_dim, num_heads, num_experts, k, hidden_dim + ) + + # Generate project name + project_name = f"allo_attention_moe_base_{CONFIG_MODE}.prj" + print(f"Using project name: {project_name}") + + # Build module + print("\n[5] Building Allo module...") + if MODE == "llvm": + mod = allo_schedule.build(target="llvm") + elif MODE == "sw_emu": + mod = allo_schedule.build( + target="vitis_hls", mode="sw_emu", project=project_name + ) + elif MODE == "hw_emu": + mod = allo_schedule.build( + target="vitis_hls", mode="hw_emu", project=project_name + ) + elif MODE == "hw": + mod = allo_schedule.build( + target="vitis_hls", mode="hw", project=project_name + ) + elif MODE == "csyn": + mod = allo_schedule.build( + target="vitis_hls", mode="csyn", project=project_name + ) + else: + raise ValueError(f"Unsupported mode: {MODE}") + + # Run Allo inference + print("\n[6] Running Allo inference...") + if MODE == "llvm": + allo_output = mod( + Q_np, + K_np, + V_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + ) + elif MODE in ["sw_emu", "hw_emu", "hw"]: + allo_output = np.zeros((batch_size, seq_len, embed_dim), dtype=np.float32) + mod( + Q_np, + K_np, + V_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + allo_output, + ) + elif MODE == "csyn": + allo_output = np.zeros((batch_size, seq_len, embed_dim), dtype=np.float32) + mod() + else: + raise ValueError(f"Unsupported mode: {MODE}") + + print(f"Allo output shape: {allo_output.shape}") + print(f"Allo output range: [{allo_output.min():.6f}, {allo_output.max():.6f}]") + + # ---------------------------------------------------------------------------------- + # Compare the Allo and PyTorch outputs + # ---------------------------------------------------------------------------------- + print("\n[7] Comparing outputs...") + pytorch_output_np = pytorch_output.detach().numpy() + + # Compute differences + diff = np.abs(allo_output - pytorch_output_np) + mean_diff = np.mean(diff) + max_diff = np.max(diff) + rel_diff = np.mean(diff / (np.abs(pytorch_output_np) + 1e-8)) + + print(f"Mean absolute difference: {mean_diff:.6e}") + print(f"Max absolute difference: {max_diff:.6e}") + print(f"Mean relative difference: {rel_diff:.6e}") + + # Check if outputs are close + atol = 5e-4 + rtol = 2e-3 + is_close = np.allclose(allo_output, pytorch_output_np, atol=atol, rtol=rtol) + + if is_close: + print( + f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})" + ) + else: + print( + f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})" + ) + print("First few differences:") + print(diff.flatten()[:10]) + + # ---------------------------------------------------------------------------------- + # Print sample outputs for comparison + # ---------------------------------------------------------------------------------- + print("\n[8] Sample outputs (first token, first 5 dimensions):") + print(f"PyTorch: {pytorch_output_np[0, 0, :5]}") + print(f"Allo: {allo_output[0, 0, :5]}") + print(f"Diff: {diff[0, 0, :5]}") + + except Exception as e: + print(f"\n✗ ERROR: Failed to run Allo implementation: {e}") + import traceback + + traceback.print_exc() + + print("=" * 60) diff --git a/ece6775/attention_moe/allo_attention_moe_lib.py b/ece6775/attention_moe/allo_attention_moe_lib.py new file mode 100644 index 000000000..0e71b21f6 --- /dev/null +++ b/ece6775/attention_moe/allo_attention_moe_lib.py @@ -0,0 +1,836 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# Attention + MoE in Allo +# Q, K, V -> Attention -> MoE -> Output +# uses library functions (nn.linear2d, nn.GeLU) for expert + +import numpy as np +import allo +import allo.library.nn as allo_nn +from allo.library.nn import linear2d, GeLU # Direct import for type inference +from allo.ir.types import float32, int32 +from allo import dsl +import sys +import os + +# path setup +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(current_dir) +llm_config_dir = os.path.join(project_root, "llm_config") +if llm_config_dir not in sys.path: + sys.path.insert(0, llm_config_dir) + +MODE = "sw_emu" # llvm, sw_emu, hw_emu, hw, csyn + +from llm_config import DEFAULT_CONFIG_MODE, get_moe_config, print_config_info + +CONFIG_MODE = DEFAULT_CONFIG_MODE + + +def softmax_1d[Ty, N, K](X: "Ty[N, K]") -> "Ty[N, K]": + """ + Softmax over last dimension for [N, K] shape. + + Args: + X: Input tensor of shape [N, K] + Returns: + Z: Output tensor of shape [N, K] with softmax applied over dimension K + """ + Z: Ty[N, K] + E_exp: Ty[N, K] + M: Ty[N] = -1000000000000.0 + S: Ty[N] = 0.0 + + # Find max for each row (over dimension K) + for n, k in dsl.grid(N, K, name="row_max"): + if X[n, k] > M[n]: + M[n] = X[n, k] + + # Compute exp and sum + for n, k in dsl.grid(N, K, name="exp_sum"): + E_exp[n, k] = dsl.exp(X[n, k] - M[n]) + S[n] += E_exp[n, k] + + # Normalize + for n, k in dsl.grid(N, K, name="update"): + # Add small epsilon to prevent division by zero + Z[n, k] = E_exp[n, k] / (S[n]) + + return Z + + +# ---------------------------------------------------------------------------------- +# Softmax 2D: Softmax over last dimension for [L, L] shape (for attention) +# ---------------------------------------------------------------------------------- +def softmax_2d[Ty, L](X: "Ty[L, L]") -> "Ty[L, L]": + """ + Softmax over last dimension for [L, L] shape. + Used for attention scores. + + Args: + X: Input tensor of shape [L, L] + Returns: + Z: Output tensor of shape [L, L] with softmax applied over last dimension + """ + Z: Ty[L, L] + E_exp: Ty[L, L] + M: Ty[L] = -1000000000000.0 + S: Ty[L] = 0.0 + + # Find max for each row + for i, j in dsl.grid(L, L, name="row_max"): + if X[i, j] > M[i]: + M[i] = X[i, j] + + # Compute exp and sum + for i, j in dsl.grid(L, L, name="exp_sum"): + E_exp[i, j] = dsl.exp(X[i, j] - M[i]) + S[i] += E_exp[i, j] + + # Normalize + for i, j in dsl.grid(L, L, name="update"): + Z[i, j] = E_exp[i, j] / S[i] + + return Z + + +# ---------------------------------------------------------------------------------- +# Scaled Dot-Product Attention: Custom implementation matching PyTorch +# ---------------------------------------------------------------------------------- +def scaled_dot_product_attention[ + Ty, H, L, D +](Q: "Ty[L, D]", K: "Ty[L, D]", V: "Ty[L, D]") -> "Ty[L, D]": + """ + Scaled Dot-Product Attention (Multi-Head Attention). + + This implementation matches the PyTorch version: + - Input: Q, K, V of shape [L, D] + - Split into H heads, each with dimension D // H + - For each head: softmax(QK^T / sqrt(head_dim)) @ V + - Merge heads back to [L, D] + - Uses linear2d for matrix multiplication (replacing systolic for HLS compatibility) + + Args: + Q: Query tensor of shape [L, D] + K: Key tensor of shape [L, D] + V: Value tensor of shape [L, D] + + Returns: + Z: Output tensor of shape [L, D] + """ + Z: Ty[L, D] = 0.0 + + # Compute scale factor: 1 / sqrt(head_dim) = 1 / sqrt(D // H) + # For D=8, H=2: head_dim=4, scale = 1/sqrt(4) = 0.5 + scale: Ty = 1.0 / dsl.sqrt(float(D // H)) + + # Process each head + for h in range(H, name="head_loop"): + # Split Q, K, V for this head + Q_h: Ty[L, D // H] = 0.0 + K_h: Ty[L, D // H] = 0.0 # Not transposed, will transpose for linear2d + V_h: Ty[L, D // H] = 0.0 + + for i, j in dsl.grid(L, D // H, name="split_qkv"): + Q_h[i, j] = Q[i, h * (D // H) + j] + K_h[i, j] = K[i, h * (D // H) + j] # Store K normally (not transposed) + V_h[i, j] = V[i, h * (D // H) + j] + + # Compute QK^T = [L, D//H] @ [D//H, L] = [L, L] + # Using linear2d: Q_h @ K_h^T = linear2d(Q_h, K_h_transposed, 0) + # linear2d computes X @ W^T + b, so we need K_h^T as W + # K_h is [L, D//H], so K_h^T is [D//H, L], but linear2d expects W[N, K] = [L, D//H] + # So we need to transpose K_h to [D//H, L], then use it as W[L, D//H] (which is K_h^T) + # Actually: Q_h[L, D//H] @ K_h^T[D//H, L] = Q_h @ (K_h^T) + # linear2d(X[M, K], W[N, K], b) computes X @ W^T, so: + # - X = Q_h [L, D//H] -> M=L, K=D//H + # - W should be [L, D//H] to get W^T = [D//H, L] + # - So W = K_h^T, but we have K_h [L, D//H], so we need to transpose it + K_h_T: Ty[D // H, L] = 0.0 + for i, j in dsl.grid(L, D // H, name="transpose_K"): + K_h_T[j, i] = K_h[i, j] + + # Now use linear2d: Q_h[L, D//H] @ K_h_T^T[D//H, L] = Q_h @ K_h_T^T + # But linear2d expects W[N, K], so we need W[L, D//H] such that W^T = [D//H, L] + # Actually, we can use K_h directly as W[L, D//H], then W^T = [D//H, L] = K_h^T + # Wait, let me reconsider: linear2d(X[M, K], W[N, K], b) = X @ W^T + # For QK^T = Q_h[L, D//H] @ K_h^T[D//H, L]: + # - X = Q_h [L, D//H] -> M=L, K=D//H + # - We want result [L, L], so N=L + # - W should be [L, D//H] such that W^T = [D//H, L] = K_h^T + # - So W = K_h (which is [L, D//H]), then W^T = [D//H, L] = K_h^T ✓ + zero_bias: Ty[L] = 0.0 + Y = linear2d[Ty, Ty, Ty, L, L, D // H](Q_h, K_h, zero_bias) # Q_h @ K_h^T + + # Scale by 1/sqrt(head_dim) + Y_scaled: Ty[L, L] = 0.0 + for i, j in dsl.grid(L, L, name="scale"): + Y_scaled[i, j] = Y[i, j] * scale + + # Apply softmax over last dimension + S: Ty[L, L] = 0.0 + E_exp: Ty[L, L] = 0.0 + M: Ty[L] = -1000000000000.0 + Sum: Ty[L] = 0.0 + + # Find max for each row + for i, j in dsl.grid(L, L, name="softmax_max"): + if Y_scaled[i, j] > M[i]: + M[i] = Y_scaled[i, j] + + # Compute exp and sum + for i, j in dsl.grid(L, L, name="softmax_exp"): + E_exp[i, j] = dsl.exp(Y_scaled[i, j] - M[i]) + Sum[i] += E_exp[i, j] + + # Normalize + for i, j in dsl.grid(L, L, name="softmax_norm"): + S[i, j] = E_exp[i, j] / Sum[i] + + # Compute S @ V_h = [L, L] @ [L, D//H] = [L, D//H] + # Using linear2d: S @ V_h = linear2d(S, V_h_transposed, 0) + # linear2d(X[M, K], W[N, K], b) = X @ W^T + # For S[L, L] @ V_h[L, D//H]: + # - X = S [L, L] -> M=L, K=L + # - We want result [L, D//H], so N=D//H + # - W should be [D//H, L] such that W^T = [L, D//H] = V_h + # - So W = V_h^T (which is [D//H, L]), then W^T = [L, D//H] = V_h ✓ + V_h_T: Ty[D // H, L] = 0.0 + for i, j in dsl.grid(L, D // H, name="transpose_V"): + V_h_T[j, i] = V_h[i, j] + + zero_bias_v: Ty[D // H] = 0.0 + C_h = linear2d[Ty, Ty, Ty, L, D // H, L](S, V_h_T, zero_bias_v) # S @ V_h + + # Merge back to Z + for i, j in dsl.grid(L, D // H, name="merge_heads"): + Z[i, h * (D // H) + j] = C_h[i, j] + + return Z + + +# ---------------------------------------------------------------------------------- +# Top1 selection: Top-k selection for k=1 (argmax) +# ---------------------------------------------------------------------------------- +def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": + """ + Select top-1 expert (argmax) for each token. + + Args: + logits: Input logits of shape [N, E] + Returns: + indices: Top-1 expert indices of shape [N] + """ + indices: int32[N] + max_val: Ty[N] = -1000000000000.0 + + # Initialize indices and max_val with first expert + for n in range(N, name="init"): + indices[n] = 0 + max_val[n] = logits[n, 0] + + # Find argmax for each token (search from index 1 onwards) + for n, e in dsl.grid(N, E, name="argmax"): + if e > 0: # Skip e=0 (already initialized) + if logits[n, e] > max_val[n]: + max_val[n] = logits[n, e] + indices[n] = e + + return indices + + +# ---------------------------------------------------------------------------------- +# Expert: A simple feed-forward expert network +# ---------------------------------------------------------------------------------- +def expert[ + Ty, N, D_in, D_hidden, D_out +]( + x: "Ty[N, D_in]", + fc1_weight: "Ty[D_hidden, D_in]", + fc1_bias: "Ty[D_hidden]", + fc2_weight: "Ty[D_out, D_hidden]", + fc2_bias: "Ty[D_out]", +) -> "Ty[N, D_out]": + """ + A simple feed-forward expert network. + + This implementation uses linear2d and GeLU from Allo library + for optimized computation. + + Args: + x: Input tensor of shape [N, D_in] + fc1_weight: First linear layer weights of shape [D_hidden, D_in] + fc1_bias: First linear layer bias of shape [D_hidden] + fc2_weight: Second linear layer weights of shape [D_out, D_hidden] + fc2_bias: Second linear layer bias of shape [D_out] + Returns: + output: Output tensor of shape [N, D_out] + """ + # Step 1: First linear layer using linear2d from library + fc1_out = linear2d[Ty, Ty, Ty, N, D_hidden, D_in](x, fc1_weight, fc1_bias) + + # Step 2: GELU activation using GeLU from library + gelu_out = GeLU[Ty, N, D_hidden](fc1_out) + + # Step 3: Second linear layer using linear2d from library + fc2_out = linear2d[Ty, Ty, Ty, N, D_out, D_hidden](gelu_out, fc2_weight, fc2_bias) + + return fc2_out + + +# ---------------------------------------------------------------------------------- +# MoE Layer: Main MoE layer that routes tokens to experts +# ---------------------------------------------------------------------------------- +def moe_layer[ + Ty, N, D_in, D_out, E, K, D_hidden +]( + x: "Ty[N, D_in]", + # Gate weights + gate_weight: "Ty[E, D_in]", + gate_bias: "Ty[E]", + # Expert weights (E experts, each with 2 linear layers) + expert_fc1_weights: "Ty[E, D_hidden, D_in]", + expert_fc1_biases: "Ty[E, D_hidden]", + expert_fc2_weights: "Ty[E, D_out, D_hidden]", + expert_fc2_biases: "Ty[E, D_out]", +) -> "Ty[N, D_out]": + """ + Mixture of Experts layer for inference. + + Args: + x: Input tensor of shape [N, D_in] (already flattened from [B, L, D_in]) + gate_weight: Gate weight matrix of shape [E, D_in] + gate_bias: Gate bias vector of shape [E] + expert_fc1_weights: Expert FC1 weights of shape [E, D_hidden, D_in] + expert_fc1_biases: Expert FC1 biases of shape [E, D_hidden] + expert_fc2_weights: Expert FC2 weights of shape [E, D_out, D_hidden] + expert_fc2_biases: Expert FC2 biases of shape [E, D_out] + Returns: + output: Output tensor of shape [N, D_out] + """ + # Step 1: Compute gate logits using linear2d + gate_logits = linear2d[Ty, Ty, Ty, N, E, D_in](x, gate_weight, gate_bias) # [N, E] + + # Step 2: Select top-1 expert using top1_select function + top1_indices_1d: int32[N] = top1_select[Ty, N, E](gate_logits) + + # Step 3: Get top-k logits and apply softmax + top_k_logits: Ty[N, K] = 0.0 + for n, k in dsl.grid(N, K, name="topk_logits"): + expert_idx = top1_indices_1d[n] if k == 0 else 0 # For k=1, K=1 + top_k_logits[n, k] = gate_logits[n, expert_idx] + + # Step 4: Apply softmax to top-k logits using softmax_1d function + top_k_weights = softmax_1d[Ty, N, K](top_k_logits) # [N, K] + + # Step 5: Create sparse weight matrix from top-k weights + top_k_indices: int32[N, K] = 0 + gate_weights: Ty[N, E] = 0.0 + + for n in range(N, name="gate_weights"): + # Store top-k indices (for k=1) + top_k_indices[n, 0] = top1_indices_1d[n] + # Set gate weights from softmax output + expert_idx = top1_indices_1d[n] + gate_weights[n, expert_idx] = top_k_weights[n, 0] # For k=1, K=1 + + # Step 6: Process each expert: compute outputs for all tokens + expert_outputs: Ty[E, N, D_out] = 0.0 + + for e in range(E, name="expert_loop"): + # Extract expert weights for this expert + expert_fc1_w: Ty[D_hidden, D_in] = 0.0 + expert_fc1_b: Ty[D_hidden] = 0.0 + expert_fc2_w: Ty[D_out, D_hidden] = 0.0 + expert_fc2_b: Ty[D_out] = 0.0 + + for d_hidden, d_in in dsl.grid(D_hidden, D_in, name="extract_fc1_w"): + expert_fc1_w[d_hidden, d_in] = expert_fc1_weights[e, d_hidden, d_in] + + for d_hidden in range(D_hidden, name="extract_fc1_b"): + expert_fc1_b[d_hidden] = expert_fc1_biases[e, d_hidden] + + for d_out, d_hidden in dsl.grid(D_out, D_hidden, name="extract_fc2_w"): + expert_fc2_w[d_out, d_hidden] = expert_fc2_weights[e, d_out, d_hidden] + + for d_out in range(D_out, name="extract_fc2_b"): + expert_fc2_b[d_out] = expert_fc2_biases[e, d_out] + + # Process all tokens through this expert using the expert function + expert_out = expert[Ty, N, D_in, D_hidden, D_out]( + x, expert_fc1_w, expert_fc1_b, expert_fc2_w, expert_fc2_b + ) # [N, D_out] + + # Store expert outputs + for n, d_out in dsl.grid(N, D_out, name="store_expert_out"): + expert_outputs[e, n, d_out] = expert_out[n, d_out] + + # Step 7: Combine expert outputs using gate weights + output: Ty[N, D_out] = 0.0 + for n, e, d_out in dsl.grid(N, E, D_out, name="combine_outputs"): + weight: Ty = gate_weights[n, e] + output[n, d_out] += expert_outputs[e, n, d_out] * weight + + return output + + +# ---------------------------------------------------------------------------------- +# Attention + MoE Layer: Combined layer +# Data flow: Q, K, V -> Attention -> MoE -> Output +# ---------------------------------------------------------------------------------- +def attention_moe_layer[ + Ty, B, L, D, H, E, TopK, D_hidden +]( + Query: "Ty[B, L, D]", + Key: "Ty[B, L, D]", + Value: "Ty[B, L, D]", + # Gate weights + gate_weight: "Ty[E, D]", + gate_bias: "Ty[E]", + # Expert weights (E experts, each with 2 linear layers) + expert_fc1_weights: "Ty[E, D_hidden, D]", + expert_fc1_biases: "Ty[E, D_hidden]", + expert_fc2_weights: "Ty[E, D, D_hidden]", + expert_fc2_biases: "Ty[E, D]", +) -> "Ty[B, L, D]": + """ + Combined Attention + MoE layer. + + Data flow: Q, K, V -> Attention -> MoE -> Output + + Args: + Query: Query tensor of shape [B, L, D] + Key: Key tensor of shape [B, L, D] + Value: Value tensor of shape [B, L, D] + gate_weight: Gate weight matrix of shape [E, D] + gate_bias: Gate bias vector of shape [E] + expert_fc1_weights: Expert FC1 weights of shape [E, D_hidden, D] + expert_fc1_biases: Expert FC1 biases of shape [E, D_hidden] + expert_fc2_weights: Expert FC2 weights of shape [E, D, D_hidden] + expert_fc2_biases: Expert FC2 biases of shape [E, D] + Returns: + output: Output tensor of shape [B, L, D] + """ + # Output tensor + output: Ty[B, L, D] = 0.0 + + # Process each batch item + for b in range(B, name="batch_loop"): + # Step 1: Extract Q, K, V for this batch item -> [L, D] + Q_b: Ty[L, D] = 0.0 + K_b: Ty[L, D] = 0.0 + V_b: Ty[L, D] = 0.0 + + for l, d in dsl.grid(L, D, name="extract_qkv"): + Q_b[l, d] = Query[b, l, d] + K_b[l, d] = Key[b, l, d] + V_b[l, d] = Value[b, l, d] + + # Step 2: Apply custom scaled_dot_product_attention + # scaled_dot_product_attention[Ty, H, L, D](Q, K, V) -> [L, D] + attn_out = scaled_dot_product_attention[Ty, H, L, D](Q_b, K_b, V_b) # [L, D] + + # Step 3: Apply MoE layer + # moe_layer expects [N, D_in], so we use N=L for single batch + moe_out = moe_layer[Ty, L, D, D, E, TopK, D_hidden]( + attn_out, + gate_weight, + gate_bias, + expert_fc1_weights, + expert_fc1_biases, + expert_fc2_weights, + expert_fc2_biases, + ) # [L, D] + + # Step 4: Store output for this batch item + for l, d in dsl.grid(L, D, name="store_output"): + output[b, l, d] = moe_out[l, d] + + return output + + +# ================================================================================== +# Schedule optimization function +# ================================================================================== +def optimize_attention_moe_with_composition( + batch_size, seq_len, embed_dim, num_heads, num_experts, k, hidden_dim +): + """ + Create optimized schedules for Attention + MoE and compose them together. + + Args: + batch_size: Batch size (B) + seq_len: Sequence length (L) + embed_dim: Embedding dimension (D) + num_heads: Number of attention heads (H) + num_experts: Number of experts (E) + k: Top-k value (currently only k=1 is supported) + hidden_dim: Hidden dimension for experts + + Returns: + s_attn_moe: Optimized schedule for attention_moe_layer with all sub-schedules composed + """ + Ty = float32 + B = batch_size + L = seq_len + D = embed_dim + H = num_heads + E = num_experts + K = k + D_hidden = hidden_dim + + print("=" * 60) + print("Creating and optimizing Attention + MoE schedules...") + print("=" * 60) + + # Step 1: Create schedule for top1_select + print("\n[1] Creating schedule for top1_select...") + s_top1 = allo.customize(top1_select, instantiate=[Ty, L, E]) + print(" - Created top1_select schedule for [L, E]") + + # Step 2: Create schedule for softmax_1d (for MoE gate) + print("\n[2] Creating schedule for softmax_1d...") + s_softmax_1d = allo.customize(softmax_1d, instantiate=[Ty, L, K]) + print(" - Created softmax_1d schedule for [L, K]") + + # Step 3: Create schedule for expert + print("\n[3] Creating schedule for expert...") + + # Create schedules for library functions that expert uses + print(" - Creating schedules for library functions (linear2d, GeLU)...") + s_linear_fc1 = allo.customize( + allo_nn.linear2d, instantiate=[Ty, Ty, Ty, L, D_hidden, D] + ) + s_gelu = allo.customize(allo_nn.GeLU, instantiate=[Ty, L, D_hidden]) + s_linear_fc2 = allo.customize( + allo_nn.linear2d, instantiate=[Ty, Ty, Ty, L, D, D_hidden] + ) + print(" - Library function schedules created") + + # Create expert schedule + print(" - Creating expert schedule...") + s_expert = allo.customize(expert, instantiate=[Ty, L, D, D_hidden, D]) + + # Compose library function schedules + s_expert.compose(s_linear_fc1, id="expert_fc1") + s_expert.compose(s_gelu, id="expert_gelu") + s_expert.compose(s_linear_fc2, id="expert_fc2") + print(" - Created expert schedule") + print(" - Composed nn.linear2d and nn.GeLU schedules for expert") + + # Step 4: Create schedule for moe_layer + print("\n[4] Creating schedule for moe_layer...") + s_moe = allo.customize(moe_layer, instantiate=[Ty, L, D, D, E, K, D_hidden]) + + # Compose gate linear + s_gate_linear = allo.customize(allo_nn.linear2d, instantiate=[Ty, Ty, Ty, L, E, D]) + s_moe.compose(s_gate_linear, id="gate") + s_moe.compose(s_top1) + s_moe.compose(s_softmax_1d) + s_moe.compose(s_expert) + print(" - Created moe_layer schedule") + + # Step 5: Create schedule for custom attention + print("\n[5] Creating schedule for custom scaled_dot_product_attention...") + s_attn = allo.customize(scaled_dot_product_attention, instantiate=[Ty, H, L, D]) + print(" - Created scaled_dot_product_attention schedule") + + # Step 6: Create schedule for main attention_moe_layer function + print("\n[6] Creating schedule for attention_moe_layer...") + s_attn_moe = allo.customize( + attention_moe_layer, instantiate=[Ty, B, L, D, H, E, K, D_hidden] + ) + + # Step 7: Compose all schedules together + print("\n[7] Composing all schedules together...") + s_attn_moe.compose(s_attn) + s_attn_moe.compose(s_moe) + print(" - Composed scaled_dot_product_attention schedule") + print(" - Composed moe_layer schedule") + + print("\n" + "=" * 60) + print("Schedule composition complete!") + print("=" * 60) + + return s_attn_moe + + +# ================================================================================== +# Test function to compare Allo and PyTorch implementations +# ================================================================================== + +if __name__ == "__main__": + import torch + import torch.nn as torch_nn + from pytorch_attention_moe import AttentionMoE + + # ============================================================================ + # Configuration parameters - use shared config from llm_config.py + # ============================================================================ + # Get MoE configuration from shared config + moe_config = get_moe_config(CONFIG_MODE) + + batch_size = moe_config["batch_size"] + seq_len = moe_config["seq_len"] + embed_dim = moe_config["input_dim"] # D, must be divisible by num_heads + num_experts = moe_config["num_experts"] # E + k = moe_config["k"] # Top-k MoE + hidden_dim = moe_config["hidden_dim"] # D_hidden + + # Attention-specific parameter: num_heads + # Choose num_heads such that embed_dim is divisible by num_heads + # Common choices: 2, 4, 8, 12, 16 (depending on embed_dim) + if embed_dim >= 768: + num_heads = 12 # Standard for BERT-base + elif embed_dim >= 512: + num_heads = 8 + elif embed_dim >= 256: + num_heads = 8 + elif embed_dim >= 128: + num_heads = 4 + elif embed_dim >= 64: + num_heads = 4 + else: + num_heads = 2 + + # Ensure embed_dim is divisible by num_heads + while embed_dim % num_heads != 0 and num_heads > 1: + num_heads -= 1 + + seed = 42 + + print("=" * 60) + print("Attention + MoE Allo Implementation Test") + print("=" * 60) + print(f"Configuration Mode: {CONFIG_MODE}") + print(f"Configuration:") + print(f" batch_size={batch_size}, seq_len={seq_len}, embed_dim={embed_dim}") + print(f" num_heads={num_heads}, head_dim={embed_dim // num_heads}") + print(f" num_experts={num_experts}, k={k}, hidden_dim={hidden_dim}") + print(f" Seed: {seed}") + print("=" * 60) + + # ---------------------------------------------------------------------------------- + # Run PyTorch implementation to get weights and outputs + # ---------------------------------------------------------------------------------- + print("\n[1] Running PyTorch implementation...") + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # Create PyTorch AttentionMoE layer + pytorch_model = AttentionMoE( + seq_len=seq_len, + embed_dim=embed_dim, + num_heads=num_heads, + num_experts=num_experts, + k=k, + expert_hidden_dim=hidden_dim, + ) + pytorch_model.eval() + + # Initialize with Xavier uniform + for param in pytorch_model.parameters(): + if param.dim() > 1: + torch_nn.init.xavier_uniform_(param) + else: + torch_nn.init.zeros_(param) + + # Create random inputs (Q, K, V) + torch.manual_seed(seed) + Q_pt = torch.randn(batch_size, seq_len, embed_dim) + K_pt = torch.randn(batch_size, seq_len, embed_dim) + V_pt = torch.randn(batch_size, seq_len, embed_dim) + + # Run PyTorch inference + with torch.no_grad(): + pytorch_output = pytorch_model(Q_pt, K_pt, V_pt, verbose=False) + + print(f"PyTorch output shape: {pytorch_output.shape}") + print( + f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]" + ) + + # ---------------------------------------------------------------------------------- + # Extract weights and biases from PyTorch model + # ---------------------------------------------------------------------------------- + print("\n[2] Extracting weights from PyTorch model...") + + # Gate weights + gate_weight_pt = ( + pytorch_model.moe.gate.gate_linear.weight.data + ) # [num_experts, embed_dim] + gate_bias_pt = pytorch_model.moe.gate.gate_linear.bias + if gate_bias_pt is not None: + gate_bias_pt = gate_bias_pt.data + else: + gate_bias_pt = torch.zeros(num_experts) + + # Expert weights + expert_fc1_weights_pt = torch.stack( + [exp.fc1.weight.data for exp in pytorch_model.moe.experts] + ) + expert_fc1_biases_pt = torch.stack( + [exp.fc1.bias.data for exp in pytorch_model.moe.experts] + ) + expert_fc2_weights_pt = torch.stack( + [exp.fc2.weight.data for exp in pytorch_model.moe.experts] + ) + expert_fc2_biases_pt = torch.stack( + [exp.fc2.bias.data for exp in pytorch_model.moe.experts] + ) + + # ---------------------------------------------------------------------------------- + # Convert to numpy arrays + # ---------------------------------------------------------------------------------- + print("\n[3] Converting weights to numpy arrays...") + Q_np = np.ascontiguousarray(Q_pt.detach().numpy(), dtype=np.float32) + K_np = np.ascontiguousarray(K_pt.detach().numpy(), dtype=np.float32) + V_np = np.ascontiguousarray(V_pt.detach().numpy(), dtype=np.float32) + + gate_weight_np = np.ascontiguousarray( + gate_weight_pt.detach().numpy(), dtype=np.float32 + ) + gate_bias_np = np.ascontiguousarray(gate_bias_pt.detach().numpy(), dtype=np.float32) + + expert_fc1_weights_np = np.ascontiguousarray( + expert_fc1_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc1_biases_np = np.ascontiguousarray( + expert_fc1_biases_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_weights_np = np.ascontiguousarray( + expert_fc2_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_biases_np = np.ascontiguousarray( + expert_fc2_biases_pt.detach().numpy(), dtype=np.float32 + ) + + print(f"Q shape: {Q_np.shape}") + print(f"K shape: {K_np.shape}") + print(f"V shape: {V_np.shape}") + print(f"Gate weight shape: {gate_weight_np.shape}") + print(f"Expert FC1 weights shape: {expert_fc1_weights_np.shape}") + print(f"Expert FC2 weights shape: {expert_fc2_weights_np.shape}") + + # ---------------------------------------------------------------------------------- + # Run Allo implementation + # ---------------------------------------------------------------------------------- + print("\n[4] Running Allo implementation...") + try: + # Create optimized schedule with composition + allo_schedule = optimize_attention_moe_with_composition( + batch_size, seq_len, embed_dim, num_heads, num_experts, k, hidden_dim + ) + + # Generate project name + project_name = f"allo_attention_moe_lib_{CONFIG_MODE}.prj" + print(f"Using project name: {project_name}") + + # Build module + print("\n[5] Building Allo module...") + if MODE == "llvm": + mod = allo_schedule.build(target="llvm") + elif MODE == "sw_emu": + mod = allo_schedule.build( + target="vitis_hls", mode="sw_emu", project=project_name + ) + elif MODE == "hw_emu": + mod = allo_schedule.build( + target="vitis_hls", mode="hw_emu", project=project_name + ) + elif MODE == "hw": + mod = allo_schedule.build( + target="vitis_hls", mode="hw", project=project_name + ) + elif MODE == "csyn": + mod = allo_schedule.build( + target="vitis_hls", mode="csyn", project=project_name + ) + else: + raise ValueError(f"Unsupported mode: {MODE}") + + # Run Allo inference + print("\n[6] Running Allo inference...") + if MODE == "llvm": + allo_output = mod( + Q_np, + K_np, + V_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + ) + elif MODE in ["sw_emu", "hw_emu", "hw"]: + allo_output = np.zeros((batch_size, seq_len, embed_dim), dtype=np.float32) + mod( + Q_np, + K_np, + V_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + allo_output, + ) + elif MODE == "csyn": + allo_output = np.zeros((batch_size, seq_len, embed_dim), dtype=np.float32) + mod() + else: + raise ValueError(f"Unsupported mode: {MODE}") + + print(f"Allo output shape: {allo_output.shape}") + print(f"Allo output range: [{allo_output.min():.6f}, {allo_output.max():.6f}]") + + # ---------------------------------------------------------------------------------- + # Compare the Allo and PyTorch outputs + # ---------------------------------------------------------------------------------- + print("\n[7] Comparing outputs...") + pytorch_output_np = pytorch_output.detach().numpy() + + # Compute differences + diff = np.abs(allo_output - pytorch_output_np) + mean_diff = np.mean(diff) + max_diff = np.max(diff) + rel_diff = np.mean(diff / (np.abs(pytorch_output_np) + 1e-8)) + + print(f"Mean absolute difference: {mean_diff:.6e}") + print(f"Max absolute difference: {max_diff:.6e}") + print(f"Mean relative difference: {rel_diff:.6e}") + + # Check if outputs are close + atol = 5e-4 + rtol = 2e-3 + is_close = np.allclose(allo_output, pytorch_output_np, atol=atol, rtol=rtol) + + if is_close: + print( + f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})" + ) + else: + print( + f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})" + ) + print("First few differences:") + print(diff.flatten()[:10]) + + # ---------------------------------------------------------------------------------- + # Print sample outputs for comparison + # ---------------------------------------------------------------------------------- + print("\n[8] Sample outputs (first token, first 5 dimensions):") + print(f"PyTorch: {pytorch_output_np[0, 0, :5]}") + print(f"Allo: {allo_output[0, 0, :5]}") + print(f"Diff: {diff[0, 0, :5]}") + + except Exception as e: + print(f"\n✗ ERROR: Failed to run Allo implementation: {e}") + import traceback + + traceback.print_exc() + + print("=" * 60) diff --git a/ece6775/attention_moe/pytorch_attention_moe.py b/ece6775/attention_moe/pytorch_attention_moe.py new file mode 100644 index 000000000..068430a6a --- /dev/null +++ b/ece6775/attention_moe/pytorch_attention_moe.py @@ -0,0 +1,1095 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Attention + Mixture of Experts (MoE) Implementation for Inference Only + +This script implements: +1. Scaled Dot-Product Attention (Multi-Head Attention) +2. MoE layer +3. Combined Attention -> MoE pipeline + +Uses random inputs and weights for demonstration. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple, Optional +import math + + +class ScaledDotProductAttention(nn.Module): + """ + Scaled Dot-Product Attention (Multi-Head Attention). + + This implementation matches the Allo library version: + - Input: Q, K, V of shape [L, D] + - Split into H heads, each with dimension D // H + - For each head: softmax(QK^T / sqrt(D // H)) @ V + - Merge heads back to [L, D] + + Args: + num_heads: Number of attention heads (H) + head_dim: Dimension per head (D // H) + """ + + def __init__(self, num_heads: int, embed_dim: int): + super().__init__() + self.num_heads = num_heads + self.embed_dim = embed_dim + self.head_dim = embed_dim // num_heads + assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" + + # Scaling factor: 1 / sqrt(head_dim) + self.scale = 1.0 / math.sqrt(self.head_dim) + + def forward( + self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, verbose: bool = False + ) -> torch.Tensor: + """ + Forward pass through scaled dot-product attention. + + Args: + Q: Query tensor of shape [L, D] + K: Key tensor of shape [L, D] + V: Value tensor of shape [L, D] + verbose: Whether to print debug information + + Returns: + Output tensor of shape [L, D] + """ + L, D = Q.shape + H = self.num_heads + head_dim = D // H + + # Initialize output + Z = torch.zeros(L, D, device=Q.device, dtype=Q.dtype) + + # Process each head + for h in range(H): + # Split Q, K, V for this head + start_idx = h * head_dim + end_idx = (h + 1) * head_dim + + Q_h = Q[:, start_idx:end_idx] # [L, head_dim] + K_h = K[:, start_idx:end_idx] # [L, head_dim] + V_h = V[:, start_idx:end_idx] # [L, head_dim] + + # QK^T = [L, head_dim] @ [head_dim, L] = [L, L] + # Note: K_h.T is the transpose + Y = torch.matmul(Q_h, K_h.T) # [L, L] + + # Scale by 1/sqrt(head_dim) + Y = Y * self.scale + + # Softmax over last dimension + S = F.softmax(Y, dim=-1) # [L, L] + + # YV = [L, L] @ [L, head_dim] = [L, head_dim] + C_h = torch.matmul(S, V_h) # [L, head_dim] + + # Merge back to Z + Z[:, start_idx:end_idx] = C_h + + if verbose: + print(f"\n[Attention] L={L}, D={D}, H={H}, head_dim={head_dim}") + print(f"[Attention] scale={self.scale:.6f}") + print(f"[Attention] Output shape: {Z.shape}") + + return Z + + +class TopKGate(nn.Module): + """ + Gate module to select top k experts for routing. + + Args: + input_dim: Input feature dimension + num_experts: Total number of experts + k: Number of experts to select per token + """ + + def __init__(self, input_dim: int, num_experts: int, k: int = 1): + super().__init__() + self.k = k + self.num_experts = num_experts + # Linear layer to compute expert logits + self.gate_linear = nn.Linear(input_dim, num_experts, bias=False) + + def forward( + self, x: torch.Tensor, verbose: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass through the gate. + + Args: + x: Input tensor of shape [batch_size * seq_len, input_dim] + verbose: Whether to print routing information + + Returns: + full_weights: Sparse weight matrix [batch_size * seq_len, num_experts] + top_k_indices: Indices of selected experts [batch_size * seq_len, k] + """ + # Compute logits for all experts + logits = self.gate_linear(x) # [N, num_experts] + + # Select top-k experts + top_k_logits, top_k_indices = torch.topk(logits, self.k, dim=-1) + + # Apply softmax to top-k logits for normalized weights + top_k_weights = F.softmax(top_k_logits, dim=-1) # [N, k] + + # Create sparse weight matrix (zeros for non-selected experts) + full_weights = torch.zeros_like(logits) + full_weights.scatter_(1, top_k_indices, top_k_weights) + + # Print routing information if verbose + if verbose: + num_tokens = x.shape[0] + + # Count tokens per expert + expert_counts = {} + for expert_idx in range(self.num_experts): + count = (top_k_indices == expert_idx).sum().item() + if count > 0: + expert_counts[expert_idx] = count + + # Verify that each token selects exactly k experts + tokens_per_expert_count = {} + for i in range(num_tokens): + num_experts_selected = (top_k_weights[i] > 1e-6).sum().item() + tokens_per_expert_count[num_experts_selected] = ( + tokens_per_expert_count.get(num_experts_selected, 0) + 1 + ) + + print(f"\n[Gate Routing] Total tokens: {num_tokens}, k={self.k}") + print(f"[Gate Routing] Expert distribution:") + total_selections = num_tokens * self.k # Total expert-token pairs + for expert_idx, count in sorted(expert_counts.items()): + percentage = (count / total_selections) * 100 + print( + f" Expert {expert_idx}: {count} selections ({percentage:.1f}% of {total_selections} total selections)" + ) + + # Verify top-k: each token selects exactly k experts + if tokens_per_expert_count.get(self.k, 0) == num_tokens: + print( + f"[Gate Routing] ✓ All {num_tokens} tokens select exactly {self.k} expert(s)" + ) + else: + print( + f"[Gate Routing] ⚠ Expected all tokens to select {self.k} expert(s), but got: {tokens_per_expert_count}" + ) + + return full_weights, top_k_indices + + +class Expert(nn.Module): + """ + A simple feed-forward expert network. + + Args: + input_dim: Input feature dimension + hidden_dim: Hidden layer dimension + output_dim: Output feature dimension + """ + + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): + super().__init__() + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, output_dim) + # Use GELU activation (modern choice) + self.activation = nn.GELU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the expert.""" + x = self.fc1(x) + x = self.activation(x) + x = self.fc2(x) + return x + + +class MoELayer(nn.Module): + """ + Mixture of Experts layer for inference. + + Args: + input_dim: Input feature dimension + output_dim: Output feature dimension + num_experts: Total number of experts + k: Number of experts to activate per token + expert_hidden_dim: Hidden dimension for experts (default: 4 * input_dim) + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + num_experts: int, + k: int = 1, + expert_hidden_dim: Optional[int] = None, + ): + super().__init__() + self.num_experts = num_experts + self.k = k + self.output_dim = output_dim + + if expert_hidden_dim is None: + expert_hidden_dim = input_dim * 4 # Common practice + + # Initialize gate and experts + self.gate = TopKGate(input_dim, num_experts, k) + self.experts = nn.ModuleList( + [ + Expert(input_dim, expert_hidden_dim, output_dim) + for _ in range(num_experts) + ] + ) + + def forward(self, x: torch.Tensor, verbose: bool = False) -> torch.Tensor: + """ + Forward pass through MoE layer. + + Args: + x: Input tensor of shape [batch_size, seq_len, input_dim] + verbose: Whether to print verification information for top-1 MoE + + Returns: + Output tensor of shape [batch_size, seq_len, output_dim] + """ + + # Store original shape + original_shape = x.shape + batch_size, seq_len = original_shape[0], original_shape[1] + + # Flatten batch and sequence dimensions + x_flat = x.view(-1, original_shape[-1]) # [N, input_dim], N = batch * seq_len + num_tokens = x_flat.shape[0] + + # Get gating weights and expert indices + gate_weights, top_k_indices = self.gate( + x_flat, verbose=verbose + ) # [N, num_experts], [N, k] + + # Initialize output tensor + output = torch.zeros( + x_flat.shape[0], self.output_dim, device=x.device, dtype=x.dtype + ) + + # Track expert usage statistics + expert_usage_stats = {} + + # Process each expert + for expert_idx in range(self.num_experts): + # Find tokens assigned to this expert + # Create mask: tokens that have this expert in their top-k + expert_mask = (top_k_indices == expert_idx).any(dim=-1) # [N] + + num_tokens_for_expert = expert_mask.sum().item() + + if not expert_mask.any(): + expert_usage_stats[expert_idx] = 0 + continue # No tokens for this expert + + # Track usage + expert_usage_stats[expert_idx] = num_tokens_for_expert + + # Get inputs for this expert + expert_inputs = x_flat[expert_mask] # [num_tokens, input_dim] + + # Process through expert + expert_outputs = self.experts[expert_idx]( + expert_inputs + ) # [num_tokens, output_dim] + + # Get corresponding weights for these tokens + # For each token, find the weight for this expert (could be in any of k positions) + token_indices = torch.where(expert_mask)[0] # Indices in flattened tensor + expert_weights = gate_weights[token_indices, expert_idx].unsqueeze( + 1 + ) # [num_tokens, 1] + + # Weight and accumulate outputs + weighted_outputs = expert_outputs * expert_weights + output[token_indices] += weighted_outputs + + # Verify top-k behavior + if verbose: + active_experts = [ + idx for idx, count in expert_usage_stats.items() if count > 0 + ] + total_usage = sum(expert_usage_stats.values()) + + # Verify top-k: each token selects exactly k experts + if top_k_indices.shape[1] != self.k: + print( + f"[MoE Verification] ✗ ERROR: Expected k={self.k}, but top_k_indices has shape {top_k_indices.shape}" + ) + + # Verify all tokens are processed (for k>1, total_usage may be > num_tokens) + expected_total_usage = num_tokens * self.k + if total_usage != expected_total_usage: + print( + f"[MoE Verification] ⚠ Expected {expected_total_usage} expert-token pairs (={num_tokens} tokens × {self.k} experts), but got {total_usage}" + ) + + # Check if weights are normalized (should sum to 1.0 per token) + sample_weights = gate_weights.sum(dim=1) + weights_normalized = torch.allclose( + sample_weights, torch.ones_like(sample_weights), atol=1e-5 + ) + + # Final summary for top-k + print(f"\n[MoE Verification] === Top-{self.k} MoE Verification ===") + print(f" ✓ Each token routes to exactly {self.k} expert(s) (k={self.k})") + print(f" ✓ All {num_tokens} tokens processed") + print( + f" ✓ Total expert-token pairs: {total_usage} (expected: {expected_total_usage})" + ) + print(f" ✓ Active experts: {len(active_experts)}/{self.num_experts}") + print( + f" ✓ Expert distribution: {dict(sorted(expert_usage_stats.items()))}" + ) + if weights_normalized: + print(f" ✓ Gate weights normalized (sum to 1.0 per token)") + else: + print(f" ✗ Gate weights NOT normalized correctly") + print(f" === Top-{self.k} MoE is working correctly! ===") + + # Reshape back to original shape + output = output.view(batch_size, seq_len, self.output_dim) + + return output + + +class AttentionMoE(nn.Module): + """ + Combined Attention + MoE layer. + + Data flow: Input -> Attention -> MoE -> Output + + This follows the Transformer architecture pattern where: + 1. Attention computes self-attention on the input + 2. MoE replaces the standard FFN (Feed-Forward Network) + + Args: + seq_len: Sequence length (L) + embed_dim: Embedding dimension (D) + num_heads: Number of attention heads (H) + num_experts: Number of MoE experts (E) + k: Top-k experts to activate per token + expert_hidden_dim: Hidden dimension for experts + """ + + def __init__( + self, + seq_len: int, + embed_dim: int, + num_heads: int, + num_experts: int, + k: int = 1, + expert_hidden_dim: Optional[int] = None, + ): + super().__init__() + self.seq_len = seq_len + self.embed_dim = embed_dim + self.num_heads = num_heads + + # Attention layer + self.attention = ScaledDotProductAttention(num_heads, embed_dim) + + # MoE layer (input_dim = output_dim = embed_dim) + self.moe = MoELayer( + input_dim=embed_dim, + output_dim=embed_dim, + num_experts=num_experts, + k=k, + expert_hidden_dim=expert_hidden_dim, + ) + + def forward( + self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, verbose: bool = False + ) -> torch.Tensor: + """ + Forward pass: Attention -> MoE + + Args: + Q: Query tensor of shape [B, L, D] + K: Key tensor of shape [B, L, D] + V: Value tensor of shape [B, L, D] + verbose: Whether to print debug information + + Returns: + Output tensor of shape [B, L, D] + """ + batch_size = Q.shape[0] + L = Q.shape[1] + D = Q.shape[2] + + # Process each batch item through attention + # Attention expects [L, D], so we process batch by batch + attn_outputs = [] + for b in range(batch_size): + attn_out = self.attention(Q[b], K[b], V[b], verbose=verbose and b == 0) + attn_outputs.append(attn_out) + + # Stack back to [B, L, D] + attn_output = torch.stack(attn_outputs, dim=0) + + if verbose: + print(f"\n[AttentionMoE] Attention output shape: {attn_output.shape}") + + # Pass through MoE + # MoE expects [B, L, D] and returns [B, L, D] + moe_output = self.moe(attn_output, verbose=verbose) + + if verbose: + print(f"[AttentionMoE] MoE output shape: {moe_output.shape}") + + return moe_output + + +def run_attention_moe_inference( + batch_size: int = 4, + seq_len: int = 4, + embed_dim: int = 64, + num_heads: int = 4, + num_experts: int = 2, + k: int = 1, + expert_hidden_dim: Optional[int] = None, + seed: int = 24, + verbose: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, AttentionMoE]: + """ + Run Attention + MoE inference with fixed seed for reproducible inputs and weights. + + Args: + batch_size: Batch size (B) + seq_len: Sequence length (L) + embed_dim: Embedding dimension (D) + num_heads: Number of attention heads (H) + num_experts: Number of experts (E) + k: Top-k experts to activate per token + expert_hidden_dim: Hidden dimension for experts + seed: Random seed for reproducibility + verbose: Whether to print verification information + + Returns: + output: Output tensor from AttentionMoE [B, L, D] + Q, K, V: Input tensors used for inference + model: AttentionMoE model instance + """ + # Set random seed for reproducibility + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # Create model + model = AttentionMoE( + seq_len=seq_len, + embed_dim=embed_dim, + num_heads=num_heads, + num_experts=num_experts, + k=k, + expert_hidden_dim=expert_hidden_dim, + ) + model.eval() + + # Initialize with random weights (Xavier uniform for better stability) + for param in model.parameters(): + if param.dim() > 1: + nn.init.xavier_uniform_(param) + else: + nn.init.zeros_(param) + + # Create random inputs (Q, K, V) + torch.manual_seed(seed) + Q = torch.randn(batch_size, seq_len, embed_dim) + K = torch.randn(batch_size, seq_len, embed_dim) + V = torch.randn(batch_size, seq_len, embed_dim) + + # Run inference with verbose output for verification + with torch.no_grad(): + output = model(Q, K, V, verbose=verbose) + + return output, Q, K, V, model + + +def verify_gate_layer( + num_tokens: int = 128, + input_dim: int = 96, + num_experts: int = 2, + k: int = 1, + seed: int = 42, +): + """ + Verify that the Gate Layer is effectively routing tokens to different experts. + + This function tests: + 1. Different inputs produce different routing decisions + 2. Gate weights are properly normalized (sum to 1.0) + 3. Expert distribution is reasonable (not all tokens to one expert) + 4. Routing is deterministic (same input -> same routing) + """ + print("=" * 70) + print("Gate Layer Verification Test") + print("=" * 70) + print(f"Configuration: num_tokens={num_tokens}, input_dim={input_dim}") + print(f" num_experts={num_experts}, k={k}") + print("=" * 70) + + torch.manual_seed(seed) + + # Create gate + gate = TopKGate(input_dim, num_experts, k) + nn.init.xavier_uniform_(gate.gate_linear.weight) + gate.eval() + + # Test 1: Create random inputs and check routing + print("\n[Test 1] Random Input Routing") + print("-" * 50) + x = torch.randn(num_tokens, input_dim) + + with torch.no_grad(): + full_weights, top_k_indices = gate(x, verbose=False) + + # Count tokens per expert + expert_counts = {} + for e in range(num_experts): + count = (top_k_indices == e).sum().item() + expert_counts[e] = count + percentage = (count / (num_tokens * k)) * 100 + print(f" Expert {e}: {count} tokens ({percentage:.1f}%)") + + # Check if routing is balanced + min_count = min(expert_counts.values()) + max_count = max(expert_counts.values()) + imbalance_ratio = max_count / max(min_count, 1) + + if imbalance_ratio < 10: + print( + f" ✓ Routing is reasonably balanced (imbalance ratio: {imbalance_ratio:.2f})" + ) + else: + print(f" ⚠ Routing is imbalanced (imbalance ratio: {imbalance_ratio:.2f})") + + # Test 2: Verify weights are normalized + print("\n[Test 2] Weight Normalization") + print("-" * 50) + weight_sums = full_weights.sum(dim=1) + all_normalized = torch.allclose( + weight_sums, torch.ones_like(weight_sums), atol=1e-5 + ) + + if all_normalized: + print(f" ✓ All gate weights sum to 1.0 (properly normalized)") + else: + print(f" ✗ Gate weights NOT properly normalized!") + print( + f" Weight sums range: [{weight_sums.min():.6f}, {weight_sums.max():.6f}]" + ) + + # Test 3: Different inputs should (mostly) produce different routings + print("\n[Test 3] Input-Dependent Routing") + print("-" * 50) + + # Create two very different inputs + x1 = torch.randn(10, input_dim) * 10 # Large magnitude + x2 = torch.randn(10, input_dim) * 0.1 # Small magnitude + + with torch.no_grad(): + _, indices1 = gate(x1, verbose=False) + _, indices2 = gate(x2, verbose=False) + + # Check if routings are different + same_routing = (indices1 == indices2).all().item() + if not same_routing: + print(f" ✓ Different inputs produce different routing decisions") + else: + print(f" ⚠ Different inputs produced same routing (may indicate issue)") + + # Test 4: Determinism - same input should always produce same routing + print("\n[Test 4] Routing Determinism") + print("-" * 50) + + x_test = torch.randn(20, input_dim) + with torch.no_grad(): + _, indices_run1 = gate(x_test, verbose=False) + _, indices_run2 = gate(x_test, verbose=False) + + is_deterministic = (indices_run1 == indices_run2).all().item() + if is_deterministic: + print(f" ✓ Routing is deterministic (same input -> same expert)") + else: + print(f" ✗ Routing is NOT deterministic!") + + # Test 5: Show actual routing decisions for a few tokens + print("\n[Test 5] Sample Routing Decisions") + print("-" * 50) + + # Create a small set of tokens to visualize + x_sample = torch.randn(8, input_dim) + with torch.no_grad(): + logits = gate.gate_linear(x_sample) + weights, indices = gate(x_sample, verbose=False) + + print(f" Token | Expert Logits (raw) | Selected Expert | Weight") + print(f" " + "-" * 60) + for i in range(min(8, x_sample.shape[0])): + logit_str = ", ".join([f"{l:.3f}" for l in logits[i].tolist()]) + selected = indices[i, 0].item() + weight = weights[i, selected].item() + print(f" {i:5d} | [{logit_str:25s}] | Expert {selected} | {weight:.4f}") + + # Test 6: Verify gate learns meaningful patterns + print("\n[Test 6] Pattern-Based Routing") + print("-" * 50) + + # Create inputs with clear patterns + # Pattern A: positive values in first half, negative in second half + # Pattern B: opposite + pattern_a = torch.cat( + [torch.ones(input_dim // 2), -torch.ones(input_dim // 2)] + ).unsqueeze(0) + pattern_b = torch.cat( + [-torch.ones(input_dim // 2), torch.ones(input_dim // 2)] + ).unsqueeze(0) + + # Repeat patterns + x_patterns = torch.cat([pattern_a.repeat(5, 1), pattern_b.repeat(5, 1)], dim=0) + + with torch.no_grad(): + _, pattern_indices = gate(x_patterns, verbose=False) + + pattern_a_experts = pattern_indices[:5, 0].tolist() + pattern_b_experts = pattern_indices[5:, 0].tolist() + + print(f" Pattern A (positive-negative) routed to experts: {pattern_a_experts}") + print(f" Pattern B (negative-positive) routed to experts: {pattern_b_experts}") + + # Check if same patterns go to same expert + a_consistent = len(set(pattern_a_experts)) == 1 + b_consistent = len(set(pattern_b_experts)) == 1 + patterns_different = set(pattern_a_experts) != set(pattern_b_experts) + + if a_consistent and b_consistent: + print(f" ✓ Same patterns consistently route to same expert") + else: + print( + f" ⚠ Same patterns route to different experts (expected with random weights)" + ) + + if patterns_different: + print(f" ✓ Different patterns route to different experts") + else: + print(f" ⚠ Different patterns route to same expert") + + print("\n" + "=" * 70) + print("Gate Layer Verification Complete!") + print("=" * 70) + + return gate, expert_counts + + +def benchmark_attention_moe( + batch_size: int = 1, + seq_len: int = 64, + embed_dim: int = 256, + num_heads: int = 8, + num_experts: int = 4, + k: int = 1, + expert_hidden_dim: Optional[int] = None, + num_warmup: int = 10, + num_runs: int = 100, + device: str = "cpu", + seed: int = 42, +): + """ + Benchmark AttentionMoE inference time. + + Best practices for accurate timing: + 1. Warmup runs: Avoid cold start overhead (JIT compilation, cache warming) + 2. Multiple runs: Get stable average and standard deviation + 3. torch.no_grad(): Disable gradient tracking for inference + 4. torch.cuda.synchronize(): Ensure GPU operations complete before timing + + Args: + batch_size: Batch size (B) + seq_len: Sequence length (L) + embed_dim: Embedding dimension (D) + num_heads: Number of attention heads (H) + num_experts: Number of experts (E) + k: Top-k experts per token + expert_hidden_dim: Hidden dim for experts (default: 4 * embed_dim) + num_warmup: Number of warmup iterations + num_runs: Number of timed iterations + device: "cpu" or "cuda" + seed: Random seed + + Returns: + dict: Timing statistics (mean, std, min, max in milliseconds) + """ + import time + import numpy as np + + print("=" * 70) + print("AttentionMoE Benchmark") + print("=" * 70) + print(f"Configuration:") + print(f" batch_size={batch_size}, seq_len={seq_len}, embed_dim={embed_dim}") + print(f" num_heads={num_heads}, head_dim={embed_dim // num_heads}") + print(f" num_experts={num_experts}, k={k}") + print(f" expert_hidden_dim={expert_hidden_dim or embed_dim * 4}") + print(f" device={device}") + print(f" num_warmup={num_warmup}, num_runs={num_runs}") + print("=" * 70) + + # Set seed and create model + torch.manual_seed(seed) + + model = AttentionMoE( + seq_len=seq_len, + embed_dim=embed_dim, + num_heads=num_heads, + num_experts=num_experts, + k=k, + expert_hidden_dim=expert_hidden_dim, + ) + model.eval() + + # Move to device + if device == "cuda" and torch.cuda.is_available(): + model = model.cuda() + print(f" Using GPU: {torch.cuda.get_device_name(0)}") + else: + device = "cpu" + print(f" Using CPU") + + # Initialize weights + for param in model.parameters(): + if param.dim() > 1: + nn.init.xavier_uniform_(param) + else: + nn.init.zeros_(param) + + # Create input tensors + torch.manual_seed(seed) + if device == "cuda": + Q = torch.randn(batch_size, seq_len, embed_dim, device="cuda") + K = torch.randn(batch_size, seq_len, embed_dim, device="cuda") + V = torch.randn(batch_size, seq_len, embed_dim, device="cuda") + else: + Q = torch.randn(batch_size, seq_len, embed_dim) + K = torch.randn(batch_size, seq_len, embed_dim) + V = torch.randn(batch_size, seq_len, embed_dim) + + # Warmup runs (important for JIT, cache warming, etc.) + print(f"\n[1] Warmup ({num_warmup} iterations)...") + with torch.no_grad(): + for _ in range(num_warmup): + _ = model(Q, K, V, verbose=False) + if device == "cuda": + torch.cuda.synchronize() # Wait for GPU to finish + + # Timed runs + print(f"[2] Timing ({num_runs} iterations)...") + times = [] + + with torch.no_grad(): + for i in range(num_runs): + if device == "cuda": + torch.cuda.synchronize() # Ensure previous work is done + + start_time = time.perf_counter() # High-resolution timer + + _ = model(Q, K, V, verbose=False) + + if device == "cuda": + torch.cuda.synchronize() # Wait for GPU to finish + + end_time = time.perf_counter() + times.append((end_time - start_time) * 1000) # Convert to ms + + # Compute statistics + times = np.array(times) + stats = { + "mean_ms": np.mean(times), + "std_ms": np.std(times), + "min_ms": np.min(times), + "max_ms": np.max(times), + "median_ms": np.median(times), + "p95_ms": np.percentile(times, 95), + "p99_ms": np.percentile(times, 99), + } + + # Print results + print("\n" + "=" * 70) + print("Benchmark Results:") + print("=" * 70) + print(f" Mean: {stats['mean_ms']:.4f} ms") + print(f" Std: {stats['std_ms']:.4f} ms") + print(f" Min: {stats['min_ms']:.4f} ms") + print(f" Max: {stats['max_ms']:.4f} ms") + print(f" Median: {stats['median_ms']:.4f} ms") + print(f" P95: {stats['p95_ms']:.4f} ms") + print(f" P99: {stats['p99_ms']:.4f} ms") + print("-" * 70) + print(f" Throughput: {1000 / stats['mean_ms']:.2f} inferences/sec") + print(f" Tokens/sec: {batch_size * seq_len * 1000 / stats['mean_ms']:.2f}") + print("=" * 70) + + return stats, model + + +def benchmark_with_allo_config( + config_mode: str = "switch_base_8_scaled_1_8", + num_warmup: int = 10, + num_runs: int = 100, + device: str = "cpu", + seed: int = 42, +): + """ + Benchmark using the same configuration as Allo tests. + This ensures fair comparison between PyTorch and Allo implementations. + + Uses llm_config.py for configuration to match allo_attention_moe_alt.py + + Args: + config_mode: Configuration mode from llm_config (e.g., "switch_base_8_scaled_1_8") + num_warmup: Number of warmup iterations + num_runs: Number of timed iterations + device: "cpu" or "cuda" + seed: Random seed (must match Allo test seed for same inputs) + + Returns: + dict: Timing statistics and model + """ + import sys + import os + import time + import numpy as np + + # Add llm_config to path + current_dir = os.path.dirname(os.path.abspath(__file__)) + project_root = os.path.dirname(current_dir) + llm_config_dir = os.path.join(project_root, "llm_config") + if llm_config_dir not in sys.path: + sys.path.insert(0, llm_config_dir) + + from llm_config import get_moe_config + + # Get configuration (same as allo_attention_moe_alt.py) + moe_config = get_moe_config(config_mode) + + batch_size = moe_config["batch_size"] + seq_len = moe_config["seq_len"] + embed_dim = moe_config["input_dim"] + num_experts = moe_config["num_experts"] + k = moe_config["k"] + hidden_dim = moe_config["hidden_dim"] + + # Determine num_heads (same logic as allo_attention_moe_alt.py) + if embed_dim >= 768: + num_heads = 12 + elif embed_dim >= 512: + num_heads = 8 + elif embed_dim >= 256: + num_heads = 8 + elif embed_dim >= 128: + num_heads = 4 + elif embed_dim >= 64: + num_heads = 4 + else: + num_heads = 2 + + while embed_dim % num_heads != 0 and num_heads > 1: + num_heads -= 1 + + print("=" * 70) + print("PyTorch AttentionMoE Benchmark (Allo-compatible config)") + print("=" * 70) + print(f"Config Mode: {config_mode}") + print(f"Configuration:") + print(f" batch_size={batch_size}, seq_len={seq_len}, embed_dim={embed_dim}") + print(f" num_heads={num_heads}, head_dim={embed_dim // num_heads}") + print(f" num_experts={num_experts}, k={k}, hidden_dim={hidden_dim}") + print(f" device={device}, seed={seed}") + print(f" num_warmup={num_warmup}, num_runs={num_runs}") + print("=" * 70) + + # Create model with same initialization as Allo test + torch.manual_seed(seed) + if torch.cuda.is_available() and device == "cuda": + torch.cuda.manual_seed_all(seed) + + model = AttentionMoE( + seq_len=seq_len, + embed_dim=embed_dim, + num_heads=num_heads, + num_experts=num_experts, + k=k, + expert_hidden_dim=hidden_dim, + ) + model.eval() + + # Initialize with Xavier uniform (same as Allo test) + for param in model.parameters(): + if param.dim() > 1: + nn.init.xavier_uniform_(param) + else: + nn.init.zeros_(param) + + # Move to device + if device == "cuda" and torch.cuda.is_available(): + model = model.cuda() + print(f" Using GPU: {torch.cuda.get_device_name(0)}") + else: + device = "cpu" + print(f" Using CPU") + + # Create input tensors with same seed (same as Allo test) + torch.manual_seed(seed) + if device == "cuda": + Q = torch.randn(batch_size, seq_len, embed_dim, device="cuda") + K = torch.randn(batch_size, seq_len, embed_dim, device="cuda") + V = torch.randn(batch_size, seq_len, embed_dim, device="cuda") + else: + Q = torch.randn(batch_size, seq_len, embed_dim) + K = torch.randn(batch_size, seq_len, embed_dim) + V = torch.randn(batch_size, seq_len, embed_dim) + + print(f"\nInput shapes: Q={Q.shape}, K={K.shape}, V={V.shape}") + + # Verify output first + print("\n[1] Verifying output...") + with torch.no_grad(): + output = model(Q, K, V, verbose=False) + print(f" Output shape: {output.shape}") + print(f" Output range: [{output.min().item():.6f}, {output.max().item():.6f}]") + + # Warmup + print(f"\n[2] Warmup ({num_warmup} iterations)...") + with torch.no_grad(): + for _ in range(num_warmup): + _ = model(Q, K, V, verbose=False) + if device == "cuda": + torch.cuda.synchronize() + + # Timed runs + print(f"[3] Timing ({num_runs} iterations)...") + times = [] + + with torch.no_grad(): + for i in range(num_runs): + if device == "cuda": + torch.cuda.synchronize() + + start_time = time.perf_counter() + _ = model(Q, K, V, verbose=False) + + if device == "cuda": + torch.cuda.synchronize() + + end_time = time.perf_counter() + times.append((end_time - start_time) * 1000) + + # Compute statistics + times = np.array(times) + stats = { + "mean_ms": np.mean(times), + "std_ms": np.std(times), + "min_ms": np.min(times), + "max_ms": np.max(times), + "median_ms": np.median(times), + "p95_ms": np.percentile(times, 95), + "p99_ms": np.percentile(times, 99), + } + + # Print results + print("\n" + "=" * 70) + print("Benchmark Results:") + print("=" * 70) + print(f" Mean: {stats['mean_ms']:.4f} ms") + print(f" Std: {stats['std_ms']:.4f} ms") + print(f" Min: {stats['min_ms']:.4f} ms") + print(f" Max: {stats['max_ms']:.4f} ms") + print(f" Median: {stats['median_ms']:.4f} ms") + print(f" P95: {stats['p95_ms']:.4f} ms") + print(f" P99: {stats['p99_ms']:.4f} ms") + print("-" * 70) + print(f" Throughput: {1000 / stats['mean_ms']:.2f} inferences/sec") + print(f" Tokens/sec: {batch_size * seq_len * 1000 / stats['mean_ms']:.2f}") + print("=" * 70) + + return stats, model, Q, K, V + + +if __name__ == "__main__": + # First, run gate layer verification + print("\n" + "#" * 70) + print("# PART 1: Gate Layer Verification") + print("#" * 70) + + gate, expert_counts = verify_gate_layer( + num_tokens=128, input_dim=96, num_experts=2, k=1, seed=42 + ) + + # Then run the full Attention + MoE test + print("\n" + "#" * 70) + print("# PART 2: Full Attention + MoE Inference Test") + print("#" * 70) + + # Configuration parameters + batch_size = 1 + seq_len = 4 + embed_dim = 8 # D, must be divisible by num_heads + num_heads = 2 # H + num_experts = 2 # E + k = 1 # Top-k MoE: each token uses exactly k experts + expert_hidden_dim = 16 # Hidden dimension for experts + seed = 24 + + print("=" * 60) + print(f"Attention + MoE Inference Test") + print("=" * 60) + print(f"Configuration:") + print(f" batch_size={batch_size}, seq_len={seq_len}, embed_dim={embed_dim}") + print(f" num_heads={num_heads}, head_dim={embed_dim // num_heads}") + print(f" num_experts={num_experts}, k={k}, expert_hidden_dim={expert_hidden_dim}") + print("=" * 60) + + # Run Attention + MoE inference + output, Q, K, V, model = run_attention_moe_inference( + batch_size=batch_size, + seq_len=seq_len, + embed_dim=embed_dim, + num_heads=num_heads, + num_experts=num_experts, + k=k, + expert_hidden_dim=expert_hidden_dim, + seed=seed, + verbose=True, + ) + + print(f"\n" + "=" * 60) + print(f"Results:") + print(f" Input Q shape: {Q.shape}") + print(f" Input K shape: {K.shape}") + print(f" Input V shape: {V.shape}") + print(f" Output shape: {output.shape}") + print(f" Output range: [{output.min().item():.6f}, {output.max().item():.6f}]") + print("=" * 60) + + # PART 3: Benchmark with Allo-compatible config + print("\n" + "#" * 70) + print("# PART 3: Performance Benchmark (Allo-compatible config)") + print("#" * 70) + + # Use the same config as allo_attention_moe_alt.py for fair comparison + stats, _, Q_bench, K_bench, V_bench = benchmark_with_allo_config( + config_mode="switch_base_8_scaled_1_8", # Same as DEFAULT_CONFIG_MODE in llm_config + num_warmup=10, + num_runs=100, + device="cuda" if torch.cuda.is_available() else "cpu", + seed=42, # Same seed as Allo test + ) diff --git a/ece6775/llm_config/check_llm_config.py b/ece6775/llm_config/check_llm_config.py new file mode 100644 index 000000000..449d572f2 --- /dev/null +++ b/ece6775/llm_config/check_llm_config.py @@ -0,0 +1,163 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# check MoE model config from Hugging Face +# usage: python check_llm_config.py + +import sys +from transformers import AutoConfig + + +def print_moe_config(model_name): + # print MoE model config + try: + print(f"Loading model configuration: {model_name}") + print("=" * 80) + + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + + # Print basic configuration + print("\n[Basic Architecture Parameters]") + print( + f"Model type: {config.model_type if hasattr(config, 'model_type') else 'N/A'}" + ) + print( + f"Vocabulary size: {config.vocab_size if hasattr(config, 'vocab_size') else 'N/A'}" + ) + print( + f"Max position embeddings: {config.max_position_embeddings if hasattr(config, 'max_position_embeddings') else 'N/A'}" + ) + + # Print hidden layer dimensions + print("\n[Hidden Layer Dimensions]") + if hasattr(config, "hidden_size"): + print(f"Hidden size (hidden_size): {config.hidden_size}") + if hasattr(config, "d_model"): + print(f"Model dimension (d_model): {config.d_model}") + if hasattr(config, "n_embd"): + print(f"Embedding dimension (n_embd): {config.n_embd}") + + # Print MoE related parameters + print("\n[MoE Expert Parameters]") + if hasattr(config, "num_local_experts"): + print( + f"Number of local experts (num_local_experts): {config.num_local_experts}" + ) + if hasattr(config, "num_experts"): + print(f"Number of experts (num_experts): {config.num_experts}") + if hasattr(config, "num_experts_per_tok"): + print( + f"Number of experts per token (num_experts_per_tok): {config.num_experts_per_tok}" + ) + if hasattr(config, "num_experts_to_select"): + print( + f"Number of experts to select (num_experts_to_select): {config.num_experts_to_select}" + ) + + # Print feed-forward network parameters + print("\n[Feed-Forward Network Parameters]") + if hasattr(config, "intermediate_size"): + print(f"Intermediate size (intermediate_size): {config.intermediate_size}") + if hasattr(config, "ffn_dim"): + print(f"FFN dimension (ffn_dim): {config.ffn_dim}") + if hasattr(config, "n_inner"): + print(f"Inner dimension (n_inner): {config.n_inner}") + + # Print attention parameters + print("\n[Attention Parameters]") + if hasattr(config, "num_attention_heads"): + print( + f"Number of attention heads (num_attention_heads): {config.num_attention_heads}" + ) + if hasattr(config, "num_heads"): + print(f"Number of heads (num_heads): {config.num_heads}") + if hasattr(config, "n_head"): + print(f"Number of heads (n_head): {config.n_head}") + + # Print layer count + print("\n[Layer Parameters]") + if hasattr(config, "num_hidden_layers"): + print( + f"Number of hidden layers (num_hidden_layers): {config.num_hidden_layers}" + ) + if hasattr(config, "num_layers"): + print(f"Number of layers (num_layers): {config.num_layers}") + if hasattr(config, "n_layer"): + print(f"Number of layers (n_layer): {config.n_layer}") + + # Print all configuration (for debugging) + print("\n[Full Configuration Information]") + print("-" * 80) + for key, value in config.to_dict().items(): + if isinstance(value, (int, float, str, bool, type(None))): + print(f"{key}: {value}") + elif isinstance(value, (list, tuple)) and len(value) < 10: + print(f"{key}: {value}") + + print("\n" + "=" * 80) + print("Configuration loaded successfully!") + + # Provide suggested test parameters + print("\n[Suggested Test Parameters (for Allo Implementation)]") + hidden_size = getattr( + config, + "hidden_size", + getattr(config, "d_model", getattr(config, "n_embd", None)), + ) + num_experts = getattr( + config, "num_local_experts", getattr(config, "num_experts", None) + ) + intermediate_size = getattr( + config, + "intermediate_size", + getattr(config, "ffn_dim", getattr(config, "n_inner", None)), + ) + num_experts_per_tok = getattr( + config, "num_experts_per_tok", getattr(config, "num_experts_to_select", 1) + ) + + if hidden_size: + print(f"input_dim = {hidden_size} # Input dimension") + if intermediate_size: + print(f"hidden_dim = {intermediate_size} # Expert hidden dimension") + if hidden_size: + print(f"output_dim = {hidden_size} # Output dimension") + if num_experts: + print(f"num_experts = {num_experts} # Number of experts") + if num_experts_per_tok: + print( + f"k = {num_experts_per_tok} # Top-k (number of experts activated per token)" + ) + print(f"batch_size = 1 # Batch size (adjustable as needed)") + print( + f"seq_len = 128 # Sequence length (adjustable as needed, actual model may support longer)" + ) + + except Exception as e: + print(f"Error: {e}") + import traceback + + traceback.print_exc() + print( + "\nHint: Make sure transformers library is installed: pip install transformers" + ) + print( + "If the model requires trust_remote_code=True, make sure to trust the model" + ) + + +if __name__ == "__main__": + if len(sys.argv) < 2: + print(__doc__) + print("\n[Common MoE Model Examples]") + print("1. DeepSeek MoE:") + print(" python check_moe_config.py deepseek-ai/deepseek-moe-16b-base") + print("\n2. Mixtral:") + print(" python check_moe_config.py mistralai/Mixtral-8x7B-v0.1") + print("\n3. Switch Transformer:") + print(" python check_moe_config.py google/switch-base-8") + print("\n4. GShard:") + print(" python check_moe_config.py google/gshard-gpt") + sys.exit(1) + + model_name = sys.argv[1] + print_moe_config(model_name) diff --git a/ece6775/llm_config/llm_config.py b/ece6775/llm_config/llm_config.py new file mode 100644 index 000000000..24bcfd1a4 --- /dev/null +++ b/ece6775/llm_config/llm_config.py @@ -0,0 +1,143 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +# LLM model config - shared for MoE and Attention+MoE + +# config modes: switch_base_8*, mixtral_8x7b*, deepseek*, custom +DEFAULT_CONFIG_MODE = "switch_base_8_scaled_1_8" + + +def get_moe_config(config_mode=None): + # get MoE config (works for standalone MoE and Attention+MoE) + if config_mode is None: + config_mode = DEFAULT_CONFIG_MODE + + config = {} + + if config_mode == "switch_base_8": + # Google Switch-Base-8 (original) + config = { + "batch_size": 1, + "seq_len": 128, + "input_dim": 768, # hidden_size + "output_dim": 768, # hidden_size + "num_experts": 8, + "k": 1, # Top-1 MoE + "hidden_dim": 2048, + } + elif config_mode == "switch_base_8_scaled_2_3": + # scaled to 2/3 + config = { + "batch_size": 1, + "seq_len": 128, + "input_dim": 512, # 768 * 2/3 + "output_dim": 512, + "num_experts": 2, + "k": 1, # Top-1 MoE + "hidden_dim": 1024, + } + elif config_mode == "switch_base_8_scaled_1_2": + # scaled to 1/2 + config = { + "batch_size": 1, + "seq_len": 128, + "input_dim": 384, # 768 * 1/2 + "output_dim": 384, + "num_experts": 2, + "k": 1, # Top-1 MoE + "hidden_dim": 1024, + } + elif config_mode == "switch_base_8_scaled_1_4": + # scaled to 1/4 + config = { + "batch_size": 1, + "seq_len": 128, + "input_dim": 192, # 768 * 1/4 + "output_dim": 192, + "num_experts": 2, + "k": 1, # Top-1 MoE + "hidden_dim": 512, + } + elif config_mode == "switch_base_8_scaled_1_8": + # scaled to 1/8 + config = { + "batch_size": 1, + "seq_len": 128, + "input_dim": 96, # 768 * 1/8 + "output_dim": 96, + "num_experts": 2, # changed to 2 for testing (original: 8) + "k": 1, + "hidden_dim": 256, + } + elif config_mode == "mixtral_8x7b": + # Mixtral-8x7B (original) + config = { + "batch_size": 1, + "seq_len": 128, + "input_dim": 4096, # hidden_size + "output_dim": 4096, # hidden_size + "num_experts": 8, + "k": 2, # Top-2 MoE + "hidden_dim": 14336, + } + elif config_mode == "mixtral_8x7b_scaled_1_8": + # scaled to 1/8 (Top-2 MoE) + config = { + "batch_size": 1, + "seq_len": 128, + "input_dim": 512, # 4096 / 8 + "output_dim": 512, + "num_experts": 8, + "k": 2, # Top-2 MoE (real Mixtral uses Top-2) + "hidden_dim": 2048, + } + elif config_mode == "deepseek": + # DeepSeek MoE-16B (needs significant resources) + config = { + "batch_size": 1, + "seq_len": 128, # Can be adjusted, model supports up to 4096 + "input_dim": 2048, # hidden_size + "output_dim": 2048, # hidden_size + "num_experts": 64, + "k": 6, + "hidden_dim": 10944, + } + elif config_mode == "deepseek_scaled": + # scaled version + config = { + "batch_size": 1, + "seq_len": 128, + "input_dim": 512, # 2048 / 4 + "output_dim": 512, + "num_experts": 8, # 64 / 8 + "k": 1, # 6 -> 1 (current implementation supports k=1) + "hidden_dim": 2048, + } + else: # custom + # small test config + config = { + "batch_size": 4, + "seq_len": 4, + "input_dim": 64, + "output_dim": 64, + "num_experts": 2, + "k": 1, # Top-1 MoE + "hidden_dim": 256, + } + + return config + + +def print_config_info(config_mode, config): + # print config info + print("=" * 60) + print(f"LLM Configuration (MoE)") + print("=" * 60) + print(f"Model: {config_mode}") + print(f"Batch size: {config['batch_size']}") + print(f"Sequence length: {config['seq_len']}") + print(f"Input dimension: {config['input_dim']}") + print(f"Output dimension: {config['output_dim']}") + print(f"Number of experts: {config['num_experts']}") + print(f"Top-k: {config['k']}") + print(f"Hidden dimension: {config['hidden_dim']}") + print("=" * 60) diff --git a/ece6775/moe/allo_moe_alt.py b/ece6775/moe/allo_moe_alt.py new file mode 100644 index 000000000..2fc90f26d --- /dev/null +++ b/ece6775/moe/allo_moe_alt.py @@ -0,0 +1,510 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# MoE with fused GeLU optimization for HLS +# fuses GeLU into FC1 to save memory and enable pipelining +# small diff vs pytorch (~1e-4) from accumulation order + +import numpy as np +import allo +from allo.ir.types import float32, int32 +from allo import dsl + +MODE = "csyn" # llvm, sw_emu, hw_emu, hw, csyn + +import sys +import os + +# path setup for config +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(current_dir) +llm_config_dir = os.path.join(project_root, "llm_config") +if llm_config_dir not in sys.path: + sys.path.insert(0, llm_config_dir) + +from llm_config import DEFAULT_CONFIG_MODE, get_moe_config, print_config_info + +CONFIG_MODE = DEFAULT_CONFIG_MODE + + +def softmax_1d[Ty, N, K](X: "Ty[N, K]") -> "Ty[N, K]": + # stable softmax over last dim + Z: Ty[N, K] + E_exp: Ty[N, K] + M: Ty[N] = -1e12 + S: Ty[N] = 0.0 + + # find max per row + for n, k in dsl.grid(N, K, name="row_max"): + if X[n, k] > M[n]: + M[n] = X[n, k] + + # exp and sum + for n, k in dsl.grid(N, K, name="exp_sum"): + E_exp[n, k] = dsl.exp(X[n, k] - M[n]) + S[n] += E_exp[n, k] + + # normalize + for n, k in dsl.grid(N, K, name="normalize"): + Z[n, k] = E_exp[n, k] / S[n] + + return Z + + +# top-1 selection (argmax) +def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": + # pick best expert per token + indices: int32[N] + max_val: Ty[N] = -1e12 + + for n in range(N, name="init"): + indices[n] = 0 + max_val[n] = logits[n, 0] + + for n, e in dsl.grid(N, E, name="argmax"): + if e > 0 and logits[n, e] > max_val[n]: + max_val[n] = logits[n, e] + indices[n] = e + + return indices + + +# FFN expert with fused GeLU, row-by-row for pipelining +def expert[ + Ty, N, D_in, D_hidden, D_out +]( + x: "Ty[N, D_in]", + fc1_weight: "Ty[D_hidden, D_in]", + fc1_bias: "Ty[D_hidden]", + fc2_weight: "Ty[D_out, D_hidden]", + fc2_bias: "Ty[D_out]", +) -> "Ty[N, D_out]": + output: Ty[N, D_out] = 0.0 + + for n in range(N, name="row_loop"): + hidden_row: Ty[D_hidden] = 0.0 + + # FC1 + GeLU together + for h in range(D_hidden, name="fc1_gelu_loop"): + acc: Ty = 0.0 + for k in range(D_in, name="fc1_reduce"): + acc += x[n, k] * fc1_weight[h, k] + fc1_val: Ty = acc + fc1_bias[h] + + # GeLU: tanh approximation + x3: Ty = fc1_val * fc1_val * fc1_val + inner: Ty = 0.7978845608028654 * ( + fc1_val + 0.044715 * x3 + ) # sqrt(2/pi) ≈ 0.7978... + hidden_row[h] = 0.5 * fc1_val * (1.0 + dsl.tanh(inner)) + + # FC2 + for o in range(D_out, name="fc2_loop"): + acc2: Ty = 0.0 + for k in range(D_hidden, name="fc2_reduce"): + acc2 += hidden_row[k] * fc2_weight[o, k] + output[n, o] = acc2 + fc2_bias[o] + + return output + + +# Top-1 MoE: pick one expert per token, weight with softmax +def moe_layer[ + Ty, N, D_in, D_out, E, K, D_hidden +]( + x: "Ty[N, D_in]", + gate_weight: "Ty[E, D_in]", + gate_bias: "Ty[E]", + expert_fc1_weights: "Ty[E, D_hidden, D_in]", + expert_fc1_biases: "Ty[E, D_hidden]", + expert_fc2_weights: "Ty[E, D_out, D_hidden]", + expert_fc2_biases: "Ty[E, D_out]", +) -> "Ty[N, D_out]": + # compute gate scores + gate_logits: Ty[N, E] = 0.0 + for n, e in dsl.grid(N, E, name="gate_linear"): + acc: Ty = 0.0 + for k in range(D_in, name="gate_reduce"): + acc += x[n, k] * gate_weight[e, k] + gate_logits[n, e] = acc + gate_bias[e] + + # pick best expert + top1_idx: int32[N] = top1_select[Ty, N, E](gate_logits) + + # get top-k logits (for k=1) + top_k_logits: Ty[N, K] = 0.0 + for n, k in dsl.grid(N, K, name="topk_logits"): + expert_idx = top1_idx[n] if k == 0 else 0 + top_k_logits[n, k] = gate_logits[n, expert_idx] + + top_k_weights = softmax_1d[Ty, N, K](top_k_logits) + + # sparse weights + gate_w: Ty[N, E] = 0.0 + for n in range(N, name="sparse_gate"): + gate_w[n, top1_idx[n]] = top_k_weights[n, 0] + + # run all experts (could optimize to skip unused ones later) + expert_out: Ty[E, N, D_out] = 0.0 + + for e in range(E, name="expert_loop"): + # extract weights for this expert + fc1_w: Ty[D_hidden, D_in] = 0.0 + fc1_b: Ty[D_hidden] = 0.0 + fc2_w: Ty[D_out, D_hidden] = 0.0 + fc2_b: Ty[D_out] = 0.0 + + for i, j in dsl.grid(D_hidden, D_in, name="extract_fc1_w"): + fc1_w[i, j] = expert_fc1_weights[e, i, j] + for i in range(D_hidden, name="extract_fc1_b"): + fc1_b[i] = expert_fc1_biases[e, i] + for i, j in dsl.grid(D_out, D_hidden, name="extract_fc2_w"): + fc2_w[i, j] = expert_fc2_weights[e, i, j] + for i in range(D_out, name="extract_fc2_b"): + fc2_b[i] = expert_fc2_biases[e, i] + + out = expert[Ty, N, D_in, D_hidden, D_out](x, fc1_w, fc1_b, fc2_w, fc2_b) + + for n, d in dsl.grid(N, D_out, name="store_expert_out"): + expert_out[e, n, d] = out[n, d] + + # combine with weights + output: Ty[N, D_out] = 0.0 + for n, e, d_out in dsl.grid(N, E, D_out, name="combine_outputs"): + output[n, d_out] += expert_out[e, n, d_out] * gate_w[n, e] + + return output + + +# build schedule with pipelining/unrolling optimizations +def optimize_moe_schedule( + num_tokens, input_dim, output_dim, num_experts, k, hidden_dim +): + Ty = float32 + N = num_tokens + D_in = input_dim + D_out = output_dim + E = num_experts + K = k + D_hidden = hidden_dim + + print("Building MoE schedule...") + + # top1_select + s_top1 = allo.customize(top1_select, instantiate=[Ty, N, E]) + s_top1.pipeline("top1_select:e") + + # softmax + s_softmax = allo.customize(softmax_1d, instantiate=[Ty, N, K]) + sm_loops = s_softmax.get_loops(s_softmax.top_func_name) + s_softmax.pipeline(sm_loops["row_max"]["k"]) + s_softmax.pipeline(sm_loops["exp_sum"]["k"]) + s_softmax.pipeline(sm_loops["normalize"]["k"]) + + # expert + s_expert = allo.customize(expert, instantiate=[Ty, N, D_in, D_hidden, D_out]) + exp_loops = s_expert.get_loops(s_expert.top_func_name) + row_loop = exp_loops["row_loop"] + s_expert.pipeline(row_loop["h"]) + s_expert.pipeline(row_loop["o"]) + + # unroll reduction loop for parallel MACs + unroll_factor = min(4, D_in, D_hidden) + s_expert.unroll(row_loop["k"], factor=unroll_factor) + print(f" - Applied unroll to k (reduction loop) factor={unroll_factor}") + + # array partitioning + if D_hidden <= 32: + s_expert.partition(s_expert.hidden_row, dim=0) + print(f" - Applied complete partition to hidden_row (D_hidden={D_hidden})") + else: + s_expert.partition(s_expert.hidden_row, dim=0, factor=4) + print(f" - Applied cyclic partition to hidden_row (factor=4)") + + print( + " - Created expert schedule with pipeline, unroll, and partition optimizations" + ) + + # moe_layer schedule + print("\n[4] Creating schedule for moe_layer...") + s_moe = allo.customize(moe_layer, instantiate=[Ty, N, D_in, D_out, E, K, D_hidden]) + + moe_loops = s_moe.get_loops(s_moe.top_func_name) + print(f" - Available top-level loops: {list(moe_loops.loops.keys())}") + + # optimize gate_linear + gate_linear_loop = moe_loops["gate_linear"] + print(f" - gate_linear sub-loops: {list(gate_linear_loop.loops.keys())}") + + s_moe.pipeline(gate_linear_loop["e"]) + print(" - Applied pipeline to gate_linear:e") + + unroll_factor_gate = min(4, D_in) + s_moe.unroll(gate_linear_loop["k"], factor=unroll_factor_gate) + print(f" - Applied unroll to gate_linear:k (factor={unroll_factor_gate})") + + # other loops + combine_loop = moe_loops["combine_outputs"] + s_moe.pipeline(combine_loop["d_out"]) + + topk_loop = moe_loops["topk_logits"] + s_moe.pipeline(topk_loop["k"]) + + sparse_loop = moe_loops["sparse_gate"] + s_moe.pipeline(sparse_loop["n"]) + + print(" - Applied pipeline to combine_outputs:d_out, topk_logits:k, sparse_gate:n") + + # compose sub-schedules + s_moe.compose(s_top1) + s_moe.compose(s_softmax) + s_moe.compose(s_expert) + print(" - Composed top1_select, softmax_1d, and expert schedules") + + print("\n" + "=" * 60) + print("Schedule done!") + print("=" * 60) + print("Optimizations:") + print(" 1. Fused GeLU into FC1") + print(" 2. Row-level structure") + print(" 3. Unrolled reduction loops") + print(f" - Expert k: factor {unroll_factor}") + print(f" - Gate k: factor {unroll_factor_gate}") + print(" 4. Array partitioning for hidden_row") + print(" 5. Pipeline on h, o, e loops") + print("=" * 60) + + return s_moe + + +# test/compare with pytorch +if __name__ == "__main__": + import torch + import torch.nn as nn + from pytorch_moe import MoELayer + + moe_config = get_moe_config(CONFIG_MODE) + + batch_size = moe_config["batch_size"] + seq_len = moe_config["seq_len"] + input_dim = moe_config["input_dim"] + output_dim = moe_config["output_dim"] + num_experts = moe_config["num_experts"] + k = moe_config["k"] + hidden_dim = moe_config["hidden_dim"] + seed = 42 + + print("=" * 60) + print("MoE Allo (Vayun Optimized) vs PyTorch Comparison Test") + print("=" * 60) + print(f"Configuration Mode: {CONFIG_MODE}") + print(f"Configuration:") + print(f" batch_size={batch_size}, seq_len={seq_len}") + print(f" input_dim={input_dim}, output_dim={output_dim}") + print(f" num_experts={num_experts}, k={k}, hidden_dim={hidden_dim}") + print(f" Seed: {seed}") + print("=" * 60) + + # pytorch baseline + print("\n[1] Running PyTorch implementation...") + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # Create PyTorch MoE layer + pytorch_moe = MoELayer(input_dim, output_dim, num_experts, k, hidden_dim) + pytorch_moe.eval() + + # Initialize with Xavier uniform + for param in pytorch_moe.parameters(): + if param.dim() > 1: + nn.init.xavier_uniform_(param) + else: + nn.init.zeros_(param) + + # Create random input + torch.manual_seed(seed) + pytorch_input = torch.randn(batch_size, seq_len, input_dim) + + # Run PyTorch inference + with torch.no_grad(): + pytorch_output = pytorch_moe(pytorch_input, verbose=False) + + print(f"PyTorch output shape: {pytorch_output.shape}") + print( + f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]" + ) + + # extract weights + print("\n[2] Extracting weights from PyTorch model...") + + # Gate weights + gate_weight_pt = ( + pytorch_moe.gate.gate_linear.weight.data + ) # [num_experts, input_dim] + gate_bias_pt = pytorch_moe.gate.gate_linear.bias + if gate_bias_pt is not None: + gate_bias_pt = gate_bias_pt.data + else: + gate_bias_pt = torch.zeros(num_experts) + + # Expert weights + expert_fc1_weights_pt = torch.stack( + [exp.fc1.weight.data for exp in pytorch_moe.experts] + ) + expert_fc1_biases_pt = torch.stack( + [exp.fc1.bias.data for exp in pytorch_moe.experts] + ) + expert_fc2_weights_pt = torch.stack( + [exp.fc2.weight.data for exp in pytorch_moe.experts] + ) + expert_fc2_biases_pt = torch.stack( + [exp.fc2.bias.data for exp in pytorch_moe.experts] + ) + + # convert to numpy + print("\n[3] Converting weights to numpy arrays...") + x_np = np.ascontiguousarray(pytorch_input.detach().numpy(), dtype=np.float32) + + gate_weight_np = np.ascontiguousarray( + gate_weight_pt.detach().numpy(), dtype=np.float32 + ) + gate_bias_np = np.ascontiguousarray(gate_bias_pt.detach().numpy(), dtype=np.float32) + + expert_fc1_weights_np = np.ascontiguousarray( + expert_fc1_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc1_biases_np = np.ascontiguousarray( + expert_fc1_biases_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_weights_np = np.ascontiguousarray( + expert_fc2_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_biases_np = np.ascontiguousarray( + expert_fc2_biases_pt.detach().numpy(), dtype=np.float32 + ) + + print(f"Input shape: {x_np.shape}") + print(f"Gate weight shape: {gate_weight_np.shape}") + print(f"Expert FC1 weights shape: {expert_fc1_weights_np.shape}") + print(f"Expert FC2 weights shape: {expert_fc2_weights_np.shape}") + + # flatten input (N = B * L) + num_tokens = batch_size * seq_len + x_flat_np = x_np.reshape(num_tokens, input_dim) + print(f"Flattened input shape: {x_flat_np.shape}") + + # run allo + print("\n[4] Running Allo implementation...") + try: + # Create optimized schedule + allo_schedule = optimize_moe_schedule( + num_tokens, input_dim, output_dim, num_experts, k, hidden_dim + ) + + # Generate project name + project_name = f"allo_moe_alt_{CONFIG_MODE}.prj" + print(f"Using project name: {project_name}") + + # Build module + print("\n[5] Building Allo module...") + if MODE == "llvm": + mod = allo_schedule.build(target="llvm") + elif MODE == "sw_emu": + mod = allo_schedule.build( + target="vitis_hls", mode="sw_emu", project=project_name + ) + elif MODE == "hw_emu": + mod = allo_schedule.build( + target="vitis_hls", mode="hw_emu", project=project_name + ) + elif MODE == "hw": + mod = allo_schedule.build( + target="vitis_hls", mode="hw", project=project_name + ) + elif MODE == "csyn": + mod = allo_schedule.build( + target="vitis_hls", mode="csyn", project=project_name + ) + else: + raise ValueError(f"Unsupported mode: {MODE}") + + # Run Allo inference + print("\n[6] Running Allo inference...") + if MODE == "llvm": + allo_output_flat = mod( + x_flat_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + ) + elif MODE in ["sw_emu", "hw_emu", "hw"]: + allo_output_flat = np.zeros((num_tokens, output_dim), dtype=np.float32) + mod( + x_flat_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + allo_output_flat, + ) + elif MODE == "csyn": + allo_output_flat = np.zeros((num_tokens, output_dim), dtype=np.float32) + mod() + else: + raise ValueError(f"Unsupported mode: {MODE}") + + # Reshape output back to [B, L, D_out] + allo_output = allo_output_flat.reshape(batch_size, seq_len, output_dim) + + print(f"Allo output shape: {allo_output.shape}") + print(f"Allo output range: [{allo_output.min():.6f}, {allo_output.max():.6f}]") + + # compare + print("\n[7] Comparing outputs...") + pytorch_output_np = pytorch_output.detach().numpy() + + # Compute differences + diff = np.abs(allo_output - pytorch_output_np) + mean_diff = np.mean(diff) + max_diff = np.max(diff) + rel_diff = np.mean(diff / (np.abs(pytorch_output_np) + 1e-8)) + + print(f"Mean absolute difference: {mean_diff:.6e}") + print(f"Max absolute difference: {max_diff:.6e}") + print(f"Mean relative difference: {rel_diff:.6e}") + + # Check if outputs are close + atol = 5e-4 + rtol = 2e-3 + is_close = np.allclose(allo_output, pytorch_output_np, atol=atol, rtol=rtol) + + if is_close: + print( + f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})" + ) + else: + print( + f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})" + ) + print("First few differences:") + print(diff.flatten()[:10]) + + # sample outputs + print("\n[8] Sample outputs (first token, first 5 dimensions):") + print(f"PyTorch: {pytorch_output_np[0, 0, :5]}") + print(f"Allo: {allo_output[0, 0, :5]}") + print(f"Diff: {diff[0, 0, :5]}") + + except Exception as e: + print(f"\n✗ ERROR: Failed to run Allo implementation: {e}") + import traceback + + traceback.print_exc() + + print("=" * 60) diff --git a/ece6775/moe/allo_moe_base.py b/ece6775/moe/allo_moe_base.py new file mode 100644 index 000000000..eba936f99 --- /dev/null +++ b/ece6775/moe/allo_moe_base.py @@ -0,0 +1,566 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# MoE in Allo - manual implementation (no nn lib) +# matches pytorch but with small numerical differences (1e-4 to 1e-3) from: +# - different accumulation order +# - GELU approximation differences +# - fp precision +# these are normal and don't affect performance + +import allo +from allo.ir.types import float32, int32 +from allo import dsl +import sys +import os + +# path setup +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(current_dir) +llm_config_dir = os.path.join(project_root, "llm_config") +if llm_config_dir not in sys.path: + sys.path.insert(0, llm_config_dir) + +MODE = "sw_emu" # llvm, sw_emu, hw_emu, hw, csyn + +from llm_config import DEFAULT_CONFIG_MODE, get_moe_config, print_config_info + +CONFIG_MODE = DEFAULT_CONFIG_MODE + + +def softmax_1d[Ty, N, E](X: "Ty[N, E]") -> "Ty[N, E]": + # softmax over last dim + Z: Ty[N, E] + E_exp: Ty[N, E] + M: Ty[N] = -1000000000000.0 # TODO: use -inf if available + S: Ty[N] = 0.0 + + # find max per row + for i in range(N): + for j in range(E): + if X[i, j] > M[i]: + M[i] = X[i, j] + + # exp and sum + for i in range(N): + for j in range(E): + E_exp[i, j] = dsl.exp(X[i, j] - M[i]) + S[i] += E_exp[i, j] + + # normalize + for i in range(N): + for j in range(E): + Z[i, j] = E_exp[i, j] / S[i] + + return Z + + +# top-1 selection (argmax) +def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": + # pick best expert for each token + indices: int32[N] + max_val: Ty[N] = -1000000000000.0 + + for i in range(N): + indices[i] = 0 + max_val[i] = logits[i, 0] + # check rest + for j in range(1, E): + if logits[i, j] > max_val[i]: + max_val[i] = logits[i, j] + indices[i] = j + + return indices + + +# TopKGate: Gate module to select top-k experts +def topk_gate[ + Ty, N, D, E, K +](x: "Ty[N, D]", gate_weight: "Ty[E, D]", gate_bias: "Ty[E]") -> ( + "Ty[N, E]", + "int32[N, K]", +): + """ + Gate module to select top-k experts for routing. + + Args: + x: Input tensor of shape [N, D] where N = batch * seq_len + gate_weight: Gate weight matrix of shape [E, D] + gate_bias: Gate bias vector of shape [E] + Returns: + full_weights: Sparse weight matrix of shape [N, E] + top_k_indices: Indices of selected experts of shape [N, K] + """ + # Compute logits using linear layer (manual implementation) + # Computes: x @ gate_weight^T + gate_bias = [N, D] @ [D, E] + [E] = [N, E] + # Note: In PyTorch version, bias=False, so gate_bias should be zeros + logits: Ty[N, E] = 0.0 + for i in range(N): + for j in range(E): + # Initialize with bias + logits[i, j] = gate_bias[j] + # Matrix multiplication: x[i] @ gate_weight[j]^T + for k in range(D): + logits[i, j] += x[i, k] * gate_weight[j, k] + + # For k=1, use top1_select + top1_indices_1d: int32[N] = top1_select[Ty, N, E](logits) + + # Expand to [N, K] format (for k=1, K=1) + top_k_indices: int32[N, K] = 0 + for i in range(N): + top_k_indices[i, 0] = top1_indices_1d[i] + + # Get top-k logits + top_k_logits: Ty[N, K] = 0.0 + for i in range(N): + for k in range(K): + expert_idx = top_k_indices[i, k] + top_k_logits[i, k] = logits[i, expert_idx] + + # Apply softmax to top-k logits + top_k_weights = softmax_1d[Ty, N, K](top_k_logits) # [N, K] + + # Create sparse weight matrix + full_weights: Ty[N, E] = 0.0 + for i in range(N): + for k in range(K): + expert_idx = top_k_indices[i, k] + full_weights[i, expert_idx] = top_k_weights[i, k] + + return full_weights, top_k_indices + + +# simple FFN expert +def expert[ + Ty, N, D_in, D_hidden, D_out +]( + x: "Ty[N, D_in]", + fc1_weight: "Ty[D_hidden, D_in]", + fc1_bias: "Ty[D_hidden]", + fc2_weight: "Ty[D_out, D_hidden]", + fc2_bias: "Ty[D_out]", +) -> "Ty[N, D_out]": + # FC1 + fc1_out: Ty[N, D_hidden] = 0.0 + for i in range(N): + for j in range(D_hidden): + fc1_out[i, j] = fc1_bias[j] + for k in range(D_in): + fc1_out[i, j] += x[i, k] * fc1_weight[j, k] + + # GELU + gelu_out: Ty[N, D_hidden] = 0.0 + for i in range(N): + for j in range(D_hidden): + x_val: Ty = fc1_out[i, j] + x3: Ty = x_val * x_val * x_val + inner: Ty = 0.797885 * (x_val + 0.044715 * x3) + tanh_term: Ty = dsl.tanh(inner) + gelu_out[i, j] = 0.5 * x_val * (1.0 + tanh_term) + + # FC2 + fc2_out: Ty[N, D_out] = 0.0 + for i in range(N): + for j in range(D_out): + fc2_out[i, j] = fc2_bias[j] + for k in range(D_hidden): + fc2_out[i, j] += gelu_out[i, k] * fc2_weight[j, k] + + return fc2_out + + +# main MoE layer +def moe_layer[ + Ty, B, L, D_in, D_out, E, K, D_hidden +]( + x: "Ty[B, L, D_in]", + gate_weight: "Ty[E, D_in]", + gate_bias: "Ty[E]", + expert_fc1_weights: "Ty[E, D_hidden, D_in]", + expert_fc1_biases: "Ty[E, D_hidden]", + expert_fc2_weights: "Ty[E, D_out, D_hidden]", + expert_fc2_biases: "Ty[E, D_out]", +) -> "Ty[B, L, D_out]": + # process all tokens through all experts, then weight by gate + # matches pytorch behavior + gate_logits: Ty[B * L, E] = 0.0 + + # Compute gate logits for all tokens + for b in range(B): + for l in range(L): + token_idx = b * L + l + # Compute logits for this token: x[b, l] @ gate_weight^T + gate_bias + for j in range(E): + # Initialize with bias + gate_logits[token_idx, j] = gate_bias[j] + # Matrix multiplication: x[b, l] @ gate_weight[j]^T + for k in range(D_in): + gate_logits[token_idx, j] += x[b, l, k] * gate_weight[j, k] + + # Select top-k experts and compute weights + gate_weights: Ty[B * L, E] = 0.0 + top_k_indices: int32[B * L, K] = 0 + + for i in range(B * L): + # Find top-1 expert (argmax) for this token + top_expert_idx: int32 = 0 + max_logit: Ty = gate_logits[i, 0] + + for j in range(1, E): + if gate_logits[i, j] > max_logit: + max_logit = gate_logits[i, j] + top_expert_idx = j + + # Store top-k indices (for k=1) + top_k_indices[i, 0] = top_expert_idx + + # Get top-k logit and apply softmax + # For k=1, softmax just returns 1.0, but we compute it for consistency + # softmax(top_k_logit) = exp(top_k_logit) / exp(top_k_logit) = 1.0 + gate_weights[i, top_expert_idx] = 1.0 + + # Flatten input for expert processing + x_flat: Ty[B * L, D_in] = 0.0 + for b in range(B): + for l in range(L): + for d in range(D_in): + x_flat[b * L + l, d] = x[b, l, d] + + # Process each expert: compute outputs for all tokens + # Then use gate weights to select correct outputs + expert_outputs: Ty[E, B * L, D_out] = 0.0 + + for expert_idx in range(E): + # Extract expert weights for this expert + expert_fc1_w: Ty[D_hidden, D_in] = 0.0 + expert_fc1_b: Ty[D_hidden] = 0.0 + expert_fc2_w: Ty[D_out, D_hidden] = 0.0 + expert_fc2_b: Ty[D_out] = 0.0 + + for h in range(D_hidden): + for d in range(D_in): + expert_fc1_w[h, d] = expert_fc1_weights[expert_idx, h, d] + expert_fc1_b[h] = expert_fc1_biases[expert_idx, h] + + for o in range(D_out): + for h in range(D_hidden): + expert_fc2_w[o, h] = expert_fc2_weights[expert_idx, o, h] + expert_fc2_b[o] = expert_fc2_biases[expert_idx, o] + + # Process all tokens through this expert (inline expert function) + # First linear layer: fc1_out = x_flat @ expert_fc1_w^T + expert_fc1_b + fc1_out: Ty[B * L, D_hidden] = 0.0 + for i in range(B * L): + for j in range(D_hidden): + # Initialize with bias + fc1_out[i, j] = expert_fc1_b[j] + # Matrix multiplication: x_flat[i] @ expert_fc1_w[j]^T + for k in range(D_in): + fc1_out[i, j] += x_flat[i, k] * expert_fc1_w[j, k] + + # GELU activation: 0.5 * x * (1 + tanh(0.797885 * (x + 0.044715 * x^3))) + gelu_out: Ty[B * L, D_hidden] = 0.0 + for i in range(B * L): + for j in range(D_hidden): + x_val: Ty = fc1_out[i, j] + # Compute x^3 + x3: Ty = x_val * x_val * x_val + # Inner term: sqrt(2/pi) * (x + 0.044715 * x^3) + # Use exact value: sqrt(2/pi) ≈ 0.7978845608028654 + inner: Ty = 0.7978845608028654 * (x_val + 0.044715 * x3) + # Tanh + tanh_term: Ty = dsl.tanh(inner) + # GELU: 0.5 * x * (1 + tanh_term) + gelu_out[i, j] = 0.5 * x_val * (1.0 + tanh_term) + + # Second linear layer: fc2_out = gelu_out @ expert_fc2_w^T + expert_fc2_b + expert_out: Ty[B * L, D_out] = 0.0 + for i in range(B * L): + for j in range(D_out): + # Initialize with bias + expert_out[i, j] = expert_fc2_b[j] + # Matrix multiplication: gelu_out[i] @ expert_fc2_w[j]^T + for k in range(D_hidden): + expert_out[i, j] += gelu_out[i, k] * expert_fc2_w[j, k] + + # Store expert outputs + for i in range(B * L): + for o in range(D_out): + expert_outputs[expert_idx, i, o] = expert_out[i, o] + + # Combine expert outputs using gate weights + # For each token, sum over experts weighted by gate_weights + output_flat: Ty[B * L, D_out] = 0.0 + for i in range(B * L): + for expert_idx in range(E): + weight: Ty = gate_weights[i, expert_idx] + for o in range(D_out): + output_flat[i, o] += expert_outputs[expert_idx, i, o] * weight + + # Reshape back to original shape + output: Ty[B, L, D_out] = 0.0 + for b in range(B): + for l in range(L): + for o in range(D_out): + output[b, l, o] = output_flat[b * L + l, o] + + return output + + +# test/compare with pytorch +if __name__ == "__main__": + import numpy as np + import torch + import torch.nn as nn + from pytorch_moe import MoELayer, run_moe_inference + + config = get_moe_config(CONFIG_MODE) + batch_size = config["batch_size"] + seq_len = config["seq_len"] + input_dim = config["input_dim"] + output_dim = config["output_dim"] + num_experts = config["num_experts"] + k = config["k"] + hidden_dim = config["hidden_dim"] + + seed = 42 + + # Print configuration info using shared function + print("=" * 60) + print("MoE Allo vs PyTorch Comparison Test") + print("=" * 60) + print_config_info(CONFIG_MODE, config) + print(f"Seed: {seed}") + print("=" * 60) + + # pytorch baseline + print("\n[1] Running PyTorch implementation...") + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # Create PyTorch MoE layer + pytorch_moe = MoELayer(input_dim, output_dim, num_experts, k, hidden_dim) + pytorch_moe.eval() + + # Initialize with Xavier uniform (matching PyTorch) + for param in pytorch_moe.parameters(): + if param.dim() > 1: + nn.init.xavier_uniform_(param) + else: + nn.init.zeros_(param) + + # Create random input + torch.manual_seed(seed) + pytorch_input = torch.randn(batch_size, seq_len, input_dim) + + # Run PyTorch inference + with torch.no_grad(): + pytorch_output = pytorch_moe(pytorch_input, verbose=False) + + print(f"PyTorch output shape: {pytorch_output.shape}") + print( + f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]" + ) + + # extract weights + print("\n[2] Extracting weights from PyTorch model...") + + # Gate weights: [input_dim, num_experts] -> [num_experts, input_dim] for Allo + gate_weight_pt = ( + pytorch_moe.gate.gate_linear.weight.data + ) # [num_experts, input_dim] + gate_bias_pt = pytorch_moe.gate.gate_linear.bias + if gate_bias_pt is not None: + gate_bias_pt = gate_bias_pt.data + else: + gate_bias_pt = torch.zeros(num_experts) + + # Expert weights + expert_fc1_weights_pt = torch.stack( + [expert.fc1.weight.data for expert in pytorch_moe.experts] + ) # [num_experts, hidden_dim, input_dim] + expert_fc1_biases_pt = torch.stack( + [expert.fc1.bias.data for expert in pytorch_moe.experts] + ) # [num_experts, hidden_dim] + expert_fc2_weights_pt = torch.stack( + [expert.fc2.weight.data for expert in pytorch_moe.experts] + ) # [num_experts, output_dim, hidden_dim] + expert_fc2_biases_pt = torch.stack( + [expert.fc2.bias.data for expert in pytorch_moe.experts] + ) # [num_experts, output_dim] + + # convert to numpy + print("\n[3] Converting weights to numpy arrays...") + x_np = np.ascontiguousarray(pytorch_input.detach().numpy(), dtype=np.float32) + + # Gate weights: PyTorch has [num_experts, input_dim], Allo expects [num_experts, input_dim] (same) + gate_weight_np = np.ascontiguousarray( + gate_weight_pt.detach().numpy(), dtype=np.float32 + ) + gate_bias_np = np.ascontiguousarray(gate_bias_pt.detach().numpy(), dtype=np.float32) + + # Expert weights: PyTorch has [num_experts, hidden_dim, input_dim], Allo expects [num_experts, hidden_dim, input_dim] (same) + expert_fc1_weights_np = np.ascontiguousarray( + expert_fc1_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc1_biases_np = np.ascontiguousarray( + expert_fc1_biases_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_weights_np = np.ascontiguousarray( + expert_fc2_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_biases_np = np.ascontiguousarray( + expert_fc2_biases_pt.detach().numpy(), dtype=np.float32 + ) + + print(f"Input shape: {x_np.shape}") + print(f"Gate weight shape: {gate_weight_np.shape}") + print(f"Expert FC1 weights shape: {expert_fc1_weights_np.shape}") + print(f"Expert FC2 weights shape: {expert_fc2_weights_np.shape}") + + # run allo + print("\n[4] Running Allo implementation...") + try: + # Customize Allo module + allo_mod = allo.customize( + moe_layer, + instantiate=[ + float32, + batch_size, + seq_len, + input_dim, + output_dim, + num_experts, + k, + hidden_dim, + ], + ) + + # Generate project name based on CONFIG_MODE to avoid conflicts + # This ensures different configurations use different build folders + project_name = f"allo_moe_base_{CONFIG_MODE}.prj" + print(f"Using project name: {project_name}") + + # Build module + print("Building Allo module...") + if MODE == "llvm": + mod = allo_mod.build(target="llvm") # Use LLVM for CPU testing + elif MODE == "sw_emu": + mod = allo_mod.build( + target="vitis_hls", mode="sw_emu", project=project_name + ) + elif MODE == "hw_emu": + mod = allo_mod.build( + target="vitis_hls", mode="hw_emu", project=project_name + ) + elif MODE == "hw": + mod = allo_mod.build(target="vitis_hls", mode="hw", project=project_name) + elif MODE == "csyn": + mod = allo_mod.build(target="vitis_hls", mode="csyn", project=project_name) + else: + raise ValueError(f"Unsupported mode: {MODE}") + + # Run Allo inference + # Note: In LLVM backend, functions with return values return the result directly + # We don't need to pass output buffer as argument + print("Running Allo inference...") + if MODE == "llvm": + allo_output = mod( + x_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + ) + elif MODE == "sw_emu": + allo_output = np.zeros((batch_size, seq_len, output_dim), dtype=np.float32) + mod( + x_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + allo_output, + ) + elif MODE == "hw_emu": + allo_output = np.zeros((batch_size, seq_len, output_dim), dtype=np.float32) + mod( + x_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + allo_output, + ) + elif MODE == "hw": + allo_output = np.zeros((batch_size, seq_len, output_dim), dtype=np.float32) + mod( + x_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + allo_output, + ) + elif MODE == "csyn": + allo_output = np.zeros((batch_size, seq_len, output_dim), dtype=np.float32) + mod() + else: + raise ValueError(f"Unsupported mode: {MODE}") + + print(f"Allo output shape: {allo_output.shape}") + print(f"Allo output range: [{allo_output.min():.6f}, {allo_output.max():.6f}]") + + # compare + print("\n[5] Comparing outputs...") + pytorch_output_np = pytorch_output.detach().numpy() + + # Compute differences + diff = np.abs(allo_output - pytorch_output_np) + mean_diff = np.mean(diff) + max_diff = np.max(diff) + rel_diff = np.mean(diff / (np.abs(pytorch_output_np) + 1e-8)) + + print(f"Mean absolute difference: {mean_diff:.6e}") + print(f"Max absolute difference: {max_diff:.6e}") + print(f"Mean relative difference: {rel_diff:.6e}") + + # check if close (1e-4 to 1e-3 diff is normal from accumulation order, GELU approx, fp precision) + atol = 5e-4 + rtol = 2e-3 + is_close = np.allclose(allo_output, pytorch_output_np, atol=atol, rtol=rtol) + + if is_close: + print( + f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})" + ) + else: + print( + f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})" + ) + print("First few differences:") + print(diff.flatten()[:10]) + + # sample outputs + print("\n[6] Sample outputs (first token, first 5 dimensions):") + print(f"PyTorch: {pytorch_output_np[0, 0, :5]}") + print(f"Allo: {allo_output[0, 0, :5]}") + print(f"Diff: {diff[0, 0, :5]}") + + except Exception as e: + print(f"\n✗ ERROR: Failed to run Allo implementation: {e}") + import traceback + + traceback.print_exc() + + print("=" * 60) diff --git a/ece6775/moe/allo_moe_lib.py b/ece6775/moe/allo_moe_lib.py new file mode 100644 index 000000000..e433c2498 --- /dev/null +++ b/ece6775/moe/allo_moe_lib.py @@ -0,0 +1,511 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# MoE using Allo library functions (nn.linear2d, nn.GeLU) +# follows test_bert pattern - all functions in one file with generic types + +import numpy as np +import allo +import allo.library.nn as nn +from allo.library.nn import linear2d, GeLU # Direct import for type inference +from allo.ir.types import float32, int32 +from allo import dsl +import sys +import os + +# path setup +current_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(current_dir) +llm_config_dir = os.path.join(project_root, "llm_config") +if llm_config_dir not in sys.path: + sys.path.insert(0, llm_config_dir) + +MODE = "sw_emu" # llvm, sw_emu, hw_emu, hw, csyn + +from llm_config import DEFAULT_CONFIG_MODE, get_moe_config, print_config_info + +CONFIG_MODE = DEFAULT_CONFIG_MODE + + +def softmax_1d[Ty, N, K](X: "Ty[N, K]") -> "Ty[N, K]": + # softmax over last dim + Z: Ty[N, K] + E_exp: Ty[N, K] + M: Ty[N] = -1000000000000.0 + S: Ty[N] = 0.0 + + # find max per row + for n, k in dsl.grid(N, K, name="row_max"): + if X[n, k] > M[n]: + M[n] = X[n, k] + + # exp and sum + for n, k in dsl.grid(N, K, name="exp_sum"): + E_exp[n, k] = dsl.exp(X[n, k] - M[n]) + S[n] += E_exp[n, k] + + # normalize + for n, k in dsl.grid(N, K, name="update"): + Z[n, k] = E_exp[n, k] / (S[n]) + + return Z + + +# top-1 selection +def top1_select[Ty, N, E](logits: "Ty[N, E]") -> "int32[N]": + # pick best expert per token + indices: int32[N] + max_val: Ty[N] = -1000000000000.0 + + for n in range(N, name="init"): + indices[n] = 0 + max_val[n] = logits[n, 0] + + for n, e in dsl.grid(N, E, name="argmax"): + if e > 0: # skip e=0 + if logits[n, e] > max_val[n]: + max_val[n] = logits[n, e] + indices[n] = e + + return indices + + +# FFN expert using library functions +def expert[ + Ty, N, D_in, D_hidden, D_out +]( + x: "Ty[N, D_in]", + fc1_weight: "Ty[D_hidden, D_in]", + fc1_bias: "Ty[D_hidden]", + fc2_weight: "Ty[D_out, D_hidden]", + fc2_bias: "Ty[D_out]", +) -> "Ty[N, D_out]": + # using linear2d and GeLU from library (imported at top) + fc1_out = linear2d[Ty, Ty, Ty, N, D_hidden, D_in](x, fc1_weight, fc1_bias) + gelu_out = GeLU[Ty, N, D_hidden](fc1_out) + fc2_out = linear2d[Ty, Ty, Ty, N, D_out, D_hidden](gelu_out, fc2_weight, fc2_bias) + return fc2_out + + +# main MoE layer +def moe_layer[ + Ty, B, L, D_in, D_out, E, K, D_hidden +]( + x: "Ty[B, L, D_in]", + gate_weight: "Ty[E, D_in]", + gate_bias: "Ty[E]", + expert_fc1_weights: "Ty[E, D_hidden, D_in]", + expert_fc1_biases: "Ty[E, D_hidden]", + expert_fc2_weights: "Ty[E, D_out, D_hidden]", + expert_fc2_biases: "Ty[E, D_out]", +) -> "Ty[B, L, D_out]": + # Flatten batch and sequence dimensions: N = B * L + N = B * L + x_flat: Ty[N, D_in] = 0.0 + for b, l, d_in in dsl.grid(B, L, D_in, name="flatten"): + x_flat[b * L + l, d_in] = x[b, l, d_in] + + # Step 1: Compute gate logits using linear2d (inlined topk_gate logic) + # Note: Use linear2d (direct import) to avoid conflict with torch.nn + gate_logits = linear2d[Ty, Ty, Ty, N, E, D_in]( + x_flat, gate_weight, gate_bias + ) # [N, E] + + # Step 2: Select top-1 expert using top1_select function + top1_indices_1d: int32[N] = top1_select[Ty, N, E](gate_logits) + + # Step 3: Get top-k logits and apply softmax + top_k_logits: Ty[N, K] = 0.0 + for n, k in dsl.grid(N, K, name="topk_logits"): + expert_idx = top1_indices_1d[n] if k == 0 else 0 # For k=1, K=1 + top_k_logits[n, k] = gate_logits[n, expert_idx] + + # Step 4: Apply softmax to top-k logits using softmax_1d function + top_k_weights = softmax_1d[Ty, N, K](top_k_logits) # [N, K] + + # Step 5: Create sparse weight matrix from top-k weights + top_k_indices: int32[N, K] = 0 + gate_weights: Ty[N, E] = 0.0 + + for n in range(N, name="gate_weights"): + # Store top-k indices (for k=1) + top_k_indices[n, 0] = top1_indices_1d[n] + # Set gate weights from softmax output + expert_idx = top1_indices_1d[n] + gate_weights[n, expert_idx] = top_k_weights[n, 0] # For k=1, K=1 + + # Step 2: Process each expert: compute outputs for all tokens + expert_outputs: Ty[E, N, D_out] = 0.0 + + for e in range(E, name="expert_loop"): + # Extract expert weights for this expert + expert_fc1_w: Ty[D_hidden, D_in] = 0.0 + expert_fc1_b: Ty[D_hidden] = 0.0 + expert_fc2_w: Ty[D_out, D_hidden] = 0.0 + expert_fc2_b: Ty[D_out] = 0.0 + + for d_hidden, d_in in dsl.grid(D_hidden, D_in, name="extract_fc1_w"): + expert_fc1_w[d_hidden, d_in] = expert_fc1_weights[e, d_hidden, d_in] + + for d_hidden in range(D_hidden, name="extract_fc1_b"): + expert_fc1_b[d_hidden] = expert_fc1_biases[e, d_hidden] + + for d_out, d_hidden in dsl.grid(D_out, D_hidden, name="extract_fc2_w"): + expert_fc2_w[d_out, d_hidden] = expert_fc2_weights[e, d_out, d_hidden] + + for d_out in range(D_out, name="extract_fc2_b"): + expert_fc2_b[d_out] = expert_fc2_biases[e, d_out] + + # Process all tokens through this expert using the expert function + expert_out = expert[Ty, N, D_in, D_hidden, D_out]( + x_flat, expert_fc1_w, expert_fc1_b, expert_fc2_w, expert_fc2_b + ) # [N, D_out] + + # Store expert outputs + for n, d_out in dsl.grid(N, D_out, name="store_expert_out"): + expert_outputs[e, n, d_out] = expert_out[n, d_out] + + # Step 3: Combine expert outputs using gate weights + output_flat: Ty[N, D_out] = 0.0 + for n, e, d_out in dsl.grid(N, E, D_out, name="combine_outputs"): + weight: Ty = gate_weights[n, e] + output_flat[n, d_out] += expert_outputs[e, n, d_out] * weight + + # Step 4: Reshape back to original shape + output: Ty[B, L, D_out] = 0.0 + for b, l, d_out in dsl.grid(B, L, D_out, name="reshape"): + output[b, l, d_out] = output_flat[b * L + l, d_out] + + return output + + +# ================================================================================== +# Schedule optimization function +# ================================================================================== +def optimize_moe_with_composition( + batch_size, seq_len, input_dim, output_dim, num_experts, k, hidden_dim +): + # create schedules for sub-functions and compose them + # TODO: add optimizations later (pipeline, partition, etc) + Ty = float32 + N = batch_size * seq_len + D_in = input_dim + E = num_experts + K = k + D_hidden = hidden_dim + D_out = output_dim + + print("=" * 60) + print("Creating and optimizing MoE schedules...") + print("=" * 60) + + # top1_select schedule + print("\n[1] Creating schedule for top1_select...") + s_top1 = allo.customize(top1_select, instantiate=[Ty, N, E]) + print(" - Created top1_select schedule for [N, E]") + + # softmax schedule + print("\n[2] Creating schedule for softmax_1d...") + s_softmax = allo.customize(softmax_1d, instantiate=[Ty, N, K]) + print(" - Created softmax_1d schedule for [N, K]") + + # expert schedule - need to create library function schedules first (type inference workaround) + print("\n[3] Creating schedule for expert...") + import allo.library.nn as allo_nn + + print(" - Creating schedules for library functions (linear2d, GeLU)...") + s_linear_fc1 = allo.customize( + allo_nn.linear2d, instantiate=[Ty, Ty, Ty, N, D_hidden, D_in] + ) + s_gelu = allo.customize(allo_nn.GeLU, instantiate=[Ty, N, D_hidden]) + s_linear_fc2 = allo.customize( + allo_nn.linear2d, instantiate=[Ty, Ty, Ty, N, D_out, D_hidden] + ) + print(" - Library function schedules created") + + print(" - Creating expert schedule...") + s_expert = allo.customize(expert, instantiate=[Ty, N, D_in, D_hidden, D_out]) + + # compose library functions into expert + s_expert.compose(s_linear_fc1, id="expert_fc1") + s_expert.compose(s_gelu, id="expert_gelu") + s_expert.compose(s_linear_fc2, id="expert_fc2") + print(" - Created expert schedule") + print(" - Composed nn.linear2d and nn.GeLU schedules for expert") + + # moe_layer schedule + print("\n[4] Creating schedule for moe_layer...") + s_moe = allo.customize( + moe_layer, + instantiate=[ + Ty, + batch_size, + seq_len, + input_dim, + output_dim, + num_experts, + k, + hidden_dim, + ], + ) + + # compose everything + print("\n[5] Composing all schedules together...") + + # gate linear2d + s_gate_linear = allo.customize( + allo_nn.linear2d, instantiate=[Ty, Ty, Ty, N, E, D_in] + ) + s_moe.compose(s_gate_linear, id="gate") + print(" - Composed linear2d (gate logits) schedule") + + # top1 and softmax + s_moe.compose(s_top1) + s_moe.compose(s_softmax) + print(" - Composed top1_select and softmax_1d schedules") + + # expert (already has library functions composed) + s_moe.compose(s_expert) + print(" - Composed expert schedule (includes nn.linear2d and nn.GeLU)") + + print("\n" + "=" * 60) + print("Schedule composition complete!") + print("=" * 60) + + return s_moe + + +# test/compare with pytorch +if __name__ == "__main__": + import torch + import torch.nn as nn + import sys + import os + + # Import pytorch_moe from the same directory + from pytorch_moe import MoELayer + + # ============================================================================ + # Configuration parameters - using shared config module + # ============================================================================ + # Get configuration from shared module + config = get_moe_config(CONFIG_MODE) + batch_size = config["batch_size"] + seq_len = config["seq_len"] + input_dim = config["input_dim"] + output_dim = config["output_dim"] + num_experts = config["num_experts"] + k = config["k"] + hidden_dim = config["hidden_dim"] + + seed = 42 + + # Print configuration info using shared function + print_config_info(CONFIG_MODE, config) + print(f"Seed: {seed}") + + # ---------------------------------------------------------------------------------- + # Run PyTorch implementation to get weights and outputs + # ---------------------------------------------------------------------------------- + print("\n[1] Running PyTorch implementation...") + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # Create PyTorch MoE layer from pytorch_moe.py + pytorch_moe = MoELayer(input_dim, output_dim, num_experts, k, hidden_dim) + pytorch_moe.eval() + + # Initialize with Xavier uniform (matching PyTorch) + for param in pytorch_moe.parameters(): + if param.dim() > 1: + nn.init.xavier_uniform_(param) + else: + nn.init.zeros_(param) + + # Create random input + torch.manual_seed(seed) + pytorch_input = torch.randn(batch_size, seq_len, input_dim) + + # Run PyTorch inference + with torch.no_grad(): + pytorch_output = pytorch_moe(pytorch_input, verbose=False) + + print(f"PyTorch output shape: {pytorch_output.shape}") + print( + f"PyTorch output range: [{pytorch_output.min().item():.6f}, {pytorch_output.max().item():.6f}]" + ) + + # ---------------------------------------------------------------------------------- + # Extract weights and biases from PyTorch model + # ---------------------------------------------------------------------------------- + print("\n[2] Extracting weights from PyTorch model...") + + # Gate weights: [input_dim, num_experts] -> [num_experts, input_dim] for Allo + gate_weight_pt = ( + pytorch_moe.gate.gate_linear.weight.data + ) # [num_experts, input_dim] + gate_bias_pt = pytorch_moe.gate.gate_linear.bias + if gate_bias_pt is not None: + gate_bias_pt = gate_bias_pt.data + else: + gate_bias_pt = torch.zeros(num_experts) + + # Expert weights + expert_fc1_weights_pt = torch.stack( + [expert.fc1.weight.data for expert in pytorch_moe.experts] + ) # [num_experts, hidden_dim, input_dim] + expert_fc1_biases_pt = torch.stack( + [expert.fc1.bias.data for expert in pytorch_moe.experts] + ) # [num_experts, hidden_dim] + expert_fc2_weights_pt = torch.stack( + [expert.fc2.weight.data for expert in pytorch_moe.experts] + ) # [num_experts, output_dim, hidden_dim] + expert_fc2_biases_pt = torch.stack( + [expert.fc2.bias.data for expert in pytorch_moe.experts] + ) # [num_experts, output_dim] + + # ---------------------------------------------------------------------------------- + # Convert the weights and biases to numpy arrays (ensure C-contiguous and correct shape) + # ---------------------------------------------------------------------------------- + print("\n[3] Converting weights to numpy arrays...") + x_np = np.ascontiguousarray(pytorch_input.detach().numpy(), dtype=np.float32) + + # Gate weights: PyTorch has [num_experts, input_dim], Allo expects [num_experts, input_dim] (same) + gate_weight_np = np.ascontiguousarray( + gate_weight_pt.detach().numpy(), dtype=np.float32 + ) + gate_bias_np = np.ascontiguousarray(gate_bias_pt.detach().numpy(), dtype=np.float32) + + # Expert weights: PyTorch has [num_experts, hidden_dim, input_dim], Allo expects [num_experts, hidden_dim, input_dim] (same) + expert_fc1_weights_np = np.ascontiguousarray( + expert_fc1_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc1_biases_np = np.ascontiguousarray( + expert_fc1_biases_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_weights_np = np.ascontiguousarray( + expert_fc2_weights_pt.detach().numpy(), dtype=np.float32 + ) + expert_fc2_biases_np = np.ascontiguousarray( + expert_fc2_biases_pt.detach().numpy(), dtype=np.float32 + ) + + print(f"Input shape: {x_np.shape}") + print(f"Gate weight shape: {gate_weight_np.shape}") + print(f"Expert FC1 weights shape: {expert_fc1_weights_np.shape}") + print(f"Expert FC2 weights shape: {expert_fc2_weights_np.shape}") + + # ---------------------------------------------------------------------------------- + # Run Allo implementation + # ---------------------------------------------------------------------------------- + print("\n[4] Running Allo implementation...") + try: + # Create optimized schedule with composition + allo_schedule = optimize_moe_with_composition( + batch_size, seq_len, input_dim, output_dim, num_experts, k, hidden_dim + ) + + # Generate project name based on CONFIG_MODE to avoid conflicts + # This ensures different configurations use different build folders + project_name = f"allo_moe_lib_{CONFIG_MODE}.prj" + print(f"Using project name: {project_name}") + + # Build module + print("\n[5] Building Allo module...") + if MODE == "llvm": + mod = allo_schedule.build(target="llvm") + elif MODE == "sw_emu": + mod = allo_schedule.build( + target="vitis_hls", mode="sw_emu", project=project_name + ) + elif MODE == "hw_emu": + mod = allo_schedule.build( + target="vitis_hls", mode="hw_emu", project=project_name + ) + elif MODE == "hw": + mod = allo_schedule.build( + target="vitis_hls", mode="hw", project=project_name + ) + elif MODE == "csyn": + mod = allo_schedule.build( + target="vitis_hls", mode="csyn", project=project_name + ) + else: + raise ValueError(f"Unsupported mode: {MODE}") + + # Run Allo inference + print("\n[6] Running Allo inference...") + if MODE == "llvm": + allo_output = mod( + x_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + ) + elif MODE == "sw_emu" or MODE == "hw_emu" or MODE == "hw": + allo_output = np.zeros((batch_size, seq_len, output_dim), dtype=np.float32) + mod( + x_np, + gate_weight_np, + gate_bias_np, + expert_fc1_weights_np, + expert_fc1_biases_np, + expert_fc2_weights_np, + expert_fc2_biases_np, + allo_output, + ) + elif MODE == "csyn": + allo_output = np.zeros((batch_size, seq_len, output_dim), dtype=np.float32) + mod() + else: + raise ValueError(f"Unsupported mode: {MODE}") + + print(f"Allo output shape: {allo_output.shape}") + print(f"Allo output range: [{allo_output.min():.6f}, {allo_output.max():.6f}]") + + # compare + print("\n[7] Comparing outputs...") + pytorch_output_np = pytorch_output.detach().numpy() + + diff = np.abs(allo_output - pytorch_output_np) + mean_diff = np.mean(diff) + max_diff = np.max(diff) + rel_diff = np.mean(diff / (np.abs(pytorch_output_np) + 1e-8)) + + print(f"Mean absolute difference: {mean_diff:.6e}") + print(f"Max absolute difference: {max_diff:.6e}") + print(f"Mean relative difference: {rel_diff:.6e}") + + # check if close (1e-4 to 1e-3 diff is normal) + atol = 5e-4 + rtol = 2e-3 + is_close = np.allclose(allo_output, pytorch_output_np, atol=atol, rtol=rtol) + + if is_close: + print( + f"\n✓ SUCCESS: Allo output matches PyTorch output (atol={atol}, rtol={rtol})" + ) + else: + print( + f"\n✗ WARNING: Allo output differs from PyTorch output (atol={atol}, rtol={rtol})" + ) + print("First few differences:") + print(diff.flatten()[:10]) + + # sample outputs + print("\n[8] Sample outputs (first token, first 5 dimensions):") + print(f"PyTorch: {pytorch_output_np[0, 0, :5]}") + print(f"Allo: {allo_output[0, 0, :5]}") + print(f"Diff: {diff[0, 0, :5]}") + + except Exception as e: + print(f"\n✗ ERROR: Failed to run Allo implementation: {e}") + import traceback + + traceback.print_exc() + + print("=" * 60) diff --git a/ece6775/moe/pytorch_moe.py b/ece6775/moe/pytorch_moe.py new file mode 100644 index 000000000..fea2f445a --- /dev/null +++ b/ece6775/moe/pytorch_moe.py @@ -0,0 +1,366 @@ +# Copyright Allo authors. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +""" +Modern Mixture of Experts (MoE) Implementation for Inference Only + +This script implements a simplified MoE layer optimized for inference. +Uses random inputs and weights for demonstration. +""" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Tuple, Optional + + +class TopKGate(nn.Module): + """ + Gate module to select top k experts for routing. + + Args: + input_dim: Input feature dimension + num_experts: Total number of experts + k: Number of experts to select per token + """ + + def __init__(self, input_dim: int, num_experts: int, k: int = 1): + super().__init__() + self.k = k + self.num_experts = num_experts + # Linear layer to compute expert logits + self.gate_linear = nn.Linear(input_dim, num_experts, bias=False) + + def forward( + self, x: torch.Tensor, verbose: bool = False + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Forward pass through the gate. + + Args: + x: Input tensor of shape [batch_size * seq_len, input_dim] + verbose: Whether to print routing information + + Returns: + full_weights: Sparse weight matrix [batch_size * seq_len, num_experts] + top_k_indices: Indices of selected experts [batch_size * seq_len, k] + """ + # Compute logits for all experts + logits = self.gate_linear(x) # [N, num_experts] + + # Select top-k experts + top_k_logits, top_k_indices = torch.topk(logits, self.k, dim=-1) + + # Apply softmax to top-k logits for normalized weights + top_k_weights = F.softmax(top_k_logits, dim=-1) # [N, k] + + # Create sparse weight matrix (zeros for non-selected experts) + full_weights = torch.zeros_like(logits) + full_weights.scatter_(1, top_k_indices, top_k_weights) + + # Print routing information if verbose + if verbose: + num_tokens = x.shape[0] + + # Count tokens per expert + expert_counts = {} + for expert_idx in range(self.num_experts): + count = (top_k_indices == expert_idx).sum().item() + if count > 0: + expert_counts[expert_idx] = count + + # Verify that each token selects exactly k experts + tokens_per_expert_count = {} + for i in range(num_tokens): + num_experts_selected = (top_k_weights[i] > 1e-6).sum().item() + tokens_per_expert_count[num_experts_selected] = ( + tokens_per_expert_count.get(num_experts_selected, 0) + 1 + ) + + print(f"\n[Gate Routing] Total tokens: {num_tokens}, k={self.k}") + print(f"[Gate Routing] Expert distribution:") + total_selections = num_tokens * self.k # Total expert-token pairs + for expert_idx, count in sorted(expert_counts.items()): + percentage = (count / total_selections) * 100 + print( + f" Expert {expert_idx}: {count} selections ({percentage:.1f}% of {total_selections} total selections)" + ) + + # Verify top-k: each token selects exactly k experts + if tokens_per_expert_count.get(self.k, 0) == num_tokens: + print( + f"[Gate Routing] ✓ All {num_tokens} tokens select exactly {self.k} expert(s)" + ) + else: + print( + f"[Gate Routing] ⚠ Expected all tokens to select {self.k} expert(s), but got: {tokens_per_expert_count}" + ) + + return full_weights, top_k_indices + + +class Expert(nn.Module): + """ + A simple feed-forward expert network. + + Args: + input_dim: Input feature dimension + hidden_dim: Hidden layer dimension + output_dim: Output feature dimension + """ + + def __init__(self, input_dim: int, hidden_dim: int, output_dim: int): + super().__init__() + self.fc1 = nn.Linear(input_dim, hidden_dim) + self.fc2 = nn.Linear(hidden_dim, output_dim) + # Use GELU activation (modern choice) + self.activation = nn.GELU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass through the expert.""" + x = self.fc1(x) + x = self.activation(x) + x = self.fc2(x) + return x + + +class MoELayer(nn.Module): + """ + Mixture of Experts layer for inference. + + Args: + input_dim: Input feature dimension + output_dim: Output feature dimension + num_experts: Total number of experts + k: Number of experts to activate per token + expert_hidden_dim: Hidden dimension for experts (default: 4 * input_dim) + """ + + def __init__( + self, + input_dim: int, + output_dim: int, + num_experts: int, + k: int = 1, + expert_hidden_dim: Optional[int] = None, + ): + super().__init__() + self.num_experts = num_experts + self.k = k + self.output_dim = output_dim + + if expert_hidden_dim is None: + expert_hidden_dim = input_dim * 4 # Common practice + + # Initialize gate and experts + self.gate = TopKGate(input_dim, num_experts, k) + self.experts = nn.ModuleList( + [ + Expert(input_dim, expert_hidden_dim, output_dim) + for _ in range(num_experts) + ] + ) + + def forward(self, x: torch.Tensor, verbose: bool = False) -> torch.Tensor: + """ + Forward pass through MoE layer. + + Args: + x: Input tensor of shape [batch_size, seq_len, input_dim] + verbose: Whether to print verification information for top-1 MoE + + Returns: + Output tensor of shape [batch_size, seq_len, output_dim] + """ + + # Store original shape + original_shape = x.shape + batch_size, seq_len = original_shape[0], original_shape[1] + + # Flatten batch and sequence dimensions + x_flat = x.view(-1, original_shape[-1]) # [N, input_dim], N = batch * seq_len + num_tokens = x_flat.shape[0] + + # Get gating weights and expert indices + gate_weights, top_k_indices = self.gate( + x_flat, verbose=verbose + ) # [N, num_experts], [N, k] + + # Initialize output tensor + output = torch.zeros( + x_flat.shape[0], self.output_dim, device=x.device, dtype=x.dtype + ) + + # Track expert usage statistics + expert_usage_stats = {} + + # Process each expert + for expert_idx in range(self.num_experts): + # Find tokens assigned to this expert + # Create mask: tokens that have this expert in their top-k + expert_mask = (top_k_indices == expert_idx).any(dim=-1) # [N] + + num_tokens_for_expert = expert_mask.sum().item() + + if not expert_mask.any(): + expert_usage_stats[expert_idx] = 0 + continue # No tokens for this expert + + # Track usage + expert_usage_stats[expert_idx] = num_tokens_for_expert + + # Get inputs for this expert + expert_inputs = x_flat[expert_mask] # [num_tokens, input_dim] + + # Process through expert + expert_outputs = self.experts[expert_idx]( + expert_inputs + ) # [num_tokens, output_dim] + + # Get corresponding weights for these tokens + # For each token, find the weight for this expert (could be in any of k positions) + token_indices = torch.where(expert_mask)[0] # Indices in flattened tensor + expert_weights = gate_weights[token_indices, expert_idx].unsqueeze( + 1 + ) # [num_tokens, 1] + + # Weight and accumulate outputs + weighted_outputs = expert_outputs * expert_weights + output[token_indices] += weighted_outputs + + # Verify top-k behavior + if verbose: + active_experts = [ + idx for idx, count in expert_usage_stats.items() if count > 0 + ] + total_usage = sum(expert_usage_stats.values()) + + # Verify top-k: each token selects exactly k experts + if top_k_indices.shape[1] != self.k: + print( + f"[MoE Verification] ✗ ERROR: Expected k={self.k}, but top_k_indices has shape {top_k_indices.shape}" + ) + + # Verify all tokens are processed (for k>1, total_usage may be > num_tokens) + expected_total_usage = num_tokens * self.k + if total_usage != expected_total_usage: + print( + f"[MoE Verification] ⚠ Expected {expected_total_usage} expert-token pairs (={num_tokens} tokens × {self.k} experts), but got {total_usage}" + ) + + # Check if weights are normalized (should sum to 1.0 per token) + sample_weights = gate_weights.sum(dim=1) + weights_normalized = torch.allclose( + sample_weights, torch.ones_like(sample_weights), atol=1e-5 + ) + + # Final summary for top-k + print(f"\n[MoE Verification] === Top-{self.k} MoE Verification ===") + print(f" ✓ Each token routes to exactly {self.k} expert(s) (k={self.k})") + print(f" ✓ All {num_tokens} tokens processed") + print( + f" ✓ Total expert-token pairs: {total_usage} (expected: {expected_total_usage})" + ) + print(f" ✓ Active experts: {len(active_experts)}/{self.num_experts}") + print( + f" ✓ Expert distribution: {dict(sorted(expert_usage_stats.items()))}" + ) + if weights_normalized: + print(f" ✓ Gate weights normalized (sum to 1.0 per token)") + else: + print(f" ✗ Gate weights NOT normalized correctly") + print(f" === Top-{self.k} MoE is working correctly! ===") + + # Reshape back to original shape + output = output.view(batch_size, seq_len, self.output_dim) + + return output + + +def run_moe_inference( + batch_size: int = 4, + seq_len: int = 4, + input_dim: int = 64, + output_dim: int = 64, + num_experts: int = 2, + k: int = 1, + seed: int = 24, + verbose: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, MoELayer]: + """ + Run MoE layer inference with fixed seed for reproducible inputs and weights. + + Args: + batch_size: Batch size + seq_len: Sequence length + input_dim: Input feature dimension + output_dim: Output feature dimension + num_experts: Number of experts + k: Top-k experts to activate per token + seed: Random seed for reproducibility + verbose: Whether to print verification information + + Returns: + output: Output tensor from MoE layer [batch_size, seq_len, output_dim] + input_tensor: Input tensor used for inference + moe_layer: MoE layer instance with initialized weights + """ + # Set random seed for reproducibility + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + + # Create MoE layer + moe_layer = MoELayer(input_dim, output_dim, num_experts, k) + moe_layer.eval() + + # Initialize with random weights (Xavier uniform for better stability) + for param in moe_layer.parameters(): + if param.dim() > 1: + nn.init.xavier_uniform_(param) + else: + nn.init.zeros_(param) + + # Create random input + torch.manual_seed(seed) + input_tensor = torch.randn(batch_size, seq_len, input_dim) + + # Run inference with verbose output for verification + with torch.no_grad(): + output = moe_layer(input_tensor, verbose=verbose) + + return output, input_tensor, moe_layer + + +if __name__ == "__main__": + # Configuration parameters + batch_size = 4 + seq_len = 1 + input_dim = 1 + output_dim = 1 + num_experts = 2 + k = 1 # Top-k MoE: each token uses exactly k experts + seed = 24 + + print("=" * 60) + print(f"Top-{k} MoE Inference Test") + print("=" * 60) + print( + f"Configuration: batch={batch_size}, seq_len={seq_len}, input_dim={input_dim}" + ) + print(f"Experts: {num_experts}, k={k}") + print("=" * 60) + + # Run MoE inference + output, input_tensor, moe_layer = run_moe_inference( + batch_size=batch_size, + seq_len=seq_len, + input_dim=input_dim, + output_dim=output_dim, + num_experts=num_experts, + k=k, + seed=seed, + verbose=True, + ) + + print(f"\nOutput shape: {output.shape}") + print("=" * 60)