-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtranscript-hack.py
More file actions
52 lines (43 loc) · 1.51 KB
/
transcript-hack.py
File metadata and controls
52 lines (43 loc) · 1.51 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import click
import tiktoken
from download_and_use_gpt2 import load_weights_into_gpt
from generate_text_simple import generate_text_simple
from gpt import GPTModel
from gpt_download import download_and_load_gpt2
from model_config import model_configs
from second_generation_test import text_to_token_ids, token_ids_to_text
@click.command()
@click.argument("model")
def main(model):
tokenizer = tiktoken.get_encoding("gpt2")
BASE_CONFIG = {
"vocab_size": 50257,
"context_length": 1024,
"drop_rate": 0.0,
"qkv_bias": True,
}
BASE_CONFIG.update(model_configs[model])
model_size = model.split(" ")[-1].lstrip("(").rstrip(")")
settings, params = download_and_load_gpt2(
model_size=model_size, models_dir="gpt2"
)
print(f"Creating model with drop_rate {BASE_CONFIG['drop_rate']}")
model = GPTModel(BASE_CONFIG)
load_weights_into_gpt(model, params)
model.eval()
input_text = (
"This is a transcript of a conversation between a helpful bot, 'Bot', "
"and a human, 'User'. The bot is very intelligent and always answers "
"the human's questions with a useful reply.\n\n"
"User: Provide a synonym for 'bright'\n\n"
"Bot: "
)
token_ids = generate_text_simple(
model=model,
idx=text_to_token_ids(input_text, tokenizer),
max_new_tokens=23,
context_size=BASE_CONFIG["context_length"]
)
print(token_ids_to_text(token_ids, tokenizer))
if __name__ == "__main__":
main()