Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,21 @@
#
# model download: https://huggingface.co/BlinkDL/rwkv7-g1
#
args.MODEL_NAME = "/mnt/e/RWKV-Runner/models/rwkv7-g1a-0.1b-20250728-ctx4096"
args.n_layer = 12
args.n_embd = 768
# args.MODEL_NAME = "/mnt/e/RWKV-Runner/models/rwkv7-g1-0.4b-20250324-ctx4096"
# args.MODEL_NAME = "./rwkv7-g1a-0.1b-20250728-ctx4096"
# args.n_layer = 12
# args.n_embd = 768
# args.MODEL_NAME = "./rwkv7-g1-0.4b-20250324-ctx4096"
# args.n_layer = 24
# args.n_embd = 1024
# args.MODEL_NAME = "/mnt/e/RWKV-Runner/models/rwkv7-g1-1.5b-20250429-ctx4096"
# args.MODEL_NAME = "./rwkv7-g1-1.5b-20250429-ctx4096"
# args.n_layer = 24
# args.n_embd = 2048
# args.MODEL_NAME = "/mnt/e/RWKV-Runner/models/rwkv7-g1-2.9b-20250519-ctx4096"
# args.MODEL_NAME = "./rwkv7-g1-2.9b-20250519-ctx4096"
# args.n_layer = 32
# args.n_embd = 2560
# args.MODEL_NAME = "/mnt/e/RWKV-Runner/models/rwkv7-g0a-7.2b-20250829-ctx4096"
# args.n_layer = 32
# args.n_embd = 4096
args.MODEL_NAME = "./rwkv7-g0a-7.2b-20250829-ctx4096"
args.n_layer = 32
args.n_embd = 4096

print(f'\nUsing CUDA fp16. Loading {args.MODEL_NAME} ...\n')

Expand Down Expand Up @@ -65,7 +65,7 @@ def xprint(s):

prompt = "The Eiffel tower is in the city of"
print(prompt)

torch.compiler.cudagraph_mark_step_begin()
init_out, init_state = model.forward(tokenizer.encode(prompt), None)
probs = F.softmax(init_out.float(), dim=-1) # compute softmax in float (more accurate)
_, indices = torch.topk(probs, 5) # print top-5 possibilities
Expand All @@ -89,11 +89,15 @@ def xprint(s):

all_tokens = []
out_last = 0
torch.compiler.cudagraph_mark_step_begin()
init_out, init_state = model.forward(tokenizer.encode(prompt), None)
out, state = init_out.clone(), copy.deepcopy(init_state)

min_time = 1e10
min_time_all = 1e10
# min_time = 1e10
# min_time_all = 1e10

all_times = []

t000 = time.perf_counter()
for i in range(LENGTH_PER_TRIAL):
t00 = time.perf_counter()
Expand All @@ -109,11 +113,16 @@ def xprint(s):

torch.cuda.synchronize()
t0 = time.perf_counter()
torch.compiler.cudagraph_mark_step_begin()
out, state = model.forward(token, state)
out, state = out.clone(), copy.deepcopy(state)
torch.cuda.synchronize()
t1 = time.perf_counter()
min_time = min(min_time, t1 - t0)
min_time_all = min(min_time_all, t1 - t00)
# min_time = min(min_time, t1 - t0)
# min_time_all = min(min_time_all, t1 - t00)
all_times.append(t1 - t0)

min_time=min_time_all=np.median(all_times)

print(f'\n\nToken/s = {round(1/min_time,2)} (forward), {round(1/min_time_all,2)} (full) || Bandwidth = {round(active_GB/min_time,2)} GB/s || {round(time.perf_counter()-t000,3)}s')

Expand Down
24 changes: 15 additions & 9 deletions reference/rwkv7.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,13 @@
import torch.nn as nn
from torch.nn import functional as F

MyModule = torch.jit.ScriptModule
MyFunction = torch.jit.script_method
MyStatic = torch.jit.script
# MyModule = nn.Module
# def __nop(ob): return ob
# MyFunction = __nop
# MyStatic = __nop
# MyModule = torch.jit.ScriptModule
# MyFunction = torch.jit.script_method
# MyStatic = torch.jit.script
MyModule = nn.Module
def __nop(ob): return ob
MyFunction = __nop
MyStatic = __nop

DTYPE = torch.half

Expand All @@ -47,6 +47,8 @@ def forward(ctx, state, r, w, k, v, a, b):
y = torch.empty((T, C), device=k.device, dtype=DTYPE, requires_grad=False, memory_format=torch.contiguous_format)
torch.ops.rwkv7_state_fwd_fp16.forward(1, T, C, H, state, r, w, k, v, a, b, y)
return y

@torch.compiler.disable
def RWKV7_OP(state, r, w, k, v, a, b):
return WKV_7.apply(state, r, w, k, v, a, b)

Expand Down Expand Up @@ -96,7 +98,10 @@ def forward(self, idx, state, full_output=False):
else:
return self.forward_one(idx, state)

@MyFunction
@torch.compile(mode='max-autotune-no-cudagraphs')
# @torch.compile(mode='reduce-overhead')
# @torch.compile(mode='max-autotune')
# @MyFunction
def forward_one(self, idx:int, state:List[torch.Tensor]):
with torch.no_grad():
z = self.z
Expand Down Expand Up @@ -127,7 +132,8 @@ def forward_one(self, idx:int, state:List[torch.Tensor]):
x = x @ z['head.weight']
return x, state

@MyFunction
# @torch.compile(mode='max-autotune-no-cudagraphs')
# @MyFunction
def forward_seq(self, idx:List[int], state:List[torch.Tensor], full_output:bool=False):
with torch.no_grad():
z = self.z
Expand Down