diff --git a/Pipfile b/Pipfile index b65227a..576163e 100644 --- a/Pipfile +++ b/Pipfile @@ -35,6 +35,7 @@ PyYAML = "*" Shapely = "*" numpy = "*" scipy = "*" +marisa_trie = "*" [dev-packages] diff --git a/litecoder/__init__.py b/litecoder/__init__.py index e67b0f9..ec7ce0a 100644 --- a/litecoder/__init__.py +++ b/litecoder/__init__.py @@ -9,9 +9,9 @@ DATA_DIR = os.path.join(os.path.dirname(__file__), 'data') -US_STATE_PATH = os.path.join(DATA_DIR, 'us-states.p') +US_STATE_PATH = os.path.join(DATA_DIR, 'us-states.marisa') -US_CITY_PATH = os.path.join(DATA_DIR, 'us-cities.p') +US_CITY_PATH = os.path.join(DATA_DIR, 'us-cities.marisa') logging.basicConfig( diff --git a/litecoder/usa.py b/litecoder/usa.py index d6ff15a..362d485 100644 --- a/litecoder/usa.py +++ b/litecoder/usa.py @@ -1,7 +1,8 @@ import re -import pickle +import marisa_trie +import ujson as json from tqdm import tqdm from collections import defaultdict @@ -147,99 +148,53 @@ def state_key_iter(row): yield ' '.join((abbr, usa)) -class Match: - - def __init__(self, row): - """Set model class, PK, metadata. - """ - state = inspect(row) - - # Don't store the actual row, so we can serialize. - self._model_cls = state.class_ - self._pk = state.identity - - self.data = Box(dict(row)) - - @cached_property - def db_row(self): - """Hydrate database row, lazily. - """ - return self._model_cls.query.get(self._pk) - - -class CityMatch(Match): - - def __repr__(self): - return '%s<%s, %s, %s, wof:%d>' % ( - self.__class__.__name__, - self.data.name, - self.data.name_a1, - self.data.name_a0, - self.data.wof_id, - ) - - -class StateMatch(Match): - - def __repr__(self): - return '%s<%s, %s, wof:%d>' % ( - self.__class__.__name__, - self.data.name, - self.data.name_a0, - self.data.wof_id, - ) - - class Index: - @classmethod - def load(cls, path): - with open(path, 'rb') as fh: - return pickle.load(fh) + def load(self, path, mmap=False): + if mmap: + self._trie.mmap(path) + else: + self._trie.load(path) def __init__(self): - self._key_to_ids = defaultdict(set) - self._id_to_loc = dict() + self._trie = marisa_trie.BytesTrie() + + # We use prefixes here to store the keys -> ids and ids -> loc "maps" as subtrees in one marisa trie. + self._keys_prefix = "A" + self._ids_prefix = "B" def __len__(self): - return len(self._key_to_ids) + return len(self._trie.keys(self._keys_prefix)) def __repr__(self): return '%s<%d keys, %d entities>' % ( self.__class__.__name__, - len(self._key_to_ids), - len(self._id_to_loc), + len(self._trie.keys(self._keys_prefix)), + len(self._trie.keys(self._ids_prefix)), ) def __getitem__(self, text): """Get ids, map to records only if there is a match in the index """ - if keyify(text) not in self._key_to_ids: + normalized_key = self._keys_prefix + keyify(text) + val = self._trie.get(normalized_key, None) + if not val: return None + ids = json.loads(val[0]) - ids = self._key_to_ids[keyify(text)] - - return [self._id_to_loc[id] for id in ids] - - def add_key(self, key, id): - self._key_to_ids[key].add(id) - - def add_location(self, id, location): - self._id_to_loc[id] = location + return [json.loads(self._trie[self._ids_prefix + id][0]) for id in ids] def locations(self): - return list(self._id_to_loc.values()) + return self._trie.items(self._ids_prefix) def save(self, path): - with open(path, 'wb') as fh: - pickle.dump(self, fh) + self._trie.save(path) class USCityIndex(Index): - @classmethod - def load(cls, path=US_CITY_PATH): - return super().load(path) + def load(self, path=US_CITY_PATH, mmap=False): + return super().load(path, mmap) def __init__(self, bare_name_blocklist=None): super().__init__() @@ -248,6 +203,7 @@ def __init__(self, bare_name_blocklist=None): def build(self): """Index all US cities. """ + allow_bare = AllowBareCityName(blocklist=self.bare_name_blocklist) iter_keys = CityKeyIter(allow_bare) @@ -257,21 +213,27 @@ def build(self): logger.info('Indexing US cities.') + key_to_ids = defaultdict(set) + id_to_loc_items = list() + for row in tqdm(cities): # Key -> id(s) for key in map(keyify, iter_keys(row)): - self.add_key(key, row.wof_id) + key_to_ids[key].add(str(row.wof_id)) # ID -> city - self.add_location(row.wof_id, CityMatch(row)) + id_to_loc_items.append((self._ids_prefix + str(row.wof_id), bytes(json.dumps(dict(row)), encoding="utf-8"))) + + key_to_ids_items = [(self._keys_prefix + key, json.dumps(list(key_to_ids[key])).encode("utf-8")) for key in key_to_ids] + + self._trie = marisa_trie.BytesTrie(key_to_ids_items + id_to_loc_items) class USStateIndex(Index): - @classmethod - def load(cls, path=US_STATE_PATH): - return super().load(path) + def load(self, path=US_STATE_PATH, mmap=False): + return super().load(path, mmap) def build(self): """Index all US states. @@ -280,11 +242,18 @@ def build(self): logger.info('Indexing US states.') + key_to_ids = defaultdict(set) + id_to_loc_items = list() + for row in tqdm(states): # Key -> id(s) for key in map(keyify, state_key_iter(row)): - self.add_key(key, row.wof_id) + key_to_ids[key].add(str(row.wof_id)) # ID -> state - self.add_location(row.wof_id, StateMatch(row)) + id_to_loc_items.append((self._ids_prefix + str(row.wof_id), bytes(json.dumps(dict(row)), encoding="utf-8"))) + + key_to_ids_items = [(self._keys_prefix + key, json.dumps(list(key_to_ids[key])).encode("utf-8")) for key in key_to_ids] + + self._trie = marisa_trie.BytesTrie(key_to_ids_items + id_to_loc_items) diff --git a/tests/prod_db/conftest.py b/tests/prod_db/conftest.py index 1005d12..06583e3 100644 --- a/tests/prod_db/conftest.py +++ b/tests/prod_db/conftest.py @@ -7,9 +7,13 @@ @pytest.fixture(scope='session') def city_idx(): - return USCityIndex.load() + city_idx = USCityIndex() + city_idx.load() + return city_idx @pytest.fixture(scope='session') def state_idx(): - return USStateIndex.load() + state_idx = USStateIndex() + state_idx.load() + return state_idx diff --git a/tests/prod_db/test_us_city_index.py b/tests/prod_db/test_us_city_index.py index aec57f0..e05cfa2 100644 --- a/tests/prod_db/test_us_city_index.py +++ b/tests/prod_db/test_us_city_index.py @@ -33,7 +33,7 @@ def test_cases(city_idx, query, matches, xfail): res = city_idx[query] - ids = [r.data.wof_id for r in res] + ids = [r["wof_id"] for r in res] # Exact id list match. assert sorted(ids) == sorted(matches) @@ -49,6 +49,6 @@ def test_topn(city_idx, city): """Smoke test N most populous cities. """ res = city_idx['%s, %s' % (city.name, city.name_a1)] - res_ids = [r.data.wof_id for r in res] + res_ids = [r["wof_id"] for r in res] assert city.wof_id in res_ids diff --git a/tests/prod_db/test_us_state_index.py b/tests/prod_db/test_us_state_index.py index 9adafc6..2ae1561 100644 --- a/tests/prod_db/test_us_state_index.py +++ b/tests/prod_db/test_us_state_index.py @@ -28,7 +28,7 @@ def test_cases(state_idx, query, matches): res = state_idx[query] - ids = [r.data.wof_id for r in res] + ids = [r["wof_id"] for r in res] assert sorted(ids) == sorted(matches) @@ -41,6 +41,6 @@ def test_all(state_idx, state): """Smoke test N most populous cities. """ res = state_idx[state.name] - res_ids = [r.data.wof_id for r in res] + res_ids = [r["wof_id"] for r in res] assert state.wof_id in res_ids diff --git a/tests/runtime/concurrency_test.py b/tests/runtime/concurrency_test.py new file mode 100644 index 0000000..6adc618 --- /dev/null +++ b/tests/runtime/concurrency_test.py @@ -0,0 +1,45 @@ +from multiprocessing import Pool +from litecoder.usa import USCityIndex, USStateIndex +import time + +NUM_PROCESSES = 4 + +# Load 50 test city lookups +with open("tests/runtime/test_city_lookups.txt", "r") as lookups_file: + city_tests = lookups_file.read().splitlines() + +# Increase the number of lookups for the speed test if necessary +for x in range (10): + city_tests += city_tests +num_tests_per_process = len(city_tests) +num_tests = NUM_PROCESSES * num_tests_per_process + +# Load USCityIndex +city_idx = USCityIndex() +city_idx.load() + + +def lookup_cities(process_num): + print ('Process {}: looking up {} cities'.format(process_num, num_tests_per_process)) + start_time = time.time() + for city in city_tests: + city_idx[city] + ms = 1000*(time.time() - start_time) + print("Process {}: finished, took {}ms @ {} ms/lookup!".format(process_num, ms, float(ms/num_tests_per_process))) + +if __name__ == '__main__': + print("Looking up {} cities on {} processes...".format(num_tests, NUM_PROCESSES)) + start_time = time.time() + with Pool(5) as p: + p.map(lookup_cities, range(1, NUM_PROCESSES+1)) + ms = 1000*(time.time() - start_time) + print("Fully finished: took {}ms @ {} ms/lookup!".format(ms, float(ms/num_tests))) + + print() + print("Looking up all {} cities on one process...".format(num_tests), end="") + start_time = time.time() + for i in range(NUM_PROCESSES): + for city in city_tests: + city_idx[city] + ms = 1000*(time.time() - start_time) + print("finished: took {}ms @ {} ms/lookup!".format(ms, ms/num_tests)) \ No newline at end of file diff --git a/tests/runtime/speed_test.py b/tests/runtime/speed_test.py new file mode 100644 index 0000000..17a7092 --- /dev/null +++ b/tests/runtime/speed_test.py @@ -0,0 +1,44 @@ +from litecoder.usa import USCityIndex, USStateIndex +import time + +print("Loading USCityIndex... ", end="") +start_time = time.time() +city_idx = USCityIndex() +city_idx.load() +print("finished: {}s!".format(time.time() - start_time)) + +# Load 50 test city lookups +with open("test_city_lookups.txt", "r") as lookups_file: + city_tests = lookups_file.read().splitlines() + +# Increase the number of lookups for the speed test if necessary +for x in range (5): + city_tests += city_tests +num_tests = len(city_tests) +print("measuring time for {} cities... ".format(num_tests), end="") +start_time = time.time() +for city in city_tests: + city_idx[city] +ms = 1000*(time.time() - start_time) +print("finished: took {}ms at {} ms/lookup!".format(ms, float(ms/num_tests))) + +print("Loading USStateIndex... ", end="") +start_time = time.time() +state_idx = USStateIndex() +state_idx.load() +print("finished: {}s!".format(time.time() - start_time)) + +# Load 50 test state lookups +with open("test_state_lookups.txt", "r") as lookups_file: + state_tests = lookups_file.read().splitlines() + +# Increase the number of lookups for the speed test if necessary +for x in range (5): + state_tests += state_tests +num_tests = len(state_tests) +print("measuring time for {} states... ".format(num_tests), end="") +start_time = time.time() +for state in state_tests: + state_idx[state] +ms = 1000*(time.time() - start_time) +print("finished: took {}ms at {} ms/lookup!".format(ms, float(ms/num_tests))) diff --git a/tests/runtime/test_city_lookups.txt b/tests/runtime/test_city_lookups.txt new file mode 100644 index 0000000..43d1bd5 --- /dev/null +++ b/tests/runtime/test_city_lookups.txt @@ -0,0 +1,50 @@ +Edinburg, Texas +Lakeville , Minnesota +Woodland, CA. +Gary, IN +Cornelius, NC +Okeechobee, Fl +Saginaw Township South, MI +Lansdowne, PA +Knoxville TN +OAKLAND, CA +suffolk va +Port Orange, FL +Sedona, AZ +Cedar City UT +Cincinnati. +Huntington Beach CA +Wooster,Ohio +Lewisville, Texas +traverse city mi +Pennsauken, New Jersey +Jonesboro, Arkansas +Zephyrhills, FL +West Jefferson, NC +Escondido, CA +Lumberton, NC +Cayce, SC +Stratford, Connecticut, USA +Avondale, AZ +Coral Springs, FL +Gaithersburg, MD +Westchester, IL +Louisa, Virginia +Norway, ME +Philadelphia PA, USA +Fort worth, tx +Eureka Springs, Arkansas +Nashville , TN +Ellenwood Ga +Floral Park, NY +Nashville Tennessee +Malvern, AR +Valdosta, Georgia +Valley Center Ca +St. Robert Mo. +Hollandale, MS +New Castle, PA +Harlem, FL +Kings Mills, OH +knoxville Tennessee +BrooklYn \ No newline at end of file diff --git a/tests/runtime/test_state_lookups.txt b/tests/runtime/test_state_lookups.txt new file mode 100644 index 0000000..fd8b3c9 --- /dev/null +++ b/tests/runtime/test_state_lookups.txt @@ -0,0 +1,50 @@ +North Carolina, USA +District of Columbia +Illinois, United States +Georgia United States +north carolina +texas +iowa +Florida, United States +Vermont, USA +TX USA +FL U.S.A. +pennsylvania usa +nebraska + Oregon +Pennsylvania +New Hampshire USA +Nebraska, USA +New mexico +Indiana +South Dakota + Oklahoma +Ohio,US +Kansas, USA +indiana +MA, USA + New York +Ohio, United States +NJ USA +ohio usa +Connecticut, USA +MICHIGAN, United States +Missouri +New York +California - USA +Massachusetts, USA + Missouri +FL, United States of America +New Hampshire +Georgia. +Nevada USA + PENNSYLVANIA +Virginia, USA. +Alabama, USA +Indiana +Louisiana, United States +New Mexico +Ohio USA +Nevada, USA +LOUISIANA +New Jersey, us \ No newline at end of file