Skip to content

Commit 46c210e

Browse files
Peter JohnsonPeter Johnson
authored andcommitted
Add compressed n-gram cache and updated code to compress/decompress
1 parent 18a54dc commit 46c210e

File tree

3 files changed

+39
-41
lines changed

3 files changed

+39
-41
lines changed

evaluation_function/models/shannon_words_ngram.py

Lines changed: 14 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,66 +1,39 @@
11
"""
22
A simple n-gram (word) Shannon-style language model with add-one smoothing.
33
"""
4-
import sys, traceback, os
5-
import random, pickle, tempfile, re
4+
import os, random, pickle, bz2, tempfile
65
from pathlib import Path
76
from io import StringIO
87
from lf_toolkit.evaluation import Result, Params
9-
from .utils import csv_to_lists
10-
import nltk
11-
from nltk.corpus import brown, reuters, gutenberg, webtext
8+
from .utils import csv_to_lists, build_counts
9+
1210

1311
# Local users run the following once (no need if using Docker):
1412
#nltk.download("brown"); nltk.download("reuters"); nltk.download("gutenberg"); nltk.download("webtext") # CHANGE (one-time)
1513

1614
START, END = "<s>", "</s>"
1715

18-
def corpus_sents(): # CHANGE
19-
# Each yields lists of tokens already sentence-segmented
20-
for s in brown.sents(): yield s
21-
for s in reuters.sents(): yield s
22-
for s in gutenberg.sents(): yield s
23-
for s in webtext.sents(): yield s
24-
2516
# Setup paths for saving/loading model and data
2617
BASE_DIR = Path(__file__).resolve().parent
2718
MODEL_DIR = Path(os.environ.get("MODEL_DIR", BASE_DIR / "storage"))
2819
MODEL_DIR.mkdir(parents=True, exist_ok=True)
2920
WORD_LENGTHS_PATH = MODEL_DIR / "norvig_word_length_frequencies.csv"
30-
FILE = Path(tempfile.gettempdir()) / "ngram_counts.pkl"
31-
32-
# If not cache:
33-
def corpus_sents(): # CHANGE
34-
# Each yields lists of tokens already sentence-segmented
35-
for s in brown.sents(): yield s
36-
for s in reuters.sents(): yield s
37-
for s in gutenberg.sents(): yield s
38-
for s in webtext.sents(): yield s
39-
40-
def build_counts(n=3):
41-
counts = {}
42-
for sent in corpus_sents():
43-
tokens = [w.lower() for w in sent]
44-
s = ([START] * (n - 1)) + tokens + ([END] if n > 1 else [])
45-
for i in range(len(s)-n+1):
46-
ctx = tuple(s[i:i+n-1])
47-
nxt = s[i+n-1]
48-
counts.setdefault(ctx, {})
49-
counts[ctx][nxt] = counts[ctx].get(nxt, 0) + 1
50-
return counts
51-
# End caching part
21+
# If creating when deployed:
22+
#FILE = Path(tempfile.gettempdir()) / "ngram_counts.pkl"
23+
# If creating locally, to be copied when deployed:
24+
FILE = MODEL_DIR / "ngram_counts.pkl.bz2"
5225

53-
# Always used:
5426
def get_counts(n=3):
5527
if os.path.exists(FILE):
56-
with open(FILE, "rb") as f:
28+
with bz2.BZ2File(FILE, "rb") as f:
5729
cache = pickle.load(f)
58-
else:
30+
else: # from here the deployed version will not work because the corpora are not bundled (to save space)
5931
cache = {}
6032
if n not in cache:
61-
cache[n] = build_counts(n)
33+
print(f"Building counts for n={n} (this may take a while)...")
34+
cache[n] = build_counts(n, START, END) # similarly, only works if NLTK corpora are available
6235
try:
63-
with open(FILE, "wb") as f:
36+
with bz2.BZ2File(FILE, "wb") as f:
6437
pickle.dump(cache, f)
6538
except Exception as e:
6639
print(f"Warning: couldn't save n-gram cache to {FILE}: {e}")
@@ -111,6 +84,7 @@ def run(response, answer, params:Params) -> Result:
11184
output.append(generate(context,word_count,context_window))
11285
preface = 'Context window: '+str(context_window)+', Word count: '+str(word_count)+'. Output: <br>'
11386
feedback_items = [("general", preface + ' '.join(output))]
114-
feedback_items.append("| Answer not an integer; used default context window") if not response_used else None
87+
#feedback_items.append("| Answer not an integer; used default context window") if not response_used else None
11588
is_correct = True
89+
print(feedback_items)
11690
return Result(is_correct=is_correct,feedback_items=feedback_items)
13.7 MB
Binary file not shown.
Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,33 @@
11
import csv
2+
import nltk
3+
from nltk.corpus import brown, reuters, gutenberg, webtext
4+
25
def csv_to_lists(filename: str) -> list:
36
frequencies = []
47
with open(filename, newline='') as csvfile:
58
reader = csv.reader(csvfile)
69
next(reader) # Skip header row
710
for key,value in reader:
811
frequencies.append([key, float(value)])
9-
return frequencies
12+
return frequencies
13+
14+
15+
# Generate word ngram counts from NLTK corpora
16+
def corpus_sents(): # CHANGE
17+
# Each yields lists of tokens already sentence-segmented
18+
for s in brown.sents(): yield s
19+
for s in reuters.sents(): yield s
20+
for s in gutenberg.sents(): yield s
21+
for s in webtext.sents(): yield s
22+
23+
def build_counts(n=3, START="<s>", END="</s>"):
24+
counts = {}
25+
for sent in corpus_sents():
26+
tokens = [w.lower() for w in sent]
27+
s = ([START] * (n - 1)) + tokens + ([END] if n > 1 else [])
28+
for i in range(len(s)-n+1):
29+
ctx = tuple(s[i:i+n-1])
30+
nxt = s[i+n-1]
31+
counts.setdefault(ctx, {})
32+
counts[ctx][nxt] = counts[ctx].get(nxt, 0) + 1
33+
return counts

0 commit comments

Comments
 (0)