|
1 | 1 | """ |
2 | 2 | A simple n-gram (word) Shannon-style language model with add-one smoothing. |
3 | 3 | """ |
4 | | -import sys, traceback, os |
5 | | -import random, pickle, tempfile, re |
| 4 | +import os, random, pickle, bz2, tempfile |
6 | 5 | from pathlib import Path |
7 | 6 | from io import StringIO |
8 | 7 | from lf_toolkit.evaluation import Result, Params |
9 | | -from .utils import csv_to_lists |
10 | | -import nltk |
11 | | -from nltk.corpus import brown, reuters, gutenberg, webtext |
| 8 | +from .utils import csv_to_lists, build_counts |
| 9 | + |
12 | 10 |
|
13 | 11 | # Local users run the following once (no need if using Docker): |
14 | 12 | #nltk.download("brown"); nltk.download("reuters"); nltk.download("gutenberg"); nltk.download("webtext") # CHANGE (one-time) |
15 | 13 |
|
16 | 14 | START, END = "<s>", "</s>" |
17 | 15 |
|
18 | | -def corpus_sents(): # CHANGE |
19 | | - # Each yields lists of tokens already sentence-segmented |
20 | | - for s in brown.sents(): yield s |
21 | | - for s in reuters.sents(): yield s |
22 | | - for s in gutenberg.sents(): yield s |
23 | | - for s in webtext.sents(): yield s |
24 | | - |
25 | 16 | # Setup paths for saving/loading model and data |
26 | 17 | BASE_DIR = Path(__file__).resolve().parent |
27 | 18 | MODEL_DIR = Path(os.environ.get("MODEL_DIR", BASE_DIR / "storage")) |
28 | 19 | MODEL_DIR.mkdir(parents=True, exist_ok=True) |
29 | 20 | WORD_LENGTHS_PATH = MODEL_DIR / "norvig_word_length_frequencies.csv" |
30 | | -FILE = Path(tempfile.gettempdir()) / "ngram_counts.pkl" |
31 | | - |
32 | | -# If not cache: |
33 | | -def corpus_sents(): # CHANGE |
34 | | - # Each yields lists of tokens already sentence-segmented |
35 | | - for s in brown.sents(): yield s |
36 | | - for s in reuters.sents(): yield s |
37 | | - for s in gutenberg.sents(): yield s |
38 | | - for s in webtext.sents(): yield s |
39 | | - |
40 | | -def build_counts(n=3): |
41 | | - counts = {} |
42 | | - for sent in corpus_sents(): |
43 | | - tokens = [w.lower() for w in sent] |
44 | | - s = ([START] * (n - 1)) + tokens + ([END] if n > 1 else []) |
45 | | - for i in range(len(s)-n+1): |
46 | | - ctx = tuple(s[i:i+n-1]) |
47 | | - nxt = s[i+n-1] |
48 | | - counts.setdefault(ctx, {}) |
49 | | - counts[ctx][nxt] = counts[ctx].get(nxt, 0) + 1 |
50 | | - return counts |
51 | | -# End caching part |
| 21 | +# If creating when deployed: |
| 22 | +#FILE = Path(tempfile.gettempdir()) / "ngram_counts.pkl" |
| 23 | +# If creating locally, to be copied when deployed: |
| 24 | +FILE = MODEL_DIR / "ngram_counts.pkl.bz2" |
52 | 25 |
|
53 | | -# Always used: |
54 | 26 | def get_counts(n=3): |
55 | 27 | if os.path.exists(FILE): |
56 | | - with open(FILE, "rb") as f: |
| 28 | + with bz2.BZ2File(FILE, "rb") as f: |
57 | 29 | cache = pickle.load(f) |
58 | | - else: |
| 30 | + else: # from here the deployed version will not work because the corpora are not bundled (to save space) |
59 | 31 | cache = {} |
60 | 32 | if n not in cache: |
61 | | - cache[n] = build_counts(n) |
| 33 | + print(f"Building counts for n={n} (this may take a while)...") |
| 34 | + cache[n] = build_counts(n, START, END) # similarly, only works if NLTK corpora are available |
62 | 35 | try: |
63 | | - with open(FILE, "wb") as f: |
| 36 | + with bz2.BZ2File(FILE, "wb") as f: |
64 | 37 | pickle.dump(cache, f) |
65 | 38 | except Exception as e: |
66 | 39 | print(f"Warning: couldn't save n-gram cache to {FILE}: {e}") |
@@ -111,6 +84,7 @@ def run(response, answer, params:Params) -> Result: |
111 | 84 | output.append(generate(context,word_count,context_window)) |
112 | 85 | preface = 'Context window: '+str(context_window)+', Word count: '+str(word_count)+'. Output: <br>' |
113 | 86 | feedback_items = [("general", preface + ' '.join(output))] |
114 | | - feedback_items.append("| Answer not an integer; used default context window") if not response_used else None |
| 87 | + #feedback_items.append("| Answer not an integer; used default context window") if not response_used else None |
115 | 88 | is_correct = True |
| 89 | + print(feedback_items) |
116 | 90 | return Result(is_correct=is_correct,feedback_items=feedback_items) |
0 commit comments