Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ PyYAML = "*"
Shapely = "*"
numpy = "*"
scipy = "*"
marisa_trie = "*"

[dev-packages]

Expand Down
4 changes: 2 additions & 2 deletions litecoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@

DATA_DIR = os.path.join(os.path.dirname(__file__), 'data')

US_STATE_PATH = os.path.join(DATA_DIR, 'us-states.p')
US_STATE_PATH = os.path.join(DATA_DIR, 'us-states.marisa')

US_CITY_PATH = os.path.join(DATA_DIR, 'us-cities.p')
US_CITY_PATH = os.path.join(DATA_DIR, 'us-cities.marisa')


logging.basicConfig(
Expand Down
121 changes: 45 additions & 76 deletions litecoder/usa.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@


import re
import pickle
import marisa_trie
import ujson as json

from tqdm import tqdm
from collections import defaultdict
Expand Down Expand Up @@ -147,99 +148,53 @@ def state_key_iter(row):
yield ' '.join((abbr, usa))


class Match:

def __init__(self, row):
"""Set model class, PK, metadata.
"""
state = inspect(row)

# Don't store the actual row, so we can serialize.
self._model_cls = state.class_
self._pk = state.identity

self.data = Box(dict(row))

@cached_property
def db_row(self):
"""Hydrate database row, lazily.
"""
return self._model_cls.query.get(self._pk)


class CityMatch(Match):

def __repr__(self):
return '%s<%s, %s, %s, wof:%d>' % (
self.__class__.__name__,
self.data.name,
self.data.name_a1,
self.data.name_a0,
self.data.wof_id,
)


class StateMatch(Match):

def __repr__(self):
return '%s<%s, %s, wof:%d>' % (
self.__class__.__name__,
self.data.name,
self.data.name_a0,
self.data.wof_id,
)


class Index:

@classmethod
def load(cls, path):
with open(path, 'rb') as fh:
return pickle.load(fh)
def load(self, path, mmap=False):
if mmap:
self._trie.mmap(path)
else:
self._trie.load(path)

def __init__(self):
self._key_to_ids = defaultdict(set)
self._id_to_loc = dict()
self._trie = marisa_trie.BytesTrie()

# We use prefixes here to store the keys -> ids and ids -> loc "maps" as subtrees in one marisa trie.
self._keys_prefix = "A"
self._ids_prefix = "B"

def __len__(self):
return len(self._key_to_ids)
return len(self._trie.keys(self._keys_prefix))

def __repr__(self):
return '%s<%d keys, %d entities>' % (
self.__class__.__name__,
len(self._key_to_ids),
len(self._id_to_loc),
len(self._trie.keys(self._keys_prefix)),
len(self._trie.keys(self._ids_prefix)),
)

def __getitem__(self, text):
"""Get ids, map to records only if there is a match in the index
"""
if keyify(text) not in self._key_to_ids:
normalized_key = self._keys_prefix + keyify(text)
val = self._trie.get(normalized_key, None)
if not val:
return None
ids = json.loads(val[0])

ids = self._key_to_ids[keyify(text)]

return [self._id_to_loc[id] for id in ids]

def add_key(self, key, id):
self._key_to_ids[key].add(id)

def add_location(self, id, location):
self._id_to_loc[id] = location
return [json.loads(self._trie[self._ids_prefix + id][0]) for id in ids]

def locations(self):
return list(self._id_to_loc.values())
return self._trie.items(self._ids_prefix)

def save(self, path):
with open(path, 'wb') as fh:
pickle.dump(self, fh)
self._trie.save(path)


class USCityIndex(Index):

@classmethod
def load(cls, path=US_CITY_PATH):
return super().load(path)
def load(self, path=US_CITY_PATH, mmap=False):
return super().load(path, mmap)

def __init__(self, bare_name_blocklist=None):
super().__init__()
Expand All @@ -248,6 +203,7 @@ def __init__(self, bare_name_blocklist=None):
def build(self):
"""Index all US cities.
"""

allow_bare = AllowBareCityName(blocklist=self.bare_name_blocklist)

iter_keys = CityKeyIter(allow_bare)
Expand All @@ -257,21 +213,27 @@ def build(self):

logger.info('Indexing US cities.')

key_to_ids = defaultdict(set)
id_to_loc_items = list()

for row in tqdm(cities):

# Key -> id(s)
for key in map(keyify, iter_keys(row)):
self.add_key(key, row.wof_id)
key_to_ids[key].add(str(row.wof_id))

# ID -> city
self.add_location(row.wof_id, CityMatch(row))
id_to_loc_items.append((self._ids_prefix + str(row.wof_id), bytes(json.dumps(dict(row)), encoding="utf-8")))

key_to_ids_items = [(self._keys_prefix + key, json.dumps(list(key_to_ids[key])).encode("utf-8")) for key in key_to_ids]

self._trie = marisa_trie.BytesTrie(key_to_ids_items + id_to_loc_items)


class USStateIndex(Index):

@classmethod
def load(cls, path=US_STATE_PATH):
return super().load(path)
def load(self, path=US_STATE_PATH, mmap=False):
return super().load(path, mmap)

def build(self):
"""Index all US states.
Expand All @@ -280,11 +242,18 @@ def build(self):

logger.info('Indexing US states.')

key_to_ids = defaultdict(set)
id_to_loc_items = list()

for row in tqdm(states):

# Key -> id(s)
for key in map(keyify, state_key_iter(row)):
self.add_key(key, row.wof_id)
key_to_ids[key].add(str(row.wof_id))

# ID -> state
self.add_location(row.wof_id, StateMatch(row))
id_to_loc_items.append((self._ids_prefix + str(row.wof_id), bytes(json.dumps(dict(row)), encoding="utf-8")))

key_to_ids_items = [(self._keys_prefix + key, json.dumps(list(key_to_ids[key])).encode("utf-8")) for key in key_to_ids]

self._trie = marisa_trie.BytesTrie(key_to_ids_items + id_to_loc_items)
8 changes: 6 additions & 2 deletions tests/prod_db/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@

@pytest.fixture(scope='session')
def city_idx():
return USCityIndex.load()
city_idx = USCityIndex()
city_idx.load()
return city_idx


@pytest.fixture(scope='session')
def state_idx():
return USStateIndex.load()
state_idx = USStateIndex()
state_idx.load()
return state_idx
4 changes: 2 additions & 2 deletions tests/prod_db/test_us_city_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def test_cases(city_idx, query, matches, xfail):

res = city_idx[query]

ids = [r.data.wof_id for r in res]
ids = [r["wof_id"] for r in res]

# Exact id list match.
assert sorted(ids) == sorted(matches)
Expand All @@ -49,6 +49,6 @@ def test_topn(city_idx, city):
"""Smoke test N most populous cities.
"""
res = city_idx['%s, %s' % (city.name, city.name_a1)]
res_ids = [r.data.wof_id for r in res]
res_ids = [r["wof_id"] for r in res]

assert city.wof_id in res_ids
4 changes: 2 additions & 2 deletions tests/prod_db/test_us_state_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_cases(state_idx, query, matches):

res = state_idx[query]

ids = [r.data.wof_id for r in res]
ids = [r["wof_id"] for r in res]

assert sorted(ids) == sorted(matches)

Expand All @@ -41,6 +41,6 @@ def test_all(state_idx, state):
"""Smoke test N most populous cities.
"""
res = state_idx[state.name]
res_ids = [r.data.wof_id for r in res]
res_ids = [r["wof_id"] for r in res]

assert state.wof_id in res_ids
45 changes: 45 additions & 0 deletions tests/runtime/concurrency_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from multiprocessing import Pool
from litecoder.usa import USCityIndex, USStateIndex
import time

NUM_PROCESSES = 4

# Load 50 test city lookups
with open("tests/runtime/test_city_lookups.txt", "r") as lookups_file:
city_tests = lookups_file.read().splitlines()

# Increase the number of lookups for the speed test if necessary
for x in range (10):
city_tests += city_tests
num_tests_per_process = len(city_tests)
num_tests = NUM_PROCESSES * num_tests_per_process

# Load USCityIndex
city_idx = USCityIndex()
city_idx.load()


def lookup_cities(process_num):
print ('Process {}: looking up {} cities'.format(process_num, num_tests_per_process))
start_time = time.time()
for city in city_tests:
city_idx[city]
ms = 1000*(time.time() - start_time)
print("Process {}: finished, took {}ms @ {} ms/lookup!".format(process_num, ms, float(ms/num_tests_per_process)))

if __name__ == '__main__':
print("Looking up {} cities on {} processes...".format(num_tests, NUM_PROCESSES))
start_time = time.time()
with Pool(5) as p:
p.map(lookup_cities, range(1, NUM_PROCESSES+1))
ms = 1000*(time.time() - start_time)
print("Fully finished: took {}ms @ {} ms/lookup!".format(ms, float(ms/num_tests)))

print()
print("Looking up all {} cities on one process...".format(num_tests), end="")
start_time = time.time()
for i in range(NUM_PROCESSES):
for city in city_tests:
city_idx[city]
ms = 1000*(time.time() - start_time)
print("finished: took {}ms @ {} ms/lookup!".format(ms, ms/num_tests))
44 changes: 44 additions & 0 deletions tests/runtime/speed_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from litecoder.usa import USCityIndex, USStateIndex
import time

print("Loading USCityIndex... ", end="")
start_time = time.time()
city_idx = USCityIndex()
city_idx.load()
print("finished: {}s!".format(time.time() - start_time))

# Load 50 test city lookups
with open("test_city_lookups.txt", "r") as lookups_file:
city_tests = lookups_file.read().splitlines()

# Increase the number of lookups for the speed test if necessary
for x in range (5):
city_tests += city_tests
num_tests = len(city_tests)
print("measuring time for {} cities... ".format(num_tests), end="")
start_time = time.time()
for city in city_tests:
city_idx[city]
ms = 1000*(time.time() - start_time)
print("finished: took {}ms at {} ms/lookup!".format(ms, float(ms/num_tests)))

print("Loading USStateIndex... ", end="")
start_time = time.time()
state_idx = USStateIndex()
state_idx.load()
print("finished: {}s!".format(time.time() - start_time))

# Load 50 test state lookups
with open("test_state_lookups.txt", "r") as lookups_file:
state_tests = lookups_file.read().splitlines()

# Increase the number of lookups for the speed test if necessary
for x in range (5):
state_tests += state_tests
num_tests = len(state_tests)
print("measuring time for {} states... ".format(num_tests), end="")
start_time = time.time()
for state in state_tests:
state_idx[state]
ms = 1000*(time.time() - start_time)
print("finished: took {}ms at {} ms/lookup!".format(ms, float(ms/num_tests)))
Loading