From 1d13ea15f607d2bb846bc09cbafe7ba16816ace6 Mon Sep 17 00:00:00 2001 From: xTimeCrystal <68882569+xTimeCrystal@users.noreply.github.com> Date: Fri, 5 Sep 2025 18:29:08 +0800 Subject: [PATCH 1/4] Update rwkv7.py --- reference/rwkv7.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/reference/rwkv7.py b/reference/rwkv7.py index a1edfc1..b40a9f1 100644 --- a/reference/rwkv7.py +++ b/reference/rwkv7.py @@ -19,13 +19,13 @@ import torch.nn as nn from torch.nn import functional as F -MyModule = torch.jit.ScriptModule -MyFunction = torch.jit.script_method -MyStatic = torch.jit.script -# MyModule = nn.Module -# def __nop(ob): return ob -# MyFunction = __nop -# MyStatic = __nop +# MyModule = torch.jit.ScriptModule +# MyFunction = torch.jit.script_method +# MyStatic = torch.jit.script +MyModule = nn.Module +def __nop(ob): return ob +MyFunction = __nop +MyStatic = __nop DTYPE = torch.half @@ -47,6 +47,8 @@ def forward(ctx, state, r, w, k, v, a, b): y = torch.empty((T, C), device=k.device, dtype=DTYPE, requires_grad=False, memory_format=torch.contiguous_format) torch.ops.rwkv7_state_fwd_fp16.forward(1, T, C, H, state, r, w, k, v, a, b, y) return y + +@torch.compiler.disable def RWKV7_OP(state, r, w, k, v, a, b): return WKV_7.apply(state, r, w, k, v, a, b) @@ -96,7 +98,9 @@ def forward(self, idx, state, full_output=False): else: return self.forward_one(idx, state) - @MyFunction + # @torch.compile(mode='max-autotune-no-cudagraphs') + # @torch.compile(mode='reduce-overhead') + @torch.compile(mode='max-autotune') def forward_one(self, idx:int, state:List[torch.Tensor]): with torch.no_grad(): z = self.z @@ -127,7 +131,7 @@ def forward_one(self, idx:int, state:List[torch.Tensor]): x = x @ z['head.weight'] return x, state - @MyFunction + @torch.compile(mode='max-autotune-no-cudagraphs') def forward_seq(self, idx:List[int], state:List[torch.Tensor], full_output:bool=False): with torch.no_grad(): z = self.z From db8e5f343af756187ba73f95f63c8b529709e8a6 Mon Sep 17 00:00:00 2001 From: xTimeCrystal <68882569+xTimeCrystal@users.noreply.github.com> Date: Fri, 5 Sep 2025 18:29:29 +0800 Subject: [PATCH 2/4] Update benchmark.py --- benchmark.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/benchmark.py b/benchmark.py index ab4a11c..82eec48 100644 --- a/benchmark.py +++ b/benchmark.py @@ -23,7 +23,7 @@ # # model download: https://huggingface.co/BlinkDL/rwkv7-g1 # -args.MODEL_NAME = "/mnt/e/RWKV-Runner/models/rwkv7-g1a-0.1b-20250728-ctx4096" +args.MODEL_NAME = "./rwkv7-g1a-0.1b-20250728-ctx4096" args.n_layer = 12 args.n_embd = 768 # args.MODEL_NAME = "/mnt/e/RWKV-Runner/models/rwkv7-g1-0.4b-20250324-ctx4096" @@ -65,7 +65,7 @@ def xprint(s): prompt = "The Eiffel tower is in the city of" print(prompt) - +torch.compiler.cudagraph_mark_step_begin() init_out, init_state = model.forward(tokenizer.encode(prompt), None) probs = F.softmax(init_out.float(), dim=-1) # compute softmax in float (more accurate) _, indices = torch.topk(probs, 5) # print top-5 possibilities @@ -89,6 +89,7 @@ def xprint(s): all_tokens = [] out_last = 0 +torch.compiler.cudagraph_mark_step_begin() init_out, init_state = model.forward(tokenizer.encode(prompt), None) out, state = init_out.clone(), copy.deepcopy(init_state) @@ -109,7 +110,9 @@ def xprint(s): torch.cuda.synchronize() t0 = time.perf_counter() + torch.compiler.cudagraph_mark_step_begin() out, state = model.forward(token, state) + out, state = out.clone(), copy.deepcopy(state) torch.cuda.synchronize() t1 = time.perf_counter() min_time = min(min_time, t1 - t0) From 6ae3e2a179e338e230ab62a4ec2ee0d748fb5162 Mon Sep 17 00:00:00 2001 From: xTimeCrystal <68882569+xTimeCrystal@users.noreply.github.com> Date: Fri, 5 Sep 2025 19:05:58 +0800 Subject: [PATCH 3/4] Update benchmark.py --- benchmark.py | 32 +++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/benchmark.py b/benchmark.py index 82eec48..99debed 100644 --- a/benchmark.py +++ b/benchmark.py @@ -23,21 +23,21 @@ # # model download: https://huggingface.co/BlinkDL/rwkv7-g1 # -args.MODEL_NAME = "./rwkv7-g1a-0.1b-20250728-ctx4096" -args.n_layer = 12 -args.n_embd = 768 -# args.MODEL_NAME = "/mnt/e/RWKV-Runner/models/rwkv7-g1-0.4b-20250324-ctx4096" +# args.MODEL_NAME = "./rwkv7-g1a-0.1b-20250728-ctx4096" +# args.n_layer = 12 +# args.n_embd = 768 +# args.MODEL_NAME = "./rwkv7-g1-0.4b-20250324-ctx4096" # args.n_layer = 24 # args.n_embd = 1024 -# args.MODEL_NAME = "/mnt/e/RWKV-Runner/models/rwkv7-g1-1.5b-20250429-ctx4096" +# args.MODEL_NAME = "./rwkv7-g1-1.5b-20250429-ctx4096" # args.n_layer = 24 # args.n_embd = 2048 -# args.MODEL_NAME = "/mnt/e/RWKV-Runner/models/rwkv7-g1-2.9b-20250519-ctx4096" +# args.MODEL_NAME = "./rwkv7-g1-2.9b-20250519-ctx4096" # args.n_layer = 32 # args.n_embd = 2560 -# args.MODEL_NAME = "/mnt/e/RWKV-Runner/models/rwkv7-g0a-7.2b-20250829-ctx4096" -# args.n_layer = 32 -# args.n_embd = 4096 +args.MODEL_NAME = "./rwkv7-g0a-7.2b-20250829-ctx4096" +args.n_layer = 32 +args.n_embd = 4096 print(f'\nUsing CUDA fp16. Loading {args.MODEL_NAME} ...\n') @@ -93,8 +93,11 @@ def xprint(s): init_out, init_state = model.forward(tokenizer.encode(prompt), None) out, state = init_out.clone(), copy.deepcopy(init_state) -min_time = 1e10 -min_time_all = 1e10 +# min_time = 1e10 +# min_time_all = 1e10 + +all_times = [] + t000 = time.perf_counter() for i in range(LENGTH_PER_TRIAL): t00 = time.perf_counter() @@ -115,8 +118,11 @@ def xprint(s): out, state = out.clone(), copy.deepcopy(state) torch.cuda.synchronize() t1 = time.perf_counter() - min_time = min(min_time, t1 - t0) - min_time_all = min(min_time_all, t1 - t00) + # min_time = min(min_time, t1 - t0) + # min_time_all = min(min_time_all, t1 - t00) + all_times.append(t1 - t0) + +min_time=min_time_all=np.median(all_times) print(f'\n\nToken/s = {round(1/min_time,2)} (forward), {round(1/min_time_all,2)} (full) || Bandwidth = {round(active_GB/min_time,2)} GB/s || {round(time.perf_counter()-t000,3)}s') From 30465d1f9d970a9e377a01f6572ee997f4d406d2 Mon Sep 17 00:00:00 2001 From: xTimeCrystal <68882569+xTimeCrystal@users.noreply.github.com> Date: Fri, 5 Sep 2025 19:06:14 +0800 Subject: [PATCH 4/4] Update rwkv7.py --- reference/rwkv7.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/reference/rwkv7.py b/reference/rwkv7.py index b40a9f1..3f4cbae 100644 --- a/reference/rwkv7.py +++ b/reference/rwkv7.py @@ -98,9 +98,10 @@ def forward(self, idx, state, full_output=False): else: return self.forward_one(idx, state) - # @torch.compile(mode='max-autotune-no-cudagraphs') + @torch.compile(mode='max-autotune-no-cudagraphs') # @torch.compile(mode='reduce-overhead') - @torch.compile(mode='max-autotune') + # @torch.compile(mode='max-autotune') + # @MyFunction def forward_one(self, idx:int, state:List[torch.Tensor]): with torch.no_grad(): z = self.z @@ -131,7 +132,8 @@ def forward_one(self, idx:int, state:List[torch.Tensor]): x = x @ z['head.weight'] return x, state - @torch.compile(mode='max-autotune-no-cudagraphs') + # @torch.compile(mode='max-autotune-no-cudagraphs') + # @MyFunction def forward_seq(self, idx:List[int], state:List[torch.Tensor], full_output:bool=False): with torch.no_grad(): z = self.z