From fbc53fb6b08319907cb84d7ffecc21db57f9e624 Mon Sep 17 00:00:00 2001 From: Luca Date: Thu, 3 Dec 2020 20:41:43 +0000 Subject: [PATCH 1/5] First commit after 7 years --- .github/workflows/build.yml | 38 + .github/workflows/release.yml | 36 + .gitignore | 1 + .travis.yml | 36 - Makefile | 36 + clean.py | 27 +- covrun.py | 5 +- dev/requirements-dev.txt | 7 + requirements.txt => dev/requirements.txt | 0 dev/run-black.sh | 2 + docs/source/_ext/sphinxtogithub.py | 4 +- docs/source/conf.py | 7 +- docs/source/examples/ts.py | 2 +- examples/data.py | 116 +-- examples/m2m.py | 5 +- examples/models.py | 57 +- examples/observer.py | 38 +- examples/permissions.py | 136 ++-- examples/spelling/spelling.py | 12 +- examples/sql.py | 6 +- examples/tsmodels.py | 26 +- examples/wordsearch/basicwords.py | 6 +- requirements_dev.txt | 6 - runtests.py | 43 +- setup.cfg | 29 + setup.py | 124 +-- stdnet/__init__.py | 47 +- stdnet/apps/__init__.py | 4 +- stdnet/apps/columnts/__init__.py | 4 +- stdnet/apps/columnts/models.py | 192 ++--- stdnet/apps/columnts/npts.py | 16 +- stdnet/apps/columnts/redis.py | 213 +++-- stdnet/apps/searchengine/__init__.py | 103 +-- stdnet/apps/searchengine/models.py | 19 +- .../apps/searchengine/processors/__init__.py | 13 +- stdnet/apps/searchengine/processors/ignore.py | 14 +- .../apps/searchengine/processors/metaphone.py | 1 + stdnet/apps/searchengine/processors/porter.py | 314 ++++--- stdnet/apps/tasks/__init__.py | 8 +- stdnet/apps/tasks/models.py | 5 +- stdnet/backends/__init__.py | 236 +++--- stdnet/backends/redisb/__init__.py | 479 ++++++----- stdnet/backends/redisb/client/__init__.py | 39 +- stdnet/backends/redisb/client/async.py | 21 +- stdnet/backends/redisb/client/client.py | 22 +- stdnet/backends/redisb/client/extensions.py | 251 +++--- stdnet/backends/redisb/client/prefixed.py | 158 ++-- stdnet/odm/__init__.py | 12 +- stdnet/odm/base.py | 421 +++++----- stdnet/odm/fields.py | 770 ++++++++++-------- stdnet/odm/globals.py | 36 +- stdnet/odm/mapper.py | 268 +++--- stdnet/odm/models.py | 167 ++-- stdnet/odm/query.py | 560 +++++++------ stdnet/odm/related.py | 156 ++-- stdnet/odm/search.py | 208 +++-- stdnet/odm/session.py | 474 +++++------ stdnet/odm/struct.py | 334 ++++---- stdnet/odm/structfields.py | 233 +++--- stdnet/odm/utils.py | 149 ++-- stdnet/utils/__init__.py | 28 +- stdnet/utils/dates.py | 42 +- stdnet/utils/encoders.py | 115 +-- stdnet/utils/fallbacks/_collections.py | 24 +- stdnet/utils/fallbacks/_importlib.py | 11 +- stdnet/utils/importer.py | 2 +- stdnet/utils/jsontools.py | 145 ++-- stdnet/utils/populate.py | 81 +- stdnet/utils/py2py3.py | 50 +- stdnet/utils/skiplist.py | 51 +- stdnet/utils/structures.py | 2 +- stdnet/utils/test.py | 222 ++--- stdnet/utils/version.py | 40 +- stdnet/utils/zset.py | 20 +- tests/all/apps/columnts/evaluate.py | 21 +- tests/all/apps/columnts/field.py | 15 +- tests/all/apps/columnts/main.py | 357 ++++---- tests/all/apps/columnts/manipulate.py | 25 +- tests/all/apps/columnts/npts.py | 47 +- tests/all/apps/columnts/readonly.py | 93 +-- tests/all/apps/searchengine/add.py | 59 +- tests/all/apps/searchengine/meta.py | 237 +++--- tests/all/apps/searchengine/search.py | 54 +- tests/all/backends/interface.py | 25 +- tests/all/backends/redis/async.py | 28 +- tests/all/backends/redis/client.py | 284 ++++--- tests/all/backends/redis/info.py | 94 +-- tests/all/backends/redis/prefixed.py | 34 +- tests/all/benchmarks/__init__.py | 11 +- tests/all/fields/fk.py | 36 +- tests/all/fields/fknotrequired.py | 123 +-- tests/all/fields/id.py | 113 ++- tests/all/fields/integer.py | 10 +- tests/all/fields/jsonfield.py | 300 +++---- tests/all/fields/meta.py | 10 +- tests/all/fields/pickle.py | 33 +- tests/all/fields/pk.py | 34 +- tests/all/fields/scalar.py | 72 +- tests/all/lib/autoincrement.py | 37 +- tests/all/lib/local.py | 46 +- tests/all/lib/me.py | 9 +- tests/all/lib/meta.py | 99 ++- tests/all/lib/register.py | 41 +- tests/all/multifields/hash.py | 55 +- tests/all/multifields/list.py | 22 +- tests/all/multifields/set.py | 37 +- tests/all/multifields/string.py | 20 +- tests/all/multifields/struct.py | 39 +- tests/all/multifields/timeseries.py | 292 +++---- tests/all/query/contains.py | 45 +- tests/all/query/delete.py | 98 +-- tests/all/query/get_field.py | 76 +- tests/all/query/instruments.py | 133 ++- tests/all/query/load_only.py | 214 +++-- tests/all/query/load_related.py | 151 ++-- tests/all/query/manager.py | 81 +- tests/all/query/manytomany.py | 158 ++-- tests/all/query/meta.py | 43 +- tests/all/query/ranges.py | 104 ++- tests/all/query/related.py | 106 +-- tests/all/query/session.py | 40 +- tests/all/query/signal.py | 21 +- tests/all/query/slice.py | 23 +- tests/all/query/sorting.py | 96 +-- tests/all/query/transaction.py | 54 +- tests/all/query/unique.py | 87 +- tests/all/query/where.py | 21 +- tests/all/serialize/base.py | 43 +- tests/all/serialize/csv.py | 20 +- tests/all/serialize/json.py | 16 +- tests/all/structures/base.py | 32 +- tests/all/structures/hash.py | 58 +- tests/all/structures/list.py | 18 +- tests/all/structures/numarray.py | 19 +- tests/all/structures/set.py | 19 +- tests/all/structures/string.py | 16 +- tests/all/structures/ts.py | 21 +- tests/all/structures/zset.py | 81 +- tests/all/topics/finance.py | 42 +- tests/all/topics/observer.py | 24 +- tests/all/topics/permissions.py | 41 +- tests/all/topics/twitter.py | 68 +- tests/all/utils/intervals.py | 98 ++- tests/all/utils/tools.py | 210 ++--- tests/all/utils/zset.py | 49 +- 145 files changed, 6731 insertions(+), 6092 deletions(-) create mode 100644 .github/workflows/build.yml create mode 100644 .github/workflows/release.yml delete mode 100644 .travis.yml create mode 100644 Makefile create mode 100644 dev/requirements-dev.txt rename requirements.txt => dev/requirements.txt (100%) create mode 100755 dev/run-black.sh delete mode 100644 requirements_dev.txt create mode 100644 setup.cfg diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml new file mode 100644 index 0000000..3796015 --- /dev/null +++ b/.github/workflows/build.yml @@ -0,0 +1,38 @@ +name: build + +on: + push: + branches-ignore: + - deploy + tags-ignore: + - v* + +jobs: + + build: + runs-on: ubuntu-latest + env: + PYTHON_ENV: ci + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + COVERALLS_REPO_TOKEN: ${{ secrets.COVERALLS_REPO_TOKEN }} + strategy: + matrix: + python-version: [3.6, 3.7, 3.8, 3.9] + + steps: + - uses: actions/checkout@v2 + - name: run postgres + run: make postgresql + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: make install + - name: run lint + run: make test-lint + - name: run tests + run: make test + - name: upload coverage + if: matrix.python-version == '3.8' + run: coveralls diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..24a9700 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,36 @@ +name: release + +on: + push: + branches: + - deploy + tags-ignore: + - v* + +jobs: + release: + runs-on: ubuntu-latest + env: + PYTHON_ENV: ci + PYPI_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + GITHUB_TOKEN: ${{ secrets.QMBOT_GITHUB_TOKEN }} + strategy: + matrix: + python-version: [3.6, 3.7, 3.8, 3.9] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: make install + - name: test version + run: make test-version + - name: build python bundle + run: "make bundle${{ matrix.python-version }}" + - name: release to pypi + run: make release-pypi + - name: release to github + if: matrix.python-version == '3.8' + run: make release-github diff --git a/.gitignore b/.gitignore index 1df572f..7ef1d48 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ *.o *.def dist +venv __pycache__ extensions/src/cparser.cpp build diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 65431e4..0000000 --- a/.travis.yml +++ /dev/null @@ -1,36 +0,0 @@ -language: python - -python: - - "2.6" - - "2.7" - - "3.2" - - "3.3" - - "pypy" - -install: - - if [[ $TRAVIS_PYTHON_VERSION == '2.6' ]]; then pip install --use-mirrors argparse unittest2; fi - - pip install -r requirements_dev.txt --use-mirrors - - git clone https://github.com/quantmind/pulsar.git - - cd pulsar - - python setup.py install - - cd .. - - sudo rm -rf pulsar - - python setup.py install - - sudo rm -rf /dev/shm && sudo ln -s /run/shm /dev/shm - -services: - - redis-server - -script: - - pep8 stdnet --exclude stdnet/apps/searchengine/processors - - sudo rm -rf stdnet - - python -m covrun - -notifications: - email: false - -# Only test master and dev -branches: - only: - - master - - dev diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..2bc16be --- /dev/null +++ b/Makefile @@ -0,0 +1,36 @@ +REPO_NAME = quantmind + +# Fixed - dont modify these lines ================================== +K8S_NS ?= prod +LOCAL_DOCKER_NETWORK = services_default +# ================================================================== + +GIT_SHA := $(shell git rev-parse HEAD) +TIMESTAMP := $(shell date -u) + + +.PHONY: help clean deploy env freeze install image serverless test + +help: + @echo ======================== METACORE ==================================================== + @fgrep -h "##" $(MAKEFILE_LIST) | fgrep -v fgrep | sed -e 's/\\$$//' | sed -e 's/##//' + @echo ====================================================================================== + +clean: ## remove python cache files + ../devops/dev/clean.sh + + +cloc: ## Count lines of code + cloc --exclude-dir=tools,venv,node-modules,build,.pytest_cache,.mypy_cache,target . + + +install: ## install python dependencies in venv + @pip install -U pip twine + @pip install -U -r ./dev/requirements-dev.txt + @pip install -U -r ./dev/requirements.txt + + +lint: ## run linters + isort . + ./dev/run-black.sh + flake8 diff --git a/clean.py b/clean.py index dca9d2c..fc981fe 100644 --- a/clean.py +++ b/clean.py @@ -1,17 +1,18 @@ import os import shutil - + + def rmgeneric(path, __func__): try: __func__(path) - #print 'Removed ', path + # print 'Removed ', path return 1 except OSError as e: - print('Could not remove {0}, {1}'.format(path,e)) + print("Could not remove {0}, {1}".format(path, e)) return 0 - - -def rmfiles(path, ext = None, rmcache = True): + + +def rmfiles(path, ext=None, rmcache=True): if not os.path.isdir(path): return 0 trem = 0 @@ -20,24 +21,22 @@ def rmfiles(path, ext = None, rmcache = True): for f in files: fullpath = os.path.join(path, f) if os.path.isfile(fullpath): - sf = f.split('.') + sf = f.split(".") if len(sf) == 2: if ext == None or sf[1] == ext: tall += 1 trem += rmgeneric(fullpath, os.remove) - elif f == '__pycache__' and rmcache: + elif f == "__pycache__" and rmcache: shutil.rmtree(fullpath) tall += 1 elif os.path.isdir(fullpath): - r,ra = rmfiles(fullpath, ext) + r, ra = rmfiles(fullpath, ext) trem += r tall += ra return trem, tall - -if __name__ == '__main__': +if __name__ == "__main__": path = os.curdir - removed, allfiles = rmfiles(path,'pyc') - print('removed {0} pyc files out of {1}'.format(removed, allfiles)) - + removed, allfiles = rmfiles(path, "pyc") + print("removed {0} pyc files out of {1}".format(removed, allfiles)) diff --git a/covrun.py b/covrun.py index daed967..515903e 100644 --- a/covrun.py +++ b/covrun.py @@ -1,10 +1,9 @@ -import sys import os +import sys from runtests import run - -if __name__ == '__main__': +if __name__ == "__main__": if sys.version_info > (3, 3): run(coverage=True, coveralls=True) else: diff --git a/dev/requirements-dev.txt b/dev/requirements-dev.txt new file mode 100644 index 0000000..d433426 --- /dev/null +++ b/dev/requirements-dev.txt @@ -0,0 +1,7 @@ +flake8 +black +isort +pytest +cython +coverage +mypy diff --git a/requirements.txt b/dev/requirements.txt similarity index 100% rename from requirements.txt rename to dev/requirements.txt diff --git a/dev/run-black.sh b/dev/run-black.sh new file mode 100755 index 0000000..febd01a --- /dev/null +++ b/dev/run-black.sh @@ -0,0 +1,2 @@ +#!/bin/bash +black . --exclude "venv|build|docs" $1 diff --git a/docs/source/_ext/sphinxtogithub.py b/docs/source/_ext/sphinxtogithub.py index b66b232..3c12fd4 100644 --- a/docs/source/_ext/sphinxtogithub.py +++ b/docs/source/_ext/sphinxtogithub.py @@ -1,8 +1,8 @@ #! /usr/bin/env python -from optparse import OptionParser import os -import sys import shutil +import sys +from optparse import OptionParser class NoDirectoriesError(Exception): diff --git a/docs/source/conf.py b/docs/source/conf.py index a89cd0f..1fdae07 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -11,7 +11,8 @@ # All configuration values have a default; values that are commented out # serve to show the default. -import sys, os +import os +import sys # If your extensions are in another directory, add it here. source_dir = os.path.split(os.path.abspath(__file__))[0] @@ -20,9 +21,11 @@ sys.path.append(os.path.join(source_dir, "_ext")) sys.path.append(base_dir) import stdnet + version = stdnet.__version__ release = version -import runtests # so that it import pulsar if available +import runtests # so that it import pulsar if available + # -- General configuration ----------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be extensions diff --git a/docs/source/examples/ts.py b/docs/source/examples/ts.py index c111a3e..3852b16 100644 --- a/docs/source/examples/ts.py +++ b/docs/source/examples/ts.py @@ -2,7 +2,7 @@ from stdnet import odm from stdnet.contrib.timeseries.models import TimeSeries - + class FinanceTimeSeries(TimeSeries): ticker = odm.SymbolField(unique = True) diff --git a/examples/data.py b/examples/data.py index d672e88..8b18f56 100644 --- a/examples/data.py +++ b/examples/data.py @@ -1,72 +1,75 @@ from datetime import date, timedelta from random import randint -from stdnet.utils import test, populate, zip, iteritems +from stdnet.utils import iteritems, populate, test, zip -from .models import Instrument, Fund, Position +from .models import Fund, Instrument, Position - -CCYS_TYPES = ['EUR', 'GBP', 'AUD', 'USD', 'CHF', 'JPY'] -INSTS_TYPES = ['equity', 'bond', 'future', 'cash', 'option', 'bond option'] +CCYS_TYPES = ["EUR", "GBP", "AUD", "USD", "CHF", "JPY"] +INSTS_TYPES = ["equity", "bond", "future", "cash", "option", "bond option"] def assertEqual(x, y): - assert x == y, 'no equal' + assert x == y, "no equal" class key_data(test.DataGenerator): - def generate(self, min_len=10, max_len=20, **kwargs): - self.keys = populate('string', self.size, min_len=min_len, - max_len=max_len) - self.values = populate('string', self.size, min_len=min_len+10, - max_len=max_len+20) + self.keys = populate("string", self.size, min_len=min_len, max_len=max_len) + self.values = populate( + "string", self.size, min_len=min_len + 10, max_len=max_len + 20 + ) - def mapping(self, prefix=''): + def mapping(self, prefix=""): for k, v in zip(self.keys, self.values): - yield prefix+k, v + yield prefix + k, v class hash_data(key_data): - sizes = {'tiny': (50, 30), # fields/average field size - 'small': (300, 100), - 'normal': (1000, 300), - 'big': (5000, 1000), - 'huge': (20000, 5000)} - - def generate(self, fieldtype='string', **kwargs): + sizes = { + "tiny": (50, 30), # fields/average field size + "small": (300, 100), + "normal": (1000, 300), + "big": (5000, 1000), + "huge": (20000, 5000), + } + + def generate(self, fieldtype="string", **kwargs): fsize, dsize = self.size - if fieldtype == 'date': - self.fields = populate('date', fsize, - start=date(1971, 12, 30), - end=date.today()) + if fieldtype == "date": + self.fields = populate( + "date", fsize, start=date(1971, 12, 30), end=date.today() + ) else: - self.fields = populate('string', fsize, min_len=5, max_len=30) - self.data = populate('string', fsize, min_len=dsize, max_len=dsize) + self.fields = populate("string", fsize, min_len=5, max_len=30) + self.data = populate("string", fsize, min_len=dsize, max_len=dsize) def items(self): return zip(self.fields, self.data) class finance_data(test.DataGenerator): - sizes = {'tiny': (20, 3, 10, 1), # positions = 20*100*3 = 30 - 'small': (100, 10, 30, 2), # positions = 20*100*3 = 600 - 'normal': (500, 20, 100, 3), # positions = 20*100*3 = 6,000 - 'big': (2000, 30, 200, 5), # positions = 30*200*5 = 30,000 - 'huge': (10000, 50, 300, 8)} # positions = 50*300*8 = 120,000 + sizes = { + "tiny": (20, 3, 10, 1), # positions = 20*100*3 = 30 + "small": (100, 10, 30, 2), # positions = 20*100*3 = 600 + "normal": (500, 20, 100, 3), # positions = 20*100*3 = 6,000 + "big": (2000, 30, 200, 5), # positions = 30*200*5 = 30,000 + "huge": (10000, 50, 300, 8), + } # positions = 50*300*8 = 120,000 def generate(self, insts_types=None, ccys_types=None, **kwargs): inst_len, fund_len, pos_len, num_dates = self.size insts_types = insts_types or INSTS_TYPES ccys_types = ccys_types or CCYS_TYPES self.pos_len = pos_len - self.inst_names = populate('string', inst_len, min_len=5, max_len=20) - self.inst_types = populate('choice', inst_len, choice_from=insts_types) - self.inst_ccys = populate('choice', inst_len, choice_from=ccys_types) - self.fund_names = populate('string', fund_len, min_len=5, max_len=20) - self.fund_ccys = populate('choice', fund_len, choice_from=ccys_types) - self.dates = populate('date', num_dates, start=date(2009, 6, 1), - end=date(2010, 6, 6)) + self.inst_names = populate("string", inst_len, min_len=5, max_len=20) + self.inst_types = populate("choice", inst_len, choice_from=insts_types) + self.inst_ccys = populate("choice", inst_len, choice_from=ccys_types) + self.fund_names = populate("string", fund_len, min_len=5, max_len=20) + self.fund_ccys = populate("choice", fund_len, choice_from=ccys_types) + self.dates = populate( + "date", num_dates, start=date(2009, 6, 1), end=date(2010, 6, 6) + ) def create(self, test, use_transaction=True): session = test.session() @@ -78,14 +81,14 @@ def create(self, test, use_transaction=True): with session.begin() as t: for name, ccy in zip(self.fund_names, self.fund_ccys): t.add(models.fund(name=name, ccy=ccy)) - for name, typ, ccy in zip(self.inst_names, self.inst_types, - self.inst_ccys): + for name, typ, ccy in zip( + self.inst_names, self.inst_types, self.inst_ccys + ): t.add(models.instrument(name=name, type=typ, ccy=ccy)) yield t.on_result else: test.register() - for name, typ, ccy in zip(self.inst_names, self.inst_types, - self.inst_ccys): + for name, typ, ccy in zip(self.inst_names, self.inst_types, self.inst_ccys): yield models.instrument.new(name=name, type=typ, ccy=ccy) for name, ccy in zip(self.fund_names, self.fund_ccys): yield models.fund(name=name, ccy=ccy) @@ -102,28 +105,37 @@ def makePositions(self, test, use_transaction=True): if use_transaction: with session.begin() as t: for f in funds: - insts = populate('choice', self.pos_len, - choice_from=instruments) + insts = populate("choice", self.pos_len, choice_from=instruments) for dt in self.dates: for inst in insts: - t.add(Position(instrument=inst, dt=dt, fund=f, - size=randint(-100000, 100000))) + t.add( + Position( + instrument=inst, + dt=dt, + fund=f, + size=randint(-100000, 100000), + ) + ) yield t.on_result else: for f in funds: - insts = populate('choice', self.pos_len, - choice_from=instruments) + insts = populate("choice", self.pos_len, choice_from=instruments) for dt in self.dates: for inst in insts: - yield Position(instrument=inst, dt=dt, fund=f, - size=randint(-100000, 100000)).save() + yield Position( + instrument=inst, + dt=dt, + fund=f, + size=randint(-100000, 100000), + ).save() # self.num_pos = yield session.query(Position).count() yield session class FinanceTest(test.TestCase): - '''A class for testing the Finance application example. It can be run -with different sizes by passing the''' + """A class for testing the Finance application example. It can be run + with different sizes by passing the""" + data_cls = finance_data models = (Instrument, Fund, Position) diff --git a/examples/m2m.py b/examples/m2m.py index e5632a1..acdb9af 100644 --- a/examples/m2m.py +++ b/examples/m2m.py @@ -11,5 +11,6 @@ class CompositeElement(odm.StdModel): class Composite(odm.StdModel): name = odm.SymbolField() - elements = odm.ManyToManyField(Element, through=CompositeElement, - related_name='composites') \ No newline at end of file + elements = odm.ManyToManyField( + Element, through=CompositeElement, related_name="composites" + ) diff --git a/examples/models.py b/examples/models.py index a9f266c..3a59b97 100755 --- a/examples/models.py +++ b/examples/models.py @@ -1,13 +1,12 @@ import time -from datetime import datetime, date +from datetime import date, datetime from stdnet import odm class CustomManager(odm.Manager): - def small_query(self, **kwargs): - return self.query(**kwargs).load_only('code', 'group') + return self.query(**kwargs).load_only("code", "group") def something(self): return "I'm a custom manager" @@ -51,9 +50,9 @@ class Instrument2(Base): type = odm.SymbolField() class Meta: - ordering = 'id' - app_label = 'examples2' - name = 'instrument' + ordering = "id" + app_label = "examples2" + name = "instrument" class Fund(Base): @@ -61,13 +60,13 @@ class Fund(Base): class Position(odm.StdModel): - instrument = odm.ForeignKey(Instrument, related_name='positions') - fund = odm.ForeignKey(Fund, related_name='positions') + instrument = odm.ForeignKey(Instrument, related_name="positions") + fund = odm.ForeignKey(Fund, related_name="positions") dt = odm.DateField() size = odm.FloatField(default=1) def __unicode__(self): - return '%s: %s @ %s' % (self.fund, self.instrument, self.dt) + return "%s: %s @ %s" % (self.fund, self.instrument, self.dt) class PortfolioView(odm.StdModel): @@ -77,9 +76,9 @@ class PortfolioView(odm.StdModel): class Folder(odm.StdModel): name = odm.SymbolField() - view = odm.ForeignKey(PortfolioView, related_name='folders') - positions = odm.ManyToManyField(Position, related_name='folders') - parent = odm.ForeignKey('self', related_name='children', required=False) + view = odm.ForeignKey(PortfolioView, related_name="folders") + positions = odm.ManyToManyField(Position, related_name="folders") + parent = odm.ForeignKey("self", related_name="children", required=False) def __unicode__(self): return self.name @@ -97,7 +96,7 @@ class DateValue(odm.StdModel): def score(self): "implement the score function for sorting in the ordered set" - return int(1000*time.mktime(self.dt.timetuple())) + return int(1000 * time.mktime(self.dt.timetuple())) class Calendar(odm.StdModel): @@ -129,15 +128,13 @@ class TestDateModel(odm.StdModel): class SportAtDate(TestDateModel): - class Meta: - ordering = 'dt' + ordering = "dt" class SportAtDate2(TestDateModel): - class Meta: - ordering = '-dt' + ordering = "-dt" class Group(odm.StdModel): @@ -152,11 +149,11 @@ class Person(odm.StdModel): # A model for testing a recursive foreign key class Node(odm.StdModel): - parent = odm.ForeignKey('self', required=False, related_name='children') + parent = odm.ForeignKey("self", required=False, related_name="children") weight = odm.FloatField() def __unicode__(self): - return '%s' % self.weight + return "%s" % self.weight class Page(odm.StdModel): @@ -173,18 +170,19 @@ class Collection(odm.StdModel): class Post(odm.StdModel): dt = odm.DateTimeField(index=False, default=datetime.now) data = odm.CharField(required=True) - user = odm.ForeignKey('examples.user', index=False) + user = odm.ForeignKey("examples.user", index=False) def __unicode__(self): return self.data class User(odm.StdModel): - '''A model for holding information about users''' + """A model for holding information about users""" + username = odm.SymbolField(unique=True) password = odm.SymbolField(index=False) updates = odm.ListField(model=Post) - following = odm.ManyToManyField(model='self', related_name='followers') + following = odm.ManyToManyField(model="self", related_name="followers") def __unicode__(self): return self.username @@ -237,6 +235,7 @@ class Environment(odm.StdModel): ############################################## # Numeric Data + class NumericData(odm.StdModel): pv = odm.FloatField() vega = odm.FloatField(default=0.0) @@ -257,7 +256,7 @@ class DateData(odm.StdModel): class CrossData(odm.StdModel): name = odm.SymbolField() data = odm.JSONField(as_string=False) - extra = odm.ForeignKey('self', required=False) + extra = odm.ForeignKey("self", required=False) class FeedBase(odm.StdModel): @@ -285,7 +284,7 @@ class Task(odm.StdModel): timestamp = odm.DateTimeField(default=datetime.now) class Meta: - ordering = '-timestamp' + ordering = "-timestamp" def clone(self, **kwargs): instance = super(Task, self).clone(**kwargs) @@ -301,18 +300,18 @@ class Parent(odm.StdModel): class Child(odm.StdModel): name = odm.SymbolField() parent = odm.ForeignKey(Parent) - uncles = odm.ManyToManyField(Parent, related_name='nephews') + uncles = odm.ManyToManyField(Parent, related_name="nephews") #################################################### # Composite ID class WordBook(odm.StdModel): - id = odm.CompositeIdField('word', 'book') + id = odm.CompositeIdField("word", "book") word = odm.SymbolField() book = odm.SymbolField() def __unicode__(self): - return '%s:%s' % (self.word, self.book) + return "%s:%s" % (self.word, self.book) ############################################################################ @@ -323,12 +322,12 @@ class ObjectAnalytics(odm.StdModel): @property def object(self): - if not hasattr(self, '_object'): + if not hasattr(self, "_object"): self._object = self.model_type.objects.get(id=self.object_id) return self._object class AnalyticData(odm.StdModel): group = odm.ForeignKey(Group) - object = odm.ForeignKey(ObjectAnalytics, related_name='analytics') + object = odm.ForeignKey(ObjectAnalytics, related_name="analytics") data = odm.JSONField() diff --git a/examples/observer.py b/examples/observer.py index ad89bae..ddff15a 100644 --- a/examples/observer.py +++ b/examples/observer.py @@ -1,18 +1,20 @@ -'''This example is an implementation of the Observer design-pattern +"""This example is an implementation of the Observer design-pattern when Observers receives multiple updates from several instances they are observing. -''' +""" from time import time + from stdnet import odm -from stdnet.odm import struct from stdnet.backends import redisb +from stdnet.odm import struct class update_observer(redisb.RedisScript): - '''Script for adding/updating an observer. The ARGV contains, the member -value, the initial score (usually a timestamp) and the increment for -subsequent additions.''' - script = '''\ + """Script for adding/updating an observer. The ARGV contains, the member + value, the initial score (usually a timestamp) and the increment for + subsequent additions.""" + + script = """\ local key = KEYS[1] local index = 0 local n = 0 @@ -28,17 +30,18 @@ class update_observer(redisb.RedisScript): end end return n -''' +""" class RedisUpdateZset(redisb.Zset): - '''Redis backend structure override Zset''' + """Redis backend structure override Zset""" + def flush(self): cache = self.instance.cache result = None if cache.toadd: flat = tuple(self.flat(cache.toadd.items())) - self.client.execute_script('update_observer', (self.id,), *flat) + self.client.execute_script("update_observer", (self.id,), *flat) result = True if cache.toremove: flat = tuple((el[1] for el in cache.toremove)) @@ -57,23 +60,23 @@ class UpdateZset(odm.Zset): penalty = 0 # penalty in seconds def __init__(self, *args, **kwargs): - self.penalty = kwargs.pop('penalty', self.penalty) + self.penalty = kwargs.pop("penalty", self.penalty) super(UpdateZset, self).__init__(*args, **kwargs) def dump_data(self, instances): dt = time() for n, instance in enumerate(instances): - if hasattr(instance, 'pkvalue'): + if hasattr(instance, "pkvalue"): instance = instance.pkvalue() # put n so that it allows for repeated values yield dt, (n, self.penalty, instance) + # Register the new structure with redis backend -redisb.BackendDataServer.struct_map['updatezset'] = RedisUpdateZset +redisb.BackendDataServer.struct_map["updatezset"] = RedisUpdateZset class UpdatesField(odm.StructureField): - def structure_class(self): return UpdateZset @@ -85,7 +88,7 @@ class Observable(odm.StdModel): class Observer(odm.StdModel): # Underlyings are the Obsarvable this Observer is tracking for updates name = odm.CharField() - underlyings = odm.ManyToManyField(Observable, related_name='observers') + underlyings = odm.ManyToManyField(Observable, related_name="observers") # field with a 5 seconds penalty updates = UpdatesField(class_field=True, penalty=5) @@ -101,5 +104,6 @@ def update_observers(signal, sender, instances=None, session=None, **kwargs): observers = models.observer through = models[observers.underlyings.model] return through.backend.execute( - through.filter(observable=instances).get_field('observer').all(), - observers.updates.update) + through.filter(observable=instances).get_field("observer").all(), + observers.updates.update, + ) diff --git a/examples/permissions.py b/examples/permissions.py index 99ecf33..e24f1cb 100644 --- a/examples/permissions.py +++ b/examples/permissions.py @@ -1,4 +1,4 @@ -''' +""" This section is a practical application of ``stdnet`` for solving role-based access control (RBAC). It is an approach for managing users permissions on your application which could be a web-site, an organisation @@ -123,40 +123,38 @@ authenticated_query(query, user, level) -''' +""" from inspect import isclass -from stdnet import odm, FieldError +from stdnet import FieldError, odm -class PermissionManager(odm.Manager): +class PermissionManager(odm.Manager): def for_object(self, object, **params): if isclass(object): qs = self.query(model_type=object) else: - qs = self.query(model_type=object.__class__, - object_pk=object.pkvalue()) + qs = self.query(model_type=object.__class__, object_pk=object.pkvalue()) if params: qs = qs.filter(**params) return qs class GroupManager(odm.Manager): - def query(self, session=None): - '''Makes sure the :attr:`Group.user` is always loaded.''' - return super(GroupManager, self).query(session).load_related('user') + """Makes sure the :attr:`Group.user` is always loaded.""" + return super(GroupManager, self).query(session).load_related("user") def check_user(self, username, email): - '''username and email (if provided) must be unique.''' + """username and email (if provided) must be unique.""" users = self.router.user avail = yield users.filter(username=username).count() if avail: - raise FieldError('Username %s not available' % username) + raise FieldError("Username %s not available" % username) if email: avail = yield users.filter(email=email).count() if avail: - raise FieldError('Email %s not available' % email) + raise FieldError("Email %s not available" % email) def create_user(self, username=None, email=None, **params): yield self.check_user(username, email) @@ -166,59 +164,58 @@ def create_user(self, username=None, email=None, **params): yield self.new(user=user, name=user.username) def permitted_query(self, query, group, operations): - '''Change the ``query`` so that only instances for which -``group`` has roles with permission on ``operations`` are returned.''' + """Change the ``query`` so that only instances for which + ``group`` has roles with permission on ``operations`` are returned.""" session = query.session models = session.router user = group.user - if user.is_superuser: # super-users have all permissions + if user.is_superuser: # super-users have all permissions return query roles = group.roles.query() roles = group.roles.query() # query on all roles for group # The throgh model for Role/Permission relationship throgh_model = models.role.permissions.model - models[throgh_model].filter(role=roles, - permission__model_type=query.model, - permission__operations=operations) + models[throgh_model].filter( + role=roles, + permission__model_type=query.model, + permission__operations=operations, + ) # query on all relevant permissions - permissions = router.permission.filter(model_type=query.model, - level=operations) + permissions = router.permission.filter(model_type=query.model, level=operations) owner_query = query.filter(user=user) # all roles for the query model with appropriate permission level roles = models.role.filter(model_type=query.model, level__ge=level) # Now we need groups which have these roles - groups = Role.groups.throughquery( - session).filter(role=roles).get_field('group') + groups = Role.groups.throughquery(session).filter(role=roles).get_field("group") # I need to know if user is in any of these groups if user.groups.filter(id=groups).count(): # it is, lets get the model with permissions less # or equal permission level - permitted = models.instancerole.filter( - role=roles).get_field('object_id') + permitted = models.instancerole.filter(role=roles).get_field("object_id") return owner_query.union(model.objects.filter(id=permitted)) else: return owner_query class Subject(object): - roles = odm.ManyToManyField('Role', related_name='subjects') + roles = odm.ManyToManyField("Role", related_name="subjects") def create_role(self, name): - '''Create a new :class:`Role` owned by this :class:`Subject`''' + """Create a new :class:`Role` owned by this :class:`Subject`""" models = self.session.router return models.role.new(name=name, owner=self) def assign(self, role): - '''Assign :class:`Role` ``role`` to this :class:`Subject`. If this -:class:`Subject` is the :attr:`Role.owner`, this method does nothing.''' + """Assign :class:`Role` ``role`` to this :class:`Subject`. If this + :class:`Subject` is the :attr:`Role.owner`, this method does nothing.""" if role.owner_id != self.id: return self.roles.add(role) def has_permissions(self, object, group, operations): - '''Check if this :class:`Subject` has permissions for ``operations`` -on an ``object``. It returns the number of valid permissions.''' + """Check if this :class:`Subject` has permissions for ``operations`` + on an ``object``. It returns the number of valid permissions.""" if self.is_superuser: return 1 else: @@ -226,13 +223,13 @@ def has_permissions(self, object, group, operations): # valid permissions query = models.permission.for_object(object, operation=operations) objects = models[models.role.permissions.model] - return objects.filter(role=self.role.query(), - permission=query).count() + return objects.filter(role=self.role.query(), permission=query).count() class User(odm.StdModel): - '''The user of a system. The only field required is the :attr:`username`. -which is also unique across all users.''' + """The user of a system. The only field required is the :attr:`username`. + which is also unique across all users.""" + username = odm.SymbolField(unique=True) password = odm.CharField(required=False, hidden=True) first_name = odm.CharField() @@ -247,18 +244,18 @@ def __unicode__(self): class Group(odm.StdModel, Subject): - id = odm.CompositeIdField('name', 'user') + id = odm.CompositeIdField("name", "user") name = odm.SymbolField() - '''Group name. If the group is for a signle user, it can be the -user username''' + """Group name. If the group is for a signle user, it can be the +user username""" user = odm.ForeignKey(User) - '''A group is always `owned` by a :class:`User`. For example the ``admin`` -group for a website is owned by the ``website`` user.''' + """A group is always `owned` by a :class:`User`. For example the ``admin`` +group for a website is owned by the ``website`` user.""" # - users = odm.ManyToManyField(User, related_name='groups') - '''The :class:`stdnet.odm.ManyToManyField` for linking :class:`User` -and :class:`Group`.''' - roles = odm.ManyToManyField('Role', related_name='subjects') + users = odm.ManyToManyField(User, related_name="groups") + """The :class:`stdnet.odm.ManyToManyField` for linking :class:`User` +and :class:`Group`.""" + roles = odm.ManyToManyField("Role", related_name="subjects") manager_class = GroupManager @@ -267,15 +264,16 @@ def __unicode__(self): class Permission(odm.StdModel): - '''A model which implements permission and operation within -this RBAC implementation.''' - id = odm.CompositeIdField('model_type', 'object_pk', 'operation') - '''The name of the role, for example, ``Editor`` for a role which can - edit a certain :attr:`model_type`.''' + """A model which implements permission and operation within + this RBAC implementation.""" + + id = odm.CompositeIdField("model_type", "object_pk", "operation") + """The name of the role, for example, ``Editor`` for a role which can + edit a certain :attr:`model_type`.""" model_type = odm.ModelField() - '''The model (resource) which this permission refers to.''' + """The model (resource) which this permission refers to.""" operation = odm.IntegerField(default=0) - '''The operation assigned to this permission.''' + """The operation assigned to this permission.""" object_pk = odm.SymbolField(required=False) manager_class = PermissionManager @@ -283,36 +281,36 @@ class Permission(odm.StdModel): def __unicode__(self): op = self.operation if self.object_pk: - return '%s - %s - %s' % (self.model_type, self.object_pk, op) + return "%s - %s - %s" % (self.model_type, self.object_pk, op) else: - return '%s - %s' % (self.model_type, op) + return "%s - %s" % (self.model_type, op) class Role(odm.StdModel): - '''A :class:`Role` is uniquely identified by its :attr:`name` and -:attr:`owner`.''' - id = odm.CompositeIdField('name', 'owner') + """A :class:`Role` is uniquely identified by its :attr:`name` and + :attr:`owner`.""" + + id = odm.CompositeIdField("name", "owner") name = odm.SymbolField() - '''The name of this role.''' + """The name of this role.""" owner = odm.ForeignKey(Group) - '''The owner of this role-permission.''' - permissions = odm.ManyToManyField(Permission, related_name='roles') - '''the set of all :class:`Permission` assigned to this :class:`Role`.''' + """The owner of this role-permission.""" + permissions = odm.ManyToManyField(Permission, related_name="roles") + """the set of all :class:`Permission` assigned to this :class:`Role`.""" def __unicode__(self): return self.name def add_permission(self, resource, operation): - '''Add a new :class:`Permission` for ``resource`` to perform an -``operation``. The resource can be either an object or a model.''' + """Add a new :class:`Permission` for ``resource`` to perform an + ``operation``. The resource can be either an object or a model.""" if isclass(resource): model_type = resource - pk = '' + pk = "" else: model_type = resource.__class__ pk = resource.pkvalue() - p = Permission(model_type=model_type, object_pk=pk, - operation=operation) + p = Permission(model_type=model_type, object_pk=pk, operation=operation) session = self.session if session.transaction: session.add(p) @@ -325,14 +323,14 @@ def add_permission(self, resource, operation): return t.add_callback(lambda r: p) def assignto(self, subject): - '''Assign this :class:`Role` to ``subject``.''' + """Assign this :class:`Role` to ``subject``.""" return subject.assign(self) def register_for_permissions(model): - if 'group' not in model._meta.dfields: + if "group" not in model._meta.dfields: group = odm.ForeignKey(Group, related_name=model.__name__.lower()) - group.register_with_model('group', model) - group = model._meta.dfields['group'] + group.register_with_model("group", model) + group = model._meta.dfields["group"] if not isinstance(group, odm.ForeignKey) or group.relmodel != Group: - raise RuntimeError('group field of wrong type') + raise RuntimeError("group field of wrong type") diff --git a/examples/spelling/spelling.py b/examples/spelling/spelling.py index 6295aa8..a027d4f 100755 --- a/examples/spelling/spelling.py +++ b/examples/spelling/spelling.py @@ -6,15 +6,15 @@ # To train you can use # http://norvig.com/big.txt # +import collections import os import re -import collections CURDIR = os.path.split(os.path.abspath(__file__))[0] def words(text): - return re.findall('[a-z]+', text.lower()) + return re.findall("[a-z]+", text.lower()) def train(features): @@ -23,10 +23,11 @@ def train(features): model[f] += 1 return model -NWORDS = train(words(open(os.path.join(CURDIR, 'big.txt')).read())) + +NWORDS = train(words(open(os.path.join(CURDIR, "big.txt")).read())) -alphabet = 'abcdefghijklmnopqrstuvwxyz' +alphabet = "abcdefghijklmnopqrstuvwxyz" def edits1(word): @@ -47,6 +48,5 @@ def known(words): def correct(word): - candidates = (known([word]) or known(edits1(word)) or known_edits2(word) - or [word]) + candidates = known([word]) or known(edits1(word)) or known_edits2(word) or [word] return max(candidates, key=NWORDS.get) diff --git a/examples/sql.py b/examples/sql.py index 543e0fd..a1a6233 100644 --- a/examples/sql.py +++ b/examples/sql.py @@ -1,12 +1,12 @@ -from sqlalchemy.ext.declarative import declarative_base from sqlalchemy import Column, Integer, String +from sqlalchemy.ext.declarative import declarative_base Base = declarative_base() class User(Base): - __tablename__ = 'users' + __tablename__ = "users" id = Column(Integer, primary_key=True) fullname = Column(String) password = Column(String) - email = Column(String) \ No newline at end of file + email = Column(String) diff --git a/examples/tsmodels.py b/examples/tsmodels.py index b6d1898..ee6cf55 100644 --- a/examples/tsmodels.py +++ b/examples/tsmodels.py @@ -1,6 +1,6 @@ from stdnet import odm -from stdnet.utils import encoders, todatetime, todate, missing_intervals from stdnet.apps.columnts import ColumnTSField +from stdnet.utils import encoders, missing_intervals, todate, todatetime class TimeSeries(odm.StdModel): @@ -20,26 +20,33 @@ def __get_start(self): r = self.data.front() if r: return r[0] + data_start = property(__get_start) def __get_end(self): r = self.data.back() if r: return r[0] + data_end = property(__get_end) def size(self): - '''number of dates in timeseries''' + """number of dates in timeseries""" return self.data.size() def intervals(self, startdate, enddate, parseinterval=None): - '''Given a ``startdate`` and an ``enddate`` dates, evaluate the + """Given a ``startdate`` and an ``enddate`` dates, evaluate the date intervals from which data is not available. It return a list of two-dimensional tuples containing start and end date for the - interval. The list could contain 0, 1 or 2 tuples.''' - return missing_intervals(startdate, enddate, self.data_start, - self.data_end, dateconverter=self.todate, - parseinterval=parseinterval) + interval. The list could contain 0, 1 or 2 tuples.""" + return missing_intervals( + startdate, + enddate, + self.data_start, + self.data_end, + dateconverter=self.todate, + parseinterval=parseinterval, + ) class DateTimeSeries(TimeSeries): @@ -50,8 +57,9 @@ def todate(self, v): class BigTimeSeries(DateTimeSeries): - data = odm.TimeSeriesField(pickler=encoders.DateConverter(), - value_pickler=encoders.PythonPickle()) + data = odm.TimeSeriesField( + pickler=encoders.DateConverter(), value_pickler=encoders.PythonPickle() + ) class ColumnTimeSeries(odm.StdModel): diff --git a/examples/wordsearch/basicwords.py b/examples/wordsearch/basicwords.py index f7a0842..63792cf 100644 --- a/examples/wordsearch/basicwords.py +++ b/examples/wordsearch/basicwords.py @@ -1,4 +1,4 @@ -basic_english_words = 'a,able,about,account,acid,across,act,addition,\ +basic_english_words = "a,able,about,account,acid,across,act,addition,\ adjustment,advertisement,after,again,against,agreement,air,all,almost,\ among,amount,amusement,and,angle,angry,animal,answer,ant,any,apparatus,\ apple,approval,arch,argument,arm,army,art,as,at,attack,attempt,attention,\ @@ -73,4 +73,6 @@ wall,war,warm,wash,waste,watch,water,wave,wax,way,weather,week,weight,well,\ west,wet,wheel,when,where,while,whip,whistle,white,who,why,wide,will,wind,\ window,wine,wing,winter,wire,wise,with,woman,wood,wool,word,work,worm,wound,\ -writing,wrong,year,yellow,yes,yesterday,you,young'.split(',') +writing,wrong,year,yellow,yes,yesterday,you,young".split( + "," +) diff --git a/requirements_dev.txt b/requirements_dev.txt deleted file mode 100644 index 0392700..0000000 --- a/requirements_dev.txt +++ /dev/null @@ -1,6 +0,0 @@ -redis -pep8 -cython -coverage -mock -pulsar==0.7.4 diff --git a/runtests.py b/runtests.py index a03d759..910891d 100755 --- a/runtests.py +++ b/runtests.py @@ -1,7 +1,7 @@ #!/usr/bin/env python -'''Stdnet asynchronous test suite. Requires pulsar.''' -import sys +"""Stdnet asynchronous test suite. Requires pulsar.""" import os +import sys from multiprocessing import current_process ## This is for dev environment with pulsar and dynts. @@ -11,18 +11,20 @@ try: import pulsar except ImportError: - pdir = p.join(dir, 'pulsar') + pdir = p.join(dir, "pulsar") if os.path.isdir(pdir): sys.path.append(pdir) import pulsar + from pulsar.apps.test import TestSuite from pulsar.apps.test.plugins import bench, profile from pulsar.utils.path import Path + # try: import dynts except ImportError: - pdir = p.join(dir, 'dynts') + pdir = p.join(dir, "dynts") if os.path.isdir(pdir): sys.path.append(pdir) try: @@ -32,36 +34,41 @@ def run(**params): - args = params.get('argv', sys.argv) - if '--coverage' in args or params.get('coverage'): + args = params.get("argv", sys.argv) + if "--coverage" in args or params.get("coverage"): import coverage + p = current_process() p._coverage = coverage.coverage(data_suffix=True) p._coverage.start() runtests(**params) - + def runtests(**params): import stdnet from stdnet.utils import test + # strip_dirs = [Path(stdnet.__file__).parent.parent, os.getcwd()] # - suite = TestSuite(description='Stdnet Asynchronous test suite', - modules=('tests.all',), - plugins=(test.StdnetPlugin(), - bench.BenchMark(), - profile.Profile()), - **params) - suite.bind_event('tests', test.create_tests) + suite = TestSuite( + description="Stdnet Asynchronous test suite", + modules=("tests.all",), + plugins=(test.StdnetPlugin(), bench.BenchMark(), profile.Profile()), + **params + ) + suite.bind_event("tests", test.create_tests) suite.start() # if suite.cfg.coveralls: from pulsar.utils.cov import coveralls - coveralls(strip_dirs=strip_dirs, - stream=suite.stream, - repo_token='ZQinNe5XNbzQ44xYGTljP8R89jrQ5xTKB') + + coveralls( + strip_dirs=strip_dirs, + stream=suite.stream, + repo_token="ZQinNe5XNbzQ44xYGTljP8R89jrQ5xTKB", + ) -if __name__ == '__main__': +if __name__ == "__main__": run() diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..8f6e244 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,29 @@ +[flake8] +exclude = __pycache__,.eggs,venv,build,dist,docs,dev +max-line-length = 88 +ignore = A001,A002,A003,C815,C812,W503,E203 + +[isort] +line_length=88 +src_paths=stdnet,tests +multi_line_output=3 +include_trailing_comma=True + +[mypy] +python_version = 3.8 +ignore_missing_imports=True +disallow_untyped_calls=False +warn_return_any=False +# disallow_untyped_defs=True +warn_no_return=True + +[tool:pytest] +testpaths = tests +filterwarnings= default + ignore:::.*raven_aiohttp.* + ignore:::.*asynctest.* + ignore:::.*aioconsole.* + ignore:::.*aiohttp.* + ignore:::.*mockaioredis.* + ignore:This message has already been written once.*:UserWarning + ignore:numpy.dtype size changed, may indicate binary incompatibility:RuntimeWarning diff --git a/setup.py b/setup.py index 6f37a04..f24bfac 100644 --- a/setup.py +++ b/setup.py @@ -1,60 +1,65 @@ import os import sys +from distutils.command.install import INSTALL_SCHEMES +from distutils.command.install_data import install_data from setuptools import setup -from distutils.command.install_data import install_data -from distutils.command.install import INSTALL_SCHEMES if sys.version_info < (2, 6): raise Exception("stdnet requires Python 2.6 or higher.") -package_name = 'stdnet' -package_fullname = 'python-%s' % package_name +package_name = "stdnet" +package_fullname = "python-%s" % package_name root_dir = os.path.split(os.path.abspath(__file__))[0] package_dir = os.path.join(root_dir, package_name) + def get_module(): if root_dir not in sys.path: - sys.path.insert(0,root_dir) + sys.path.insert(0, root_dir) return __import__(package_name) + mod = get_module() # Try to import lib build -#try: +# try: # from extensions.setup import libparams, BuildFailed -#except ImportError: +# except ImportError: # libparams = None libparams = False + def read(fname): return open(os.path.join(root_dir, fname)).read() + def requirements(): - req = read('requirements.txt').replace('\r','').split('\n') + req = read("requirements.txt").replace("\r", "").split("\n") result = [] for r in req: - r = r.replace(' ','') + r = r.replace(" ", "") if r: result.append(r) return result -class osx_install_data(install_data): +class osx_install_data(install_data): def finalize_options(self): - self.set_undefined_options('install', ('install_lib', 'install_dir')) + self.set_undefined_options("install", ("install_lib", "install_dir")) install_data.finalize_options(self) # Tell distutils to put the data_files in platform-specific installation # locations. See here for an explanation: for scheme in INSTALL_SCHEMES.values(): - scheme['data'] = scheme['purelib'] + scheme["data"] = scheme["purelib"] def read(fname): return open(os.path.join(os.path.dirname(__file__), fname)).read() - + + def fullsplit(path, result=None): """ Split a pathname into components (the opposite of os.path.join) in a @@ -63,76 +68,85 @@ def fullsplit(path, result=None): if result is None: result = [] head, tail = os.path.split(path) - if head == '': + if head == "": return [tail] + result if head == path: return result return fullsplit(head, [tail] + result) - + + # Compile the list of packages available, because distutils doesn't have # an easy way to do this. -def get_rel_dir(d,base,res=''): +def get_rel_dir(d, base, res=""): if d == base: return res - br,r = os.path.split(d) + br, r = os.path.split(d) if res: - r = os.path.join(r,res) - return get_rel_dir(br,base,r) + r = os.path.join(r, res) + return get_rel_dir(br, base, r) + packages, data_files = [], [] pieces = fullsplit(root_dir) -if pieces[-1] == '': +if pieces[-1] == "": len_root_dir = len(pieces) - 1 else: len_root_dir = len(pieces) for dirpath, _, filenames in os.walk(package_dir): - if '__init__.py' in filenames: - packages.append('.'.join(fullsplit(dirpath)[len_root_dir:])) - elif filenames and not dirpath.endswith('__pycache__'): + if "__init__.py" in filenames: + packages.append(".".join(fullsplit(dirpath)[len_root_dir:])) + elif filenames and not dirpath.endswith("__pycache__"): rel_dir = get_rel_dir(dirpath, package_dir) data_files.extend((os.path.join(rel_dir, f) for f in filenames)) -if len(sys.argv) > 1 and sys.argv[1] == 'bdist_wininst': +if len(sys.argv) > 1 and sys.argv[1] == "bdist_wininst": for file_info in data_files: - file_info[0] = '\\PURELIB\\%s' % file_info[0] - + file_info[0] = "\\PURELIB\\%s" % file_info[0] + def run_setup(params=None, argv=None): - params = params or {'cmdclass': {}} + params = params or {"cmdclass": {}} if sys.platform == "darwin": - params['cmdclass']['install_data'] = osx_install_data + params["cmdclass"]["install_data"] = osx_install_data else: - params['cmdclass']['install_data'] = install_data + params["cmdclass"]["install_data"] = install_data argv = argv if argv is not None else sys.argv if len(argv) > 1: - if argv[1] == 'install' and sys.version_info >= (3,0): - packages.remove('stdnet.utils.fallbacks.py2') - params.update({'name': package_fullname, - 'version': mod.__version__, - 'author': mod.__author__, - 'author_email': mod.__contact__, - 'url': mod.__homepage__, - 'license': mod.__license__, - 'description': mod.__doc__, - 'long_description': read('README.rst'), - 'packages': packages, - 'package_data': {package_name: data_files}, - 'classifiers': mod.CLASSIFIERS, - 'install_requires': requirements()}) + if argv[1] == "install" and sys.version_info >= (3, 0): + packages.remove("stdnet.utils.fallbacks.py2") + params.update( + { + "name": package_fullname, + "version": mod.__version__, + "author": mod.__author__, + "author_email": mod.__contact__, + "url": mod.__homepage__, + "license": mod.__license__, + "description": mod.__doc__, + "long_description": read("README.rst"), + "packages": packages, + "package_data": {package_name: data_files}, + "classifiers": mod.CLASSIFIERS, + "install_requires": requirements(), + } + ) setup(**params) - + + def status_msgs(*msgs): - print('*' * 75) + print("*" * 75) for msg in msgs: print(msg) - print('*' * 75) + print("*" * 75) + if libparams is False: run_setup() elif libparams is None: - status_msgs('WARNING: C extensions could not be compiled, ' - 'Cython is not installed.') + status_msgs( + "WARNING: C extensions could not be compiled, " "Cython is not installed." + ) run_setup() status_msgs("Plain-Python build succeeded.") else: @@ -140,11 +154,11 @@ def status_msgs(*msgs): run_setup(libparams) except BuildFailed as exc: status_msgs( - exc.msg, - "WARNING: C extensions could not be compiled, " + - "speedups are not enabled.", - "Failure information, if any, is above.", - "Retrying the build without C extensions now." - ) + exc.msg, + "WARNING: C extensions could not be compiled, " + + "speedups are not enabled.", + "Failure information, if any, is above.", + "Retrying the build without C extensions now.", + ) run_setup() - status_msgs("Plain-Python build succeeded.") \ No newline at end of file + status_msgs("Plain-Python build succeeded.") diff --git a/stdnet/__init__.py b/stdnet/__init__.py index 3eaf0e0..2efd983 100755 --- a/stdnet/__init__.py +++ b/stdnet/__init__.py @@ -1,11 +1,11 @@ -'''Object data mapper and advanced query manager for non relational +"""Object data mapper and advanced query manager for non relational databases. -''' +""" +from .backends import * from .utils.exceptions import * from .utils.version import get_version, stdnet_version -from .backends import * -VERSION = stdnet_version(0, 9, 0, 'alpha', 3) +VERSION = stdnet_version(0, 9, 0, "alpha", 3) __version__ = version = get_version(VERSION) @@ -13,22 +13,23 @@ __author__ = "Luca Sbardella" __contact__ = "luca.sbardella@gmail.com" __homepage__ = "https://github.com/lsbardel/python-stdnet" -CLASSIFIERS = ['Development Status :: 4 - Beta', - 'Environment :: Plugins', - 'Environment :: Console', - 'Environment :: Web Environment', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: BSD License', - 'Operating System :: OS Independent', - 'Programming Language :: Python', - 'Programming Language :: Python :: 2', - 'Programming Language :: Python :: 2.6', - 'Programming Language :: Python :: 2.7', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.2', - 'Programming Language :: Python :: 3.3', - 'Programming Language :: Python :: Implementation :: PyPy', - 'Topic :: Utilities', - 'Topic :: Database', - 'Topic :: Internet' - ] +CLASSIFIERS = [ + "Development Status :: 4 - Beta", + "Environment :: Plugins", + "Environment :: Console", + "Environment :: Web Environment", + "Intended Audience :: Developers", + "License :: OSI Approved :: BSD License", + "Operating System :: OS Independent", + "Programming Language :: Python", + "Programming Language :: Python :: 2", + "Programming Language :: Python :: 2.6", + "Programming Language :: Python :: 2.7", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.2", + "Programming Language :: Python :: 3.3", + "Programming Language :: Python :: Implementation :: PyPy", + "Topic :: Utilities", + "Topic :: Database", + "Topic :: Internet", +] diff --git a/stdnet/apps/__init__.py b/stdnet/apps/__init__.py index debecb8..2a27531 100644 --- a/stdnet/apps/__init__.py +++ b/stdnet/apps/__init__.py @@ -1,4 +1,4 @@ -'''\ +"""\ Collection of applications which may be relevant or not to the user. They show off the main features of the core library. -''' +""" diff --git a/stdnet/apps/columnts/__init__.py b/stdnet/apps/columnts/__init__.py index 08f8933..604d3e3 100644 --- a/stdnet/apps/columnts/__init__.py +++ b/stdnet/apps/columnts/__init__.py @@ -1,4 +1,4 @@ -'''\ +"""\ **backends**: :ref:`Redis `. An application which implements a specialised remote @@ -92,6 +92,6 @@ class Ticker(odm.StdModel): .. _timeseries: http://en.wikipedia.org/wiki/Time_series -''' +""" from . import redis from .models import * diff --git a/stdnet/apps/columnts/models.py b/stdnet/apps/columnts/models.py index a0f2781..5da6353 100644 --- a/stdnet/apps/columnts/models.py +++ b/stdnet/apps/columnts/models.py @@ -1,10 +1,9 @@ -'''Multivariate numeric timeseries interface.''' -from stdnet import odm, SessionNotAvailable, InvalidTransaction +"""Multivariate numeric timeseries interface.""" +from stdnet import InvalidTransaction, SessionNotAvailable, odm +from stdnet.utils import encoders, iteritems, iterpair, zip from stdnet.utils.skiplist import skiplist -from stdnet.utils import encoders, iteritems, zip, iterpair - -__all__ = ['TimeseriesCache', 'ColumnTS', 'ColumnTSField', 'as_dict'] +__all__ = ["TimeseriesCache", "ColumnTS", "ColumnTSField", "as_dict"] class TimeseriesCache(object): @@ -41,42 +40,44 @@ def as_dict(times, fields): class ColumnTS(odm.TS): - '''A specialised :class:`stdnet.odm.TS` structure for numeric -multivariate timeseries.''' - default_multi_stats = ['covariance'] + """A specialised :class:`stdnet.odm.TS` structure for numeric + multivariate timeseries.""" + + default_multi_stats = ["covariance"] cache_class = TimeseriesCache pickler = encoders.DateTimeConverter() value_pickler = encoders.Double() def front(self, *fields): - '''Return the front pair of the structure''' + """Return the front pair of the structure""" v, f = tuple(self.irange(0, 0, fields=fields)) if v: return (v[0], dict(((field, f[field][0]) for field in f))) def back(self, *fields): - '''Return the back pair of the structure''' + """Return the back pair of the structure""" v, f = tuple(self.irange(-1, -1, fields=fields)) if v: return (v[0], dict(((field, f[field][0]) for field in f))) def info(self, start=None, end=None, fields=None): - '''Provide data information for this :class:`ColumnTS`. If no -parameters are specified it returns the number of data points for each -fields, as well as the start and end date.''' + """Provide data information for this :class:`ColumnTS`. If no + parameters are specified it returns the number of data points for each + fields, as well as the start and end date.""" start = self.pickler.dumps(start) if start else None end = self.pickler.dumps(end) if end else None backend = self.read_backend return backend.execute( - backend.structure(self).info(start, end, fields), self._stats) + backend.structure(self).info(start, end, fields), self._stats + ) def fields(self): - '''Tuple of ordered fields for this :class:`ColumnTS`.''' + """Tuple of ordered fields for this :class:`ColumnTS`.""" return self.backend_structure().fields() def numfields(self): - '''Number of fields.''' + """Number of fields.""" return self.backend_structure().numfields() @odm.commit_when_no_transaction @@ -96,101 +97,100 @@ def update(self, mapping): def evaluate(self, script, *series, **params): backend = self.backend return backend.execute( - backend.structure(self).run_script('evaluate', series, script, - **params), self._evaluate) + backend.structure(self).run_script("evaluate", series, script, **params), + self._evaluate, + ) def istats(self, start=0, end=-1, fields=None): - '''Perform a multivariate statistic calculation of this -:class:`ColumnTS` from *start* to *end*. - -:param start: Optional index (rank) where to start the analysis. -:param end: Optional index (rank) where to end the analysis. -:param fields: Optional subset of :meth:`fields` to perform analysis on. - If not provided all fields are included in the analysis. -''' + """Perform a multivariate statistic calculation of this + :class:`ColumnTS` from *start* to *end*. + + :param start: Optional index (rank) where to start the analysis. + :param end: Optional index (rank) where to end the analysis. + :param fields: Optional subset of :meth:`fields` to perform analysis on. + If not provided all fields are included in the analysis.""" backend = self.read_backend return backend.execute( - backend.structure(self).istats(start, end, fields), self._stats) + backend.structure(self).istats(start, end, fields), self._stats + ) def stats(self, start, end, fields=None): - '''Perform a multivariate statistic calculation of this -:class:`ColumnTS` from a *start* date/datetime to an -*end* date/datetime. - -:param start: Start date for analysis. -:param end: End date for analysis. -:param fields: Optional subset of :meth:`fields` to perform analysis on. - If not provided all fields are included in the analysis. -''' + """Perform a multivariate statistic calculation of this + :class:`ColumnTS` from a *start* date/datetime to an + *end* date/datetime. + + :param start: Start date for analysis. + :param end: End date for analysis. + :param fields: Optional subset of :meth:`fields` to perform analysis on. + If not provided all fields are included in the analysis.""" start = self.pickler.dumps(start) end = self.pickler.dumps(end) backend = self.read_backend return backend.execute( - backend.structure(self).stats(start, end, fields), self._stats) - - def imulti_stats(self, start=0, end=-1, series=None, fields=None, - stats=None): - '''Perform cross multivariate statistics calculation of -this :class:`ColumnTS` and other optional *series* from *start* -to *end*. - -:parameter start: the start rank. -:parameter start: the end rank -:parameter field: name of field to perform multivariate statistics. -:parameter series: a list of two elements tuple containing the id of the - a :class:`columnTS` and a field name. -:parameter stats: list of statistics to evaluate. - Default: ['covariance'] -''' + backend.structure(self).stats(start, end, fields), self._stats + ) + + def imulti_stats(self, start=0, end=-1, series=None, fields=None, stats=None): + """Perform cross multivariate statistics calculation of + this :class:`ColumnTS` and other optional *series* from *start* + to *end*. + + :parameter start: the start rank. + :parameter start: the end rank + :parameter field: name of field to perform multivariate statistics. + :parameter series: a list of two elements tuple containing the id of the + a :class:`columnTS` and a field name. + :parameter stats: list of statistics to evaluate. + Default: ['covariance']""" stats = stats or self.default_multi_stats backend = self.read_backend return backend.execute( - backend.structure(self).imulti_stats(start, end, fields, series, - stats), self._stats) - - def multi_stats(self, start, end, series=None, fields=None, stats=None): - '''Perform cross multivariate statistics calculation of -this :class:`ColumnTS` and other *series*. - -:parameter start: the start date. -:parameter start: the end date -:parameter field: name of field to perform multivariate statistics. -:parameter series: a list of two elements tuple containing the id of the - a :class:`columnTS` and a field name. -:parameter stats: list of statistics to evaluate. - Default: ['covariance'] -''' + backend.structure(self).imulti_stats(start, end, fields, series, stats), + self._stats, + ) + + def multi_stats(self, start, end, series=None, fields=None, stats=None): + """Perform cross multivariate statistics calculation of + this :class:`ColumnTS` and other *series*. + + :parameter start: the start date. + :parameter start: the end date + :parameter field: name of field to perform multivariate statistics. + :parameter series: a list of two elements tuple containing the id of the + a :class:`columnTS` and a field name. + :parameter stats: list of statistics to evaluate. + Default: ['covariance']""" stats = stats or self.default_multi_stats start = self.pickler.dumps(start) end = self.pickler.dumps(end) backend = self.read_backend return backend.execute( - backend.structure(self).multi_stats(start, end, fields, series, - stats), self._stats) + backend.structure(self).multi_stats(start, end, fields, series, stats), + self._stats, + ) def merge(self, *series, **kwargs): - '''Merge this :class:`ColumnTS` with several other *series*. + """Merge this :class:`ColumnTS` with several other *series*. -:parameters series: a list of tuples where the nth element is a tuple - of the form:: + :parameters series: a list of tuples where the nth element is a tuple + of the form:: - (wight_n, ts_n1, ts_n2, ..., ts_nMn) + (wight_n, ts_n1, ts_n2, ..., ts_nMn) -The result will be calculated using the formula:: + The result will be calculated using the formula:: - ts = weight_1*ts_11*ts_12*...*ts_1M1 + weight_2*ts_21*ts_22*...*ts_2M2 + - ... -''' + ts = weight_1*ts_11*ts_12*...*ts_1M1 + weight_2*ts_21*ts_22*...*ts_2M2 + + ...""" session = self.session if not session: - raise SessionNotAvailable('No session available') + raise SessionNotAvailable("No session available") self.check_router(session.router, *series) return self._merge(*series, **kwargs) @classmethod def merged_series(cls, *series, **kwargs): - '''Merge ``series`` and return the results without storing data -in the backend server.''' + """Merge ``series`` and return the results without storing data + in the backend server.""" router, backend = cls.check_router(None, *series) if backend: target = router.register(cls(), backend) @@ -198,8 +198,8 @@ def merged_series(cls, *series, **kwargs): target._merge(*series, **kwargs) backend = target.backend return backend.execute( - backend.structure(target).irange_and_delete(), - target.load_data) + backend.structure(target).irange_and_delete(), target.load_data + ) # INTERNALS @classmethod @@ -207,29 +207,30 @@ def check_router(cls, router, *series): backend = None for serie in series: if len(serie) < 2: - raise ValueError('merge requires tuples of length 2 or more') + raise ValueError("merge requires tuples of length 2 or more") for s in serie[1:]: if not s.session: - raise SessionNotAvailable('No session available') + raise SessionNotAvailable("No session available") if router is None: router = s.session.router else: if router is not s.session.router: - raise InvalidTransaction('mistmaching routers') + raise InvalidTransaction("mistmaching routers") if backend is None: backend = s.backend_structure().backend else: if backend is not s.backend_structure().backend: - raise InvalidTransaction('merging is possible only on ' - 'the same backend') + raise InvalidTransaction( + "merging is possible only on " "the same backend" + ) return router, backend def _merge(self, *series, **kwargs): - fields = kwargs.get('fields') or () + fields = kwargs.get("fields") or () self.backend_structure().merge(series, fields) def load_data(self, result): - #Overwrite :meth:`stdnet.odm.PairMixin.load_data` method + # Overwrite :meth:`stdnet.odm.PairMixin.load_data` method loads = self.pickler.loads vloads = self.value_pickler.loads dt = [loads(t) for t in result[0]] @@ -243,9 +244,9 @@ def load_get_data(self, result): return dict(((f, vloads(v)) for f, v in iterpair(result))) def _stats(self, result): - if result and 'start' in result: - result['start'] = self.pickler.loads(result['start']) - result['stop'] = self.pickler.loads(result['stop']) + if result and "start" in result: + result["start"] = self.pickler.loads(result["start"]) + result["stop"] = self.pickler.loads(result["stop"]) return result def _evaluate(self, result): @@ -264,12 +265,13 @@ def _add(self, dt, *args): elif len(args) == 2: add(timestamp, args[0], dump(args[1])) else: - raise TypeError('Expected a mapping or a field value pair') + raise TypeError("Expected a mapping or a field value pair") class ColumnTSField(odm.StructureField): - '''A multivariate timeseries field.''' - type = 'columnts' + """A multivariate timeseries field.""" + + type = "columnts" def structure_class(self): return ColumnTS diff --git a/stdnet/apps/columnts/npts.py b/stdnet/apps/columnts/npts.py index 690b071..89de6d8 100644 --- a/stdnet/apps/columnts/npts.py +++ b/stdnet/apps/columnts/npts.py @@ -1,29 +1,28 @@ -'''Experimental! +"""Experimental! This is an experimental module for converting ColumnTS into dynts.timeseries. It requires dynts_. .. _dynts: https://github.com/quantmind/dynts -''' +""" from collections import Mapping -from . import models as columnts - import numpy as ny - from dynts import timeseries, tsname +from . import models as columnts + class ColumnTS(columnts.ColumnTS): - '''Integrate stdnet timeseries with dynts_ TimeSeries''' + """Integrate stdnet timeseries with dynts_ TimeSeries""" def front(self, *fields): - '''Return the front pair of the structure''' + """Return the front pair of the structure""" ts = self.irange(0, 0, fields=fields) if ts: return ts.start(), ts[0] def back(self, *fields): - '''Return the back pair of the structure''' + """Return the back pair of the structure""" ts = self.irange(-1, -1, fields=fields) if ts: return ts.end(), ts[0] @@ -56,6 +55,5 @@ def _get(self, result): class ColumnTSField(columnts.ColumnTSField): - def structure_class(self): return ColumnTS diff --git a/stdnet/apps/columnts/redis.py b/stdnet/apps/columnts/redis.py index 8c63947..68302ba 100644 --- a/stdnet/apps/columnts/redis.py +++ b/stdnet/apps/columnts/redis.py @@ -1,138 +1,188 @@ -'''Redis implementation of ColumnTS''' -import os +"""Redis implementation of ColumnTS""" import json +import os from stdnet.backends import redisb from stdnet.utils.encoders import safe_number class RedisColumnTS(redisb.RedisStructure): - '''Redis backend for :class:`ColumnTS`''' + """Redis backend for :class:`ColumnTS`""" + def __contains__(self, timestamp): - return self.client.execute_script('timeseries_run', (self.id,), - 'exists', timestamp) + return self.client.execute_script( + "timeseries_run", (self.id,), "exists", timestamp + ) def size(self): - return self.client.execute_script('timeseries_run', (self.id,), 'size') + return self.client.execute_script("timeseries_run", (self.id,), "size") @property def fieldsid(self): - return self.id + ':fields' + return self.id + ":fields" def fieldid(self, field): - return self.id + ':field:' + field + return self.id + ":field:" + field def flush(self): cache = self.instance.cache sargs = self.flat() if sargs: - return self.client.execute_script('timeseries_run', (self.id,), - 'session', *sargs) + return self.client.execute_script( + "timeseries_run", (self.id,), "session", *sargs + ) elif cache.merged_series: - return self.client.execute_script('timeseries_run', (self.id,), - 'merge', cache.merged_series) + return self.client.execute_script( + "timeseries_run", (self.id,), "merge", cache.merged_series + ) def allkeys(self): - return self.client.keys(self.id + '*') + return self.client.keys(self.id + "*") def fields(self): - '''Return a tuple of ordered fields for this :class:`ColumnTS`.''' - key = self.id + ':fields' + """Return a tuple of ordered fields for this :class:`ColumnTS`.""" + key = self.id + ":fields" encoding = self.client.encoding - return tuple(sorted((f.decode(encoding) - for f in self.client.smembers(key)))) + return tuple(sorted((f.decode(encoding) for f in self.client.smembers(key)))) def info(self, start, end, fields): fields = fields or () - return self.client.execute_script('timeseries_run', (self.id,), 'info', - start or -1, end or -1, *fields, - return_type='json') + return self.client.execute_script( + "timeseries_run", + (self.id,), + "info", + start or -1, + end or -1, + *fields, + return_type="json" + ) def field(self, field): - '''Fetch an entire row field string from redis''' + """Fetch an entire row field string from redis""" return self.client.get(self.fieldid(field)) def numfields(self): - '''Number of fields''' + """Number of fields""" return self.client.scard(self.fieldsid) def get(self, dte): - return self.client.execute_script('timeseries_run', (self.id,), 'get', - dte, return_type='get') + return self.client.execute_script( + "timeseries_run", (self.id,), "get", dte, return_type="get" + ) def pop(self, dte): - return self.client.execute_script('timeseries_run', (self.id,), - 'pop', dte, return_type='get') + return self.client.execute_script( + "timeseries_run", (self.id,), "pop", dte, return_type="get" + ) def ipop(self, index): - return self.client.execute_script('timeseries_run', (self.id,), - 'ipop', index, return_type='get') + return self.client.execute_script( + "timeseries_run", (self.id,), "ipop", index, return_type="get" + ) def irange(self, start=0, end=-1, fields=None, **kwargs): fields = fields or () - return self.client.execute_script('timeseries_run', (self.id,), - 'irange', - start, end, *fields, fields=fields, - return_type='range', **kwargs) + return self.client.execute_script( + "timeseries_run", + (self.id,), + "irange", + start, + end, + *fields, + fields=fields, + return_type="range", + **kwargs + ) def range(self, start, end, fields=None, **kwargs): fields = fields or () - return self.client.execute_script('timeseries_run', (self.id,), - 'range', - start, end, *fields, fields=fields, - return_type='range', **kwargs) + return self.client.execute_script( + "timeseries_run", + (self.id,), + "range", + start, + end, + *fields, + fields=fields, + return_type="range", + **kwargs + ) def irange_and_delete(self): - return self.client.execute_script('timeseries_run', (self.id,), - 'irange_and_delete', - return_type='range') + return self.client.execute_script( + "timeseries_run", (self.id,), "irange_and_delete", return_type="range" + ) def pop_range(self, start, end, **kwargs): - return self.client.execute_script('timeseries_run', (self.id,), - 'pop_range', start, end, - return_type='range', **kwargs) + return self.client.execute_script( + "timeseries_run", + (self.id,), + "pop_range", + start, + end, + return_type="range", + **kwargs + ) def ipop_range(self, start=0, end=-1, **kwargs): - return self.client.execute_script('timeseries_run', (self.id,), - 'ipop_range', start, end, - return_type='range', **kwargs) + return self.client.execute_script( + "timeseries_run", + (self.id,), + "ipop_range", + start, + end, + return_type="range", + **kwargs + ) def times(self, start, end, **kwargs): - return self.client.execute_script('timeseries_run', (self.id,), - 'times', start, end, **kwargs) + return self.client.execute_script( + "timeseries_run", (self.id,), "times", start, end, **kwargs + ) def itimes(self, start=0, end=-1, **kwargs): - return self.client.execute_script('timeseries_run', (self.id,), - 'itimes', start, end, **kwargs) + return self.client.execute_script( + "timeseries_run", (self.id,), "itimes", start, end, **kwargs + ) def stats(self, start, end, fields=None, **kwargs): fields = fields or () - return self.client.execute_script('timeseries_run', (self.id,), - 'stats', - start, end, *fields, - return_type='json', **kwargs) + return self.client.execute_script( + "timeseries_run", + (self.id,), + "stats", + start, + end, + *fields, + return_type="json", + **kwargs + ) def istats(self, start, end, fields=None, **kwargs): fields = fields or () - return self.client.execute_script('timeseries_run', (self.id,), - 'istats', start, end, *fields, - return_type='json', **kwargs) + return self.client.execute_script( + "timeseries_run", + (self.id,), + "istats", + start, + end, + *fields, + return_type="json", + **kwargs + ) def multi_stats(self, start, end, fields, series, stats): - return self._multi_stats(start, end, 'multi_stats', fields, series, - stats) + return self._multi_stats(start, end, "multi_stats", fields, series, stats) def imulti_stats(self, start, end, fields, series, stats): - return self._multi_stats(start, end, 'imulti_stats', fields, series, - stats) + return self._multi_stats(start, end, "imulti_stats", fields, series, stats) def merge(self, series, fields): all_series = [] - argv = {'series': all_series, 'fields': fields} + argv = {"series": all_series, "fields": fields} for elems in series: ser = [] - d = {'weight': elems[0], - 'series': ser} + d = {"weight": elems[0], "series": ser} all_series.append(d) for ts in elems[1:]: ser.append(ts.backend_structure().id) @@ -145,8 +195,7 @@ def run_script(self, script_name, series, *args, **params): if params: args = list(args) args.append(json.dumps(params)) - return self.client.execute_script('timeseries_run', keys, - script_name, *args) + return self.client.execute_script("timeseries_run", keys, script_name, *args) ############################################################### INTERNALS def flat(self): @@ -158,12 +207,13 @@ def flat(self): data = [] for t, v in cache.fields[field]: times.append(t) - data.append('%s' % v) - fields.append({'times': times, - 'fields': {field: data}}) - data = {'delete_times': list(cache.deleted_timestamps), - 'delete_fields': list(cache.delete_fields), - 'add': fields} + data.append("%s" % v) + fields.append({"times": times, "fields": {field: data}}) + data = { + "delete_times": list(cache.deleted_timestamps), + "delete_fields": list(cache.delete_fields), + "add": fields, + } return [json.dumps(data)] args = [len(cache.deleted_timestamps)] @@ -182,33 +232,34 @@ def flat(self): def _multi_stats(self, start, end, command, fields, series, stats): all = [(self.id, fields)] if series: - all.extend(((ts.backend_structure().id, fields) - for ts, fields in series)) + all.extend(((ts.backend_structure().id, fields) for ts, fields in series)) keys = [] argv = [] for s in all: if not len(s) == 2: - raise ValueError('Series must be a list of two elements tuple') + raise ValueError("Series must be a list of two elements tuple") id, fields = s keys.append(id) fields = fields if fields is not None else () argv.append(fields) fields = json.dumps(argv) return self.client.execute_script( - 'timeseries_run', keys, command, start, end, fields, - return_type='json') + "timeseries_run", keys, command, start, end, fields, return_type="json" + ) # Add the redis structure to the struct map in the backend class -redisb.BackendDataServer.struct_map['columnts'] = RedisColumnTS +redisb.BackendDataServer.struct_map["columnts"] = RedisColumnTS ############################################################## SCRIPT class timeseries_run(redisb.RedisScript): - script = (redisb.read_lua_file('tabletools'), - redisb.read_lua_file('columnts.columnts'), - redisb.read_lua_file('columnts.stats'), - redisb.read_lua_file('columnts.runts')) + script = ( + redisb.read_lua_file("tabletools"), + redisb.read_lua_file("columnts.columnts"), + redisb.read_lua_file("columnts.stats"), + redisb.read_lua_file("columnts.runts"), + ) def callback(self, response, redis_client=None, return_type=None, **opts): if return_type and response: diff --git a/stdnet/apps/searchengine/__init__.py b/stdnet/apps/searchengine/__init__.py index f3682ef..ded385f 100644 --- a/stdnet/apps/searchengine/__init__.py +++ b/stdnet/apps/searchengine/__init__.py @@ -1,4 +1,4 @@ -'''\ +"""\ Stdnet provides a redis-based implementation for the :class:`stdnet.odm.SearchEngine` so that you can have your models stored and indexed in redis and if you like in the same redis instance. @@ -61,58 +61,65 @@ .. autoclass:: WordItem :members: :member-order: bysource -''' +""" import re from inspect import isclass from itertools import chain -from stdnet import odm, getdb +from stdnet import getdb, odm from stdnet.utils import grouper -from .models import WordItem from . import processors +from .models import WordItem class SearchEngine(odm.SearchEngine): """A python implementation for the :class:`stdnet.odm.SearchEngine` -driver. + driver. -:parameter min_word_length: minimum number of words required by the engine - to work. + :parameter min_word_length: minimum number of words required by the engine + to work. - Default ``3``. + Default ``3``. -:parameter stop_words: list of words not included in the search engine. + :parameter stop_words: list of words not included in the search engine. - Default ``stdnet.apps.searchengine.ignore.STOP_WORDS`` + Default ``stdnet.apps.searchengine.ignore.STOP_WORDS`` -:parameter metaphone: If ``True`` the double metaphone_ algorithm will be - used to store and search for words. The metaphone should be the last - world middleware to be added. + :parameter metaphone: If ``True`` the double metaphone_ algorithm will be + used to store and search for words. The metaphone should be the last + world middleware to be added. - Default ``True``. + Default ``True``. -:parameter splitters: string whose characters are used to split text - into words. If this parameter is set to `"_-"`, - for example, than the word `bla_pippo_ciao-moon` will - be split into `bla`, `pippo`, `ciao` and `moon`. - Set to empty string for no splitting. - Splitting will always occur on white spaces. + :parameter splitters: string whose characters are used to split text + into words. If this parameter is set to `"_-"`, + for example, than the word `bla_pippo_ciao-moon` will + be split into `bla`, `pippo`, `ciao` and `moon`. + Set to empty string for no splitting. + Splitting will always occur on white spaces. - Default - ``stdnet.apps.searchengine.ignore.PUNCTUATION_CHARS``. + Default + ``stdnet.apps.searchengine.ignore.PUNCTUATION_CHARS``. + + .. _metaphone: http://en.wikipedia.org/wiki/Metaphone""" -.. _metaphone: http://en.wikipedia.org/wiki/Metaphone -""" REGISTERED_MODELS = {} ITEM_PROCESSORS = [] - def __init__(self, backend=None, min_word_length=3, stop_words=None, - metaphone=True, stemming=True, splitters=None, **kwargs): + def __init__( + self, + backend=None, + min_word_length=3, + stop_words=None, + metaphone=True, + stemming=True, + splitters=None, + **kwargs + ): super(SearchEngine, self).__init__(backend=backend, **kwargs) self.MIN_WORD_LENGTH = min_word_length - splitters = (splitters if splitters is not None else - processors.PUNCTUATION_CHARS) + splitters = splitters if splitters is not None else processors.PUNCTUATION_CHARS if splitters: self.punctuation_regex = re.compile(r"[%s]" % re.escape(splitters)) else: @@ -142,9 +149,9 @@ def flush(self): def add_item(self, item, words, transaction): for word in words: - transaction.add(WordItem(word=word, - model_type=item.__class__, - object_id=item.id)) + transaction.add( + WordItem(word=word, model_type=item.__class__, object_id=item.id) + ) def remove_item(self, item_or_model, transaction, ids=None): query = transaction.query(WordItem) @@ -153,8 +160,9 @@ def remove_item(self, item_or_model, transaction, ids=None): if ids is not None: wi = wi.filter(object_id=ids) else: - wi = query.filter(model_type=item_or_model.__class__, - object_id=item_or_model.id) + wi = query.filter( + model_type=item_or_model.__class__, object_id=item_or_model.id + ) transaction.delete(wi) def search(self, text, include=None, exclude=None, lookup=None): @@ -162,32 +170,31 @@ def search(self, text, include=None, exclude=None, lookup=None): return self._search(words, include, exclude, lookup) def search_model(self, q, text, lookup=None): - '''Implements :meth:`stdnet.odm.SearchEngine.search_model`. -It return a new :class:`stdnet.odm.QueryElem` instance from -the input :class:`Query` and the *text* to search.''' + """Implements :meth:`stdnet.odm.SearchEngine.search_model`. + It return a new :class:`stdnet.odm.QueryElem` instance from + the input :class:`Query` and the *text* to search.""" words = self.words_from_text(text, for_search=True) if not words: return q qs = self._search(words, include=(q.model,), lookup=lookup) - qs = tuple((q.get_field('object_id') for q in qs)) - return odm.intersect((q,)+qs) + qs = tuple((q.get_field("object_id") for q in qs)) + return odm.intersect((q,) + qs) def worditems(self, model=None): q = self.router.worditem.query() if model: if not isclass(model): - return q.filter(model_type=model.__class__, - object_id=model.id) + return q.filter(model_type=model.__class__, object_id=model.id) else: return q.filter(model_type=model) else: return q def index_items_from_model(self, items, model): - self.logger.debug('Indexing %s objects of %s model.', - len(items), model._meta) + self.logger.debug("Indexing %s objects of %s model.", len(items), model._meta) return self.router.worditem.backend.execute( - self._index_items_from_model(items, model)) + self._index_items_from_model(items, model) + ) def reindex(self): backend = self.router.worditem.backend @@ -205,8 +212,8 @@ def _reindex(self): yield total def _search(self, words, include=None, exclude=None, lookup=None): - '''Full text search. Return a list of queries to intersect.''' - lookup = lookup or 'contains' + """Full text search. Return a list of queries to intersect.""" + lookup = lookup or "contains" query = self.router.worditem.query() if include: query = query.filter(model_type__in=include) @@ -215,11 +222,11 @@ def _search(self, words, include=None, exclude=None, lookup=None): if not words: return [query] qs = [] - if lookup == 'in': + if lookup == "in": # we are looking for items with at least one word in it qs.append(query.filter(word__in=words)) - elif lookup == 'contains': - #we want to match every single words + elif lookup == "contains": + # we want to match every single words for word in words: qs.append(query.filter(word=word)) else: diff --git a/stdnet/apps/searchengine/models.py b/stdnet/apps/searchengine/models.py index 3d1caa5..143a7fa 100644 --- a/stdnet/apps/searchengine/models.py +++ b/stdnet/apps/searchengine/models.py @@ -1,14 +1,13 @@ -'''\ +"""\ Search Engine and Tagging models. Just two of them, one for storing Words and one for linking other objects to Words. -''' +""" from inspect import isclass from stdnet import odm class WordItemManager(odm.Manager): - def for_model(self, model): q = self.query() if not isclass(model): @@ -18,9 +17,10 @@ def for_model(self, model): class WordItem(odm.StdModel): - '''A model for associating a word with general -:class:`stdnet.odm.StdModel` instance.''' - id = odm.CompositeIdField('word', 'model_type', 'object_id') + """A model for associating a word with general + :class:`stdnet.odm.StdModel` instance.""" + + id = odm.CompositeIdField("word", "model_type", "object_id") word = odm.SymbolField() model_type = odm.ModelField() object_id = odm.SymbolField() @@ -34,11 +34,10 @@ class Meta: ordering = -odm.autoincrement() def object(self, session): - '''Instance of :attr:`model_type` with id :attr:`object_id`.''' - if not hasattr(self, '_object'): + """Instance of :attr:`model_type` with id :attr:`object_id`.""" + if not hasattr(self, "_object"): pkname = self.model_type._meta.pkname() - query = session.query(self.model_type).filter(**{pkname: - self.object_id}) + query = session.query(self.model_type).filter(**{pkname: self.object_id}) return query.items(callback=self.__set_object) else: return self._object diff --git a/stdnet/apps/searchengine/processors/__init__.py b/stdnet/apps/searchengine/processors/__init__.py index fc262fc..77927e9 100644 --- a/stdnet/apps/searchengine/processors/__init__.py +++ b/stdnet/apps/searchengine/processors/__init__.py @@ -1,10 +1,9 @@ -from .ignore import STOP_WORDS, PUNCTUATION_CHARS +from .ignore import PUNCTUATION_CHARS, STOP_WORDS from .metaphone import dm as double_metaphone from .porter import PorterStemmer class stopwords: - def __init__(self, stp=None): self.stp = stp if stp is not None else STOP_WORDS @@ -16,7 +15,7 @@ def __call__(self, words): def metaphone_processor(words): - '''Double metaphone word processor.''' + """Double metaphone word processor.""" for word in words: for w in double_metaphone(word): if w: @@ -26,8 +25,8 @@ def metaphone_processor(words): def tolerant_metaphone_processor(words): - '''Double metaphone word processor slightly modified so that when no -words are returned by the algorithm, the original word is returned.''' + """Double metaphone word processor slightly modified so that when no + words are returned by the algorithm, the original word is returned.""" for word in words: r = 0 for w in double_metaphone(word): @@ -41,8 +40,8 @@ def tolerant_metaphone_processor(words): def stemming_processor(words): - '''Porter Stemmer word processor''' + """Porter Stemmer word processor""" stem = PorterStemmer().stem for word in words: - word = stem(word, 0, len(word)-1) + word = stem(word, 0, len(word) - 1) yield word diff --git a/stdnet/apps/searchengine/processors/ignore.py b/stdnet/apps/searchengine/processors/ignore.py index 393ab29..4fd6424 100644 --- a/stdnet/apps/searchengine/processors/ignore.py +++ b/stdnet/apps/searchengine/processors/ignore.py @@ -1,7 +1,8 @@ __test__ = False # from # http://www.textfixer.com/resources/common-english-words.txt -STOP_WORDS = set('''a,able,about,across,after,all,almost,also,am,among,an,and,\ +STOP_WORDS = set( + """a,able,about,across,after,all,almost,also,am,among,an,and,\ any,are,as,at,be,because,been,but,by,can,cannot,could,dear,did,do,does,either,\ else,ever,every,for,from,get,got,had,has,have,he,her,hers,him,his,how,however,\ i,if,in,into,is,it,its,just,least,let,like,likely,may,me,might,most,must,my,\ @@ -9,12 +10,15 @@ she,should,since,so,some,than,that,the,their,them,then,there,these,they,this,\ tis,to,too,twas,us,wants,was,we,were,what,when,where,which,while,who,whom,\ why,will,with,would,yet,you,your -'''.split(',')) +""".split( + "," + ) +) -ALPHABET = 'abcdefghijklmnopqrstuvwxyz' -NUMBERS = '0123456789' -ALPHA_NUMERIC = ALPHABET+NUMBERS +ALPHABET = "abcdefghijklmnopqrstuvwxyz" +NUMBERS = "0123456789" +ALPHA_NUMERIC = ALPHABET + NUMBERS # Consider these characters to be punctuation # they will be replaced with spaces prior to word extraction diff --git a/stdnet/apps/searchengine/processors/metaphone.py b/stdnet/apps/searchengine/processors/metaphone.py index 6a9b91d..ac66c06 100644 --- a/stdnet/apps/searchengine/processors/metaphone.py +++ b/stdnet/apps/searchengine/processors/metaphone.py @@ -17,6 +17,7 @@ # excellent communication. # The script was also updated to use utf-8 rather than latin-1. import sys + try: NNNN = unicode('N') decode = lambda x: x.decode('utf-8', 'ignore') diff --git a/stdnet/apps/searchengine/processors/porter.py b/stdnet/apps/searchengine/processors/porter.py index 9ace91b..eae7d99 100644 --- a/stdnet/apps/searchengine/processors/porter.py +++ b/stdnet/apps/searchengine/processors/porter.py @@ -35,7 +35,6 @@ class PorterStemmer(object): - def __init__(self): """The main part of the stemming algorithm starts here. b is a buffer holding a word to be stemmed. The letters are in b[k0], @@ -49,18 +48,23 @@ def __init__(self): self.b = "" # buffer for word to be stemmed self.k = 0 self.k0 = 0 - self.j = 0 # j is a general offset into the string + self.j = 0 # j is a general offset into the string def cons(self, i): """cons(i) is TRUE <=> b[i] is a consonant.""" - if (self.b[i] == 'a' or self.b[i] == 'e' or self.b[i] == 'i' or - self.b[i] == 'o' or self.b[i] == 'u'): + if ( + self.b[i] == "a" + or self.b[i] == "e" + or self.b[i] == "i" + or self.b[i] == "o" + or self.b[i] == "u" + ): return 0 - if self.b[i] == 'y': + if self.b[i] == "y": if i == self.k0: return 1 else: - return (not self.cons(i - 1)) + return not self.cons(i - 1) return 1 def m(self): @@ -111,7 +115,7 @@ def doublec(self, j): """doublec(j) is TRUE <=> j,(j-1) contain a double consonant.""" if j < (self.k0 + 1): return 0 - if (self.b[j] != self.b[j-1]): + if self.b[j] != self.b[j - 1]: return 0 return self.cons(j) @@ -123,22 +127,26 @@ def cvc(self, i): cav(e), lov(e), hop(e), crim(e), but snow, box, tray. """ - if (i < (self.k0 + 2) or not self.cons(i) or self.cons(i-1) or - not self.cons(i-2)): + if ( + i < (self.k0 + 2) + or not self.cons(i) + or self.cons(i - 1) + or not self.cons(i - 2) + ): return 0 ch = self.b[i] - if ch == 'w' or ch == 'x' or ch == 'y': + if ch == "w" or ch == "x" or ch == "y": return 0 return 1 def ends(self, s): """ends(s) is TRUE <=> k0,...k ends with the string s.""" length = len(s) - if s[length - 1] != self.b[self.k]: # tiny speed-up + if s[length - 1] != self.b[self.k]: # tiny speed-up return 0 if length > (self.k - self.k0 + 1): return 0 - if self.b[self.k-length+1:self.k+1] != s: + if self.b[self.k - length + 1 : self.k + 1] != s: return 0 self.j = self.k - length return 1 @@ -147,7 +155,7 @@ def setto(self, s): """setto(s) sets (j+1),...k to the characters in the string s, readjusting k.""" length = len(s) - self.b = self.b[:self.j+1] + s + self.b[self.j+length+1:] + self.b = self.b[: self.j + 1] + s + self.b[self.j + length + 1 :] self.k = self.j + length def r(self, s): @@ -158,156 +166,214 @@ def r(self, s): def step1ab(self): """step1ab() gets rid of plurals and -ed or -ing. e.g. - caresses -> caress - ponies -> poni - ties -> ti - caress -> caress - cats -> cat + caresses -> caress + ponies -> poni + ties -> ti + caress -> caress + cats -> cat - feed -> feed - agreed -> agree - disabled -> disable + feed -> feed + agreed -> agree + disabled -> disable - matting -> mat - mating -> mate - meeting -> meet - milling -> mill - messing -> mess + matting -> mat + mating -> mate + meeting -> meet + milling -> mill + messing -> mess - meetings -> meet + meetings -> meet """ - if self.b[self.k] == 's': + if self.b[self.k] == "s": if self.ends("sses"): self.k = self.k - 2 elif self.ends("ies"): self.setto("i") - elif self.b[self.k - 1] != 's': + elif self.b[self.k - 1] != "s": self.k = self.k - 1 if self.ends("eed"): if self.m() > 0: self.k = self.k - 1 elif (self.ends("ed") or self.ends("ing")) and self.vowelinstem(): self.k = self.j - if self.ends("at"): self.setto("ate") - elif self.ends("bl"): self.setto("ble") - elif self.ends("iz"): self.setto("ize") + if self.ends("at"): + self.setto("ate") + elif self.ends("bl"): + self.setto("ble") + elif self.ends("iz"): + self.setto("ize") elif self.doublec(self.k): self.k = self.k - 1 ch = self.b[self.k] - if ch == 'l' or ch == 's' or ch == 'z': + if ch == "l" or ch == "s" or ch == "z": self.k = self.k + 1 - elif (self.m() == 1 and self.cvc(self.k)): + elif self.m() == 1 and self.cvc(self.k): self.setto("e") def step1c(self): """step1c() turns terminal y to i when there is another vowel in the stem.""" - if (self.ends("y") and self.vowelinstem()): - self.b = self.b[:self.k] + 'i' + self.b[self.k+1:] + if self.ends("y") and self.vowelinstem(): + self.b = self.b[: self.k] + "i" + self.b[self.k + 1 :] def step2(self): """step2() maps double suffices to single ones. so -ization ( = -ize plus -ation) maps to -ize etc. note that the string before the suffix must give m() > 0. """ - if self.b[self.k - 1] == 'a': - if self.ends("ational"): self.r("ate") - elif self.ends("tional"): self.r("tion") - elif self.b[self.k - 1] == 'c': - if self.ends("enci"): self.r("ence") - elif self.ends("anci"): self.r("ance") - elif self.b[self.k - 1] == 'e': - if self.ends("izer"): self.r("ize") - elif self.b[self.k - 1] == 'l': - if self.ends("bli"): self.r("ble") # --DEPARTURE-- + if self.b[self.k - 1] == "a": + if self.ends("ational"): + self.r("ate") + elif self.ends("tional"): + self.r("tion") + elif self.b[self.k - 1] == "c": + if self.ends("enci"): + self.r("ence") + elif self.ends("anci"): + self.r("ance") + elif self.b[self.k - 1] == "e": + if self.ends("izer"): + self.r("ize") + elif self.b[self.k - 1] == "l": + if self.ends("bli"): + self.r("ble") # --DEPARTURE-- # To match the published algorithm, replace this phrase with # if self.ends("abli"): self.r("able") - elif self.ends("alli"): self.r("al") - elif self.ends("entli"): self.r("ent") - elif self.ends("eli"): self.r("e") - elif self.ends("ousli"): self.r("ous") - elif self.b[self.k - 1] == 'o': - if self.ends("ization"): self.r("ize") - elif self.ends("ation"): self.r("ate") - elif self.ends("ator"): self.r("ate") - elif self.b[self.k - 1] == 's': - if self.ends("alism"): self.r("al") - elif self.ends("iveness"): self.r("ive") - elif self.ends("fulness"): self.r("ful") - elif self.ends("ousness"): self.r("ous") - elif self.b[self.k - 1] == 't': - if self.ends("aliti"): self.r("al") - elif self.ends("iviti"): self.r("ive") - elif self.ends("biliti"): self.r("ble") - elif self.b[self.k - 1] == 'g': # --DEPARTURE-- - if self.ends("logi"): self.r("log") + elif self.ends("alli"): + self.r("al") + elif self.ends("entli"): + self.r("ent") + elif self.ends("eli"): + self.r("e") + elif self.ends("ousli"): + self.r("ous") + elif self.b[self.k - 1] == "o": + if self.ends("ization"): + self.r("ize") + elif self.ends("ation"): + self.r("ate") + elif self.ends("ator"): + self.r("ate") + elif self.b[self.k - 1] == "s": + if self.ends("alism"): + self.r("al") + elif self.ends("iveness"): + self.r("ive") + elif self.ends("fulness"): + self.r("ful") + elif self.ends("ousness"): + self.r("ous") + elif self.b[self.k - 1] == "t": + if self.ends("aliti"): + self.r("al") + elif self.ends("iviti"): + self.r("ive") + elif self.ends("biliti"): + self.r("ble") + elif self.b[self.k - 1] == "g": # --DEPARTURE-- + if self.ends("logi"): + self.r("log") # To match the published algorithm, delete this phrase def step3(self): """step3() dels with -ic-, -full, -ness etc. similar strategy to step2.""" - if self.b[self.k] == 'e': - if self.ends("icate"): self.r("ic") - elif self.ends("ative"): self.r("") - elif self.ends("alize"): self.r("al") - elif self.b[self.k] == 'i': - if self.ends("iciti"): self.r("ic") - elif self.b[self.k] == 'l': - if self.ends("ical"): self.r("ic") - elif self.ends("ful"): self.r("") - elif self.b[self.k] == 's': - if self.ends("ness"): self.r("") + if self.b[self.k] == "e": + if self.ends("icate"): + self.r("ic") + elif self.ends("ative"): + self.r("") + elif self.ends("alize"): + self.r("al") + elif self.b[self.k] == "i": + if self.ends("iciti"): + self.r("ic") + elif self.b[self.k] == "l": + if self.ends("ical"): + self.r("ic") + elif self.ends("ful"): + self.r("") + elif self.b[self.k] == "s": + if self.ends("ness"): + self.r("") def step4(self): """step4() takes off -ant, -ence etc., in context vcvc.""" - if self.b[self.k - 1] == 'a': - if self.ends("al"): pass - else: return - elif self.b[self.k - 1] == 'c': - if self.ends("ance"): pass - elif self.ends("ence"): pass - else: return - elif self.b[self.k - 1] == 'e': - if self.ends("er"): pass - else: return - elif self.b[self.k - 1] == 'i': - if self.ends("ic"): pass - else: return - elif self.b[self.k - 1] == 'l': - if self.ends("able"): pass - elif self.ends("ible"): pass - else: return - elif self.b[self.k - 1] == 'n': - if self.ends("ant"): pass - elif self.ends("ement"): pass - elif self.ends("ment"): pass - elif self.ends("ent"): pass - else: return - elif self.b[self.k - 1] == 'o': - if (self.ends("ion") and - (self.b[self.j] == 's' or self.b[self.j] == 't')): + if self.b[self.k - 1] == "a": + if self.ends("al"): + pass + else: + return + elif self.b[self.k - 1] == "c": + if self.ends("ance"): + pass + elif self.ends("ence"): + pass + else: + return + elif self.b[self.k - 1] == "e": + if self.ends("er"): + pass + else: + return + elif self.b[self.k - 1] == "i": + if self.ends("ic"): + pass + else: + return + elif self.b[self.k - 1] == "l": + if self.ends("able"): + pass + elif self.ends("ible"): + pass + else: + return + elif self.b[self.k - 1] == "n": + if self.ends("ant"): + pass + elif self.ends("ement"): + pass + elif self.ends("ment"): + pass + elif self.ends("ent"): + pass + else: + return + elif self.b[self.k - 1] == "o": + if self.ends("ion") and (self.b[self.j] == "s" or self.b[self.j] == "t"): pass elif self.ends("ou"): pass # takes care of -ous else: return - elif self.b[self.k - 1] == 's': - if self.ends("ism"): pass - else: return - elif self.b[self.k - 1] == 't': - if self.ends("ate"): pass - elif self.ends("iti"): pass - else: return - elif self.b[self.k - 1] == 'u': - if self.ends("ous"): pass - else: return - elif self.b[self.k - 1] == 'v': - if self.ends("ive"): pass - else: return - elif self.b[self.k - 1] == 'z': - if self.ends("ize"): pass - else: return + elif self.b[self.k - 1] == "s": + if self.ends("ism"): + pass + else: + return + elif self.b[self.k - 1] == "t": + if self.ends("ate"): + pass + elif self.ends("iti"): + pass + else: + return + elif self.b[self.k - 1] == "u": + if self.ends("ous"): + pass + else: + return + elif self.b[self.k - 1] == "v": + if self.ends("ive"): + pass + else: + return + elif self.b[self.k - 1] == "z": + if self.ends("ize"): + pass + else: + return else: return if self.m() > 1: @@ -318,12 +384,12 @@ def step5(self): m() > 1. """ self.j = self.k - if self.b[self.k] == 'e': + if self.b[self.k] == "e": a = self.m() - if a > 1 or (a == 1 and not self.cvc(self.k-1)): + if a > 1 or (a == 1 and not self.cvc(self.k - 1)): self.k = self.k - 1 - if self.b[self.k] == 'l' and self.doublec(self.k) and self.m() > 1: - self.k = self.k -1 + if self.b[self.k] == "l" and self.doublec(self.k) and self.m() > 1: + self.k = self.k - 1 def stem(self, p, i, j): """In stem(p,i,j), p is a char pointer, and the string to be stemmed @@ -339,7 +405,7 @@ def stem(self, p, i, j): self.k = j self.k0 = i if self.k <= self.k0 + 1: - return self.b # --DEPARTURE-- + return self.b # --DEPARTURE-- # With this line, strings of length 1 or 2 don't go through the # stemming process, although no mention is made of this in the # published algorithm. Remove the line to match the published @@ -350,4 +416,4 @@ def stem(self, p, i, j): self.step3() self.step4() self.step5() - return self.b[self.k0:self.k+1] + return self.b[self.k0 : self.k + 1] diff --git a/stdnet/apps/tasks/__init__.py b/stdnet/apps/tasks/__init__.py index 9ad5dc4..c33587e 100644 --- a/stdnet/apps/tasks/__init__.py +++ b/stdnet/apps/tasks/__init__.py @@ -1,5 +1,4 @@ -from pulsar.apps import data -from pulsar.apps import tasks +from pulsar.apps import data, tasks from .models import TaskData @@ -9,7 +8,6 @@ class Store(data.Store): class TaskBackend(tasks.TaskBackend): - def get_task(self, task_id=None, timeout=1): task_manager = self.task_manager() # @@ -21,7 +19,7 @@ def get_task(self, task_id=None, timeout=1): yield task_data.as_task() -tasks.task_backends['stdnet'] = TaskBackend +tasks.task_backends["stdnet"] = TaskBackend -data.register_store('redis', 'stdnet.apps.tasks.Store') +data.register_store("redis", "stdnet.apps.tasks.Store") diff --git a/stdnet/apps/tasks/models.py b/stdnet/apps/tasks/models.py index b709608..d226e9e 100644 --- a/stdnet/apps/tasks/models.py +++ b/stdnet/apps/tasks/models.py @@ -1,4 +1,3 @@ - from stdnet import odm @@ -23,7 +22,7 @@ class TaskData(odm.StdModel): executing = odm.SetField(class_field=True) class Meta: - app_label = 'tasks' + app_label = "tasks" def as_task(self): params = dict(self.meta or {}) @@ -32,4 +31,4 @@ def as_task(self): return backends.Task(self.id, **params) def __unicode__(self): - return '%s (%s)' % (self.name, self.status) + return "%s (%s)" % (self.name, self.status) diff --git a/stdnet/backends/__init__.py b/stdnet/backends/__init__.py index 7de2f72..0ed395c 100755 --- a/stdnet/backends/__init__.py +++ b/stdnet/backends/__init__.py @@ -4,74 +4,80 @@ try: from pulsar import maybe_async as async -except ImportError: # pragma noproxy +except ImportError: # pragma noproxy def async(gen): raise NotImplementedError +from stdnet.utils import ( + int_or_float, + iteritems, + raise_error_trace, + to_string, + urlencode, + urlparse, +) from stdnet.utils.exceptions import * -from stdnet.utils import raise_error_trace from stdnet.utils.importer import import_module -from stdnet.utils import (iteritems, int_or_float, to_string, urlencode, - urlparse) - -__all__ = ['BackendStructure', - 'BackendDataServer', - 'BackendQuery', - 'session_result', - 'session_data', - 'instance_session_result', - 'query_result', - 'range_lookups', - 'getdb', - 'settings', - 'async'] - - -query_result = namedtuple('query_result', 'key count') +__all__ = [ + "BackendStructure", + "BackendDataServer", + "BackendQuery", + "session_result", + "session_data", + "instance_session_result", + "query_result", + "range_lookups", + "getdb", + "settings", + "async", +] + + +query_result = namedtuple("query_result", "key count") # tuple containing information about a commit/delete operation on the backend # server. Id is the id in the session, persistent is a boolean indicating # if the instance is persistent on the backend, bid is the id in the backend. -instance_session_result = namedtuple('instance_session_result', - 'iid persistent id deleted score') -session_data = namedtuple('session_data', - 'meta dirty deletes queries structures') -session_result = namedtuple('session_result', 'meta results') +instance_session_result = namedtuple( + "instance_session_result", "iid persistent id deleted score" +) +session_data = namedtuple("session_data", "meta dirty deletes queries structures") +session_result = namedtuple("session_result", "meta results") pass_through = lambda x: x str_lower_case = lambda x: to_string(x).lower() range_lookups = { - 'gt': int_or_float, - 'ge': int_or_float, - 'lt': int_or_float, - 'le': int_or_float, - 'contains': pass_through, - 'startswith': pass_through, - 'endswith': pass_through, - 'icontains': str_lower_case, - 'istartswith': str_lower_case, - 'iendswith': str_lower_case} + "gt": int_or_float, + "ge": int_or_float, + "lt": int_or_float, + "le": int_or_float, + "contains": pass_through, + "startswith": pass_through, + "endswith": pass_through, + "icontains": str_lower_case, + "istartswith": str_lower_case, + "iendswith": str_lower_case, +} def get_connection_string(scheme, address, params): if address: - address = ':'.join((str(b) for b in address)) + address = ":".join((str(b) for b in address)) else: - address = '' + address = "" if params: - address += '?' + urlencode(params) - return scheme + '://' + address + address += "?" + urlencode(params) + return scheme + "://" + address class Settings(object): - def __init__(self): - self.DEFAULT_BACKEND = 'redis://127.0.0.1:6379?db=7' - self.CHARSET = 'utf-8' + self.DEFAULT_BACKEND = "redis://127.0.0.1:6379?db=7" + self.CHARSET = "utf-8" self.REDIS_PY_PARSER = False self.ASYNC_BINDINGS = False @@ -80,21 +86,21 @@ def __init__(self): class BackendStructure(object): - '''Interface for :class:`stdnet.odm.Structure` backends. + """Interface for :class:`stdnet.odm.Structure` backends. -.. attribute:: instance + .. attribute:: instance - The :class:`stdnet.odm.Structure` which this backend represents. + The :class:`stdnet.odm.Structure` which this backend represents. -.. attribute:: backend + .. attribute:: backend - The :class:`BackendDataServer` + The :class:`BackendDataServer` -.. attribute:: client + .. attribute:: client - The client of the :class:`BackendDataServer` + The client of the :class:`BackendDataServer` + """ -''' def __init__(self, instance, backend, client): self.instance = instance self.backend = backend @@ -121,7 +127,7 @@ def size(self): class BackendDataServer(object): - '''Generic interface for a backend databases. + """Generic interface for a backend databases. It should not be initialised directly, the :func:`getdb` function should be used instead. @@ -157,34 +163,33 @@ class BackendDataServer(object): The default model Manager for this backend. If not provided, the :class:`stdnet.odm.Manager` is used. Default ``None``. - ''' + """ + Query = None structure_module = None default_manager = None default_port = 8000 struct_map = {} - def __init__(self, name=None, address=None, charset=None, namespace='', - **params): - self.__name = name or 'dummy' - address = address or ':' + def __init__(self, name=None, address=None, charset=None, namespace="", **params): + self.__name = name or "dummy" + address = address or ":" if not isinstance(address, (list, tuple)): - address = address.split(':') + address = address.split(":") else: address = list(address) if not address[0]: - address[0] = '127.0.0.1' + address[0] = "127.0.0.1" if len(address) == 2: if not address[1]: address[1] = self.default_port else: address[1] = int(address[1]) - self.charset = charset or 'utf-8' + self.charset = charset or "utf-8" self.params = params self.namespace = namespace self.client = self.setup_connection(address) - self.connection_string = get_connection_string( - self.name, address, self.params) + self.connection_string = get_connection_string(self.name, address, self.params) @property def name(self): @@ -208,29 +213,28 @@ def issame(self, other): def basekey(self, meta, *args): """Calculate the key to access model data. -:parameter meta: a :class:`stdnet.odm.Metaclass`. -:parameter args: optional list of strings to prepend to the basekey. -:rtype: a native string -""" - key = '%s%s' % (self.namespace, meta.modelkey) - postfix = ':'.join((str(p) for p in args if p is not None)) - return '%s:%s' % (key, postfix) if postfix else key + :parameter meta: a :class:`stdnet.odm.Metaclass`. + :parameter args: optional list of strings to prepend to the basekey. + :rtype: a native string""" + key = "%s%s" % (self.namespace, meta.modelkey) + postfix = ":".join((str(p) for p in args if p is not None)) + return "%s:%s" % (key, postfix) if postfix else key def disconnect(self): - '''Disconnect the connection.''' + """Disconnect the connection.""" pass def __repr__(self): return self.connection_string + __str__ = __repr__ def make_objects(self, meta, data, related_fields=None): - '''Generator of :class:`stdnet.odm.StdModel` instances with data -from database. + """Generator of :class:`stdnet.odm.StdModel` instances with data + from database. -:parameter meta: instance of model :class:`stdnet.odm.Metaclass`. -:parameter data: iterator over instances data. -''' + :parameter meta: instance of model :class:`stdnet.odm.Metaclass`. + :parameter data: iterator over instances data.""" make_object = meta.make_object related_data = [] if related_fields: @@ -242,8 +246,12 @@ def make_objects(self, meta, data, related_fields=None): else: multi = False relmodel = field.relmodel - related = dict(((obj.id, obj) for obj in - self.make_objects(relmodel._meta, fdata))) + related = dict( + ( + (obj.id, obj) + for obj in self.make_objects(relmodel._meta, fdata) + ) + ) related_data.append((field, related, multi)) for state in data: instance = make_object(state, self) @@ -261,15 +269,17 @@ def objects_from_db(self, meta, data, related_fields=None): return list(self.make_objects(meta, data, related_fields)) def structure(self, instance, client=None): - '''Create a backend :class:`stdnet.odm.Structure` handler. + """Create a backend :class:`stdnet.odm.Structure` handler. :param instance: a :class:`stdnet.odm.Structure` :param client: Optional client handler. - ''' + """ struct = self.struct_map.get(instance._meta.name) if struct is None: - raise ModelNotAvailable('"%s" is not available for backend ' - '"%s"' % (instance._meta.name, self)) + raise ModelNotAvailable( + '"%s" is not available for backend ' + '"%s"' % (instance._meta.name, self) + ) client = client if client is not None else self.client return struct(instance, self, client) @@ -287,54 +297,54 @@ def execute(self, result, callback=None): # VIRTUAL METHODS def is_async(self): - '''Check if the backend handler is asynchronous.''' + """Check if the backend handler is asynchronous.""" return False def setup_model(self, meta): - '''Invoked when registering a model with a backend. This is a chance to -perform model specific operation in the server. For example, mongo db ensure -indices are created.''' + """Invoked when registering a model with a backend. This is a chance to + perform model specific operation in the server. For example, mongo db ensure + indices are created.""" pass def clean(self, meta): - '''Remove temporary keys for a model''' + """Remove temporary keys for a model""" pass def ping(self): - '''Ping the server''' + """Ping the server""" pass def instance_keys(self, obj): - '''Return a list of database keys used by instance *obj*''' + """Return a list of database keys used by instance *obj*""" return [self.basekey(obj._meta, obj.pkvalue())] def auto_id_to_python(self, value): - '''Return a proper python value for the auto id.''' + """Return a proper python value for the auto id.""" return value # PURE VIRTUAL METHODS def setup_connection(self, address): - '''Callback during initialization. Implementation should override -this function for customizing their handling of connection parameters. It -must return a instance of the backend handler.''' + """Callback during initialization. Implementation should override + this function for customizing their handling of connection parameters. It + must return a instance of the backend handler.""" raise NotImplementedError() def execute_session(self, session, callback): - '''Execute a :class:`stdnet.odm.Session` in the backend server.''' + """Execute a :class:`stdnet.odm.Session` in the backend server.""" raise NotImplementedError() def model_keys(self, meta): - '''Return a list of database keys used by model *model*''' + """Return a list of database keys used by model *model*""" raise NotImplementedError() def flush(self, meta=None): - '''Flush the database or drop all instances of a model/collection''' + """Flush the database or drop all instances of a model/collection""" raise NotImplementedError() class BackendQuery(object): - '''Asynchronous query interface class. + """Asynchronous query interface class. Implements the database queries specified by :class:`stdnet.odm.Query`. @@ -346,9 +356,10 @@ class BackendQuery(object): flag indicating if the query has been executed in the backend server - ''' + """ + def __init__(self, queryelem, timeout=0, **kwargs): - '''Initialize the query for the backend database.''' + """Initialize the query for the backend database.""" self.queryelem = queryelem self.expire = max(timeout, 10) self.timeout = timeout @@ -385,7 +396,7 @@ def executed(self): @property def cache(self): - '''Cached results.''' + """Cached results.""" return self.__slice_cache def __len__(self): @@ -414,25 +425,24 @@ def items(self, slic=None, callback=None): def delete(self, qs): with self.session.begin() as t: t.delete(qs) - return self.backend.execute(t.on_result, - lambda _: t.deleted.get(self.meta)) + return self.backend.execute(t.on_result, lambda _: t.deleted.get(self.meta)) # VIRTUAL METHODS - MUST BE IMPLEMENTED BY BACKENDS - def _has(self, val): # pragma: no cover + def _has(self, val): # pragma: no cover raise NotImplementedError - def _items(self, slic): # pragma: no cover + def _items(self, slic): # pragma: no cover raise NotImplementedError - def _build(self, **kwargs): # pragma: no cover + def _build(self, **kwargs): # pragma: no cover raise NotImplementedError - def _execute_query(self): # pragma: no cover - '''Execute the query without fetching data from server. + def _execute_query(self): # pragma: no cover + """Execute the query without fetching data from server. Must be implemented by data-server backends and return a generator. - ''' + """ raise NotImplementedError # PRIVATE METHODS @@ -469,14 +479,14 @@ def _slice_items(self, slic): def parse_backend(backend): """Converts the "backend" into the database connection parameters. -It returns a (scheme, host, params) tuple.""" + It returns a (scheme, host, params) tuple.""" r = urlparse.urlsplit(backend) scheme, host = r.scheme, r.netloc path, query = r.path, r.query if path and not query: - query, path = path, '' + query, path = path, "" if query: - if query.find('?'): + if query.find("?"): path = query else: query = query[1:] @@ -490,14 +500,14 @@ def parse_backend(backend): def _getdb(scheme, host, params): try: - module = import_module('stdnet.backends.%sb' % scheme) + module = import_module("stdnet.backends.%sb" % scheme) except ImportError: raise NotImplementedError - return getattr(module, 'BackendDataServer')(scheme, host, **params) + return getattr(module, "BackendDataServer")(scheme, host, **params) def getdb(backend=None, **kwargs): - '''get a :class:`BackendDataServer`.''' + """get a :class:`BackendDataServer`.""" if isinstance(backend, BackendDataServer): return backend backend = backend or settings.DEFAULT_BACKEND @@ -505,8 +515,8 @@ def getdb(backend=None, **kwargs): return None scheme, address, params = parse_backend(backend) params.update(kwargs) - if 'timeout' in params: - params['timeout'] = int(params['timeout']) + if "timeout" in params: + params["timeout"] = int(params["timeout"]) return _getdb(scheme, address, params) diff --git a/stdnet/backends/redisb/__init__.py b/stdnet/backends/redisb/__init__.py index bb8a036..7b27a05 100755 --- a/stdnet/backends/redisb/__init__.py +++ b/stdnet/backends/redisb/__init__.py @@ -1,33 +1,40 @@ -'''Redis backend implementation''' +"""Redis backend implementation""" import json from functools import partial -from .client import * - import stdnet -from stdnet import FieldValueError, CommitException, QuerySetError -from stdnet.utils import (gen_unique_id, zip, ispy3k, - native_str, flat_mapping, unique_tuple) -from stdnet.backends import (BackendStructure, session_result, - instance_session_result) +from stdnet import CommitException, FieldValueError, QuerySetError +from stdnet.backends import BackendStructure, instance_session_result, session_result +from stdnet.utils import ( + flat_mapping, + gen_unique_id, + ispy3k, + native_str, + unique_tuple, + zip, +) + +from .client import * -MIN_FLOAT = -1.e99 +MIN_FLOAT = -1.0e99 ############################################################################ # prefixes for data -OBJ = 'obj' # the hash table for a instance -TMP = 'tmp' # temorary key -ODM_SCRIPTS = ('odmrun', 'move2set', 'zdiffstore') +OBJ = "obj" # the hash table for a instance +TMP = "tmp" # temorary key +ODM_SCRIPTS = ("odmrun", "move2set", "zdiffstore") ############################################################################ if ispy3k: + def decode(value, encoding): if isinstance(value, bytes): return value.decode(encoding) else: return value -else: # pragma nocover + +else: # pragma nocover def decode(value, encoding): return value @@ -40,25 +47,25 @@ def pairs_to_dict(response, encoding): class odmrun(RedisScript): - script = (read_lua_file('tabletools'), - # timeseries must be included before utils - read_lua_file('commands.timeseries'), - read_lua_file('commands.utils'), - read_lua_file('odm')) + script = ( + read_lua_file("tabletools"), + # timeseries must be included before utils + read_lua_file("commands.timeseries"), + read_lua_file("commands.utils"), + read_lua_file("odm"), + ) required_scripts = ODM_SCRIPTS - def callback(self, response, meta=None, backend=None, odm_command=None, - **opts): - if odm_command == 'delete': - res = (instance_session_result(r, False, r, True, 0) - for r in response) + def callback(self, response, meta=None, backend=None, odm_command=None, **opts): + if odm_command == "delete": + res = (instance_session_result(r, False, r, True, 0) for r in response) return session_result(meta, res) - elif odm_command == 'commit': + elif odm_command == "commit": res = self._wrap_commit(response, **opts) return session_result(meta, res) - elif odm_command == 'load': + elif odm_command == "load": return self.load_query(response, backend, meta, **opts) - elif odm_command == 'structure': + elif odm_command == "structure": return self.flush_structure(response, backend, meta, **opts) else: return response @@ -67,14 +74,22 @@ def _wrap_commit(self, response, iids=None, redis_client=None, **options): for id, iid in zip(response, iids): id, flag, info = id if int(flag): - yield instance_session_result(iid, True, id, False, - float(info)) + yield instance_session_result(iid, True, id, False, float(info)) else: msg = info.decode(redis_client.encoding) yield CommitException(msg) - def load_query(self, response, backend, meta, get=None, fields=None, - fields_attributes=None, redis_client=None, **options): + def load_query( + self, + response, + backend, + meta, + get=None, + fields=None, + fields_attributes=None, + redis_client=None, + **options + ): if get: tpy = meta.dfields.get(get).to_python return [tpy(v, backend) for v in response] @@ -87,14 +102,15 @@ def load_query(self, response, backend, meta, get=None, fields=None, for fname, rdata, fields in related: fname = native_str(fname, encoding) fields = tuple(native_str(f, encoding) for f in fields) - related_fields[fname] =\ - self.load_related(meta, fname, rdata, fields, encoding) + related_fields[fname] = self.load_related( + meta, fname, rdata, fields, encoding + ) return backend.objects_from_db(meta, data, related_fields) def build(self, response, meta, fields, fields_attributes, encoding): fields = tuple(fields) if fields else None if fields: - if len(fields) == 1 and fields[0] in (meta.pkname(), ''): + if len(fields) == 1 and fields[0] in (meta.pkname(), ""): for id in response: yield id, (), {} else: @@ -105,24 +121,24 @@ def build(self, response, meta, fields, fields_attributes, encoding): yield id, None, pairs_to_dict(fdata, encoding) def load_related(self, meta, fname, data, fields, encoding): - '''Parse data for related objects.''' + """Parse data for related objects.""" field = meta.dfields[fname] if field in meta.multifields: fmeta = field.structure_class()._meta - if fmeta.name in ('hashtable', 'zset'): - return ((native_str(id, encoding), - pairs_to_dict(fdata, encoding)) for - id, fdata in data) + if fmeta.name in ("hashtable", "zset"): + return ( + (native_str(id, encoding), pairs_to_dict(fdata, encoding)) + for id, fdata in data + ) else: - return ((native_str(id, encoding), fdata) for - id, fdata in data) + return ((native_str(id, encoding), fdata) for id, fdata in data) else: # this is data for stdmodel instances return self.build(data, meta, fields, fields, encoding) class check_structures(RedisScript): - script = read_lua_file('structures') + script = read_lua_file("structures") ############################################################################ @@ -131,7 +147,7 @@ class check_structures(RedisScript): class RedisQuery(stdnet.BackendQuery): card = None _meta_info = None - script_dep = {'script_dependency': ('build_query', 'move2set')} + script_dep = {"script_dependency": ("build_query", "move2set")} def zism(self, r): return r is not None @@ -155,42 +171,43 @@ def _build(self, pipe=None, **kwargs): key, meta, keys, args = None, self.meta, [], [] pkname = meta.pkname() for child in qs: - if getattr(child, 'backend', None) == backend: - lookup, value = 'set', child + if getattr(child, "backend", None) == backend: + lookup, value = "set", child else: lookup, value = child - if lookup == 'set': + if lookup == "set": be = value.backend_query(pipe=pipe) keys.append(be.query_key) - args.extend(('set', be.query_key)) + args.extend(("set", be.query_key)) else: if isinstance(value, tuple): value = self.dump_nested(*value) - args.extend((lookup, '' if value is None else value)) + args.extend((lookup, "" if value is None else value)) temp_key = True - if qs.keyword == 'set': + if qs.keyword == "set": if qs.name == pkname and not args: - key = backend.basekey(meta, 'id') + key = backend.basekey(meta, "id") temp_key = False else: key = backend.tempkey(meta) keys.insert(0, key) - backend.odmrun(pipe, 'query', meta, keys, self.meta_info, - qs.name, *args) + backend.odmrun( + pipe, "query", meta, keys, self.meta_info, qs.name, *args + ) else: key = backend.tempkey(meta) - p = 'z' if meta.ordering else 's' - pipe.execute_script('move2set', keys, p) - if qs.keyword == 'intersect': - command = getattr(pipe, p+'interstore') - elif qs.keyword == 'union': - command = getattr(pipe, p+'unionstore') - elif qs.keyword == 'diff': - command = getattr(pipe, p+'diffstore') + p = "z" if meta.ordering else "s" + pipe.execute_script("move2set", keys, p) + if qs.keyword == "intersect": + command = getattr(pipe, p + "interstore") + elif qs.keyword == "union": + command = getattr(pipe, p + "unionstore") + elif qs.keyword == "diff": + command = getattr(pipe, p + "diffstore") else: - raise ValueError('Could not perform %s operation' % qs.keyword) + raise ValueError("Could not perform %s operation" % qs.keyword) command(key, keys) - where = self.queryelem.data.get('where') + where = self.queryelem.data.get("where") # where query if where: # First key is the current key @@ -212,25 +229,25 @@ def _build(self, pipe=None, **kwargs): if not temp_key: temp_key = True key = backend.tempkey(meta) - okey = backend.basekey(meta, OBJ, '*->' + field_attribute) - pipe.sort(bkey, by='nosort', get=okey, store=key) - self.card = getattr(pipe, 'llen') + okey = backend.basekey(meta, OBJ, "*->" + field_attribute) + pipe.sort(bkey, by="nosort", get=okey, store=key) + self.card = getattr(pipe, "llen") if temp_key: pipe.expire(key, self.expire) self.query_key = key def _execute_query(self): - '''Execute the query without fetching data. Returns the number of -elements in the query.''' + """Execute the query without fetching data. Returns the number of + elements in the query.""" pipe = self.pipe if not self.card: if self.meta.ordering: - self.ismember = getattr(self.backend.client, 'zrank') - self.card = getattr(pipe, 'zcard') + self.ismember = getattr(self.backend.client, "zrank") + self.card = getattr(pipe, "zcard") self._check_member = self.zism else: - self.ismember = getattr(self.backend.client, 'sismember') - self.card = getattr(pipe, 'scard') + self.ismember = getattr(self.backend.client, "sismember") + self.card = getattr(pipe, "scard") self._check_member = self.sism else: self.ismember = None @@ -239,7 +256,7 @@ def _execute_query(self): yield result[-1] def order(self, last): - '''Perform ordering with respect model fields.''' + """Perform ordering with respect model fields.""" desc = last.desc field = last.name nested = last.nested @@ -249,13 +266,10 @@ def order(self, last): nested_args.extend((self.backend.basekey(meta), nested.name)) last = nested nested = nested.nested - method = 'ALPHA' if last.field.internal_type == 'text' else '' + method = "ALPHA" if last.field.internal_type == "text" else "" if field == last.model._meta.pkname(): - field = '' - return {'field': field, - 'method': method, - 'desc': desc, - 'nested': nested_args} + field = "" + return {"field": field, "method": method, "desc": desc, "nested": nested_args} def dump_nested(self, value, nested): nested_args = [] @@ -284,19 +298,19 @@ def _items(self, slic): # the load_query lua script backend = self.backend meta = self.meta - name = '' + name = "" order = () start, stop = self.get_redis_slice(slic) if self.queryelem.ordering: order = self.order(self.queryelem.ordering) elif meta.ordering: - name = 'DESC' if meta.ordering.desc else 'ASC' + name = "DESC" if meta.ordering.desc else "ASC" elif start or stop is not None: order = self.order(meta.get_sorting(meta.pkname())) # Wen using the sort algorithm redis requires the number of element # not the stop index if order: - name = 'explicit' + name = "explicit" N = self.execute_query() if stop is None: stop = N @@ -313,8 +327,10 @@ def _items(self, slic): # if the get_field is available, we only load that field if get: if slic: - raise QuerySetError('Cannot slice a queryset in conjunction ' - 'with get_field. Use load_only instead.') + raise QuerySetError( + "Cannot slice a queryset in conjunction " + "with get_field. Use load_only instead." + ) if get == meta.pk.name: fields_attributes = fields = pkname_tuple else: @@ -322,44 +338,55 @@ def _items(self, slic): else: fields = self.queryelem.fields or None if fields: - fields = unique_tuple(fields, - self.queryelem.select_related or ()) + fields = unique_tuple(fields, self.queryelem.select_related or ()) if fields == pkname_tuple: fields_attributes = fields elif fields: fields, fields_attributes = meta.backend_fields(fields) else: fields_attributes = () - options = {'ordering': name, - 'order': order, - 'start': start, - 'stop': stop, - 'fields': fields_attributes, - 'related': dict(self.related_lua_args()), - 'get': get} + options = { + "ordering": name, + "order": order, + "start": start, + "stop": stop, + "fields": fields_attributes, + "related": dict(self.related_lua_args()), + "get": get, + } joptions = json.dumps(options) - options.update({'fields': fields, - 'fields_attributes': fields_attributes}) - return backend.odmrun(backend.client, 'load', meta, (self.query_key,), - self.meta_info, joptions, **options) + options.update({"fields": fields, "fields_attributes": fields_attributes}) + return backend.odmrun( + backend.client, + "load", + meta, + (self.query_key,), + self.meta_info, + joptions, + **options + ) def related_lua_args(self): - '''Generator of load_related arguments''' + """Generator of load_related arguments""" related = self.queryelem.select_related if related: meta = self.meta for rel in related: field = meta.dfields[rel] relmodel = field.relmodel - bk = self.backend.basekey(relmodel._meta) if relmodel else '' + bk = self.backend.basekey(relmodel._meta) if relmodel else "" fields = list(related[rel]) if meta.pkname() in fields: fields.remove(meta.pkname()) if not fields: - fields.append('') - ftype = field.type if field in meta.multifields else '' - data = {'field': field.attname, 'type': ftype, - 'bk': bk, 'fields': fields} + fields.append("") + ftype = field.type if field in meta.multifields else "" + data = { + "field": field.attname, + "type": ftype, + "bk": bk, + "fields": fields, + } yield field.name, data @@ -367,7 +394,6 @@ def related_lua_args(self): ## STRUCTURES ############################################################################ class RedisStructure(BackendStructure): - def __init__(self, *args, **kwargs): super(RedisStructure, self).__init__(*args, **kwargs) instance = self.instance @@ -375,12 +401,13 @@ def __init__(self, *args, **kwargs): if field: model = field.model if instance._pkvalue: - id = self.backend.basekey(model._meta, 'obj', - instance._pkvalue, field.name) + id = self.backend.basekey( + model._meta, "obj", instance._pkvalue, field.name + ) else: - id = self.backend.basekey(model._meta, 'struct', field.name) + id = self.backend.basekey(model._meta, "struct", field.name) else: - id = '%s.%s' % (instance._meta.name, instance.id) + id = "%s.%s" % (instance._meta.name, instance.id) self.id = id @property @@ -392,7 +419,6 @@ def delete(self): class String(RedisStructure): - def flush(self): cache = self.instance.cache result = None @@ -410,7 +436,6 @@ def incr(self, num=1): class Set(RedisStructure): - def flush(self): cache = self.instance.cache result = None @@ -430,7 +455,8 @@ def items(self): class Zset(RedisStructure): - '''Redis ordered set structure''' + """Redis ordered set structure""" + def flush(self): cache = self.instance.cache result = None @@ -469,29 +495,37 @@ def count(self, start, stop): def range(self, start, end, withscores=True, **options): return self.backend.execute( - self.client.zrangebyscore(self.id, start, end, - withscores=withscores, **options), - partial(self._range, withscores)) + self.client.zrangebyscore( + self.id, start, end, withscores=withscores, **options + ), + partial(self._range, withscores), + ) def irange(self, start=0, stop=-1, desc=False, withscores=True, **options): return self.backend.execute( - self.client.zrange(self.id, start, stop, desc=desc, - withscores=withscores, **options), - partial(self._range, withscores)) + self.client.zrange( + self.id, start, stop, desc=desc, withscores=withscores, **options + ), + partial(self._range, withscores), + ) def ipop_range(self, start, stop=None, withscores=True, **options): - '''Remove and return a range from the ordered set by rank (index).''' + """Remove and return a range from the ordered set by rank (index).""" return self.backend.execute( - self.client.zpopbyrank(self.id, start, stop, - withscores=withscores, **options), - partial(self._range, withscores)) + self.client.zpopbyrank( + self.id, start, stop, withscores=withscores, **options + ), + partial(self._range, withscores), + ) def pop_range(self, start, stop=None, withscores=True, **options): - '''Remove and return a range from the ordered set by score.''' + """Remove and return a range from the ordered set by score.""" return self.backend.execute( - self.client.zpopbyscore(self.id, start, stop, - withscores=withscores, **options), - partial(self._range, withscores)) + self.client.zpopbyscore( + self.id, start, stop, withscores=withscores, **options + ), + partial(self._range, withscores), + ) # PRIVATE def _range(self, withscores, result): @@ -502,7 +536,6 @@ def _range(self, withscores, result): class List(RedisStructure): - def pop_front(self): return self.client.lpop(self.id) @@ -538,7 +571,6 @@ def range(self, start=0, end=-1): class Hash(RedisStructure): - def flush(self): cache = self.instance.cache result = None @@ -581,136 +613,150 @@ def items(self): class TS(RedisStructure): - '''Redis timeseries implementation is based on the ts.lua script''' + """Redis timeseries implementation is based on the ts.lua script""" + def flush(self): cache = self.instance.cache result = None if cache.toadd: - result = self.client.execute_script('ts_commands', (self.id,), - 'add', *cache.toadd.flat()) + result = self.client.execute_script( + "ts_commands", (self.id,), "add", *cache.toadd.flat() + ) if cache.toremove: - raise NotImplementedError('Cannot remove. TSDEL not implemented') + raise NotImplementedError("Cannot remove. TSDEL not implemented") return result def __contains__(self, timestamp): - return self.client.execute_script('ts_commands', (self.id,), 'exists', - timestamp) + return self.client.execute_script( + "ts_commands", (self.id,), "exists", timestamp + ) def size(self): - return self.client.execute_script('ts_commands', (self.id,), 'size') + return self.client.execute_script("ts_commands", (self.id,), "size") def count(self, start, stop): - return self.client.execute_script('ts_commands', (self.id,), 'count', - start, stop) + return self.client.execute_script( + "ts_commands", (self.id,), "count", start, stop + ) def times(self, time_start, time_stop, **kwargs): - return self.client.execute_script('ts_commands', (self.id,), 'times', - time_start, time_stop, **kwargs) + return self.client.execute_script( + "ts_commands", (self.id,), "times", time_start, time_stop, **kwargs + ) def itimes(self, start=0, stop=-1, **kwargs): - return self.client.execute_script('ts_commands', (self.id,), 'itimes', - start, stop, **kwargs) + return self.client.execute_script( + "ts_commands", (self.id,), "itimes", start, stop, **kwargs + ) def get(self, dte): - return self.client.execute_script('ts_commands', (self.id,), - 'get', dte) + return self.client.execute_script("ts_commands", (self.id,), "get", dte) def rank(self, dte): - return self.client.execute_script('ts_commands', (self.id,), - 'rank', dte) + return self.client.execute_script("ts_commands", (self.id,), "rank", dte) def pop(self, dte): - return self.client.execute_script('ts_commands', (self.id,), - 'pop', dte) + return self.client.execute_script("ts_commands", (self.id,), "pop", dte) def ipop(self, index): - return self.client.execute_script('ts_commands', (self.id,), - 'ipop', index) + return self.client.execute_script("ts_commands", (self.id,), "ipop", index) def range(self, time_start, time_stop, **kwargs): - return self.client.execute_script('ts_commands', (self.id,), 'range', - time_start, time_stop, **kwargs) + return self.client.execute_script( + "ts_commands", (self.id,), "range", time_start, time_stop, **kwargs + ) def irange(self, start=0, stop=-1, **kwargs): - return self.client.execute_script('ts_commands', (self.id,), 'irange', - start, stop, **kwargs) + return self.client.execute_script( + "ts_commands", (self.id,), "irange", start, stop, **kwargs + ) def pop_range(self, time_start, time_stop, **kwargs): - return self.client.execute_script('ts_commands', (self.id,), - 'pop_range', - time_start, time_stop, **kwargs) + return self.client.execute_script( + "ts_commands", (self.id,), "pop_range", time_start, time_stop, **kwargs + ) def ipop_range(self, start=0, stop=-1, **kwargs): - return self.client.execute_script('ts_commands', (self.id,), - 'ipop_range', start, stop, **kwargs) + return self.client.execute_script( + "ts_commands", (self.id,), "ipop_range", start, stop, **kwargs + ) class NumberArray(RedisStructure): - def flush(self): cache = self.instance.cache result = None if cache.back: - self.client.execute_script('numberarray_pushback', (self.id,), - *cache.back) + self.client.execute_script("numberarray_pushback", (self.id,), *cache.back) result = True return result def get(self, index): - return self.client.execute_script('numberarray_getset', (self.id,), - 'get', index+1) + return self.client.execute_script( + "numberarray_getset", (self.id,), "get", index + 1 + ) def set(self, value): - return self.client.execute_script('numberarray_getset', (self.id,), - 'set', index+1, value) + return self.client.execute_script( + "numberarray_getset", (self.id,), "set", index + 1, value + ) def range(self): - return self.client.execute_script('numberarray_all_raw', (self.id,),) + return self.client.execute_script( + "numberarray_all_raw", + (self.id,), + ) def resize(self, size, value=None): if value is not None: argv = (size, value) else: argv = (size,) - return self.client.execute_script('numberarray_resize', (self.id,), - *argv) + return self.client.execute_script("numberarray_resize", (self.id,), *argv) def size(self): - return self.client.strlen(self.id)//8 + return self.client.strlen(self.id) // 8 class ts_commands(RedisScript): - script = (read_lua_file('commands.timeseries'), - read_lua_file('tabletools'), - read_lua_file('ts')) + script = ( + read_lua_file("commands.timeseries"), + read_lua_file("tabletools"), + read_lua_file("ts"), + ) class numberarray_resize(RedisScript): - script = (read_lua_file('numberarray'), - '''return array:new(KEYS[1]):resize(unpack(ARGV))''') + script = ( + read_lua_file("numberarray"), + """return array:new(KEYS[1]):resize(unpack(ARGV))""", + ) class numberarray_all_raw(RedisScript): - script = (read_lua_file('numberarray'), - '''return array:new(KEYS[1]):all_raw()''') + script = (read_lua_file("numberarray"), """return array:new(KEYS[1]):all_raw()""") class numberarray_getset(RedisScript): - script = (read_lua_file('numberarray'), - '''local a = array:new(KEYS[1]) + script = ( + read_lua_file("numberarray"), + """local a = array:new(KEYS[1]) if ARGV[1] == 'get' then return a:get(ARGV[2],true) else a:set(ARGV[2],ARGV[3],true) -end''') +end""", + ) class numberarray_pushback(RedisScript): - script = (read_lua_file('numberarray'), - '''local a = array:new(KEYS[1]) + script = ( + read_lua_file("numberarray"), + """local a = array:new(KEYS[1]) for _,v in ipairs(ARGV) do a:push_back(v,true) -end''') +end""", + ) ############################################################################ @@ -720,24 +766,26 @@ class BackendDataServer(stdnet.BackendDataServer): Query = RedisQuery _redis_clients = {} default_port = 6379 - struct_map = {'set': Set, - 'list': List, - 'zset': Zset, - 'hashtable': Hash, - 'ts': TS, - 'numberarray': NumberArray, - 'string': String} + struct_map = { + "set": Set, + "list": List, + "zset": Zset, + "hashtable": Hash, + "ts": TS, + "numberarray": NumberArray, + "string": String, + } def setup_connection(self, address): if len(address) == 2: address = tuple(address) elif len(address) == 1: address = address[0] - if 'db' not in self.params: - self.params['db'] = 0 + if "db" not in self.params: + self.params["db"] = 0 rpy = redis_client(address=address, **self.params) if self.namespace: - self.params['namespace'] = self.namespace + self.params["namespace"] = self.namespace return rpy def auto_id_to_python(self, value): @@ -753,20 +801,19 @@ def disconnect(self): self.client.connection_pool.disconnect() def meta(self, meta): - '''Extract model metadata for lua script stdnet/lib/lua/odm.lua''' + """Extract model metadata for lua script stdnet/lib/lua/odm.lua""" data = meta.as_dict() - data['namespace'] = self.basekey(meta) + data["namespace"] = self.basekey(meta) return data - def odmrun(self, client, odm_command, meta, keys, meta_info, - *args, **options): - options.update({'backend': self, 'meta': meta, - 'odm_command': odm_command}) - return client.execute_script('odmrun', keys, odm_command, meta_info, - *args, **options) + def odmrun(self, client, odm_command, meta, keys, meta_info, *args, **options): + options.update({"backend": self, "meta": meta, "odm_command": odm_command}) + return client.execute_script( + "odmrun", keys, odm_command, meta_info, *args, **options + ) def where_run(self, client, meta_info, keys, where, load_only): - where = read_lua_file('where', context={'where_clause': where}) + where = read_lua_file("where", context={"where_clause": where}) numkeys = len(keys) keys.append(meta_info) if load_only: @@ -774,7 +821,7 @@ def where_run(self, client, meta_info, keys, where, load_only): return client.eval(where, numkeys, *keys) def execute_session(self, session_data): - '''Execute a session in redis.''' + """Execute a session in redis.""" pipe = self.client.pipeline() for sm in session_data: # loop through model sessions meta = sm.meta @@ -791,8 +838,7 @@ def execute_session(self, session_data): for instance in sm.dirty: state = instance.get_state() if not meta.is_valid(instance): - raise FieldValueError( - json.dumps(instance._dbdata['errors'])) + raise FieldValueError(json.dumps(instance._dbdata["errors"])) score = MIN_FLOAT if meta.ordering: if meta.ordering.auto: @@ -801,16 +847,17 @@ def execute_session(self, session_data): v = getattr(instance, meta.ordering.name, None) if v is not None: score = meta.ordering.field.scorefun(v) - data = instance._dbdata['cleaned_data'] + data = instance._dbdata["cleaned_data"] action = state.action - prev_id = state.iid if state.persistent else '' - id = instance.pkvalue() or '' + prev_id = state.iid if state.persistent else "" + id = instance.pkvalue() or "" data = flat_mapping(data) lua_data.extend((action, prev_id, id, score, len(data))) lua_data.extend(data) processed.append(state.iid) - self.odmrun(pipe, 'commit', meta, (), meta_info, - *lua_data, iids=processed) + self.odmrun( + pipe, "commit", meta, (), meta_info, *lua_data, iids=processed + ) return pipe.execute() def accumulate_delete(self, pipe, backend_query): @@ -830,8 +877,9 @@ def accumulate_delete(self, pipe, backend_query): rmanager = getattr(meta.model, name) # the related manager model is the same as current model if rmanager.model == meta.model: - self.odmrun(pipe, 'aggregate', meta, keys, meta_info, - rmanager.field.attname) + self.odmrun( + pipe, "aggregate", meta, keys, meta_info, rmanager.field.attname + ) # only consider models which are registered with the router elif rmanager.model in session.router: rel_managers.append(rmanager) @@ -841,22 +889,21 @@ def accumulate_delete(self, pipe, backend_query): if rmanager.field.required: rq = rmanager.query_from_query(query).backend_query(pipe=pipe) self.accumulate_delete(pipe, rq) - self.odmrun(pipe, 'delete', meta, keys, meta_info) + self.odmrun(pipe, "delete", meta, keys, meta_info) def tempkey(self, meta, name=None): - return self.basekey(meta, TMP, name if name is not None else - gen_unique_id()) + return self.basekey(meta, TMP, name if name is not None else gen_unique_id()) def flush(self, meta=None): - '''Flush all model keys from the database''' + """Flush all model keys from the database""" pattern = self.basekey(meta) if meta else self.namespace - return self.client.delpattern('%s*' % pattern) + return self.client.delpattern("%s*" % pattern) def clean(self, meta): - return self.client.delpattern(self.tempkey(meta, '*')) + return self.client.delpattern(self.tempkey(meta, "*")) def model_keys(self, meta): - pattern = '%s*' % self.basekey(meta) + pattern = "%s*" % self.basekey(meta) return self.execute(self.client.keys(pattern), self._decode_keys) def instance_keys(self, obj): @@ -872,7 +919,7 @@ def flush_structure(self, sm, pipe): for instance in sm.structures: be = self.structure(instance, pipe) be.action = instance.action - if be.action == 'update': + if be.action == "update": be.flush() else: be.delete() diff --git a/stdnet/backends/redisb/client/__init__.py b/stdnet/backends/redisb/client/__init__.py index 8ced62a..b6a5cc3 100644 --- a/stdnet/backends/redisb/client/__init__.py +++ b/stdnet/backends/redisb/client/__init__.py @@ -3,33 +3,50 @@ except ImportError: async = None -from .extensions import (RedisScript, read_lua_file, redis, get_script, - RedisDb, RedisKey, RedisDataFormatter) from .client import Redis +from .extensions import ( + RedisDataFormatter, + RedisDb, + RedisKey, + RedisScript, + get_script, + read_lua_file, + redis, +) RedisError = redis.RedisError -__all__ = ['redis_client', 'RedisScript', 'read_lua_file', 'RedisError', - 'RedisDb', 'RedisKey', 'RedisDataFormatter', 'get_script'] +__all__ = [ + "redis_client", + "RedisScript", + "read_lua_file", + "RedisError", + "RedisDb", + "RedisKey", + "RedisDataFormatter", + "get_script", +] -def redis_client(address=None, connection_pool=None, timeout=None, - parser=None, **kwargs): - '''Get a new redis client. +def redis_client( + address=None, connection_pool=None, timeout=None, parser=None, **kwargs +): + """Get a new redis client. :param address: a ``host``, ``port`` tuple. :param connection_pool: optional connection pool. :param timeout: socket timeout. :param timeout: socket timeout. - ''' + """ if not connection_pool: if timeout == 0: if not async: - raise ImportError('Asynchronous connection requires async ' - 'bindings installed.') + raise ImportError( + "Asynchronous connection requires async " "bindings installed." + ) return async.pool.redis(address, **kwargs) else: - kwargs['socket_timeout'] = timeout + kwargs["socket_timeout"] = timeout return Redis(address[0], address[1], **kwargs) else: return Redis(connection_pool=connection_pool) diff --git a/stdnet/backends/redisb/client/async.py b/stdnet/backends/redisb/client/async.py index f61c4d5..39a12ee 100644 --- a/stdnet/backends/redisb/client/async.py +++ b/stdnet/backends/redisb/client/async.py @@ -1,4 +1,4 @@ -'''The :mod:`stdnet.backends.redisb.async` module implements an asynchronous +"""The :mod:`stdnet.backends.redisb.async` module implements an asynchronous connector for redis-py_. It uses pulsar_ asynchronous framework. To use this connector, add ``timeout=0`` to redis :ref:`connection string `:: @@ -11,17 +11,15 @@ db = getdb('redis://127.0.0.1:6378?password=bla&timeout=0') -''' +""" from pulsar.apps import redis from pulsar.apps.redis.client import BasePipeline -from .extensions import (RedisExtensionsMixin, get_script, RedisError, - all_loaded_scripts) +from .extensions import RedisError, RedisExtensionsMixin, all_loaded_scripts, get_script from .prefixed import PrefixedRedisMixin class Redis(RedisExtensionsMixin, redis.Redis): - @property def is_async(self): return True @@ -30,18 +28,17 @@ def address(self): return self.connection_info[0] def prefixed(self, prefix): - '''Return a new :class:`PrefixedRedis` client. - ''' + """Return a new :class:`PrefixedRedis` client.""" return PrefixedRedis(self, prefix) def pipeline(self, transaction=True, shard_hint=None): return Pipeline(self, self.response_callbacks, transaction, shard_hint) def execute_script(self, name, keys, *args, **options): - '''Execute a script. + """Execute a script. makes sure all required scripts are loaded. - ''' + """ script = get_script(name) if not script: raise redis.RedisError('No such script "%s"' % name) @@ -62,12 +59,11 @@ class PrefixedRedis(PrefixedRedisMixin, Redis): class Pipeline(BasePipeline, Redis): - def execute_script(self, name, keys, *args, **options): - '''Execute a script. + """Execute a script. makes sure all required scripts are loaded. - ''' + """ script = get_script(name) if not script: raise redis.RedisError('No such script "%s"' % name) @@ -84,7 +80,6 @@ def execute_script(self, name, keys, *args, **options): class RedisPool(redis.RedisPool): - def redis(self, address, db=0, password=None, timeout=None, **kw): timeout = int(timeout or self.timeout) info = redis.connection_info(address, db, password, timeout) diff --git a/stdnet/backends/redisb/client/client.py b/stdnet/backends/redisb/client/client.py index 031362d..9e47020 100644 --- a/stdnet/backends/redisb/client/client.py +++ b/stdnet/backends/redisb/client/client.py @@ -1,4 +1,4 @@ -'''The :mod:`stdnet.backends.redisb.client` implements several extensions +"""The :mod:`stdnet.backends.redisb.client` implements several extensions to the standard redis client in redis-py_ @@ -23,36 +23,31 @@ :members: :member-order: bysource -''' -import os +""" import io +import os import socket from copy import copy -from .extensions import RedisExtensionsMixin, redis, BasePipeline +from .extensions import BasePipeline, RedisExtensionsMixin, redis from .prefixed import PrefixedRedisMixin class Redis(RedisExtensionsMixin, redis.StrictRedis): - @property def encoding(self): - return self.connection_pool.connection_kwargs.get('encoding', 'utf-8') + return self.connection_pool.connection_kwargs.get("encoding", "utf-8") def address(self): kw = self.connection_pool.connection_kwargs - return (kw['host'], kw['port']) + return (kw["host"], kw["port"]) def prefixed(self, prefix): - '''Return a new :class:`PrefixedRedis` client. - ''' + """Return a new :class:`PrefixedRedis` client.""" return PrefixedRedis(self, prefix) def pipeline(self, transaction=True, shard_hint=None): - return Pipeline( - self, - transaction, - shard_hint) + return Pipeline(self, transaction, shard_hint) class PrefixedRedis(PrefixedRedisMixin, Redis): @@ -60,7 +55,6 @@ class PrefixedRedis(PrefixedRedisMixin, Redis): class Pipeline(BasePipeline, Redis): - def __init__(self, client, transaction, shard_hint): self.client = client self.response_callbacks = client.response_callbacks diff --git a/stdnet/backends/redisb/client/extensions.py b/stdnet/backends/redisb/client/extensions.py index d7abe52..f645483 100644 --- a/stdnet/backends/redisb/client/extensions.py +++ b/stdnet/backends/redisb/client/extensions.py @@ -1,25 +1,26 @@ import os -from hashlib import sha1 from collections import namedtuple -from datetime import datetime from copy import copy +from datetime import datetime +from hashlib import sha1 -from stdnet.utils.structures import OrderedDict -from stdnet.utils import iteritems, format_int from stdnet import odm +from stdnet.utils import format_int, iteritems +from stdnet.utils.structures import OrderedDict try: import redis -except ImportError: # pragma nocover +except ImportError: # pragma nocover from stdnet import ImproperlyConfigured - raise ImproperlyConfigured('Redis backend requires redis python client') + + raise ImproperlyConfigured("Redis backend requires redis python client") from redis.client import BasePipeline RedisError = redis.RedisError p = os.path -DEFAULT_LUA_PATH = p.join(p.dirname(p.dirname(p.abspath(__file__))), 'lua') -redis_connection = namedtuple('redis_connection', 'address db') +DEFAULT_LUA_PATH = p.join(p.dirname(p.dirname(p.abspath(__file__))), "lua") +redis_connection = namedtuple("redis_connection", "address db") ########################################################### # GLOBAL REGISTERED SCRIPT DICTIONARY @@ -33,6 +34,8 @@ def registered_scripts(): def get_script(script): return _scripts.get(script) + + ########################################################### @@ -44,10 +47,10 @@ def script_callback(response, script=None, **options): def read_lua_file(dotted_module, path=None, context=None): - '''Load lua script from the stdnet/lib/lua directory''' + """Load lua script from the stdnet/lib/lua directory""" path = path or DEFAULT_LUA_PATH - bits = dotted_module.split('.') - bits[-1] += '.lua' + bits = dotted_module.split(".") + bits[-1] += ".lua" name = os.path.join(path, *bits) with open(name) as f: data = f.read() @@ -57,25 +60,26 @@ def read_lua_file(dotted_module, path=None, context=None): def parse_info(response): - '''Parse the response of Redis's INFO command into a Python dict. -In doing so, convert byte data into unicode.''' + """Parse the response of Redis's INFO command into a Python dict. + In doing so, convert byte data into unicode.""" info = {} - response = response.decode('utf-8') + response = response.decode("utf-8") def get_value(value): - if ',' and '=' not in value: + if "," and "=" not in value: return value sub_dict = {} - for item in value.split(','): - k, v = item.split('=') + for item in value.split(","): + k, v = item.split("=") try: sub_dict[k] = int(v) except ValueError: sub_dict[k] = v return sub_dict + data = info for line in response.splitlines(): - keyvalue = line.split(':') + keyvalue = line.split(":") if len(keyvalue) == 2: key, value = keyvalue try: @@ -95,13 +99,12 @@ def dict_update(original, data): class RedisExtensionsMixin(object): - '''Extension for Redis clients. - ''' - prefix = '' + """Extension for Redis clients.""" + + prefix = "" RESPONSE_CALLBACKS = dict_update( redis.StrictRedis.RESPONSE_CALLBACKS, - {'EVALSHA': script_callback, - 'INFO': parse_info} + {"EVALSHA": script_callback, "INFO": parse_info}, ) @property @@ -113,12 +116,11 @@ def is_pipeline(self): return False def address(self): - '''Address of redis server. - ''' + """Address of redis server.""" raise NotImplementedError def execute_script(self, name, keys, *args, **options): - '''Execute a registered lua script at ``name``. + """Execute a registered lua script at ``name``. The script must be implemented via subclassing :class:`RedisScript`. @@ -128,7 +130,7 @@ def execute_script(self, name, keys, *args, **options): :param options: key-value parameters passed to the :meth:`RedisScript.callback` method once the script has finished execution. - ''' + """ script = get_script(name) if not script: raise RedisError('No such script "%s"' % name) @@ -144,49 +146,56 @@ def execute_script(self, name, keys, *args, **options): return script(self, keys, args, options) def countpattern(self, pattern): - '''delete all keys matching *pattern*. - ''' - return self.execute_script('countpattern', (), pattern) + """delete all keys matching *pattern*.""" + return self.execute_script("countpattern", (), pattern) def delpattern(self, pattern): - '''delete all keys matching *pattern*. - ''' - return self.execute_script('delpattern', (), pattern) + """delete all keys matching *pattern*.""" + return self.execute_script("delpattern", (), pattern) def zdiffstore(self, dest, keys, withscores=False): - '''Compute the difference of multiple sorted. + """Compute the difference of multiple sorted. The difference of sets specified by ``keys`` into a new sorted set in ``dest``. - ''' + """ keys = (dest,) + tuple(keys) - wscores = 'withscores' if withscores else '' - return self.execute_script('zdiffstore', keys, wscores, - withscores=withscores) + wscores = "withscores" if withscores else "" + return self.execute_script("zdiffstore", keys, wscores, withscores=withscores) def zpopbyrank(self, name, start, stop=None, withscores=False, desc=False): - '''Pop a range by rank. - ''' + """Pop a range by rank.""" stop = stop if stop is not None else start - return self.execute_script('zpop', (name,), 'rank', start, - stop, int(desc), int(withscores), - withscores=withscores) - - def zpopbyscore(self, name, start, stop=None, withscores=False, - desc=False): - '''Pop a range by score. - ''' + return self.execute_script( + "zpop", + (name,), + "rank", + start, + stop, + int(desc), + int(withscores), + withscores=withscores, + ) + + def zpopbyscore(self, name, start, stop=None, withscores=False, desc=False): + """Pop a range by score.""" stop = stop if stop is not None else start - return self.execute_script('zpop', (name,), 'score', start, - stop, int(desc), int(withscores), - withscores=withscores) + return self.execute_script( + "zpop", + (name,), + "score", + start, + stop, + int(desc), + int(withscores), + withscores=withscores, + ) class RedisScriptMeta(type): - def __new__(cls, name, bases, attrs): super_new = super(RedisScriptMeta, cls).__new__ - abstract = attrs.pop('abstract', False) + abstract = attrs.pop("abstract", False) new_class = super_new(cls, name, bases, attrs) if not abstract: self = new_class(new_class.script, new_class.__name__) @@ -194,8 +203,8 @@ def __new__(cls, name, bases, attrs): return new_class -class RedisScript(RedisScriptMeta('_RS', (object,), {'abstract': True})): - '''Class which helps the sending and receiving lua scripts. +class RedisScript(RedisScriptMeta("_RS", (object,), {"abstract": True})): + """Class which helps the sending and receiving lua scripts. It uses the ``evalsha`` command. @@ -215,14 +224,15 @@ class RedisScript(RedisScriptMeta('_RS', (object,), {'abstract': True})): it is not set by the user. .. _SHA-1: http://en.wikipedia.org/wiki/SHA-1 - ''' + """ + abstract = True script = None required_scripts = () def __init__(self, script, name): if isinstance(script, (list, tuple)): - script = '\n'.join(script) + script = "\n".join(script) self.__name = name self.script = script rs = set((name,)) @@ -235,63 +245,65 @@ def name(self): @property def sha1(self): - if not hasattr(self, '_sha1'): - self._sha1 = sha1(self.script.encode('utf-8')).hexdigest() + if not hasattr(self, "_sha1"): + self._sha1 = sha1(self.script.encode("utf-8")).hexdigest() return self._sha1 def __repr__(self): return self.name if self.name else self.__class__.__name__ + __str__ = __repr__ def preprocess_args(self, client, args): return args def callback(self, response, **options): - '''Called back after script execution. + """Called back after script execution. This is the only method user should override when writing a new :class:`RedisScript`. By default it returns ``response``. :parameter response: the response obtained from the script execution. :parameter options: Additional options for the callback. - ''' + """ return response def __call__(self, client, keys, args, options): args = self.preprocess_args(client, args) numkeys = len(keys) keys_args = tuple(keys) + args - options.update({'script': self, 'redis_client': client}) - return client.execute_command('EVALSHA', self.sha1, numkeys, - *keys_args, **options) + options.update({"script": self, "redis_client": client}) + return client.execute_command( + "EVALSHA", self.sha1, numkeys, *keys_args, **options + ) ############################################################################ ## BATTERY INCLUDED REDIS SCRIPTS ############################################################################ class countpattern(RedisScript): - script = '''\ + script = """\ return # redis.call('keys', ARGV[1]) -''' +""" def preprocess_args(self, client, args): if args and client.prefix: - args = tuple(('%s%s' % (client.prefix, a) for a in args)) + args = tuple(("%s%s" % (client.prefix, a) for a in args)) return args class delpattern(countpattern): - script = '''\ + script = """\ local n = 0 for i,key in ipairs(redis.call('keys', ARGV[1])) do n = n + redis.call('del', key) end return n -''' +""" class zpop(RedisScript): - script = read_lua_file('commands.zpop') + script = read_lua_file("commands.zpop") def callback(self, response, withscores=False, **options): if not response or not withscores: @@ -300,20 +312,19 @@ def callback(self, response, withscores=False, **options): class zdiffstore(RedisScript): - script = read_lua_file('commands.zdiffstore') + script = read_lua_file("commands.zdiffstore") class move2set(RedisScript): - script = (read_lua_file('commands.utils'), - read_lua_file('commands.move2set')) + script = (read_lua_file("commands.utils"), read_lua_file("commands.move2set")) class keyinfo(RedisScript): - script = read_lua_file('commands.keyinfo') + script = read_lua_file("commands.keyinfo") def preprocess_args(self, client, args): if args and client.prefix: - a = ['%s%s' % (client.prefix, args[0])] + a = ["%s%s" % (client.prefix, args[0])] a.extend(args[1:]) args = tuple(a) return args @@ -322,16 +333,19 @@ def callback(self, response, redis_client=None, **options): client = redis_client if client.is_pipeline: client = client.client - encoding = 'utf-8' + encoding = "utf-8" all_keys = [] for key, typ, length, ttl, enc, idle in response: - key = key.decode(encoding)[len(client.prefix):] - key = RedisKey(key=key, client=client, - type=typ.decode(encoding), - length=length, - ttl=ttl if ttl != -1 else False, - encoding=enc.decode(encoding), - idle=idle) + key = key.decode(encoding)[len(client.prefix) :] + key = RedisKey( + key=key, + client=client, + type=typ.decode(encoding), + length=length, + ttl=ttl if ttl != -1 else False, + encoding=enc.decode(encoding), + idle=idle, + ) all_keys.append(key) return all_keys @@ -339,8 +353,8 @@ def callback(self, response, redis_client=None, **options): ############################################################################### ## key info models -class RedisDbQuery(odm.QueryBase): +class RedisDbQuery(odm.QueryBase): @property def client(self): return self.session.router[self.model].backend.client @@ -356,38 +370,47 @@ def items(self): def get(self, db=None): if db is not None: info = yield self.client.info() - data = info.get('db%s' % db) + data = info.get("db%s" % db) if data: yield self.instance(db, data) def keyspace(self, info): n = 0 - keyspace = info['Keyspace'] + keyspace = info["Keyspace"] while keyspace: - info = keyspace.pop('db%s' % n, None) + info = keyspace.pop("db%s" % n, None) if info: yield n, info n += 1 def instance(self, db, data): - rdb = self.model(db=int(db), keys=data['keys'], - expires=data['expires']) + rdb = self.model(db=int(db), keys=data["keys"], expires=data["expires"]) rdb.session = self.session return rdb class RedisDbManager(odm.Manager): - '''Handler for gathering information from redis.''' - names = ('Server', 'Memory', 'Persistence', - 'Replication', 'Clients', 'Stats', 'CPU') - converters = {'last_save_time': ('date', None), - 'uptime_in_seconds': ('timedelta', 'uptime'), - 'uptime_in_days': None} + """Handler for gathering information from redis.""" + + names = ( + "Server", + "Memory", + "Persistence", + "Replication", + "Clients", + "Stats", + "CPU", + ) + converters = { + "last_save_time": ("date", None), + "uptime_in_seconds": ("timedelta", "uptime"), + "uptime_in_days": None, + } query_class = RedisDbQuery def __init__(self, *args, **kwargs): - self.formatter = kwargs.pop('formatter', RedisDataFormatter()) + self.formatter = kwargs.pop("formatter", RedisDataFormatter()) self._panels = OrderedDict() super(RedisDbManager, self).__init__(*args, **kwargs) @@ -414,10 +437,9 @@ def makepanel(self, name, info): for k, v in iteritems(info[name]): add = True if k in self.converters or isinstance(v, int): - fdata = self.converters.get(k, ('int', None)) + fdata = self.converters.get(k, ("int", None)) if fdata: - formatter = getattr(self.formatter, - 'format_{0}'.format(fdata[0])) + formatter = getattr(self.formatter, "format_{0}".format(fdata[0])) k = fdata[1] or k v = formatter(v) else: @@ -425,17 +447,17 @@ def makepanel(self, name, info): elif v in boolval: v = nicebool(v) if add: - pa.append({'name': nicename(k), - 'value': v}) + pa.append({"name": nicename(k), "value": v}) return pa def delete(self, instance): - '''Delete an instance''' + """Delete an instance""" flushdb(self.client) if flushdb else self.client.flushdb() class KeyQuery(odm.QueryBase): - '''A lazy query for keys in a redis database.''' + """A lazy query for keys in a redis database.""" + db = None def count(self): @@ -460,16 +482,16 @@ def __getitem__(self, slic): o.slice = slic return o.all() else: - return self[slic:slic+1][0] + return self[slic : slic + 1][0] def __iter__(self): db = self.db c = db.client if self.slice: start, num = self.get_start_num(self.slice) - qs = c.execute_script('keyinfo', (), self.pattern, start, num) + qs = c.execute_script("keyinfo", (), self.pattern, start, num) else: - qs = c.execute_script('keyinfo', (), self.pattern) + qs = c.execute_script("keyinfo", (), self.pattern) for q in qs: q.database = db yield q @@ -486,7 +508,7 @@ def get_start_num(self, slic): if N is None: N = self.count() start += N - return start+1, stop-start + return start + 1, stop - start class RedisKeyManager(odm.Manager): @@ -504,15 +526,15 @@ class RedisDb(odm.StdModel): manager_class = RedisDbManager def __unicode__(self): - return '%s' % self.db + return "%s" % self.db class Meta: - attributes = ('keys', 'expires') + attributes = ("keys", "expires") class RedisKey(odm.StdModel): key = odm.SymbolField(primary_key=True) - db = odm.ForeignKey(RedisDb, related_name='all_keys') + db = odm.ForeignKey(RedisDb, related_name="all_keys") manager_class = RedisKeyManager @@ -520,13 +542,12 @@ def __unicode__(self): return self.key class Meta: - attributes = 'type', 'length', 'ttl', 'encoding', 'idle', 'client' + attributes = "type", "length", "ttl", "encoding", "idle", "client" class RedisDataFormatter(object): - def format_bool(self, val): - return 'yes' if val else 'no' + return "yes" if val else "no" def format_name(self, name): return name @@ -537,9 +558,9 @@ def format_int(self, val): def format_date(self, dte): try: d = datetime.fromtimestamp(dte) - return d.isoformat().split('.')[0] + return d.isoformat().split(".")[0] except: - return '' + return "" def format_timedelta(self, td): return td diff --git a/stdnet/backends/redisb/client/prefixed.py b/stdnet/backends/redisb/client/prefixed.py index c75b4e6..2f418b4 100644 --- a/stdnet/backends/redisb/client/prefixed.py +++ b/stdnet/backends/redisb/client/prefixed.py @@ -1,23 +1,25 @@ def raise_error(exception=NotImplementedError): raise exception() -prefix_all = lambda pfix, args: ['%s%s' % (pfix, a) for a in args] -prefix_alternate = lambda pfix, args: [a if n//2*2 == n else '%s%s' % (pfix, a) - for n, a in enumerate(args, 1)] -prefix_not_last = lambda pfix, args: ['%s%s' % (pfix, a) - for a in args[:-1]] + [args[-1]] -prefix_not_first = lambda pfix, args: [args[0]] +\ - ['%s%s' % (pfix, a) for a in args[1:]] + +prefix_all = lambda pfix, args: ["%s%s" % (pfix, a) for a in args] +prefix_alternate = lambda pfix, args: [ + a if n // 2 * 2 == n else "%s%s" % (pfix, a) for n, a in enumerate(args, 1) +] +prefix_not_last = lambda pfix, args: ["%s%s" % (pfix, a) for a in args[:-1]] + [ + args[-1] +] +prefix_not_first = lambda pfix, args: [args[0]] + ["%s%s" % (pfix, a) for a in args[1:]] def prefix_zinter(pfix, args): dest, numkeys, params = args[0], args[1], args[2:] - args = ['%s%s' % (pfix, dest), numkeys] + args = ["%s%s" % (pfix, dest), numkeys] nk = 0 for p in params: if nk < numkeys: nk += 1 - p = '%s%s' % (pfix, p) + p = "%s%s" % (pfix, p) args.append(p) return args @@ -27,9 +29,9 @@ def prefix_sort(pfix, args): nargs = [] for a in args: if prefix: - a = '%s%s' % (pfix, a) + a = "%s%s" % (pfix, a) prefix = False - elif a in ('BY', 'GET', 'STORE'): + elif a in ("BY", "GET", "STORE"): prefix = True nargs.append(a) return nargs @@ -37,67 +39,91 @@ def prefix_sort(pfix, args): def pop_list_result(pfix, result): if result: - return (result[0][len(pfix):], result[1]) + return (result[0][len(pfix) :], result[1]) def prefix_eval_keys(pfix, args): n = args[1] if n: - keys = tuple(('%s%s' % (pfix, a) for a in args[2:n+2])) - return args[:2] + keys + args[n+2:] + keys = tuple(("%s%s" % (pfix, a) for a in args[2 : n + 2])) + return args[:2] + keys + args[n + 2 :] else: return args class PrefixedRedisMixin(object): - '''A class for a prefixed redis client. It append a prefix to all keys. - -.. attribute:: prefix - - The prefix to append to all keys - -''' - EXCLUDE_COMMANDS = frozenset(('BGREWRITEOF', 'BGSAVE', 'CLIENT', 'CONFIG', - 'DBSIZE', 'DEBUG', 'DISCARD', 'ECHO', 'EXEC', - 'INFO', 'LASTSAVE', 'PING', - 'PSUBSCRIBE', 'PUBLISH', 'PUNSUBSCRIBE', - 'QUIT', 'RANDOMKEY', 'SAVE', 'SCRIPT', - 'SELECT', 'SHUTDOWN', 'SLAVEOF', - 'SLOWLOG', 'SUBSCRIBE', 'SYNC', - 'TIME', 'UNSUBSCRIBE', 'UNWATCH')) + """A class for a prefixed redis client. It append a prefix to all keys. + + .. attribute:: prefix + + The prefix to append to all keys + """ + + EXCLUDE_COMMANDS = frozenset( + ( + "BGREWRITEOF", + "BGSAVE", + "CLIENT", + "CONFIG", + "DBSIZE", + "DEBUG", + "DISCARD", + "ECHO", + "EXEC", + "INFO", + "LASTSAVE", + "PING", + "PSUBSCRIBE", + "PUBLISH", + "PUNSUBSCRIBE", + "QUIT", + "RANDOMKEY", + "SAVE", + "SCRIPT", + "SELECT", + "SHUTDOWN", + "SLAVEOF", + "SLOWLOG", + "SUBSCRIBE", + "SYNC", + "TIME", + "UNSUBSCRIBE", + "UNWATCH", + ) + ) SPECIAL_COMMANDS = { - 'BITOP': prefix_not_first, - 'BLPOP': prefix_not_last, - 'BRPOP': prefix_not_last, - 'BRPOPLPUSH': prefix_not_last, - 'RPOPLPUSH': prefix_all, - 'DEL': prefix_all, - 'EVAL': prefix_eval_keys, - 'EVALSHA': prefix_eval_keys, - 'FLUSHDB': lambda prefix, args: raise_error(), - 'FLUSHALL': lambda prefix, args: raise_error(), - 'MGET': prefix_all, - 'MSET': prefix_alternate, - 'MSETNX': prefix_alternate, - 'MIGRATE': prefix_all, - 'RENAME': prefix_all, - 'RENAMENX': prefix_all, - 'SDIFF': prefix_all, - 'SDIFFSTORE': prefix_all, - 'SINTER': prefix_all, - 'SINTERSTORE': prefix_all, - 'SMOVE': prefix_not_last, - 'SORT': prefix_sort, - 'SUNION': prefix_all, - 'SUNIONSTORE': prefix_all, - 'WATCH': prefix_all, - 'ZINTERSTORE': prefix_zinter, - 'ZUNIONSTORE': prefix_zinter + "BITOP": prefix_not_first, + "BLPOP": prefix_not_last, + "BRPOP": prefix_not_last, + "BRPOPLPUSH": prefix_not_last, + "RPOPLPUSH": prefix_all, + "DEL": prefix_all, + "EVAL": prefix_eval_keys, + "EVALSHA": prefix_eval_keys, + "FLUSHDB": lambda prefix, args: raise_error(), + "FLUSHALL": lambda prefix, args: raise_error(), + "MGET": prefix_all, + "MSET": prefix_alternate, + "MSETNX": prefix_alternate, + "MIGRATE": prefix_all, + "RENAME": prefix_all, + "RENAMENX": prefix_all, + "SDIFF": prefix_all, + "SDIFFSTORE": prefix_all, + "SINTER": prefix_all, + "SINTERSTORE": prefix_all, + "SMOVE": prefix_not_last, + "SORT": prefix_sort, + "SUNION": prefix_all, + "SUNIONSTORE": prefix_all, + "WATCH": prefix_all, + "ZINTERSTORE": prefix_zinter, + "ZUNIONSTORE": prefix_zinter, } RESPONSE_CALLBACKS = { - 'KEYS': lambda pfix, response: [r[len(pfix):] for r in response], - 'BLPOP': pop_list_result, - 'BRPOP': pop_list_result + "KEYS": lambda pfix, response: [r[len(pfix) :] for r in response], + "BLPOP": pop_list_result, + "BRPOP": pop_list_result, } def __init__(self, client, prefix): @@ -133,19 +159,19 @@ def preprocess_command(self, cmnd, *args, **options): def handle(self, prefix, args): if args: args = list(args) - args[0] = '%s%s' % (prefix, args[0]) + args[0] = "%s%s" % (prefix, args[0]) return args def dbsize(self): - return self.client.countpattern('%s*' % self.prefix) + return self.client.countpattern("%s*" % self.prefix) def flushdb(self): - return self.client.delpattern('%s*' % self.prefix) + return self.client.delpattern("%s*" % self.prefix) def _parse_response(self, request, response, command_name, args, options): if command_name in self.RESPONSE_CALLBACKS: if not isinstance(response, Exception): - response = self.RESPONSE_CALLBACKS[command_name](self.prefix, - response) - return self.client._parse_response(request, response, command_name, - args, options) + response = self.RESPONSE_CALLBACKS[command_name](self.prefix, response) + return self.client._parse_response( + request, response, command_name, args, options + ) diff --git a/stdnet/odm/__init__.py b/stdnet/odm/__init__.py index 74c526f..d5f577a 100755 --- a/stdnet/odm/__init__.py +++ b/stdnet/odm/__init__.py @@ -1,12 +1,12 @@ -from .query import * -from .session import * -from .related import * -from .fields import * from .base import * +from .fields import * +from .globals import * from .mapper import * from .models import * +from .query import * +from .related import * +from .search import * +from .session import * from .struct import * from .structfields import * -from .globals import * from .utils import * -from .search import * diff --git a/stdnet/odm/base.py b/stdnet/odm/base.py index 720fe76..c0ac2be 100755 --- a/stdnet/odm/base.py +++ b/stdnet/odm/base.py @@ -1,19 +1,24 @@ -'''Defines Metaclasses and Base classes for stdnet Models.''' +"""Defines Metaclasses and Base classes for stdnet Models.""" import sys from copy import copy, deepcopy from inspect import isclass -from stdnet.utils.exceptions import * from stdnet.utils import UnicodeMixin, unique_tuple +from stdnet.utils.exceptions import * from stdnet.utils.structures import OrderedDict -from .globals import hashmodel, JSPLITTER, orderinginfo -from .fields import Field, AutoIdField +from .fields import AutoIdField, Field +from .globals import JSPLITTER, hashmodel, orderinginfo from .related import class_prepared - -__all__ = ['ModelMeta', 'Model', 'ModelBase', 'ModelState', - 'autoincrement', 'ModelType'] +__all__ = [ + "ModelMeta", + "Model", + "ModelBase", + "ModelState", + "autoincrement", + "ModelType", +] def get_fields(bases, attrs): @@ -26,9 +31,14 @@ def get_fields(bases, attrs): fields = sorted(fields, key=lambda x: x[1].creation_counter) # for base in bases: - if hasattr(base, '_meta'): - fields = list((name, deepcopy(field)) for name, field - in base._meta.dfields.items()) + fields + if hasattr(base, "_meta"): + fields = ( + list( + (name, deepcopy(field)) + for name, field in base._meta.dfields.items() + ) + + fields + ) # return OrderedDict(fields) @@ -37,117 +47,128 @@ def make_app_label(new_class, app_label=None): if app_label is None: model_module = sys.modules[new_class.__module__] try: - bits = model_module.__name__.split('.') + bits = model_module.__name__.split(".") app_label = bits.pop() - if app_label == 'models': + if app_label == "models": app_label = bits.pop() except: - app_label = '' + app_label = "" return app_label class ModelMeta(object): - '''A class for storing meta data for a :class:`Model` class. -To override default behaviour you can specify the ``Meta`` class as an inner -class of :class:`Model` in the following way:: + """A class for storing meta data for a :class:`Model` class. + To override default behaviour you can specify the ``Meta`` class as an inner + class of :class:`Model` in the following way:: - from datetime import datetime - from stdnet import odm + from datetime import datetime + from stdnet import odm - class MyModel(odm.StdModel): - timestamp = odm.DateTimeField(default = datetime.now) - ... + class MyModel(odm.StdModel): + timestamp = odm.DateTimeField(default = datetime.now) + ... - class Meta: - ordering = '-timestamp' - name = 'custom' + class Meta: + ordering = '-timestamp' + name = 'custom' -:parameter register: if ``True`` (default), this :class:`ModelMeta` is - registered in the global models hashtable. -:parameter abstract: Check the :attr:`abstract` attribute. -:parameter ordering: Check the :attr:`ordering` attribute. -:parameter app_label: Check the :attr:`app_label` attribute. -:parameter name: Check the :attr:`name` attribute. -:parameter modelkey: Check the :attr:`modelkey` attribute. -:parameter attributes: Check the :attr:`attributes` attribute. + :parameter register: if ``True`` (default), this :class:`ModelMeta` is + registered in the global models hashtable. + :parameter abstract: Check the :attr:`abstract` attribute. + :parameter ordering: Check the :attr:`ordering` attribute. + :parameter app_label: Check the :attr:`app_label` attribute. + :parameter name: Check the :attr:`name` attribute. + :parameter modelkey: Check the :attr:`modelkey` attribute. + :parameter attributes: Check the :attr:`attributes` attribute. -This is the list of attributes and methods available. All attributes, -but the ones mantioned above, are initialized by the object relational -mapper. + This is the list of attributes and methods available. All attributes, + but the ones mantioned above, are initialized by the object relational + mapper. -.. attribute:: abstract + .. attribute:: abstract - If ``True``, This is an abstract Meta class. + If ``True``, This is an abstract Meta class. -.. attribute:: model + .. attribute:: model - :class:`Model` for which this class is the database metadata container. + :class:`Model` for which this class is the database metadata container. -.. attribute:: name + .. attribute:: name - Usually it is the :class:`Model` class name in lower-case, but it - can be customised. + Usually it is the :class:`Model` class name in lower-case, but it + can be customised. -.. attribute:: app_label + .. attribute:: app_label - Unless specified it is the name of the directory or file - (if at top level) containing the :class:`Model` definition. It can be - customised. + Unless specified it is the name of the directory or file + (if at top level) containing the :class:`Model` definition. It can be + customised. -.. attribute:: modelkey + .. attribute:: modelkey - The modelkey which is by default given by ``app_label.name``. + The modelkey which is by default given by ``app_label.name``. -.. attribute:: ordering + .. attribute:: ordering - Optional name of a :class:`Field` in the :attr:`model`. - If provided, model indices will be sorted with respect to the value of the - specified field. It can also be a :class:`autoincrement` instance. - Check the :ref:`sorting ` documentation for more details. + Optional name of a :class:`Field` in the :attr:`model`. + If provided, model indices will be sorted with respect to the value of the + specified field. It can also be a :class:`autoincrement` instance. + Check the :ref:`sorting ` documentation for more details. - Default: ``None``. + Default: ``None``. -.. attribute:: dfields + .. attribute:: dfields - dictionary of :class:`Field` instances. + dictionary of :class:`Field` instances. -.. attribute:: fields + .. attribute:: fields - list of all :class:`Field` instances. + list of all :class:`Field` instances. -.. attribute:: scalarfields + .. attribute:: scalarfields - Ordered list of all :class:`Field` which are not :class:`StructureField`. - The order is the same as in the :class:`Model` definition. The :attr:`pk` - field is not included. + Ordered list of all :class:`Field` which are not :class:`StructureField`. + The order is the same as in the :class:`Model` definition. The :attr:`pk` + field is not included. -.. attribute:: indices + .. attribute:: indices - List of :class:`Field` which are indices (:attr:`Field.index` attribute - set to ``True``). + List of :class:`Field` which are indices (:attr:`Field.index` attribute + set to ``True``). -.. attribute:: pk + .. attribute:: pk - The :class:`Field` representing the primary key. + The :class:`Field` representing the primary key. -.. attribute:: related + .. attribute:: related - Dictionary of :class:`related.RelatedManager` for the :attr:`model`. It is - created at runtime by the object data mapper. + Dictionary of :class:`related.RelatedManager` for the :attr:`model`. It is + created at runtime by the object data mapper. -.. attribute:: manytomany + .. attribute:: manytomany - List of :class:`ManyToManyField` names for the :attr:`model`. This - information is useful during registration. + List of :class:`ManyToManyField` names for the :attr:`model`. This + information is useful during registration. -.. attribute:: attributes + .. attribute:: attributes - Additional attributes for :attr:`model`. -''' - def __init__(self, model, fields, app_label=None, modelkey=None, - name=None, register=True, pkname=None, ordering=None, - attributes=None, abstract=False, **kwargs): + Additional attributes for :attr:`model`.""" + + def __init__( + self, + model, + fields, + app_label=None, + modelkey=None, + name=None, + register=True, + pkname=None, + ordering=None, + attributes=None, + abstract=False, + **kwargs + ): self.model = model self.abstract = abstract self.attributes = unique_tuple(attributes or ()) @@ -163,7 +184,7 @@ def __init__(self, model, fields, app_label=None, modelkey=None, self.name = (name or model.__name__).lower() if not modelkey: if self.app_label: - modelkey = '{0}.{1}'.format(self.app_label, self.name) + modelkey = "{0}.{1}".format(self.app_label, self.name) else: modelkey = self.name self.modelkey = modelkey @@ -172,13 +193,12 @@ def __init__(self, model, fields, app_label=None, modelkey=None, # # Check if PK field exists pk = None - pkname = pkname or 'id' + pkname = pkname or "id" for name in fields: field = fields[name] if field.primary_key: if pk is not None: - raise FieldError("Primary key already available %s." - % name) + raise FieldError("Primary key already available %s." % name) pk = field pkname = name if pk is None and not self.abstract: @@ -195,11 +215,11 @@ def __init__(self, model, fields, app_label=None, modelkey=None, @property def type(self): - '''Model type, either ``structure`` or ``object``.''' + """Model type, either ``structure`` or ``object``.""" return self.model._model_type def make_object(self, state=None, backend=None): - '''Create a new instance of :attr:`model` from a *state* tuple.''' + """Create a new instance of :attr:`model` from a *state* tuple.""" model = self.model obj = model.__new__(model) self.load_state(obj, state, backend) @@ -217,8 +237,9 @@ def load_state(self, obj, state=None, backend=None): for field in obj.loadedfields(): value = field.value_from_data(obj, data) setattr(obj, field.attname, field.to_python(value, backend)) - if backend or ('__dbdata__' in data and - data['__dbdata__'][pk.name] == pkvalue): + if backend or ( + "__dbdata__" in data and data["__dbdata__"][pk.name] == pkvalue + ): obj.dbdata[pk.name] = pkvalue def __repr__(self): @@ -228,22 +249,21 @@ def __str__(self): return self.__repr__() def pkname(self): - '''Primary key name. A shortcut for ``self.pk.name``.''' + """Primary key name. A shortcut for ``self.pk.name``.""" return self.pk.name def pk_to_python(self, value, backend): - '''Convert the primary key ``value`` to a valid python representation. - ''' + """Convert the primary key ``value`` to a valid python representation.""" return self.pk.to_python(value, backend) def is_valid(self, instance): - '''Perform validation for *instance* and stores serialized data, -indexes and errors into local cache. -Return ``True`` if the instance is ready to be saved to database.''' + """Perform validation for *instance* and stores serialized data, + indexes and errors into local cache. + Return ``True`` if the instance is ready to be saved to database.""" dbdata = instance.dbdata - data = dbdata['cleaned_data'] = {} - errors = dbdata['errors'] = {} - #Loop over scalar fields first + data = dbdata["cleaned_data"] = {} + errors = dbdata["errors"] = {} + # Loop over scalar fields first for field, value in instance.fieldvalue_pairs(): name = field.attname try: @@ -251,9 +271,10 @@ def is_valid(self, instance): except Exception as e: errors[name] = str(e) else: - if (svalue is None or svalue is '') and field.required: - errors[name] = ("Field '{0}' is required for '{1}'." - .format(name, self)) + if (svalue is None or svalue is "") and field.required: + errors[name] = "Field '{0}' is required for '{1}'.".format( + name, self + ) else: if isinstance(svalue, dict): data.update(svalue) @@ -266,7 +287,7 @@ def get_sorting(self, sortby, errorClass=None): if isinstance(sortby, autoincrement): f = self.pk return orderinginfo(sortby, f, desc, self.model, None, True) - elif sortby.startswith('-'): + elif sortby.startswith("-"): desc = True sortby = sortby[1:] if sortby == self.pkname(): @@ -275,8 +296,7 @@ def get_sorting(self, sortby, errorClass=None): else: if sortby in self.dfields: f = self.dfields[sortby] - return orderinginfo(f.attname, f, desc, self.model, - None, False) + return orderinginfo(f.attname, f, desc, self.model, None, False) sortbys = sortby.split(JSPLITTER) s0 = sortbys[0] if len(sortbys) > 1 and s0 in self.dfields: @@ -286,12 +306,14 @@ def get_sorting(self, sortby, errorClass=None): sortby = f.attname return orderinginfo(sortby, f, desc, self.model, nested, False) errorClass = errorClass or ValueError - raise errorClass('"%s" cannot order by attribute "%s". It is not a ' - 'scalar field.' % (self, sortby)) + raise errorClass( + '"%s" cannot order by attribute "%s". It is not a ' + "scalar field." % (self, sortby) + ) def backend_fields(self, fields): - '''Return a two elements tuple containing a list -of fields names and a list of field attribute names.''' + """Return a two elements tuple containing a list + of fields names and a list of field attribute names.""" dfields = self.dfields processed = set() names = [] @@ -309,50 +331,51 @@ def backend_fields(self, fields): bname = name.split(JSPLITTER)[0] if bname in dfields: field = dfields[bname] - if field.type in ('json object', 'related object'): + if field.type in ("json object", "related object"): processed.add(name) names.append(name) atts.append(name) return names, atts def as_dict(self): - '''Model metadata in a dictionary''' + """Model metadata in a dictionary""" pk = self.pk id_type = 3 - if pk.type == 'auto': + if pk.type == "auto": id_type = 1 - return {'id_name': pk.name, - 'id_type': id_type, - 'sorted': bool(self.ordering), - 'autoincr': self.ordering and self.ordering.auto, - 'multi_fields': [field.name for field in self.multifields], - 'indices': dict(((idx.attname, idx.unique) - for idx in self.indices))} + return { + "id_name": pk.name, + "id_type": id_type, + "sorted": bool(self.ordering), + "autoincr": self.ordering and self.ordering.auto, + "multi_fields": [field.name for field in self.multifields], + "indices": dict(((idx.attname, idx.unique) for idx in self.indices)), + } class autoincrement(object): - '''An :class:`autoincrement` is used in a :class:`StdModel` Meta -class to specify a model with :ref:`incremental sorting `. + """An :class:`autoincrement` is used in a :class:`StdModel` Meta + class to specify a model with :ref:`incremental sorting `. -.. attribute:: incrby + .. attribute:: incrby - The amount to increment the score by when a duplicate element is saved. + The amount to increment the score by when a duplicate element is saved. - Default: 1. + Default: 1. -For example, the :class:`stdnet.apps.searchengine.Word` model is defined as:: + For example, the :class:`stdnet.apps.searchengine.Word` model is defined as:: - class Word(odm.StdModel): - id = odm.SymbolField(primary_key = True) + class Word(odm.StdModel): + id = odm.SymbolField(primary_key = True) - class Meta: - ordering = -autoincrement() + class Meta: + ordering = -autoincrement() -This means every time we save a new instance of Word, and that instance has -an id already available, the score of that word is incremented by the -:attr:`incrby` attribute. + This means every time we save a new instance of Word, and that instance has + an id already available, the score of that word is incremented by the + :attr:`incrby` attribute. + """ -''' def __init__(self, incrby=1, desc=False): self.incrby = incrby self._asce = -1 if desc else 1 @@ -367,20 +390,23 @@ def desc(self): return True if self._asce == -1 else False def __repr__(self): - return ('' if self._asce == 1 else '-' - ) + '{0}({1})'.format(self.__class__.__name__, self.incrby) + return ("" if self._asce == 1 else "-") + "{0}({1})".format( + self.__class__.__name__, self.incrby + ) def __str__(self): return self.__repr__() class ModelType(type): - '''Model metaclass''' + """Model metaclass""" + def __new__(cls, name, bases, attrs): - meta = attrs.pop('Meta', None) + meta = attrs.pop("Meta", None) if isclass(meta): - meta = dict(((k, v) for k, v in meta.__dict__.items() - if not k.startswith('__'))) + meta = dict( + ((k, v) for k, v in meta.__dict__.items() if not k.startswith("__")) + ) else: meta = meta or {} cls.extend_meta(meta, attrs) @@ -392,73 +418,75 @@ def __new__(cls, name, bases, attrs): @classmethod def extend_meta(cls, meta, attrs): - for name in ('register', 'abstract', 'attributes'): + for name in ("register", "abstract", "attributes"): if name in attrs: meta[name] = attrs.pop(name) class ModelState(object): - '''The database state of a :class:`Model`.''' + """The database state of a :class:`Model`.""" + def __init__(self, instance, iid=None, action=None): - self._action = action or 'add' + self._action = action or "add" self.deleted = False self.score = 0 dbdata = instance.dbdata pkname = instance._meta.pkname() pkvalue = iid or getattr(instance, pkname, None) if pkvalue and pkname in dbdata: - if self._action == 'add': + if self._action == "add": self._action = instance.get_state_action() elif not pkvalue: - self._action = 'add' - pkvalue = 'new.{0}'.format(id(instance)) + self._action = "add" + pkvalue = "new.{0}".format(id(instance)) self._iid = pkvalue @property def action(self): - '''Action to be performed by the backend server when committing -changes to the instance of :class:`Model` for which this is a state.''' + """Action to be performed by the backend server when committing + changes to the instance of :class:`Model` for which this is a state.""" return self._action @property def persistent(self): - '''``True`` if the instance is persistent in the backend server.''' - return self._action != 'add' + """``True`` if the instance is persistent in the backend server.""" + return self._action != "add" @property def iid(self): - '''Instance primary key or a temporary key if not yet available.''' + """Instance primary key or a temporary key if not yet available.""" return self._iid def __repr__(self): - return '%s%s' % (self.iid, ' deleted' if self.deleted else '') + return "%s%s" % (self.iid, " deleted" if self.deleted else "") + __str__ = __repr__ class Model(UnicodeMixin): - '''This is the base class for both :class:`StdModel` and :class:`Structure` -classes. It implements the :attr:`uuid` attribute which provides the universal -unique identifier for an instance of a model. + """This is the base class for both :class:`StdModel` and :class:`Structure` + classes. It implements the :attr:`uuid` attribute which provides the universal + unique identifier for an instance of a model. -.. attribute:: _meta + .. attribute:: _meta - A class attribute which is an instance of :class:`ModelMeta`, it - containes all the information needed by a - :class:`stdnet.BackendDataServer`. + A class attribute which is an instance of :class:`ModelMeta`, it + containes all the information needed by a + :class:`stdnet.BackendDataServer`. -.. attribute:: session + .. attribute:: session + + The :class:`Session` which loaded the instance. Only available, + when the instance has been loaded from a :class:`stdnet.BackendDataServer` + via a :ref:`query operation `.""" - The :class:`Session` which loaded the instance. Only available, - when the instance has been loaded from a :class:`stdnet.BackendDataServer` - via a :ref:`query operation `. -''' _dbdata = None _model_type = None DoesNotExist = ObjectNotFound - '''Exception raised when an instance of a model does not exist.''' + """Exception raised when an instance of a model does not exist.""" DoesNotValidate = ObjectNotValidated - '''Exception raised when an instance of a model does not validate. Usually -raised when trying to save an invalid instance.''' + """Exception raised when an instance of a model does not validate. Usually +raised when trying to save an invalid instance.""" def __eq__(self, other): if other.__class__ == self.__class__: @@ -473,30 +501,30 @@ def __hash__(self): return hash(self.get_uuid(self.get_state().iid)) def get_state(self, **kwargs): - '''Return the current :class:`ModelState` for this :class:`Model`. -If ``kwargs`` parameters are passed a new :class:`ModelState` is created, -otherwise it returns the cached value.''' + """Return the current :class:`ModelState` for this :class:`Model`. + If ``kwargs`` parameters are passed a new :class:`ModelState` is created, + otherwise it returns the cached value.""" dbdata = self.dbdata - if 'state' not in dbdata or kwargs: - dbdata['state'] = ModelState(self, **kwargs) - return dbdata['state'] + if "state" not in dbdata or kwargs: + dbdata["state"] = ModelState(self, **kwargs) + return dbdata["state"] def pkvalue(self): - '''Value of primary key''' + """Value of primary key""" return self._meta.pk.get_value(self) @classmethod def get_uuid(cls, pk): - return '%s.%s' % (cls._meta.hash, pk) + return "%s.%s" % (cls._meta.hash, pk) @property def uuid(self): - '''Universally unique identifier for an instance of a :class:`Model`. - ''' + """Universally unique identifier for an instance of a :class:`Model`.""" pk = self.pkvalue() if not pk: raise self.DoesNotExist( - 'Object not saved. Cannot obtain universally unique id') + "Object not saved. Cannot obtain universally unique id" + ) return self.get_uuid(pk) @property @@ -506,64 +534,65 @@ def dbdata(self): return self._dbdata def __get_session(self): - return self.dbdata.get('session') + return self.dbdata.get("session") def __set_session(self, session): - self.dbdata['session'] = session - session = property(__get_session, __set_session, - doc='The current :class:`Session` for this model.') + self.dbdata["session"] = session + + session = property( + __get_session, __set_session, doc="The current :class:`Session` for this model." + ) @property def backend(self, client=None): - '''The :class:`stdnet.BackendDatServer` for this instance. + """The :class:`stdnet.BackendDatServer` for this instance. It can be ``None``. - ''' + """ session = self.session if session: return session.model(self).backend @property def read_backend(self, client=None): - '''The read :class:`stdnet.BackendDatServer` for this instance. + """The read :class:`stdnet.BackendDatServer` for this instance. It can be ``None``. - ''' + """ session = self.session if session: return session.model(self).read_backend def get_attr_value(self, name): - '''Provided for compatibility with :meth:`StdModel.get_attr_value`. -For this class it simply get the attribute at name:: + """Provided for compatibility with :meth:`StdModel.get_attr_value`. + For this class it simply get the attribute at name:: - return getattr(self, name) -''' + return getattr(self, name)""" return getattr(self, name) def get_state_action(self): - return 'update' + return "update" def save(self): - '''Save the model by adding it to the :attr:`session`. If the -:attr:`session` is not available, it raises a :class:`SessionNotAvailable` -exception.''' + """Save the model by adding it to the :attr:`session`. If the + :attr:`session` is not available, it raises a :class:`SessionNotAvailable` + exception.""" return self.session.add(self) def delete(self): - '''Delete the model. If the :attr:`session` is not available, -it raises a :class:`SessionNotAvailable` exception.''' + """Delete the model. If the :attr:`session` is not available, + it raises a :class:`SessionNotAvailable` exception.""" return self.session.delete(self) -ModelBase = ModelType('ModelBase', (Model,), {'abstract': True}) +ModelBase = ModelType("ModelBase", (Model,), {"abstract": True}) def raise_kwargs(model, kwargs): if kwargs: - keys = ', '.join(kwargs) + keys = ", ".join(kwargs) if len(kwargs) > 1: - keys += ' are' + keys += " are" else: - keys += ' is an' + keys += " is an" raise ValueError("%s invalid keyword for %s." % (keys, model._meta)) diff --git a/stdnet/odm/fields.py b/stdnet/odm/fields.py index 1a78227..0977c81 100755 --- a/stdnet/odm/fields.py +++ b/stdnet/odm/fields.py @@ -1,132 +1,143 @@ import logging +from base64 import b64encode from copy import copy from datetime import date, datetime -from base64 import b64encode from stdnet import range_lookups -from stdnet.utils import (DefaultJSONEncoder, DefaultJSONHook, timestamp2date, - date2timestamp, UnicodeMixin, to_string, string_type, - encoders, flat_to_nested, dict_flat_generator) +from stdnet.utils import ( + DefaultJSONEncoder, + DefaultJSONHook, + UnicodeMixin, + date2timestamp, + dict_flat_generator, + encoders, + flat_to_nested, + string_type, + timestamp2date, + to_string, +) from stdnet.utils.exceptions import * from . import related -from .globals import get_model_from_hash, get_hash_from_model, JSPLITTER - -logger = logging.getLogger('stdnet.odm') - -__all__ = ['Field', - 'AutoIdField', - 'AtomField', - 'IntegerField', - 'BooleanField', - 'FloatField', - 'DateField', - 'DateTimeField', - 'SymbolField', - 'CharField', - 'ByteField', - 'ForeignKey', - 'JSONField', - 'PickleObjectField', - 'ModelField', - 'ManyToManyField', - 'CompositeIdField', - 'JSPLITTER'] - -NONE_EMPTY = (None, '') +from .globals import JSPLITTER, get_hash_from_model, get_model_from_hash + +logger = logging.getLogger("stdnet.odm") + +__all__ = [ + "Field", + "AutoIdField", + "AtomField", + "IntegerField", + "BooleanField", + "FloatField", + "DateField", + "DateTimeField", + "SymbolField", + "CharField", + "ByteField", + "ForeignKey", + "JSONField", + "PickleObjectField", + "ModelField", + "ManyToManyField", + "CompositeIdField", + "JSPLITTER", +] + +NONE_EMPTY = (None, "") class Field(UnicodeMixin): - '''This is the base class of all StdNet Fields. -Each field is specified as a :class:`StdModel` class attribute. + """This is the base class of all StdNet Fields. + Each field is specified as a :class:`StdModel` class attribute. -.. attribute:: index + .. attribute:: index - Probably the most important field attribute, it establish if - the field creates indexes for queries. - If you don't need to query the field you should set this value to - ``False``, it will save you memory. + Probably the most important field attribute, it establish if + the field creates indexes for queries. + If you don't need to query the field you should set this value to + ``False``, it will save you memory. - .. note:: if ``index`` is set to ``False`` executing queries - against the field will - throw a :class:`stdnet.QuerySetError` exception. - No database queries are allowed for non indexed fields - as a design decision (explicit better than implicit). + .. note:: if ``index`` is set to ``False`` executing queries + against the field will + throw a :class:`stdnet.QuerySetError` exception. + No database queries are allowed for non indexed fields + as a design decision (explicit better than implicit). - Default ``True``. + Default ``True``. -.. attribute:: unique + .. attribute:: unique - If ``True``, the field must be unique throughout the model. - In this case :attr:`Field.index` is also ``True``. - Enforced at :class:`stdnet.BackendDataServer` level. + If ``True``, the field must be unique throughout the model. + In this case :attr:`Field.index` is also ``True``. + Enforced at :class:`stdnet.BackendDataServer` level. - Default ``False``. + Default ``False``. -.. attribute:: primary_key + .. attribute:: primary_key - If ``True``, this field is the primary key for the model. - A primary key field has the following properties: + If ``True``, this field is the primary key for the model. + A primary key field has the following properties: - * :attr:`Field.unique` is also ``True``. - * There can be only one in a model. - * It's attribute name in the model must be **id**. - * If not specified a :class:`AutoIdField` will be added. + * :attr:`Field.unique` is also ``True``. + * There can be only one in a model. + * It's attribute name in the model must be **id**. + * If not specified a :class:`AutoIdField` will be added. - Default ``False``. + Default ``False``. -.. attribute:: required + .. attribute:: required - If ``False``, the field is allowed to be null. + If ``False``, the field is allowed to be null. - Default ``True``. + Default ``True``. -.. attribute:: default + .. attribute:: default - Default value for this field. It can be a callable attribute with arity 0. + Default value for this field. It can be a callable attribute with arity 0. - Default ``None``. + Default ``None``. -.. attribute:: name + .. attribute:: name - Field name, created by the ``odm`` at runtime. + Field name, created by the ``odm`` at runtime. -.. attribute:: attname + .. attribute:: attname - The attribute name for the field, created by the :meth:`get_attname` - method at runtime. For most field, its value is the same as the - :attr:`name`. It is the field sorted in the backend database. + The attribute name for the field, created by the :meth:`get_attname` + method at runtime. For most field, its value is the same as the + :attr:`name`. It is the field sorted in the backend database. -.. attribute:: model + .. attribute:: model - The :class:`StdModel` holding the field. - Created by the ``odm`` at runtime. + The :class:`StdModel` holding the field. + Created by the ``odm`` at runtime. -.. attribute:: charset + .. attribute:: charset - The charset used for encoding decoding text. + The charset used for encoding decoding text. -.. attribute:: hidden + .. attribute:: hidden - If ``True`` the field will be hidden from search algorithms. + If ``True`` the field will be hidden from search algorithms. - Default ``False``. + Default ``False``. -.. attribute:: python_type + .. attribute:: python_type - The python ``type`` for the :class:`Field`. + The python ``type`` for the :class:`Field`. -.. attribute:: as_cache + .. attribute:: as_cache - If ``True`` the field contains data which is considered cache and - therefore always reproducible. Field marked as cache, have :attr:`required` - always ``False``. + If ``True`` the field contains data which is considered cache and + therefore always reproducible. Field marked as cache, have :attr:`required` + always ``False``. - This attribute is used by the :class:`StdModel.fieldvalue_pairs` method - which returns a dictionary of field names and values. + This attribute is used by the :class:`StdModel.fieldvalue_pairs` method + which returns a dictionary of field names and values. + + Default ``False``.""" - Default ``False``. -''' _default = None type = None python_type = None @@ -136,8 +147,16 @@ class Field(UnicodeMixin): internal_type = None creation_counter = 0 - def __init__(self, unique=False, primary_key=False, required=True, - index=None, hidden=None, as_cache=False, **extras): + def __init__( + self, + unique=False, + primary_key=False, + required=True, + index=None, + hidden=None, + as_cache=False, + **extras + ): self.primary_key = primary_key index = index if index is not None else self.index if primary_key: @@ -145,7 +164,7 @@ def __init__(self, unique=False, primary_key=False, required=True, self.required = True self.index = True self.as_cache = False - extras['default'] = None + extras["default"] = None else: self.unique = unique self.required = required @@ -155,12 +174,12 @@ def __init__(self, unique=False, primary_key=False, required=True, self.required = False self.unique = False self.index = False - self.charset = extras.pop('charset', self.charset) + self.charset = extras.pop("charset", self.charset) self.hidden = hidden if hidden is not None else self.hidden self.meta = None self.name = None self.model = None - self._default = extras.pop('default', self._default) + self._default = extras.pop("default", self._default) self.encoder = self.get_encoder(extras) self._handle_extras(**extras) self.creation_counter = Field.creation_counter @@ -180,23 +199,30 @@ def get_encoder(self, params): def error_extras(self, extras): keys = list(extras) if keys: - raise TypeError("__init__() got an unexepcted keyword\ - argument '{0}'".format(keys[0])) + raise TypeError( + "__init__() got an unexepcted keyword\ + argument '{0}'".format( + keys[0] + ) + ) def __unicode__(self): - return to_string('%s.%s' % (self.meta, self.name)) + return to_string("%s.%s" % (self.meta, self.name)) def value_from_data(self, instance, data): return data.pop(self.attname, None) def register_with_model(self, name, model): - '''Called during the creation of a the :class:`StdModel` -class when :class:`Metaclass` is initialised. It fills -:attr:`Field.name` and :attr:`Field.model`. This is an internal -function users should never call.''' + """Called during the creation of a the :class:`StdModel` + class when :class:`Metaclass` is initialised. It fills + :attr:`Field.name` and :attr:`Field.model`. This is an internal + function users should never call.""" if self.name: - raise FieldError('Field %s is already registered\ - with a model' % self) + raise FieldError( + "Field %s is already registered\ + with a model" + % self + ) self.name = name self.attname = self.get_attname() self.model = model @@ -210,35 +236,35 @@ class when :class:`Metaclass` is initialised. It fills model._meta.pk = self def add_to_fields(self): - '''Add this :class:`Field` to the fields of :attr:`model`.''' + """Add this :class:`Field` to the fields of :attr:`model`.""" meta = self.model._meta meta.scalarfields.append(self) if self.index: meta.indices.append(self) def get_attname(self): - '''Generate the :attr:`attname` at runtime''' + """Generate the :attr:`attname` at runtime""" return self.name def get_cache_name(self): - '''name for the private attribute which contains a cached value -for this field. Used only by realted fields.''' - return '_%s_cache' % self.name + """name for the private attribute which contains a cached value + for this field. Used only by realted fields.""" + return "_%s_cache" % self.name def id(self, obj): - '''Field id for object *obj*, if applicable. Default is ``None``.''' + """Field id for object *obj*, if applicable. Default is ``None``.""" return None def get_default(self): "Returns the default value for this field." default = self._default - if hasattr(default, '__call__'): + if hasattr(default, "__call__"): return default() else: return default def __deepcopy__(self, memodict): - '''Nothing to deepcopy here''' + """Nothing to deepcopy here""" field = copy(self) field.name = None field.model = None @@ -249,18 +275,18 @@ def filter(self, session, name, value): pass def get_sorting(self, name, errorClass): - raise errorClass('Cannot use nested sorting on field {0}'.format(self)) + raise errorClass("Cannot use nested sorting on field {0}".format(self)) def get_lookup(self, remaining, errorClass=ValueError): - '''called by the :class:`Query` method when it needs to build -lookup on fields with additional nested fields. This is the case of -:class:`ForeignKey` and :class:`JSONField`. + """called by the :class:`Query` method when it needs to build + lookup on fields with additional nested fields. This is the case of + :class:`ForeignKey` and :class:`JSONField`. -:param remaining: the :ref:`double underscored` fields if this :class:`Field` -:param errorClass: Optional exception class to use if the *remaining* field - is not valid.''' + :param remaining: the :ref:`double underscored` fields if this :class:`Field` + :param errorClass: Optional exception class to use if the *remaining* field + is not valid.""" if remaining: - raise errorClass('Cannot use nested lookup on field %s' % self) + raise errorClass("Cannot use nested lookup on field %s" % self) return (self.attname, None) def todelete(self): @@ -270,57 +296,56 @@ def todelete(self): ## FIELD VALUES ######################################################################## def get_value(self, instance, *bits): - '''Retrieve the value :class:`Field` from a :class:`StdModel` -``instance``. - -:param instance: The :class:`StdModel` ``instance`` invoking this function. -:param bits: Additional information for nested fields which derives from - the :ref:`double underscore ` notation. -:return: the value of this :class:`Field` in the ``instance``. can raise - :class:`AttributeError`. - -This method is used by the :meth:`StdModel.get_attr_value` method when -retrieving values form a :class:`StdModel` instance. -''' + """Retrieve the value :class:`Field` from a :class:`StdModel` + ``instance``. + + :param instance: The :class:`StdModel` ``instance`` invoking this function. + :param bits: Additional information for nested fields which derives from + the :ref:`double underscore ` notation. + :return: the value of this :class:`Field` in the ``instance``. can raise + :class:`AttributeError`. + + This method is used by the :meth:`StdModel.get_attr_value` method when + retrieving values form a :class:`StdModel` instance.""" if bits: raise AttributeError else: return getattr(instance, self.attname) def set_value(self, instance, value): - '''Set the ``value`` for this :class:`Field` in a ``instance`` -of a :class:`StdModel`.''' + """Set the ``value`` for this :class:`Field` in a ``instance`` + of a :class:`StdModel`.""" setattr(instance, self.attname, self.to_python(value)) def set_get_value(self, instance, value): - '''Set the ``value`` for this :class:`Field` in a ``instance`` -of a :class:`StdModel` and return the database representation. This method -is invoked by the validation lagorithm before saving instances.''' + """Set the ``value`` for this :class:`Field` in a ``instance`` + of a :class:`StdModel` and return the database representation. This method + is invoked by the validation lagorithm before saving instances.""" value = self.to_python(value) setattr(instance, self.attname, value) return value def to_python(self, value, backend=None): """Converts the input value into the expected Python -data type, raising :class:`stdnet.FieldValueError` if the data -can't be converted. -Returns the converted value. Subclasses should override this.""" + data type, raising :class:`stdnet.FieldValueError` if the data + can't be converted. + Returns the converted value. Subclasses should override this.""" return value def serialise(self, value, lookup=None): - '''Convert ``value`` to a valid database representation for this field. + """Convert ``value`` to a valid database representation for this field. -This method is invoked by the Query algorithm.''' + This method is invoked by the Query algorithm.""" return self.to_python(value) def json_serialise(self, value): - '''Return a representation of this field which is compatible with - JSON.''' + """Return a representation of this field which is compatible with + JSON.""" return None def scorefun(self, value): - '''Function which evaluate a score from the field value. Used by -the ordering alorithm''' + """Function which evaluate a score from the field value. Used by + the ordering alorithm""" return self.to_python(value) ######################################################################## @@ -328,46 +353,48 @@ def scorefun(self, value): ######################################################################## def _handle_extras(self, **extras): - '''Callback to hadle extra arguments during initialization.''' + """Callback to hadle extra arguments during initialization.""" self.error_extras(extras) class AtomField(Field): - '''The base class for fields containing ``atoms``. -An atom is an irreducible -value with a specific data type. it can be of four different types: - -* boolean -* integer -* date -* datetime -* floating point -* symbol -''' + """The base class for fields containing ``atoms``. + An atom is an irreducible + value with a specific data type. it can be of four different types: + + * boolean + * integer + * date + * datetime + * floating point + * symbol""" + def to_python(self, value, backend=None): - if hasattr(value, '_meta'): + if hasattr(value, "_meta"): return value.pkvalue() else: return value + json_serialise = to_python class AutoIdField(AtomField): - '''An :class:`AtomField` for primary keys which are automatically -generated by the backend server. -You usually won't need to use this directly; a ``primary_key`` field -of this type, named ``id``, will automatically be added to your model -if you don't specify otherwise. -Check the :ref:`primary key tutorial ` for -further information on primary keys.''' - type = 'auto' + """An :class:`AtomField` for primary keys which are automatically + generated by the backend server. + You usually won't need to use this directly; a ``primary_key`` field + of this type, named ``id``, will automatically be added to your model + if you don't specify otherwise. + Check the :ref:`primary key tutorial ` for + further information on primary keys.""" + + type = "auto" def __init__(self, *args, **kwargs): - kwargs['primary_key'] = True + kwargs["primary_key"] = True super(AutoIdField, self).__init__(*args, **kwargs) def to_python(self, value, backend=None): - if hasattr(value, '_meta'): + if hasattr(value, "_meta"): return value.pkvalue() elif backend: return backend.auto_id_to_python(value) @@ -376,9 +403,10 @@ def to_python(self, value, backend=None): class BooleanField(AtomField): - '''A boolean :class:`AtomField`''' - type = 'bool' - internal_type = 'numeric' + """A boolean :class:`AtomField`""" + + type = "bool" + internal_type = "numeric" python_type = bool _default = False @@ -398,13 +426,15 @@ def set_get_value(self, instance, value): def serialise(self, value, lookup=None): return 1 if value else 0 + scorefun = serialise class IntegerField(AtomField): - '''An integer :class:`AtomField`.''' - type = 'integer' - internal_type = 'numeric' + """An integer :class:`AtomField`.""" + + type = "integer" + internal_type = "numeric" python_type = int def to_python(self, value, backend=None): @@ -415,20 +445,22 @@ def to_python(self, value, backend=None): class FloatField(IntegerField): - '''An floating point :class:`AtomField`. By default -its :attr:`Field.index` is set to ``False``. - ''' - type = 'float' - internal_type = 'numeric' + """An floating point :class:`AtomField`. By default + its :attr:`Field.index` is set to ``False``. + """ + + type = "float" + internal_type = "numeric" index = False python_type = float class DateField(AtomField): - '''An :class:`AtomField` represented in Python by -a :class:`datetime.date` instance.''' - type = 'date' - internal_type = 'numeric' + """An :class:`AtomField` represented in Python by + a :class:`datetime.date` instance.""" + + type = "date" + internal_type = "numeric" python_type = date _default = None @@ -453,16 +485,18 @@ def serialise(self, value, lookup=None): if isinstance(value, date): value = date2timestamp(value) else: - raise FieldValueError('Field %s is not a valid date' % self) + raise FieldValueError("Field %s is not a valid date" % self) return value + scorefun = serialise json_serialise = serialise class DateTimeField(DateField): - '''A date :class:`AtomField` represented in Python by -a :class:`datetime.datetime` instance.''' - type = 'datetime' + """A date :class:`AtomField` represented in Python by + a :class:`datetime.datetime` instance.""" + + type = "datetime" python_type = datetime index = False @@ -479,15 +513,16 @@ def to_python(self, value, backend=None): class SymbolField(AtomField): - '''An :class:`AtomField` which contains a ``symbol``. -A symbol holds a unicode string as a single unit. -A symbol is irreducible, and are often used to hold names, codes -or other entities. They are indexes by default.''' - type = 'text' + """An :class:`AtomField` which contains a ``symbol``. + A symbol holds a unicode string as a single unit. + A symbol is irreducible, and are often used to hold names, codes + or other entities. They are indexes by default.""" + + type = "text" python_type = string_type - internal_type = 'text' - charset = 'utf-8' - _default = '' + internal_type = "text" + charset = "utf-8" + _default = "" def get_encoder(self, params): return encoders.Default(self.charset) @@ -499,31 +534,33 @@ def to_python(self, value, backend=None): return self.get_default() def scorefun(self, value): - raise FieldValueError('Could not obtain score') + raise FieldValueError("Could not obtain score") class CharField(SymbolField): - '''A text :class:`SymbolField` which is never an index. -It contains unicode and by default and :attr:`Field.required` -is set to ``False``.''' + """A text :class:`SymbolField` which is never an index. + It contains unicode and by default and :attr:`Field.required` + is set to ``False``.""" + def __init__(self, *args, **kwargs): - kwargs['index'] = False - kwargs['unique'] = False - kwargs['primary_key'] = False - self.max_length = kwargs.pop('max_length', None) # not used for now - required = kwargs.get('required', None) + kwargs["index"] = False + kwargs["unique"] = False + kwargs["primary_key"] = False + self.max_length = kwargs.pop("max_length", None) # not used for now + required = kwargs.get("required", None) if required is None: - kwargs['required'] = False + kwargs["required"] = False super(CharField, self).__init__(*args, **kwargs) class ByteField(CharField): - '''A :class:`CharField` which contains binary data. -In python this is converted to `bytes`.''' - type = 'bytes' - internal_type = 'bytes' + """A :class:`CharField` which contains binary data. + In python this is converted to `bytes`.""" + + type = "bytes" + internal_type = "bytes" python_type = bytes - _default = b'' + _default = b"" def json_serialise(self, value): if value is not None: @@ -534,16 +571,16 @@ def get_encoder(self, params): class PickleObjectField(ByteField): - '''A field which implements automatic conversion to and form a picklable -python object. -This field is python specific and therefore not of much use -if accessed from external programs. Consider the :class:`ForeignKey` -or :class:`JSONField` fields as more general alternatives. - -.. note:: The best way to use this field is when its :class:`Field.as_cache` - attribute is ``True``. -''' - type = 'object' + """A field which implements automatic conversion to and form a picklable + python object. + This field is python specific and therefore not of much use + if accessed from external programs. Consider the :class:`ForeignKey` + or :class:`JSONField` fields as more general alternatives. + + .. note:: The best way to use this field is when its :class:`Field.as_cache` + attribute is ``True``.""" + + type = "object" _default = None def set_get_value(self, instance, value): @@ -562,44 +599,43 @@ def get_encoder(self, params): class ForeignKey(Field): - '''A field defining a :ref:`one-to-many ` objects - relationship. -Requires a positional argument: the class to which the model is related. -For example:: + """A field defining a :ref:`one-to-many ` objects + relationship. + Requires a positional argument: the class to which the model is related. + For example:: - class Folder(odm.StdModel): - name = odm.SymobolField() + class Folder(odm.StdModel): + name = odm.SymobolField() - class File(odm.StdModel): - folder = odm.ForeignKey(Folder, related_name = 'files') + class File(odm.StdModel): + folder = odm.ForeignKey(Folder, related_name = 'files') -To create a recursive relationship, an object that has a many-to-one -relationship with itself use:: + To create a recursive relationship, an object that has a many-to-one + relationship with itself use:: - odm.ForeignKey('self') + odm.ForeignKey('self') -Behind the scenes, stdnet appends "_id" to the field name to create -its field name in the back-end data-server. In the above example, -the database field for the ``File`` model will have a ``folder_id`` field. + Behind the scenes, stdnet appends "_id" to the field name to create + its field name in the back-end data-server. In the above example, + the database field for the ``File`` model will have a ``folder_id`` field. -.. attribute:: related_name + .. attribute:: related_name - Optional name to use for the relation from the related object - back to ``self``. -''' - type = 'related object' - internal_type = 'numeric' + Optional name to use for the relation from the related object + back to ``self``.""" + + type = "related object" + internal_type = "numeric" python_type = int proxy_class = related.LazyForeignKey related_manager_class = related.One2ManyRelatedManager - def __init__(self, model, related_name=None, related_manager_class=None, - **kwargs): + def __init__(self, model, related_name=None, related_manager_class=None, **kwargs): if related_manager_class: self.related_manager_class = related_manager_class super(ForeignKey, self).__init__(**kwargs) if not model: - raise FieldError('Model not specified') + raise FieldError("Model not specified") self.relmodel = model self.related_name = related_name @@ -612,15 +648,17 @@ def _set_relmodel(self, relmodel): self.relmodel = relmodel meta = self.relmodel._meta if not self.related_name: - self.related_name = '%s_%s_set' % (self.model._meta.name, - self.name) - if (self.related_name not in meta.related and - self.related_name not in meta.dfields): + self.related_name = "%s_%s_set" % (self.model._meta.name, self.name) + if ( + self.related_name not in meta.related + and self.related_name not in meta.dfields + ): self._register_with_related_model() else: - raise FieldError('Duplicated related name "{0} in model "{1}" ' - 'and field {2}'.format(self.related_name, - meta, self)) + raise FieldError( + 'Duplicated related name "{0} in model "{1}" ' + "and field {2}".format(self.related_name, meta, self) + ) def _register_with_related_model(self): manager = self.related_manager_class(self) @@ -629,12 +667,11 @@ def _register_with_related_model(self): self.relmodel_manager = manager def get_attname(self): - return '%s_id' % self.name + return "%s_id" % self.name def get_value(self, instance, *bits): related = getattr(instance, self.name) - return related.get_attr_value(JSPLITTER.join(bits) - ) if bits else related + return related.get_attr_value(JSPLITTER.join(bits)) if bits else related def set_value(self, instance, value): if isinstance(value, self.relmodel): @@ -652,7 +689,7 @@ def scorefun(self, value): if isinstance(value, self.relmodel): return value.scorefun() else: - raise FieldValueError('cannot evaluate score of {0}'.format(value)) + raise FieldValueError("cannot evaluate score of {0}".format(value)) def to_python(self, value, backend=None): if isinstance(value, self.relmodel): @@ -661,6 +698,7 @@ def to_python(self, value, backend=None): return self.relmodel._meta.pk.to_python(value, backend) else: return value + json_serialise = to_python def filter(self, session, name, value): @@ -677,7 +715,7 @@ def get_lookup(self, name, errorClass=ValueError): fname = bits.pop(0) field = self.relmodel._meta.dfields.get(fname) meta = self.relmodel._meta - if field: # it is a field + if field: # it is a field nested = [(self.attname, meta)] remaining = JSPLITTER.join(bits) name, _nested = field.get_lookup(remaining, errorClass) @@ -685,79 +723,80 @@ def get_lookup(self, name, errorClass=ValueError): nested.extend(_nested) return (name, nested) else: - raise errorClass('%s not a valid field for %s' % (fname, meta)) + raise errorClass("%s not a valid field for %s" % (fname, meta)) else: return super(ForeignKey, self).get_lookup(name, errorClass) class JSONField(CharField): - '''A JSON field which implements automatic conversion to -and from an object and a JSON string. It is the responsability of the -user making sure the object is JSON serializable. + """A JSON field which implements automatic conversion to + and from an object and a JSON string. It is the responsability of the + user making sure the object is JSON serializable. + + There are few extra parameters which can be used to customize the + behaviour and how the field is stored in the back-end server. -There are few extra parameters which can be used to customize the -behaviour and how the field is stored in the back-end server. + :parameter encoder_class: The JSON class used for encoding. -:parameter encoder_class: The JSON class used for encoding. + Default: :class:`stdnet.utils.jsontools.JSONDateDecimalEncoder`. - Default: :class:`stdnet.utils.jsontools.JSONDateDecimalEncoder`. + :parameter decoder_hook: A JSON decoder function. -:parameter decoder_hook: A JSON decoder function. + Default: :class:`stdnet.utils.jsontools.date_decimal_hook`. - Default: :class:`stdnet.utils.jsontools.date_decimal_hook`. + :parameter as_string: Set the :attr:`as_string` attribute. -:parameter as_string: Set the :attr:`as_string` attribute. + Default ``True``. - Default ``True``. + .. attribute:: as_string -.. attribute:: as_string + A boolean indicating if data should be serialized + into a single JSON string or it should be used to create several + fields prefixed with the field name and the double underscore ``__``. - A boolean indicating if data should be serialized - into a single JSON string or it should be used to create several - fields prefixed with the field name and the double underscore ``__``. + Default ``True``. - Default ``True``. + Effectively, a :class:`JSONField` with ``as_string`` attribute set to + ``False`` is a multifield, in the sense that it generates several + field-value pairs. For example, lets consider the following:: - Effectively, a :class:`JSONField` with ``as_string`` attribute set to - ``False`` is a multifield, in the sense that it generates several - field-value pairs. For example, lets consider the following:: + class MyModel(odm.StdModel): + name = odm.SymbolField() + data = odm.JSONField(as_string=False) - class MyModel(odm.StdModel): - name = odm.SymbolField() - data = odm.JSONField(as_string=False) + And:: - And:: + >>> m = MyModel(name='bla', + ... data={'pv': {'': 0.5, 'mean': 1, 'std': 3.5}}) + >>> m.cleaned_data + {'name': 'bla', 'data__pv': 0.5, 'data__pv__mean': '1', + 'data__pv__std': '3.5', 'data': '""'} + >>> - >>> m = MyModel(name='bla', - ... data={'pv': {'': 0.5, 'mean': 1, 'std': 3.5}}) - >>> m.cleaned_data - {'name': 'bla', 'data__pv': 0.5, 'data__pv__mean': '1', - 'data__pv__std': '3.5', 'data': '""'} - >>> + The reason for setting ``as_string`` to ``False`` is to allow + the :class:`JSONField` to define several fields at runtime, + without introducing new :class:`Field` in your model class. + These fields behave exactly like standard fields and therefore you + can, for example, sort queries with respect to them:: - The reason for setting ``as_string`` to ``False`` is to allow - the :class:`JSONField` to define several fields at runtime, - without introducing new :class:`Field` in your model class. - These fields behave exactly like standard fields and therefore you - can, for example, sort queries with respect to them:: + >>> MyModel.objects.query().sort_by('data__pv__std') + >>> MyModel.objects.query().sort_by('-data__pv') - >>> MyModel.objects.query().sort_by('data__pv__std') - >>> MyModel.objects.query().sort_by('-data__pv') + which can be rather useful feature.""" - which can be rather useful feature. -''' - type = 'json object' - internal_type = 'serialized' + type = "json object" + internal_type = "serialized" _default = {} def get_encoder(self, params): - self.as_string = params.pop('as_string', True) + self.as_string = params.pop("as_string", True) if not self.as_string and not isinstance(self._default, dict): self._default = {} return encoders.Json( charset=self.charset, - json_encoder=params.pop('encoder_class', DefaultJSONEncoder), - object_hook=params.pop('decoder_hook', DefaultJSONHook)) + json_encoder=params.pop("encoder_class", DefaultJSONEncoder), + object_hook=params.pop("decoder_hook", DefaultJSONHook), + ) def to_python(self, value, backend=None): if value is None: @@ -777,10 +816,14 @@ def set_get_value(self, instance, value): return self.serialise(value) else: # unwind as a dictionary - value = dict(dict_flat_generator(value, - attname=self.attname, - dumps=self.serialise, - error=FieldValueError)) + value = dict( + dict_flat_generator( + value, + attname=self.attname, + dumps=self.serialise, + error=FieldValueError, + ) + ) # If the dictionary is empty we modify so that # an update is possible. if not value: @@ -789,7 +832,7 @@ def set_get_value(self, instance, value): # TODO Better implementation of this is a ack! # set the root value to an empty string to distinguish # from None. - value[self.attname] = self.serialise('') + value[self.attname] = self.serialise("") return value def serialise(self, value, lookup=None): @@ -801,9 +844,9 @@ def value_from_data(self, instance, data): if self.as_string: return data.pop(self.attname, None) else: - return flat_to_nested(data, instance=instance, - attname=self.attname, - loads=self.encoder.loads) + return flat_to_nested( + data, instance=instance, attname=self.attname, loads=self.encoder.loads + ) def get_sorting(self, name, errorClass): pass @@ -821,23 +864,24 @@ def get_value(self, instance, *bits): try: for bit in bits: value = value[bit] - if isinstance(value, dict) and '' in value: - value = value[''] + if isinstance(value, dict) and "" in value: + value = value[""] return value except Exception: raise AttributeError class ModelField(SymbolField): - '''A filed which can be used to store the model classes (not only -:class:`StdModel` models). If a class has a attribute ``_meta`` -with a unique hash attribute ``hash`` and it is -registered in the model hash table, it can be used.''' - type = 'model' - internal_type = 'text' + """A filed which can be used to store the model classes (not only + :class:`StdModel` models). If a class has a attribute ``_meta`` + with a unique hash attribute ``hash`` and it is + registered in the model hash table, it can be used.""" + + type = "model" + internal_type = "text" def to_python(self, value, backend=None): - if value and not hasattr(value, '_meta'): + if value and not hasattr(value, "_meta"): value = self.encoder.loads(value) return get_model_from_hash(value) else: @@ -851,6 +895,7 @@ def serialise(self, value, lookup=None): return get_hash_from_model(value) else: return v + json_serialise = serialise def set_get_value(self, instance, value): @@ -862,45 +907,45 @@ def set_get_value(self, instance, value): class ManyToManyField(Field): - '''A :ref:`many-to-many ` relationship. -Like :class:`ForeignKey`, it requires a positional argument, the class -to which the model is related and it accepts **related_name** as extra -argument. + """A :ref:`many-to-many ` relationship. + Like :class:`ForeignKey`, it requires a positional argument, the class + to which the model is related and it accepts **related_name** as extra + argument. -.. attribute:: related_name + .. attribute:: related_name - Optional name to use for the relation from the related object - back to ``self``. For example:: + Optional name to use for the relation from the related object + back to ``self``. For example:: - class Group(odm.StdModel): - name = odm.SymbolField(unique=True) + class Group(odm.StdModel): + name = odm.SymbolField(unique=True) - class User(odm.StdModel): - name = odm.SymbolField(unique=True) - groups = odm.ManyToManyField(Group, related_name='users') + class User(odm.StdModel): + name = odm.SymbolField(unique=True) + groups = odm.ManyToManyField(Group, related_name='users') - To use it:: + To use it:: - >>> g = Group(name='developers').save() - >>> g.users.add(User(name='john').save()) - >>> u.users.add(User(name='mark').save()) + >>> g = Group(name='developers').save() + >>> g.users.add(User(name='john').save()) + >>> u.users.add(User(name='mark').save()) - and to remove:: + and to remove:: - >>> u.following.remove(User.objects.get(name='john')) + >>> u.following.remove(User.objects.get(name='john')) -.. attribute:: through + .. attribute:: through - An optional :class:`StdModel` to use for creating the many-to-many - relationship can be passed to the constructor, via the **through** keyword. - If such a model is not passed, under the hood, a :class:`ManyToManyField` - creates a new *model* with name constructed from the field name - and the model holding the field. In the example above it would be - *group_user*. - This model contains two :class:`ForeignKeys`, one to model holding the - :class:`ManyToManyField` and the other to the *related_model*. + An optional :class:`StdModel` to use for creating the many-to-many + relationship can be passed to the constructor, via the **through** keyword. + If such a model is not passed, under the hood, a :class:`ManyToManyField` + creates a new *model* with name constructed from the field name + and the model holding the field. In the example above it would be + *group_user*. + This model contains two :class:`ForeignKeys`, one to model holding the + :class:`ManyToManyField` and the other to the *related_model*. + """ -''' def __init__(self, model, through=None, related_name=None, **kwargs): self.through = through self.relmodel = model @@ -915,7 +960,7 @@ def register_with_model(self, name, model): def _set_relmodel(self, relmodel): self.relmodel = relmodel if not self.related_name: - self.related_name = '%s_set' % self.model._meta.name + self.related_name = "%s_set" % self.model._meta.name related.Many2ManyThroughModel(self) def get_attname(self): @@ -925,36 +970,37 @@ def todelete(self): return False def add_to_fields(self): - #A many to many field is a dummy field. All it does it provides a proxy - #for the through model. Remove it from the fields dictionary - #and addit to the list of many_to_many + # A many to many field is a dummy field. All it does it provides a proxy + # for the through model. Remove it from the fields dictionary + # and addit to the list of many_to_many self.meta.dfields.pop(self.name) self.meta.manytomany.append(self.name) class CompositeIdField(AutoIdField): - '''This field can be used when an instance of a model is uniquely -identified by a combination of two or more :class:`Field` in the model -itself. It requires a number of positional arguments greater or equal 2. -These arguments must be fields names in the model where the -:class:`CompositeIdField` is defined. + """This field can be used when an instance of a model is uniquely + identified by a combination of two or more :class:`Field` in the model + itself. It requires a number of positional arguments greater or equal 2. + These arguments must be fields names in the model where the + :class:`CompositeIdField` is defined. + + .. attribute:: fields -.. attribute:: fields + list of :class:`Field` names which are used to uniquely identify a + model instance - list of :class:`Field` names which are used to uniquely identify a - model instance + Check the :ref:`composite id tutorial ` for more + information and tips on how to use it.""" -Check the :ref:`composite id tutorial ` for more -information and tips on how to use it. -''' - type = 'composite' + type = "composite" def __init__(self, *fields, **kwargs): super(CompositeIdField, self).__init__(**kwargs) self.fields = fields if len(self.fields) < 2: - raise FieldError('At least tow fields are required by composite ' - 'CompositeIdField') + raise FieldError( + "At least tow fields are required by composite " "CompositeIdField" + ) def get_value(self, instance, *bits): if bits: @@ -966,12 +1012,12 @@ def register_with_model(self, name, model): fields = [] for field in self.fields: if field not in model._meta.dfields: - raise FieldError('Composite id field "%s" in in "%s" model.' % - (field, model._meta)) + raise FieldError( + 'Composite id field "%s" in in "%s" model.' % (field, model._meta) + ) field = model._meta.dfields[field] - if field.internal_type not in ('text', 'numeric'): - raise FieldError('Composite id field "%s" not valid type.' % - field) + if field.internal_type not in ("text", "numeric"): + raise FieldError('Composite id field "%s" not valid type.' % field) fields.append(field) self.fields = tuple(fields) return super(CompositeIdField, self).register_with_model(name, model) diff --git a/stdnet/odm/globals.py b/stdnet/odm/globals.py index d075594..718d9a6 100755 --- a/stdnet/odm/globals.py +++ b/stdnet/odm/globals.py @@ -1,31 +1,28 @@ import hashlib from collections import namedtuple -from stdnet.utils import to_bytes, JSPLITTER +from stdnet.utils import JSPLITTER, to_bytes -__all__ = ['get_model_from_hash', - 'get_hash_from_model', - 'hashmodel', - 'JSPLITTER'] +__all__ = ["get_model_from_hash", "get_hash_from_model", "hashmodel", "JSPLITTER"] # Information about a lookup in a query -lookup_value = namedtuple('lookup_value', 'lookup value') +lookup_value = namedtuple("lookup_value", "lookup value") # Utilities for sorting and range lookups -orderinginfo = namedtuple('orderinginfo', 'name field desc model nested auto') +orderinginfo = namedtuple("orderinginfo", "name field desc model nested auto") # attribute name, field, model where to do lookup, nested lookup_info -range_lookup_info = namedtuple('range_lookup_info', 'name field model nested') +range_lookup_info = namedtuple("range_lookup_info", "name field model nested") class ModelDict(dict): - def from_hash(self, hash): return self.get(hash) def to_hash(self, model): return model._meta.hash + _model_dict = ModelDict() @@ -38,36 +35,39 @@ def get_hash_from_model(model): def hashmodel(model, library=None): - '''Calculate the Hash id of metaclass ``meta``''' - library = library or 'python-stdnet' + """Calculate the Hash id of metaclass ``meta``""" + library = library or "python-stdnet" meta = model._meta - sha = hashlib.sha1(to_bytes('{0}({1})'.format(library, meta))) + sha = hashlib.sha1(to_bytes("{0}({1})".format(library, meta))) hash = sha.hexdigest()[:8] meta.hash = hash if hash in _model_dict: - raise KeyError('Model "{0}" already in hash table.\ - Rename your model or the module containing the model.'.format(meta)) + raise KeyError( + 'Model "{0}" already in hash table.\ + Rename your model or the module containing the model.'.format( + meta + ) + ) _model_dict[hash] = model def _make_id(target): - if hasattr(target, '__func__'): + if hasattr(target, "__func__"): return (id(target.__self__), id(target.__func__)) return id(target) class Event: - def __init__(self): self.callbacks = [] def bind(self, callback, sender=None): - '''Bind a ``callback`` for a given ``sender``.''' + """Bind a ``callback`` for a given ``sender``.""" key = (_make_id(callback), _make_id(sender)) self.callbacks.append((key, callback)) def fire(self, sender=None, **params): - '''Fire callbacks from a ``sender``.''' + """Fire callbacks from a ``sender``.""" keys = (_make_id(None), _make_id(sender)) results = [] for (_, key), callback in self.callbacks: diff --git a/stdnet/odm/mapper.py b/stdnet/odm/mapper.py index 5254297..6b13933 100755 --- a/stdnet/odm/mapper.py +++ b/stdnet/odm/mapper.py @@ -1,64 +1,63 @@ -from inspect import ismodule, isclass +from inspect import isclass, ismodule +from stdnet import getdb from stdnet.utils import native_str from stdnet.utils.importer import import_module -from stdnet import getdb -from .base import ModelType, Model -from .session import Manager, Session, ModelDictionary, StructureManager -from .struct import Structure +from .base import Model, ModelType from .globals import Event, get_model_from_hash +from .session import Manager, ModelDictionary, Session, StructureManager +from .struct import Structure - -__all__ = ['Router', 'model_iterator'] +__all__ = ["Router", "model_iterator"] class Router(object): - '''A router is a mapping of :class:`Model` to the registered -:class:`Manager` of that model:: + """A router is a mapping of :class:`Model` to the registered + :class:`Manager` of that model:: + + from stdnet import odm - from stdnet import odm + models = odm.Router() + models.register(MyModel, ...) - models = odm.Router() - models.register(MyModel, ...) + # dictionary Notation + query = models[MyModel].query() - # dictionary Notation - query = models[MyModel].query() + # or dotted notation (lowercase) + query = models.mymodel.query() - # or dotted notation (lowercase) - query = models.mymodel.query() + The ``models`` instance in the above snipped can be set globally if + one wishes to do so. -The ``models`` instance in the above snipped can be set globally if -one wishes to do so. + .. attribute:: pre_commit -.. attribute:: pre_commit + A signal which can be used to register ``callbacks`` before instances are + committed:: - A signal which can be used to register ``callbacks`` before instances are - committed:: + models.pre_commit.bind(callback, sender=MyModel) - models.pre_commit.bind(callback, sender=MyModel) + .. attribute:: pre_delete -.. attribute:: pre_delete + A signal which can be used to register ``callbacks`` before instances are + deleted:: - A signal which can be used to register ``callbacks`` before instances are - deleted:: + models.pre_delete.bind(callback, sender=MyModel) - models.pre_delete.bind(callback, sender=MyModel) + .. attribute:: post_commit -.. attribute:: post_commit + A signal which can be used to register ``callbacks`` after instances are + committed:: - A signal which can be used to register ``callbacks`` after instances are - committed:: + models.post_commit.bind(callback, sender=MyModel) - models.post_commit.bind(callback, sender=MyModel) + .. attribute:: post_delete -.. attribute:: post_delete + A signal which can be used to register ``callbacks`` after instances are + deleted:: - A signal which can be used to register ``callbacks`` after instances are - deleted:: + models.post_delete.bind(callback, sender=MyModel)""" - models.post_delete.bind(callback, sender=MyModel) -''' def __init__(self, default_backend=None, install_global=False): self._registered_models = ModelDictionary() self._registered_names = {} @@ -73,24 +72,24 @@ def __init__(self, default_backend=None, install_global=False): @property def default_backend(self): - '''The default backend for this :class:`Router`. This is used when -calling the :meth:`register` method without explicitly passing a backend.''' + """The default backend for this :class:`Router`. This is used when + calling the :meth:`register` method without explicitly passing a backend.""" return self._default_backend @property def registered_models(self): - '''List of registered :class:`Model`.''' + """List of registered :class:`Model`.""" return list(self._registered_models) @property def search_engine(self): - '''The :class:`SearchEngine` for this :class:`Router`. This -must be created by users. Check :ref:`full text search ` -tutorial for information.''' + """The :class:`SearchEngine` for this :class:`Router`. This + must be created by users. Check :ref:`full text search ` + tutorial for information.""" return self._search_engine def __repr__(self): - return '%s %s' % (self.__class__.__name.__, self._registered_models) + return "%s %s" % (self.__class__.__name.__, self._registered_models) def __str__(self): return str(self._registered_models) @@ -110,42 +109,43 @@ def structure(self, model): return self._structures.get(model) def set_search_engine(self, engine): - '''Set the search ``engine`` for this :class:`Router`.''' + """Set the search ``engine`` for this :class:`Router`.""" self._search_engine = engine self._search_engine.set_router(self) - def register(self, model, backend=None, read_backend=None, - include_related=True, **params): - '''Register a :class:`Model` with this :class:`Router`. If the -model was already registered it does nothing. - -:param model: a :class:`Model` class. -:param backend: a :class:`stdnet.BackendDataServer` or a - :ref:`connection string `. -:param read_backend: Optional :class:`stdnet.BackendDataServer` for read - operations. This is useful when the server has a master/slave - configuration, where the master accept write and read operations - and the ``slave`` read only operations. -:param include_related: ``True`` if related models to ``model`` needs to be - registered. Default ``True``. -:param params: Additional parameters for the :func:`getdb` function. -:return: the number of models registered. -''' + def register( + self, model, backend=None, read_backend=None, include_related=True, **params + ): + """Register a :class:`Model` with this :class:`Router`. If the + model was already registered it does nothing. + + :param model: a :class:`Model` class. + :param backend: a :class:`stdnet.BackendDataServer` or a + :ref:`connection string `. + :param read_backend: Optional :class:`stdnet.BackendDataServer` for read + operations. This is useful when the server has a master/slave + configuration, where the master accept write and read operations + and the ``slave`` read only operations. + :param include_related: ``True`` if related models to ``model`` needs to be + registered. Default ``True``. + :param params: Additional parameters for the :func:`getdb` function. + :return: the number of models registered.""" backend = backend or self._default_backend backend = getdb(backend=backend, **params) if read_backend: read_backend = getdb(read_backend) registered = 0 if isinstance(model, Structure): - self._structures[model] = StructureManager(model, backend, - read_backend, self) + self._structures[model] = StructureManager( + model, backend, read_backend, self + ) return model for model in models_from_model(model, include_related=include_related): if model in self._registered_models: continue registered += 1 default_manager = backend.default_manager or Manager - manager_class = getattr(model, 'manager_class', default_manager) + manager_class = getattr(model, "manager_class", default_manager) manager = manager_class(model, backend, read_backend, self) self._registered_models[model] = manager if isinstance(model, ModelType): @@ -160,35 +160,36 @@ def register(self, model, backend=None, read_backend=None, return backend def from_uuid(self, uuid, session=None): - '''Retrieve a :class:`Model` from its universally unique identifier -``uuid``. If the ``uuid`` does not match any instance an exception will raise. -''' - elems = uuid.split('.') + """Retrieve a :class:`Model` from its universally unique identifier + ``uuid``. If the ``uuid`` does not match any instance an exception will raise.""" + elems = uuid.split(".") if len(elems) == 2: model = get_model_from_hash(elems[0]) if not model: raise Model.DoesNotExist( - 'model id "{0}" not available'.format(elems[0])) + 'model id "{0}" not available'.format(elems[0]) + ) if not session or session.router is not self: session = self.session() return session.query(model).get(id=elems[1]) raise Model.DoesNotExist('uuid "{0}" not recognized'.format(uuid)) def flush(self, exclude=None, include=None, dryrun=False): - '''Flush :attr:`registered_models`. + """Flush :attr:`registered_models`. :param exclude: optional list of model names to exclude. :param include: optional list of model names to include. :param dryrun: Doesn't remove anything, simply collect managers to flush. :return: - ''' + """ exclude = exclude or [] results = [] for manager in self._registered_models.values(): m = manager._meta - if include is not None and not (m.modelkey in include or - m.app_label in include): + if include is not None and not ( + m.modelkey in include or m.app_label in include + ): continue if not (m.modelkey in exclude or m.app_label in exclude): if dryrun: @@ -198,9 +199,9 @@ def flush(self, exclude=None, include=None, dryrun=False): return results def unregister(self, model=None): - '''Unregister a ``model`` if provided, otherwise it unregister all -registered models. Return a list of unregistered model managers or ``None`` -if no managers were removed.''' + """Unregister a ``model`` if provided, otherwise it unregister all + registered models. Return a list of unregistered model managers or ``None`` + if no managers were removed.""" if model is not None: try: manager = self._registered_models.pop(model) @@ -215,48 +216,45 @@ def unregister(self, model=None): return managers def register_applications(self, applications, models=None, backends=None): - '''A higher level registration functions for group of models located -on application modules. -It uses the :func:`model_iterator` function to iterate -through all :class:`Model` models available in ``applications`` -and register them using the :func:`register` low level method. - -:parameter applications: A String or a list of strings representing - python dotted paths where models are implemented. -:parameter models: Optional list of models to include. If not provided - all models found in *applications* will be included. -:parameter backends: optional dictionary which map a model or an - application to a backend :ref:`connection string `. -:rtype: A list of registered :class:`Model`. - -For example:: - - - mapper.register_application_models('mylib.myapp') - mapper.register_application_models(['mylib.myapp', 'another.path']) - mapper.register_application_models(pythonmodule) - mapper.register_application_models(['mylib.myapp',pythonmodule]) - -''' - return list(self._register_applications(applications, models, - backends)) + """A higher level registration functions for group of models located + on application modules. + It uses the :func:`model_iterator` function to iterate + through all :class:`Model` models available in ``applications`` + and register them using the :func:`register` low level method. + + :parameter applications: A String or a list of strings representing + python dotted paths where models are implemented. + :parameter models: Optional list of models to include. If not provided + all models found in *applications* will be included. + :parameter backends: optional dictionary which map a model or an + application to a backend :ref:`connection string `. + :rtype: A list of registered :class:`Model`. + + For example:: + + + mapper.register_application_models('mylib.myapp') + mapper.register_application_models(['mylib.myapp', 'another.path']) + mapper.register_application_models(pythonmodule) + mapper.register_application_models(['mylib.myapp',pythonmodule]) + """ + return list(self._register_applications(applications, models, backends)) def session(self): - '''Obatain a new :class:`Session` for this ``Router``.''' + """Obatain a new :class:`Session` for this ``Router``.""" return Session(self) def create_all(self): - '''Loop though :attr:`registered_models` and issue the -:meth:`Manager.create_all` method.''' + """Loop though :attr:`registered_models` and issue the + :meth:`Manager.create_all` method.""" for manager in self._registered_models.values(): manager.create_all() def add(self, instance): - '''Add an ``instance`` to its backend database. This is a shurtcut -method for:: + """Add an ``instance`` to its backend database. This is a shurtcut + method for:: - self.session().add(instance) -''' + self.session().add(instance)""" return self.session().add(instance) # PRIVATE METHODS @@ -271,7 +269,7 @@ def _register_applications(self, applications, models, backends): name = model._meta.app_label kwargs = backends.get(name, self._default_backend) if not isinstance(kwargs, dict): - kwargs = {'backend': kwargs} + kwargs = {"backend": kwargs} else: kwargs = kwargs.copy() if self.register(model, include_related=False, **kwargs): @@ -279,7 +277,7 @@ def _register_applications(self, applications, models, backends): def models_from_model(model, include_related=False, exclude=None): - '''Generator of all model in model.''' + """Generator of all model in model.""" if exclude is None: exclude = set() if model and model not in exclude: @@ -289,18 +287,18 @@ def models_from_model(model, include_related=False, exclude=None): if include_related: exclude.add(model) for field in model._meta.fields: - if hasattr(field, 'relmodel'): - through = getattr(field, 'through', None) + if hasattr(field, "relmodel"): + through = getattr(field, "through", None) for rmodel in (field.relmodel, field.model, through): for m in models_from_model( - rmodel, include_related=include_related, - exclude=exclude): + rmodel, include_related=include_related, exclude=exclude + ): yield m for manytomany in model._meta.manytomany: related = getattr(model, manytomany) - for m in models_from_model(related.model, - include_related=include_related, - exclude=exclude): + for m in models_from_model( + related.model, include_related=include_related, exclude=exclude + ): yield m elif not isinstance(model, ModelType) and isclass(model): # This is a class which is not o ModelType @@ -308,24 +306,23 @@ def models_from_model(model, include_related=False, exclude=None): def model_iterator(application, include_related=True, exclude=None): - '''A generator of :class:`StdModel` classes found in *application*. - -:parameter application: A python dotted path or an iterable over python - dotted-paths where models are defined. + """A generator of :class:`StdModel` classes found in *application*. -Only models defined in these paths are considered. + :parameter application: A python dotted path or an iterable over python + dotted-paths where models are defined. -For example:: + Only models defined in these paths are considered. - from stdnet.odm import model_iterator + For example:: - APPS = ('stdnet.contrib.searchengine', - 'stdnet.contrib.timeseries') + from stdnet.odm import model_iterator - for model in model_iterator(APPS): - ... + APPS = ('stdnet.contrib.searchengine', + 'stdnet.contrib.timeseries') -''' + for model in model_iterator(APPS): + ... + """ if exclude is None: exclude = set() application = native_str(application) @@ -339,22 +336,21 @@ def model_iterator(application, include_related=True, exclude=None): # the module is not there mod = None if mod: - label = application.split('.')[-1] + label = application.split(".")[-1] try: - mod_models = import_module('.models', application) + mod_models = import_module(".models", application) except ImportError: mod_models = mod - label = getattr(mod_models, 'app_label', label) + label = getattr(mod_models, "app_label", label) models = set() for name in dir(mod_models): value = getattr(mod_models, name) - meta = getattr(value, '_meta', None) + meta = getattr(value, "_meta", None) if isinstance(value, ModelType) and meta: for model in models_from_model( - value, include_related=include_related, - exclude=exclude): - if (model._meta.app_label == label - and model not in models): + value, include_related=include_related, exclude=exclude + ): + if model._meta.app_label == label and model not in models: models.add(model) yield model else: diff --git a/stdnet/odm/models.py b/stdnet/odm/models.py index 3ccdadb..e8da569 100755 --- a/stdnet/odm/models.py +++ b/stdnet/odm/models.py @@ -1,19 +1,19 @@ from functools import partial -from stdnet.utils import zip, JSPLITTER, EMPTYJSON, iteritems +from stdnet.utils import EMPTYJSON, JSPLITTER, iteritems, zip from stdnet.utils.exceptions import * -from .base import ModelType, ModelBase, raise_kwargs +from .base import ModelBase, ModelType, raise_kwargs from .session import Manager - -__all__ = ['StdModel', 'create_model', 'model_to_dict'] +__all__ = ["StdModel", "create_model", "model_to_dict"] class StdModel(ModelBase): - '''A :class:`Model` which contains data in :class:`Field`. This represents -the main class of :mod:`stdnet.odm` module.''' - _model_type = 'object' + """A :class:`Model` which contains data in :class:`Field`. This represents + the main class of :mod:`stdnet.odm` module.""" + + _model_type = "object" abstract = True _loadedfields = None @@ -28,7 +28,7 @@ def __init__(self, *args, **kwargs): if args: N = len(args) if N > len(attributes): - raise ValueError('Too many attributes') + raise ValueError("Too many attributes") attrs, attributes = attributes[:N], attributes[N:] for name, value in zip(attrs, args): setattr(self, name, value) @@ -39,10 +39,10 @@ def __init__(self, *args, **kwargs): @property def has_all_data(self): - '''``True`` if this :class:`StdModel` instance has all back-end data -loaded. This applies to persistent instances only. This property is used when -committing changes. If all data is available, the commit will replace the -previous object data entirely, otherwise it will only update it.''' + """``True`` if this :class:`StdModel` instance has all back-end data + loaded. This applies to persistent instances only. This property is used when + committing changes. If all data is available, the commit will replace the + previous object data entirely, otherwise it will only update it.""" return self.get_state().persistent and self._loadedfields is None def set(self, name, value): @@ -52,10 +52,10 @@ def set(self, name, value): elif name in meta.attributes: setattr(self, name, value) else: - raise AttributeError('Model has no field/attribute %s' % name) + raise AttributeError("Model has no field/attribute %s" % name) def loadedfields(self): - '''Generator of fields loaded from database''' + """Generator of fields loaded from database""" if self._loadedfields is None: for field in self._meta.scalarfields: yield field @@ -72,21 +72,21 @@ def loadedfields(self): name = name.split(JSPLITTER)[0] if name in fields and name not in processed: field = fields[name] - if field.type == 'json object': + if field.type == "json object": processed.add(name) yield field def fieldvalue_pairs(self, exclude_cache=False): - '''Generator of fields,values pairs. Fields correspond to -the ones which have been loaded (usually all of them) or -not loaded but modified. -Check the :ref:`load_only ` query function for more -details. + """Generator of fields,values pairs. Fields correspond to + the ones which have been loaded (usually all of them) or + not loaded but modified. + Check the :ref:`load_only ` query function for more + details. -If *exclude_cache* evaluates to ``True``, fields with :attr:`Field.as_cache` -attribute set to ``True`` won't be included. + If *exclude_cache* evaluates to ``True``, fields with :attr:`Field.as_cache` + attribute set to ``True`` won't be included. -:rtype: a generator of two-elements tuples''' + :rtype: a generator of two-elements tuples""" for field in self._meta.scalarfields: if exclude_cache and field.as_cache: continue @@ -95,20 +95,20 @@ def fieldvalue_pairs(self, exclude_cache=False): yield field, getattr(self, name) def clear_cache_fields(self): - '''Set cache fields to ``None``. Check :attr:`Field.as_cache` -for information regarding fields which are considered cache.''' + """Set cache fields to ``None``. Check :attr:`Field.as_cache` + for information regarding fields which are considered cache.""" for field in self._meta.scalarfields: if field.as_cache: setattr(self, field.name, None) def get_attr_value(self, name): - '''Retrieve the ``value`` for the attribute ``name``. The ``name`` -can be nested following the :ref:`double underscore ` -notation, for example ``group__name``. If the attribute is not available it -raises :class:`AttributeError`.''' + """Retrieve the ``value`` for the attribute ``name``. The ``name`` + can be nested following the :ref:`double underscore ` + notation, for example ``group__name``. If the attribute is not available it + raises :class:`AttributeError`.""" if name in self._meta.dfields: return self._meta.dfields[name].get_value(self) - elif not name.startswith('__') and JSPLITTER in name: + elif not name.startswith("__") and JSPLITTER in name: bits = name.split(JSPLITTER) fname = bits[0] if fname in self._meta.dfields: @@ -119,103 +119,99 @@ def get_attr_value(self, name): return getattr(self, name) def clone(self, **data): - '''Utility method for cloning the instance as a new object. + """Utility method for cloning the instance as a new object. -:parameter data: additional which override field data. -:rtype: a new instance of this class. -''' + :parameter data: additional which override field data. + :rtype: a new instance of this class.""" meta = self._meta session = self.session pkname = meta.pkname() pkvalue = data.pop(pkname, None) fields = self.todict(exclude_cache=True) fields.update(data) - fields.pop('__dbdata__', None) + fields.pop("__dbdata__", None) obj = self._meta.make_object((pkvalue, None, fields)) obj.session = session return obj def is_valid(self): - '''Kick off the validation algorithm by checking all -:attr:`StdModel.loadedfields` against their respective validation algorithm. + """Kick off the validation algorithm by checking all + :attr:`StdModel.loadedfields` against their respective validation algorithm. -:rtype: Boolean indicating if the model validates.''' + :rtype: Boolean indicating if the model validates.""" return self._meta.is_valid(self) def todict(self, exclude_cache=False): - '''Return a dictionary of serialised scalar field for pickling. -If the *exclude_cache* flag is ``True``, fields with :attr:`Field.as_cache` -attribute set to ``True`` will be excluded.''' + """Return a dictionary of serialised scalar field for pickling. + If the *exclude_cache* flag is ``True``, fields with :attr:`Field.as_cache` + attribute set to ``True`` will be excluded.""" odict = {} for field, value in self.fieldvalue_pairs(exclude_cache=exclude_cache): value = field.serialise(value) if value: odict[field.name] = value - if self._dbdata and 'id' in self._dbdata: - odict['__dbdata__'] = {'id': self._dbdata['id']} + if self._dbdata and "id" in self._dbdata: + odict["__dbdata__"] = {"id": self._dbdata["id"]} return odict def _to_json(self, exclude_cache): pk = self.pkvalue() if pk: yield self._meta.pkname(), pk - for field, value in self.fieldvalue_pairs(exclude_cache= - exclude_cache): + for field, value in self.fieldvalue_pairs(exclude_cache=exclude_cache): value = field.json_serialise(value) if value not in EMPTYJSON: yield field.name, value def tojson(self, exclude_cache=True): - '''Return a JSON serialisable dictionary representation.''' + """Return a JSON serialisable dictionary representation.""" return dict(self._to_json(exclude_cache)) def load_fields(self, *fields): - '''Load extra fields to this :class:`StdModel`.''' + """Load extra fields to this :class:`StdModel`.""" if self._loadedfields is not None: if self.session is None: - raise SessionNotAvailable('No session available') + raise SessionNotAvailable("No session available") meta = self._meta kwargs = {meta.pkname(): self.pkvalue()} obj = session.query(self).load_only(fields).get(**kwargs) for name in fields: field = meta.dfields.get(name) if field is not None: - setattr(self, field.attname, - getattr(obj, field.attname, None)) + setattr(self, field.attname, getattr(obj, field.attname, None)) def get_state_action(self): - return 'override' if self._loadedfields is None else 'update' + return "override" if self._loadedfields is None else "update" def load_related_model(self, name, load_only=None, dont_load=None): - '''Load a the :class:`ForeignKey` field ``name`` if this is part of the -fields of this model and if the related object is not already loaded. -It is used by the lazy loading mechanism of :ref:`one-to-many ` -relationships. - -:parameter name: the :attr:`Field.name` of the :class:`ForeignKey` to load. -:parameter load_only: Optional parameters which specify the fields to load. -:parameter dont_load: Optional parameters which specify the fields not to load. -:return: the related :class:`StdModel` instance. -''' + """Load a the :class:`ForeignKey` field ``name`` if this is part of the + fields of this model and if the related object is not already loaded. + It is used by the lazy loading mechanism of :ref:`one-to-many ` + relationships. + + :parameter name: the :attr:`Field.name` of the :class:`ForeignKey` to load. + :parameter load_only: Optional parameters which specify the fields to load. + :parameter dont_load: Optional parameters which specify the fields not to load. + :return: the related :class:`StdModel` instance.""" field = self._meta.dfields.get(name) if not field: raise ValueError('Field "%s" not available' % name) - elif not field.type == 'related object': + elif not field.type == "related object": raise ValueError('Field "%s" not a foreign key' % name) return self._load_related_model(field, load_only, dont_load) @classmethod def get_field(cls, name): - '''Returns the :class:`Field` instance at ``name`` if available, -otherwise it returns ``None``.''' + """Returns the :class:`Field` instance at ``name`` if available, + otherwise it returns ``None``.""" return cls._meta.dfields.get(name) @classmethod def from_base64_data(cls, **kwargs): - '''Load a :class:`StdModel` from possibly base64encoded data. + """Load a :class:`StdModel` from possibly base64encoded data. -This method is used to load models from data obtained from the :meth:`tojson` -method.''' + This method is used to load models from data obtained from the :meth:`tojson` + method.""" o = cls() meta = cls._meta pkname = meta.pkname() @@ -232,8 +228,8 @@ def from_base64_data(cls, **kwargs): @classmethod def pk(cls): - '''Returns the primary key :class:`Field` for this model. This is a -proxy for the :attr:`Metaclass.pk` attribute.''' + """Returns the primary key :class:`Field` for this model. This is a + proxy for the :attr:`Metaclass.pk` attribute.""" return cls._meta.pk @classmethod @@ -242,7 +238,7 @@ def get_unique_instance(cls, items): if len(items) == 1: return items[0] else: - raise QuerySetError('Non unique results') + raise QuerySetError("Non unique results") else: raise cls.DoesNotExist() @@ -288,22 +284,21 @@ def __set_related_value(self, field, items=None): def create_model(name, *attributes, **params): - '''Create a :class:`Model` class for objects requiring -and interface similar to :class:`StdModel`. We refers to this type -of models as :ref:`local models ` since instances of such -models are not persistent on a :class:`stdnet.BackendDataServer`. - -:param name: Name of the model class. -:param attributes: positiona attribute names. These are the only attribute - available to the model during the default constructor. -:param params: key-valued parameter to pass to the :class:`ModelMeta` - constructor. -:return: a local :class:`Model` class. - ''' - params['register'] = False - params['attributes'] = attributes - kwargs = {'manager_class': params.pop('manager_class', Manager), - 'Meta': params} + """Create a :class:`Model` class for objects requiring + and interface similar to :class:`StdModel`. We refers to this type + of models as :ref:`local models ` since instances of such + models are not persistent on a :class:`stdnet.BackendDataServer`. + + :param name: Name of the model class. + :param attributes: positiona attribute names. These are the only attribute + available to the model during the default constructor. + :param params: key-valued parameter to pass to the :class:`ModelMeta` + constructor. + :return: a local :class:`Model` class. + """ + params["register"] = False + params["attributes"] = attributes + kwargs = {"manager_class": params.pop("manager_class", Manager), "Meta": params} return ModelType(name, (StdModel,), kwargs) diff --git a/stdnet/odm/query.py b/stdnet/odm/query.py index b729cb1..da397b7 100755 --- a/stdnet/odm/query.py +++ b/stdnet/odm/query.py @@ -1,7 +1,7 @@ +from collections import Mapping from copy import copy -from inspect import isgenerator from functools import partial -from collections import Mapping +from inspect import isgenerator from stdnet import range_lookups from stdnet.utils import JSPLITTER, iteritems, unique_tuple @@ -9,9 +9,16 @@ from .globals import lookup_value - -__all__ = ['Q', 'QueryBase', 'Query', 'QueryElement', 'EmptyQuery', - 'intersect', 'union', 'difference'] +__all__ = [ + "Q", + "QueryBase", + "Query", + "QueryElement", + "EmptyQuery", + "intersect", + "union", + "difference", +] iterables = (tuple, list, set, frozenset, Mapping) @@ -29,7 +36,7 @@ def update_dictionary(result, extra): v.update(v2) if iterable(v2) else v.add(v2) if len(k.split(JSPLITTER)) == 1: result.pop(k) - k = k + JSPLITTER + 'in' + k = k + JSPLITTER + "in" result[k] = v return result @@ -43,32 +50,42 @@ def get_lookups(attname, field_lookups): class Q(object): - '''Base class for :class:`Query` and :class:`QueryElement`. + """Base class for :class:`Query` and :class:`QueryElement`. -.. attribute:: meta + .. attribute:: meta - The :attr:`StdModel._meta` attribute. + The :attr:`StdModel._meta` attribute. -.. attribute:: model + .. attribute:: model - the :class:`StdModel` class for this query. + the :class:`StdModel` class for this query. -.. attribute:: backend + .. attribute:: backend - the :class:`stdnet.BackendDataServer` class for this query. -''' - keyword = '' - name = '' + the :class:`stdnet.BackendDataServer` class for this query.""" - def __init__(self, meta, session, select_related=None, - ordering=None, fields=None, - get_field=None, name=None, keyword=None): + keyword = "" + name = "" + + def __init__( + self, + meta, + session, + select_related=None, + ordering=None, + fields=None, + get_field=None, + name=None, + keyword=None, + ): self._meta = meta self.session = session - self.data = {'select_related': select_related, - 'ordering': ordering, - 'fields': fields, - 'get_field': get_field} + self.data = { + "select_related": select_related, + "ordering": ordering, + "fields": fields, + "get_field": get_field, + } self.name = name if name is not None else meta.pk.name self.keyword = keyword if keyword is not None else self.keyword @@ -82,46 +99,46 @@ def model(self): @property def select_related(self): - return self.data['select_related'] + return self.data["select_related"] @property def ordering(self): - return self.data['ordering'] + return self.data["ordering"] @property def fields(self): - return self.data['fields'] + return self.data["fields"] @property def _get_field(self): - return self.data['get_field'] + return self.data["get_field"] @property def backend(self): return self.session.model(self._meta).read_backend def get_field(self, field): - '''A :class:`Q` performs a series of operations and ultimately -generate of set of matched elements ``ids``. If on the other hand, a -different field is required, it can be specified with the :meth:`get_field` -method. For example, lets say a model has a field called ``object_id`` -which contains ids of another model, we could use:: + """A :class:`Q` performs a series of operations and ultimately + generate of set of matched elements ``ids``. If on the other hand, a + different field is required, it can be specified with the :meth:`get_field` + method. For example, lets say a model has a field called ``object_id`` + which contains ids of another model, we could use:: - qs = session.query(MyModel).get_field('object_id') + qs = session.query(MyModel).get_field('object_id') -to obtain a set containing the values of matched elements ``object_id`` -fields. + to obtain a set containing the values of matched elements ``object_id`` + fields. -:parameter field: the name of the field which will be used to obtained the - matched elements value. Must be an index. -:rtype: a new :class:`Q` instance. -''' + :parameter field: the name of the field which will be used to obtained the + matched elements value. Must be an index. + :rtype: a new :class:`Q` instance.""" if field != self._get_field: if field not in self._meta.dfields: - raise QuerySetError('Model "{0}" has no field "{1}".' - .format(self._meta, field)) + raise QuerySetError( + 'Model "{0}" has no field "{1}".'.format(self._meta, field) + ) q = self._clone() - q.data['get_field'] = field + q.data["get_field"] = field return q else: return self @@ -138,55 +155,56 @@ def clear(self): pass def backend_query(self, **kwargs): - '''Build the :class:`stdnet.utils.async.BackendQuery` for this - instance. -This is a virtual method with different implementation in :class:`Query` -and :class:`QueryElement`.''' + """Build the :class:`stdnet.utils.async.BackendQuery` for this + instance. + This is a virtual method with different implementation in :class:`Query` + and :class:`QueryElement`.""" raise NotImplementedError def _clone(self): cls = self.__class__ q = cls.__new__(cls) d = self.__dict__.copy() - d['data'] = d['data'].copy() + d["data"] = d["data"].copy() if self.unions: - d['unions'] = copy(self.unions) + d["unions"] = copy(self.unions) q.__dict__ = d q.clear() return q class QueryElement(Q): - '''An element of a :class:`Query`. + """An element of a :class:`Query`. + + .. attribute:: qs -.. attribute:: qs + the :class:`Query` which contains this :class:`QueryElement`. - the :class:`Query` which contains this :class:`QueryElement`. + .. attribute:: underlying -.. attribute:: underlying + the element contained in the :class:`QueryElement`. This underlying is + an iterable or another :class:`Query`. - the element contained in the :class:`QueryElement`. This underlying is - an iterable or another :class:`Query`. + .. attribute:: valid -.. attribute:: valid + if ``False`` this :class:`QueryElement` has no underlying elements and + won't produce any query.""" - if ``False`` this :class:`QueryElement` has no underlying elements and - won't produce any query. -''' def __init__(self, *args, **kwargs): self.__backend_query = None - underlying = kwargs.pop('underlying', None) + underlying = kwargs.pop("underlying", None) super(QueryElement, self).__init__(*args, **kwargs) self.underlying = underlying if underlying is not None else () def __repr__(self): - v = '' + v = "" if self.underlying is not None: - v = '('+', '.join((str(v) for v in self))+')' + v = "(" + ", ".join((str(v) for v in self)) + ")" k = self.keyword if self.name: - k += '-' + self.name + k += "-" + self.name return k + v + __str__ = __repr__ def __iter__(self): @@ -213,39 +231,40 @@ def executed(self): @property def valid(self): if isinstance(self.underlying, QueryElement): - return self.keyword == 'set' + return self.keyword == "set" else: return len(self.underlying) > 0 class QuerySet(QueryElement): - '''A :class:`QueryElement` which represents a lookup on a field.''' - keyword = 'set' - name = 'id' + """A :class:`QueryElement` which represents a lookup on a field.""" + + keyword = "set" + name = "id" class Select(QueryElement): """Forms the basis of select type set operations.""" + pass def make_select(keyword, queries): first = queries[0] queries = [q.construct() for q in queries] - return Select(first.meta, first.session, keyword=keyword, - underlying=queries) + return Select(first.meta, first.session, keyword=keyword, underlying=queries) def intersect(queries): - return make_select('intersect', queries) + return make_select("intersect", queries) def union(queries): - return make_select('union', queries) + return make_select("union", queries) def difference(queries): - return make_select('diff', queries) + return make_select("diff", queries) def queryset(qs, **kwargs): @@ -253,7 +272,6 @@ def queryset(qs, **kwargs): class QueryBase(Q): - def __iter__(self): return iter(self.items()) @@ -261,13 +279,14 @@ def __len__(self): return self.count() def all(self): - '''Return a ``list`` of all matched elements in this :class:`Query`.''' + """Return a ``list`` of all matched elements in this :class:`Query`.""" return self.items() class EmptyQuery(QueryBase): - '''Degenerate :class:`QueryBase` simulating and empty set.''' - keyword = 'empty' + """Degenerate :class:`QueryBase` simulating and empty set.""" + + keyword = "empty" def items(self, slic=None): return [] @@ -288,93 +307,93 @@ def union(self, query, *queries): def intersect(self, *queries): return self - + class Query(QueryBase): - '''A :class:`Query` is produced in terms of a given :class:`Session`, -using the :meth:`Session.query` method:: + """A :class:`Query` is produced in terms of a given :class:`Session`, + using the :meth:`Session.query` method:: + + qs = session.query(MyModel) - qs = session.query(MyModel) + A query is equivalent to a collection of SELECT statements for a standard + relational database. It has a a generative interface whereby successive calls + return a new :class:`Query` object, a copy of the former with additional + criteria and options associated with it. -A query is equivalent to a collection of SELECT statements for a standard -relational database. It has a a generative interface whereby successive calls -return a new :class:`Query` object, a copy of the former with additional -criteria and options associated with it. + **ATTRIBUTES** -**ATTRIBUTES** + .. attribute:: _meta -.. attribute:: _meta + The :attr:`StdModel._meta` attribute. - The :attr:`StdModel._meta` attribute. + .. attribute:: model -.. attribute:: model + the :class:`StdModel` class for this query. - the :class:`StdModel` class for this query. + .. attribute:: session -.. attribute:: session + The :class:`Session` which created the :class:`Query` via the + :meth:`Session.query` method. - The :class:`Session` which created the :class:`Query` via the - :meth:`Session.query` method. + .. attribute:: backend -.. attribute:: backend + the :class:`stdnet.BackendDataServer` holding the data to query. - the :class:`stdnet.BackendDataServer` holding the data to query. + .. attribute:: _get_field -.. attribute:: _get_field + When iterating over a :class:`Query`, you get back instances of + the :attr:`model` class. However, if ``_get_field`` is specified + you get back values of the field specified. + This can be changed via the :meth:`get_field` method:: - When iterating over a :class:`Query`, you get back instances of - the :attr:`model` class. However, if ``_get_field`` is specified - you get back values of the field specified. - This can be changed via the :meth:`get_field` method:: + qs = query.get_field('name').all() - qs = query.get_field('name').all() + the results is a list of name values (provided the model has a + ``name`` field of course). - the results is a list of name values (provided the model has a - ``name`` field of course). + Default: ``None``. - Default: ``None``. + .. attribute:: fargs -.. attribute:: fargs + Dictionary containing the ``filter`` lookup parameters each one of + them corresponding to a ``where`` clause of a select. This value is + manipulated via the :meth:`filter` method. - Dictionary containing the ``filter`` lookup parameters each one of - them corresponding to a ``where`` clause of a select. This value is - manipulated via the :meth:`filter` method. + Default: ``{}``. - Default: ``{}``. + .. attribute:: eargs -.. attribute:: eargs + Dictionary containing the ``exclude`` lookup parameters each one + of them corresponding to a ``where`` clause of a select. This value is + manipulated via the :meth:`exclude` method. - Dictionary containing the ``exclude`` lookup parameters each one - of them corresponding to a ``where`` clause of a select. This value is - manipulated via the :meth:`exclude` method. + Default: ``{}``. - Default: ``{}``. + .. attribute:: ordering -.. attribute:: ordering + optional ordering field. - optional ordering field. + .. attribute:: text -.. attribute:: text + optional text to filter result on. - optional text to filter result on. + Default: ``""``. - Default: ``""``. + **METHODS**""" -**METHODS** -''' start = None stop = None - lookups = ('in', 'contains') - + lookups = ("in", "contains") + def __init__(self, *args, **kwargs): - '''A :class:`Query` is not initialized directly but via the -:meth:`Session.query` or :meth:`Manager.query` methods.''' - self.fargs = kwargs.pop('fargs', None) - self.eargs = kwargs.pop('eargs', None) - self.unions = kwargs.pop('unions', ()) - self.searchengine = kwargs.pop('searchengine', None) - self.intersections = kwargs.pop('intersections', ()) - self.text = kwargs.pop('text', None) - self.exclude_fields = kwargs.pop('exclude_fields', None) + """A :class:`Query` is not initialized directly but via the + :meth:`Session.query` or :meth:`Manager.query` methods.""" + self.fargs = kwargs.pop("fargs", None) + self.eargs = kwargs.pop("eargs", None) + self.unions = kwargs.pop("unions", ()) + self.searchengine = kwargs.pop("searchengine", None) + self.intersections = kwargs.pop("intersections", ()) + self.text = kwargs.pop("text", None) + self.exclude_fields = kwargs.pop("exclude_fields", None) super(Query, self).__init__(*args, **kwargs) self.clear() @@ -390,26 +409,26 @@ def __repr__(self): if seq is None: s = self.__class__.__name__ if self.fargs: - s = '%s.filter(%s)' % (s, self.fargs) + s = "%s.filter(%s)" % (s, self.fargs) if self.eargs: - s = '%s.exclude(%s)' % (s, self.eargs) + s = "%s.exclude(%s)" % (s, self.eargs) return s else: return repr(seq) + __str__ = __repr__ def filter(self, **kwargs): - '''Create a new :class:`Query` with additional clauses corresponding to -``where`` or ``limit`` in a ``SQL SELECT`` statement. + """Create a new :class:`Query` with additional clauses corresponding to + ``where`` or ``limit`` in a ``SQL SELECT`` statement. -:parameter kwargs: dictionary of limiting clauses. -:rtype: a new :class:`Query` instance. + :parameter kwargs: dictionary of limiting clauses. + :rtype: a new :class:`Query` instance. -For example:: + For example:: - qs = session.query(MyModel) - result = qs.filter(group = 'planet') -''' + qs = session.query(MyModel) + result = qs.filter(group = 'planet')""" if kwargs: q = self._clone() if self.fargs: @@ -420,19 +439,18 @@ def filter(self, **kwargs): return self def exclude(self, **kwargs): - '''Returns a new :class:`Query` with additional clauses corresponding -to ``EXCEPT`` in a ``SQL SELECT`` statement. + """Returns a new :class:`Query` with additional clauses corresponding + to ``EXCEPT`` in a ``SQL SELECT`` statement. -:parameter kwargs: dictionary of limiting clauses. -:rtype: a new :class:`Query` instance. + :parameter kwargs: dictionary of limiting clauses. + :rtype: a new :class:`Query` instance. -Using an equivalent example to the :meth:`filter` method:: + Using an equivalent example to the :meth:`filter` method:: - qs = session.query(MyModel) - result1 = qs.exclude(group = 'planet') - result2 = qs.exclude(group__in = ('planet','stars')) - -''' + qs = session.query(MyModel) + result1 = qs.exclude(group = 'planet') + result2 = qs.exclude(group__in = ('planet','stars')) + """ if kwargs: q = self._clone() if self.eargs: @@ -443,116 +461,112 @@ def exclude(self, **kwargs): return self def union(self, *queries): - '''Return a new :class:`Query` obtained form the union of this -:class:`Query` with one or more *queries*. -For example, lets say we want to have the union -of two queries obtained from the :meth:`filter` method:: - - query = session.query(MyModel) - qs = query.filter(field1 = 'bla').union(query.filter(field2 = 'foo')) -''' + """Return a new :class:`Query` obtained form the union of this + :class:`Query` with one or more *queries*. + For example, lets say we want to have the union + of two queries obtained from the :meth:`filter` method:: + + query = session.query(MyModel) + qs = query.filter(field1 = 'bla').union(query.filter(field2 = 'foo'))""" q = self._clone() q.unions += queries return q def intersect(self, *queries): - '''Return a new :class:`Query` obtained form the intersection of this -:class:`Query` with one or more *queries*. Workds the same way as -the :meth:`union` method.''' + """Return a new :class:`Query` obtained form the intersection of this + :class:`Query` with one or more *queries*. Workds the same way as + the :meth:`union` method.""" q = self._clone() q.intersections += queries return q def sort_by(self, ordering): - '''Sort the query by the given field + """Sort the query by the given field -:parameter ordering: a string indicating the class:`Field` name to sort by. - If prefixed with ``-``, the sorting will be in descending order, otherwise - in ascending order. -:return type: a new :class:`Query` instance. -''' + :parameter ordering: a string indicating the class:`Field` name to sort by. + If prefixed with ``-``, the sorting will be in descending order, otherwise + in ascending order. + :return type: a new :class:`Query` instance.""" if ordering: ordering = self._meta.get_sorting(ordering, QuerySetError) q = self._clone() - q.data['ordering'] = ordering + q.data["ordering"] = ordering return q def search(self, text, lookup=None): - '''Search *text* in model. A search engine needs to be installed -for this function to be available. + """Search *text* in model. A search engine needs to be installed + for this function to be available. -:parameter text: a string to search. -:return type: a new :class:`Query` instance. -''' + :parameter text: a string to search. + :return type: a new :class:`Query` instance.""" q = self._clone() q.text = (text, lookup) return q def where(self, code, load_only=None): - '''For :ref:`backend ` supporting scripting, it is possible -to construct complex queries which execute the scripting *code* against -each element in the query. The *coe* should reference an instance of -:attr:`model` by ``this`` keyword. - -:parameter code: a valid expression in the scripting language of the database. -:parameter load_only: Load only the selected fields when performing the query - (this is different from the :meth:`load_only` method which is used when - fetching data from the database). This field is an optimization which is - used by the :ref:`redis backend ` only and can be safely - ignored in most use-cases. -:return: a new :class:`Query` -''' + """For :ref:`backend ` supporting scripting, it is possible + to construct complex queries which execute the scripting *code* against + each element in the query. The *coe* should reference an instance of + :attr:`model` by ``this`` keyword. + + :parameter code: a valid expression in the scripting language of the database. + :parameter load_only: Load only the selected fields when performing the query + (this is different from the :meth:`load_only` method which is used when + fetching data from the database). This field is an optimization which is + used by the :ref:`redis backend ` only and can be safely + ignored in most use-cases. + :return: a new :class:`Query`""" if code: q = self._clone() - q.data['where'] = (code, load_only) + q.data["where"] = (code, load_only) return q else: return self def search_queries(self, q): - '''Return a new :class:`QueryElem` for *q* applying a text search.''' + """Return a new :class:`QueryElem` for *q* applying a text search.""" if self.text: searchengine = self.session.router.search_engine if searchengine: return searchengine.search_model(q, *self.text) else: - raise QuerySetError('Search not available for %s' % self._meta) + raise QuerySetError("Search not available for %s" % self._meta) else: return q def load_related(self, related, *related_fields): - '''It returns a new :class:`Query` that automatically -follows the foreign-key relationship ``related``. + """It returns a new :class:`Query` that automatically + follows the foreign-key relationship ``related``. -:parameter related: A field name corresponding to a :class:`ForeignKey` - in :attr:`Query.model`. -:parameter related_fields: optional :class:`Field` names for the ``related`` - model to load. If not provided, all fields will be loaded. + :parameter related: A field name corresponding to a :class:`ForeignKey` + in :attr:`Query.model`. + :parameter related_fields: optional :class:`Field` names for the ``related`` + model to load. If not provided, all fields will be loaded. -This function is :ref:`performance boost ` when -accessing the related fields of all (most) objects in your query. + This function is :ref:`performance boost ` when + accessing the related fields of all (most) objects in your query. -If Your model contains more than one foreign key, you can use this function -in a generative way:: + If Your model contains more than one foreign key, you can use this function + in a generative way:: - qs = myquery.load_related('rel1').load_related('rel2','field1','field2') + qs = myquery.load_related('rel1').load_related('rel2','field1','field2') -:rtype: a new :class:`Query`.''' + :rtype: a new :class:`Query`.""" field = self._get_related_field(related) if not field: - raise FieldError('"%s" is not a related field for "%s"' % - (related, self._meta)) + raise FieldError( + '"%s" is not a related field for "%s"' % (related, self._meta) + ) q = self._clone() return q._add_to_load_related(field, *related_fields) def load_only(self, *fields): - '''This is provides a :ref:`performance boost ` -in cases when you need to load a subset of fields of your model. The boost -achieved is less than the one obtained when using -:meth:`Query.load_related`, since it does not reduce the number of requests -to the database. However, it can save you lots of bandwidth when excluding -data intensive fields you don't need. -''' + """This is provides a :ref:`performance boost ` + in cases when you need to load a subset of fields of your model. The boost + achieved is less than the one obtained when using + :meth:`Query.load_related`, since it does not reduce the number of requests + to the database. However, it can save you lots of bandwidth when excluding + data intensive fields you don't need.""" q = self._clone() new_fields = [] for field in fields: @@ -569,14 +583,13 @@ def load_only(self, *fields): # loaded. new_fields.append(self._meta.pkname()) fs = unique_tuple(q.fields, new_fields) - q.data['fields'] = fs if fs else None + q.data["fields"] = fs if fs else None return q def dont_load(self, *fields): - '''Works like :meth:`load_only` to provides a -:ref:`performance boost ` in cases when you need -to load all fields except a subset specified by *fields*. -''' + """Works like :meth:`load_only` to provides a + :ref:`performance boost ` in cases when you need + to load all fields except a subset specified by *fields*.""" q = self._clone() fs = unique_tuple(q.exclude_fields, fields) q.exclude_fields = fs if fs else None @@ -588,67 +601,64 @@ def __getitem__(self, slic): return self.backend_query()[slic] def items(self, callback=None): - '''Retrieve all items for this :class:`Query`.''' + """Retrieve all items for this :class:`Query`.""" return self.backend_query().items(callback=callback) def get(self, **kwargs): - '''Return an instance of a model matching the query. A special case is -the query on ``id`` which provides a direct access to the :attr:`session` -instances. If the given primary key is present in the session, the object -is returned directly without performing any query.''' - return self.filter(**kwargs).items( - callback=self.model.get_unique_instance) + """Return an instance of a model matching the query. A special case is + the query on ``id`` which provides a direct access to the :attr:`session` + instances. If the given primary key is present in the session, the object + is returned directly without performing any query.""" + return self.filter(**kwargs).items(callback=self.model.get_unique_instance) def count(self): - '''Return the number of objects in ``self``. -This method is efficient since the :class:`Query` does not -receive any data from the server apart from the number of matched elements. -It construct the queries and count the -objects on the server side.''' + """Return the number of objects in ``self``. + This method is efficient since the :class:`Query` does not + receive any data from the server apart from the number of matched elements. + It construct the queries and count the + objects on the server side.""" return self.backend_query().count() def delete(self): - '''Delete all matched elements of the :class:`Query`. It returns the -list of ids deleted.''' + """Delete all matched elements of the :class:`Query`. It returns the + list of ids deleted.""" return self.session.delete(self) def construct(self): - '''Build the :class:`QueryElement` representing this query.''' + """Build the :class:`QueryElement` representing this query.""" if self.__construct is None: self.__construct = self._construct() return self.__construct def backend_query(self, **kwargs): - '''Build and return the :class:`stdnet.utils.async.BackendQuery`. -This is a lazy method in the sense that it is evaluated once only and its -result stored for future retrieval.''' + """Build and return the :class:`stdnet.utils.async.BackendQuery`. + This is a lazy method in the sense that it is evaluated once only and its + result stored for future retrieval.""" q = self.construct() return q if isinstance(q, EmptyQuery) else q.backend_query(**kwargs) def test_unique(self, fieldname, value, instance=None, exception=None): - '''Test if a given field *fieldname* has a unique *value* -in :attr:`model`. The field must be an index of the model. -If the field value is not unique and the *instance* is not the same -an exception is raised. - -:parameter fieldname: :class:`Field` name to test -:parameter vale: :class:`Field` value -:parameter instance: optional instance of :attr:`model` -:parameter exception: optional exception class to raise if the test fails. - Default: :attr:`ModelMixin.DoesNotValidate`. -:return: *value* -''' + """Test if a given field *fieldname* has a unique *value* + in :attr:`model`. The field must be an index of the model. + If the field value is not unique and the *instance* is not the same + an exception is raised. + + :parameter fieldname: :class:`Field` name to test + :parameter vale: :class:`Field` value + :parameter instance: optional instance of :attr:`model` + :parameter exception: optional exception class to raise if the test fails. + Default: :attr:`ModelMixin.DoesNotValidate`. + :return: *value*""" qs = self.filter(**{fieldname: value}) - callback = partial(self._test_unique, fieldname, value, - instance, exception) + callback = partial(self._test_unique, fieldname, value, instance, exception) return qs.backend_query().items(callback=callback) def map_reduce(self, map_script, reduce_script, **kwargs): - '''Perform a map/reduce operation on this query.''' + """Perform a map/reduce operation on this query.""" pass ######################################################################## - # PRIVATE METHODS + # PRIVATE METHODS ######################################################################## def clear(self): self.__construct = None @@ -679,24 +689,24 @@ def _construct(self): else: eargs = None if eargs: - q = difference([q]+eargs) + q = difference([q] + eargs) if self.intersections: - q = intersect((q,)+self.intersections) + q = intersect((q,) + self.intersections) if self.unions: - q = union((q,)+self.unions) + q = union((q,) + self.unions) q = self.search_queries(q) data = self.data.copy() if self.exclude_fields: - fields = data['fields'] + fields = data["fields"] if not fields: fields = tuple((f.name for f in self._meta.scalarfields)) fields = tuple((f for f in fields if f not in self.exclude_fields)) - data['fields'] = fields + data["fields"] = fields q.data = data return q def aggregate(self, kwargs): - '''Aggregate lookup parameters.''' + """Aggregate lookup parameters.""" meta = self._meta fields = meta.dfields field_lookups = {} @@ -704,42 +714,49 @@ def aggregate(self, kwargs): bits = name.split(JSPLITTER) field_name = bits.pop(0) if field_name not in fields: - raise QuerySetError('Could not filter on model "{0}".\ - Field "{1}" does not exist.'.format(meta, field_name)) + raise QuerySetError( + 'Could not filter on model "{0}".\ + Field "{1}" does not exist.'.format( + meta, field_name + ) + ) field = fields[field_name] attname = field.attname lookup = None if bits: bits = [n.lower() for n in bits] - if bits[-1] == 'in': + if bits[-1] == "in": bits.pop() elif bits[-1] in range_lookups: lookup = bits.pop() remaining = JSPLITTER.join(bits) if lookup: # this is a range lookup - attname, nested = field.get_lookup(remaining, - QuerySetError) + attname, nested = field.get_lookup(remaining, QuerySetError) lookups = get_lookups(attname, field_lookups) lookups.append(lookup_value(lookup, (value, nested))) continue - elif remaining: # Not a range lookup, must be a nested filter + elif remaining: # Not a range lookup, must be a nested filter value = field.filter(self.session, remaining, value) lookups = get_lookups(attname, field_lookups) # If we are here the field must be an index if not field.index: - raise QuerySetError("%s %s is not an index. Cannot query." % - (field.__class__.__name__, field_name)) + raise QuerySetError( + "%s %s is not an index. Cannot query." + % (field.__class__.__name__, field_name) + ) if not iterable(value): value = (value,) for v in value: if isinstance(v, Q): - v = lookup_value('set', v.construct()) + v = lookup_value("set", v.construct()) else: - v = lookup_value('value', field.serialise(v, lookup)) + v = lookup_value("value", field.serialise(v, lookup)) lookups.append(v) # - return [queryset(self, name=name, underlying=field_lookups[name]) - for name in sorted(field_lookups)] + return [ + queryset(self, name=name, underlying=field_lookups[name]) + for name in sorted(field_lookups) + ] def _test_unique(self, fieldname, value, instance, exception, items): if items: @@ -748,8 +765,9 @@ def _test_unique(self, fieldname, value, instance, exception, items): return value else: exception = exception or self.model.DoesNotValidate - raise exception('An instance with %s %s is already available' - % (fieldname, value)) + raise exception( + "An instance with %s %s is already available" % (fieldname, value) + ) else: return value @@ -757,7 +775,7 @@ def _get_related_field(self, related): meta = self._meta if related in meta.dfields: field = meta.dfields[related] - if hasattr(field, 'relmodel'): + if hasattr(field, "relmodel"): return field def _add_to_load_related(self, field, *related_fields): @@ -767,7 +785,7 @@ def _add_to_load_related(self, field, *related_fields): d = dict(((k, tuple(v)) for k, v in self.select_related.items())) else: d = {} - self.data['select_related'] = d + self.data["select_related"] = d if field.name in d: d[field.name] = unique_tuple(d[field.name], rf) else: diff --git a/stdnet/odm/related.py b/stdnet/odm/related.py index 51e8cbb..2d88a01 100644 --- a/stdnet/odm/related.py +++ b/stdnet/odm/related.py @@ -1,15 +1,15 @@ from functools import partial +from stdnet import ManyToManyError, QuerySetError from stdnet.utils import encoders -from stdnet import QuerySetError, ManyToManyError from .globals import Event -from .session import Manager, LazyProxy +from .session import LazyProxy, Manager -__all__ = ['LazyForeignKey', 'ModelFieldPickler'] +__all__ = ["LazyForeignKey", "ModelFieldPickler"] -RECURSIVE_RELATIONSHIP_CONSTANT = 'self' +RECURSIVE_RELATIONSHIP_CONSTANT = "self" pending_lookups = {} @@ -17,7 +17,8 @@ class ModelFieldPickler(encoders.Encoder): - '''An encoder for :class:`StdModel` instances.''' + """An encoder for :class:`StdModel` instances.""" + def __init__(self, model): self.model = model @@ -65,7 +66,7 @@ def load_relmodel(field, callback): def do_pending_lookups(event, sender, **kwargs): """Handle any pending relations to the sending model. -Sent from class_prepared.""" + Sent from class_prepared.""" key = (sender._meta.app_label, sender._meta.name) for callback in pending_lookups.pop(key, []): callback(sender) @@ -75,59 +76,62 @@ def do_pending_lookups(event, sender, **kwargs): def Many2ManyThroughModel(field): - '''Create a Many2Many through model with two foreign key fields and a -CompositeFieldId depending on the two foreign keys.''' - from stdnet.odm import ModelType, StdModel, ForeignKey, CompositeIdField + """Create a Many2Many through model with two foreign key fields and a + CompositeFieldId depending on the two foreign keys.""" + from stdnet.odm import CompositeIdField, ForeignKey, ModelType, StdModel + name_model = field.model._meta.name name_relmodel = field.relmodel._meta.name # The two models are the same. if name_model == name_relmodel: - name_relmodel += '2' + name_relmodel += "2" through = field.through # Create the through model if through is None: - name = '{0}_{1}'.format(name_model, name_relmodel) + name = "{0}_{1}".format(name_model, name_relmodel) class Meta: app_label = field.model._meta.app_label - through = ModelType(name, (StdModel,), {'Meta': Meta}) + + through = ModelType(name, (StdModel,), {"Meta": Meta}) field.through = through # The first field - field1 = ForeignKey(field.model, - related_name=field.name, - related_manager_class=makeMany2ManyRelatedManager( - field.relmodel, - name_model, - name_relmodel) - ) + field1 = ForeignKey( + field.model, + related_name=field.name, + related_manager_class=makeMany2ManyRelatedManager( + field.relmodel, name_model, name_relmodel + ), + ) field1.register_with_model(name_model, through) # The second field - field2 = ForeignKey(field.relmodel, - related_name=field.related_name, - related_manager_class=makeMany2ManyRelatedManager( - field.model, - name_relmodel, - name_model) - ) + field2 = ForeignKey( + field.relmodel, + related_name=field.related_name, + related_manager_class=makeMany2ManyRelatedManager( + field.model, name_relmodel, name_model + ), + ) field2.register_with_model(name_relmodel, through) pk = CompositeIdField(name_model, name_relmodel) - pk.register_with_model('id', through) + pk.register_with_model("id", through) class LazyForeignKey(LazyProxy): - '''Descriptor for a :class:`ForeignKey` field.''' + """Descriptor for a :class:`ForeignKey` field.""" + def load(self, instance, session=None, backend=None): return instance._load_related_model(self.field) def __set__(self, instance, value): if instance is None: - raise AttributeError("%s must be accessed via instance" % - self._field.name) + raise AttributeError("%s must be accessed via instance" % self._field.name) field = self.field if value is not None and not isinstance(value, field.relmodel): raise ValueError( - 'Cannot assign "%r": "%s" must be a "%s" instance.' % - (value, field, field.relmodel._meta.name)) + 'Cannot assign "%r": "%s" must be a "%s" instance.' + % (value, field, field.relmodel._meta.name) + ) cache_name = self.field.get_cache_name() # If we're setting the value of a OneToOneField to None, @@ -151,19 +155,19 @@ def __set__(self, instance, value): class RelatedManager(Manager): - '''Base class for managers handling relationships between models. -While standard :class:`Manager` are class properties of a model, -related managers are accessed by instances to easily retrieve instances -of a related model. + """Base class for managers handling relationships between models. + While standard :class:`Manager` are class properties of a model, + related managers are accessed by instances to easily retrieve instances + of a related model. -.. attribute:: relmodel + .. attribute:: relmodel - The :class:`StdModel` this related manager relates to. + The :class:`StdModel` this related manager relates to. -.. attribute:: related_instance + .. attribute:: related_instance + + An instance of the :attr:`relmodel`.""" - An instance of the :attr:`relmodel`. -''' def __init__(self, field, model=None, instance=None): self.field = field model = model or field.model @@ -174,25 +178,28 @@ def __get__(self, instance, instance_type=None): return self.__class__(self.field, self.model, instance) def session(self, session=None): - '''Override :meth:`Manager.session` so that this + """Override :meth:`Manager.session` so that this :class:`RelatedManager` can retrieve the session from the :attr:`related_instance` if available. - ''' + """ if self.related_instance: session = self.related_instance.session # we have a session, we either create a new one return the same session if session is None: - raise QuerySetError('Related manager can be accessed only from\ - a loaded instance of its related model.') + raise QuerySetError( + "Related manager can be accessed only from\ + a loaded instance of its related model." + ) return session class One2ManyRelatedManager(RelatedManager): - '''A specialised :class:`RelatedManager` for handling one-to-many -relationships under the hood. -If a model has a :class:`ForeignKey` field, instances of -that model will have access to the related (foreign) objects -via a simple attribute of the model.''' + """A specialised :class:`RelatedManager` for handling one-to-many + relationships under the hood. + If a model has a :class:`ForeignKey` field, instances of + that model will have access to the related (foreign) objects + via a simple attribute of the model.""" + @property def relmodel(self): return self.field.relmodel @@ -213,45 +220,48 @@ def query_from_query(self, query, params=None): class Many2ManyRelatedManager(One2ManyRelatedManager): - '''A specialized :class:`Manager` for handling -many-to-many relationships under the hood. -When a model has a :class:`ManyToManyField`, instances -of that model will have access to the related objects via a simple -attribute of the model.''' + """A specialized :class:`Manager` for handling + many-to-many relationships under the hood. + When a model has a :class:`ManyToManyField`, instances + of that model will have access to the related objects via a simple + attribute of the model.""" + def session_instance(self, name, value, session, **kwargs): if self.related_instance is None: raise ManyToManyError('Cannot use "%s" method from class' % name) elif not self.related_instance.pkvalue(): - raise ManyToManyError('Cannot use "%s" method on a non persistent ' - 'instance.' % name) + raise ManyToManyError( + 'Cannot use "%s" method on a non persistent ' "instance." % name + ) elif not isinstance(value, self.formodel): raise ManyToManyError( - '%s is not an instance of %s' % (value, self.formodel._meta)) + "%s is not an instance of %s" % (value, self.formodel._meta) + ) elif not value.pkvalue(): - raise ManyToManyError('Cannot use "%s" a non persistent instance.' - % name) - kwargs.update({self.name_formodel: value, - self.name_relmodel: self.related_instance}) + raise ManyToManyError('Cannot use "%s" a non persistent instance.' % name) + kwargs.update( + {self.name_formodel: value, self.name_relmodel: self.related_instance} + ) return self.session(session), self.model(**kwargs) def add(self, value, session=None, **kwargs): - '''Add ``value``, an instance of :attr:`formodel` to the -:attr:`through` model. This method can only be accessed by an instance of the -model for which this related manager is an attribute.''' - s, instance = self.session_instance('add', value, session, **kwargs) + """Add ``value``, an instance of :attr:`formodel` to the + :attr:`through` model. This method can only be accessed by an instance of the + model for which this related manager is an attribute.""" + s, instance = self.session_instance("add", value, session, **kwargs) return s.add(instance) def remove(self, value, session=None): - '''Remove *value*, an instance of ``self.model`` from the set of -elements contained by the field.''' - s, instance = self.session_instance('remove', value, session) + """Remove *value*, an instance of ``self.model`` from the set of + elements contained by the field.""" + s, instance = self.session_instance("remove", value, session) # update state so that the instance does look persistent - instance.get_state(iid=instance.pkvalue(), action='update') + instance.get_state(iid=instance.pkvalue(), action="update") return s.delete(instance) def throughquery(self, session=None): - '''Return a :class:`Query` on the ``throughmodel``, the model -used to hold the :ref:`many-to-many relationship `.''' + """Return a :class:`Query` on the ``throughmodel``, the model + used to hold the :ref:`many-to-many relationship `.""" return super(Many2ManyRelatedManager, self).query(session) def query(self, session=None): @@ -263,7 +273,7 @@ def query(self, session=None): def makeMany2ManyRelatedManager(formodel, name_relmodel, name_formodel): - '''formodel is the model which the manager .''' + """formodel is the model which the manager .""" class _Many2ManyRelatedManager(Many2ManyRelatedManager): pass diff --git a/stdnet/odm/search.py b/stdnet/odm/search.py index 6db9544..2894793 100755 --- a/stdnet/odm/search.py +++ b/stdnet/odm/search.py @@ -1,58 +1,56 @@ import logging +from inspect import isclass, isgenerator -from inspect import isgenerator, isclass +__all__ = ["SearchEngine"] -__all__ = ['SearchEngine'] - - -LOGGER = logging.getLogger('stdnet.search') +LOGGER = logging.getLogger("stdnet.search") class SearchEngine(object): """Stdnet search engine driver. This is an abstract class which -expose the base functionalities for full text-search on model instances. -Stdnet also provides a :ref:`python implementation ` -of this interface. + expose the base functionalities for full text-search on model instances. + Stdnet also provides a :ref:`python implementation ` + of this interface. + + The main methods to be implemented are :meth:`add_item`, + :meth:`remove_index` and :meth:`search_model`. -The main methods to be implemented are :meth:`add_item`, -:meth:`remove_index` and :meth:`search_model`. + .. attribute:: word_middleware -.. attribute:: word_middleware + A list of middleware functions for preprocessing text + to be indexed. A middleware function has arity 1 by + accepting an iterable of words and + returning an iterable of words. Word middleware functions + are added to the search engine via the + :meth:`add_word_middleware` method. - A list of middleware functions for preprocessing text - to be indexed. A middleware function has arity 1 by - accepting an iterable of words and - returning an iterable of words. Word middleware functions - are added to the search engine via the - :meth:`add_word_middleware` method. + For example this function remove a group of words from the index:: - For example this function remove a group of words from the index:: + se = SearchEngine() - se = SearchEngine() + class stopwords(object): - class stopwords(object): + def __init__(self, *swords): + self.swords = set(swords) - def __init__(self, *swords): - self.swords = set(swords) + def __call__(self, words): + for word in words: + if word not in self.swords: + yield word - def __call__(self, words): - for word in words: - if word not in self.swords: - yield word + se.add_word_middleware(stopwords('and','or','this','that',...)) - se.add_word_middleware(stopwords('and','or','this','that',...)) + .. attribute:: max_in_session -.. attribute:: max_in_session + Maximum number of instances to be reindexed in one session. + Default ``1000``.""" - Maximum number of instances to be reindexed in one session. - Default ``1000``. -""" def __init__(self, backend=None, logger=None, max_in_session=None): self._backend = backend self.REGISTERED_MODELS = {} self.ITEM_PROCESSORS = [] - self.last_indexed = 'last_indexed' + self.last_indexed = "last_indexed" self.word_middleware = [] self.add_processor(stdnet_processor(self)) self.logger = logger or LOGGER @@ -61,17 +59,16 @@ def __init__(self, backend=None, logger=None, max_in_session=None): @property def backend(self): - '''Backend for this search engine.''' + """Backend for this search engine.""" return self._backend def register(self, model, related=None): - '''Register a :class:`StdModel` with this search :class:`SearchEngine`. -When registering a model, every time an instance is created, it will be -indexed by the search engine. + """Register a :class:`StdModel` with this search :class:`SearchEngine`. + When registering a model, every time an instance is created, it will be + indexed by the search engine. -:param model: a :class:`StdModel` class. -:param related: a list of related fields to include in the index. -''' + :param model: a :class:`StdModel` class. + :param related: a list of related fields to include in the index.""" update_model = UpdateSE(self, related) self.REGISTERED_MODELS[model] = update_model self.router.post_commit.bind(update_model, model) @@ -84,22 +81,21 @@ def get_related_fields(self, item): return registered.related if registered else () def words_from_text(self, text, for_search=False): - '''Generator of indexable words in *text*. -This functions loop through the :attr:`word_middleware` attribute -to process the text. - -:param text: string from which to extract words. -:param for_search: flag indicating if the the words will be used for search - or to index the database. This flug is used in conjunction with the - middleware flag *for_search*. If this flag is ``True`` (i.e. we need to - search the database for the words in *text*), only the - middleware functions in :attr:`word_middleware` enabled for searching are - used. - - Default: ``False``. - -return a *list* of cleaned words. -''' + """Generator of indexable words in *text*. + This functions loop through the :attr:`word_middleware` attribute + to process the text. + + :param text: string from which to extract words. + :param for_search: flag indicating if the the words will be used for search + or to index the database. This flug is used in conjunction with the + middleware flag *for_search*. If this flag is ``True`` (i.e. we need to + search the database for the words in *text*), only the + middleware functions in :attr:`word_middleware` enabled for searching are + used. + + Default: ``False``. + + return a *list* of cleaned words.""" if not text: return [] word_gen = self.split_text(text) @@ -112,8 +108,8 @@ def words_from_text(self, text, for_search=False): return word_gen def split_text(self, text): - '''Split text into words and return an iterable over them. -Can and should be reimplemented by subclasses.''' + """Split text into words and return an iterable over them. + Can and should be reimplemented by subclasses.""" return text.split() def add_processor(self, processor): @@ -121,37 +117,33 @@ def add_processor(self, processor): self.ITEM_PROCESSORS.append(processor) def add_word_middleware(self, middleware, for_search=True): - '''Add a *middleware* function to the list of :attr:`word_middleware`, -for preprocessing words to be indexed. - -:param middleware: a callable receving an iterable over words. -:param for_search: flag indicating if the *middleware* can be used for the - text to search. Default: ``True``. -''' - if hasattr(middleware, '__call__'): + """Add a *middleware* function to the list of :attr:`word_middleware`, + for preprocessing words to be indexed. + + :param middleware: a callable receving an iterable over words. + :param for_search: flag indicating if the *middleware* can be used for the + text to search. Default: ``True``.""" + if hasattr(middleware, "__call__"): self.word_middleware.append((middleware, for_search)) def index_item(self, item): """This is the main function for indexing items. -It extracts content from the given *item* and add it to the index. + It extracts content from the given *item* and add it to the index. -:param item: an instance of a :class:`stdnet.odm.StdModel`. -""" + :param item: an instance of a :class:`stdnet.odm.StdModel`.""" self.index_items_from_model((item,), item.__class__) def query(self, model): - '''Return a query for ``model`` when it needs to be indexed. - ''' + """Return a query for ``model`` when it needs to be indexed.""" session = self.router.session() - fields = tuple((f.name for f in model._meta.scalarfields - if f.type == 'text')) + fields = tuple((f.name for f in model._meta.scalarfields if f.type == "text")) qs = session.query(model).load_only(*fields) for related in self.get_related_fields(model): qs = qs.load_related(related) - return qs + return qs def session(self): - '''Create a session for the search engine''' + """Create a session for the search engine""" return self.router.session() # INTERNALS @@ -170,7 +162,7 @@ def item_field_iterator(self, item): def _item_data(self, items): fi = self.item_field_iterator for item in items: - if item is None: # stop if we get a None + if item is None: # stop if we get a None break data = fi(item) if data: @@ -190,33 +182,32 @@ def index_items_from_model(self, items, model): raise NotImplementedError def remove_item(self, item_or_model, session, ids=None): - '''Remove an item from the search indices''' + """Remove an item from the search indices""" raise NotImplementedError def add_item(self, item, words, transaction): - '''Create indices for *item* and each word in *words*. Must be -implemented by subclasses. - -:param item: a *model* instance to be indexed. It does not need to be - a :class:`stdnet.odm.StdModel`. -:param words: iterable over words. It has been obtained from the - text in *item* via the :attr:`word_middleware`. -:param transaction: The :class:`Transaction` used. -''' + """Create indices for *item* and each word in *words*. Must be + implemented by subclasses. + + :param item: a *model* instance to be indexed. It does not need to be + a :class:`stdnet.odm.StdModel`. + :param words: iterable over words. It has been obtained from the + text in *item* via the :attr:`word_middleware`. + :param transaction: The :class:`Transaction` used.""" raise NotImplementedError def search(self, text, include=None, exclude=None, lookup=None): - '''Full text search. Must be implemented by subclasses. + """Full text search. Must be implemented by subclasses. -:param test: text to search -:param include: optional list of models to include in the search. If not - provided all :attr:`REGISTERED_MODELS` will be used. -:param exclude: optional list of models to exclude for the search. -:param lookup: currently not used.''' + :param test: text to search + :param include: optional list of models to include in the search. If not + provided all :attr:`REGISTERED_MODELS` will be used. + :param exclude: optional list of models to exclude for the search. + :param lookup: currently not used.""" raise NotImplementedError def search_model(self, query, text, lookup=None): - '''Search *text* in *model* instances. + """Search *text* in *model* instances. This is the functions needing implementation by custom search engines. @@ -224,58 +215,61 @@ def search_model(self, query, text, lookup=None): :param text: text to search :param lookup: Optional lookup, one of ``contains`` or ``in``. :return: An updated :class:`Query`. - ''' + """ raise NotImplementedError def flush(self, full=False): - '''Clean the search engine''' + """Clean the search engine""" raise NotImplementedError def reindex(self): - '''Re-index models. + """Re-index models. Remove existing indices indexes and rebuilding them by iterating through all the instances of :attr:`REGISTERED_MODELS`. - ''' + """ raise NotImplementedError class UpdateSE(object): - def __init__(self, se, related=None): self.se = se self.related = related or () def __call__(self, signal, sender, instances=None, **kwargs): - '''An update on ``instances`` has occurred. + """An update on ``instances`` has occurred. Propagate it to the search engine index models. - ''' + """ if sender: # get a new session models = self.se.router se_session = models.session() if signal == models.post_delete: return self.remove(instances, sender, se_session) - else: + else: return self.index(instances, sender, se_session) def index(self, instances, sender, session): return self.se.index_items_from_model(instances, sender) def remove(self, instances, sender, session): - self.se.logger.debug('Removing from search index %s instances of %s', - len(instances), sender._meta) + self.se.logger.debug( + "Removing from search index %s instances of %s", + len(instances), + sender._meta, + ) remove_item = self.se.remove_item - with session.begin(name='Remove search indexes') as t: + with session.begin(name="Remove search indexes") as t: remove_item(sender, t, instances) return t.on_result class stdnet_processor(object): - '''A search engine processor for stdnet models. -An engine processor is a callable -which return an iterable over text.''' + """A search engine processor for stdnet models. + An engine processor is a callable + which return an iterable over text.""" + def __init__(self, se): self.se = se @@ -285,7 +279,7 @@ def __call__(self, item): for field in item._meta.fields: if field.hidden: continue - if field.type == 'text': + if field.type == "text": if hasattr(item, field.attname): data.append(getattr(item, field.attname)) else: diff --git a/stdnet/odm/session.py b/stdnet/odm/session.py index 378ecd4..9a2c1d2 100644 --- a/stdnet/odm/session.py +++ b/stdnet/odm/session.py @@ -1,19 +1,20 @@ from itertools import chain - -from stdnet import session_result, session_data, async -from stdnet.utils import itervalues, iteritems -from stdnet.utils.structures import OrderedDict -from stdnet.utils.exceptions import * -from .query import Q, Query, EmptyQuery +from stdnet import async, session_data, session_result +from stdnet.utils import iteritems, itervalues +from stdnet.utils.exceptions import * +from stdnet.utils.structures import OrderedDict +from .query import EmptyQuery, Q, Query -__all__ = ['Session', - 'SessionModel', - 'Manager', - 'LazyProxy', - 'Transaction', - 'ModelDictionary'] +__all__ = [ + "Session", + "SessionModel", + "Manager", + "LazyProxy", + "Transaction", + "ModelDictionary", +] def is_query(query): @@ -21,7 +22,6 @@ def is_query(query): class ModelDictionary(dict): - def __contains__(self, model): return super(ModelDictionary, self).__contains__(self.meta(model)) @@ -38,12 +38,13 @@ def pop(self, model, *args): return super(ModelDictionary, self).pop(self.meta(model), *args) def meta(self, model): - return getattr(model, '_meta', model) + return getattr(model, "_meta", model) class SessionModel(object): - '''A :class:`SessionModel` is the container of all objects for a given -:class:`Model` in a stdnet :class:`Session`.''' + """A :class:`SessionModel` is the container of all objects for a given + :class:`Model` in a stdnet :class:`Session`.""" + def __init__(self, manager): self.manager = manager self._new = OrderedDict() @@ -54,83 +55,90 @@ def __init__(self, manager): self._structures = set() def __len__(self): - return (len(self._new) + len(self._modified) + len(self._deleted) + - len(self.commit_structures) + len(self.delete_structures)) + return ( + len(self._new) + + len(self._modified) + + len(self._deleted) + + len(self.commit_structures) + + len(self.delete_structures) + ) def __repr__(self): return self._meta.__repr__() + __str__ = __repr__ @property def backend(self): - '''The backend for this :class:`SessionModel`.''' + """The backend for this :class:`SessionModel`.""" return self.manager.backend @property def read_backend(self): - '''The read-only backend for this :class:`SessionModel`.''' + """The read-only backend for this :class:`SessionModel`.""" return self.manager.read_backend @property def model(self): - '''The :class:`Model` for this :class:`SessionModel`.''' + """The :class:`Model` for this :class:`SessionModel`.""" return self.manager.model @property def _meta(self): - '''The :class:`Metaclass` for this :class:`SessionModel`.''' + """The :class:`Metaclass` for this :class:`SessionModel`.""" return self.manager._meta @property def new(self): - '''The set of all new instances within this ``SessionModel``. This -instances will be inserted in the database.''' + """The set of all new instances within this ``SessionModel``. This + instances will be inserted in the database.""" return tuple(itervalues(self._new)) @property def modified(self): - '''The set of all modified instances within this ``Session``. This -instances will.''' + """The set of all modified instances within this ``Session``. This + instances will.""" return tuple(itervalues(self._modified)) @property def deleted(self): - '''The set of all instance pks marked as `deleted` within this -:class:`Session`.''' + """The set of all instance pks marked as `deleted` within this + :class:`Session`.""" return tuple((p.pkvalue() for p in itervalues(self._deleted))) @property def dirty(self): - '''The set of all instances which have changed, but not deleted, -within this :class:`SessionModel`.''' + """The set of all instances which have changed, but not deleted, + within this :class:`SessionModel`.""" return tuple(self.iterdirty()) def iterdirty(self): - '''Ordered iterator over dirty elements.''' + """Ordered iterator over dirty elements.""" return iter(chain(itervalues(self._new), itervalues(self._modified))) def __contains__(self, instance): iid = instance.get_state().iid - return (iid in self._new or - iid in self._modified or - iid in self._deleted or - instance in self._structures) - - def add(self, instance, modified=True, persistent=None, - force_update=False): - '''Add a new instance to this :class:`SessionModel`. - -:param modified: Optional flag indicating if the ``instance`` has been - modified. By default its value is ``True``. -:param force_update: if ``instance`` is persistent, it forces an update of the - data rather than a full replacement. This is used by the - :meth:`insert_update_replace` method. -:rtype: The instance added to the session''' - if instance._meta.type == 'structure': + return ( + iid in self._new + or iid in self._modified + or iid in self._deleted + or instance in self._structures + ) + + def add(self, instance, modified=True, persistent=None, force_update=False): + """Add a new instance to this :class:`SessionModel`. + + :param modified: Optional flag indicating if the ``instance`` has been + modified. By default its value is ``True``. + :param force_update: if ``instance`` is persistent, it forces an update of the + data rather than a full replacement. This is used by the + :meth:`insert_update_replace` method. + :rtype: The instance added to the session""" + if instance._meta.type == "structure": return self._add_structure(instance) state = instance.get_state() if state.deleted: - raise ValueError('State is deleted. Cannot add.') + raise ValueError("State is deleted. Cannot add.") self.pop(state.iid) pers = persistent if persistent is not None else state.persistent pkname = instance._meta.pkname() @@ -141,7 +149,7 @@ def add(self, instance, modified=True, persistent=None, instance._dbdata[pkname] = instance.pkvalue() state = instance.get_state(iid=instance.pkvalue()) else: - action = 'update' if force_update else None + action = "update" if force_update else None state = instance.get_state(action=action, iid=state.iid) iid = state.iid if state.persistent: @@ -152,8 +160,8 @@ def add(self, instance, modified=True, persistent=None, return instance def delete(self, instance, session): - '''delete an *instance*''' - if instance._meta.type == 'structure': + """delete an *instance*""" + if instance._meta.type == "structure": return self._delete_structure(instance) inst = self.pop(instance) instance = inst if inst is not None else instance @@ -168,13 +176,12 @@ def delete(self, instance, session): return instance def pop(self, instance): - '''Remove ``instance`` from the :class:`SessionModel`. Instance -could be a :class:`Model` or an id. + """Remove ``instance`` from the :class:`SessionModel`. Instance + could be a :class:`Model` or an id. -:parameter instance: a :class:`Model` or an ``id``. -:rtype: the :class:`Model` removed from session or ``None`` if - it was not in the session. -''' + :parameter instance: a :class:`Model` or an ``id``. + :rtype: the :class:`Model` removed from session or ``None`` if + it was not in the session.""" if isinstance(instance, self.model): iid = instance.get_state().iid else: @@ -186,29 +193,28 @@ def pop(self, instance): if instance is None: instance = inst elif inst is not instance: - raise ValueError('Critical error: %s is duplicated' % iid) + raise ValueError("Critical error: %s is duplicated" % iid) return instance def expunge(self, instance): - '''Remove *instance* from the :class:`Session`. Instance could be a -:class:`Model` or an id. + """Remove *instance* from the :class:`Session`. Instance could be a + :class:`Model` or an id. -:parameter instance: a :class:`Model` or an *id* -:rtype: the :class:`Model` removed from session or ``None`` if - it was not in the session. -''' + :parameter instance: a :class:`Model` or an *id* + :rtype: the :class:`Model` removed from session or ``None`` if + it was not in the session.""" instance = self.pop(instance) instance.session = None return instance def post_commit(self, results): - '''\ + """\ Process results after a commit. :parameter results: iterator over :class:`stdnet.instance_session_result` items. :rtype: a two elements tuple containing a list of instances saved and - a list of ids of instances deleted.''' + a list of ids of instances deleted.""" tpy = self._meta.pk_to_python instances = [] deleted = [] @@ -217,8 +223,11 @@ def post_commit(self, results): # all committed instances for result in results: if isinstance(result, Exception): - errors.append(result.__class__('Exception while committing %s.' - ' %s' % (self._meta, result))) + errors.append( + result.__class__( + "Exception while committing %s." " %s" % (self._meta, result) + ) + ) continue instance = self.pop(result.iid) id = tpy(result.id, self.backend) @@ -226,30 +235,34 @@ def post_commit(self, results): deleted.append(id) else: if instance is None: - raise InvalidTransaction('{0} session received id "{1}"\ - which is not in the session.'.format(self, result.iid)) + raise InvalidTransaction( + '{0} session received id "{1}"\ + which is not in the session.'.format( + self, result.iid + ) + ) setattr(instance, instance._meta.pkname(), id) - instance = self.add(instance, - modified=False, - persistent=result.persistent) + instance = self.add( + instance, modified=False, persistent=result.persistent + ) instance.get_state().score = result.score if instance.get_state().persistent: instances.append(instance) return instances, deleted, errors def flush(self): - '''Completely flush :attr:`model` from the database. No keys -associated with the model will exists after this operation.''' + """Completely flush :attr:`model` from the database. No keys + associated with the model will exists after this operation.""" return self.backend.flush(self._meta) def clean(self): - '''Remove empty keys for a :attr:`model` from the database. No -empty keys associated with the model will exists after this operation.''' + """Remove empty keys for a :attr:`model` from the database. No + empty keys associated with the model will exists after this operation.""" return self.backend.clean(self._meta) def keys(self): - '''Retrieve all keys for a :attr:`model`. Uses the -:attr:`Manager.read_backend`.''' + """Retrieve all keys for a :attr:`model`. Uses the + :attr:`Manager.read_backend`.""" return self.read_backend.model_keys(self._meta) ## INTERNALS @@ -281,34 +294,30 @@ def backends_data(self, session): queries = self._queries if dirty or has_delete or queries is not None or structures: if transaction.signal_delete and has_delete: - models.pre_delete.fire(model, instances=deletes, - session=session) + models.pre_delete.fire(model, instances=deletes, session=session) if dirty and transaction.signal_commit: - models.pre_commit.fire(model, instances=dirty, - session=session) + models.pre_commit.fire(model, instances=dirty, session=session) if be == rbe: - yield be, session_data(meta, dirty, deletes, queries, - structures) + yield be, session_data(meta, dirty, deletes, queries, structures) else: if dirty or has_delete or structures: - yield be, session_data(meta, dirty, deletes, (), - structures) + yield be, session_data(meta, dirty, deletes, (), structures) if queries: yield rbe, session_data(meta, (), (), queries, ()) def _add_structure(self, instance): - instance.action = 'update' + instance.action = "update" self._structures.add(instance) return instance def _delete_structure(self, instance): - instance.action = 'delete' + instance.action = "delete" self._structures.add(instance) return instance class Transaction(object): - '''Transaction class for pipelining commands to the backend server. + """Transaction class for pipelining commands to the backend server. An instance of this class is usually obtained via the :meth:`Session.begin` or the :meth:`Manager.transaction` methods:: @@ -363,12 +372,12 @@ class Transaction(object): Dictionary of list of ids saved in the backend server after a commit operation. This dictionary is only available once the transaction has :attr:`finished`. - ''' + """ + on_result = None - def __init__(self, session, name=None, signal_commit=True, - signal_delete=True): - self.name = name or 'transaction' + def __init__(self, session, name=None, signal_commit=True, signal_delete=True): + self.name = name or "transaction" self.session = session self.signal_commit = signal_commit self.signal_delete = signal_delete @@ -377,32 +386,32 @@ def __init__(self, session, name=None, signal_commit=True, @property def executed(self): - '''``True`` when this transaction has been executed. + """``True`` when this transaction has been executed. A transaction can be executed once only via the :meth:`commit` method. An executed transaction if :attr:`finished` once a response from the backend server has been processed. - ''' + """ return self.session is None def add(self, instance, **kwargs): - '''A convenience proxy for :meth:`Session.add` method.''' + """A convenience proxy for :meth:`Session.add` method.""" return self.session.add(instance, **kwargs) def delete(self, instance): - '''A convenience proxy for :meth:`Session.delete` method.''' + """A convenience proxy for :meth:`Session.delete` method.""" return self.session.delete(instance) def expunge(self, instance=None): - '''A convenience proxy for :meth:`Session.expunge` method.''' + """A convenience proxy for :meth:`Session.expunge` method.""" return self.session.expunge(instance) def query(self, model, **kwargs): - '''A convenience proxy for :meth:`Session.query` method.''' + """A convenience proxy for :meth:`Session.query` method.""" return self.session.query(model, **kwargs) def model(self, model): - '''A convenience proxy for :meth:`Session.model` method.''' + """A convenience proxy for :meth:`Session.model` method.""" return self.session.model(model) def __enter__(self): @@ -425,10 +434,11 @@ def rollback(self): self.session = None def commit(self, callback=None): - '''Close the transaction and commit session to the backend.''' + """Close the transaction and commit session to the backend.""" if self.executed: - raise InvalidTransaction('Invalid operation. ' - 'Transaction already executed.') + raise InvalidTransaction( + "Invalid operation. " "Transaction already executed." + ) session = self.session self.session = None self.on_result = self._commit(session, callback) @@ -484,14 +494,15 @@ def _post_commit(self, session, response): signals.append((models.post_commit.fire, sm, saved)) # Once finished we send signals for fire, sm, instances in signals: - for result in fire(sm.model, instances=instances, - session=session, transaction=self): + for result in fire( + sm.model, instances=instances, session=session, transaction=self + ): yield result if exceptions: nf = len(exceptions) if nf > 1: - error = 'There were %s exceptions during commit.\n\n' % nf - error += '\n\n'.join((str(e) for e in exceptions)) + error = "There were %s exceptions during commit.\n\n" % nf + error += "\n\n".join((str(e) for e in exceptions)) else: error = str(exceptions[0]) raise CommitException(error, failures=nf) @@ -508,7 +519,7 @@ def _async_commit(self, session, responses, callback): class Session(object): - '''The middleware for persistent operations on the back-end. + """The middleware for persistent operations on the back-end. It is created via the :meth:`Router.session` method. @@ -520,7 +531,8 @@ class Session(object): .. attribute:: router Instance of the :class:`Router` which created this :class:`Session`. - ''' + """ + def __init__(self, router): self.transaction = None self._models = OrderedDict() @@ -530,7 +542,7 @@ def __str__(self): return str(self._router) def __repr__(self): - return '%s: %s' % (self.__class__.__name__, self._router) + return "%s: %s" % (self.__class__.__name__, self._router) def __iter__(self): for sm in self._models.values(): @@ -545,31 +557,29 @@ def router(self): @property def dirty(self): - '''The set of instances in this :class:`Session` which have -been modified.''' - return frozenset(chain(*tuple((sm.dirty for sm - in itervalues(self._models))))) + """The set of instances in this :class:`Session` which have + been modified.""" + return frozenset(chain(*tuple((sm.dirty for sm in itervalues(self._models))))) def begin(self, **options): - '''Begin a new :class:`Transaction`. If this :class:`Session` -is already in a :ref:`transactional state `, -an error will occur. It returns the :attr:`transaction` attribute. + """Begin a new :class:`Transaction`. If this :class:`Session` + is already in a :ref:`transactional state `, + an error will occur. It returns the :attr:`transaction` attribute. -This method is mostly used within a ``with`` statement block:: + This method is mostly used within a ``with`` statement block:: - with session.begin() as t: - t.add(...) - ... + with session.begin() as t: + t.add(...) + ... -which is equivalent to:: + which is equivalent to:: - t = session.begin() - t.add(...) - ... - session.commit() + t = session.begin() + t.add(...) + ... + session.commit() -``options`` parameters are passed to the :class:`Transaction` constructor. -''' + ``options`` parameters are passed to the :class:`Transaction` constructor.""" if self.transaction is not None: raise InvalidTransaction("A transaction is already begun.") else: @@ -588,17 +598,17 @@ def commit(self): return self.transaction.commit() def query(self, model, **kwargs): - '''Create a new :class:`Query` for *model*.''' + """Create a new :class:`Query` for *model*.""" sm = self.model(model) query_class = sm.manager.query_class or Query return query_class(sm._meta, self, **kwargs) def empty(self, model): - '''Returns an empty :class:`Query` for ``model``.''' + """Returns an empty :class:`Query` for ``model``.""" return EmptyQuery(self.manager(model)._meta, self) def update_or_create(self, model, **kwargs): - '''Update or create a new instance of ``model``. + """Update or create a new instance of ``model``. This method can raise an exception if the ``kwargs`` dictionary contains field data that does not validate. @@ -607,12 +617,12 @@ def update_or_create(self, model, **kwargs): :param kwargs: dictionary of parameters. :returns: A two elements tuple containing the instance and a boolean indicating if the instance was created or not. - ''' + """ backend = self.model(model).backend return backend.execute(self._update_or_create(model, **kwargs)) def add(self, instance, modified=True, **params): - '''Add an ``instance`` to the session. + """Add an ``instance`` to the session. If the session is not in a :ref:`transactional state `, this operation @@ -627,7 +637,7 @@ def add(self, instance, modified=True, **params): If the instance is persistent (it is already stored in the database), an updated will be performed, otherwise a new entry will be created once the :meth:`commit` method is invoked. - ''' + """ sm = self.model(instance) instance.session = self o = sm.add(instance, modified=modified, **params) @@ -638,7 +648,7 @@ def add(self, instance, modified=True, **params): return o def delete(self, instance_or_query): - '''Delete an ``instance`` or a ``query``. + """Delete an ``instance`` or a ``query``. Adds ``instance_or_query`` to this :class:`Session` list of data to be deleted. If the session is not in a @@ -647,34 +657,33 @@ def delete(self, instance_or_query): :parameter instance_or_query: a :class:`Model` instance or a :class:`Query`. - ''' + """ sm = self.model(instance_or_query) # not an instance of a Model. Assume it is a query. if is_query(instance_or_query): if instance_or_query.session is not self: - raise ValueError('Adding a query generated by another session') + raise ValueError("Adding a query generated by another session") sm._delete_query.append(instance_or_query) else: instance_or_query = sm.delete(instance_or_query, self) if not self.transaction: transaction = self.begin() - return transaction.commit( - lambda: transaction.deleted.get(sm._meta)) + return transaction.commit(lambda: transaction.deleted.get(sm._meta)) else: return instance_or_query def flush(self, model): - '''Completely flush a :class:`Model` from the database. No keys -associated with the model will exists after this operation.''' + """Completely flush a :class:`Model` from the database. No keys + associated with the model will exists after this operation.""" return self.model(model).flush() def clean(self, model): - '''Remove empty keys for a :class:`Model` from the database. No -empty keys associated with the model will exists after this operation.''' + """Remove empty keys for a :class:`Model` from the database. No + empty keys associated with the model will exists after this operation.""" return self.model(model).clean() def keys(self, model): - '''Retrieve all keys for a *model*.''' + """Retrieve all keys for a *model*.""" return self.model(model).keys() def __contains__(self, instance): @@ -682,9 +691,9 @@ def __contains__(self, instance): return instance in sm if sm is not None else False def model(self, model, create=True): - '''Returns the :class:`SessionModel` for ``model`` which -can be :class:`Model`, or a :class:`MetaClass`, or an instance -of :class:`Model`.''' + """Returns the :class:`SessionModel` for ``model`` which + can be :class:`Model`, or a :class:`MetaClass`, or an instance + of :class:`Model`.""" manager = self.manager(model) sm = self._models.get(manager) if sm is None and create: @@ -693,8 +702,8 @@ def model(self, model, create=True): return sm def expunge(self, instance=None): - '''Remove ``instance`` from this :class:`Session`. If ``instance`` -is not given, it removes all instances from this :class:`Session`.''' + """Remove ``instance`` from this :class:`Session`. If ``instance`` + is not given, it removes all instances from this :class:`Session`.""" if instance is not None: sm = self._models.get(instance._meta) if sm: @@ -703,15 +712,15 @@ def expunge(self, instance=None): self._models.clear() def manager(self, model): - '''Retrieve the :class:`Manager` for ``model`` which can be any of the -values valid for the :meth:`model` method.''' + """Retrieve the :class:`Manager` for ``model`` which can be any of the + values valid for the :meth:`model` method.""" try: return self.router[model] except KeyError: - meta = getattr(model, '_meta', model) - if meta.type == 'structure': + meta = getattr(model, "_meta", model) + if meta.type == "structure": # this is a structure - if hasattr(model, 'model'): + if hasattr(model, "model"): structure_model = model.model if structure_model: return self.manager(structure_model) @@ -773,19 +782,20 @@ def _update_or_create(self, model, **kwargs): class LazyProxy(object): - '''Base class for descriptors used by :class:`ForeignKey` and -:class:`StructureField`. + """Base class for descriptors used by :class:`ForeignKey` and + :class:`StructureField`. + + .. attribute:: field -.. attribute:: field + The :class:`Field` which create this descriptor. Either a + :class:`ForeignKey` or a :class:`StructureField`.""" - The :class:`Field` which create this descriptor. Either a - :class:`ForeignKey` or a :class:`StructureField`. -''' def __init__(self, field): self.field = field def __repr__(self): return self.field.name + __str__ = __repr__ @property @@ -793,12 +803,12 @@ def name(self): return self.field.name def load(self, instance, session): - '''Load the lazy data for this descriptor. Implemented by -subclasses.''' + """Load the lazy data for this descriptor. Implemented by + subclasses.""" raise NotImplementedError def load_from_manager(self, manager): - raise NotImplementedError('cannot access %s from manager' % self) + raise NotImplementedError("cannot access %s from manager" % self) def __get__(self, instance, instance_type=None): if not self.field.class_field: @@ -810,64 +820,64 @@ def __get__(self, instance, instance_type=None): class Manager(object): - '''Before a :class:`StdModel` can be used in conjunction -with a :ref:`backend server `, a :class:`Manager` must be associated -with it via a :class:`Router`. Check the -:ref:`registration tutorial ` for further info:: + """Before a :class:`StdModel` can be used in conjunction + with a :ref:`backend server `, a :class:`Manager` must be associated + with it via a :class:`Router`. Check the + :ref:`registration tutorial ` for further info:: - class MyModel(odm.StdModel): - group = odm.SymbolField() - flag = odm.BooleanField() + class MyModel(odm.StdModel): + group = odm.SymbolField() + flag = odm.BooleanField() - models = odm.Router() - models.register(MyModel) + models = odm.Router() + models.register(MyModel) - manager = models[MyModel] + manager = models[MyModel] -Managers are used as :class:`Session` and :class:`Query` factories -for a given :class:`StdModel`:: + Managers are used as :class:`Session` and :class:`Query` factories + for a given :class:`StdModel`:: - session = router[MyModel].session() - query = router[MyModel].query() + session = router[MyModel].session() + query = router[MyModel].query() -A model can specify a :ref:`custom manager ` by -creating a :class:`Manager` subclass with additional methods:: + A model can specify a :ref:`custom manager ` by + creating a :class:`Manager` subclass with additional methods:: - class MyModelManager(odm.Manager): + class MyModelManager(odm.Manager): - def special_query(self, **kwargs): - ... + def special_query(self, **kwargs): + ... + + At this point we need to tell the model about the custom manager, and we do + so by setting the ``manager_class`` attribute in the :class:`StdModel`:: -At this point we need to tell the model about the custom manager, and we do -so by setting the ``manager_class`` attribute in the :class:`StdModel`:: + class MyModel(odm.StdModel): + ... - class MyModel(odm.StdModel): - ... + manager_class = MyModelManager - manager_class = MyModelManager + .. attribute:: model -.. attribute:: model + The :class:`StdModel` for this :class:`Manager`. This attribute is + assigned by the Object data mapper at runtime. - The :class:`StdModel` for this :class:`Manager`. This attribute is - assigned by the Object data mapper at runtime. + .. attribute:: router -.. attribute:: router + The :class:`Router` which contain this this :class:`Manager`. - The :class:`Router` which contain this this :class:`Manager`. + .. attribute:: backend -.. attribute:: backend + The :class:`stdnet.BackendDataServer` for this :class:`Manager`. - The :class:`stdnet.BackendDataServer` for this :class:`Manager`. + .. attribute:: read_backend -.. attribute:: read_backend + A :class:`stdnet.BackendDataServer` for read-only operations (Queries). - A :class:`stdnet.BackendDataServer` for read-only operations (Queries). + .. attribute:: query_class -.. attribute:: query_class + Class for querying. Default is :class:`Query`.""" - Class for querying. Default is :class:`Query`. -''' session_factory = Session query_class = None @@ -894,7 +904,7 @@ def read_backend(self): return self._read_backend or self._backend def __getattr__(self, attrname): - if attrname.startswith('__'): # required for copy + if attrname.startswith("__"): # required for copy raise AttributeError else: result = getattr(self.model, attrname) @@ -904,11 +914,12 @@ def __getattr__(self, attrname): def __str__(self): if self.backend: - return '{0}({1} - {2})'.format(self.__class__.__name__, - self._meta, - self.backend) + return "{0}({1} - {2})".format( + self.__class__.__name__, self._meta, self.backend + ) else: - return '{0}({1})'.format(self.__class__.__name__, self._meta) + return "{0}({1})".format(self.__class__.__name__, self._meta) + __repr__ = __str__ def __call__(self, *args, **kwargs): @@ -917,70 +928,68 @@ def __call__(self, *args, **kwargs): return self.model(*args, **kwargs) def session(self, session=None): - '''Returns a new :class:`Session`. This is a shortcut for the -:meth:`Router.session` method.''' + """Returns a new :class:`Session`. This is a shortcut for the + :meth:`Router.session` method.""" return self._router.session() def new(self, *args, **kwargs): - '''Create a new instance of :attr:`model` and commit it to the backend -server. This a shortcut method for the more verbose:: + """Create a new instance of :attr:`model` and commit it to the backend + server. This a shortcut method for the more verbose:: - instance = manager.session().add(MyModel(**kwargs)) -''' + instance = manager.session().add(MyModel(**kwargs))""" return self.session().add(self.model(*args, **kwargs)) def save(self, instance): - '''Save an existing instance of :attr:`model`. This a shortcut -method for the more verbose:: + """Save an existing instance of :attr:`model`. This a shortcut + method for the more verbose:: - instance = manager.session().add(instance) -''' + instance = manager.session().add(instance)""" return self.session().add(instance) def update_or_create(self, **kwargs): - '''Invokes the :class:`Session.update_or_create` method.''' + """Invokes the :class:`Session.update_or_create` method.""" return self.session().update_or_create(self.model, **kwargs) def all(self): - '''Return all instances for this manager. -Equivalent to:: + """Return all instances for this manager. + Equivalent to:: - self.query().all() - ''' + self.query().all() + """ return self.query().all() def create_all(self): - '''A method which can implement table creation. For sql models. Does -nothing for redis or mongo.''' + """A method which can implement table creation. For sql models. Does + nothing for redis or mongo.""" pass def query(self, session=None): - '''Returns a new :class:`Query` for :attr:`Manager.model`.''' + """Returns a new :class:`Query` for :attr:`Manager.model`.""" if session is None or session.router is not self.router: session = self.session() return session.query(self.model) def empty(self): - '''Returns an empty :class:`Query` for :attr:`Manager.model`.''' + """Returns an empty :class:`Query` for :attr:`Manager.model`.""" return self.session().empty(self.model) def filter(self, **kwargs): - '''Returns a new :class:`Query` for :attr:`Manager.model` with -a filter.''' + """Returns a new :class:`Query` for :attr:`Manager.model` with + a filter.""" return self.query().filter(**kwargs) def exclude(self, **kwargs): - '''Returns a new :class:`Query` for :attr:`Manager.model` with -a exclude filter.''' + """Returns a new :class:`Query` for :attr:`Manager.model` with + a exclude filter.""" return self.query().exclude(**kwargs) def search(self, text, lookup=None): - '''Returns a new :class:`Query` for :attr:`Manager.model` with -a full text search value.''' + """Returns a new :class:`Query` for :attr:`Manager.model` with + a full text search value.""" return self.query().search(text, lookup=lookup) def get(self, **kwargs): - '''Shortcut for ``self.query().get**kwargs)``.''' + """Shortcut for ``self.query().get**kwargs)``.""" return self.query().get(**kwargs) def flush(self): @@ -993,7 +1002,7 @@ def keys(self): return self.session().keys(self.model) def pkvalue(self, instance): - '''Return the primary key value for ``instance``.''' + """Return the primary key value for ``instance``.""" return instance.pkvalue() def __hash__(self): @@ -1001,6 +1010,5 @@ def __hash__(self): class StructureManager(Manager): - def __hash__(self): return hash(self.model) diff --git a/stdnet/odm/struct.py b/stdnet/odm/struct.py index 79b3af6..de96bae 100755 --- a/stdnet/odm/struct.py +++ b/stdnet/odm/struct.py @@ -1,40 +1,43 @@ from uuid import uuid4 -from stdnet.utils import iteritems, encoders, BytesIO, iterpair, ispy3k -from stdnet.utils.zset import zset +from stdnet.utils import BytesIO, encoders, ispy3k, iteritems, iterpair from stdnet.utils.skiplist import skiplist +from stdnet.utils.zset import zset from .base import ModelBase - -__all__ = ['Structure', - 'StructureCache', - 'Sequence', - 'OrderedMixin', - 'KeyValueMixin', - 'String', - 'List', - 'Set', - 'Zset', - 'HashTable', - 'TS', - 'NumberArray', - # Mixins - 'OrderedMixin', - 'PairMixin', - 'KeyValueMixin', - 'commit_when_no_transaction'] +__all__ = [ + "Structure", + "StructureCache", + "Sequence", + "OrderedMixin", + "KeyValueMixin", + "String", + "List", + "Set", + "Zset", + "HashTable", + "TS", + "NumberArray", + # Mixins + "OrderedMixin", + "PairMixin", + "KeyValueMixin", + "commit_when_no_transaction", +] passthrough = lambda r: r def commit_when_no_transaction(f): - '''Decorator for committing changes when the instance session is -not in a transaction.''' + """Decorator for committing changes when the instance session is + not in a transaction.""" + def _(self, *args, **kwargs): r = f(self, *args, **kwargs) return self.session.add(self) if self.session is not None else r + _.__name__ = f.__name__ _.__doc__ = f.__doc__ return _ @@ -43,19 +46,20 @@ def _(self, *args, **kwargs): ########################################################################### ## CACHE CLASSES FOR STRUCTURES ########################################################################### -class StructureCache(object): - '''Interface for all :attr:`Structure.cache` classes.''' +class StructureCache: + """Interface for all :attr:`Structure.cache` classes.""" + def __init__(self): self.clear() def __str__(self): if self.cache is None: - return '' + return "" else: return str(self.cache) def clear(self): - '''Clear the cache for data''' + """Clear the cache for data""" self.cache = None def items(self): @@ -66,7 +70,6 @@ def set_cache(self, data): class stringcache(StructureCache): - def getvalue(self): return self.data.getvalue() @@ -79,7 +82,6 @@ def clear(self): class listcache(StructureCache): - def push_front(self, value): self.front.append(value) @@ -98,7 +100,6 @@ def set_cache(self, data): class setcache(StructureCache): - def __contains__(self, v): if v not in self.toremove: return v in self.cache or v in self.toadd @@ -126,7 +127,6 @@ def set_cache(self, data): class zsetcache(setcache): - def clear(self): self.cache = None self.toadd = zset() @@ -139,7 +139,6 @@ def set_cache(self, data): class hashcache(zsetcache): - def clear(self): self.cache = None self.toadd = {} @@ -164,7 +163,6 @@ def remove(self, keys, add_to_remove=True): class tscache(hashcache): - def clear(self): self.cache = None self.toadd = skiplist() @@ -180,7 +178,7 @@ def set_cache(self, data): ## STRUCTURE CLASSES ############################################################################ class Structure(ModelBase): - '''A :class:`Model` for remote data-structures. + """A :class:`Model` for remote data-structures. Remote structures are the backend of :ref:`structured fields ` but they @@ -210,25 +208,35 @@ class Structure(ModelBase): The :class:`StructureField` which this owns this :class:`Structure`. Default ``None``. - ''' - _model_type = 'structure' + """ + + _model_type = "structure" abstract = True pickler = None value_pickler = None - def __init__(self, value_pickler=None, name='', field=False, - session=None, pkvalue=None, id=None, **kwargs): + def __init__( + self, + value_pickler=None, + name="", + field=False, + session=None, + pkvalue=None, + id=None, + **kwargs + ): self._field = field self._pkvalue = pkvalue self.id = id self.name = name - self.value_pickler = (value_pickler or self.value_pickler or - encoders.NumericDefault()) + self.value_pickler = ( + value_pickler or self.value_pickler or encoders.NumericDefault() + ) self.setup(**kwargs) self.session = session if not self.id and not self._field: self.id = self.makeid() - self.dbdata['id'] = self.id + self.dbdata["id"] = self.id def makeid(self): return str(uuid4())[:8] @@ -242,21 +250,20 @@ def field(self): @property def model(self): - '''The :class:`StdModel` which contains the :attr:`field` of this -:class:`Structure`. Only available if :attr:`field` is defined.''' + """The :class:`StdModel` which contains the :attr:`field` of this + :class:`Structure`. Only available if :attr:`field` is defined.""" if self._field: return self._field.model @property def cache(self): - if 'cache' not in self._dbdata: - self._dbdata['cache'] = self.cache_class() - return self._dbdata['cache'] + if "cache" not in self._dbdata: + self._dbdata["cache"] = self.cache_class() + return self._dbdata["cache"] @property def backend(self): - '''Returns the :class:`stdnet.BackendStructure`. - ''' + """Returns the :class:`stdnet.BackendStructure`.""" session = self.session if session is not None: if self._field: @@ -266,8 +273,7 @@ def backend(self): @property def read_backend(self): - '''Returns the :class:`stdnet.BackendStructure`. - ''' + """Returns the :class:`stdnet.BackendStructure`.""" session = self.session if session is not None: if self._field: @@ -276,7 +282,7 @@ def read_backend(self): return session.model(self).read_backend def __repr__(self): - return '%s %s' % (self.__class__.__name__, self.cache) + return "%s %s" % (self.__class__.__name__, self.cache) def __str__(self): return self.__repr__() @@ -286,7 +292,7 @@ def __iter__(self): return iter(self.items()) def size(self): - '''Number of elements in the :class:`Structure`.''' + """Number of elements in the :class:`Structure`.""" if self.cache.cache is None: return self.read_backend_structure().size() else: @@ -299,22 +305,21 @@ def __len__(self): return self.size() def items(self): - '''All items of this :class:`Structure`. Implemented by subclasses.''' + """All items of this :class:`Structure`. Implemented by subclasses.""" raise NotImplementedError def load_data(self, data, callback=None): - '''Load ``data`` from the :class:`stdnet.BackendDataServer`.''' + """Load ``data`` from the :class:`stdnet.BackendDataServer`.""" return self.backend.execute( - self.value_pickler.load_iterable(data, self.session), callback) + self.value_pickler.load_iterable(data, self.session), callback + ) def backend_structure(self, client=None): - '''Returns the :class:`stdnet.BackendStructure`. - ''' + """Returns the :class:`stdnet.BackendStructure`.""" return self.backend.structure(self, client) def read_backend_structure(self, client=None): - '''Returns the :class:`stdnet.BackendStructure` for reading. - ''' + """Returns the :class:`stdnet.BackendStructure` for reading.""" return self.read_backend.structure(self, client) def _items(self, data): @@ -326,17 +331,17 @@ def _items(self, data): ## Mixins Structures ############################################################################ class PairMixin(object): - '''A mixin for structures with which holds pairs. It is the parent class -of :class:`KeyValueMixin` and it is used as base class for the ordered set -structure :class:`Zset`. + """A mixin for structures with which holds pairs. It is the parent class + of :class:`KeyValueMixin` and it is used as base class for the ordered set + structure :class:`Zset`. -.. attribute:: pickler + .. attribute:: pickler - An :ref:`encoder ` for the additional value in the pair. - The additional value is a field key for :class:`Hashtable`, - a numeric score for :class:`Zset` and a tim value for :class:`TS`. + An :ref:`encoder ` for the additional value in the pair. + The additional value is a field key for :class:`Hashtable`, + a numeric score for :class:`Zset` and a tim value for :class:`TS`. + """ -''' pickler = encoders.NoEncoder() def setup(self, pickler=None, **kwargs): @@ -346,24 +351,25 @@ def __setitem__(self, key, value): self.add(key, value) def items(self): - '''Iterator over items (pairs) of :class:`PairMixin`.''' + """Iterator over items (pairs) of :class:`PairMixin`.""" if self.cache.cache is None: backend = self.read_backend - backend.execute(backend.structure(self).items(), - lambda data: self.load_data(data, self._items)) + backend.execute( + backend.structure(self).items(), + lambda data: self.load_data(data, self._items), + ) return self.cache.items() def values(self): - '''Iteratir over values of :class:`PairMixin`.''' + """Iteratir over values of :class:`PairMixin`.""" if self.cache.cache is None: backend = self.read_backend - return backend.execute(backend.structure(self).values(), - self.load_values) + return backend.execute(backend.structure(self).values(), self.load_values) else: return self.cache.cache.values() def pair(self, pair): - '''Add a *pair* to the structure.''' + """Add a *pair* to the structure.""" if len(pair) == 1: # if only one value is passed, the value must implement a # score function which retrieve the first value of the pair @@ -371,8 +377,7 @@ def pair(self, pair): # hashtable) return (pair[0].score(), pair[0]) elif len(pair) != 2: - raise TypeError('add expected 2 arguments, got {0}' - .format(len(pair))) + raise TypeError("add expected 2 arguments, got {0}".format(len(pair))) else: return pair @@ -381,10 +386,10 @@ def add(self, *pair): @commit_when_no_transaction def update(self, mapping): - '''Add *mapping* dictionary to hashtable. -Equivalent to python dictionary update method. + """Add *mapping* dictionary to hashtable. + Equivalent to python dictionary update method. -:parameter mapping: a dictionary of field values.''' + :parameter mapping: a dictionary of field values.""" self.cache.update(self.dump_data(mapping)) def dump_data(self, mapping): @@ -396,7 +401,7 @@ def dump_data(self, mapping): data = [] for pair in mapping: if not isinstance(pair, tuple): - pair = pair, + pair = (pair,) k, v = p(pair) data.append((tokey(k), dumps(v))) return data @@ -410,10 +415,10 @@ def _iterable(): for k, v in iterpair(mapping): data1.append(loads(k)) yield v + res = self.value_pickler.load_iterable(_iterable(), self.session) callback = callback or passthrough - return self.backend.execute( - res, lambda data2: callback(zip(data1, data2))) + return self.backend.execute(res, lambda data2: callback(zip(data1, data2))) else: vloads = self.value_pickler.loads data = [(loads(k), vloads(v)) for k, v in iterpair(mapping)] @@ -429,22 +434,22 @@ def load_values(self, iterable): class KeyValueMixin(PairMixin): - '''A mixin for ordered and unordered key-valued pair containers. -A key-value pair container has the :meth:`values` and :meth:`items` -methods, while its iterator is over keys.''' + """A mixin for ordered and unordered key-valued pair containers. + A key-value pair container has the :meth:`values` and :meth:`items` + methods, while its iterator is over keys.""" + def __iter__(self): return iter(self.keys()) def keys(self): if self.cache.cache is None: backend = self.read_backend - return backend.execute(backend.structure(self).keys(), - self.load_keys) + return backend.execute(backend.structure(self).keys(), self.load_keys) else: return self.cache.cache def __delitem__(self, key): - '''Remove an element. Same as the :meth:`remove` method`.''' + """Remove an element. Same as the :meth:`remove` method`.""" return self.pop(key) def __getitem__(self, key): @@ -457,17 +462,16 @@ def __getitem__(self, key): return self.cache.cache[key] def get(self, key, default=None): - '''Retrieve a single element from the structure. -If the element is not available return the default value. + """Retrieve a single element from the structure. + If the element is not available return the default value. -:parameter key: lookup field -:parameter default: default value when the field is not available''' + :parameter key: lookup field + :parameter default: default value when the field is not available""" if self.cache.cache is None: dkey = self.pickler.dumps(key) backend = self.read_backend res = backend.structure(self).get(dkey) - return backend.execute( - res, lambda r: self._load_get_data(r, key, default)) + return backend.execute(res, lambda r: self._load_get_data(r, key, default)) else: return self.cache.cache.get(key, default) @@ -476,15 +480,15 @@ def pop(self, key, *args): dkey = self.pickler.dumps(key) backend = self.backend res = backend.structure(self).pop(dkey) - return backend.execute( - res, lambda r: self._load_get_data(r, key, *args)) + return backend.execute(res, lambda r: self._load_get_data(r, key, *args)) else: - raise TypeError('pop expected at most 2 arguments, got {0}' - .format(len(args)+1)) + raise TypeError( + "pop expected at most 2 arguments, got {0}".format(len(args) + 1) + ) @commit_when_no_transaction def remove(self, *keys): - '''Remove *keys* from the key-value container.''' + """Remove *keys* from the key-value container.""" dumps = self.pickler.dumps self.cache.remove([dumps(v) for v in keys]) @@ -502,51 +506,49 @@ def _load_get_data(self, value, key, *args): class OrderedMixin(object): - '''A mixin for a :class:`Structure` which maintains ordering with respect -a numeric value, the score.''' + """A mixin for a :class:`Structure` which maintains ordering with respect + a numeric value, the score.""" def front(self): - '''Return the front pair of the structure''' + """Return the front pair of the structure""" v = tuple(self.irange(0, 0)) if v: return v[0] def back(self): - '''Return the back pair of the structure''' + """Return the back pair of the structure""" v = tuple(self.irange(-1, -1)) if v: return v[0] def count(self, start, stop): - '''Count the number of elements bewteen *start* and *stop*.''' + """Count the number of elements bewteen *start* and *stop*.""" s1 = self.pickler.dumps(start) s2 = self.pickler.dumps(stop) return self.backend_structure().count(s1, s2) def range(self, start, stop, callback=None, withscores=True, **options): - '''Return a range with scores between start and end.''' + """Return a range with scores between start and end.""" s1 = self.pickler.dumps(start) s2 = self.pickler.dumps(stop) backend = self.read_backend - res = backend.structure(self).range(s1, s2, withscores=withscores, - **options) + res = backend.structure(self).range(s1, s2, withscores=withscores, **options) if not callback: callback = self.load_data if withscores else self.load_values return backend.execute(res, callback) - def irange(self, start=0, end=-1, callback=None, withscores=True, - **options): - '''Return the range by rank between start and end.''' + def irange(self, start=0, end=-1, callback=None, withscores=True, **options): + """Return the range by rank between start and end.""" backend = self.read_backend - res = backend.structure(self).irange(start, end, - withscores=withscores, - **options) + res = backend.structure(self).irange( + start, end, withscores=withscores, **options + ) if not callback: callback = self.load_data if withscores else self.load_values return backend.execute(res, callback) def pop_range(self, start, stop, callback=None, withscores=True): - '''pop a range by score from the :class:`OrderedMixin`''' + """pop a range by score from the :class:`OrderedMixin`""" s1 = self.pickler.dumps(start) s2 = self.pickler.dumps(stop) backend = self.backend @@ -556,19 +558,19 @@ def pop_range(self, start, stop, callback=None, withscores=True): return backend.execute(res, callback) def ipop_range(self, start=0, stop=-1, callback=None, withscores=True): - '''pop a range from the :class:`OrderedMixin`''' + """pop a range from the :class:`OrderedMixin`""" backend = self.backend - res = backend.structure(self).ipop_range(start, stop, - withscores=withscores) + res = backend.structure(self).ipop_range(start, stop, withscores=withscores) if not callback: callback = self.load_data if withscores else self.load_values return backend.execute(res, callback) class Sequence(object): - '''Mixin for a :class:`Structure` which implements a kind of sequence -container. The elements in a sequence container are ordered following a linear -sequence.''' + """Mixin for a :class:`Structure` which implements a kind of sequence + container. The elements in a sequence container are ordered following a linear + sequence.""" + cache_class = listcache def items(self): @@ -576,20 +578,22 @@ def items(self): backend = self.read_backend return backend.execute( backend.structure(self).range(), - lambda data: self.load_data(data, self._items)) + lambda data: self.load_data(data, self._items), + ) return self.cache.items() @commit_when_no_transaction def push_back(self, value): - '''Appends a copy of *value* at the end of the :class:`Sequence`.''' + """Appends a copy of *value* at the end of the :class:`Sequence`.""" self.cache.push_back(self.value_pickler.dumps(value)) return self def pop_back(self): - '''Remove the last element from the :class:`Sequence`.''' + """Remove the last element from the :class:`Sequence`.""" backend = self.backend - return backend.execute(backend.structure(self).pop_back(), - self.value_pickler.loads) + return backend.execute( + backend.structure(self).pop_back(), self.value_pickler.loads + ) def __getitem__(self, index): backend = self.read_backend @@ -605,84 +609,92 @@ def __setitem__(self, index, value): ## STRUCTURES ############################################################################ + class Set(Structure): - '''An unordered set :class:`Structure`. Equivalent to a python ``set``.''' + """An unordered set :class:`Structure`. Equivalent to a python ``set``.""" + cache_class = setcache @commit_when_no_transaction def add(self, value): - '''Add *value* to the set''' + """Add *value* to the set""" return self.cache.update((self.value_pickler.dumps(value),)) @commit_when_no_transaction def update(self, values): - '''Add iterable *values* to the set''' + """Add iterable *values* to the set""" d = self.value_pickler.dumps return self.cache.update(tuple((d(v) for v in values))) @commit_when_no_transaction def discard(self, value): - '''Remove an element *value* from a set if it is a member.''' + """Remove an element *value* from a set if it is a member.""" return self.cache.remove((self.value_pickler.dumps(value),)) + remove = discard @commit_when_no_transaction def difference_update(self, values): - '''Remove an iterable of *values* from the set.''' + """Remove an iterable of *values* from the set.""" d = self.value_pickler.dumps return self.cache.remove(tuple((d(v) for v in values))) class List(Sequence, Structure): - '''A doubly-linked list :class:`Structure`. It expands the -:class:`Sequence` mixin with functionalities to add and remove from -the front of the list in an efficient manner.''' + """A doubly-linked list :class:`Structure`. It expands the + :class:`Sequence` mixin with functionalities to add and remove from + the front of the list in an efficient manner.""" + def pop_front(self): - '''Remove the first element from of the list.''' + """Remove the first element from of the list.""" backend = self.backend - return backend.execute(backend.structure(self).pop_front(), - self.value_pickler.loads) + return backend.execute( + backend.structure(self).pop_front(), self.value_pickler.loads + ) def block_pop_back(self, timeout=10): - '''Remove the last element from of the list. If no elements are -available, blocks for at least ``timeout`` seconds.''' + """Remove the last element from of the list. If no elements are + available, blocks for at least ``timeout`` seconds.""" value = yield self.backend_structure().block_pop_back(timeout) if value is not None: yield self.value_pickler.loads(value) def block_pop_front(self, timeout=10): - '''Remove the first element from of the list. If no elements are -available, blocks for at least ``timeout`` seconds.''' + """Remove the first element from of the list. If no elements are + available, blocks for at least ``timeout`` seconds.""" value = yield self.backend_structure().block_pop_front(timeout) if value is not None: yield self.value_pickler.loads(value) @commit_when_no_transaction def push_front(self, value): - '''Appends a copy of ``value`` to the beginning of the list.''' + """Appends a copy of ``value`` to the beginning of the list.""" self.cache.push_front(self.value_pickler.dumps(value)) class Zset(OrderedMixin, PairMixin, Set): - '''An ordered version of :class:`Set`. It derives from -:class:`OrderedMixin` and :class:`PairMixin`.''' + """An ordered version of :class:`Set`. It derives from + :class:`OrderedMixin` and :class:`PairMixin`.""" + pickler = encoders.Double() cache_class = zsetcache def rank(self, value): - '''The rank of a given *value*. This is the position of *value* -in the :class:`OrderedMixin` container.''' + """The rank of a given *value*. This is the position of *value* + in the :class:`OrderedMixin` container.""" value = self.value_pickler.dumps(value) return self.backend_structure().rank(value) class HashTable(KeyValueMixin, Structure): - '''A :class:`Structure` which is the networked equivalent to -a Python ``dict``. Derives from :class:`KeyValueMixin`.''' + """A :class:`Structure` which is the networked equivalent to + a Python ``dict``. Derives from :class:`KeyValueMixin`.""" + pickler = encoders.Default() cache_class = hashcache if not ispy3k: + def iteritems(self): return self.items() @@ -691,10 +703,11 @@ def itervalues(self): class TS(OrderedMixin, KeyValueMixin, Structure): - '''A timeseries is a :class:`Structure` which derives from -:class:`OrderedMixin` and :class:`KeyValueMixin`. -It represents an ordered associative container where keys are timestamps -and values are objects.''' + """A timeseries is a :class:`Structure` which derives from + :class:`OrderedMixin` and :class:`KeyValueMixin`. + It represents an ordered associative container where keys are timestamps + and values are objects.""" + pickler = encoders.DateTimeConverter() value_pickler = encoders.Json() cache_class = tscache @@ -706,20 +719,19 @@ def keys(self): return self.itimes() def rank(self, dte): - '''The rank of a given *dte* in the timeseries''' + """The rank of a given *dte* in the timeseries""" timestamp = self.pickler.dumps(dte) return self.backend_structure().rank(timestamp) def ipop(self, index): - '''Pop a value at *index* from the :class:`TS`. Return ``None`` if -index is not out of bound.''' + """Pop a value at *index* from the :class:`TS`. Return ``None`` if + index is not out of bound.""" backend = self.backend res = backend.structure(self).ipop(index) - return backend.execute(res, - lambda r: self._load_get_data(r, index, None)) + return backend.execute(res, lambda r: self._load_get_data(r, index, None)) def times(self, start, stop, callback=None, **kwargs): - '''The times between times *start* and *stop*.''' + """The times between times *start* and *stop*.""" s1 = self.pickler.dumps(start) s2 = self.pickler.dumps(stop) backend = self.read_backend @@ -727,15 +739,15 @@ def times(self, start, stop, callback=None, **kwargs): return backend.execute(res, callback or self.load_keys) def itimes(self, start=0, stop=-1, callback=None, **kwargs): - '''The times between rank *start* and *stop*.''' + """The times between rank *start* and *stop*.""" backend = self.read_backend res = backend.structure(self).itimes(start, stop, **kwargs) return backend.execute(res, callback or self.load_keys) class String(Sequence, Structure): - '''A String :class:`Sequence` of bytes. -''' + """A String :class:`Sequence` of bytes.""" + cache_class = stringcache value_pickler = encoders.Bytes() @@ -744,7 +756,6 @@ def incr(self, v=1): class Array(Sequence, Structure): - def resize(self, size): return self.backend_structure().resize(size) @@ -753,5 +764,6 @@ def capacity(self): class NumberArray(Array): - '''A compact :class:`Array` containing numbers.''' + """A compact :class:`Array` containing numbers.""" + value_pickler = encoders.CompactDouble() diff --git a/stdnet/odm/structfields.py b/stdnet/odm/structfields.py index 0bf0559..e6057a2 100755 --- a/stdnet/odm/structfields.py +++ b/stdnet/odm/structfields.py @@ -1,23 +1,25 @@ from stdnet.utils import encoders from stdnet.utils.exceptions import * -from .struct import * +from . import related from .fields import Field from .session import LazyProxy -from . import related - +from .struct import * -__all__ = ['StructureField', - 'StructureFieldProxy', - 'StringField', - 'SetField', - 'ListField', - 'HashField', - 'TimeSeriesField'] +__all__ = [ + "StructureField", + "StructureFieldProxy", + "StringField", + "SetField", + "ListField", + "HashField", + "TimeSeriesField", +] class StructureFieldProxy(LazyProxy): - '''A descriptor for a :class:`StructureField`.''' + """A descriptor for a :class:`StructureField`.""" + def __init__(self, field, factory): super(StructureFieldProxy, self).__init__(field) self.factory = factory @@ -27,7 +29,7 @@ def load_from_manager(self, manager): # don't cache when this is a class field return self.get_structure(None, manager.session()) else: - raise NotImplementedError('Cannot access %s from manager' % self) + raise NotImplementedError("Cannot access %s from manager" % self) def load(self, instance, session): cache_name = self.field.get_cache_name() @@ -41,100 +43,105 @@ def load(self, instance, session): structure = self.get_structure(instance, session) setattr(instance, cache_name, structure) if cache_val is not None: - structure.load_data(cache_val, - structure.cache.set_cache) + structure.load_data(cache_val, structure.cache.set_cache) if session: # override session only if a new session is given structure.session = session return structure def get_structure(self, instance, session): if session is None: - raise StructureFieldError('No session available, Cannot access.') + raise StructureFieldError("No session available, Cannot access.") pkvalue = None if not self.field.class_field: pkvalue = instance.pkvalue() if pkvalue is None: - raise StructureFieldError('Cannot access "%s". The "%s" model ' - 'is not persistent' % - (instance, self.field)) - return self.factory(pkvalue=pkvalue, session=session, - field=self.field, pickler=self.field.pickler, - value_pickler=self.field.value_pickler, - **self.field.struct_params) + raise StructureFieldError( + 'Cannot access "%s". The "%s" model ' + "is not persistent" % (instance, self.field) + ) + return self.factory( + pkvalue=pkvalue, + session=session, + field=self.field, + pickler=self.field.pickler, + value_pickler=self.field.value_pickler, + **self.field.struct_params + ) class StructureField(Field): - '''Virtual base class for :class:`Field` which are proxies to -:ref:`data structures ` such as :class:`List`, -:class:`Set`, :class:`Zset`, :class:`HashTable` and timeseries -:class:`TS`. + """Virtual base class for :class:`Field` which are proxies to + :ref:`data structures ` such as :class:`List`, + :class:`Set`, :class:`Zset`, :class:`HashTable` and timeseries + :class:`TS`. -Sometimes you want to structure your data model without breaking it up -into multiple entities. For example, you might want to define model -that contains a list of messages an instance receive:: + Sometimes you want to structure your data model without breaking it up + into multiple entities. For example, you might want to define model + that contains a list of messages an instance receive:: - from stdnet import orm + from stdnet import orm - class MyModel(odm.StdModel): - ... - messages = odm.ListField() + class MyModel(odm.StdModel): + ... + messages = odm.ListField() -By defining structured fields in a model, an instance of that model can access -a stand alone structure in the back-end server with very little effort:: + By defining structured fields in a model, an instance of that model can access + a stand alone structure in the back-end server with very little effort:: - m = MyModel.objects.get(id=1) - m.messages.push_back('Hello there!') + m = MyModel.objects.get(id=1) + m.messages.push_back('Hello there!') -Behind the scenes, this functionality is implemented by Python descriptors_. + Behind the scenes, this functionality is implemented by Python descriptors_. -:parameter model: an optional :class:`StdModel` class. If - specified, the structured will contains ids of instances of the model and - it can be accessed via the :attr:`relmodel` attribute. - It can also be specified as a string if class specification is - not possible. + :parameter model: an optional :class:`StdModel` class. If + specified, the structured will contains ids of instances of the model and + it can be accessed via the :attr:`relmodel` attribute. + It can also be specified as a string if class specification is + not possible. -**Additional Field attributes** + **Additional Field attributes** -.. attribute:: relmodel + .. attribute:: relmodel - Optional :class:`StdModel` class contained in the structure. + Optional :class:`StdModel` class contained in the structure. -.. attribute:: value_pickler + .. attribute:: value_pickler - An :class:`stdnet.utils.encoders.Encoder` used to encode and decode - values. + An :class:`stdnet.utils.encoders.Encoder` used to encode and decode + values. - Default: :class:`stdnet.utils.encoders.Json`. + Default: :class:`stdnet.utils.encoders.Json`. -.. attribute:: pickler + .. attribute:: pickler - Same as the :attr:`value_pickler` attribute, this serializer is applied - to keys, rather than values, in :class:`StructureField` - of :class:`PairMixin` type (these include :class:`HashField`, - :class:`TimeSeriesField` and ordered :class:`SetField`) + Same as the :attr:`value_pickler` attribute, this serializer is applied + to keys, rather than values, in :class:`StructureField` + of :class:`PairMixin` type (these include :class:`HashField`, + :class:`TimeSeriesField` and ordered :class:`SetField`) - Default: ``None``. + Default: ``None``. -.. attribute:: class_field + .. attribute:: class_field - If ``True`` this :class:`StructureField` is a class field (it belongs to - the model class rather than model instances). For example:: + If ``True`` this :class:`StructureField` is a class field (it belongs to + the model class rather than model instances). For example:: - class MyModel(odm.StdModel): - ... - updates = odm.List(class_field=True) + class MyModel(odm.StdModel): + ... + updates = odm.List(class_field=True) - MyModel.updates.push_back(1) + MyModel.updates.push_back(1) - Default: ``False``. + Default: ``False``. + + .. _descriptors: http://users.rcn.com/python/download/Descriptor.htm""" -.. _descriptors: http://users.rcn.com/python/download/Descriptor.htm -''' default_pickler = None default_value_pickler = encoders.Json() - def __init__(self, model=None, pickler=None, value_pickler=None, - class_field=False, **kwargs): + def __init__( + self, model=None, pickler=None, value_pickler=None, class_field=False, **kwargs + ): # Force required to be false super(StructureField, self).__init__(**kwargs) self.relmodel = model @@ -166,18 +173,16 @@ def _set_relmodel(self, relmodel): def _register_with_model(self): data_structure_class = self.structure_class() - self.value_pickler = (self.value_pickler or - data_structure_class.value_pickler) - self.pickler = (self.pickler or data_structure_class.pickler or - self.default_pickler) + self.value_pickler = self.value_pickler or data_structure_class.value_pickler + self.pickler = ( + self.pickler or data_structure_class.pickler or self.default_pickler + ) if not self.value_pickler: if self.relmodel: self.value_pickler = related.ModelFieldPickler(self.relmodel) else: self.value_pickler = self.default_value_pickler - setattr(self.model, - self.name, - StructureFieldProxy(self, data_structure_class)) + setattr(self.model, self.name, StructureFieldProxy(self, data_structure_class)) def add_to_fields(self): self.model._meta.multifields.append(self) @@ -192,7 +197,7 @@ def todelete(self): return True def structure_class(self): - '''Returns the :class:`Structure` class for this field.''' + """Returns the :class:`Structure` class for this field.""" raise NotImplementedError def set_cache(self, instance, data): @@ -200,34 +205,34 @@ def set_cache(self, instance, data): class SetField(StructureField): - '''A field maintaining an unordered or ordered collection of values. -It is initiated without any argument other than an optional model class:: - - class User(odm.StdModel): - username = odm.AtomField(unique = True) - password = odm.AtomField() - following = odm.SetField(model = 'self') - -It can be used in the following way:: - - >>> user = User(username='lsbardel', password='mypassword').save() - >>> user2 = User(username='pippo', password='pippopassword').save() - >>> user.following.add(user2) - >>> user.save() - >>> user2 in user.following - True - -.. attribute:: ordered - - A flag indicating if the elements in this set are ordered with respect - a score. If ordered, the :meth:`StructureField.structure_class` method - returns a :class:`Zset` otherwise a :class:`Set`. - Default ``False``. -''' + """A field maintaining an unordered or ordered collection of values. + It is initiated without any argument other than an optional model class:: + + class User(odm.StdModel): + username = odm.AtomField(unique = True) + password = odm.AtomField() + following = odm.SetField(model = 'self') + + It can be used in the following way:: + + >>> user = User(username='lsbardel', password='mypassword').save() + >>> user2 = User(username='pippo', password='pippopassword').save() + >>> user.following.add(user2) + >>> user.save() + >>> user2 in user.following + True + + .. attribute:: ordered + + A flag indicating if the elements in this set are ordered with respect + a score. If ordered, the :meth:`StructureField.structure_class` method + returns a :class:`Zset` otherwise a :class:`Set`. + Default ``False``.""" + ordered = False def __init__(self, *args, **kwargs): - self.ordered = kwargs.pop('ordered', self.ordered) + self.ordered = kwargs.pop("ordered", self.ordered) super(SetField, self).__init__(*args, **kwargs) def structure_class(self): @@ -235,8 +240,8 @@ def structure_class(self): class ListField(StructureField): - '''A field maintaining a list of values. - + """A field maintaining a list of values. + When accessed from the model instance, it returns a of :class:`List` structure. For example:: @@ -257,20 +262,21 @@ class UserMessage(odm.StdModel): >>> type(u.messages) >>> u.messages.size() - 2 - ''' - type = 'list' + 2 + """ + + type = "list" def structure_class(self): return List class HashField(StructureField): - '''A Hash table field, the networked equivalent of a python dictionary. -Keys are string while values are string/numeric. -it returns an instance of :class:`HashTable` structure. -''' - type = 'hash' + """A Hash table field, the networked equivalent of a python dictionary. + Keys are string while values are string/numeric. + it returns an instance of :class:`HashTable` structure.""" + + type = "hash" default_pickler = encoders.NoEncoder() default_value_pickler = encoders.Json() @@ -283,8 +289,9 @@ def structure_class(self): class TimeSeriesField(HashField): - '''A timeseries field based on :class:`TS` data structure.''' - type = 'ts' + """A timeseries field based on :class:`TS` data structure.""" + + type = "ts" default_pickler = None def structure_class(self): diff --git a/stdnet/odm/utils.py b/stdnet/odm/utils.py index 903dcbf..2e7ede1 100755 --- a/stdnet/odm/utils.py +++ b/stdnet/odm/utils.py @@ -1,60 +1,64 @@ -import logging +import csv import json +import logging import sys -import csv from inspect import isclass from stdnet.utils import StringIO from .globals import get_model_from_hash -__all__ = ['get_serializer', - 'register_serializer', - 'unregister_serializer', - 'all_serializers', - 'Serializer', - 'JsonSerializer'] +__all__ = [ + "get_serializer", + "register_serializer", + "unregister_serializer", + "all_serializers", + "Serializer", + "JsonSerializer", +] -LOGGER = logging.getLogger('stdnet.odm') +LOGGER = logging.getLogger("stdnet.odm") _serializers = {} -if sys.version_info < (2, 7): # pragma: no cover +if sys.version_info < (2, 7): # pragma: no cover def writeheader(dw): # hack to handle writeheader in python 2.6 dw.writerow(dict(((k, k) for k in dw.fieldnames))) + + else: + def writeheader(dw): dw.writeheader() def get_serializer(name, **options): - '''Retrieve a serializer register as *name*. If the serializer is not -available a ``ValueError`` exception will raise. -A common usage pattern:: - - qs = MyModel.objects.query().sort_by('id') - s = odm.get_serializer('json') - s.dump(qs) -''' + """Retrieve a serializer register as *name*. If the serializer is not + available a ``ValueError`` exception will raise. + A common usage pattern:: + + qs = MyModel.objects.query().sort_by('id') + s = odm.get_serializer('json') + s.dump(qs)""" if name in _serializers: serializer = _serializers[name] return serializer(**options) else: - raise ValueError('Unknown serializer {0}.'.format(name)) + raise ValueError("Unknown serializer {0}.".format(name)) def register_serializer(name, serializer): - '''\ + """\ Register a new serializer to the library. :parameter name: serializer name (it can override existing serializers). :parameter serializer: an instance or a derived class of a :class:`stdnet.odm.Serializer` class or a callable. -''' +""" if not isclass(serializer): serializer = serializer.__class__ _serializers[name] = serializer @@ -69,19 +73,19 @@ def all_serializers(): class Serializer(object): - '''The stdnet serializer base class. During initialization, the *options* -dictionary is used to override the :attr:`default_options`. These are specific -to each :class:`Serializer` implementation. + """The stdnet serializer base class. During initialization, the *options* + dictionary is used to override the :attr:`default_options`. These are specific + to each :class:`Serializer` implementation. + + .. attribute:: default_options -.. attribute:: default_options + Dictionary of default options which are overwritten during initialisation. + By default it is an empty dictionary. - Dictionary of default options which are overwritten during initialisation. - By default it is an empty dictionary. + .. attribute:: options -.. attribute:: options + Dictionary of options.""" - Dictionary of options. -''' default_options = {} arguments = () @@ -92,42 +96,42 @@ def __init__(self, **options): @property def data(self): - '''CList of data to dump into a stream.''' - if not hasattr(self, '_data'): + """CList of data to dump into a stream.""" + if not hasattr(self, "_data"): self._data = [] return self._data def dump(self, qs): - '''Add a :class:`Query` ``qs`` into the collection of :attr:`data` -to dump into a stream. No writing is done until the :meth:`write` method.''' + """Add a :class:`Query` ``qs`` into the collection of :attr:`data` + to dump into a stream. No writing is done until the :meth:`write` method.""" raise NotImplementedError def write(self, stream=None): - '''Write the serialized data into a stream. If *stream* is not -provided, a python ``StringIO`` is used. + """Write the serialized data into a stream. If *stream* is not + provided, a python ``StringIO`` is used. -:return: the stream object.''' + :return: the stream object.""" raise NotImplementedError def load(self, models, stream, model=None): - '''Load a stream of data into the database. + """Load a stream of data into the database. -:param models: the :class:`Router` which must contains all the model this - method will load. -:param stream: bytes or an object with a ``read`` method returning bytes. -:param model: Optional :class:`StdModel` we need to load. If not provided all - models in ``stream`` are loaded. + :param models: the :class:`Router` which must contains all the model this + method will load. + :param stream: bytes or an object with a ``read`` method returning bytes. + :param model: Optional :class:`StdModel` we need to load. If not provided all + models in ``stream`` are loaded. -This method must be implemented by subclasses. -''' + This method must be implemented by subclasses.""" raise NotImplementedError class JsonSerializer(Serializer): - '''The default :class:`Serializer` of :mod:`stdnet`. It -serialise/unserialise models into json data. It has one option given -by the *indent* of the ``json`` string for pretty serialisation.''' - arguments = ('indent',) + """The default :class:`Serializer` of :mod:`stdnet`. It + serialise/unserialise models into json data. It has one option given + by the *indent* of the ``json`` string for pretty serialisation.""" + + arguments = ("indent",) def get_data(self, qs): data = [] @@ -135,9 +139,7 @@ def get_data(self, qs): data.append(obj.tojson()) meta = obj._meta if data: - return {'model': str(meta), - 'hash': meta.hash, - 'data': data} + return {"model": str(meta), "hash": meta.hash, "data": data} def dump(self, qs): data = self.get_data(qs) @@ -151,43 +153,43 @@ def write(self, stream=None): return stream def load(self, models, stream, model=None): - if hasattr(stream, 'read'): + if hasattr(stream, "read"): stream = stream.read() data = json.loads(stream, **self.options) for model_data in data: - model = get_model_from_hash(model_data['hash']) + model = get_model_from_hash(model_data["hash"]) if model: model = self.on_load_model(model, model_data) if model: manager = models[model] - LOGGER.info('Loading model %s', model._meta) + LOGGER.info("Loading model %s", model._meta) session = manager.session() with session.begin(signal_commit=False) as t: - for item_data in model_data['data']: + for item_data in model_data["data"]: t.add(model.from_base64_data(**item_data)) else: - LOGGER.error('Could not load model %s', - model_data.get('model')) + LOGGER.error("Could not load model %s", model_data.get("model")) self.on_finished_load() def on_load_model(self, model, model_data): - '''Callback when a *model* is about to be loaded. If it returns the -model, the model will get loaded otherwise it will skip the loading.''' + """Callback when a *model* is about to be loaded. If it returns the + model, the model will get loaded otherwise it will skip the loading.""" return model def on_finished_load(self): - '''Callback when loading of data is finished''' + """Callback when loading of data is finished""" pass class CsvSerializer(Serializer): - '''A csv serializer for single model. It serialize/unserialize a model -query into a csv file.''' - default_options = {'lineterminator': '\n'} + """A csv serializer for single model. It serialize/unserialize a model + query into a csv file.""" + + default_options = {"lineterminator": "\n"} def dump(self, qs): if self.data: - raise ValueError('Cannot serialize more than one model into CSV') + raise ValueError("Cannot serialize more than one model into CSV") fields = None data = [] for obj in qs: @@ -199,18 +201,15 @@ def dump(self, qs): data.append(js) meta = obj._meta ordered_fields = [meta.pkname()] - ordered_fields.extend((f.name for f in meta.scalarfields - if f.name in fields)) - data = {'fieldnames': ordered_fields, - 'hash': meta.hash, - 'data': data} + ordered_fields.extend((f.name for f in meta.scalarfields if f.name in fields)) + data = {"fieldnames": ordered_fields, "hash": meta.hash, "data": data} self.data.append(data) def write(self, stream=None): stream = stream or StringIO() if self.data: - fieldnames = self.data[0]['fieldnames'] - data = self.data[0]['data'] + fieldnames = self.data[0]["fieldnames"] + data = self.data[0]["data"] if data: w = csv.DictWriter(stream, fieldnames, **self.options) writeheader(w) @@ -220,7 +219,7 @@ def write(self, stream=None): def load(self, models, stream, model=None): if not model: - raise ValueError('Model is required when loading from csv file') + raise ValueError("Model is required when loading from csv file") r = csv.DictReader(stream, **self.options) with models.session().begin() as t: for item_data in r: @@ -228,5 +227,5 @@ def load(self, models, stream, model=None): return t.on_result -register_serializer('json', JsonSerializer) -register_serializer('csv', CsvSerializer) +register_serializer("json", JsonSerializer) +register_serializer("csv", CsvSerializer) diff --git a/stdnet/utils/__init__.py b/stdnet/utils/__init__.py index e4ea04d..b429472 100755 --- a/stdnet/utils/__init__.py +++ b/stdnet/utils/__init__.py @@ -1,12 +1,13 @@ +from collections import Mapping from inspect import istraceback from itertools import chain -from collections import Mapping from uuid import uuid4 from .py2py3 import * if ispy3k: # pragma: no cover import pickle + unichr = chr def raise_error_trace(err, traceback): @@ -15,14 +16,16 @@ def raise_error_trace(err, traceback): else: raise err -else: # pragma: no cover + +else: # pragma: no cover import cPickle as pickle + unichr = unichr from .fallbacks.py2 import raise_error_trace +from .dates import * from .jsontools import * from .populate import populate -from .dates import * def gen_unique_id(short=True): @@ -46,21 +49,24 @@ def int_or_float(v): def grouper(n, iterable, padvalue=None): - '''grouper(3, 'abcdefg', 'x') --> ('a','b','c'), ('d','e','f'), - ('g','x','x')''' - return zip_longest(*[iter(iterable)]*n, fillvalue=padvalue) + """grouper(3, 'abcdefg', 'x') --> ('a','b','c'), ('d','e','f'), + ('g','x','x')""" + return zip_longest(*[iter(iterable)] * n, fillvalue=padvalue) def _format_int(val): positive = val >= 0 - sval = ''.join(reversed(','.join(( - ''.join(g) for g in grouper(3, reversed(str(abs(val))), ''))))) - return sval if positive else '-'+sval + sval = "".join( + reversed( + ",".join(("".join(g) for g in grouper(3, reversed(str(abs(val))), ""))) + ) + ) + return sval if positive else "-" + sval def format_int(val): try: # for python 2.7 and up - return '{:,}'.format(val) + return "{:,}".format(val) except ValueError: # pragma nocover _format_int(val) @@ -80,7 +86,7 @@ def _flat2d_gen(iterable): def flat2d(iterable): - if hasattr(iterable, '__len__'): + if hasattr(iterable, "__len__"): return chain(*iterable) else: return _flat2d_gen(iterable) diff --git a/stdnet/utils/dates.py b/stdnet/utils/dates.py index 35bef09..2bc31ed 100644 --- a/stdnet/utils/dates.py +++ b/stdnet/utils/dates.py @@ -1,13 +1,12 @@ -from time import mktime -from datetime import datetime, timedelta from collections import namedtuple +from datetime import datetime, timedelta +from time import mktime -class Interval(namedtuple('IntervalBase', 'start end')): - +class Interval(namedtuple("IntervalBase", "start end")): def __init__(self, start, end): if start > end: - raise ValueError('Bad interval.') + raise ValueError("Bad interval.") def __reduce__(self): return tuple, (tuple(self),) @@ -25,12 +24,10 @@ def __eq__(self, other): return self.start == other.start and self.end == other.end def union(self, other): - return Interval(min(self.start, other.start), - max(self.end, other.end)) + return Interval(min(self.start, other.start), max(self.end, other.end)) class Intervals(list): - def __init__(self, data=None): super(Intervals, self).__init__() if data: @@ -70,20 +67,20 @@ def check(self): while merged and len(self) > 1: merged = False for idx, interval in enumerate(self[:-1]): - other = self[idx+1] + other = self[idx + 1] if interval < other: continue elif interval > other: raise ValueError() else: self[idx] = interval.union(other) - self.pop(idx+1) + self.pop(idx + 1) merged = True break def date2timestamp(dte): - '''Convert a *dte* into a valid unix timestamp.''' + """Convert a *dte* into a valid unix timestamp.""" seconds = mktime(dte.timetuple()) if isinstance(dte, datetime): return seconds + dte.microsecond / 1000000.0 @@ -117,14 +114,19 @@ def default_parse_interval(dt, delta=0): return dt -def missing_intervals(startdate, enddate, start, end, - dateconverter=None, - parseinterval=None, - intervals=None): - '''Given a ``startdate`` and an ``enddate`` dates, evaluate the -date intervals from which data is not available. It return a list of -two-dimensional tuples containing start and end date for the interval. -The list could countain 0,1 or 2 tuples.''' +def missing_intervals( + startdate, + enddate, + start, + end, + dateconverter=None, + parseinterval=None, + intervals=None, +): + """Given a ``startdate`` and an ``enddate`` dates, evaluate the + date intervals from which data is not available. It return a list of + two-dimensional tuples containing start and end date for the interval. + The list could countain 0,1 or 2 tuples.""" parseinterval = parseinterval or default_parse_interval dateconverter = dateconverter or todate startdate = dateconverter(parseinterval(startdate, 0)) @@ -162,7 +164,7 @@ def missing_intervals(startdate, enddate, start, end, def dategenerator(start, end, step=1, desc=False): - '''Generates dates between *atrt* and *end*.''' + """Generates dates between *atrt* and *end*.""" delta = timedelta(abs(step)) end = max(start, end) if desc: diff --git a/stdnet/utils/encoders.py b/stdnet/utils/encoders.py index 87f2585..c6b1d51 100644 --- a/stdnet/utils/encoders.py +++ b/stdnet/utils/encoders.py @@ -1,4 +1,4 @@ -'''Classes used for encoding and decoding :class:`stdnet.odm.Field` values. +"""Classes used for encoding and decoding :class:`stdnet.odm.Field` values. .. autoclass:: Encoder @@ -23,49 +23,53 @@ .. autoclass:: DateTimeConverter .. autoclass:: DateConverter -''' +""" import json import logging - -from datetime import datetime, date +from datetime import date, datetime from struct import pack, unpack -from stdnet.utils import (JSONDateDecimalEncoder, pickle, - JSONDateDecimalEncoder, DefaultJSONHook, - ispy3k, date2timestamp, timestamp2date, - string_type) +from stdnet.utils import ( + DefaultJSONHook, + JSONDateDecimalEncoder, + date2timestamp, + ispy3k, + pickle, + string_type, + timestamp2date, +) -nan = float('nan') +nan = float("nan") -LOGGER = logging.getLogger('stdnet.encoders') +LOGGER = logging.getLogger("stdnet.encoders") class Encoder(object): - '''Virtaul class for encoding data in -:ref:`data structures `. It exposes two methods -for encoding and decoding data to and from the data server. + """Virtaul class for encoding data in + :ref:`data structures `. It exposes two methods + for encoding and decoding data to and from the data server. + + .. attribute:: type -.. attribute:: type + The type of data once loaded into python""" - The type of data once loaded into python -''' type = None def dumps(self, x): - '''Serialize data for database''' + """Serialize data for database""" raise NotImplementedError() def loads(self, x): - '''Unserialize data from database''' + """Unserialize data from database""" raise NotImplementedError() def require_session(self): - '''``True`` if this :class:`Encoder` requires a -:class:`stdnet.odm.Session`.''' + """``True`` if this :class:`Encoder` requires a + :class:`stdnet.odm.Session`.""" return False def load_iterable(self, iterable, session=None): - '''Load an ``iterable``. + """Load an ``iterable``. By default it returns a generator of data loaded via the :meth:`loads` method. @@ -73,7 +77,7 @@ def load_iterable(self, iterable, session=None): :param iterable: an iterable over data to load. :param session: Optional :class:`stdnet.odm.Session`. :return: an iterable over decoded data. - ''' + """ data = [] load = self.loads for v in iterable: @@ -82,15 +86,17 @@ def load_iterable(self, iterable, session=None): class Default(Encoder): - '''The default unicode encoder. It converts bytes to unicode when loading -data from the server. Viceversa when sending data.''' + """The default unicode encoder. It converts bytes to unicode when loading + data from the server. Viceversa when sending data.""" + type = string_type - def __init__(self, charset='utf-8', encoding_errors='strict'): + def __init__(self, charset="utf-8", encoding_errors="strict"): self.charset = charset self.encoding_errors = encoding_errors if ispy3k: + def dumps(self, x): if isinstance(x, bytes): return x @@ -104,6 +110,7 @@ def loads(self, x): return str(x) else: # pragma nocover + def dumps(self, x): if not isinstance(x, unicode): x = str(x) @@ -127,16 +134,18 @@ def safe_number(v): class NumericDefault(Default): - '''It decodes values into unicode unless they are numeric, in which case -they are decoded as such.''' + """It decodes values into unicode unless they are numeric, in which case + they are decoded as such.""" + def loads(self, x): x = super(NumericDefault, self).loads(x) return safe_number(x) class Double(Encoder): - '''It decodes values into doubles. If the decoding fails it decodes the -value into ``nan`` (not a number).''' + """It decodes values into doubles. If the decoding fails it decodes the + value into ``nan`` (not a number).""" + type = float def loads(self, x): @@ -144,14 +153,16 @@ def loads(self, x): return float(x) except (ValueError, TypeError): return nan + dumps = loads class Bytes(Encoder): - '''The binary encoder''' + """The binary encoder""" + type = bytes - def __init__(self, charset='utf-8', encoding_errors='strict'): + def __init__(self, charset="utf-8", encoding_errors="strict"): self.charset = charset self.encoding_errors = encoding_errors @@ -164,7 +175,8 @@ def dumps(self, x): class NoEncoder(Encoder): - '''A dummy encoder class''' + """A dummy encoder class""" + def dumps(self, x): return x @@ -173,8 +185,9 @@ def loads(self, x): class PythonPickle(Encoder): - '''A safe pickle serializer. By default we use protocol 2 for compatibility -between python 2 and python 3.''' + """A safe pickle serializer. By default we use protocol 2 for compatibility + between python 2 and python 3.""" + type = bytes def __init__(self, protocol=2): @@ -185,7 +198,7 @@ def dumps(self, x): try: return pickle.dumps(x, self.protocol) except: - LOGGER.exception('Could not pickle %s', x) + LOGGER.exception("Could not pickle %s", x) def loads(self, x): if x is None: @@ -194,19 +207,22 @@ def loads(self, x): try: return pickle.loads(x) except (pickle.UnpicklingError, EOFError, ValueError): - return x.decode('utf-8', 'ignore') + return x.decode("utf-8", "ignore") else: return x class Json(Default): - '''A JSON encoder for maintaning python types when dealing with -remote data structures.''' - def __init__(self, - charset='utf-8', - encoding_errors='strict', - json_encoder=None, - object_hook=None): + """A JSON encoder for maintaning python types when dealing with + remote data structures.""" + + def __init__( + self, + charset="utf-8", + encoding_errors="strict", + json_encoder=None, + object_hook=None, + ): super(Json, self).__init__(charset, encoding_errors) self.json_encoder = json_encoder or JSONDateDecimalEncoder self.object_hook = object_hook or DefaultJSONHook @@ -221,7 +237,8 @@ def loads(self, x): class DateTimeConverter(Encoder): - '''Convert to and from python ``datetime`` objects and unix timestamps''' + """Convert to and from python ``datetime`` objects and unix timestamps""" + type = datetime def dumps(self, value): @@ -233,7 +250,7 @@ def loads(self, value): class DateConverter(DateTimeConverter): type = date - '''Convert to and from python ``date`` objects and unix timestamps''' + """Convert to and from python ``date`` objects and unix timestamps""" def loads(self, value): return timestamp2date(value).date() @@ -241,8 +258,8 @@ def loads(self, value): class CompactDouble(Encoder): type = float - nil = b'\x00'*8 - nan = float('nan') + nil = b"\x00" * 8 + nan = float("nan") def dumps(self, value): if value is None: @@ -251,10 +268,10 @@ def dumps(self, value): if value != value: return self.nil else: - return pack('>d', value) + return pack(">d", value) def loads(self, value): if value == self.nil: return self.nan else: - return unpack('>d', value)[0] + return unpack(">d", value)[0] diff --git a/stdnet/utils/fallbacks/_collections.py b/stdnet/utils/fallbacks/_collections.py index 548bf3a..636ac1e 100755 --- a/stdnet/utils/fallbacks/_collections.py +++ b/stdnet/utils/fallbacks/_collections.py @@ -1,19 +1,19 @@ from UserDict import DictMixin -__all__ = ['OrderedDict'] +__all__ = ["OrderedDict"] class OrderedDict(dict, DictMixin): - '''Drop-in substitute for Py2.7's new collections.OrderedDict. -The recipe has big-oh performance that matches regular dictionaries -(amortized O(1) insertion/deletion/lookup and O(n) -iteration/repr/copy/equality_testing). + """Drop-in substitute for Py2.7's new collections.OrderedDict. + The recipe has big-oh performance that matches regular dictionaries + (amortized O(1) insertion/deletion/lookup and O(n) + iteration/repr/copy/equality_testing). -From http://code.activestate.com/recipes/576693/''' + From http://code.activestate.com/recipes/576693/""" def __init__(self, *args, **kwds): if len(args) > 1: - raise TypeError('expected at most 1 arguments, got %d' % len(args)) + raise TypeError("expected at most 1 arguments, got %d" % len(args)) try: self.__end except AttributeError: @@ -22,8 +22,8 @@ def __init__(self, *args, **kwds): def clear(self): self.__end = end = [] - end += [None, end, end] # sentinel node for doubly linked list - self.__map = {} # key --> [key, prev, next] + end += [None, end, end] # sentinel node for doubly linked list + self.__map = {} # key --> [key, prev, next] dict.clear(self) def __setitem__(self, key, value): @@ -55,7 +55,7 @@ def __reversed__(self): def popitem(self, last=True): if not self: - raise KeyError('dictionary is empty') + raise KeyError("dictionary is empty") if last: key = reversed(self).next() else: @@ -87,8 +87,8 @@ def keys(self): def __repr__(self): if not self: - return '%s()' % (self.__class__.__name__,) - return '%s(%r)' % (self.__class__.__name__, self.items()) + return "%s()" % (self.__class__.__name__,) + return "%s(%r)" % (self.__class__.__name__, self.items()) def copy(self): return self.__class__(self) diff --git a/stdnet/utils/fallbacks/_importlib.py b/stdnet/utils/fallbacks/_importlib.py index c005cb7..fc92331 100755 --- a/stdnet/utils/fallbacks/_importlib.py +++ b/stdnet/utils/fallbacks/_importlib.py @@ -4,15 +4,14 @@ def _resolve_name(name, package, level): """Return the absolute name of the module to be imported.""" - if not hasattr(package, 'rindex'): + if not hasattr(package, "rindex"): raise ValueError("'package' not set to a string") dot = len(package) for x in xrange(level, 1, -1): try: - dot = package.rindex('.', 0, dot) + dot = package.rindex(".", 0, dot) except ValueError: - raise ValueError("attempted relative import beyond top-level " - "package") + raise ValueError("attempted relative import beyond top-level " "package") return "%s.%s" % (package[:dot], name) @@ -24,12 +23,12 @@ def import_module(name, package=None): relative import to an absolute import. """ - if name.startswith('.'): + if name.startswith("."): if not package: raise TypeError("relative imports require the 'package' argument") level = 0 for character in name: - if character != '.': + if character != ".": break level += 1 name = _resolve_name(name[level:], package, level) diff --git a/stdnet/utils/importer.py b/stdnet/utils/importer.py index 7a89bf6..30b2fa8 100755 --- a/stdnet/utils/importer.py +++ b/stdnet/utils/importer.py @@ -1,4 +1,4 @@ -try: # pragma nocover +try: # pragma nocover from importlib import * except ImportError: # pragma nocover from .fallbacks._importlib import * diff --git a/stdnet/utils/jsontools.py b/stdnet/utils/jsontools.py index 71fc4f3..a507556 100644 --- a/stdnet/utils/jsontools.py +++ b/stdnet/utils/jsontools.py @@ -1,4 +1,4 @@ -''' +""" JSONDateDecimalEncoder ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -17,26 +17,33 @@ addmul_number_dicts ~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. autofunction:: addmul_number_dicts -''' +""" +import json import time +from collections import Mapping from datetime import date, datetime from decimal import Decimal -from collections import Mapping -import json from stdnet.utils import iteritems - -__ALL__ = ['JSPLITTER', 'EMPTYJSON', - 'date2timestamp', 'totimestamp', - 'totimestamp2', 'todatetime', - 'JSONDateDecimalEncoder', 'date_decimal_hook', - 'DefaultJSONEncoder', 'DefaultJSONHook', - 'flat_to_nested', 'dict_flat_generator', - 'addmul_number_dicts'] - -JSPLITTER = '__' -EMPTYJSON = (b'', '', None) +__ALL__ = [ + "JSPLITTER", + "EMPTYJSON", + "date2timestamp", + "totimestamp", + "totimestamp2", + "todatetime", + "JSONDateDecimalEncoder", + "date_decimal_hook", + "DefaultJSONEncoder", + "DefaultJSONHook", + "flat_to_nested", + "dict_flat_generator", + "addmul_number_dicts", +] + +JSPLITTER = "__" +EMPTYJSON = (b"", "", None) date2timestamp = lambda dte: int(time.mktime(dte.timetuple())) @@ -45,7 +52,7 @@ def totimestamp(dte): def totimestamp2(dte): - return totimestamp(dte) + 0.000001*dte.microsecond + return totimestamp(dte) + 0.000001 * dte.microsecond def todatetime(tstamp): @@ -54,24 +61,24 @@ def todatetime(tstamp): class JSONDateDecimalEncoder(json.JSONEncoder): """The default JSON encoder used by stdnet. It provides -JSON serialization for four additional classes: + JSON serialization for four additional classes: -* `datetime.date` as a ``{'__date__': timestamp}`` dictionary -* `datetime.datetime` as a ``{'__datetime__': timestamp}`` dictionary -* `decimal.Decimal` as a ``{'__decimal__': number}`` dictionary + * `datetime.date` as a ``{'__date__': timestamp}`` dictionary + * `datetime.datetime` as a ``{'__datetime__': timestamp}`` dictionary + * `decimal.Decimal` as a ``{'__decimal__': number}`` dictionary + + .. seealso:: It is the default encoder for :class:`stdnet.odm.JSONField`""" -.. seealso:: It is the default encoder for :class:`stdnet.odm.JSONField` -""" def default(self, obj): - if hasattr(obj, 'tojson'): + if hasattr(obj, "tojson"): # handle the Model instances return obj.tojson() if isinstance(obj, datetime): - return {'__datetime__': totimestamp2(obj)} + return {"__datetime__": totimestamp2(obj)} elif isinstance(obj, date): - return {'__date__': totimestamp(obj)} + return {"__date__": totimestamp(obj)} elif isinstance(obj, Decimal): - return {'__decimal__': str(obj)} + return {"__decimal__": str(obj)} elif ndarray and isinstance(obj, ndarray): return obj.tolist() else: @@ -79,14 +86,14 @@ def default(self, obj): def date_decimal_hook(dct): - '''The default JSON decoder hook. It is the inverse of -:class:`stdnet.utils.jsontools.JSONDateDecimalEncoder`.''' - if '__datetime__' in dct: - return todatetime(dct['__datetime__']) - elif '__date__' in dct: - return todatetime(dct['__date__']).date() - elif '__decimal__' in dct: - return Decimal(dct['__decimal__']) + """The default JSON decoder hook. It is the inverse of + :class:`stdnet.utils.jsontools.JSONDateDecimalEncoder`.""" + if "__datetime__" in dct: + return todatetime(dct["__datetime__"]) + elif "__date__" in dct: + return todatetime(dct["__date__"]).date() + elif "__decimal__" in dct: + return Decimal(dct["__decimal__"]) else: return dct @@ -95,18 +102,17 @@ def date_decimal_hook(dct): DefaultJSONHook = date_decimal_hook -def flat_to_nested(data, instance=None, attname=None, - separator=None, loads=None): - '''Convert a flat representation of a dictionary to -a nested representation. Fields in the flat representation are separated -by the *splitter* parameters. +def flat_to_nested(data, instance=None, attname=None, separator=None, loads=None): + """Convert a flat representation of a dictionary to + a nested representation. Fields in the flat representation are separated + by the *splitter* parameters. -:parameter data: a flat dictionary of key value pairs. -:parameter instance: optional instance of a model. -:parameter attribute: optional attribute of a model. -:parameter separator: optional separator. Default ``"__"``. -:parameter loads: optional data unserializer. -:rtype: a nested dictionary''' + :parameter data: a flat dictionary of key value pairs. + :parameter instance: optional instance of a model. + :parameter attribute: optional attribute of a model. + :parameter separator: optional separator. Default ``"__"``. + :parameter loads: optional data unserializer. + :rtype: a nested dictionary""" separator = separator or JSPLITTER val = {} flat_vals = {} @@ -139,13 +145,13 @@ def flat_to_nested(data, instance=None, attname=None, else: nd = d[k] if not isinstance(nd, dict): - nd = {'': nd} + nd = {"": nd} d[k] = nd d = nd if lk not in d: d[lk] = value else: - d[lk][''] = value + d[lk][""] = value if instance and flat_vals: for attr, value in iteritems(flat_vals): @@ -154,16 +160,21 @@ def flat_to_nested(data, instance=None, attname=None, return val -def dict_flat_generator(value, attname=None, splitter=JSPLITTER, - dumps=None, prefix=None, error=ValueError, - recursive=True): - '''Convert a nested dictionary into a flat dictionary representation''' +def dict_flat_generator( + value, + attname=None, + splitter=JSPLITTER, + dumps=None, + prefix=None, + error=ValueError, + recursive=True, +): + """Convert a nested dictionary into a flat dictionary representation""" if not isinstance(value, dict) or not recursive: if not prefix: - raise error('Cannot assign a non dictionary to a JSON field') + raise error("Cannot assign a non dictionary to a JSON field") else: - name = '%s%s%s' % (attname, splitter, - prefix) if attname else prefix + name = "%s%s%s" % (attname, splitter, prefix) if attname else prefix yield name, dumps(value) if dumps else value else: # loop over dictionary @@ -171,10 +182,10 @@ def dict_flat_generator(value, attname=None, splitter=JSPLITTER, val = value[field] key = prefix if field: - key = '%s%s%s' % (prefix, splitter, - field) if prefix else field - for k, v2 in dict_flat_generator(val, attname, splitter, dumps, - key, error, field): + key = "%s%s%s" % (prefix, splitter, field) if prefix else field + for k, v2 in dict_flat_generator( + val, attname, splitter, dumps, key, error, field + ): yield k, v2 @@ -199,23 +210,23 @@ def value_type(data): def addmul_number_dicts(series): - '''Multiply dictionaries by a numeric values and add them together. + """Multiply dictionaries by a numeric values and add them together. -:parameter series: a tuple of two elements tuples. Each serie is of the form:: + :parameter series: a tuple of two elements tuples. Each serie is of the form:: - (weight,dictionary) + (weight,dictionary) - where ``weight`` is a number and ``dictionary`` is a dictionary with - numeric values. -:parameter skip: optional list of field names to skip. + where ``weight`` is a number and ``dictionary`` is a dictionary with + numeric values. + :parameter skip: optional list of field names to skip. -Only common fields are aggregated. If a field has a non-numeric value it is -not included either.''' + Only common fields are aggregated. If a field has a non-numeric value it is + not included either.""" if not series: return vtype = value_type((s[1] for s in series)) if vtype == 1: - return sum((weight*float(d) for weight, d in series)) + return sum((weight * float(d) for weight, d in series)) elif vtype == 3: keys = set(series[0][1]) for serie in series[1:]: diff --git a/stdnet/utils/populate.py b/stdnet/utils/populate.py index 4506ed2..8d25b64 100755 --- a/stdnet/utils/populate.py +++ b/stdnet/utils/populate.py @@ -1,81 +1,88 @@ -from datetime import date, timedelta -from random import uniform, randint, choice import string +from datetime import date, timedelta +from random import choice, randint, uniform from stdnet.utils import ispy3k if ispy3k: # pragma nocover characters = string.ascii_letters + string.digits -else: # pragma nocover +else: # pragma nocover characters = string.letters + string.digits range = xrange def_converter = lambda x: x -def populate(datatype='string', size=10, start=None, end=None, - converter=None, choice_from=None, **kwargs): - '''Utility function for populating lists with random data. -Useful for populating database with data for fuzzy testing. -Supported data-types +def populate( + datatype="string", + size=10, + start=None, + end=None, + converter=None, + choice_from=None, + **kwargs +): + """Utility function for populating lists with random data. + Useful for populating database with data for fuzzy testing. + Supported data-types -* *string* - For example:: + * *string* + For example:: - populate('string',100, min_len=3, max_len=10) + populate('string',100, min_len=3, max_len=10) - create a 100 elements list with random strings - with random length between 3 and 10 + create a 100 elements list with random strings + with random length between 3 and 10 -* *date* - For example:: + * *date* + For example:: - from datetime import date - populate('date',200, start = date(1997,1,1), end = date.today()) + from datetime import date + populate('date',200, start = date(1997,1,1), end = date.today()) - create a 200 elements list with random datetime.date objects - between *start* and *end* + create a 200 elements list with random datetime.date objects + between *start* and *end* -* *integer* - For example:: + * *integer* + For example:: - populate('integer',200, start = 0, end = 1000) + populate('integer',200, start = 0, end = 1000) - create a 200 elements list with random int between *start* and *end* + create a 200 elements list with random int between *start* and *end* -* *float* - For example:: + * *float* + For example:: - populate('float', 200, start = 0, end = 10) + populate('float', 200, start = 0, end = 10) - create a 200 elements list with random floats between *start* and *end* + create a 200 elements list with random floats between *start* and *end* -* *choice* (elements of an iterable) - For example:: + * *choice* (elements of an iterable) + For example:: - populate('choice', 200, choice_from = ['pippo','pluto','blob']) + populate('choice', 200, choice_from = ['pippo','pluto','blob']) - create a 200 elements list with random elements from *choice_from*. - ''' + create a 200 elements list with random elements from *choice_from*. + """ data = [] converter = converter or def_converter - if datatype == 'date': + if datatype == "date": date_end = end or date.today() date_start = start or date(1990, 1, 1) delta = date_end - date_start for s in range(size): data.append(converter(random_date(date_start, delta.days))) - elif datatype == 'integer': + elif datatype == "integer": start = start or 0 end = end or 1000000 for s in range(size): data.append(converter(randint(start, end))) - elif datatype == 'float': + elif datatype == "float": start = start or 0 end = end or 10 for s in range(size): data.append(converter(uniform(start, end))) - elif datatype == 'choice' and choice_from: + elif datatype == "choice" and choice_from: for s in range(size): data.append(choice(list(choice_from))) else: @@ -86,7 +93,7 @@ def populate(datatype='string', size=10, start=None, end=None, def random_string(min_len=3, max_len=20, **kwargs): len = randint(min_len, max_len) if max_len > min_len else min_len - return ''.join((choice(characters) for s in range(len))) + return "".join((choice(characters) for s in range(len))) def random_date(date_start, delta): diff --git a/stdnet/utils/py2py3.py b/stdnet/utils/py2py3.py index d4d70be..cdc8515 100755 --- a/stdnet/utils/py2py3.py +++ b/stdnet/utils/py2py3.py @@ -1,6 +1,6 @@ -'''\ +"""\ Simple python script which helps writing python 2.6 \ -forward compatible code with python 3''' +forward compatible code with python 3""" import os import sys import types @@ -19,55 +19,59 @@ long = int range = range - from urllib import parse as urlparse - from io import StringIO, BytesIO + from io import BytesIO, StringIO from itertools import zip_longest + from urllib import parse as urlparse urlencode = urlparse.urlencode class UnicodeMixin(object): - def __unicode__(self): - return '{0} object'.format(self.__class__.__name__) + return "{0} object".format(self.__class__.__name__) def __str__(self): return self.__unicode__() def __repr__(self): - return '%s: %s' % (self.__class__.__name__, self) + return "%s: %s" % (self.__class__.__name__, self) - def native_str(s, encoding='utf-8'): + def native_str(s, encoding="utf-8"): if isinstance(s, bytes): return s.decode(encoding) return s + # Python 2 -else: # pragma: no cover +else: # pragma: no cover string_type = unicode itervalues = lambda d: d.itervalues() iteritems = lambda d: d.iteritems() int_type = (types.IntType, types.LongType) - from itertools import izip as zip, imap as map, izip_longest as zip_longest + from itertools import imap as map + from itertools import izip as zip + from itertools import izip_longest as zip_longest + range = xrange long = long - import urlparse from urllib import urlencode + + import urlparse from cStringIO import StringIO + BytesIO = StringIO class UnicodeMixin(object): - def __unicode__(self): - return unicode('{0} object'.format(self.__class__.__name__)) + return unicode("{0} object".format(self.__class__.__name__)) def __str__(self): - return self.__unicode__().encode('utf-8', 'ignore') + return self.__unicode__().encode("utf-8", "ignore") def __repr__(self): - return '%s: %s' % (self.__class__.__name__, self) + return "%s: %s" % (self.__class__.__name__, self) - def native_str(s, encoding='utf-8'): + def native_str(s, encoding="utf-8"): if isinstance(s, unicode): return s.encode(encoding) return s @@ -77,13 +81,13 @@ def native_str(s, encoding='utf-8'): is_int = lambda x: isinstance(x, int_type) -def to_bytes(s, encoding=None, errors='strict'): +def to_bytes(s, encoding=None, errors="strict"): """Returns a bytestring version of 's', -encoded as specified in 'encoding'.""" - encoding = encoding or 'utf-8' + encoded as specified in 'encoding'.""" + encoding = encoding or "utf-8" if isinstance(s, bytes): - if encoding != 'utf-8': - return s.decode('utf-8', errors).encode(encoding, errors) + if encoding != "utf-8": + return s.decode("utf-8", errors).encode(encoding, errors) else: return s if not is_string(s): @@ -91,9 +95,9 @@ def to_bytes(s, encoding=None, errors='strict'): return s.encode(encoding, errors) -def to_string(s, encoding=None, errors='strict'): +def to_string(s, encoding=None, errors="strict"): """Inverse of to_bytes""" - encoding = encoding or 'utf-8' + encoding = encoding or "utf-8" if isinstance(s, bytes): return s.decode(encoding, errors) if not is_string(s): diff --git a/stdnet/utils/skiplist.py b/stdnet/utils/skiplist.py index 8354274..060c0ae 100644 --- a/stdnet/utils/skiplist.py +++ b/stdnet/utils/skiplist.py @@ -3,8 +3,8 @@ # 576930-efficient-running-median-using-an-indexable-skipli/ # import sys -from random import random from math import log +from random import random ispy3k = int(sys.version[0]) >= 3 @@ -12,23 +12,22 @@ range = xrange -__all__ = ['skiplist'] +__all__ = ["skiplist"] class Node(object): - __slots__ = ('score', 'value', 'next', 'width') + __slots__ = ("score", "value", "next", "width") def __init__(self, score, value, next, width): - self.score, self.value, self.next, self.width = (score, value, - next, width) + self.score, self.value, self.next, self.width = (score, value, next, width) -SKIPLIST_MAXLEVEL = 32 # Should be enough for 2^32 elements +SKIPLIST_MAXLEVEL = 32 # Should be enough for 2^32 elements class skiplist(object): - '''Sorted collection supporting O(lg n) insertion, -removal, and lookup by rank.''' + """Sorted collection supporting O(lg n) insertion, + removal, and lookup by rank.""" def __init__(self, data=None, unique=False): self.unique = unique @@ -39,10 +38,9 @@ def __init__(self, data=None, unique=False): def clear(self): self.__size = 0 self.__level = 1 - self.__head = Node('HEAD', - None, - [None]*SKIPLIST_MAXLEVEL, - [1]*SKIPLIST_MAXLEVEL) + self.__head = Node( + "HEAD", None, [None] * SKIPLIST_MAXLEVEL, [1] * SKIPLIST_MAXLEVEL + ) def __repr__(self): return list(self).__repr__() @@ -57,27 +55,28 @@ def __getitem__(self, index): node = self.__head traversed = 0 index += 1 - for i in range(self.__level-1, -1, -1): + for i in range(self.__level - 1, -1, -1): while node.next[i] and (traversed + node.width[i]) <= index: traversed += node.width[i] node = node.next[i] if traversed == index: return node.value - raise IndexError('skiplist index out of range') + raise IndexError("skiplist index out of range") def extend(self, iterable): i = self.insert for score_values in iterable: i(*score_values) + update = extend def rank(self, score): - '''Return the 0-based index (rank) of ``score``. If the score is not -available it returns a negative integer which absolute score is the -left most closest index with score less than *score*.''' + """Return the 0-based index (rank) of ``score``. If the score is not + available it returns a negative integer which absolute score is the + left most closest index with score less than *score*.""" node = self.__head rank = 0 - for i in range(self.__level-1, -1, -1): + for i in range(self.__level - 1, -1, -1): while node.next[i] and node.next[i].score <= score: rank += node.width[i] node = node.next[i] @@ -89,13 +88,13 @@ def rank(self, score): def insert(self, score, value): # find first node on each level where node.next[levels].score > score if score != score: - raise ValueError('Cannot insert score {0}'.format(score)) + raise ValueError("Cannot insert score {0}".format(score)) chain = [None] * SKIPLIST_MAXLEVEL rank = [0] * SKIPLIST_MAXLEVEL node = self.__head - for i in range(self.__level-1, -1, -1): - #store rank that is crossed to reach the insert position - rank[i] = 0 if i == self.__level-1 else rank[i+1] + for i in range(self.__level - 1, -1, -1): + # store rank that is crossed to reach the insert position + rank[i] = 0 if i == self.__level - 1 else rank[i + 1] while node.next[i] and node.next[i].score <= score: rank[i] += node.width[i] node = node.next[i] @@ -113,7 +112,7 @@ def insert(self, score, value): self.__level = level # create the new node - node = Node(score, value, [None]*level, [None]*level) + node = Node(score, value, [None] * level, [None] * level) for i in range(level): prevnode = chain[i] steps = rank[0] - rank[i] @@ -133,14 +132,14 @@ def remove(self, score): # find first node on each level where node.next[levels].score >= score chain = [None] * SKIPLIST_MAXLEVEL node = self.__head - for i in range(self.__level-1, -1, -1): + for i in range(self.__level - 1, -1, -1): while node.next[i] and node.next[i].score < score: node = node.next[i] chain[i] = node node = node.next[0] if score != node.score: - raise KeyError('Not Found') + raise KeyError("Not Found") for i in range(self.__level): if chain[i].next[i] == node: @@ -152,7 +151,7 @@ def remove(self, score): self.__size -= 1 def __iter__(self): - 'Iterate over values in sorted order' + "Iterate over values in sorted order" node = self.__head.next[0] while node: yield node.score, node.value diff --git a/stdnet/utils/structures.py b/stdnet/utils/structures.py index 23def35..310a01c 100755 --- a/stdnet/utils/structures.py +++ b/stdnet/utils/structures.py @@ -1,5 +1,5 @@ import sys from collections import * -if sys.version_info < (2, 7): # pragma nocover +if sys.version_info < (2, 7): # pragma nocover from .fallbacks._collections import * diff --git a/stdnet/utils/test.py b/stdnet/utils/test.py index 11683f4..63a0878 100755 --- a/stdnet/utils/test.py +++ b/stdnet/utils/test.py @@ -1,4 +1,4 @@ -'''Test case classes and plugins for stdnet testing. Requires pulsar_. +"""Test case classes and plugins for stdnet testing. Requires pulsar_. TestCase @@ -18,13 +18,13 @@ .. _pulsar: https://pypi.python.org/pypi/pulsar -''' +""" +import logging import os import sys -import logging import pulsar -from pulsar.apps.test import unittest, mock, TestSuite, TestPlugin, sequential +from pulsar.apps.test import TestPlugin, TestSuite, mock, sequential, unittest from stdnet import getdb, settings from stdnet.utils import gen_unique_id @@ -32,30 +32,26 @@ from .populate import populate skipUnless = unittest.skipUnless -LOGGER = logging.getLogger('stdnet.test') +LOGGER = logging.getLogger("stdnet.test") class DataGenerator(object): - '''A generator of data. It must be initialised with the :attr:`size` -parameter obtained from the command line which is avaiable as a class -attribute in :class:`TestCase`. + """A generator of data. It must be initialised with the :attr:`size` + parameter obtained from the command line which is avaiable as a class + attribute in :class:`TestCase`. + + .. attribute:: sizes -.. attribute:: sizes + A dictionary of sizes for this generator. It is a class attribute with + the following entries: ``tiny``, ``small``, ``normal``, ``big`` + and ``huge``. - A dictionary of sizes for this generator. It is a class attribute with - the following entries: ``tiny``, ``small``, ``normal``, ``big`` - and ``huge``. + .. attribute:: size -.. attribute:: size + The actual size of the data to be generated. Obtained from the + :attr:`sizes` and the input ``size`` code during initialisation.""" - The actual size of the data to be generated. Obtained from the - :attr:`sizes` and the input ``size`` code during initialisation. -''' - sizes = {'tiny': 10, - 'small': 100, - 'normal': 1000, - 'big': 10000, - 'huge': 1000000} + sizes = {"tiny": 10, "small": 100, "normal": 1000, "big": 10000, "huge": 1000000} def __init__(self, size, sizes=None): self.sizes = sizes or self.sizes @@ -64,31 +60,33 @@ def __init__(self, size, sizes=None): self.generate() def generate(self): - '''Called during initialisation to generate the data. ``kwargs`` -are additional key-valued parameter passed during initialisation. Must -be implemented by subclasses.''' + """Called during initialisation to generate the data. ``kwargs`` + are additional key-valued parameter passed during initialisation. Must + be implemented by subclasses.""" pass def create(self, test, use_transaction=True): pass - def populate(self, datatype='string', size=None, **kwargs): - '''A shortcut for the :func:`stdnet.utils.populate` function. -If ``size`` is not given, the :attr:`size` is used.''' + def populate(self, datatype="string", size=None, **kwargs): + """A shortcut for the :func:`stdnet.utils.populate` function. + If ``size`` is not given, the :attr:`size` is used.""" size = size or self.size return populate(datatype, size, **kwargs) def random_string(self, min_len=5, max_len=30): - '''Return a random string''' - return populate('string', 1, min_len=min_len, max_len=max_len)[0] + """Return a random string""" + return populate("string", 1, min_len=min_len, max_len=max_len)[0] def create_backend(self, prefix): from stdnet import odm - self.namespace = '%s%s-' % (prefix, gen_unique_id()) + + self.namespace = "%s%s-" % (prefix, gen_unique_id()) if self.connection_string: - server = getdb(self.connection_string, namespace=self.namespace, - **self.backend_params()) + server = getdb( + self.connection_string, namespace=self.namespace, **self.backend_params() + ) self.backend = server yield server.flush() self.mapper = odm.Router(self.backend) @@ -97,65 +95,65 @@ def create_backend(self, prefix): class TestCase(unittest.TestCase): - '''A :class:`unittest.TestCase` subclass for testing stdnet with -synchronous and asynchronous connections. It contains -several class methods for testing in a parallel test suite. + """A :class:`unittest.TestCase` subclass for testing stdnet with + synchronous and asynchronous connections. It contains + several class methods for testing in a parallel test suite. + + .. attribute:: multipledb -.. attribute:: multipledb + class attribute which indicates which backend can run the test. There are + several options: - class attribute which indicates which backend can run the test. There are - several options: + * ``multipledb = False`` The test case does not require a backend and + only one :class:`TestCase` class is added to the test-suite regardless + of which backend has been tested. + * ``multipledb = True``, the default falue. Create as many + :class:`TestCase` classes as the number of backend tested, each backend + will run the tests. + * ``multipledb = string, list, tuple``, Only those backend will run tests. - * ``multipledb = False`` The test case does not require a backend and - only one :class:`TestCase` class is added to the test-suite regardless - of which backend has been tested. - * ``multipledb = True``, the default falue. Create as many - :class:`TestCase` classes as the number of backend tested, each backend - will run the tests. - * ``multipledb = string, list, tuple``, Only those backend will run tests. + .. attribute:: backend -.. attribute:: backend + A :class:`stdnet.BackendDataServer` for this + :class:`TestCase` class. It is a class attribute which is different + for each :class:`TestCase` class and it is created by the + :meth:`setUpClass` method. - A :class:`stdnet.BackendDataServer` for this - :class:`TestCase` class. It is a class attribute which is different - for each :class:`TestCase` class and it is created by the - :meth:`setUpClass` method. + .. attribute:: data_cls -.. attribute:: data_cls + A :class:`DataGenerator` class for creating data. The data is created + during the :meth:`setUpClass` class method. - A :class:`DataGenerator` class for creating data. The data is created - during the :meth:`setUpClass` class method. + .. attribute:: data -.. attribute:: data + The :class:`DataGenerator` instance created from :attr:`data_cls`. - The :class:`DataGenerator` instance created from :attr:`data_cls`. + .. attribute:: model -.. attribute:: model + The default :class:`StdModel` for this test. A class attribute. - The default :class:`StdModel` for this test. A class attribute. + .. attribute:: models -.. attribute:: models + A tuple of models which can be registered by this test. The :attr:`model` + is always the model at index 0 in :attr:`models`. - A tuple of models which can be registered by this test. The :attr:`model` - is always the model at index 0 in :attr:`models`. + .. attribute:: mapper -.. attribute:: mapper + A :class:`stdnet.odm.Router` with all :attr:`models` registered with + :attr:`backend`.""" - A :class:`stdnet.odm.Router` with all :attr:`models` registered with - :attr:`backend`. -''' models = () model = None connection_string = None backend = None sizes = None - prefix = 'stdtest' + prefix = "stdtest" data_cls = DataGenerator @classmethod def backend_params(cls): - '''Optional :attr:`backend` parameters for tests in this -:class:`TestCase` class.''' + """Optional :attr:`backend` parameters for tests in this + :class:`TestCase` class.""" return {} @classmethod @@ -168,20 +166,20 @@ def setup_models(cls): @classmethod def setUpClass(cls): - '''Set up this :class:`TestCase` before test methods are run. here -is where a :attr:`backend` server instance is created and it is unique for this -:class:`TestCase` class. It create the :attr:`mapper`, -a :class:`stdnet.odm.Router` with all :attr:`models` registered. -There shouldn't be any reason to override this method, use :meth:`after_setup` -class method instead.''' + """Set up this :class:`TestCase` before test methods are run. here + is where a :attr:`backend` server instance is created and it is unique for this + :class:`TestCase` class. It create the :attr:`mapper`, + a :class:`stdnet.odm.Router` with all :attr:`models` registered. + There shouldn't be any reason to override this method, use :meth:`after_setup` + class method instead.""" cls.setup_models() yield create_backend(cls, cls.prefix) yield cls.after_setup() @classmethod def after_setup(cls): - '''This class method can be used to setup this :class:`TestCase` class -after the :meth:`setUpClass` was called. By default it does nothing.''' + """This class method can be used to setup this :class:`TestCase` class + after the :meth:`setUpClass` was called. By default it does nothing.""" pass @classmethod @@ -191,33 +189,32 @@ def tearDownClass(cls): @classmethod def session(cls, **kwargs): - '''Create a new :class:`stdnet.odm.Session` bind to the -:attr:`TestCase.backend` attribute.''' + """Create a new :class:`stdnet.odm.Session` bind to the + :attr:`TestCase.backend` attribute.""" return cls.mapper.session() @classmethod def query(cls, model=None): - '''Shortcut function to create a query for a model.''' + """Shortcut function to create a query for a model.""" return cls.session().query(model or cls.model) @classmethod def multi_async(cls, iterable, **kwargs): - '''Treat ``iterable`` as a container of asynchronous results.''' + """Treat ``iterable`` as a container of asynchronous results.""" return pulsar.multi_async(iterable, **kwargs) def assertEqualId(self, instance, value, exact=False): - '''Assert the value of a primary key in a backend agnostic way. + """Assert the value of a primary key in a backend agnostic way. -:param instance: the :class:`StdModel` to check the primary key ``value``. -:param value: the value of the id to check against. -:param exact: if ``True`` the exact value must be matched. For redis backend - this parameter is not used. -''' + :param instance: the :class:`StdModel` to check the primary key ``value``. + :param value: the value of the id to check against. + :param exact: if ``True`` the exact value must be matched. For redis backend + this parameter is not used.""" pk = instance.pkvalue() - if exact or self.backend.name == 'redis': + if exact or self.backend.name == "redis": self.assertEqual(pk, value) - elif self.backend.name == 'mongo': - if instance._meta.pk.type == 'auto': + elif self.backend.name == "mongo": + if instance._meta.pk.type == "auto": self.assertTrue(pk) else: self.assertEqual(pk, value) @@ -227,8 +224,9 @@ def assertEqualId(self, instance, value, exact=False): class TestWrite(TestCase): - '''A variant of :class:`TestCase` which clean the backend at each -test function. Useful when testing write operations.''' + """A variant of :class:`TestCase` which clean the backend at each + test function. Useful when testing write operations.""" + @classmethod def setUpClass(cls): cls.setup_models() @@ -249,29 +247,33 @@ def session(self): return self.mapper.session() def query(self, model=None): - '''Shortcut function to create a query for a model.''' + """Shortcut function to create a query for a model.""" return self.session().query(model or self.model) class StdnetPlugin(TestPlugin): name = "server" flags = ["-s", "--server"] - nargs = '*' - desc = 'Back-end data server where to run tests.' + nargs = "*" + desc = "Back-end data server where to run tests." default = [settings.DEFAULT_BACKEND] validator = pulsar.validate_list py_redis_parser = pulsar.Setting( - flags=['--py-redis-parser'], - desc=('Run tests using the python redis parser rather ' - 'the C implementation.'), + flags=["--py-redis-parser"], + desc=( + "Run tests using the python redis parser rather " "the C implementation." + ), action="store_true", - default=False) + default=False, + ) - sync = pulsar.Setting(flags=['--sync'], - desc='Switch off asynchronous bindings', - action="store_true", - default=False) + sync = pulsar.Setting( + flags=["--sync"], + desc="Switch off asynchronous bindings", + action="store_true", + default=False, + ) def configure(self, cfg): if cfg.sync: @@ -287,23 +289,21 @@ def on_start(self): s = getdb(s) s.ping() except Exception: - LOGGER.error('Could not obtain server %s' % s, - exc_info=True) + LOGGER.error("Could not obtain server %s" % s, exc_info=True) else: if s.name not in names: names.add(s.name) servers.append(s.connection_string) if not servers: - raise pulsar.HaltServer('No server available. BAILING OUT') + raise pulsar.HaltServer("No server available. BAILING OUT") settings.servers = servers - -class testmaker(object): +class testmaker(object): def __init__(self, test, name, server): self.test = test - self.cls_name = '%s_%s' % (test.__name__, name) - self.server = server + self.cls_name = "%s_%s" % (test.__name__, name) + self.server = server def __call__(self): new_test = type(self.cls_name, (self.test,), {}) @@ -312,11 +312,11 @@ def __call__(self): def create_tests(suite, tests=None): - servers = getattr(settings, 'servers', None) + servers = getattr(settings, "servers", None) if isinstance(suite, TestSuite) and servers: for tag, test in list(tests): tests.pop(0) - multipledb = getattr(test, 'multipledb', True) + multipledb = getattr(test, "multipledb", True) toadd = True if isinstance(multipledb, str): multipledb = [multipledb] @@ -324,7 +324,7 @@ def create_tests(suite, tests=None): toadd = False if multipledb: for server in servers: - name = server.split('://')[0] + name = server.split("://")[0] if multipledb is True or name in multipledb: toadd = False tests.append((tag, testmaker(test, name, server))) diff --git a/stdnet/utils/version.py b/stdnet/utils/version.py index d8af506..b900be0 100644 --- a/stdnet/utils/version.py +++ b/stdnet/utils/version.py @@ -4,32 +4,32 @@ from collections import namedtuple -class stdnet_version(namedtuple('stdnet_version', - 'major minor micro releaselevel serial')): +class stdnet_version( + namedtuple("stdnet_version", "major minor micro releaselevel serial") +): __impl = None def __new__(cls, *args, **kwargs): if cls.__impl is None: - cls.__impl = super(stdnet_version, cls).__new__(cls, *args, - **kwargs) + cls.__impl = super(stdnet_version, cls).__new__(cls, *args, **kwargs) return cls.__impl else: - raise TypeError('cannot create stdnet_version instances') + raise TypeError("cannot create stdnet_version instances") def get_version(version): "Returns a PEP 386-compliant version number from *version*." assert len(version) == 5 - assert version[3] in ('alpha', 'beta', 'rc', 'final') + assert version[3] in ("alpha", "beta", "rc", "final") parts = 2 if version[2] == 0 else 3 - main = '.'.join(map(str, version[:parts])) - sub = '' - if version[3] == 'alpha' and version[4] == 0: + main = ".".join(map(str, version[:parts])) + sub = "" + if version[3] == "alpha" and version[4] == 0: git_changeset = get_git_changeset() if git_changeset: - sub = '.dev%s' % git_changeset - elif version[3] != 'final': - mapping = {'alpha': 'a', 'beta': 'b', 'rc': 'c'} + sub = ".dev%s" % git_changeset + elif version[3] != "final": + mapping = {"alpha": "a", "beta": "b", "rc": "c"} sub = mapping[version[3]] + str(version[4]) return main + sub @@ -42,13 +42,17 @@ def get_git_changeset(): so it's sufficient for generating the development version numbers. """ repo_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) - git_show = subprocess.Popen('git show --pretty=format:%ct --quiet HEAD', - stdout=subprocess.PIPE, stderr=subprocess.PIPE, - shell=True, cwd=repo_dir, - universal_newlines=True) - timestamp = git_show.communicate()[0].partition('\n')[0] + git_show = subprocess.Popen( + "git show --pretty=format:%ct --quiet HEAD", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True, + cwd=repo_dir, + universal_newlines=True, + ) + timestamp = git_show.communicate()[0].partition("\n")[0] try: timestamp = datetime.datetime.utcfromtimestamp(int(timestamp)) except ValueError: return None - return timestamp.strftime('%Y%m%d%H%M%S') + return timestamp.strftime("%Y%m%d%H%M%S") diff --git a/stdnet/utils/zset.py b/stdnet/utils/zset.py index 8f1966a..8d127d4 100644 --- a/stdnet/utils/zset.py +++ b/stdnet/utils/zset.py @@ -4,11 +4,12 @@ ispy3k = int(sys.version[0]) >= 3 -__all__ = ['zset'] +__all__ = ["zset"] class zset(object): - '''Ordered-set equivalent of redis zset.''' + """Ordered-set equivalent of redis zset.""" + def __init__(self): self.clear() @@ -26,8 +27,7 @@ def __iter__(self): yield value def items(self): - '''Iterable over ordered score, value pairs of this :class:`zset` - ''' + """Iterable over ordered score, value pairs of this :class:`zset`""" return iter(self._sl) def add(self, score, val): @@ -43,27 +43,27 @@ def add(self, score, val): return r def update(self, score_vals): - '''Update the :class:`zset` with an iterable over pairs of -scores and values.''' + """Update the :class:`zset` with an iterable over pairs of + scores and values.""" add = self.add for score, value in score_vals: add(score, value) def remove(self, item): - '''Remove ``item`` for the :class:`zset` it it exists. -If found it returns the score of the item removed.''' + """Remove ``item`` for the :class:`zset` it it exists. + If found it returns the score of the item removed.""" score = self._dict.pop(item, None) if score is not None: self._sl.remove(score) return score def clear(self): - '''Clear this :class:`zset`.''' + """Clear this :class:`zset`.""" self._sl = skiplist() self._dict = {} def rank(self, item): - '''Return the rank (index) of ``item`` in this :class:`zset`.''' + """Return the rank (index) of ``item`` in this :class:`zset`.""" score = self._dict.get(item) if score is not None: return self._sl.rank(score) diff --git a/tests/all/apps/columnts/evaluate.py b/tests/all/apps/columnts/evaluate.py index 3028609..1f2c979 100644 --- a/tests/all/apps/columnts/evaluate.py +++ b/tests/all/apps/columnts/evaluate.py @@ -1,25 +1,22 @@ from datetime import date -from stdnet.utils import test from stdnet.apps.columnts import ColumnTS +from stdnet.utils import test from .main import ColumnMixin class TestEvaluate(ColumnMixin, test.TestCase): - def test_simple(self): ts = self.empty() - l = yield ts.evaluate('return self:length()') + l = yield ts.evaluate("return self:length()") self.assertEqual(l, 0) - yield ts.update({date(2012,5,15): {'open':605}, - date(2012,5,16): {'open':617}}) - yield self.async.assertEqual(ts.evaluate('return self:length()'), 2) - yield self.async.assertEqual(ts.evaluate('return self:fields()'), - [b'open']) - #Return the change from last open with respect previous open - change = "return self:rank_value(-1,'open')-"\ - "self:rank_value(-2,'open')" + yield ts.update( + {date(2012, 5, 15): {"open": 605}, date(2012, 5, 16): {"open": 617}} + ) + yield self.async.assertEqual(ts.evaluate("return self:length()"), 2) + yield self.async.assertEqual(ts.evaluate("return self:fields()"), [b"open"]) + # Return the change from last open with respect previous open + change = "return self:rank_value(-1,'open')-" "self:rank_value(-2,'open')" change = yield ts.evaluate(change) self.assertEqual(change, 12) - \ No newline at end of file diff --git a/tests/all/apps/columnts/field.py b/tests/all/apps/columnts/field.py index 29662df..4c9110d 100644 --- a/tests/all/apps/columnts/field.py +++ b/tests/all/apps/columnts/field.py @@ -1,9 +1,8 @@ -from stdnet.utils import test +from examples.tsmodels import ColumnTimeSeries +from stdnet.utils import test from tests.all.multifields.struct import MultiFieldMixin -from examples.tsmodels import ColumnTimeSeries - from .npts import ColumnTimeSeriesNumpy, skipUnless @@ -12,13 +11,13 @@ class TestColumnTSField(MultiFieldMixin, test.TestCase): def testModel(self): meta = self.model._meta - self.assertTrue(len(meta.multifields),1) + self.assertTrue(len(meta.multifields), 1) m = meta.multifields[0] - self.assertEqual(m.name,'data') + self.assertEqual(m.name, "data") self.assertTrue(isinstance(m.value_pickler, encoders.Double)) -@skipUnless(ColumnTimeSeriesNumpy, 'Requires stdnet-redis and dynts') +@skipUnless(ColumnTimeSeriesNumpy, "Requires stdnet-redis and dynts") class TestColumnTSField(TestColumnTSField): model = ColumnTimeSeriesNumpy @@ -27,7 +26,7 @@ def setUp(self): def testMeta(self): meta = self.model._meta - self.assertTrue(len(meta.multifields),1) + self.assertTrue(len(meta.multifields), 1) m = meta.multifields[0] - self.assertEqual(m.name, 'data') + self.assertEqual(m.name, "data") self.assertTrue(isinstance(m.value_pickler, encoders.Double)) diff --git a/tests/all/apps/columnts/main.py b/tests/all/apps/columnts/main.py index d0c5fca..68a224e 100644 --- a/tests/all/apps/columnts/main.py +++ b/tests/all/apps/columnts/main.py @@ -1,85 +1,81 @@ import os -from random import randint from datetime import date, datetime, timedelta +from random import randint from struct import unpack -from stdnet import SessionNotAvailable, CommitException -from stdnet.utils import test, encoders, populate, ispy3k, iteritems +from stdnet import CommitException, SessionNotAvailable from stdnet.apps.columnts import ColumnTS, as_dict from stdnet.backends import redisb - +from stdnet.utils import encoders, ispy3k, iteritems, populate, test from tests.all.structures.base import StructMixin - -nan = float('nan') +nan = float("nan") this_path = os.path.split(os.path.abspath(__file__))[0] -bin_to_float = lambda f : unpack('>d', f)[0] +bin_to_float = lambda f: unpack(">d", f)[0] if ispy3k: # pragma nocover bitflag = lambda value: value -else: # pragma nocover +else: # pragma nocover bitflag = ord + class timeseries_test1(redisb.RedisScript): - script = (redisb.read_lua_file('tabletools'), - redisb.read_lua_file('columnts.columnts'), - redisb.read_lua_file('test1',this_path)) + script = ( + redisb.read_lua_file("tabletools"), + redisb.read_lua_file("columnts.columnts"), + redisb.read_lua_file("test1", this_path), + ) class ColumnData(test.DataGenerator): - sizes = {'tiny': 100, - 'small': 300, - 'normal': 2000, - 'big': 10000, - 'huge': 1000000} + sizes = {"tiny": 100, "small": 300, "normal": 2000, "big": 10000, "huge": 1000000} def generate(self): size = self.size - self.data1 = tsdata(self, ('a','b','c','d','f','g')) - self.data2 = tsdata(self, ('a','b','c','d','f','g')) - self.data3 = tsdata(self, ('a','b','c','d','f','g')) - self.missing = tsdata(self, ('a','b','c','d','f','g'), missing=True) - self.data_mul1 = tsdata(self, ('eurusd',)) - self.data_mul2 = tsdata(self, ('gbpusd',)) + self.data1 = tsdata(self, ("a", "b", "c", "d", "f", "g")) + self.data2 = tsdata(self, ("a", "b", "c", "d", "f", "g")) + self.data3 = tsdata(self, ("a", "b", "c", "d", "f", "g")) + self.missing = tsdata(self, ("a", "b", "c", "d", "f", "g"), missing=True) + self.data_mul1 = tsdata(self, ("eurusd",)) + self.data_mul2 = tsdata(self, ("gbpusd",)) class tsdata(object): - def __init__(self, g, fields, start=None, end=None, missing=False): end = end or date.today() if not start: start = end - timedelta(days=g.size) # random dates - self.dates = g.populate('date', start=start, end=end) + self.dates = g.populate("date", start=start, end=end) self.unique_dates = set(self.dates) self.fields = {} self.sorted_fields = {} for field in fields: - vals = g.populate('float') + vals = g.populate("float") if missing: N = len(vals) - for num in range(randint(0, N//2)): - index = randint(0, N-1) + for num in range(randint(0, N // 2)): + index = randint(0, N - 1) vals[index] = nan self.fields[field] = vals self.sorted_fields[field] = [] self.values = [] date_dict = {} - for i,dt in enumerate(self.dates): - vals = dict(((f,v[i]) for f,v in iteritems(self.fields))) - self.values.append((dt,vals)) + for i, dt in enumerate(self.dates): + vals = dict(((f, v[i]) for f, v in iteritems(self.fields))) + self.values.append((dt, vals)) date_dict[dt] = vals sdates = [] - for i,dt in enumerate(sorted(date_dict)): + for i, dt in enumerate(sorted(date_dict)): sdates.append(dt) fields = date_dict[dt] for field in fields: self.sorted_fields[field].append(fields[field]) - self.sorted_values = (sdates,self.sorted_fields) + self.sorted_values = (sdates, self.sorted_fields) self.length = len(sdates) def create(self, test, id=None): - '''Create one ColumnTS with six fields and cls.size dates''' + """Create one ColumnTS with six fields and cls.size dates""" models = test.mapper ts = models.register(test.structure()) models.session().add(ts) @@ -91,20 +87,30 @@ def create(self, test, id=None): class ColumnMixin(object): - '''Used by all tests on ColumnTS''' + """Used by all tests on ColumnTS""" + structure = ColumnTS - name = 'columnts' + name = "columnts" data_cls = ColumnData def create_one(self): ts = self.structure() - d1 = date(2012,1,23) - data = {d1: {'open':586, 'high':588.66, - 'low':583.16, 'close':585.52}, - date(2012,1,20): {'open':590.53, 'high':591, - 'low':581.7, 'close':585.99}, - date(2012,1,19): {'open':640.99, 'high':640.99, - 'low':631.46, 'close':639.57}} + d1 = date(2012, 1, 23) + data = { + d1: {"open": 586, "high": 588.66, "low": 583.16, "close": 585.52}, + date(2012, 1, 20): { + "open": 590.53, + "high": 591, + "low": 581.7, + "close": 585.99, + }, + date(2012, 1, 19): { + "open": 640.99, + "high": 640.99, + "low": 631.46, + "close": 639.57, + }, + } ts.add(d1, data[d1]) self.data = data data = self.data.copy() @@ -112,7 +118,7 @@ def create_one(self): data = tuple(data.items()) ts.update(data) # test bad add - self.assertRaises(TypeError, ts.add, date(2012,1,20), 1, 2, 3) + self.assertRaises(TypeError, ts.add, date(2012, 1, 20), 1, 2, 3) return ts def empty(self): @@ -125,18 +131,18 @@ def empty(self): def check_stats(self, stat_field, data): N = len(data) - cdata = list((d for d in data if d==d)) - cdata2 = list((d*d for d in cdata)) - dd = list((a-b for a,b in zip(cdata[1:],cdata[:-1]))) - dd2 = list((d*d for d in dd)) + cdata = list((d for d in data if d == d)) + cdata2 = list((d * d for d in cdata)) + dd = list((a - b for a, b in zip(cdata[1:], cdata[:-1]))) + dd2 = list((d * d for d in dd)) NC = len(cdata) - self.assertEqual(stat_field['N'],NC) - self.assertAlmostEqual(stat_field['min'], min(cdata)) - self.assertAlmostEqual(stat_field['max'], max(cdata)) - self.assertAlmostEqual(stat_field['sum'], sum(cdata)/NC) - self.assertAlmostEqual(stat_field['sum2'], sum(cdata2)/NC) - self.assertAlmostEqual(stat_field['dsum'], sum(dd)/(NC-1)) - self.assertAlmostEqual(stat_field['dsum2'], sum(dd2)/(NC-1)) + self.assertEqual(stat_field["N"], NC) + self.assertAlmostEqual(stat_field["min"], min(cdata)) + self.assertAlmostEqual(stat_field["max"], max(cdata)) + self.assertAlmostEqual(stat_field["sum"], sum(cdata) / NC) + self.assertAlmostEqual(stat_field["sum2"], sum(cdata2) / NC) + self.assertAlmostEqual(stat_field["dsum"], sum(dd) / (NC - 1)) + self.assertAlmostEqual(stat_field["dsum2"], sum(dd2) / (NC - 1)) def as_dict(self, serie): times, fields = yield serie.irange() @@ -144,7 +150,7 @@ def as_dict(self, serie): def makeGoogle(self): ts = self.mapper.register(self.create_one()) - self.assertTrue(len(ts.cache.fields['open']), 2) + self.assertTrue(len(ts.cache.fields["open"]), 2) self.assertTrue(len(ts.cache.fields), 4) yield self.mapper.session().add(ts) yield self.async.assertEqual(ts.size(), 3) @@ -161,17 +167,16 @@ def makeGoogle(self): class TestTimeSeries(ColumnMixin, StructMixin, test.TestCase): - def testLuaClass(self): ts = self.empty() backend = ts.backend_structure() self.assertEqual(backend.instance, ts) c = backend.client - r = yield c.execute_script('timeseries_test1', (backend.id,)) - self.assertEqual(r, b'OK') + r = yield c.execute_script("timeseries_test1", (backend.id,)) + self.assertEqual(r, b"OK") def testEmpty2(self): - '''Check an empty timeseries''' + """Check an empty timeseries""" ts = self.empty() yield self.async.assertEqual(ts.numfields(), 0) yield self.async.assertEqual(ts.fields(), ()) @@ -185,36 +190,36 @@ def testFrontBack(self): d2 = date.today() d1 = d2 - timedelta(days=2) with ts.session.begin() as t: - ts.add(d2,'foo',-5.2) - ts.add(d1,'foo',789.3) + ts.add(d2, "foo", -5.2) + ts.add(d1, "foo", 789.3) yield t.on_result - yield self.async.assertEqual(ts.size(),2) - yield self.async.assertEqual(ts.front(), (d1, {'foo':789.3})) - yield self.async.assertEqual(ts.back(), (d2, {'foo':-5.2})) + yield self.async.assertEqual(ts.size(), 2) + yield self.async.assertEqual(ts.front(), (d1, {"foo": 789.3})) + yield self.async.assertEqual(ts.back(), (d2, {"foo": -5.2})) def test_ddd_simple(self): ts = self.empty() with ts.session.begin() as t: - ts.add(date.today(), 'pv', 56) + ts.add(date.today(), "pv", 56) self.assertTrue(ts.cache.fields) - ts.add(date.today()-timedelta(days=2), 'pv', 53.8) - self.assertTrue(len(ts.cache.fields['pv']), 2) + ts.add(date.today() - timedelta(days=2), "pv", 53.8) + self.assertTrue(len(ts.cache.fields["pv"]), 2) yield t.on_result - yield self.async.assertEqual(ts.fields(), ('pv',)) + yield self.async.assertEqual(ts.fields(), ("pv",)) yield self.async.assertEqual(ts.numfields(), 1) yield self.async.assertEqual(ts.size(), 2) # # Check that a string is available at the field key bts = ts.backend_structure() keys = yield bts.allkeys() - keys = tuple((b.decode('utf-8') for b in keys)) + keys = tuple((b.decode("utf-8") for b in keys)) self.assertEqual(len(keys), 3) self.assertTrue(bts.id in keys) self.assertTrue(bts.fieldsid in keys) - self.assertTrue(bts.fieldid('pv') in keys) - raw_data = bts.field('pv') + self.assertTrue(bts.fieldid("pv") in keys) + raw_data = bts.field("pv") self.assertTrue(raw_data) - self.assertEqual(len(raw_data),18) + self.assertEqual(len(raw_data), 18) a1 = raw_data[:9] a2 = raw_data[9:] n = bitflag(a1[0]) @@ -224,54 +229,54 @@ def test_ddd_simple(self): self.assertEqual(bin_to_float(a2[1:]), 56) # data = ts.irange() - self.assertEqual(len(data),2) - dt,fields = data - self.assertEqual(len(dt),2) - self.assertTrue('pv' in fields) - for v, t in zip(fields['pv'],[53.8, 56]): + self.assertEqual(len(data), 2) + dt, fields = data + self.assertEqual(len(dt), 2) + self.assertTrue("pv" in fields) + for v, t in zip(fields["pv"], [53.8, 56]): self.assertAlmostEqual(v, t) def test_add_nil(self): ts = self.empty() with ts.session.begin() as t: - ts.add(date.today(), 'pv', 56) - ts.add(date.today()-timedelta(days=2), 'pv', nan) + ts.add(date.today(), "pv", 56) + ts.add(date.today() - timedelta(days=2), "pv", nan) yield t.on_result yield self.async.assertEqual(ts.size(), 2) dt, fields = yield ts.irange() self.assertEqual(len(dt), 2) - self.assertTrue('pv' in fields) - n = fields['pv'][0] + self.assertTrue("pv" in fields) + n = fields["pv"][0] self.assertNotEqual(n, n) def testGoogleDrop(self): ts = yield self.makeGoogle() - yield self.async.assertEqual(ts.fields(), ('close','high','low','open')) + yield self.async.assertEqual(ts.fields(), ("close", "high", "low", "open")) yield self.async.assertEqual(ts.numfields(), 4) yield self.async.assertEqual(ts.size(), 3) def testRange(self): ts = yield self.makeGoogle() data = ts.irange() - self.assertEqual(len(data),2) - dt,fields = data - self.assertEqual(len(fields),4) - high = list(zip(dt,fields['high'])) - self.assertEqual(high[0],(datetime(2012,1,19),640.99)) - self.assertEqual(high[1],(datetime(2012,1,20),591)) - self.assertEqual(high[2],(datetime(2012,1,23),588.66)) + self.assertEqual(len(data), 2) + dt, fields = data + self.assertEqual(len(fields), 4) + high = list(zip(dt, fields["high"])) + self.assertEqual(high[0], (datetime(2012, 1, 19), 640.99)) + self.assertEqual(high[1], (datetime(2012, 1, 20), 591)) + self.assertEqual(high[2], (datetime(2012, 1, 23), 588.66)) def testRangeField(self): ts = yield self.makeGoogle() - data = ts.irange(fields=('low','high','badone')) - self.assertEqual(len(data),2) - dt,fields = data - self.assertEqual(len(fields),2) - low = list(zip(dt,fields['low'])) - high = list(zip(dt,fields['high'])) - self.assertEqual(high[0],(datetime(2012,1,19),640.99)) - self.assertEqual(high[1],(datetime(2012,1,20),591)) - self.assertEqual(high[2],(datetime(2012,1,23),588.66)) + data = ts.irange(fields=("low", "high", "badone")) + self.assertEqual(len(data), 2) + dt, fields = data + self.assertEqual(len(fields), 2) + low = list(zip(dt, fields["low"])) + high = list(zip(dt, fields["high"])) + self.assertEqual(high[0], (datetime(2012, 1, 19), 640.99)) + self.assertEqual(high[1], (datetime(2012, 1, 20), 591)) + self.assertEqual(high[2], (datetime(2012, 1, 23), 588.66)) def testRaises(self): ts = yield self.makeGoogle() @@ -281,21 +286,35 @@ def testRaises(self): self.assertRaises(SessionNotAvailable, ts.merge, (5, ts)) def testUpdateDict(self): - '''Test updating via a dictionary.''' + """Test updating via a dictionary.""" ts = yield self.makeGoogle() - data = {date(2012,1,23):{'open':586.00, 'high':588.66, - 'low':583.16, 'close':585.52}, - date(2012,1,25):{'open':586.32, 'high':687.68, - 'low':578, 'close':580.93}, - date(2012,1,24):{'open':586.32, 'high':687.68, - 'low':578, 'close':580.93}} + data = { + date(2012, 1, 23): { + "open": 586.00, + "high": 588.66, + "low": 583.16, + "close": 585.52, + }, + date(2012, 1, 25): { + "open": 586.32, + "high": 687.68, + "low": 578, + "close": 580.93, + }, + date(2012, 1, 24): { + "open": 586.32, + "high": 687.68, + "low": 578, + "close": 580.93, + }, + } ts.update(data) self.assertEqual(ts.size(), 5) - dates, fields = ts.range(date(2012,1,23), date(2012,1,25)) - self.assertEqual(len(dates),3) - self.assertEqual(dates[0].date(),date(2012,1,23)) - self.assertEqual(dates[1].date(),date(2012,1,24)) - self.assertEqual(dates[2].date(),date(2012,1,25)) + dates, fields = ts.range(date(2012, 1, 23), date(2012, 1, 25)) + self.assertEqual(len(dates), 3) + self.assertEqual(dates[0].date(), date(2012, 1, 23)) + self.assertEqual(dates[1].date(), date(2012, 1, 24)) + self.assertEqual(dates[2].date(), date(2012, 1, 25)) for field in fields: for d, v1 in zip(dates, fields[field]): v2 = data[d.date()][field] @@ -307,34 +326,33 @@ def __testBadQuery(self): id = ts.dbid() client = ts.session.backend.client client.delete(id) - client.rpush(id, 'bla') - client.rpush(id, 'foo') + client.rpush(id, "bla") + client.rpush(id, "foo") self.assertEqual(client.llen(id), 2) - self.assertRaises(redisb.ScriptError, ts.add, - date(2012,1,23), {'open':586}) + self.assertRaises(redisb.ScriptError, ts.add, date(2012, 1, 23), {"open": 586}) self.assertRaises(redisb.ScriptError, ts.irange) self.assertRaises(redisb.RedisInvalidResponse, ts.size) def test_get(self): ts = yield self.makeGoogle() - v = yield ts.get(date(2012,1,23)) + v = yield ts.get(date(2012, 1, 23)) self.assertTrue(v) - self.assertEqual(len(v),4) - v2 = ts[date(2012,1,23)] - self.assertEqual(v,v2) - self.assertEqual(ts.get(date(2014,1,1)),None) - self.assertRaises(KeyError, lambda: ts[date(2014,1,1)]) + self.assertEqual(len(v), 4) + v2 = ts[date(2012, 1, 23)] + self.assertEqual(v, v2) + self.assertEqual(ts.get(date(2014, 1, 1)), None) + self.assertRaises(KeyError, lambda: ts[date(2014, 1, 1)]) def testSet(self): ts = yield self.makeGoogle() - ts[date(2012,1,27)] = {'open': 600} + ts[date(2012, 1, 27)] = {"open": 600} self.assertEqual(len(ts), 4) - res = ts[date(2012,1,27)] - self.assertEqual(len(res),4) - self.assertEqual(res['open'], 600) - self.assertNotEqual(res['close'],res['close']) - self.assertNotEqual(res['high'],res['high']) - self.assertNotEqual(res['low'],res['low']) + res = ts[date(2012, 1, 27)] + self.assertEqual(len(res), 4) + self.assertEqual(res["open"], 600) + self.assertNotEqual(res["close"], res["close"]) + self.assertNotEqual(res["high"], res["high"]) + self.assertNotEqual(res["low"], res["low"]) def test_times(self): ts = yield self.makeGoogle() @@ -346,7 +364,6 @@ def test_times(self): class TestOperations(ColumnMixin, test.TestCase): - @classmethod def after_setup(cls): cls.ts1 = yield cls.data.data1.create(cls) @@ -372,19 +389,19 @@ def test_merge2series(self): yield self.async.assertTrue(ts3.size()) yield self.async.assertEqual(ts3.numfields(), 6) times, fields = ts3.irange() - for i,dt in enumerate(times): + for i, dt in enumerate(times): dt = dt.date() v1 = ts1.get(dt) v2 = ts2.get(dt) if dt in data.data1.unique_dates and dt in data.data2.unique_dates: for field, values in fields.items(): - res = 2*v2[field] - v1[field] - self.assertAlmostEqual(values[i],res) + res = 2 * v2[field] - v1[field] + self.assertAlmostEqual(values[i], res) else: self.assertTrue(v1 is None or v2 is None) for values in fields.values(): v = values[i] - self.assertNotEqual(v,v) + self.assertNotEqual(v, v) def test_merge3series(self): data = self.data @@ -401,8 +418,9 @@ def test_merge3series(self): self.assertEqual(ts4.session, session) yield t.on_result length = yield ts4.size() - self.assertTrue(length >= max(data.data1.length, data.data2.length, - data.data3.length)) + self.assertTrue( + length >= max(data.data1.length, data.data2.length, data.data3.length) + ) yield self.async.assertEqual(ts2.numfields(), 6) # results = yield self.as_dict(ts4) @@ -418,7 +436,7 @@ def test_merge3series(self): if v1 is not None and v2 is not None and v3 is not None: for field in result: vc = result[field] - res = 0.5*v1[field] + 1.3*v2[field] - 2.65*v3[field] + res = 0.5 * v1[field] + 1.3 * v2[field] - 2.65 * v3[field] self.assertAlmostEqual(vc, res) else: for v in result.values(): @@ -448,14 +466,14 @@ def test_add_multiply1(self): m1 = mul1.get(dt) result = results[dt] if v1 is not None and v2 is not None and m1 is not None: - m1 = m1['eurusd'] + m1 = m1["eurusd"] for field in result: vc = result[field] - res = 1.5*m1*v1[field] - 1.2*v2[field] + res = 1.5 * m1 * v1[field] - 1.2 * v2[field] self.assertAlmostEqual(vc, res) else: for v in result.values(): - self.assertNotEqual(v,v) + self.assertNotEqual(v, v) def test_add_multiply2(self): data = self.data @@ -471,41 +489,39 @@ def test_add_multiply2(self): self.assertTrue(length >= max(data.data1.length, data.data2.length)) yield self.async.assertEqual(ts.numfields(), 6) times, fields = ts.irange() - for i,dt in enumerate(times): + for i, dt in enumerate(times): dt = dt.date() v1 = ts1.get(dt) v2 = ts2.get(dt) m1 = mul1.get(dt) m2 = mul2.get(dt) - if v1 is not None and v2 is not None and m1 is not None\ - and m2 is not None: - m1 = m1['eurusd'] - m2 = m2['gbpusd'] - for field,values in fields.items(): - res = 1.5*m1*v1[field] - 1.2*m2*v2[field] - self.assertAlmostEqual(values[i],res) + if v1 is not None and v2 is not None and m1 is not None and m2 is not None: + m1 = m1["eurusd"] + m2 = m2["gbpusd"] + for field, values in fields.items(): + res = 1.5 * m1 * v1[field] - 1.2 * m2 * v2[field] + self.assertAlmostEqual(values[i], res) else: for values in fields.values(): v = values[i] - self.assertNotEqual(v,v) + self.assertNotEqual(v, v) def test_multiply_no_store(self): data = self.data ts1, ts2 = self.ts1, self.ts2 - times, fields = yield self.structure.merged_series((1.5, ts1), - (-1.2, ts2)) - for i,dt in enumerate(times): + times, fields = yield self.structure.merged_series((1.5, ts1), (-1.2, ts2)) + for i, dt in enumerate(times): dt = dt.date() v1 = ts1.get(dt) v2 = ts2.get(dt) if v1 is not None and v2 is not None: - for field,values in fields.items(): - res = 1.5*v1[field] - 1.2*v2[field] - self.assertAlmostEqual(values[i],res) + for field, values in fields.items(): + res = 1.5 * v1[field] - 1.2 * v2[field] + self.assertAlmostEqual(values[i], res) else: for values in fields.values(): v = values[i] - self.assertNotEqual(v,v) + self.assertNotEqual(v, v) def test_merge_fields(self): data = self.data @@ -514,36 +530,36 @@ def test_merge_fields(self): session = self.mapper.session() with session.begin() as t: t.add(ts) - ts.merge((1.5, mul1, ts1), (-1.2, mul2, ts2), - fields=('a','b','c','badone')) - self.assertEqual(ts.session,session) + ts.merge( + (1.5, mul1, ts1), (-1.2, mul2, ts2), fields=("a", "b", "c", "badone") + ) + self.assertEqual(ts.session, session) yield t.on_result length = yield ts.size() self.assertTrue(length >= max(data.data1.length, data.data2.length)) yield self.async.assertEqual(ts.numfields(), 3) - yield self.async.assertEqual(ts.fields(), ('a','b','c')) + yield self.async.assertEqual(ts.fields(), ("a", "b", "c")) times, fields = yield ts.irange() - for i,dt in enumerate(times): + for i, dt in enumerate(times): dt = dt.date() v1 = ts1.get(dt) v2 = ts2.get(dt) m1 = mul1.get(dt) m2 = mul2.get(dt) - if v1 is not None and v2 is not None and m1 is not None\ - and m2 is not None: - m1 = m1['eurusd'] - m2 = m2['gbpusd'] - for field,values in fields.items(): - res = 1.5*m1*v1[field] - 1.2*m2*v2[field] - self.assertAlmostEqual(values[i],res) + if v1 is not None and v2 is not None and m1 is not None and m2 is not None: + m1 = m1["eurusd"] + m2 = m2["gbpusd"] + for field, values in fields.items(): + res = 1.5 * m1 * v1[field] - 1.2 * m2 * v2[field] + self.assertAlmostEqual(values[i], res) else: for values in fields.values(): v = values[i] - self.assertNotEqual(v,v) + self.assertNotEqual(v, v) class a: -#class TestMissingValues(TestOperations): + # class TestMissingValues(TestOperations): @classmethod def after_setup(cls): @@ -551,8 +567,7 @@ def after_setup(cls): def test_missing(self): result = self.ts1.istats(0, -1) - stats = result['stats'] + stats = result["stats"] self.assertEqual(len(stats), 6) for stat in stats: - self.check_stats(stats[stat],self.fields[stat]) - + self.check_stats(stats[stat], self.fields[stat]) diff --git a/tests/all/apps/columnts/manipulate.py b/tests/all/apps/columnts/manipulate.py index b4794b8..07f17b9 100644 --- a/tests/all/apps/columnts/manipulate.py +++ b/tests/all/apps/columnts/manipulate.py @@ -6,10 +6,9 @@ class TestManipulate(ColumnMixin, test.TestCase): - def create(self): return self.data.data1.create(self) - + def pop_range(self, byrank, ts, start, end, num_popped, sl, sl2=None): all_dates, all_fields = yield ts.irange() self.assertEqual(len(all_fields), 6) @@ -21,7 +20,7 @@ def pop_range(self, byrank, ts, start, end, num_popped, sl, sl2=None): dt, fs = yield ts.pop_range(start, end) self.assertEqual(len(dt), num_popped) size = yield ts.size() - self.assertEqual(size, len(all_dates)-num_popped) + self.assertEqual(size, len(all_dates) - num_popped) self.assertEqual(dates, dt) self.assertEqual(fields, fs) # @@ -37,29 +36,29 @@ def pop_range(self, byrank, ts, start, end, num_popped, sl, sl2=None): for f in fields: self.assertEqual(len(fields[f]), len(fs[f])) self.assertEqual(fields, fs) - + def test_ipop_range_back(self): ts = yield self.create() - yield self.pop_range(True, ts, -2, -1, 2, slice(0,-2)) - + yield self.pop_range(True, ts, -2, -1, 2, slice(0, -2)) + def test_ipop_range_middle(self): ts = yield self.create() all_dates, all_fields = yield ts.irange() - yield self.pop_range(True, ts, -10, -5, 6, slice(0,-10), slice(-4, None)) - + yield self.pop_range(True, ts, -10, -5, 6, slice(0, -10), slice(-4, None)) + def test_ipop_range_start(self): ts = yield self.create() # popping the first 11 records yield self.pop_range(True, ts, 0, 10, 11, slice(11, None)) - + def test_pop_range_back(self): ts = yield self.create() start, end = yield ts.itimes(-2) - yield self.pop_range(False, ts, start, end, 2, slice(0,-2)) - + yield self.pop_range(False, ts, start, end, 2, slice(0, -2)) + def test_contains(self): ts = yield self.create() - all_dates = yield ts.itimes() + all_dates = yield ts.itimes() dt = all_dates[10] self.assertTrue(dt in ts) # now lets pop dt @@ -68,4 +67,4 @@ def test_contains(self): self.assertFalse(dt in ts) # dn = datetime.now() - self.assertFalse(dn in ts) \ No newline at end of file + self.assertFalse(dn in ts) diff --git a/tests/all/apps/columnts/npts.py b/tests/all/apps/columnts/npts.py index 9acda03..bad2f74 100644 --- a/tests/all/apps/columnts/npts.py +++ b/tests/all/apps/columnts/npts.py @@ -1,59 +1,60 @@ import os from stdnet import odm -from stdnet.utils import test, encoders +from stdnet.utils import encoders, test + try: - from stdnet.apps.columnts import npts from dynts import tsname - + + from stdnet.apps.columnts import npts + nptsColumnTS = npts.ColumnTS - + class ColumnTimeSeriesNumpy(odm.StdModel): - ticker = odm.SymbolField(unique = True) + ticker = odm.SymbolField(unique=True) data = npts.ColumnTSField() - + + except ImportError: nptsColumnTS = None ColumnTimeSeriesNumpy = None - -from . import main +from . import main skipUnless = test.unittest.skipUnless -@skipUnless(nptsColumnTS, 'Requires dynts') +@skipUnless(nptsColumnTS, "Requires dynts") class TestDynTsIntegration(main.TestOperations): - ColumnTS = nptsColumnTS - + ColumnTS = nptsColumnTS + def testGetFields(self): ts1 = self.create() ts = ts1.irange() - self.assertEqual(ts.count(),6) - d1,v1 = ts1.front() - d2,v2 = ts1.back() - self.assertTrue(d2>d1) - + self.assertEqual(ts.count(), 6) + d1, v1 = ts1.front() + d2, v2 = ts1.back() + self.assertTrue(d2 > d1) + def testEmpty(self): session = self.session() ts1 = session.add(self.ColumnTS()) ts = ts1.irange() - self.assertEqual(len(ts),0) + self.assertEqual(len(ts), 0) self.assertFalse(ts1.front()) self.assertFalse(ts1.back()) - + def testgetFieldInOrder(self): ts1 = self.create() - ts = ts1.irange(fields = ('a','b','c')) + ts = ts1.irange(fields=("a", "b", "c")) self.assertEqual(ts.count(), 3) - self.assertEqual(ts.name, tsname('a','b','c')) - + self.assertEqual(ts.name, tsname("a", "b", "c")) + def testgetItem(self): ts1 = self.create() dates = list(ts1) N = len(dates) self.assertTrue(N) - n = N//2 + n = N // 2 dte = dates[n] v = ts1[dte] - \ No newline at end of file diff --git a/tests/all/apps/columnts/readonly.py b/tests/all/apps/columnts/readonly.py index 23c9495..f729bc0 100644 --- a/tests/all/apps/columnts/readonly.py +++ b/tests/all/apps/columnts/readonly.py @@ -1,13 +1,12 @@ from datetime import date -from stdnet.utils import test from stdnet.apps.columnts import ColumnTS +from stdnet.utils import test from .main import ColumnMixin, nan class TestReadOnly(ColumnMixin, test.TestCase): - @classmethod def after_setup(cls): cls.ts1 = yield cls.data.data1.create(cls) @@ -15,49 +14,49 @@ def after_setup(cls): cls.ts3 = yield cls.data.data3.create(cls) cls.mul1 = yield cls.data.data_mul1.create(cls) cls.mul2 = yield cls.data.data_mul2.create(cls) - + def test_info_simple(self): ts = yield self.empty() info = yield ts.info() - self.assertEqual(info['size'], 0) - self.assertFalse('start' in info) + self.assertEqual(info["size"], 0) + self.assertFalse("start" in info) d1 = date(2012, 5, 15) d2 = date(2012, 5, 16) - yield ts.update({d1: {'open':605}, - d2: {'open':617}}) + yield ts.update({d1: {"open": 605}, d2: {"open": 617}}) info = yield ts.info() - self.assertEqual(info['size'], 2) - self.assertEqual(info['fields']['open']['missing'], 0) - self.assertEqual(info['start'].date(), d1) - self.assertEqual(info['stop'].date(), d2) - d3 = date(2012,5,14) - d4 = date(2012,5,13) - yield ts.update({d3: {'open':nan,'close':607}, - d4: {'open':nan,'close':nan}}) + self.assertEqual(info["size"], 2) + self.assertEqual(info["fields"]["open"]["missing"], 0) + self.assertEqual(info["start"].date(), d1) + self.assertEqual(info["stop"].date(), d2) + d3 = date(2012, 5, 14) + d4 = date(2012, 5, 13) + yield ts.update( + {d3: {"open": nan, "close": 607}, d4: {"open": nan, "close": nan}} + ) info = yield ts.info() - self.assertEqual(info['size'], 4) - self.assertEqual(info['start'].date(), d4) - self.assertEqual(info['stop'].date(), d2) - self.assertEqual(info['fields']['open']['missing'], 2) - self.assertEqual(info['fields']['close']['missing'], 3) - + self.assertEqual(info["size"], 4) + self.assertEqual(info["start"].date(), d4) + self.assertEqual(info["stop"].date(), d2) + self.assertEqual(info["fields"]["open"]["missing"], 2) + self.assertEqual(info["fields"]["close"]["missing"], 3) + def test_istats(self): data = self.data ts1 = self.ts1 - dt,fields = yield ts1.irange() + dt, fields = yield ts1.irange() self.assertEqual(len(fields), 6) - result = yield ts1.istats(0,-1) + result = yield ts1.istats(0, -1) self.assertTrue(result) - self.assertEqual(result['start'],dt[0]) - self.assertEqual(result['stop'],dt[-1]) - self.assertEqual(result['len'],len(dt)) - stats = result['stats'] - for field in ('a','b','c','d','f','g'): + self.assertEqual(result["start"], dt[0]) + self.assertEqual(result["stop"], dt[-1]) + self.assertEqual(result["len"], len(dt)) + stats = result["stats"] + for field in ("a", "b", "c", "d", "f", "g"): self.assertTrue(field in stats) stat_field = stats[field] res = data.data1.sorted_fields[field] self.check_stats(stat_field, res) - + def test_stats(self): data = self.data ts1 = self.ts1 @@ -71,32 +70,30 @@ def test_stats(self): # Perform the statistics between start and end result = yield ts1.stats(start, end) self.assertTrue(result) - self.assertEqual(result['start'], start) - self.assertEqual(result['stop'], end) - self.assertEqual(result['len'], len(dt)) - stats = result['stats'] - for field in ('a','b','c','d','f','g'): + self.assertEqual(result["start"], start) + self.assertEqual(result["stop"], end) + self.assertEqual(result["len"], len(dt)) + stats = result["stats"] + for field in ("a", "b", "c", "d", "f", "g"): self.assertTrue(field in stats) stat_field = stats[field] res = data.data1.sorted_fields[field][idx:-idx] self.check_stats(stat_field, res) - + def testSimpleMultiStats(self): ts1 = self.ts1 - dt,fields = yield ts1.irange() + dt, fields = yield ts1.irange() result = ts1.imulti_stats() self.assertTrue(result) - self.assertEqual(result['type'],'multi') - self.assertEqual(result['start'],dt[0]) - self.assertEqual(result['stop'],dt[-1]) - self.assertEqual(result['N'],len(dt)) - + self.assertEqual(result["type"], "multi") + self.assertEqual(result["start"], dt[0]) + self.assertEqual(result["stop"], dt[-1]) + self.assertEqual(result["N"], len(dt)) + def __test(self): - ts.update({date(2012,5,15): {'open':605}, - date(2012,5,16): {'open':617}}) - self.assertEqual(ts.evaluate('return self:length()'), 2) - self.assertEqual(ts.evaluate('return self:fields()'), [b'open']) - #Return the change from last open with respect prevois open - change = "return self:rank_value(-1,'open')-"\ - "self:rank_value(-2,'open')" + ts.update({date(2012, 5, 15): {"open": 605}, date(2012, 5, 16): {"open": 617}}) + self.assertEqual(ts.evaluate("return self:length()"), 2) + self.assertEqual(ts.evaluate("return self:fields()"), [b"open"]) + # Return the change from last open with respect prevois open + change = "return self:rank_value(-1,'open')-" "self:rank_value(-2,'open')" self.assertEqual(ts.evaluate(change), 12) diff --git a/tests/all/apps/searchengine/add.py b/tests/all/apps/searchengine/add.py index 7268454..8418ab2 100644 --- a/tests/all/apps/searchengine/add.py +++ b/tests/all/apps/searchengine/add.py @@ -1,27 +1,25 @@ from stdnet.utils import test -from .meta import Item, RelatedItem, SearchMixin, SearchEngine, processors +from .meta import Item, RelatedItem, SearchEngine, SearchMixin, processors class SearchWriteMixin(SearchMixin): - @classmethod def after_setup(cls): pass def setUp(self): self.mapper.set_search_engine(self.make_engine()) - self.mapper.search_engine.register(Item, ('related',)) + self.mapper.search_engine.register(Item, ("related",)) self.mapper.search_engine.register(RelatedItem) class TestSearchAddToEngine(SearchWriteMixin, test.TestWrite): - def testSimpleAdd(self): return self.simpleadd() def testDoubleEntries(self): - '''Test an item indexed twice.''' + """Test an item indexed twice.""" models = self.mapper session = models.session() engine = models.search_engine @@ -36,15 +34,15 @@ def testSearchWords(self): models = self.mapper engine = models.search_engine yield self.simpleadd() - words = list(engine.words_from_text('python gains')) + words = list(engine.words_from_text("python gains")) self.assertTrue(len(words) >= 2) def testSearchModelSimple(self): item, _ = yield self.simpleadd() - qs = self.query(Item).search('python gains') - self.assertEqual(qs.text, ('python gains',None)) + qs = self.query(Item).search("python gains") + self.assertEqual(qs.text, ("python gains", None)) q = qs.construct() - self.assertEqual(q.keyword, 'intersect') + self.assertEqual(q.keyword, "intersect") self.assertEqual(len(q), 4) qs = yield qs.all() self.assertEqual(len(qs), 1) @@ -52,29 +50,29 @@ def testSearchModelSimple(self): def testSearchModel(self): yield self.simpleadd() - yield self.simpleadd('pink', content='the dark side of the moon') - yield self.simpleadd('queen', content='we will rock you') - yield self.simpleadd('python', content='nothing here') - qs = self.query(Item).search('python') + yield self.simpleadd("pink", content="the dark side of the moon") + yield self.simpleadd("queen", content="we will rock you") + yield self.simpleadd("python", content="nothing here") + qs = self.query(Item).search("python") qc = qs.construct() - self.assertEqual(len(qc),3) - self.assertEqual(qc.keyword, 'intersect') + self.assertEqual(len(qc), 3) + self.assertEqual(qc.keyword, "intersect") yield self.async.assertEqual(qs.count(), 2) - qs = yield self.query(Item).search('python learn').all() + qs = yield self.query(Item).search("python learn").all() self.assertEqual(len(qs), 1) - self.assertEqual(qs[0].name, 'python') + self.assertEqual(qs[0].name, "python") def testRelatedModel(self): session = self.session() with session.begin() as t: - r = t.add(RelatedItem(name='planet earth is wonderful')) + r = t.add(RelatedItem(name="planet earth is wonderful")) yield t.on_result - yield self.simpleadd('king', content='england') - yield self.simpleadd('nothing', content='empty', related=r) - qs = self.query(Item).search('planet') + yield self.simpleadd("king", content="england") + yield self.simpleadd("nothing", content="empty", related=r) + qs = self.query(Item).search("planet") qc = qs.construct() self.assertEqual(len(qc), 2) - self.assertEqual(qc.keyword, 'intersect') + self.assertEqual(qc.keyword, "intersect") yield self.async.assertEqual(qs.count(), 1) def testFlush(self): @@ -111,21 +109,20 @@ def test_skip_indexing_when_missing_fields(self): engine = models.search_engine session = models.session() item, wis = yield self.simpleadd() - obj = yield models.item.query().load_only('id').get(id=item.id) + obj = yield models.item.query().load_only("id").get(id=item.id) yield session.add(obj) wis2 = yield engine.worditems(obj).all() self.assertEqual(wis, wis2) def testAddWithNumbers(self): - item, wi = yield self.simpleadd(name='20y', content='') + item, wi = yield self.simpleadd(name="20y", content="") wi = list(wi) - self.assertEqual(len(wi),1) + self.assertEqual(len(wi), 1) wi = wi[0] - self.assertEqual(str(wi.word),'20y') + self.assertEqual(str(wi.word), "20y") class TestCoverage(SearchWriteMixin, test.TestWrite): - @classmethod def make_engine(cls): eg = SearchEngine(metaphone=False) @@ -133,14 +130,12 @@ def make_engine(cls): return eg def testAdd(self): - item, wi = yield self.simpleadd('pink', - content='the dark side of the moon 10y') + item, wi = yield self.simpleadd("pink", content="the dark side of the moon 10y") wi = set((str(w.word) for w in wi)) self.assertEqual(len(wi), 4) - self.assertFalse('10y' in wi) + self.assertFalse("10y" in wi) def testRepr(self): - item, wi = yield self.simpleadd('pink', - content='the dark side of the moon 10y') + item, wi = yield self.simpleadd("pink", content="the dark side of the moon 10y") for w in wi: self.assertEqual(str(w), str(w.word)) diff --git a/tests/all/apps/searchengine/meta.py b/tests/all/apps/searchengine/meta.py index 9a0bbef..c07895a 100644 --- a/tests/all/apps/searchengine/meta.py +++ b/tests/all/apps/searchengine/meta.py @@ -1,77 +1,80 @@ from random import randint -from stdnet import odm, QuerySetError -from stdnet.utils import test -from stdnet.odm.search import UpdateSE -from stdnet.utils import test, to_string, range -from stdnet.apps.searchengine import SearchEngine, processors - +from examples.models import SimpleModel from examples.wordsearch.basicwords import basic_english_words from examples.wordsearch.models import Item, RelatedItem -from examples.models import SimpleModel -python_content = 'Python is a programming language that lets you work more\ +from stdnet import QuerySetError, odm +from stdnet.apps.searchengine import SearchEngine, processors +from stdnet.odm.search import UpdateSE +from stdnet.utils import range, test, to_string + +python_content = "Python is a programming language that lets you work more\ quickly and integrate your systems more effectively.\ You can learn to use Python and see almost immediate gains\ - in productivity and lower maintenance costs.' - - -NAMES = {'maurice':('MRS', None), - 'aubrey':('APR', None), - 'cambrillo':('KMPRL','KMPR'), - 'heidi':('HT', None), - 'katherine':('K0RN','KTRN'), - 'Thumbail':('0MPL','TMPL'), - 'catherine':('K0RN','KTRN'), - 'richard':('RXRT','RKRT'), - 'bob':('PP', None), - 'eric':('ARK', None), - 'geoff':('JF','KF'), - 'Through':('0R','TR'), - 'Schwein':('XN', 'XFN'), - 'dave':('TF', None), - 'ray':('R', None), - 'steven':('STFN', None), - 'bryce':('PRS', None), - 'randy':('RNT', None), - 'bryan':('PRN', None), - 'Rapelje':('RPL', None), - 'brian':('PRN', None), - 'otto':('AT', None), - 'auto':('AT', None), - 'Dallas':('TLS', None), - 'maisey':('MS', None), - 'zhang':('JNK', None), - 'Chile':('XL', None), - 'Jose':('HS', None), - 'Arnow':('ARN','ARNF'), - 'solilijs':('SLLS', None), - 'Parachute':('PRKT', None), - 'Nowhere':('NR', None), - 'Tux':('TKS', None)} - - + in productivity and lower maintenance costs." + + +NAMES = { + "maurice": ("MRS", None), + "aubrey": ("APR", None), + "cambrillo": ("KMPRL", "KMPR"), + "heidi": ("HT", None), + "katherine": ("K0RN", "KTRN"), + "Thumbail": ("0MPL", "TMPL"), + "catherine": ("K0RN", "KTRN"), + "richard": ("RXRT", "RKRT"), + "bob": ("PP", None), + "eric": ("ARK", None), + "geoff": ("JF", "KF"), + "Through": ("0R", "TR"), + "Schwein": ("XN", "XFN"), + "dave": ("TF", None), + "ray": ("R", None), + "steven": ("STFN", None), + "bryce": ("PRS", None), + "randy": ("RNT", None), + "bryan": ("PRN", None), + "Rapelje": ("RPL", None), + "brian": ("PRN", None), + "otto": ("AT", None), + "auto": ("AT", None), + "Dallas": ("TLS", None), + "maisey": ("MS", None), + "zhang": ("JNK", None), + "Chile": ("XL", None), + "Jose": ("HS", None), + "Arnow": ("ARN", "ARNF"), + "solilijs": ("SLLS", None), + "Parachute": ("PRKT", None), + "Nowhere": ("NR", None), + "Tux": ("TKS", None), +} + + class SeearchData(test.DataGenerator): - sizes = {'tiny': (10, 10), - 'small': (50, 20), - 'normal': (500, 30), - 'big': (10000, 40), - 'huge': (1000000, 50)} - + sizes = { + "tiny": (10, 10), + "small": (50, 20), + "normal": (500, 30), + "big": (10000, 40), + "huge": (1000000, 50), + } + def generate(self): size, num = self.size - self.names = self.populate('choice', size=size, - choice_from=basic_english_words) + self.names = self.populate("choice", size=size, choice_from=basic_english_words) self.groups = [] - self.empty = size*[''] + self.empty = size * [""] for i in range(size): - text = ' '.join(self.populate('choice', num,\ - choice_from=basic_english_words)) + text = " ".join( + self.populate("choice", num, choice_from=basic_english_words) + ) self.groups.append(text) - + def make_items(self, test, content=False, related=None): - '''Bulk creation of Item for testing search engine. Return a set -of words which have been included in the Items.''' + """Bulk creation of Item for testing search engine. Return a set + of words which have been included in the Items.""" session = test.session() words = set() contents = self.groups if content else self.empty @@ -81,39 +84,47 @@ def make_items(self, test, content=False, related=None): words.add(name) if content: words.update(content.split()) - t.add(Item(name=name, counter=randint(0,10), - content=content, related=related)) + t.add( + Item( + name=name, + counter=randint(0, 10), + content=content, + related=related, + ) + ) yield t.on_result test.words = words yield test.words - + class SearchMixin(object): - '''Mixin for testing the search engine. No tests implemented here, -just registration and some utility functions. All search-engine tests -below will derive from this class.''' - multipledb = 'redis' + """Mixin for testing the search engine. No tests implemented here, + just registration and some utility functions. All search-engine tests + below will derive from this class.""" + + multipledb = "redis" metaphone = True stemming = True models = (Item, RelatedItem) data_cls = SeearchData - + @classmethod def after_setup(cls): cls.mapper.set_search_engine(cls.make_engine()) - cls.mapper.search_engine.register(Item, ('related',)) + cls.mapper.search_engine.register(Item, ("related",)) cls.mapper.search_engine.register(RelatedItem) - + @classmethod def make_engine(cls): return SearchEngine(metaphone=cls.metaphone, stemming=cls.stemming) - - def make_item(self, name='python', counter=10, content=None, related=None): + + def make_item(self, name="python", counter=10, content=None, related=None): content = content if content is not None else python_content - return self.mapper.item.new(name=name, counter=counter, content=content, - related=related) - - def simpleadd(self, name='python', counter=10, content=None, related=None): + return self.mapper.item.new( + name=name, counter=counter, content=content, related=related + ) + + def simpleadd(self, name="python", counter=10, content=None, related=None): models = self.mapper engine = models.search_engine item = yield self.make_item(name, counter, content, related) @@ -124,75 +135,79 @@ def simpleadd(self, name='python', counter=10, content=None, related=None): for object in objects: self.assertEqual(object, item) yield item, wis - - + + class TestMeta(SearchMixin, test.TestCase): - '''Test internal functions, not the API.''' + """Test internal functions, not the API.""" + def test_mapper(self): models = self.mapper self.assertTrue(models.search_engine) self.assertEqual(models.search_engine.router, models) - + def testSplitting(self): eg = SearchEngine(metaphone=False, stemming=False) - self.assertEqual(list(eg.words_from_text('bla-ciao+pippo')),\ - ['bla','ciao','pippo']) - self.assertEqual(list(eg.words_from_text('bla.-ciao:;pippo')),\ - ['bla','ciao','pippo']) - self.assertEqual(list(eg.words_from_text(' bla ; @ciao ;:`')),\ - ['bla','ciao']) - self.assertEqual(list(eg.words_from_text('bla bla____bla')),\ - ['bla','bla','bla']) - + self.assertEqual( + list(eg.words_from_text("bla-ciao+pippo")), ["bla", "ciao", "pippo"] + ) + self.assertEqual( + list(eg.words_from_text("bla.-ciao:;pippo")), ["bla", "ciao", "pippo"] + ) + self.assertEqual(list(eg.words_from_text(" bla ; @ciao ;:`")), ["bla", "ciao"]) + self.assertEqual( + list(eg.words_from_text("bla bla____bla")), ["bla", "bla", "bla"] + ) + def testSplitters(self): eg = SearchEngine(splitters=False) self.assertEqual(eg.punctuation_regex, None) - words = list(eg.split_text('pippo:pluto')) - self.assertEqual(len(words),1) - self.assertEqual(words[0],'pippo:pluto') - words = list(eg.split_text('pippo: pluto')) - self.assertEqual(len(words),2) - self.assertEqual(words[0],'pippo:') - + words = list(eg.split_text("pippo:pluto")) + self.assertEqual(len(words), 1) + self.assertEqual(words[0], "pippo:pluto") + words = list(eg.split_text("pippo: pluto")) + self.assertEqual(len(words), 2) + self.assertEqual(words[0], "pippo:") + def testMetaphone(self): - '''Test metaphone algorithm''' + """Test metaphone algorithm""" for name in NAMES: d = processors.double_metaphone(name) - self.assertEqual(d,NAMES[name]) - + self.assertEqual(d, NAMES[name]) + def testRegistered(self): models = self.mapper self.assertTrue(Item in models.search_engine.REGISTERED_MODELS) self.assertTrue(RelatedItem in models.search_engine.REGISTERED_MODELS) self.assertFalse(SimpleModel in models.search_engine.REGISTERED_MODELS) - self.assertEqual(models.search_engine.REGISTERED_MODELS[Item].related, - ('related',)) self.assertEqual( - models.search_engine.REGISTERED_MODELS[RelatedItem].related, ()) - + models.search_engine.REGISTERED_MODELS[Item].related, ("related",) + ) + self.assertEqual( + models.search_engine.REGISTERED_MODELS[RelatedItem].related, () + ) + def testNoSearchEngine(self): models = odm.Router(self.backend) models.register(SimpleModel) self.assertFalse(models.search_engine) query = models.simplemodel.query() - qs = query.search('bla') + qs = query.search("bla") self.assertRaises(QuerySetError, qs.all) - + class TestCoverageBaseClass(test.TestCase): - def testAbstracts(self): e = odm.SearchEngine() - self.assertRaises(NotImplementedError, e.search, 'bla') - self.assertRaises(NotImplementedError, e.search_model, None, 'bla') + self.assertRaises(NotImplementedError, e.search, "bla") + self.assertRaises(NotImplementedError, e.search_model, None, "bla") self.assertRaises(NotImplementedError, e.flush) self.assertRaises(NotImplementedError, e.add_item, None, None, None) self.assertRaises(NotImplementedError, e.remove_item, None, None, None) self.assertRaises(AttributeError, e.session) - self.assertEqual(e.split_text('ciao luca'), ['ciao','luca']) - + self.assertEqual(e.split_text("ciao luca"), ["ciao", "luca"]) + def testItemFieldIterator(self): e = odm.SearchEngine() self.assertRaises(ValueError, e.item_field_iterator, None) u = UpdateSE(e) - self.assertEqual(u.se, e) \ No newline at end of file + self.assertEqual(u.se, e) diff --git a/tests/all/apps/searchengine/search.py b/tests/all/apps/searchengine/search.py index 2a8e1f9..655bde7 100644 --- a/tests/all/apps/searchengine/search.py +++ b/tests/all/apps/searchengine/search.py @@ -1,82 +1,78 @@ -'''search a mock database.''' -from stdnet import odm -from stdnet.utils import test, populate - +"""search a mock database.""" from examples.wordsearch.models import Item, RelatedItem +from stdnet import odm +from stdnet.utils import populate, test + from .meta import SearchMixin class TestBigSearch(SearchMixin, test.TestCase): - @classmethod def after_setup(cls): cls.mapper.set_search_engine(cls.make_engine()) - cls.mapper.search_engine.register(Item, ('related',)) + cls.mapper.search_engine.register(Item, ("related",)) cls.mapper.search_engine.register(RelatedItem) return cls.data.make_items(cls, content=True) - + def test_meta_session(self): models = self.mapper self.assertFalse(models.search_engine.backend) session = models.search_engine.session() self.assertEqual(session.router, models) - + def testSearchWords(self): engine = self.mapper.search_engine - words = list(engine.words_from_text('python gains')) - self.assertTrue(len(words)>=2) - + words = list(engine.words_from_text("python gains")) + self.assertTrue(len(words) >= 2) + def test_items(self): engine = self.mapper.search_engine wis = engine.worditems(Item) yield self.async.assertTrue(wis.count()) - + def __test_big_search(self): - #TODO: - #this test sometimes fails. Need to be fixed + # TODO: + # this test sometimes fails. Need to be fixed models = self.mapper - sw = ' '.join(populate('choice', 1, choice_from=self.words)) + sw = " ".join(populate("choice", 1, choice_from=self.words)) qs = yield models.item.search(sw).all() self.assertTrue(qs) for item in qs: self.assertTrue(sw in item.name or sw in item.content) - + def testSearch(self): engine = self.mapper.search_engine - text = ' '.join(populate('choice', 1, choice_from=self.words)) + text = " ".join(populate("choice", 1, choice_from=self.words)) result = yield engine.search(text) self.assertTrue(result) - + def testNoWords(self): models = self.mapper query = models.item.query() - q1 = yield query.search('').all() + q1 = yield query.search("").all() all = yield query.all() self.assertTrue(q1) self.assertEqual(set(q1), set(all)) - + def testInSearch(self): models = self.mapper query = models.item.query() - sw = ' '.join(populate('choice', 5, choice_from=self.words)) + sw = " ".join(populate("choice", 5, choice_from=self.words)) res1 = yield query.search(sw).all() - res2 = yield query.search(sw, lookup='in').all() + res2 = yield query.search(sw, lookup="in").all() self.assertTrue(res2) self.assertTrue(len(res1) < len(res2)) - + def testEmptySearch(self): engine = self.mapper.search_engine - queries = engine.search('') + queries = engine.search("") self.assertEqual(len(queries), 1) qs = yield queries[0].all() qs2 = yield engine.worditems().all() self.assertTrue(qs) self.assertEqual(set(qs), set(qs2)) - + def test_bad_lookup(self): engine = self.mapper.search_engine - self.assertRaises(ValueError, engine.search, - 'first second ', lookup='foo') - - + self.assertRaises(ValueError, engine.search, "first second ", lookup="foo") diff --git a/tests/all/backends/interface.py b/tests/all/backends/interface.py index 0211568..62373c3 100644 --- a/tests/all/backends/interface.py +++ b/tests/all/backends/interface.py @@ -1,12 +1,19 @@ -from stdnet import odm, getdb, BackendDataServer, ModelNotAvailable,\ - SessionNotAvailable, BackendStructure -from stdnet.utils import test - from examples.models import SimpleModel +from stdnet import ( + BackendDataServer, + BackendStructure, + ModelNotAvailable, + SessionNotAvailable, + getdb, + odm, +) +from stdnet.utils import test + class DummyBackendDataServer(BackendDataServer): default_port = 9090 + def setup_connection(self, address): pass @@ -18,9 +25,9 @@ def get_backend(self, **kwargs): return DummyBackendDataServer(**kwargs) def testVirtuals(self): - self.assertRaises(NotImplementedError, BackendDataServer, '', '') + self.assertRaises(NotImplementedError, BackendDataServer, "", "") b = self.get_backend() - self.assertEqual(str(b), 'dummy://127.0.0.1:9090') + self.assertEqual(str(b), "dummy://127.0.0.1:9090") self.assertFalse(b.clean(None)) self.assertRaises(NotImplementedError, b.execute_session, None, None) self.assertRaises(NotImplementedError, b.model_keys, None) @@ -31,9 +38,9 @@ def testMissingStructure(self): self.assertRaises(AttributeError, l.backend_structure) def testRedis(self): - b = getdb('redis://') - self.assertEqual(b.name, 'redis') - self.assertEqual(b.connection_string, 'redis://127.0.0.1:6379?db=0') + b = getdb("redis://") + self.assertEqual(b.name, "redis") + self.assertEqual(b.connection_string, "redis://127.0.0.1:6379?db=0") def testBackendStructure_error(self): s = BackendStructure(None, None, None) diff --git a/tests/all/backends/redis/async.py b/tests/all/backends/redis/async.py index a7ba60c..0e05b55 100644 --- a/tests/all/backends/redis/async.py +++ b/tests/all/backends/redis/async.py @@ -1,12 +1,12 @@ -'''Test the asynchronous redis client''' +"""Test the asynchronous redis client""" from copy import copy import pulsar +from examples.data import FinanceTest from stdnet.utils import test from stdnet.utils.async import async_binding -from examples.data import FinanceTest def check_connection(self, command_name): redis = self.mapper.default_backend.client @@ -17,30 +17,28 @@ def check_connection(self, command_name): consumer = conn.current_consumer request = consumer.current_request self.assertEqual(client.available_connections, 1) - - -@test.skipUnless(async_binding, 'Requires asynchronous binding') + + +@test.skipUnless(async_binding, "Requires asynchronous binding") class TestRedisAsyncClient(test.TestWrite): - multipledb = 'redis' - + multipledb = "redis" + @classmethod def after_setup(cls): return cls.data.create(cls) - + @classmethod def backend_params(cls): - return {'timeout': 0} - + return {"timeout": 0} + def test_client(self): redis = self.mapper.default_backend.client self.assertFalse(redis.full_response) redis = copy(redis) redis.full_response = True - ping = yield redis.execute_command('PING').on_finished + ping = yield redis.execute_command("PING").on_finished self.assertTrue(ping.result) self.assertTrue(ping.connection) - echo = yield redis.echo('Hello!').on_finished - self.assertEqual(echo.result, b'Hello!') + echo = yield redis.echo("Hello!").on_finished + self.assertEqual(echo.result, b"Hello!") self.assertTrue(echo.connection) - - \ No newline at end of file diff --git a/tests/all/backends/redis/client.py b/tests/all/backends/redis/client.py index 9ff1293..702b4f3 100755 --- a/tests/all/backends/redis/client.py +++ b/tests/all/backends/redis/client.py @@ -1,38 +1,41 @@ -'''Test additional commands for redis client.''' +"""Test additional commands for redis client.""" import json from hashlib import sha1 from stdnet import getdb from stdnet.backends import redisb -from stdnet.utils import test, flatzset +from stdnet.utils import flatzset, test + def get_version(info): - if 'redis_version' in info: - return info['redis_version'] + if "redis_version" in info: + return info["redis_version"] else: - return info['Server']['redis_version'] - - + return info["Server"]["redis_version"] + + class test_script(redisb.RedisScript): - script = (redisb.read_lua_file('commands.utils'), - '''\ + script = ( + redisb.read_lua_file("commands.utils"), + """\ local js = cjson.decode(ARGV[1]) -return cjson.encode(js)''') - +return cjson.encode(js)""", + ) + def callback(self, request, result, args, **options): return json.loads(result.decode(request.encoding)) class TestCase(test.TestWrite): - multipledb = 'redis' - + multipledb = "redis" + def setUp(self): client = self.backend.client self.client = client.prefixed(self.namespace) - + def tearDown(self): return self.client.flushdb() - + def make_hash(self, key, d): for k, v in d.items(): self.client.hset(key, k, v) @@ -44,148 +47,195 @@ def make_list(self, name, l): def make_zset(self, name, d): self.client.zadd(name, *flatzset(kwargs=d)) - - + + class TestExtraClientCommands(TestCase): - def test_coverage(self): c = self.backend.client - self.assertEqual(c.prefix, '') + self.assertEqual(c.prefix, "") size = yield c.dbsize() self.assertTrue(size >= 0) - + def test_script_meta(self): - script = redisb.get_script('test_script') + script = redisb.get_script("test_script") self.assertTrue(script.script) - sha = sha1(script.script.encode('utf-8')).hexdigest() - self.assertEqual(script.sha1,sha) - + sha = sha1(script.script.encode("utf-8")).hexdigest() + self.assertEqual(script.sha1, sha) + def test_del_pattern(self): c = self.client - items = ('bla',1, - 'bla1','ciao', - 'bla2','foo', - 'xxxx','moon', - 'blaaaaaaaaaaaaaa','sun', - 'xyyyy','earth') - yield self.async.assertTrue(c.execute_command('MSET', *items)) - N = yield c.delpattern('bla*') + items = ( + "bla", + 1, + "bla1", + "ciao", + "bla2", + "foo", + "xxxx", + "moon", + "blaaaaaaaaaaaaaa", + "sun", + "xyyyy", + "earth", + ) + yield self.async.assertTrue(c.execute_command("MSET", *items)) + N = yield c.delpattern("bla*") self.assertEqual(N, 4) - yield self.async.assertFalse(c.exists('bla')) - yield self.async.assertFalse(c.exists('bla1')) - yield self.async.assertFalse(c.exists('bla2')) - yield self.async.assertFalse(c.exists('blaaaaaaaaaaaaaa')) - yield self.async.assertEqual(c.get('xxxx'), b'moon') - N = yield c.delpattern('x*') + yield self.async.assertFalse(c.exists("bla")) + yield self.async.assertFalse(c.exists("bla1")) + yield self.async.assertFalse(c.exists("bla2")) + yield self.async.assertFalse(c.exists("blaaaaaaaaaaaaaa")) + yield self.async.assertEqual(c.get("xxxx"), b"moon") + N = yield c.delpattern("x*") self.assertEqual(N, 2) - + def testMove2Set(self): - yield self.multi_async((self.client.sadd('foo', 1, 2, 3, 4, 5), - self.client.lpush('bla', 4, 5, 6, 7, 8))) - r = yield self.client.execute_script('move2set', ('foo', 'bla'), 's') + yield self.multi_async( + ( + self.client.sadd("foo", 1, 2, 3, 4, 5), + self.client.lpush("bla", 4, 5, 6, 7, 8), + ) + ) + r = yield self.client.execute_script("move2set", ("foo", "bla"), "s") self.assertEqual(len(r), 2) self.assertEqual(r[0], 2) self.assertEqual(r[1], 1) - yield self.multi_async((self.client.sinterstore('res1', 'foo', 'bla'), - self.client.sunionstore('res2', 'foo', 'bla'))) - m1 = yield self.client.smembers('res1') - m2 = yield self.client.smembers('res2') + yield self.multi_async( + ( + self.client.sinterstore("res1", "foo", "bla"), + self.client.sunionstore("res2", "foo", "bla"), + ) + ) + m1 = yield self.client.smembers("res1") + m2 = yield self.client.smembers("res2") m1 = sorted((int(r) for r in m1)) m2 = sorted((int(r) for r in m2)) - self.assertEqual(m1, [4,5]) - self.assertEqual(m2, [1,2,3,4,5,6,7,8]) - + self.assertEqual(m1, [4, 5]) + self.assertEqual(m2, [1, 2, 3, 4, 5, 6, 7, 8]) + def testMove2ZSet(self): client = self.client - yield self.multi_async((client.zadd('foo',1,'a',2,'b',3,'c',4,'d',5,'e'), - client.lpush('bla','d','e','f','g'))) - r = yield client.execute_script('move2set', ('foo','bla'), 'z') + yield self.multi_async( + ( + client.zadd("foo", 1, "a", 2, "b", 3, "c", 4, "d", 5, "e"), + client.lpush("bla", "d", "e", "f", "g"), + ) + ) + r = yield client.execute_script("move2set", ("foo", "bla"), "z") self.assertEqual(len(r), 2) self.assertEqual(r[0], 2) self.assertEqual(r[1], 1) - yield self.multi_async((client.zinterstore('res1', ('foo', 'bla')), - client.zunionstore('res2', ('foo', 'bla')))) - m1 = yield client.zrange('res1', 0, -1) - m2 = yield client.zrange('res2', 0, -1) - self.assertEqual(sorted(m1), [b'd', b'e']) - self.assertEqual(sorted(m2), [b'a',b'b',b'c',b'd',b'e',b'f',b'g']) - + yield self.multi_async( + ( + client.zinterstore("res1", ("foo", "bla")), + client.zunionstore("res2", ("foo", "bla")), + ) + ) + m1 = yield client.zrange("res1", 0, -1) + m2 = yield client.zrange("res2", 0, -1) + self.assertEqual(sorted(m1), [b"d", b"e"]) + self.assertEqual(sorted(m2), [b"a", b"b", b"c", b"d", b"e", b"f", b"g"]) + def testMoveSetSet(self): - r = yield self.multi_async((self.client.sadd('foo',1,2,3,4,5), - self.client.sadd('bla',4,5,6,7,8))) - r = yield self.client.execute_script('move2set', ('foo', 'bla'), 's') + r = yield self.multi_async( + ( + self.client.sadd("foo", 1, 2, 3, 4, 5), + self.client.sadd("bla", 4, 5, 6, 7, 8), + ) + ) + r = yield self.client.execute_script("move2set", ("foo", "bla"), "s") self.assertEqual(len(r), 2) self.assertEqual(r[0], 2) self.assertEqual(r[1], 0) - + def testMove2List2(self): - yield self.multi_async((self.client.lpush('foo',1,2,3,4,5), - self.client.lpush('bla',4,5,6,7,8))) - r = yield self.client.execute_script('move2set', ('foo','bla'), 's') + yield self.multi_async( + ( + self.client.lpush("foo", 1, 2, 3, 4, 5), + self.client.lpush("bla", 4, 5, 6, 7, 8), + ) + ) + r = yield self.client.execute_script("move2set", ("foo", "bla"), "s") self.assertEqual(len(r), 2) self.assertEqual(r[0], 2) self.assertEqual(r[1], 2) - + def test_bad_execute_script(self): - self.assertRaises(redisb.RedisError, self.client.execute_script, 'foo', ()) - + self.assertRaises(redisb.RedisError, self.client.execute_script, "foo", ()) + # ZSET SCRIPTING COMMANDS def test_zdiffstore(self): - yield self.multi_async((self.make_zset('aa', {'a1': 1, 'a2': 1, 'a3': 1}), - self.make_zset('ba', {'a1': 2, 'a3': 2, 'a4': 2}), - self.make_zset('ca', {'a1': 6, 'a3': 5, 'a4': 4}))) - n = yield self.client.zdiffstore('za', ['aa', 'ba', 'ca']) + yield self.multi_async( + ( + self.make_zset("aa", {"a1": 1, "a2": 1, "a3": 1}), + self.make_zset("ba", {"a1": 2, "a3": 2, "a4": 2}), + self.make_zset("ca", {"a1": 6, "a3": 5, "a4": 4}), + ) + ) + n = yield self.client.zdiffstore("za", ["aa", "ba", "ca"]) self.assertEqual(n, 1) - r = yield self.client.zrange('za', 0, -1, withscores=True) - self.assertEquals(list(r), [(b'a2', 1)]) - + r = yield self.client.zrange("za", 0, -1, withscores=True) + self.assertEquals(list(r), [(b"a2", 1)]) + def test_zdiffstore_withscores(self): - yield self.multi_async((self.make_zset('ab', {'a1': 6, 'a2': 1, 'a3': 2}), - self.make_zset('bb', {'a1': 1, 'a3': 1, 'a4': 2}), - self.make_zset('cb', {'a1': 3, 'a3': 1, 'a4': 4}))) - n = yield self.client.zdiffstore('zb', ['ab', 'bb', 'cb'], withscores=True) + yield self.multi_async( + ( + self.make_zset("ab", {"a1": 6, "a2": 1, "a3": 2}), + self.make_zset("bb", {"a1": 1, "a3": 1, "a4": 2}), + self.make_zset("cb", {"a1": 3, "a3": 1, "a4": 4}), + ) + ) + n = yield self.client.zdiffstore("zb", ["ab", "bb", "cb"], withscores=True) self.assertEqual(n, 2) - r = yield self.client.zrange('zb', 0, -1, withscores=True) - self.assertEquals(list(r), [(b'a2', 1), (b'a1', 2)]) - + r = yield self.client.zrange("zb", 0, -1, withscores=True) + self.assertEquals(list(r), [(b"a2", 1), (b"a1", 2)]) + def test_zdiffstore2(self): c = self.client - yield self.multi_async((c.zadd('s1', 1, 'a', 2, 'b', 3, 'c', 4, 'd'), - c.zadd('s2', 6, 'a', 9, 'b', 100, 'c'))) - r = yield c.zdiffstore('s3', ('s1', 's2')) - self.async.assertEqual(c.zcard('s3'), 1) - r = yield c.zrange('s3', 0, -1) - self.assertEqual(r, [b'd']) - + yield self.multi_async( + ( + c.zadd("s1", 1, "a", 2, "b", 3, "c", 4, "d"), + c.zadd("s2", 6, "a", 9, "b", 100, "c"), + ) + ) + r = yield c.zdiffstore("s3", ("s1", "s2")) + self.async.assertEqual(c.zcard("s3"), 1) + r = yield c.zrange("s3", 0, -1) + self.assertEqual(r, [b"d"]) + def test_zdiffstore_withscores2(self): c = self.client - yield self.multi_async((c.zadd('s1', 1, 'a', 2, 'b', 3, 'c', 4, 'd'), - c.zadd('s2', 6, 'a', 2, 'b', 100, 'c'))) - r = yield c.zdiffstore('s3', ('s1', 's2'), withscores=True) - self.async.assertEqual(c.zcard('s3'), 3) - r = yield c.zrange('s3', 0, -1, withscores=True) - self.assertEqual(dict(r), {b'a': -5.0, b'c': -97.0, b'd': 4.0}) - + yield self.multi_async( + ( + c.zadd("s1", 1, "a", 2, "b", 3, "c", 4, "d"), + c.zadd("s2", 6, "a", 2, "b", 100, "c"), + ) + ) + r = yield c.zdiffstore("s3", ("s1", "s2"), withscores=True) + self.async.assertEqual(c.zcard("s3"), 3) + r = yield c.zrange("s3", 0, -1, withscores=True) + self.assertEqual(dict(r), {b"a": -5.0, b"c": -97.0, b"d": 4.0}) + def test_zpop_byrank(self): - yield self.client.zadd('foo',1,'a',2,'b',3,'c',4,'d',5,'e') - res = yield self.client.zpopbyrank('foo',0) - rem = yield self.client.zrange('foo',0,-1) - self.assertEqual(len(rem),4) - self.assertEqual(rem,[b'b',b'c',b'd',b'e']) - self.assertEqual(res,[b'a']) - res = yield self.client.zpopbyrank('foo',0,2) - self.assertEqual(res,[b'b',b'c',b'd']) - rem = yield self.client.zrange('foo',0,-1) - self.assertEqual(rem,[b'e']) - + yield self.client.zadd("foo", 1, "a", 2, "b", 3, "c", 4, "d", 5, "e") + res = yield self.client.zpopbyrank("foo", 0) + rem = yield self.client.zrange("foo", 0, -1) + self.assertEqual(len(rem), 4) + self.assertEqual(rem, [b"b", b"c", b"d", b"e"]) + self.assertEqual(res, [b"a"]) + res = yield self.client.zpopbyrank("foo", 0, 2) + self.assertEqual(res, [b"b", b"c", b"d"]) + rem = yield self.client.zrange("foo", 0, -1) + self.assertEqual(rem, [b"e"]) + def test_zpop_byscore(self): - yield self.client.zadd('foo', 1, 'a', 2, 'b', 3, 'c', 4, 'd', 5, 'e') - res = yield self.client.zpopbyscore('foo', 2) - rem = yield self.client.zrange('foo', 0, -1) + yield self.client.zadd("foo", 1, "a", 2, "b", 3, "c", 4, "d", 5, "e") + res = yield self.client.zpopbyscore("foo", 2) + rem = yield self.client.zrange("foo", 0, -1) self.assertEqual(len(rem), 4) - self.assertEqual(rem, [b'a', b'c', b'd', b'e']) - self.assertEqual(res, [b'b']) - res = yield self.client.zpopbyscore('foo', 0, 4.5) - self.assertEqual(res, [b'a', b'c', b'd']) - rem = yield self.client.zrange('foo', 0, -1) - self.assertEqual(rem, [b'e']) \ No newline at end of file + self.assertEqual(rem, [b"a", b"c", b"d", b"e"]) + self.assertEqual(res, [b"b"]) + res = yield self.client.zpopbyscore("foo", 0, 4.5) + self.assertEqual(res, [b"a", b"c", b"d"]) + rem = yield self.client.zrange("foo", 0, -1) + self.assertEqual(rem, [b"e"]) diff --git a/tests/all/backends/redis/info.py b/tests/all/backends/redis/info.py index 8a785b9..ef44f7b 100644 --- a/tests/all/backends/redis/info.py +++ b/tests/all/backends/redis/info.py @@ -1,57 +1,61 @@ import time -from stdnet.backends.redisb import RedisDb, RedisKey, RedisDataFormatter +from stdnet.backends.redisb import RedisDataFormatter, RedisDb, RedisKey from . import client class TestInfo(client.TestCase): models = (RedisDb, RedisKey) - - def get_manager(self, key='test', value='bla'): + + def get_manager(self, key="test", value="bla"): yield self.client.set(key, value) yield self.mapper.redisdb - + def test_dataFormatter(self): f = RedisDataFormatter() - self.assertEqual(f.format_date('bla'), '') + self.assertEqual(f.format_date("bla"), "") d = f.format_date(time.time()) self.assertTrue(d) - + def testKeyInfo(self): - yield self.client.set('planet', 'mars') - yield self.client.lpush('foo', 1, 2, 3, 4, 5) - yield self.client.lpush('bla', 4, 5, 6, 7, 8) - keys = yield self.client.execute_script('keyinfo', (), '*') + yield self.client.set("planet", "mars") + yield self.client.lpush("foo", 1, 2, 3, 4, 5) + yield self.client.lpush("bla", 4, 5, 6, 7, 8) + keys = yield self.client.execute_script("keyinfo", (), "*") self.assertEqual(len(keys), 3) d = dict(((k.key, k) for k in keys)) - self.assertEqual(d['planet'].length, 4) - self.assertEqual(d['planet'].type, 'string') - self.assertEqual(d['planet'].encoding, 'raw') - + self.assertEqual(d["planet"].length, 4) + self.assertEqual(d["planet"].type, "string") + self.assertEqual(d["planet"].encoding, "raw") + def testKeyInfo2(self): client = self.client - yield self.multi_async((client.set('planet', 'mars'), - client.lpush('foo', 1, 2, 3, 4, 5), - client.lpush('bla', 4, 5, 6, 7, 8))) - keys = yield client.execute_script('keyinfo', ('planet', 'bla')) + yield self.multi_async( + ( + client.set("planet", "mars"), + client.lpush("foo", 1, 2, 3, 4, 5), + client.lpush("bla", 4, 5, 6, 7, 8), + ) + ) + keys = yield client.execute_script("keyinfo", ("planet", "bla")) self.assertEqual(len(keys), 2) - + def test_manager(self): redisdb = yield self.get_manager() self.assertTrue(redisdb.client) self.assertEqual(redisdb.backend.client, redisdb.client) self.assertTrue(redisdb.formatter) - self.assertEqual(redisdb.formatter.format_name('ciao'), 'ciao') - self.assertEqual(redisdb.formatter.format_bool(0), 'no') - self.assertEqual(redisdb.formatter.format_bool('bla'), 'yes') - + self.assertEqual(redisdb.formatter.format_name("ciao"), "ciao") + self.assertEqual(redisdb.formatter.format_bool(0), "no") + self.assertEqual(redisdb.formatter.format_bool("bla"), "yes") + def test_info_pannel_names(self): info = yield self.client.info() self.assertTrue(info) for name in self.mapper.redisdb.names: self.assertTrue(name in info) - + def test_databases(self): redisdb = yield self.get_manager() dbs = yield redisdb.all() @@ -60,12 +64,12 @@ def test_databases(self): for db in dbs: self.assertIsInstance(db.db, int) self.assertTrue(db.expires <= db.keys) - + def test_makepanel_empty(self): redisdb = yield self.get_manager() - p = redisdb.makepanel('sdkjcbnskbcd', {}) + p = redisdb.makepanel("sdkjcbnskbcd", {}) self.assertEqual(p, None) - + def test_panels(self): redisdb = yield self.get_manager() p = yield redisdb.panels() @@ -74,43 +78,43 @@ def test_panels(self): val = p.pop(name) self.assertIsInstance(val, list) self.assertFalse(p) - + def __test_database(self): redisdb = yield self.get_manager() dbs = yield redisdb.all() for db in dbs: dbkeys = yield db.all_keys.all() self.assertIsInstance(dbkeys, list) - + def __testInfoKeys(self): redisdb = yield self.get_manager() dbs = RedisDb.objects.all(info) for db in dbs: keys = RedisKey.objects.query(db) self.assertEqual(keys.db, db) - self.assertEqual(keys.pattern, '*') - + self.assertEqual(keys.pattern, "*") + def __test_search(self): redisdb = yield self.get_manager() - yield redisdb.client.set('blaxxx', 'test') + yield redisdb.client.set("blaxxx", "test") query = db.query() - q = query.search('blax*') + q = query.search("blax*") self.assertNotEqual(query, q) self.assertEqual(q.db, db) - self.assertEqual(q.pattern, 'blax*') + self.assertEqual(q.pattern, "blax*") self.assertTrue(query.count()) self.assertEqual(query.count(), len(query)) self.assertEqual(q.count(), 1) keys = list(q) self.assertEqual(len(keys), 1) key = q[0] - self.assertEqual(str(key), 'blaxxx') - + self.assertEqual(str(key), "blaxxx") + def __testQuerySlice(self): redisdb = yield self.get_manager() db = self.newdb(info) - db.client.set('blaxxx', 'test') - db.client.set('blaxyy', 'test2') + db.client.set("blaxxx", "test") + db.client.set("blaxyy", "test2") all = db.all() self.assertTrue(isinstance(all, list)) self.assertEqual(len(all), 2) @@ -119,20 +123,20 @@ def __testQuerySlice(self): self.assertEqual(all[-1:1], db.query()[-1:1]) self.assertEqual(all[-1:2], db.query()[-1:2]) self.assertEqual(all[-2:1], db.query()[-2:1]) - self.assertEqual(db.query().search('*yy').delete(), 1) + self.assertEqual(db.query().search("*yy").delete(), 1) self.assertEqual(db.query().delete(), 1) self.assertEqual(db.all(), []) - + def __testRedisKeyManager(self): redisdb = yield self.get_manager() db = self.newdb(info) - db.client.set('blaxxx', 'test') - db.client.set('blaxyy', 'test2') + db.client.set("blaxxx", "test") + db.client.set("blaxyy", "test2") all = db.all() self.assertEqual(len(all), 2) self.assertEqual(RedisKey.objects.delete(all), 2) self.assertEqual(db.all(), []) - + def __testRedisDbDelete(self): redisdb = yield self.get_manager() dbs = RedisDb.objects.all(info) @@ -140,6 +144,4 @@ def __testRedisDbDelete(self): flushdb = lambda client: called.append(client) for db in dbs: db.delete(flushdb) - self.assertEqual(len(called),len(dbs)) - - \ No newline at end of file + self.assertEqual(len(called), len(dbs)) diff --git a/tests/all/backends/redis/prefixed.py b/tests/all/backends/redis/prefixed.py index ac4298d..4a169d6 100644 --- a/tests/all/backends/redis/prefixed.py +++ b/tests/all/backends/redis/prefixed.py @@ -1,36 +1,35 @@ -from stdnet.utils import test -from stdnet.utils import gen_unique_id +from stdnet.utils import gen_unique_id, test + - class TestRedisPrefixed(test.TestCase): - multipledb = 'redis' - + multipledb = "redis" + def get_client(self, prefix=None): prefix = prefix or gen_unique_id() c = self.backend.client.prefixed(prefix + self.namespace) if c.prefix not in self.clients: self.clients[c.prefix] = c return self.clients[c.prefix] - + def setUp(self): self.clients = {} - + def tearDown(self): for c in self.clients.values(): yield c.flushdb() - + def test_meta(self): - c = self.get_client('yyy') + c = self.get_client("yyy") self.assertTrue(c.prefix) - self.assertTrue(c.prefix.startswith('yyy')) + self.assertTrue(c.prefix.startswith("yyy")) self.assertTrue(c.client) self.assertFalse(c.client.prefix) - + def test_delete(self): c1 = self.get_client() c2 = self.get_client() - yield c1.set('bla', 'foo') - yield c2.set('bla', 'foo') + yield c1.set("bla", "foo") + yield c2.set("bla", "foo") yield self.async.assertEqual(c1.dbsize(), 1) yield self.async.assertEqual(c2.dbsize(), 1) yield c1.flushdb() @@ -38,11 +37,8 @@ def test_delete(self): yield self.async.assertEqual(c2.dbsize(), 1) yield c2.flushdb() yield self.async.assertEqual(c2.dbsize(), 0) - + def test_error(self): c = self.get_client() - self.assertRaises(NotImplementedError, c.execute_command, 'FLUSHDB') - self.assertRaises(NotImplementedError, c.execute_command, 'FLUSHALL') - - - \ No newline at end of file + self.assertRaises(NotImplementedError, c.execute_command, "FLUSHDB") + self.assertRaises(NotImplementedError, c.execute_command, "FLUSHALL") diff --git a/tests/all/benchmarks/__init__.py b/tests/all/benchmarks/__init__.py index 272ec8d..c3e368a 100644 --- a/tests/all/benchmarks/__init__.py +++ b/tests/all/benchmarks/__init__.py @@ -1,18 +1,15 @@ import os -from stdnet.utils import test +from examples.data import CCYS_TYPES, INSTS_TYPES, finance_data +from examples.models import Fund, Instrument, PortfolioView, Position, UserDefaultView -from examples.models import Instrument, Fund, Position, PortfolioView,\ - UserDefaultView -from examples.data import finance_data, INSTS_TYPES, CCYS_TYPES +from stdnet.utils import test class Benchmarks(test.TestWrite): __benchmark__ = True data_cls = finance_data models = (Instrument, Fund, Position) - + def test_create(self): session = yield self.data.create(self) - - \ No newline at end of file diff --git a/tests/all/fields/fk.py b/tests/all/fields/fk.py index 1dfb9c9..5a745ed 100644 --- a/tests/all/fields/fk.py +++ b/tests/all/fields/fk.py @@ -1,36 +1,36 @@ +from examples.models import Group, Person + import stdnet -from stdnet import odm, FieldError +from stdnet import FieldError, odm from stdnet.utils import test -from examples.models import Person, Group - class TestForeignKey(test.TestCase): models = (Person, Group) - + @classmethod def after_setup(cls): session = cls.mapper.session() with session.begin() as t: - t.add(Group(name='bla')) + t.add(Group(name="bla")) yield t.on_result - g = yield session.query(Group).get(name='bla') + g = yield session.query(Group).get(name="bla") with session.begin() as t: - t.add(Person(name='foo', group=g)) + t.add(Person(name="foo", group=g)) yield t.on_result - + def testSimple(self): session = self.session() query = session.query(Person) yield self.async.assertEqual(query.count(), 1) - p = yield query.get(name='foo') + p = yield query.get(name="foo") self.assertTrue(p.group_id) p.group = None self.assertEqual(p.group_id, None) - + def testOldRelatedNone(self): models = self.mapper - p = yield models.person.get(name='foo') + p = yield models.person.get(name="foo") g = yield p.group self.assertTrue(g) self.assertEqual(g, p.group) @@ -38,20 +38,20 @@ def testOldRelatedNone(self): p.group = None self.assertEqual(p.group_id, None) yield self.async.assertRaises(stdnet.FieldValueError, p.session.add, p) - + def testCoverage(self): self.assertRaises(FieldError, odm.ForeignKey, None) - + class TestForeignKeyWrite(test.TestWrite): models = (Person, Group) - + def test_create(self): models = self.mapper - group = yield models.group.new(name='quant') - self.assertEqual(group.name, 'quant') + group = yield models.group.new(name="quant") + self.assertEqual(group.name, "quant") self.assertEqualId(group, 1) - person = yield models.person.new(name='luca', group=group) + person = yield models.person.new(name="luca", group=group) self.assertEqualId(person, 1) self.assertEqual(group.id, person.group_id) - self.assertEqual(group, person.group) \ No newline at end of file + self.assertEqual(group, person.group) diff --git a/tests/all/fields/fknotrequired.py b/tests/all/fields/fknotrequired.py index 10f9904..203dff1 100644 --- a/tests/all/fields/fknotrequired.py +++ b/tests/all/fields/fknotrequired.py @@ -1,69 +1,68 @@ +from examples.models import CrossData, Feed1, Feed2 from pulsar import multi_async import stdnet -from stdnet import odm, FieldError +from stdnet import FieldError, odm from stdnet.utils import test -from examples.models import Feed1, Feed2, CrossData - class NonRequiredForeignKey(test.TestCase): models = (Feed1, Feed2, CrossData) - + def create_feeds(self, *names): session = self.session() with self.mapper.session().begin() as t: for name in names: t.add(Feed1(name=name)) return t.on_result - + def create_feeds_with_data(self, *names, **kwargs): models = self.mapper yield self.create_feeds(*names) all = yield models.feed1.filter(name=names).all() - params = {'pv': 30, 'delta': 40, 'name': 'live'} + params = {"pv": 30, "delta": 40, "name": "live"} params.update(kwargs) - name = params.pop('name') + name = params.pop("name") with models.session().begin() as t: for feed in all: feed.live = yield models.crossdata.new(name=name, data=params) t.add(feed) yield t.on_result - + def test_nodata(self): - yield self.create_feeds('bla', 'foo') + yield self.create_feeds("bla", "foo") session = self.session() - feeds = yield session.query(Feed1).filter(name=('bla', 'foo')).all() + feeds = yield session.query(Feed1).filter(name=("bla", "foo")).all() for feed in feeds: live, prev = yield multi_async((feed.live, feed.prev)) self.assertFalse(live) self.assertFalse(prev) - + def test_width_data(self): models = self.mapper - yield self.create_feeds_with_data('test1') - feed = yield models.feed1.get(name='test1') + yield self.create_feeds_with_data("test1") + feed = yield models.feed1.get(name="test1") live = yield feed.live self.assertEqual(live.data__pv, 30) - + def test_load_only(self): models = self.mapper - yield self.create_feeds_with_data('test2') - feed = yield models.feed1.query().load_only('live__data__pv').get(name='test2') + yield self.create_feeds_with_data("test2") + feed = yield models.feed1.query().load_only("live__data__pv").get(name="test2") self.assertFalse(feed.live.has_all_data) - self.assertEqual(feed.live.data, {'pv': 30}) - + self.assertEqual(feed.live.data, {"pv": 30}) + def test_filter(self): models = self.mapper - yield self.create_feeds_with_data('test3', pv=400) + yield self.create_feeds_with_data("test3", pv=400) feeds = models.feed1.filter(live__data__pv__gt=300) yield self.async.assertEqual(feeds.count(), 1) - + def test_delete(self): models = self.mapper - yield self.create_feeds_with_data('test4', 'test5', name='pippo') - yield models.crossdata.filter(name='pippo').delete() - feeds = yield models.feed1.query().filter(name=('test4', 'test5')).all() + yield self.create_feeds_with_data("test4", "test5", name="pippo") + yield models.crossdata.filter(name="pippo").delete() + feeds = yield models.feed1.query().filter(name=("test4", "test5")).all() self.assertEqual(len(feeds), 2) for feed in feeds: live = yield feed.live @@ -72,77 +71,89 @@ def test_delete(self): self.assertFalse(feed.live_id) self.assertFalse(prev) self.assertFalse(feed.prev_id) - + def test_load_related(self): models = self.mapper - yield self.create_feeds('jkjkjk') - feed = yield models.feed1.query().load_related('live', 'id').get(name='jkjkjk') + yield self.create_feeds("jkjkjk") + feed = yield models.feed1.query().load_related("live", "id").get(name="jkjkjk") self.assertEqual(feed.live, None) def test_load_only_missing_related(self): - '''load_only on a related field which is missing.''' + """load_only on a related field which is missing.""" models = self.mapper - yield self.create_feeds('ooo', 'ooo2') - qs = yield models.feed1.query().load_only('live__pv').filter(name__startswith='ooo') + yield self.create_feeds("ooo", "ooo2") + qs = ( + yield models.feed1.query() + .load_only("live__pv") + .filter(name__startswith="ooo") + ) yield self.async.assertEqual(qs.count(), 2) qs = yield qs.all() for feed in qs: self.assertEqual(feed.live, None) - + def test_load_only_some_missing_related(self): - '''load_only on a related field which is missing.''' + """load_only on a related field which is missing.""" models = self.mapper - yield self.create_feeds_with_data('aaa1', 'aaa2', name='palo') - qs = yield models.feed1.query().filter(name__startswith='aaa')\ - .load_only('name', 'live__data__pv').all() + yield self.create_feeds_with_data("aaa1", "aaa2", name="palo") + qs = ( + yield models.feed1.query() + .filter(name__startswith="aaa") + .load_only("name", "live__data__pv") + .all() + ) self.assertEqual(len(qs), 2) for feed in qs: - self.assertEqual(feed.live.data, {'pv': 30}) + self.assertEqual(feed.live.data, {"pv": 30}) def test_has_attribute(self): models = self.mapper - yield self.create_feeds_with_data('bbba', 'bbbc') - qs = yield models.feed1.query().filter(name__startswith='bbb')\ - .load_only('name', 'live__data__pv').all() + yield self.create_feeds_with_data("bbba", "bbbc") + qs = ( + yield models.feed1.query() + .filter(name__startswith="bbb") + .load_only("name", "live__data__pv") + .all() + ) self.assertEqual(len(qs), 2) for feed in qs: - name = feed.get_attr_value('name') - self.assertTrue(name.startswith('bbb')) - self.assertEqual(feed.get_attr_value('live__data__pv'), 30) - self.assertRaises(AttributeError, feed.get_attr_value, 'a__b') - + name = feed.get_attr_value("name") + self.assertTrue(name.startswith("bbb")) + self.assertEqual(feed.get_attr_value("live__data__pv"), 30) + self.assertRaises(AttributeError, feed.get_attr_value, "a__b") + def test_load_related_when_deleted(self): - '''Use load_related on foreign key which was deleted.''' + """Use load_related on foreign key which was deleted.""" models = self.mapper session = models.session() - yield self.create_feeds_with_data('ccc1') - feed = yield models.feed1.get(name='ccc1') + yield self.create_feeds_with_data("ccc1") + feed = yield models.feed1.get(name="ccc1") live = yield feed.live self.assertTrue(feed.live) self.assertEqual(feed.live.id, feed.live_id) # Now we delete the feed deleted = yield session.delete(feed.live) - yield self.async.assertEqual(models.crossdata.filter( - id=feed.live.id).count(), 0) + yield self.async.assertEqual( + models.crossdata.filter(id=feed.live.id).count(), 0 + ) # we still have a reference to it! self.assertTrue(feed.live_id) self.assertTrue(feed.live) # # Now reload the feed - feed = yield models.feed1.get(name='ccc1') + feed = yield models.feed1.get(name="ccc1") live = yield feed.live self.assertFalse(live) self.assertFalse(feed.live_id) # - feed = yield models.feed1.query().load_related('live').get(name='ccc1') + feed = yield models.feed1.query().load_related("live").get(name="ccc1") self.assertFalse(feed.live) self.assertFalse(feed.live_id) - + def test_sort_by_missing_fk_data(self): - yield self.create_feeds('ddd1', 'ddd2') - query = self.session().query(Feed1).filter(name__startswith='ddd') - feed1s = yield query.sort_by('live').all() - feed2s = yield query.sort_by('live__data__pv').all() + yield self.create_feeds("ddd1", "ddd2") + query = self.session().query(Feed1).filter(name__startswith="ddd") + feed1s = yield query.sort_by("live").all() + feed2s = yield query.sort_by("live__data__pv").all() self.assertEqual(len(feed1s), 2) self.assertEqual(len(feed2s), 2) - \ No newline at end of file diff --git a/tests/all/fields/id.py b/tests/all/fields/id.py index 0248375..5ef86fa 100644 --- a/tests/all/fields/id.py +++ b/tests/all/fields/id.py @@ -1,50 +1,50 @@ -'''AutoId, CompositeId and custom Id tests.''' -from uuid import uuid4 +"""AutoId, CompositeId and custom Id tests.""" from random import randint +from uuid import uuid4 import pulsar +from examples.models import Instrument, SimpleModel, Task, WordBook import stdnet from stdnet import FieldError from stdnet.utils import test -from examples.models import Task, WordBook, SimpleModel, Instrument - def genid(): return str(uuid4())[:8] class Id(test.TestCase): - '''Test primary key when it is not an AutoIdField. -Use the manager for convenience.''' + """Test primary key when it is not an AutoIdField. + Use the manager for convenience.""" + model = Task - - def make(self, name='pluto'): + + def make(self, name="pluto"): return self.mapper.task.new(id=genid(), name=name) - + def test_create(self): t1 = yield self.make() yield pulsar.async_sleep(0.5) t2 = yield self.make() self.assertNotEqual(t1.id, t2.id) self.assertTrue(t1.timestamp < t2.timestamp) - + def test_change_id(self): session = self.session() t1 = yield self.make() id1 = t1.id - self.assertEqual(id1, t1._dbdata['id']) + self.assertEqual(id1, t1._dbdata["id"]) self.assertTrue(t1.get_state().persistent) id2 = genid() t1.id = id2 - self.assertEqual(id1, t1._dbdata['id']) - self.assertNotEqual(id2, t1._dbdata['id']) + self.assertEqual(id1, t1._dbdata["id"]) + self.assertNotEqual(id2, t1._dbdata["id"]) yield session.add(t1) self.assertEqual(id2, t1.id) - self.assertEqual(id2, t1._dbdata['id']) + self.assertEqual(id2, t1._dbdata["id"]) yield self.async.assertEqual(self.query().filter(id=(id1, id2)).count(), 1) - + def test_clone(self): t1 = yield self.make() session = t1.session @@ -59,7 +59,7 @@ def test_clone(self): self.assertEqual(tasks[0].id, t2.id) self.assertEqual(tasks[1].id, t1.id) self.assertTrue(tasks[0].timestamp > tasks[1].timestamp) - + def test_delete_and_clone(self): t1 = yield self.make() session = t1.session @@ -72,66 +72,66 @@ def test_delete_and_clone(self): tasks = yield self.query().filter(id=(t1.id, t2.id)).all() self.assertEqual(len(tasks), 1) self.assertEqual(tasks[0].id, t2.id) - + def test_fail(self): session = self.session() - t = Task(name='pluto') + t = Task(name="pluto") yield self.async.assertRaises(Exception, session.add, t) class TestAutoId(test.TestCase): models = (SimpleModel, Instrument) - + def random_id(self, id=None): - if self.backend.name == 'mongo': + if self.backend.name == "mongo": from bson.objectid import ObjectId + return ObjectId() else: if id: - return id+1 + return id + 1 else: - return randint(1,1000) - + return randint(1, 1000) + def testMeta(self): pk = self.model._meta.pk - self.assertEqual(pk.name, 'id') - self.assertEqual(pk.type, 'auto') + self.assertEqual(pk.name, "id") + self.assertEqual(pk.type, "auto") self.assertEqual(pk.internal_type, None) self.assertEqual(pk.python_type, None) - self.assertEqual(str(pk), 'examples.simplemodel.id') - self.assertRaises(FieldError, pk.register_with_model, - 'bla', SimpleModel) - + self.assertEqual(str(pk), "examples.simplemodel.id") + self.assertRaises(FieldError, pk.register_with_model, "bla", SimpleModel) + def testCreateWithValue(self): # create an instance with an id models = self.mapper id = self.random_id() - m1 = yield models.simplemodel.new(id=id, code='bla') + m1 = yield models.simplemodel.new(id=id, code="bla") self.assertEqual(m1.id, id) - self.assertEqual(m1.code, 'bla') - m2 = yield models.simplemodel.new(code='foo') + self.assertEqual(m1.code, "bla") + m2 = yield models.simplemodel.new(code="foo") id2 = self.random_id(id) self.assertEqualId(m2, id2) - self.assertEqual(m2.code, 'foo') + self.assertEqual(m2.code, "foo") qs = yield models.simplemodel.query().all() self.assertEqual(len(qs), 2) self.assertEqual(set(qs), set((m1, m2))) - + def testCreateWithValue2(self): models = self.mapper id = self.random_id() - m1 = yield models[Instrument].new(name='test1', type='bla', ccy='eur') - m2 = yield models.instrument.new(id=id, name='test2', type='foo', ccy='eur') + m1 = yield models[Instrument].new(name="test1", type="bla", ccy="eur") + m2 = yield models.instrument.new(id=id, name="test2", type="foo", ccy="eur") self.assertEqualId(m1, 1) self.assertEqual(m2.id, id) qs = yield models.instrument.query().all() self.assertEqual(len(qs), 2) - self.assertEqual(set(qs), set((m1,m2))) - - + self.assertEqual(set(qs), set((m1, m2))) + + class CompositeId(test.TestCase): model = WordBook - + def create(self, word, book): session = self.session() m = yield session.add(self.model(word=word, book=book)) @@ -142,40 +142,39 @@ def create(self, word, book): self.assertEqual(m.book, book) self.assertEqual(m.id, id) yield m - + def testMeta(self): id = self.model._meta.pk - self.assertEqual(id.type, 'composite') + self.assertEqual(id.type, "composite") fields = id.fields self.assertEqual(len(fields), 2) - self.assertEqual(fields[0], self.model._meta.dfields['word']) - self.assertEqual(fields[1], self.model._meta.dfields['book']) - + self.assertEqual(fields[0], self.model._meta.dfields["word"]) + self.assertEqual(fields[1], self.model._meta.dfields["book"]) + def test_value(self): - m = self.model(book='world', word='hello') + m = self.model(book="world", word="hello") self.assertFalse(m.id) value = m.pkvalue() self.assertTrue(value) - self.assertEqual(value, hash(('hello', 'world'))) - m = self.model(book='hello', word='world') + self.assertEqual(value, hash(("hello", "world"))) + m = self.model(book="hello", word="world") self.assertNotEqual(value, m.pkvalue()) - + def test_create(self): - return self.create('hello', 'world') - + return self.create("hello", "world") + def test_change(self): - m = yield self.create('ciao', 'libro') + m = yield self.create("ciao", "libro") session = m.session id = m.id - m.word = 'beautiful' + m.word = "beautiful" self.assertNotEqual(m.pkvalue(), id) yield session.add(m) self.assertNotEqual(m.id, id) - self.assertEqual(m.word, 'beautiful') + self.assertEqual(m.word, "beautiful") query = self.query() yield self.async.assertEqual(query.filter(id=id).count(), 0) yield self.async.assertEqual(query.filter(id=m.id).count(), 1) - yield self.async.assertEqual(query.filter(word='ciao', book='libro').count(), 0) - m2 = yield query.get(word='beautiful', book='libro') + yield self.async.assertEqual(query.filter(word="ciao", book="libro").count(), 0) + m2 = yield query.get(word="beautiful", book="libro") self.assertEqual(m, m2) - \ No newline at end of file diff --git a/tests/all/fields/integer.py b/tests/all/fields/integer.py index c71c776..4ebad7c 100644 --- a/tests/all/fields/integer.py +++ b/tests/all/fields/integer.py @@ -1,8 +1,8 @@ +from examples.models import Page + from stdnet import FieldValueError from stdnet.utils import test -from examples.models import Page - class TestIntegerField(test.TestCase): model = Page @@ -11,9 +11,9 @@ def test_default_value(self): models = self.mapper p = Page() self.assertEqual(p.in_navigation, 1) - p = Page(in_navigation='4') + p = Page(in_navigation="4") self.assertEqual(p.in_navigation, 4) - self.assertRaises(FieldValueError, p=Page, in_navigation='foo') + self.assertRaises(FieldValueError, p=Page, in_navigation="foo") yield self.session().add(p) self.assertEqual(p.in_navigation, 4) p = yield models.page.get(id=p.id) @@ -23,7 +23,7 @@ def testNotValidated(self): models = self.mapper p = yield models.page.new() self.assertEqual(p.in_navigation, 1) - self.assertRaises(ValueError, Page, in_navigation='bla') + self.assertRaises(ValueError, Page, in_navigation="bla") def testZeroValue(self): models = self.mapper diff --git a/tests/all/fields/jsonfield.py b/tests/all/fields/jsonfield.py index 0d367b8..fa110a7 100755 --- a/tests/all/fields/jsonfield.py +++ b/tests/all/fields/jsonfield.py @@ -1,47 +1,51 @@ -import os import json +import os import time from copy import deepcopy from datetime import date, datetime from decimal import Decimal -from random import random, randint, choice +from random import choice, randint, random + +from examples.models import Role, Statistics, Statistics3 import stdnet -from stdnet.utils import test, zip, to_string, unichr, ispy3k, range -from stdnet.utils import date2timestamp +from stdnet.utils import date2timestamp, ispy3k, range, test, to_string, unichr, zip from stdnet.utils.populate import populate -from examples.models import Statistics, Statistics3, Role - class make_random(object): - rtype = ['number','list',None] + ['dict']*3 + rtype = ["number", "list", None] + ["dict"] * 3 + def __init__(self): self.count = 0 - - def make(self, size = 5, maxsize = 10, nesting = 1, level = 0): - keys = populate(size = size) + + def make(self, size=5, maxsize=10, nesting=1, level=0): + keys = populate(size=size) if level: - keys.append('') + keys.append("") for key in keys: - t = choice(self.rtype) if level else 'dict' - if nesting and t == 'dict': - yield key,dict(self.make(size = randint(0,maxsize), - maxsize = maxsize, - nesting = nesting - 1, - level = level + 1)) + t = choice(self.rtype) if level else "dict" + if nesting and t == "dict": + yield key, dict( + self.make( + size=randint(0, maxsize), + maxsize=maxsize, + nesting=nesting - 1, + level=level + 1, + ) + ) else: - if t == 'list': + if t == "list": v = [random() for i in range(10)] - elif t == 'number': + elif t == "number": v = random() - elif t == 'dict': + elif t == "dict": v = random() else: v = t - yield key,v - - + yield key, v + + class TestJsonField(test.TestCase): models = [Statistics, Role] @@ -53,47 +57,50 @@ def test_default(self): self.assertEqual(a.data, {}) a = yield models.statistics.get(id=a.id) self.assertEqual(a.data, {}) - + def testMetaData(self): - field = Statistics._meta.dfields['data'] - self.assertEqual(field.type,'json object') - self.assertEqual(field.index,False) - self.assertEqual(field.as_string,True) - + field = Statistics._meta.dfields["data"] + self.assertEqual(field.type, "json object") + self.assertEqual(field.index, False) + self.assertEqual(field.as_string, True) + def testCreate(self): models = self.mapper - mean = Decimal('56.4') - started = date(2010,1,1) + mean = Decimal("56.4") + started = date(2010, 1, 1) timestamp = datetime.now() - a = yield models.statistics.new(dt=date.today(), - data={'mean': mean, - 'std': 5.78, - 'started': started, - 'timestamp':timestamp}) - self.assertEqual(a.data['mean'], mean) + a = yield models.statistics.new( + dt=date.today(), + data={ + "mean": mean, + "std": 5.78, + "started": started, + "timestamp": timestamp, + }, + ) + self.assertEqual(a.data["mean"], mean) a = yield models.statistics.get(id=a.id) self.assertEqual(len(a.data), 4) - self.assertEqual(a.data['mean'], mean) - self.assertEqual(a.data['started'], started) - self.assertAlmostEqual(date2timestamp(a.data['timestamp']), - date2timestamp(timestamp), 5) - + self.assertEqual(a.data["mean"], mean) + self.assertEqual(a.data["started"], started) + self.assertAlmostEqual( + date2timestamp(a.data["timestamp"]), date2timestamp(timestamp), 5 + ) + def testCreateFromString(self): models = self.mapper - mean = 'mean' + mean = "mean" timestamp = time.time() - data = {'mean': mean, - 'std': 5.78, - 'timestamp': timestamp} + data = {"mean": mean, "std": 5.78, "timestamp": timestamp} datas = json.dumps(data) a = yield models.statistics.new(dt=date.today(), data=datas) a = yield models.statistics.get(id=a.id) - self.assertEqual(a.data['mean'], mean) + self.assertEqual(a.data["mean"], mean) a = yield models.statistics.get(id=a.id) - self.assertEqual(len(a.data),3) - self.assertEqual(a.data['mean'],mean) - self.assertAlmostEqual(a.data['timestamp'], timestamp) - + self.assertEqual(len(a.data), 3) + self.assertEqual(a.data["mean"], mean) + self.assertAlmostEqual(a.data["timestamp"], timestamp) + def test_default(self): models = self.mapper a = Statistics(dt=date.today()) @@ -102,193 +109,192 @@ def test_default(self): self.assertEqual(a.data, {}) a = yield models.statistics.get(id=a.id) self.assertEqual(a.data, {}) - + def testValueError(self): models = self.mapper - a = models.statistics(dt=date.today(), data={'mean': self}) + a = models.statistics(dt=date.today(), data={"mean": self}) yield self.async.assertRaises(stdnet.FieldValueError, models.session().add, a) - self.assertTrue('data' in a._dbdata['errors']) - + self.assertTrue("data" in a._dbdata["errors"]) + def testDefaultValue(self): models = self.mapper - role = models.role(name='test') + role = models.role(name="test") self.assertEqual(role.permissions, []) - role.permissions.append('ciao') + role.permissions.append("ciao") role.permissions.append(4) yield models.session().add(role) self.assertTrue(role.id) role = yield models.role.get(id=role.id) - self.assertEqual(role.permissions, ['ciao', 4]) + self.assertEqual(role.permissions, ["ciao", 4]) class TestJsonFieldAsData(test.TestCase): - '''Test a model with a JSONField which expand as instance fields. -The `as_string` atttribute is set to ``False``.''' + """Test a model with a JSONField which expand as instance fields. + The `as_string` atttribute is set to ``False``.""" + model = Statistics3 - def_data = {'mean': 1.0, - 'std': 5.78, - 'pv': 3.2, - 'name': 'bla', - 'dt': date.today()} - - def_baddata = {'': 3.2, - 'ts': {'a':[1,2,3,4,5,6,7], - 'b':[10,11,12]}, - 'mean': {'1y':1.0,'2y':1.1}, - 'std': {'1y':4.0,'2y':5.1}, - 'dt': datetime.now()} - - def_data2 = {'pv': {'':3.2, - 'ts': {'a':[1,2,3,4,5,6,7], - 'b':[10,11,12]}, - 'mean': {'1y':1.0,'2y':1.1}, - 'std': {'1y':4.0,'2y':5.1}}, - 'dt': datetime.now()} - + def_data = {"mean": 1.0, "std": 5.78, "pv": 3.2, "name": "bla", "dt": date.today()} + + def_baddata = { + "": 3.2, + "ts": {"a": [1, 2, 3, 4, 5, 6, 7], "b": [10, 11, 12]}, + "mean": {"1y": 1.0, "2y": 1.1}, + "std": {"1y": 4.0, "2y": 5.1}, + "dt": datetime.now(), + } + + def_data2 = { + "pv": { + "": 3.2, + "ts": {"a": [1, 2, 3, 4, 5, 6, 7], "b": [10, 11, 12]}, + "mean": {"1y": 1.0, "2y": 1.1}, + "std": {"1y": 4.0, "2y": 5.1}, + }, + "dt": datetime.now(), + } + def make(self, data=None, name=None): data = data or self.def_data name = name or self.data.random_string() return self.model(name=name, data=data) - + def testMeta(self): - field = self.model._meta.dfields['data'] + field = self.model._meta.dfields["data"] self.assertFalse(field.as_string) - + def testMake(self): m = self.make() self.assertTrue(m.is_valid()) - data = m._dbdata['cleaned_data'] - data.pop('data') + data = m._dbdata["cleaned_data"] + data.pop("data") self.assertEqual(len(data), 6) - self.assertEqual(float(data['data__mean']), 1.0) - self.assertEqual(float(data['data__std']), 5.78) - self.assertEqual(float(data['data__pv']), 3.2) - + self.assertEqual(float(data["data__mean"]), 1.0) + self.assertEqual(float(data["data__std"]), 5.78) + self.assertEqual(float(data["data__pv"]), 3.2) + def testGet(self): models = self.mapper session = models.session() m = yield session.add(self.make()) m = yield models.statistics3.get(id=m.id) - self.assertEqual(m.data['mean'], 1.0) - self.assertEqual(m.data['std'], 5.78) - self.assertEqual(m.data['pv'], 3.2) - self.assertEqual(m.data['dt'], date.today()) - self.assertEqual(m.data['name'], 'bla') - + self.assertEqual(m.data["mean"], 1.0) + self.assertEqual(m.data["std"], 5.78) + self.assertEqual(m.data["pv"], 3.2) + self.assertEqual(m.data["dt"], date.today()) + self.assertEqual(m.data["name"], "bla") + def testmakeEmptyError(self): - '''Here we test when we have a key which is empty.''' + """Here we test when we have a key which is empty.""" models = self.mapper session = models.session() m = self.make(self.def_baddata) self.assertFalse(m.is_valid()) yield self.async.assertRaises(stdnet.FieldValueError, session.add, m) - + def testmakeEmpty(self): models = self.mapper session = models.session() m = self.make(self.def_data2) self.assertTrue(m.is_valid()) - cdata = m._dbdata['cleaned_data'] - self.assertEqual(len(cdata),10) - self.assertTrue('data' in cdata) - self.assertEqual(cdata['data__pv__mean__1y'],'1.0') + cdata = m._dbdata["cleaned_data"] + self.assertEqual(len(cdata), 10) + self.assertTrue("data" in cdata) + self.assertEqual(cdata["data__pv__mean__1y"], "1.0") obj = yield session.add(m) obj = yield models.statistics3.get(id=obj.id) - self.assertEqual(obj.data['dt'].date(), date.today()) + self.assertEqual(obj.data["dt"].date(), date.today()) self.assertEqual(obj.data__dt.date(), date.today()) - self.assertEqual(obj.data['pv']['mean']['1y'], 1.0) + self.assertEqual(obj.data["pv"]["mean"]["1y"], 1.0) self.assertEqual(obj.data__pv__mean__1y, 1.0) self.assertEqual(obj.data__dt.date(), date.today()) - + def testmakeEmpty2(self): models = self.mapper session = models.session() - m = self.make({'ts': [1,2,3,4]}) + m = self.make({"ts": [1, 2, 3, 4]}) obj = yield models.add(m) obj = yield models.statistics3.get(id=obj.id) - self.assertEqual(obj.data, {'ts': [1, 2, 3, 4]}) - + self.assertEqual(obj.data, {"ts": [1, 2, 3, 4]}) + def __testFuzzySmall(self): - #TODO: This does not pass in pypy + # TODO: This does not pass in pypy models = self.mapper session = models.session() r = make_random() - data = dict(r.make(nesting = 0)) + data = dict(r.make(nesting=0)) m = self.make(data) self.assertTrue(m.is_valid()) - cdata = m._dbdata['cleaned_data'] - cdata.pop('data') + cdata = m._dbdata["cleaned_data"] + cdata.pop("data") for k in cdata: - if k is not 'name': - self.assertTrue(k.startswith('data__')) + if k is not "name": + self.assertTrue(k.startswith("data__")) obj = yield session.add(m) obj = yield models.statistics3.get(id=obj.id) self.assertEqualDict(data, obj.data) - + def __testFuzzyMedium(self): - #TODO: This does not pass in pypy + # TODO: This does not pass in pypy models = self.mapper session = models.session() r = make_random() - data = dict(r.make(nesting = 1)) + data = dict(r.make(nesting=1)) m = self.make(data) self.assertTrue(m.is_valid()) - cdata = m._dbdata['cleaned_data'] - cdata.pop('data') + cdata = m._dbdata["cleaned_data"] + cdata.pop("data") for k in cdata: - if k is not 'name': - self.assertTrue(k.startswith('data__')) + if k is not "name": + self.assertTrue(k.startswith("data__")) obj = yield session.add(m) - #obj = self.model.objects.get(id=obj.id) - #self.assertEqualDict(data,obj.data) - + # obj = self.model.objects.get(id=obj.id) + # self.assertEqualDict(data,obj.data) + def __testFuzzy(self): - #TODO: This does not pass in pypy + # TODO: This does not pass in pypy models = self.mapper session = models.session() r = make_random() - data = dict(r.make(nesting = 3)) + data = dict(r.make(nesting=3)) m = self.make(deepcopy(data)) self.assertTrue(m.is_valid()) - cdata = m._dbdata['cleaned_data'] - cdata.pop('data') + cdata = m._dbdata["cleaned_data"] + cdata.pop("data") for k in cdata: - if k is not 'name': - self.assertTrue(k.startswith('data__')) + if k is not "name": + self.assertTrue(k.startswith("data__")) obj = yield session.add(m) - #obj = self.model.objects.get(id=obj.id) - #self.assertEqualDict(data,obj.data) - + # obj = self.model.objects.get(id=obj.id) + # self.assertEqualDict(data,obj.data) + def testEmptyDict(self): models = self.mapper session = models.session() - r = yield session.add(self.model(name='bla', data = {'bla':'ciao'})) - self.assertEqual(r.data, {'bla':'ciao'}) + r = yield session.add(self.model(name="bla", data={"bla": "ciao"})) + self.assertEqual(r.data, {"bla": "ciao"}) r.data = None yield session.add(r) r = yield models.statistics3.get(id=r.id) self.assertEqual(r.data, {}) - + def testFromEmpty(self): - '''Test the change of a data jsonfield from empty to populated.''' + """Test the change of a data jsonfield from empty to populated.""" models = self.mapper session = models.session() - r = yield models.statistics3.new(name = 'bla') + r = yield models.statistics3.new(name="bla") self.assertEqual(r.data, {}) - r.data = {'bla':'ciao'} + r.data = {"bla": "ciao"} yield session.add(r) r = yield models.statistics3.get(id=r.id) - self.assertEqual(r.data, {'bla':'ciao'}) - - def assertEqualDict(self,data1,data2): + self.assertEqual(r.data, {"bla": "ciao"}) + + def assertEqualDict(self, data1, data2): for k in list(data1): v1 = data1.pop(k) - v2 = data2.pop(k,{}) - if isinstance(v1,dict): - self.assertEqualDict(v1,v2) + v2 = data2.pop(k, {}) + if isinstance(v1, dict): + self.assertEqualDict(v1, v2) else: - self.assertAlmostEqual(v1,v2) + self.assertAlmostEqual(v1, v2) self.assertFalse(data1) self.assertFalse(data2) - - \ No newline at end of file diff --git a/tests/all/fields/meta.py b/tests/all/fields/meta.py index bacfc1a..facf7ec 100644 --- a/tests/all/fields/meta.py +++ b/tests/all/fields/meta.py @@ -1,11 +1,10 @@ -'''Field metadata and full coverage.''' +"""Field metadata and full coverage.""" import stdnet -from stdnet import odm, FieldError +from stdnet import FieldError, odm from stdnet.utils import test class TestFields(test.TestCase): - def testBaseClass(self): self.assertRaises(TypeError, odm.Field, kaputt=True) f = odm.Field() @@ -20,13 +19,14 @@ def bad_class(): class MyBadClass(odm.StdModel): id = odm.IntegerField(primary_key=True) code = odm.SymbolField(primary_key=True) + self.assertRaises(FieldError, bad_class) def test_defaults(self): f = odm.Field() self.assertEqual(f.default, None) - f = odm.Field(default = 'bla') - self.assertEqual(f.default, 'bla') + f = odm.Field(default="bla") + self.assertEqual(f.default, "bla") def test_id(self): f = odm.Field() diff --git a/tests/all/fields/pickle.py b/tests/all/fields/pickle.py index 633a28d..9af64e8 100644 --- a/tests/all/fields/pickle.py +++ b/tests/all/fields/pickle.py @@ -1,36 +1,35 @@ -from stdnet.utils import test - from examples.models import Environment +from stdnet.utils import test class TestPickleObjectField(test.TestCase): model = Environment - + def testMetaData(self): - field = self.model._meta.dfields['data'] - self.assertEqual(field.type,'object') - self.assertEqual(field.internal_type,'bytes') - self.assertEqual(field.index,False) - self.assertEqual(field.name,field.attname) + field = self.model._meta.dfields["data"] + self.assertEqual(field.type, "object") + self.assertEqual(field.internal_type, "bytes") + self.assertEqual(field.index, False) + self.assertEqual(field.name, field.attname) return field - + def testOkObject(self): session = self.session() - v = self.model(data=['ciao','pippo']) - self.assertEqual(v.data, ['ciao','pippo']) + v = self.model(data=["ciao", "pippo"]) + self.assertEqual(v.data, ["ciao", "pippo"]) yield session.add(v) - self.assertEqual(v.data, ['ciao','pippo']) + self.assertEqual(v.data, ["ciao", "pippo"]) v = yield session.query(self.model).get(id=v.id) - self.assertEqual(v.data, ['ciao','pippo']) - + self.assertEqual(v.data, ["ciao", "pippo"]) + def testRecursive(self): - '''Silly test to test both pickle field and picklable instance''' + """Silly test to test both pickle field and picklable instance""" session = self.session() - v = yield session.add(self.model(data=('ciao','pippo', 4, {}))) + v = yield session.add(self.model(data=("ciao", "pippo", 4, {}))) v2 = self.model(data=v) self.assertEqual(v2.data, v) yield session.add(v2) self.assertEqual(v2.data, v) v2 = yield session.query(self.model).get(id=v2.id) - self.assertEqual(v2.data, v) \ No newline at end of file + self.assertEqual(v2.data, v) diff --git a/tests/all/fields/pk.py b/tests/all/fields/pk.py index 7e14fdc..cbf7b8f 100644 --- a/tests/all/fields/pk.py +++ b/tests/all/fields/pk.py @@ -1,24 +1,24 @@ +from examples.models import Child, Parent + import stdnet -from stdnet import odm, FieldError +from stdnet import FieldError, odm from stdnet.utils import test -from examples.models import Parent, Child - class TestForeignKey(test.TestCase): models = (Parent, Child) def test_custom_pk(self): models = self.mapper - parent = yield models.parent.new(name='test') - self.assertEqual(parent.pkvalue(), 'test') - self.assertEqual(parent.pk().name, 'name') + parent = yield models.parent.new(name="test") + self.assertEqual(parent.pkvalue(), "test") + self.assertEqual(parent.pk().name, "name") def test_add_parent_and_child(self): models = self.mapper with models.session().begin() as t: - parent = models.parent(name='test2') - child = models.child(parent=parent, name='foo') + parent = models.parent(name="test2") + child = models.child(parent=parent, name="foo") self.assertEqual(child.parent, parent) self.assertEqual(child.parent_id, parent.pkvalue()) t.add(parent) @@ -30,14 +30,14 @@ class TestQuery(test.TestCase): models = (Parent, Child) def test_non_id_pk(self): - ''' + """ Models with non-'id' primary keys should be queryable (regression test) - ''' + """ models = self.mapper with models.session().begin() as t: - parent = models.parent(name='test2') - child = models.child(parent=parent, name='foo') + parent = models.parent(name="test2") + child = models.child(parent=parent, name="foo") t.add(parent) t.add(child) yield t.on_result @@ -51,16 +51,16 @@ class TestManyToMany(test.TestCase): models = (Parent, Child) def test_non_id_pk(self): - ''' + """ Models with non-'id' primary keys should be queryable from a ManyToMany relation (regression test) - ''' + """ models = self.mapper with models.session().begin() as t: - parent = models.parent(name='test2') - uncle = models.parent(name='test3') - child = models.child(parent=parent, name='foo') + parent = models.parent(name="test2") + uncle = models.parent(name="test3") + child = models.child(parent=parent, name="foo") t.add(parent) t.add(uncle) t.add(child) diff --git a/tests/all/fields/scalar.py b/tests/all/fields/scalar.py index d178c79..e7aa48a 100755 --- a/tests/all/fields/scalar.py +++ b/tests/all/fields/scalar.py @@ -1,19 +1,17 @@ -'''Scalar fields such as Char, Float and Date, DateTime, Byte fields. -''' +"""Scalar fields such as Char, Float and Date, DateTime, Byte fields. +""" import os from datetime import date +from examples.models import DateData, NumericData, SimpleModel, TestDateModel + import stdnet from stdnet import FieldValueError -from stdnet.utils import (test, populate, zip, is_string, to_string, unichr, - ispy3k) - -from examples.models import TestDateModel, DateData, SimpleModel, NumericData +from stdnet.utils import is_string, ispy3k, populate, test, to_string, unichr, zip NUM_DATES = 100 -names = populate('string', NUM_DATES, min_len=5, max_len=20) -dates = populate('date', NUM_DATES, start=date(2010, 5, 1), - end=date(2010, 6, 1)) +names = populate("string", NUM_DATES, min_len=5, max_len=20) +dates = populate("date", NUM_DATES, start=date(2010, 5, 1), end=date(2010, 6, 1)) class TestDateModel2(TestDateModel): @@ -86,14 +84,14 @@ class TestCharFields(test.TestCase): def testUnicode(self): models = self.mapper - unicode_string = unichr(500) + to_string('ciao') + unichr(300) + unicode_string = unichr(500) + to_string("ciao") + unichr(300) m = yield models.simplemodel.new(code=unicode_string) m = yield models.simplemodel.get(id=m.id) self.assertEqual(m.code, unicode_string) if ispy3k: self.assertEqual(str(m), unicode_string) else: - code = unicode_string.encode('utf-8') + code = unicode_string.encode("utf-8") self.assertEqual(str(m), code) @@ -102,24 +100,23 @@ class TestNumericData(test.TestCase): def testDefaultValue(self): models = self.mapper - d = yield models.numericdata.new(pv=1.) - self.assertAlmostEqual(d.pv, 1.) - self.assertAlmostEqual(d.vega, 0.) - self.assertAlmostEqual(d.delta, 1.) + d = yield models.numericdata.new(pv=1.0) + self.assertAlmostEqual(d.pv, 1.0) + self.assertAlmostEqual(d.vega, 0.0) + self.assertAlmostEqual(d.delta, 1.0) self.assertEqual(d.gamma, None) def testDefaultValue2(self): models = self.mapper - d = yield models.numericdata.new(pv=0., delta=0.) - self.assertAlmostEqual(d.pv, 0.) - self.assertAlmostEqual(d.vega, 0.) - self.assertAlmostEqual(d.delta, 0.) + d = yield models.numericdata.new(pv=0.0, delta=0.0) + self.assertAlmostEqual(d.pv, 0.0) + self.assertAlmostEqual(d.vega, 0.0) + self.assertAlmostEqual(d.delta, 0.0) self.assertEqual(d.gamma, None) def testFieldError(self): models = self.mapper - yield self.async.assertRaises(stdnet.FieldValueError, - models.numericdata.new) + yield self.async.assertRaises(stdnet.FieldValueError, models.numericdata.new) class TestDateData(test.TestCase): @@ -148,30 +145,30 @@ class TestBoolField(test.TestCase): def testMeta(self): self.assertEqual(len(self.model._meta.indices), 1) index = self.model._meta.indices[0] - self.assertEqual(index.type, 'bool') + self.assertEqual(index.type, "bool") self.assertEqual(index.index, True) self.assertEqual(index.name, index.attname) return index def testSerializeAndScoreFun(self): index = self.testMeta() - for fname in ('scorefun', 'serialise'): + for fname in ("scorefun", "serialise"): func = getattr(index, fname) self.assertEqual(func(True), 1) self.assertEqual(func(False), 0) self.assertEqual(func(4), 1) self.assertEqual(func(0), 0) - self.assertEqual(func('bla'), 1) - self.assertEqual(func(''), 0) + self.assertEqual(func("bla"), 1) + self.assertEqual(func(""), 0) self.assertEqual(func(None), 0) def test_bool_value(self): models = self.mapper session = models.session() - d = yield session.add(models.numericdata(pv=1.)) + d = yield session.add(models.numericdata(pv=1.0)) d = yield models.numericdata.get(id=d.id) self.assertEqual(d.ok, False) - d.ok = 'jasxbhjaxsbjxsb' + d.ok = "jasxbhjaxsbjxsb" yield self.async.assertRaises(FieldValueError, session.add, d) d.ok = True yield session.add(d) @@ -183,26 +180,26 @@ class TestByteField(test.TestCase): model = SimpleModel def testMetaData(self): - field = SimpleModel._meta.dfields['somebytes'] - self.assertEqual(field.type, 'bytes') - self.assertEqual(field.internal_type, 'bytes') + field = SimpleModel._meta.dfields["somebytes"] + self.assertEqual(field.type, "bytes") + self.assertEqual(field.internal_type, "bytes") self.assertEqual(field.index, False) self.assertEqual(field.name, field.attname) return field def testValue(self): models = self.mapper - v = models.simplemodel(code='cgfgcgf', somebytes=to_string('hello')) - self.assertEqual(v.somebytes, b'hello') + v = models.simplemodel(code="cgfgcgf", somebytes=to_string("hello")) + self.assertEqual(v.somebytes, b"hello") self.assertFalse(v.id) yield models.session().add(v) v = yield models.simplemodel.get(id=v.id) - self.assertEqual(v.somebytes, b'hello') + self.assertEqual(v.somebytes, b"hello") def testValueByte(self): models = self.mapper b = os.urandom(8) - v = SimpleModel(code='sdcscdsc', somebytes=b) + v = SimpleModel(code="sdcscdsc", somebytes=b) self.assertFalse(is_string(v.somebytes)) self.assertEqual(v.somebytes, b) yield models.session().add(v) @@ -213,9 +210,9 @@ def testValueByte(self): def testToJson(self): models = self.mapper b = os.urandom(8) - v = yield models.simplemodel.new(code='xxsdcscdsc', somebytes=b) + v = yield models.simplemodel.new(code="xxsdcscdsc", somebytes=b) data = v.tojson() - value = data['somebytes'] + value = data["somebytes"] self.assertTrue(is_string(value)) v2 = models.simplemodel.from_base64_data(**data) self.assertTrue(v2) @@ -223,10 +220,9 @@ def testToJson(self): class TestErrorAtomFields(test.TestCase): - def testSessionNotAvailable(self): session = self.session() - m = TestDateModel2(name=names[1], dt=dates[0], person='sdcbsc') + m = TestDateModel2(name=names[1], dt=dates[0], person="sdcbsc") self.assertRaises(stdnet.InvalidTransaction, session.add, m) def testNotSaved(self): diff --git a/tests/all/lib/autoincrement.py b/tests/all/lib/autoincrement.py index bbf6e12..b5b2852 100644 --- a/tests/all/lib/autoincrement.py +++ b/tests/all/lib/autoincrement.py @@ -1,40 +1,43 @@ +from examples.models import SimpleModel + from stdnet import odm from stdnet.apps.searchengine.models import WordItem from stdnet.utils import test -from examples.models import SimpleModel - class TestCase(test.TestWrite): - multipledb = 'redis' + multipledb = "redis" models = (WordItem, SimpleModel) - + def testAutoIncrement(self): a = odm.autoincrement() self.assertEqual(a.incrby, 1) self.assertEqual(a.desc, False) - self.assertEqual(str(a), 'autoincrement(1)') + self.assertEqual(str(a), "autoincrement(1)") a = odm.autoincrement(3) self.assertEqual(a.incrby, 3) self.assertEqual(a.desc, False) - self.assertEqual(str(a), 'autoincrement(3)') + self.assertEqual(str(a), "autoincrement(3)") b = -a - self.assertEqual(str(a), 'autoincrement(3)') + self.assertEqual(str(a), "autoincrement(3)") self.assertEqual(b.desc, True) - self.assertEqual(str(b), '-autoincrement(3)') - + self.assertEqual(str(b), "-autoincrement(3)") + def testSimple(self): session = self.session() - m = yield session.add(SimpleModel(code='pluto')) - w = yield session.add(WordItem(word='ciao', model_type=SimpleModel, - object_id=m.id)) + m = yield session.add(SimpleModel(code="pluto")) + w = yield session.add( + WordItem(word="ciao", model_type=SimpleModel, object_id=m.id) + ) yield self.async.assertEqual(session.query(WordItem).count(), 1) - w = yield session.add(WordItem(word='ciao', model_type=SimpleModel, - object_id=m.id)) + w = yield session.add( + WordItem(word="ciao", model_type=SimpleModel, object_id=m.id) + ) yield self.async.assertEqual(session.query(WordItem).count(), 1) self.assertEqual(w.get_state().score, 2) # - w = yield session.add(WordItem(word='ciao', model_type=SimpleModel, - object_id=m.id)) + w = yield session.add( + WordItem(word="ciao", model_type=SimpleModel, object_id=m.id) + ) yield self.async.assertEqual(session.query(WordItem).count(), 1) - self.assertEqual(w.get_state().score, 3) \ No newline at end of file + self.assertEqual(w.get_state().score, 3) diff --git a/tests/all/lib/local.py b/tests/all/lib/local.py index d17dc3b..370a663 100644 --- a/tests/all/lib/local.py +++ b/tests/all/lib/local.py @@ -1,40 +1,40 @@ from stdnet import odm from stdnet.utils import test + class TestModel(test.TestCase): multipledb = False - + def test_create(self): - User = odm.create_model('User', 'name', 'email', 'name') + User = odm.create_model("User", "name", "email", "name") self.assertTrue(isinstance(User, odm.ModelType)) - self.assertEqual(User._meta.attributes, ('name', 'email')) - + self.assertEqual(User._meta.attributes, ("name", "email")) + def test_create_name(self): - User = odm.create_model('UserBase', 'name', 'email', 'name', - abstract=True) - self.assertEqual(User.__name__, 'UserBase') + User = odm.create_model("UserBase", "name", "email", "name", abstract=True) + self.assertEqual(User.__name__, "UserBase") self.assertTrue(User._meta.abstract) self.assertRaises(AttributeError, User._meta.pkname) - + def test_init(self): - User = odm.create_model('User', 'name', 'email') - user = User(name='luca') - self.assertEqual(user.name, 'luca') + User = odm.create_model("User", "name", "email") + user = User(name="luca") + self.assertEqual(user.name, "luca") self.assertEqual(user.email, None) - self.assertRaises(ValueError, User, bla='foo') - + self.assertRaises(ValueError, User, bla="foo") + def test_init_args(self): - User = odm.create_model('User', 'name', 'email') - user = User('luca') - self.assertEqual(user.name, 'luca') + User = odm.create_model("User", "name", "email") + user = User("luca") + self.assertEqual(user.name, "luca") self.assertEqual(user.email, None) - user = User('bla', 'bla@foo') - self.assertEqual(user.name, 'bla') - self.assertEqual(user.email, 'bla@foo') - self.assertRaises(ValueError, User, 'foo', 'jhjh', 'gjgj') - + user = User("bla", "bla@foo") + self.assertEqual(user.name, "bla") + self.assertEqual(user.email, "bla@foo") + self.assertRaises(ValueError, User, "foo", "jhjh", "gjgj") + def test_router(self): models = odm.Router() - User = odm.create_model('User', 'name', 'email', 'name') + User = odm.create_model("User", "name", "email", "name") models.register(User) - self.assertEqual(models.user.model, User) \ No newline at end of file + self.assertEqual(models.user.model, User) diff --git a/tests/all/lib/me.py b/tests/all/lib/me.py index fe03b5c..ef94d1d 100755 --- a/tests/all/lib/me.py +++ b/tests/all/lib/me.py @@ -1,21 +1,20 @@ -from stdnet.utils import test -from stdnet import settings import stdnet as me +from stdnet import settings +from stdnet.utils import test class TestInitFile(test.TestCase): multipledb = False - + def test_version(self): self.assertTrue(len(me.VERSION), 5) version = me.__version__ self.assertTrue(version) self.assertEqual(me.__version__, me.get_version(me.VERSION)) - + def testStdnetVersion(self): self.assertRaises(TypeError, me.stdnet_version, 1, 2, 3, 4, 5) def test_meta(self): for m in ("__author__", "__contact__", "__homepage__", "__doc__"): self.assertTrue(getattr(me, m, None)) - diff --git a/tests/all/lib/meta.py b/tests/all/lib/meta.py index 861b367..de1d577 100755 --- a/tests/all/lib/meta.py +++ b/tests/all/lib/meta.py @@ -1,42 +1,41 @@ -'''Tests meta classes and corner cases of the library''' +"""Tests meta classes and corner cases of the library""" import inspect from datetime import datetime -from stdnet import odm -from stdnet.utils import test, pickle -from stdnet.odm import model_iterator, ModelType +from examples.data import FinanceTest, Fund, Instrument, Position +from examples.models import ComplexModel, SimpleModel -from examples.models import SimpleModel, ComplexModel -from examples.data import FinanceTest, Instrument, Fund, Position +from stdnet import odm +from stdnet.odm import ModelType, model_iterator +from stdnet.utils import pickle, test class TestInspectionAndComparison(FinanceTest): - def test_simple(self): d = odm.model_to_dict(Instrument) self.assertFalse(d) inst = yield self.session().add( - Instrument(name='erz12', type='future', ccy='EUR')) + Instrument(name="erz12", type="future", ccy="EUR") + ) d = odm.model_to_dict(inst) - self.assertTrue(len(d),3) + self.assertTrue(len(d), 3) def testEqual(self): session = self.session() - inst = yield session.add( - Instrument(name='erm12', type='future', ccy='EUR')) + inst = yield session.add(Instrument(name="erm12", type="future", ccy="EUR")) id = inst.id b = yield self.query().get(id=id) self.assertEqual(b.id, id) self.assertTrue(inst == b) self.assertFalse(inst != b) - f = yield session.add(Fund(name='bla', ccy='EUR')) + f = yield session.add(Fund(name="bla", ccy="EUR")) self.assertFalse(inst == f) self.assertTrue(inst != f) def testNotEqual(self): session = self.session() - inst = yield session.add(Instrument(name='erz22', type='future', ccy='EUR')) - inst2 = yield session.add(Instrument(name='edz24', type='future', ccy='USD')) + inst = yield session.add(Instrument(name="erz22", type="future", ccy="EUR")) + inst2 = yield session.add(Instrument(name="edz24", type="future", ccy="USD")) id = inst.id b = yield self.query().get(id=id) self.assertEqual(b.id, id) @@ -44,8 +43,8 @@ def testNotEqual(self): self.assertTrue(inst2 != b) def testHash(self): - '''Test model instance hash''' - inst = Instrument(name='erh12', type='future', ccy='EUR') + """Test model instance hash""" + inst = Instrument(name="erh12", type="future", ccy="EUR") h0 = hash(inst) self.assertTrue(h0) inst = yield self.session().add(inst) @@ -58,20 +57,19 @@ def testmodelFromHash(self): self.assertEqual(m, Instrument) def testUniqueId(self): - '''Test model instance unique id across different model''' - inst = Instrument(name='erk12', type='future', ccy='EUR') - self.assertRaises(inst.DoesNotExist, lambda : inst.uuid) + """Test model instance unique id across different model""" + inst = Instrument(name="erk12", type="future", ccy="EUR") + self.assertRaises(inst.DoesNotExist, lambda: inst.uuid) yield self.session().add(inst) - v = inst.uuid.split('.') # <>.<> - self.assertEqual(len(v),2) - self.assertEqual(v[0],inst._meta.hash) - self.assertEqual(v[1],str(inst.id)) + v = inst.uuid.split(".") # <>.<> + self.assertEqual(len(v), 2) + self.assertEqual(v[0], inst._meta.hash) + self.assertEqual(v[1], str(inst.id)) def testModelValueError(self): - self.assertRaises(ValueError, Instrument, bla='foo') - self.assertRaises(ValueError, Instrument, name='bee', bla='foo') - self.assertRaises(ValueError, Instrument, name='bee', bla='foo', - foo='pippo') + self.assertRaises(ValueError, Instrument, bla="foo") + self.assertRaises(ValueError, Instrument, name="bee", bla="foo") + self.assertRaises(ValueError, Instrument, name="bee", bla="foo", foo="pippo") class PickleSupport(test.TestCase): @@ -79,7 +77,8 @@ class PickleSupport(test.TestCase): def testSimple(self): inst = yield self.session().add( - Instrument(name='erz12', type='future', ccy='EUR')) + Instrument(name="erz12", type="future", ccy="EUR") + ) p = pickle.dumps(inst) inst2 = pickle.loads(p) self.assertEqual(inst, inst2) @@ -89,20 +88,18 @@ def testSimple(self): def testTempDictionary(self): session = self.session() - inst = yield session.add( - Instrument(name='erz17', type='future', ccy='EUR')) - self.assertTrue('cleaned_data' in inst._dbdata) + inst = yield session.add(Instrument(name="erz17", type="future", ccy="EUR")) + self.assertTrue("cleaned_data" in inst._dbdata) p = pickle.dumps(inst) inst2 = pickle.loads(p) - self.assertFalse('cleaned_data' in inst2._dbdata) + self.assertFalse("cleaned_data" in inst2._dbdata) yield session.add(inst2) - self.assertTrue('cleaned_data' in inst._dbdata) + self.assertTrue("cleaned_data" in inst._dbdata) class TestRegistration(test.TestCase): - def testModelIterator(self): - g = model_iterator('examples') + g = model_iterator("examples") self.assertTrue(inspect.isgenerator(g)) d = list(g) self.assertTrue(d) @@ -116,9 +113,10 @@ class TestStdModelMethods(test.TestCase): def testClone(self): session = self.session() - s = yield session.add(SimpleModel(code='pluto', group='planet', - cached_data='blabla')) - self.assertEqual(s.cached_data,b'blabla') + s = yield session.add( + SimpleModel(code="pluto", group="planet", cached_data="blabla") + ) + self.assertEqual(s.cached_data, b"blabla") id = self.assertEqualId(s, 1) c = s.clone() self.assertEqual(c.id, None) @@ -126,11 +124,11 @@ def testClone(self): def test_clear_cache_fields(self): fields = self.model._meta.dfields - self.assertTrue(fields['timestamp'].as_cache) - self.assertFalse(fields['timestamp'].required) - self.assertFalse(fields['timestamp'].index) + self.assertTrue(fields["timestamp"].as_cache) + self.assertFalse(fields["timestamp"].required) + self.assertFalse(fields["timestamp"].index) session = self.session() - m = yield session.add(self.model(code='bla', timestamp=datetime.now())) + m = yield session.add(self.model(code="bla", timestamp=datetime.now())) self.assertTrue(m.timestamp) m.clear_cache_fields() self.assertEqual(m.timestamp, None) @@ -145,17 +143,18 @@ class TestComplexModel(test.TestCase): def testJsonClear(self): session = self.session() - m = yield session.add(self.model(name ='bla', - data = {'italy':'rome', 'england':'london'})) - m = yield self.query().load_only('name').get(id=1) + m = yield session.add( + self.model(name="bla", data={"italy": "rome", "england": "london"}) + ) + m = yield self.query().load_only("name").get(id=1) self.assertFalse(m.has_all_data) - m.data = {'france':'paris'} + m.data = {"france": "paris"} yield session.add(m) m = yield self.query().get(id=1) - self.assertEqual(m.data,{'italy':'rome', - 'england':'london', - 'france':'paris'}) - self.assertEqual(m.data__italy,'rome') + self.assertEqual( + m.data, {"italy": "rome", "england": "london", "france": "paris"} + ) + self.assertEqual(m.data__italy, "rome") m.data = None yield session.add(m) m = yield self.query().get(id=1) diff --git a/tests/all/lib/register.py b/tests/all/lib/register.py index 4c2600f..3203541 100755 --- a/tests/all/lib/register.py +++ b/tests/all/lib/register.py @@ -1,16 +1,15 @@ -'''Test router registration''' -from stdnet import odm, AlreadyRegistered -from stdnet.utils import test - +"""Test router registration""" from examples.models import SimpleModel +from stdnet import AlreadyRegistered, odm +from stdnet.utils import test + class TestRegistration(test.TestWrite): - def register(self): router = odm.Router(self.backend) self.assertEqual(router.default_backend, self.backend) - router.register_applications('examples') + router.register_applications("examples") self.assertTrue(router) return router @@ -18,7 +17,7 @@ def test_registered_models(self): router = self.register() for meta in router.registered_models: name = meta.name - self.assertEqual(meta.app_label, 'examples') + self.assertEqual(meta.app_label, "examples") manager = router[meta] model = manager.model self.assertEqual(manager, getattr(router, name)) @@ -36,7 +35,7 @@ def test_unregister_all(self): N = len(router.registered_models) managers = router.unregister() self.assertEqual(N, len(managers)) - self.assertFalse(router.registered_models) + self.assertFalse(router.registered_models) def testFlushModel(self): router = self.register() @@ -44,10 +43,10 @@ def testFlushModel(self): def test_flush_exclude(self): models = self.register() - s = yield models.simplemodel.new(code='test') + s = yield models.simplemodel.new(code="test") all = yield models.simplemodel.all() self.assertEqual(len(all), 1) - yield models.flush(exclude=('examples.simplemodel',)) + yield models.flush(exclude=("examples.simplemodel",)) all = yield models.simplemodel.all() self.assertEqual(len(all), 1) self.assertEqual(all[0], s) @@ -57,18 +56,20 @@ def test_flush_exclude(self): def testFromUuid(self): models = self.register() - s = yield models.simplemodel.new(code='test') + s = yield models.simplemodel.new(code="test") uuid = s.uuid - s2 = yield models.from_uuid(s.uuid) + s2 = yield models.from_uuid(s.uuid) self.assertEqual(s, s2) - yield self.async.assertRaises(odm.StdModel.DoesNotExist, - models.from_uuid, 'ccdscscds') - yield self.async.assertRaises(odm.StdModel.DoesNotExist, - models.from_uuid, 'ccdscscds.1') - a,b = tuple(uuid.split('.')) - yield self.async.assertRaises(odm.StdModel.DoesNotExist, - models.from_uuid, '{0}.5'.format(a)) + yield self.async.assertRaises( + odm.StdModel.DoesNotExist, models.from_uuid, "ccdscscds" + ) + yield self.async.assertRaises( + odm.StdModel.DoesNotExist, models.from_uuid, "ccdscscds.1" + ) + a, b = tuple(uuid.split(".")) + yield self.async.assertRaises( + odm.StdModel.DoesNotExist, models.from_uuid, "{0}.5".format(a) + ) def testFailedHashModel(self): self.assertRaises(KeyError, odm.hashmodel, SimpleModel) - diff --git a/tests/all/multifields/hash.py b/tests/all/multifields/hash.py index 5645cc7..7e03bd6 100755 --- a/tests/all/multifields/hash.py +++ b/tests/all/multifields/hash.py @@ -1,13 +1,12 @@ -'''tests for odm.HashField''' -from stdnet.utils import test, zip, iteritems, to_string - +"""tests for odm.HashField""" from examples.models import Dictionary +from stdnet.utils import iteritems, test, to_string, zip + from .struct import MultiFieldMixin - + class HashData(test.DataGenerator): - def generate(self): self.keys = self.populate() self.values = self.populate(min_len=20, max_len=300) @@ -15,18 +14,18 @@ def generate(self): class TestHashField(MultiFieldMixin, test.TestCase): - multipledb = 'redis' + multipledb = "redis" model = Dictionary data_cls = HashData - + def defaults(self): - return {'name': self.name} - + return {"name": self.name} + def adddata(self, d): yield d.data.update(self.data.data) size = yield d.data.size() self.assertEqual(len(self.data.data), size) - + def create(self, fill=False): with self.session().begin() as t: d = t.add(self.model(name=self.name)) @@ -34,7 +33,7 @@ def create(self, fill=False): if fill: yield d.data.update(self.data.data) yield d - + def test_update(self): d = yield self.create(True) data = d.data @@ -45,7 +44,7 @@ def test_update(self): self.assertTrue(data.cache.cache) self.assertNotEqual(data.cache.cache, items) self.assertEqual(data.cache.cache, dict(items)) - + def test_add(self): d = yield self.create() self.assertTrue(d.session) @@ -64,7 +63,7 @@ def testKeys(self): for k in d.data: data.pop(k) self.assertEqual(len(data), 0) - + def testItems(self): d = yield self.create(True) data = self.data.data.copy() @@ -72,12 +71,12 @@ def testItems(self): for k, v in items: self.assertEqual(v, data.pop(k)) self.assertEqual(len(data), 0) - + def testValues(self): d = yield self.create(True) values = yield d.data.values() self.assertEqual(len(self.data.data), len(values)) - + def createN(self): with self.session().begin() as t: for name in self.names: @@ -89,27 +88,27 @@ def createN(self): with self.session().begin() as t: for m in qs: t.add(m.data) - m.data['ciao'] = 'bla' - m.data['hello'] = 'foo' - m.data['hi'] = 'pippo' - m.data['salut'] = 'luna' + m.data["ciao"] = "bla" + m.data["hello"] = "foo" + m.data["hi"] = "pippo" + m.data["salut"] = "luna" yield t.on_result - + def testloadNotSelected(self): - '''Get the model and check that no data-structure data - has been loaded.''' + """Get the model and check that no data-structure data + has been loaded.""" yield self.createN() - cache = self.model._meta.dfields['data'].get_cache_name() + cache = self.model._meta.dfields["data"].get_cache_name() qs = yield self.query().all() self.assertTrue(qs) for m in qs: data = getattr(m, cache, None) self.assertFalse(data) - + def test_load_related(self): - '''Use load_selected to load stastructure data''' + """Use load_selected to load stastructure data""" yield self.createN() - cache = self.model._meta.dfields['data'].get_cache_name() - all = yield self.query().load_related('data').all() + cache = self.model._meta.dfields["data"].get_cache_name() + all = yield self.query().load_related("data").all() for m in all: - self.assertTrue(m.data.cache.cache) \ No newline at end of file + self.assertTrue(m.data.cache.cache) diff --git a/tests/all/multifields/list.py b/tests/all/multifields/list.py index 5b291c3..6a11b58 100644 --- a/tests/all/multifields/list.py +++ b/tests/all/multifields/list.py @@ -1,18 +1,18 @@ -'''tests for odm.ListField''' -from stdnet import StructureFieldError -from stdnet.utils import test, zip, to_string - +"""tests for odm.ListField""" from examples.models import SimpleList +from stdnet import StructureFieldError +from stdnet.utils import test, to_string, zip + from .struct import MultiFieldMixin, StringData class TestListField(MultiFieldMixin, test.TestCase): model = SimpleList - attrname = 'names' + attrname = "names" def adddata(self, li): - '''Add elements to a list without using transactions.''' + """Add elements to a list without using transactions.""" with li.session.begin(): names = li.names for elem in self.data.names: @@ -47,11 +47,11 @@ def test_push_back(self): self.assertEqual(el, ne) def testPushNoSave(self): - '''Push a new value to a list field should rise an error if the object -is not saved on databse.''' + """Push a new value to a list field should rise an error if the object + is not saved on databse.""" obj = self.model() - push_back = lambda : obj.names.push_back('this should fail') - push_front = lambda : obj.names.push_front('this should also fail') + push_back = lambda: obj.names.push_back("this should fail") + push_front = lambda: obj.names.push_front("this should also fail") self.assertRaises(StructureFieldError, push_back) self.assertRaises(StructureFieldError, push_front) @@ -67,7 +67,7 @@ def test_items(self): class TestRedisListField(test.TestCase): - multipledb = ['redis'] + multipledb = ["redis"] model = SimpleList data_cls = StringData diff --git a/tests/all/multifields/set.py b/tests/all/multifields/set.py index 5fdff2d..3b7698c 100755 --- a/tests/all/multifields/set.py +++ b/tests/all/multifields/set.py @@ -1,35 +1,34 @@ -'''tests for odm.SetField''' +"""tests for odm.SetField""" from datetime import datetime from itertools import chain -from stdnet import getdb -from stdnet.utils import test, populate, zip +from examples.models import Calendar, Collection, DateValue, Group -from examples.models import Calendar, DateValue, Collection, Group +from stdnet import getdb +from stdnet.utils import populate, test, zip class ZsetData(test.DataGenerator): - def generate(self): - self.dates = self.populate('date') - self.values = self.populate('string', min_len=10, max_len=120) - - + self.dates = self.populate("date") + self.values = self.populate("string", min_len=10, max_len=120) + + class TestSetField(test.TestCase): models = (Collection, Group) - + def test_simple(self): m = yield self.session().add(self.model()) yield m.numbers.add(1) yield m.numbers.update((1, 2, 3, 4, 5)) yield self.async.assertEqual(m.numbers.size(), 5) - - + + class TestOrderedSet(test.TestCase): - multipledb = 'redis' + multipledb = "redis" models = (Calendar, DateValue) data_cls = ZsetData - + def fill(self, update=False): session = self.session() c = yield session.add(Calendar(name=self.data.random_string())) @@ -46,13 +45,13 @@ def fill(self, update=False): c.data.add(value) yield t.on_result yield c - + def test_add(self): return self.fill() - + def test_update(self): return self.fill(True) - + def test_order(self): c = yield self.fill() yield self.async.assertEqual(c.data.size(), self.data.size) @@ -62,7 +61,7 @@ def test_order(self): if dprec: self.assertTrue(event.dt >= dprec) dprec = event.dt - + def test_rank(self): c = yield self.fill() data = c.data @@ -74,5 +73,3 @@ def test_rank(self): for v in vals: ranks.append(data.rank(v)) ranks = yield self.multi_async(ranks) - - diff --git a/tests/all/multifields/string.py b/tests/all/multifields/string.py index 4118059..be910fd 100644 --- a/tests/all/multifields/string.py +++ b/tests/all/multifields/string.py @@ -1,28 +1,24 @@ -'''tests for odm.StringField''' -from stdnet.utils import test, populate, zip, iteritems, to_string - +"""tests for odm.StringField""" from examples.models import SimpleString +from stdnet.utils import iteritems, populate, test, to_string, zip + from .struct import MultiFieldMixin class TestStringField(MultiFieldMixin, test.TestCase): - multipledb = 'redis' + multipledb = "redis" model = SimpleString - + def adddata(self, li): - '''Add elements to a list without using transactions.''' + """Add elements to a list without using transactions.""" for elem in self.data.names: li.data.push_back(elem) - yield self.async.assertEqual(li.data.size(), - len(''.join(self.data.names))) - + yield self.async.assertEqual(li.data.size(), len("".join(self.data.names))) + def test_incr(self): m = yield self.session().add(self.model()) self.async.assertEqual(m.data.incr(), 1) self.async.assertEqual(m.data.incr(), 2) self.async.assertEqual(m.data.incr(3), 5) self.async.assertEqual(m.data.incr(-7), -2) - - - \ No newline at end of file diff --git a/tests/all/multifields/struct.py b/tests/all/multifields/struct.py index 6d26389..0bea76e 100644 --- a/tests/all/multifields/struct.py +++ b/tests/all/multifields/struct.py @@ -2,27 +2,27 @@ from time import sleep from stdnet import StructureFieldError -from stdnet.utils import test, populate, zip, to_string +from stdnet.utils import populate, test, to_string, zip class StringData(test.DataGenerator): - def generate(self): self.names = self.populate() - + class MultiFieldMixin(object): - '''Test class which add a couple of tests for multi fields.''' - attrname = 'data' + """Test class which add a couple of tests for multi fields.""" + + attrname = "data" data_cls = StringData - + def setUp(self): - self.names = test.populate('string', size=10) + self.names = test.populate("string", size=10) self.name = self.names[0] - + def defaults(self): return {} - + def get_object_and_field(self, save=True, **kwargs): models = self.mapper params = self.defaults() @@ -31,17 +31,18 @@ def get_object_and_field(self, save=True, **kwargs): if save: yield models.session().add(m) yield m, getattr(m, self.attrname) - + def adddata(self, obj): raise NotImplementedError - + def test_RaiseStructFieldError(self): - yield self.async.assertRaises(StructureFieldError, - self.get_object_and_field, False) - + yield self.async.assertRaises( + StructureFieldError, self.get_object_and_field, False + ) + def test_multi_field_meta(self): - '''Here we check for multifield specific stuff like the instance -related keys (keys which are related to the instance rather than the model).''' + """Here we check for multifield specific stuff like the instance + related keys (keys which are related to the instance rather than the model).""" # get instance and field, the field has no data here models = self.mapper # @@ -56,9 +57,9 @@ def test_multi_field_meta(self): self.assertEqual(be.backend, models[self.model].backend) self.assertEqual(be.instance, field) # - if be.backend.name == 'redis': + if be.backend.name == "redis": yield self.check_redis_structure(obj, be) - + def check_redis_structure(self, obj, be): session = obj.session backend = be.backend @@ -74,7 +75,7 @@ def check_redis_structure(self, obj, be): # Lets add data yield self.adddata(obj) # The field id should be in the server keys - if backend.name == 'redis': + if backend.name == "redis": lkeys = yield backend.model_keys(self.model._meta) self.assertTrue(be.id in lkeys) # diff --git a/tests/all/multifields/timeseries.py b/tests/all/multifields/timeseries.py index c25bdb0..3a0d930 100644 --- a/tests/all/multifields/timeseries.py +++ b/tests/all/multifields/timeseries.py @@ -1,52 +1,51 @@ -'''tests for odm.TimeSeriesField''' +"""tests for odm.TimeSeriesField""" import os from datetime import date, datetime from random import uniform -from stdnet import odm -from stdnet.utils import test, todate, zip, dategenerator,\ - default_parse_interval +from examples.tsmodels import DateTimeSeries, TimeSeries -from examples.tsmodels import TimeSeries, DateTimeSeries +from stdnet import odm +from stdnet.utils import dategenerator, default_parse_interval, test, todate, zip from .struct import MultiFieldMixin class TsData(test.DataGenerator): - def generate(self): - self.dates = self.populate('date') - self.values = self.populate('float', start=10, end=400) - self.dates2 = self.populate('date', start=date(2009,1,1), - end=date(2010,1,1)) + self.dates = self.populate("date") + self.values = self.populate("float", start=10, end=400) + self.dates2 = self.populate( + "date", start=date(2009, 1, 1), end=date(2010, 1, 1) + ) self.big_strings = self.populate(min_len=300, max_len=1000) - self.alldata = list(zip(self.dates, self.values)) - self.alldata2 = list(zip(self.dates2, self.values)) - self.testdata = dict(self.alldata) + self.alldata = list(zip(self.dates, self.values)) + self.alldata2 = list(zip(self.dates2, self.values)) + self.testdata = dict(self.alldata) self.testdata2 = dict(self.alldata2) class TestDateTimeSeries(MultiFieldMixin, test.TestCase): - multipledb = 'redis' + multipledb = "redis" model = TimeSeries mkdate = datetime data_cls = TsData - + def defaults(self): - return {'ticker': self.name} - + return {"ticker": self.name} + def adddata(self, obj, data=None): data = data or self.data.testdata yield obj.data.update(data) yield self.async.assertEqual(obj.data.size(), len(data)) - + def make(self, name=None): return self.session().add(self.model(ticker=name or self.name)) - + def get(self, name=None): name = name or self.name return self.query().get(ticker=name) - + def filldata(self, data=None, name=None): d = yield self.make(name=name) yield self.adddata(d, data) @@ -54,18 +53,18 @@ def filldata(self, data=None, name=None): def interval(self, a, b, targets, C, D): ts = yield self.get() - intervals = ts.intervals(a,b) - self.assertEqual(len(intervals),len(targets)) - for interval,target in zip(intervals,targets): + intervals = ts.intervals(a, b) + self.assertEqual(len(intervals), len(targets)) + for interval, target in zip(intervals, targets): x = interval[0] y = interval[1] self.assertEqual(x, target[0]) self.assertEqual(y, target[1]) - for dt in dategenerator(x,y): - ts.data.add(dt,uniform(0,1)) - self.assertEqual(ts.data_start,C) - self.assertEqual(ts.data_end,D) - + for dt in dategenerator(x, y): + ts.data.add(dt, uniform(0, 1)) + self.assertEqual(ts.data_start, C) + self.assertEqual(ts.data_end, D) + def testFrontBack(self): ts = yield self.make() self.assertEqual(ts.data_start, None) @@ -73,143 +72,148 @@ def testFrontBack(self): mkdate = self.mkdate ts.data.update(self.data.testdata2) start = ts.data_start - end = ts.data_end + end = ts.data_end p = start for d in ts.dates(): - self.assertTrue(d>=p) + self.assertTrue(d >= p) p = d - self.assertEqual(d,end) - + self.assertEqual(d, end) + def testkeys(self): ts = yield self.filldata() keyp = None data = self.data.testdata.copy() for key in ts.dates(): if keyp: - self.assertTrue(key,keyp) + self.assertTrue(key, keyp) keyp = key data.pop(todate(key)) - self.assertEqual(len(data),0) - + self.assertEqual(len(data), 0) + def testitems(self): ts = yield self.filldata() keyp = None data = self.data.testdata.copy() - for key,value in ts.items(): + for key, value in ts.items(): if keyp: - self.assertTrue(key,keyp) + self.assertTrue(key, keyp) keyp = key - self.assertEqual(data.pop(todate(key)),value) - self.assertEqual(len(data),0) - + self.assertEqual(data.pop(todate(key)), value) + self.assertEqual(len(data), 0) + def testUpdate(self): ts = yield self.make() - dt1 = self.mkdate(2010,5,6) - dt2 = self.mkdate(2010,6,6) + dt1 = self.mkdate(2010, 5, 6) + dt2 = self.mkdate(2010, 6, 6) ts.data[dt1] = 56 ts.data[dt2] = 88 - self.assertEqual(ts.data[dt1],56) - self.assertEqual(ts.data[dt2],88) + self.assertEqual(ts.data[dt1], 56) + self.assertEqual(ts.data[dt2], 88) ts.data[dt1] = "ciao" - self.assertEqual(ts.data[dt1],"ciao") - + self.assertEqual(ts.data[dt1], "ciao") + def testInterval(self): - '''Test interval handling''' + """Test interval handling""" ts = yield self.make() mkdate = self.mkdate - self.assertEqual(ts.data_start,None) - self.assertEqual(ts.data_end,None) + self.assertEqual(ts.data_start, None) + self.assertEqual(ts.data_end, None) # # - A1 = mkdate(2010,5,10) - B1 = mkdate(2010,5,12) - self.interval(A1,B1,[[A1,B1]],A1,B1) + A1 = mkdate(2010, 5, 10) + B1 = mkdate(2010, 5, 12) + self.interval(A1, B1, [[A1, B1]], A1, B1) # # original -> A1 B1 # request -> A2 B2 # interval -> A2 A1- # range -> A2 B1 - A2 = mkdate(2010,5,6) - B2 = mkdate(2010,5,11) - self.interval(A2,B2,[[A2,default_parse_interval(A1,-1)]],A2,B1) + A2 = mkdate(2010, 5, 6) + B2 = mkdate(2010, 5, 11) + self.interval(A2, B2, [[A2, default_parse_interval(A1, -1)]], A2, B1) # # original -> A2 B1 # request -> A3 B3 # interval -> A3 A2- B1+ B3 # range -> A3 B3 - A3 = mkdate(2010,5,4) - B3 = mkdate(2010,5,14) - self.interval(A3,B3,[[A3,default_parse_interval(A2,-1)], - [default_parse_interval(B1,1),B3]],A3,B3) + A3 = mkdate(2010, 5, 4) + B3 = mkdate(2010, 5, 14) + self.interval( + A3, + B3, + [[A3, default_parse_interval(A2, -1)], [default_parse_interval(B1, 1), B3]], + A3, + B3, + ) # # original -> A3 B3 # request -> A2 B2 # interval -> empty # range -> A3 B3 - self.interval(A2,B2,[],A3,B3) + self.interval(A2, B2, [], A3, B3) # # original -> A3 B3 # request -> A4 B4 # interval -> A4 A3- # range -> A4 B3 - A4 = mkdate(2010,4,20) - B4 = mkdate(2010,5,1) - self.interval(A4,B4,[[A4,default_parse_interval(A3,-1)]],A4,B3) + A4 = mkdate(2010, 4, 20) + B4 = mkdate(2010, 5, 1) + self.interval(A4, B4, [[A4, default_parse_interval(A3, -1)]], A4, B3) # # original -> A4 B3 # request -> A2 B5 # interval -> B3+ B5 # range -> A4 B5 - B5 = mkdate(2010,6,1) - self.interval(A2,B5,[[default_parse_interval(B3,1),B5]],A4,B5) + B5 = mkdate(2010, 6, 1) + self.interval(A2, B5, [[default_parse_interval(B3, 1), B5]], A4, B5) # # original -> A4 B5 # request -> A6 B6 # interval -> B5+ B6 # range -> A4 B6 - A6 = mkdate(2010,7,1) - B6 = mkdate(2010,8,1) - self.interval(A6,B6,[[default_parse_interval(B5,1),B6]],A4,B6) - + A6 = mkdate(2010, 7, 1) + B6 = mkdate(2010, 8, 1) + self.interval(A6, B6, [[default_parse_interval(B5, 1), B6]], A4, B6) + def testSetLen(self): ts = yield self.make() mkdate = self.mkdate - dt = mkdate(2010,7,1) - dt2 = mkdate(2010,4,1) + dt = mkdate(2010, 7, 1) + dt2 = mkdate(2010, 4, 1) data = ts.data with ts.session.begin() as t: - data.add(dt,56) + data.add(dt, 56) data[dt2] = 78 yield t.on_result yield self.async.assertEqual(ts.data.size(), 2) - yield ts.data.update({mkdate(2009,3,1):"ciao", mkdate(2009,7,4):"luca"}) - yield self.async.assertEqual(ts.data.size(),4) + yield ts.data.update({mkdate(2009, 3, 1): "ciao", mkdate(2009, 7, 4): "luca"}) + yield self.async.assertEqual(ts.data.size(), 4) yield self.async.assertTrue(dt2 in ts.data) - yield self.async.assertFalse(mkdate(2000,4,13) in ts.data) - + yield self.async.assertFalse(mkdate(2000, 4, 13) in ts.data) + def testGet(self): ts = yield self.make() mkdate = self.mkdate with ts.session.begin() as t: - dt = mkdate(2010,7,1) - dt2 = mkdate(2010,4,1) - ts.data.add(dt,56) + dt = mkdate(2010, 7, 1) + dt2 = mkdate(2010, 4, 1) + ts.data.add(dt, 56) ts.data[dt2] = 78 yield t.on_result yield self.async.assertEqual(ts.size(), 2) yield self.async.assertEqual(ts.data.get(dt), 56) yield self.async.assertEqual(ts.data[dt2], 78) - yield self.async.assertRaises(KeyError, lambda : ts.data[mkdate(2010,3,1)]) - yield self.async.assertEqual(ts.data.get(mkdate(2010,3,1)), None) - + yield self.async.assertRaises(KeyError, lambda: ts.data[mkdate(2010, 3, 1)]) + yield self.async.assertEqual(ts.data.get(mkdate(2010, 3, 1)), None) + def testRange(self): - '''Test the range (by time) command''' + """Test the range (by time) command""" ts = yield self.filldata(testdata2) - d1 = date(2009,4,1) - d2 = date(2009,11,1) - data = list(ts.data.range(d1,d2)) + d1 = date(2009, 4, 1) + d2 = date(2009, 11, 1) + data = list(ts.data.range(d1, d2)) self.assertTrue(data) - + def testloadrelated(self): yield self.make() session = self.session() @@ -226,95 +230,95 @@ def testloadrelated(self): yield t.on_result yield self.async.assertTrue(m1.size()) yield self.async.assertTrue(m2.size()) - qs = yield qm.filter(ticker=self.names[:2]).load_related('data').all() + qs = yield qm.filter(ticker=self.names[:2]).load_related("data").all() for m in qs: self.assertTrue(m.data.cache.cache) - + def testitems2(self): ts = yield self.filldata(self.data.testdata2) - for k,v in ts.data.items(): + for k, v in ts.data.items(): self.assertEqual(v, self.data.testdata2[todate(k)]) - + def testiRange(self): ts = yield self.filldata(self.data.testdata2) - N = ts.data.size() + N = ts.data.size() self.assertTrue(N) - a = int(N/4) - b = 3*a - r1 = list(ts.data.irange(0,a)) - r2 = list(ts.data.irange(a,b)) - r3 = list(ts.data.irange(b,-1)) - self.assertEqual(r1[-1],r2[0]) - self.assertEqual(r2[-1],r3[0]) - self.assertEqual(r1[0],ts.data.front()) - self.assertEqual(r3[-1],ts.data.back()) - T = len(r1)+len(r2)+len(r3) - self.assertEqual(T,N+2) - self.assertEqual(len(r1),a+1) - self.assertEqual(len(r2),b-a+1) - self.assertEqual(len(r3),N-b) - + a = int(N / 4) + b = 3 * a + r1 = list(ts.data.irange(0, a)) + r2 = list(ts.data.irange(a, b)) + r3 = list(ts.data.irange(b, -1)) + self.assertEqual(r1[-1], r2[0]) + self.assertEqual(r2[-1], r3[0]) + self.assertEqual(r1[0], ts.data.front()) + self.assertEqual(r3[-1], ts.data.back()) + T = len(r1) + len(r2) + len(r3) + self.assertEqual(T, N + 2) + self.assertEqual(len(r1), a + 1) + self.assertEqual(len(r2), b - a + 1) + self.assertEqual(len(r3), N - b) + def __testiRangeTransaction(self): ts = self.filldata(self.data.testdata2) - N = ts.data.size() + N = ts.data.size() self.assertTrue(N) - a = int(N/4) - b = 3*a + a = int(N / 4) + b = 3 * a with self.session().begin() as t: - ts.data.irange(0,a,t) - ts.data.irange(a,b,t) - ts.data.irange(b,-1,t) + ts.data.irange(0, a, t) + ts.data.irange(a, b, t) + ts.data.irange(b, -1, t) ts.data.front(t) ts.data.back(t) - c = lambda x : x if isinstance(x,date) else list(x) - r1,r2,r3,front,back = [c(r) for r in t.get_result()] - self.assertEqual(r1[-1],r2[0]) - self.assertEqual(r2[-1],r3[0]) - self.assertEqual(r1[0][0],front) - self.assertEqual(r3[-1][0],back) - T = len(r1)+len(r2)+len(r3) - self.assertEqual(T,N+2) - self.assertEqual(len(r1),a+1) - self.assertEqual(len(r2),b-a+1) - self.assertEqual(len(r3),N-b) - + c = lambda x: x if isinstance(x, date) else list(x) + r1, r2, r3, front, back = [c(r) for r in t.get_result()] + self.assertEqual(r1[-1], r2[0]) + self.assertEqual(r2[-1], r3[0]) + self.assertEqual(r1[0][0], front) + self.assertEqual(r3[-1][0], back) + T = len(r1) + len(r2) + len(r3) + self.assertEqual(T, N + 2) + self.assertEqual(len(r1), a + 1) + self.assertEqual(len(r2), b - a + 1) + self.assertEqual(len(r3), N - b) + def testRange(self): ts = yield self.filldata(self.data.testdata2) - N = ts.data.size() - a = int(N/4) - b = 3*a + N = ts.data.size() + a = int(N / 4) + b = 3 * a r1 = list(ts.data.irange(0, a)) r2 = list(ts.data.irange(a, b)) r3 = list(ts.data.irange(b, -1)) r4 = list(ts.data.range(r2[0][0], r2[-1][0])) self.assertEqual(r4[0], r2[0]) self.assertEqual(r4[-1], r2[-1]) - + def __testRangeTransaction(self): ts = self.filldata(self.data.testdata2) - N = ts.data.size() - a = int(N/4) - b = 3*a + N = ts.data.size() + a = int(N / 4) + b = 3 * a with self.session().begin() as t: ts.data.irange(0, a, t) ts.data.irange(a, b, t) ts.data.irange(b, -1, t) - r1,r2,r3 = [list(r) for r in t.get_result()] + r1, r2, r3 = [list(r) for r in t.get_result()] with self.session().begin() as t: - ts.data.range(r2[0][0],r2[-1][0],t) + ts.data.range(r2[0][0], r2[-1][0], t) r4 = [list(r) for r in t.get_result()][0] - self.assertEqual(r4[0],r2[0]) - self.assertEqual(r4[-1],r2[-1]) - + self.assertEqual(r4[0], r2[0]) + self.assertEqual(r4[-1], r2[-1]) + def testCount(self): ts = yield self.filldata(self.data.testdata2) - N = ts.data.size() - a = int(N/4) - b = 3*a - r1 = list(ts.data.irange(0,a)) - r2 = list(ts.data.irange(a,b)) - r3 = list(ts.data.irange(b,-1)) - self.assertEqual(ts.data.count(r2[0][0],r2[-1][0]),b-a+1) + N = ts.data.size() + a = int(N / 4) + b = 3 * a + r1 = list(ts.data.irange(0, a)) + r2 = list(ts.data.irange(a, b)) + r3 = list(ts.data.irange(b, -1)) + self.assertEqual(ts.data.count(r2[0][0], r2[-1][0]), b - a + 1) class TestDateSeries(TestDateTimeSeries): diff --git a/tests/all/query/contains.py b/tests/all/query/contains.py index f053f4e..dac8682 100644 --- a/tests/all/query/contains.py +++ b/tests/all/query/contains.py @@ -1,53 +1,58 @@ +from examples.models import SimpleModel +from examples.wordsearch.basicwords import basic_english_words + from stdnet.utils import test from stdnet.utils.py2py3 import zip -from examples.models import SimpleModel -from examples.wordsearch.basicwords import basic_english_words class TextGenerator(test.DataGenerator): - sizes = {'tiny': (20, 5), - 'small': (100, 20), - 'normal': (500, 50), - 'big': (2000, 100), - 'huge': (10000, 200)} - + sizes = { + "tiny": (20, 5), + "small": (100, 20), + "normal": (500, 50), + "big": (2000, 100), + "huge": (10000, 200), + } + def generate(self): size, words = self.size self.descriptions = [] - self.names = self.populate('string', size, min_len=10, max_len=30) + self.names = self.populate("string", size, min_len=10, max_len=30) for s in range(size): - d = ' '.join(self.populate('choice', words, choice_from=basic_english_words)) + d = " ".join( + self.populate("choice", words, choice_from=basic_english_words) + ) self.descriptions.append(d) - - + + class TestFieldSerach(test.TestCase): model = SimpleModel data_cls = TextGenerator - + @classmethod def after_setup(cls): with cls.session().begin() as t: for name, des in zip(cls.data.names, cls.data.descriptions): t.add(cls.model(code=name, description=des)) yield t.on_result - + def test_contains(self): session = self.session() qs = session.query(self.model) - all = yield qs.filter(description__contains='ll').all() + all = yield qs.filter(description__contains="ll").all() self.assertTrue(all) for m in all: - self.assertTrue('ll' in m.description) - all = yield qs.filter(description__contains='llllll').all() + self.assertTrue("ll" in m.description) + all = yield qs.filter(description__contains="llllll").all() self.assertFalse(all) - + def test_startswith(self): session = self.session() qs = session.query(self.model) all = yield qs.all() count = {} for m in all: - start = m.description.split(' ')[0][:2] + start = m.description.split(" ")[0][:2] if start in count: count[start] += 1 else: @@ -58,4 +63,4 @@ def test_startswith(self): self.assertTrue(all) for m in all: self.assertTrue(m.description.startswith(start)) - self.assertEqual(len(all), count[start]) \ No newline at end of file + self.assertEqual(len(all), count[start]) diff --git a/tests/all/query/delete.py b/tests/all/query/delete.py index 41e98fa..ac84c27 100755 --- a/tests/all/query/delete.py +++ b/tests/all/query/delete.py @@ -1,18 +1,17 @@ -'''Delete objects and queries''' +"""Delete objects and queries""" import datetime from random import randint +from examples.data import FinanceTest, finance_data +from examples.models import Dictionary, Fund, Instrument, Position, SimpleModel + from stdnet import odm from stdnet.utils import test, zip -from examples.models import Instrument, Fund, Position, Dictionary, SimpleModel -from examples.data import finance_data, FinanceTest - class DictData(test.DataGenerator): - def generate(self): - self.keys = self.populate(min_len=5, max_len=20) + self.keys = self.populate(min_len=5, max_len=20) self.values = self.populate(min_len=20, max_len=300) self.data = dict(zip(self.keys, self.values)) @@ -25,7 +24,7 @@ def test_session_delete(self): session = self.session() query = session.query(self.model) with session.begin() as t: - m = t.add(self.model(code='ciao')) + m = t.add(self.model(code="ciao")) yield t.on_result elem = yield query.get(id=m.id) with session.begin() as t: @@ -37,13 +36,13 @@ def test_session_delete(self): def testSimpleQuery(self): session = self.session() with session.begin() as t: - t.add(self.model(code='hello')) - t.add(self.model(code='hello2')) + t.add(self.model(code="hello")) + t.add(self.model(code="hello2")) yield t.on_result - query = session.query(self.model).filter(code=('hello','hello2')) + query = session.query(self.model).filter(code=("hello", "hello2")) yield self.async.assertEqual(query.count(), 2) yield query.delete() - query = session.query(self.model).filter(code=('hello','hello2')) + query = session.query(self.model).filter(code=("hello", "hello2")) all = yield query.all() self.assertEqual(all, []) @@ -51,29 +50,29 @@ def test_simple_filter(self): session = self.session() query = session.query(self.model) with session.begin() as t: - t.add(self.model(code='sun', group='star')) - t.add(self.model(code='vega', group='star')) - t.add(self.model(code='sirus', group='star')) - t.add(self.model(code='pluto', group='planet')) + t.add(self.model(code="sun", group="star")) + t.add(self.model(code="vega", group="star")) + t.add(self.model(code="sirus", group="star")) + t.add(self.model(code="pluto", group="planet")) yield t.on_result with session.begin() as t: - t.delete(query.filter(group='star')) + t.delete(query.filter(group="star")) yield t.on_result - yield self.async.assertEqual(query.filter(group='star').count(), 0) - rest = query.exclude(group='star').count() + yield self.async.assertEqual(query.filter(group="star").count(), 0) + rest = query.exclude(group="star").count() self.assertTrue(rest) - qs = query.filter(group='planet') + qs = query.filter(group="planet") yield self.async.assertEqual(qs.count(), 1) class update_model(object): - def __init__(self, test): self.test = test self.session = None - def __call__(self, signal, sender, instances=None, session=None, - transaction=None, **kwargs): + def __call__( + self, signal, sender, instances=None, session=None, transaction=None, **kwargs + ): self.session = session self.instances = instances self.transaction = transaction @@ -94,8 +93,8 @@ def tearDown(self): def testSignal(self): session = self.session() with session.begin() as t: - m = t.add(self.model(code='ciao')) - m = t.add(self.model(code='bla')) + m = t.add(self.model(code="ciao")) + m = t.add(self.model(code="bla")) yield t.on_result deleted = yield session.query(self.model).delete() u = self.update_model @@ -104,7 +103,8 @@ def testSignal(self): class TestDeleteMethod(test.TestWrite): - '''Test the delete method in models and in queries.''' + """Test the delete method in models and in queries.""" + data_cls = finance_data models = (Instrument, Fund, Position) @@ -121,13 +121,13 @@ def testDeleteMultiQueries(self): session = yield self.data.create(self) query = session.query(Instrument) with session.begin() as t: - t.delete(query.filter(ccy='EUR')) - t.delete(query.filter(type=('future','bond'))) + t.delete(query.filter(ccy="EUR")) + t.delete(query.filter(type=("future", "bond"))) yield t.on_result all = yield query.all() for inst in all: - self.assertFalse(inst.type in ('future','bond')) - self.assertNotEqual(inst.ccy,'EUR') + self.assertFalse(inst.type in ("future", "bond")) + self.assertNotEqual(inst.ccy, "EUR") class TestDeleteScalarFields(test.TestWrite): @@ -135,14 +135,14 @@ class TestDeleteScalarFields(test.TestWrite): models = (Instrument, Fund, Position) def test_flush_simple_model(self): - '''Use the class method flush to remove all instances of a - Model including filters.''' + """Use the class method flush to remove all instances of a + Model including filters.""" session = yield self.data.create(self) deleted = yield session.query(Instrument).delete() yield self.async.assertEqual(session.query(Instrument).all(), []) yield self.async.assertEqual(session.query(Position).all(), []) keys = yield session.keys(Instrument) - if self.backend == 'redis': + if self.backend == "redis": self.assertTrue(len(keys) > 0) def testFlushRelatedModel(self): @@ -152,28 +152,28 @@ def testFlushRelatedModel(self): yield self.async.assertEqual(session.query(Instrument).all(), []) yield self.async.assertEqual(session.query(Position).all(), []) keys = yield session.keys(Instrument) - if self.backend == 'redis': + if self.backend == "redis": self.assertTrue(len(keys) > 0) def testDeleteSimple(self): - '''Test delete on models without related models''' + """Test delete on models without related models""" session = yield self.data.create(self) t = yield session.query(Instrument).delete() all = yield session.query(Instrument).all() self.assertEqual(all, []) # There should be only keys for indexes and auto id backend = session.model(Instrument).backend - if backend.name == 'redis': + if backend.name == "redis": keys = yield session.keys(Instrument) self.assertEqual(len(keys), 1) - self.assertEqual(keys[0], backend.basekey(Instrument._meta, 'ids')) + self.assertEqual(keys[0], backend.basekey(Instrument._meta, "ids")) yield session.flush(Instrument) keys = yield session.keys(Instrument) self.assertEqual(len(keys), 0) def testDeleteRelatedOneByOne(self): - '''Test delete on models with related models. This is a crucial -test as it involves lots of operations and consistency checks.''' + """Test delete on models with related models. This is a crucial + test as it involves lots of operations and consistency checks.""" # Create Positions which hold foreign keys to Instruments session = yield self.data.makePositions(self) instruments = yield session.query(Instrument).all() @@ -185,8 +185,8 @@ def testDeleteRelatedOneByOne(self): yield self.async.assertEqual(session.query(Position).all(), []) def testDeleteRelated(self): - '''Test delete on models with related models. This is a crucial -test as it involves lots of operations and consistency checks.''' + """Test delete on models with related models. This is a crucial + test as it involves lots of operations and consistency checks.""" # Create Positions which hold foreign keys to Instruments session = yield self.data.makePositions(self) yield session.query(Instrument).delete() @@ -194,15 +194,15 @@ def testDeleteRelated(self): yield self.async.assertEqual(session.query(Position).all(), []) def __testDeleteRelatedCounting(self): - '''Test delete on models with related models. This is a crucial -test as it involves lots of operations and consistency checks.''' + """Test delete on models with related models. This is a crucial + test as it involves lots of operations and consistency checks.""" # Create Positions which hold foreign keys to Instruments NP = 20 N = Instrument.objects.all().count() + NP self.makePositions(NP) Instrument.objects.all().delete() - self.assertEqual(Instrument.objects.all().count(),0) - self.assertEqual(Position.objects.all().count(),0) + self.assertEqual(Instrument.objects.all().count(), 0) + self.assertEqual(Position.objects.all().count(), 0) class TestDeleteStructuredFields(test.TestWrite): @@ -212,8 +212,8 @@ class TestDeleteStructuredFields(test.TestWrite): def setUp(self): session = self.session() with session.begin() as t: - t.add(Dictionary(name='test')) - t.add(Dictionary(name='test2')) + t.add(Dictionary(name="test")) + t.add(Dictionary(name="test2")) yield t.on_result yield self.async.assertEqual(session.query(Dictionary).count(), 2) @@ -238,13 +238,13 @@ def testSimpleFlush(self): self.assertEqual(len(keys), 0) def test_flush_with_data(self): - yield self.fill('test') - yield self.fill('test2') + yield self.fill("test") + yield self.fill("test2") session = self.session() yield session.flush(Dictionary) yield self.async.assertEqual(session.query(Dictionary).count(), 0) # Now we check the database if it is empty as it should backend = self.mapper.dictionary.backend - if backend.name == 'redis': + if backend.name == "redis": keys = yield session.keys(Dictionary) self.assertEqual(keys, []) diff --git a/tests/all/query/get_field.py b/tests/all/query/get_field.py index 6e2bb4f..19a26a3 100644 --- a/tests/all/query/get_field.py +++ b/tests/all/query/get_field.py @@ -1,13 +1,11 @@ -'''Test query.get_field method for obtaining a single field from a query.''' -from stdnet.utils import test, zip, is_string +"""Test query.get_field method for obtaining a single field from a query.""" +from examples.data import CCYS_TYPES, INSTS_TYPES, FinanceTest +from examples.models import AnalyticData, Fund, Group, Instrument, ObjectAnalytics -from examples.models import Instrument, ObjectAnalytics,\ - AnalyticData, Group, Fund -from examples.data import FinanceTest, INSTS_TYPES, CCYS_TYPES +from stdnet.utils import is_string, test, zip class TestInstrument(FinanceTest): - @classmethod def after_setup(cls): yield cls.data.create(cls) @@ -15,9 +13,9 @@ def after_setup(cls): def testName(self): session = self.session() all = yield session.query(self.model).all() - qb = dict(((i.name,i) for i in all)) - qs = session.query(self.model).get_field('name') - self.assertEqual(qs._get_field, 'name') + qb = dict(((i.name, i) for i in all)) + qs = session.query(self.model).get_field("name") + self.assertEqual(qs._get_field, "name") result = yield qs.all() self.assertTrue(result) for r in result: @@ -27,8 +25,8 @@ def testName(self): def testId(self): session = self.session() - qs = session.query(self.model).get_field('id') - self.assertEqual(qs._get_field, 'id') + qs = session.query(self.model).get_field("id") + self.assertEqual(qs._get_field, "id") result = yield qs.all() self.assertTrue(result) for r in result: @@ -36,15 +34,14 @@ def testId(self): class TestRelated(FinanceTest): - @classmethod def after_setup(cls): return cls.data.makePositions(cls) def testInstrument(self): models = self.mapper - qs = models.position.query().get_field('instrument') - self.assertEqual(qs._get_field, 'instrument') + qs = models.position.query().get_field("instrument") + self.assertEqual(qs._get_field, "instrument") result = yield qs.all() self.assertTrue(result) for r in result: @@ -52,7 +49,7 @@ def testInstrument(self): def testFilter(self): models = self.mapper - qs = models.position.query().get_field('instrument') + qs = models.position.query().get_field("instrument") qi = models.instrument.filter(id=qs) inst = yield qi.all() ids = yield qs.all() @@ -64,18 +61,20 @@ def testFilter(self): class generator(test.DataGenerator): - sizes = {'tiny': (2,10), - 'small': (5,30), - 'normal': (10,50), - 'big': (30,200), - 'huge': (100,1000)} + sizes = { + "tiny": (2, 10), + "small": (5, 30), + "normal": (10, 50), + "big": (30, 200), + "huge": (100, 1000), + } def generate(self): group_len, obj_len = self.size - self.inames = self.populate('string', obj_len, min_len=5, max_len=20) - self.itypes = self.populate('choice', obj_len, choice_from=INSTS_TYPES) - self.iccys = self.populate('choice', obj_len, choice_from=CCYS_TYPES) - self.gnames = self.populate('string', group_len, min_len=5, max_len=20) + self.inames = self.populate("string", obj_len, min_len=5, max_len=20) + self.itypes = self.populate("choice", obj_len, choice_from=INSTS_TYPES) + self.iccys = self.populate("choice", obj_len, choice_from=CCYS_TYPES) + self.gnames = self.populate("string", group_len, min_len=5, max_len=20) def create(self, test): session = test.session() @@ -87,8 +86,8 @@ def create(self, test): for name, ccy in zip(self.inames, self.iccys): t.add(Fund(name=name, ccy=ccy)) yield t.on_result - iall = yield test.session().query(Instrument).load_only('id').all() - fall = yield test.session().query(Fund).load_only('id').all() + iall = yield test.session().query(Instrument).load_only("id").all() + fall = yield test.session().query(Fund).load_only("id").all() with session.begin() as t: for i in iall: t.add(ObjectAnalytics(model_type=Instrument, object_id=i.id)) @@ -98,8 +97,8 @@ def create(self, test): obj_len = self.size[1] groups = yield session.query(Group).all() objs = yield session.query(ObjectAnalytics).all() - groups = self.populate('choice', obj_len, choice_from=groups) - objs = self.populate('choice', obj_len, choice_from=objs) + groups = self.populate("choice", obj_len, choice_from=groups) + objs = self.populate("choice", obj_len, choice_from=objs) with test.session().begin() as t: for g, o in zip(groups, objs): t.add(AnalyticData(group=g, object=o)) @@ -107,7 +106,8 @@ def create(self, test): class TestModelField(test.TestWrite): - '''Test the get_field method when applied to ModelField''' + """Test the get_field method when applied to ModelField""" + models = (ObjectAnalytics, AnalyticData, Group, Instrument, Fund) data_cls = generator @@ -116,29 +116,31 @@ def setUp(cls): def testLoad(self): session = self.session() - q = session.query(ObjectAnalytics)\ - .filter(model_type=Instrument).get_field('id') + q = session.query(ObjectAnalytics).filter(model_type=Instrument).get_field("id") i = session.query(Instrument).filter(id=q) yield self.async.assertEqual(i.count(), session.query(Instrument).count()) def testLoadMissing(self): models = self.mapper - yield models.instrument.filter(id=(1,2,3)).delete() - q = models.objectanalytics.filter(model_type=Instrument).get_field('id') + yield models.instrument.filter(id=(1, 2, 3)).delete() + q = models.objectanalytics.filter(model_type=Instrument).get_field("id") i = yield models.instrument.filter(id=q).all() self.assertTrue(i) def testUnion(self): models = self.mapper + def query(): - model_permissions = models.objectanalytics.filter(id=(1,2,3)) - objects = models.analyticdata.exclude(object__model_type=Instrument)\ - .get_field('object') + model_permissions = models.objectanalytics.filter(id=(1, 2, 3)) + objects = models.analyticdata.exclude( + object__model_type=Instrument + ).get_field("object") return model_permissions.union(objects).all() + result1 = yield query() self.assertTrue(result1) # Now remove some instruments - yield models.instrument.filter(id=(1,2,3)).delete() + yield models.instrument.filter(id=(1, 2, 3)).delete() # result2 = yield query() self.assertTrue(result2) diff --git a/tests/all/query/instruments.py b/tests/all/query/instruments.py index 7324b41..2970759 100755 --- a/tests/all/query/instruments.py +++ b/tests/all/query/instruments.py @@ -1,47 +1,46 @@ -'''Test query.filter and query.exclude''' -from stdnet.utils import test - -from examples.models import Instrument2, Fund, Position +"""Test query.filter and query.exclude""" from examples import data +from examples.models import Fund, Instrument2, Position + +from stdnet.utils import test class TestFilter(data.FinanceTest): - @classmethod def after_setup(cls): return cls.data.create(cls) - + def testAll(self): session = self.session() qs = session.query(self.model) c = yield qs.count() self.assertTrue(c > 0) - + def testSimpleFilterId(self): session = self.session() query = session.query(self.model) - all = yield session.query(self.model).load_only('id').all() + all = yield session.query(self.model).load_only("id").all() qs = yield query.filter(id=all[0].id).all() obj = qs[0] self.assertEqual(obj.id, all[0].id) self.assertEqual(obj._loadedfields, None) - + def testSimpleFilter(self): session = self.session() - qs = yield session.query(self.model).filter(ccy='USD').all() + qs = yield session.query(self.model).filter(ccy="USD").all() self.assertTrue(qs) for i in qs: - self.assertEqual(i.ccy, 'USD') - + self.assertEqual(i.ccy, "USD") + def testFilterIn(self): session = self.session() qs = session.query(self.model) - eur = yield qs.filter(ccy='EUR').all() - usd = yield qs.filter(ccy='USD').all() - eur = dict(((o.id,o) for o in eur)) - usd = dict(((o.id,o) for o in usd)) + eur = yield qs.filter(ccy="EUR").all() + usd = yield qs.filter(ccy="USD").all() + eur = dict(((o.id, o) for o in eur)) + usd = dict(((o.id, o) for o in usd)) all = set(eur).union(set(usd)) - CCYS = ('EUR', 'USD') + CCYS = ("EUR", "USD") qs = yield qs.filter(ccy=CCYS).all() us = set() for inst in qs: @@ -49,36 +48,36 @@ def testFilterIn(self): self.assertTrue(inst.ccy in CCYS) zero = all - us self.assertTrue(qs) - self.assertEqual(len(zero),0) - + self.assertEqual(len(zero), 0) + def testDoubleFilter(self): session = self.session() - for ccy in ('EUR','USD','JPY'): - for type in ('equity','bond','future'): + for ccy in ("EUR", "USD", "JPY"): + for type in ("equity", "bond", "future"): qs = session.query(self.model).filter(ccy=ccy, type=type) all = yield qs.all() if all: break if all: break - self.assertTrue(all) + self.assertTrue(all) for inst in all: self.assertEqual(inst.ccy, ccy) self.assertEqual(inst.type, type) - + def testDoubleFilterIn(self): - CCYS = ('EUR','USD') + CCYS = ("EUR", "USD") session = self.session() - qs = yield session.query(self.model).filter(ccy=CCYS, type='future') + qs = yield session.query(self.model).filter(ccy=CCYS, type="future") all = yield qs.all() self.assertTrue(all) for inst in all: self.assertTrue(inst.ccy in CCYS) - self.assertEqual(inst.type, 'future') - + self.assertEqual(inst.type, "future") + def testDoubleInFilter(self): - CCYS = ('EUR','USD','JPY') - types = ('equity','bond','future') + CCYS = ("EUR", "USD", "JPY") + types = ("equity", "bond", "future") session = self.session() qs = session.query(self.model).filter(ccy=CCYS, type=types) all = yield qs.all() @@ -86,30 +85,30 @@ def testDoubleInFilter(self): for inst in all: self.assertTrue(inst.ccy in CCYS) self.assertTrue(inst.type in types) - + def testSimpleExcludeFilter(self): session = self.session() - qs = session.query(self.model).exclude(ccy='JPY') + qs = session.query(self.model).exclude(ccy="JPY") all = yield qs.all() self.assertTrue(all) for inst in all: - self.assertNotEqual(inst.ccy, 'JPY') - + self.assertNotEqual(inst.ccy, "JPY") + def testExcludeFilterIn(self): - CCYS = ('EUR','GBP','JPY') + CCYS = ("EUR", "GBP", "JPY") session = self.session() A = yield session.query(self.model).filter(ccy=CCYS).all() B = yield session.query(self.model).exclude(ccy=CCYS).all() for inst in B: self.assertTrue(inst.ccy not in CCYS) - all = dict(((o.id,o) for o in A)) - all.update(dict(((o.id,o) for o in B))) + all = dict(((o.id, o) for o in A)) + all.update(dict(((o.id, o) for o in B))) N = yield session.query(self.model).count() self.assertEqual(len(all), N) - + def testDoubleExclude(self): - CCYS = ('EUR','GBP','JPY') - types = ('equity','bond','future') + CCYS = ("EUR", "GBP", "JPY") + types = ("equity", "bond", "future") session = self.session() qs = session.query(self.model).exclude(ccy=CCYS, type=types) all = yield qs.all() @@ -117,10 +116,10 @@ def testDoubleExclude(self): for inst in all: self.assertTrue(inst.ccy not in CCYS) self.assertTrue(inst.type not in types) - + def testExcludeAndFilter(self): - CCYS = ('EUR','GBP') - types = ('equity','bond','future') + CCYS = ("EUR", "GBP") + types = ("equity", "bond", "future") session = self.session() query = session.query(self.model) qs = query.exclude(ccy=CCYS).filter(type=types) @@ -129,21 +128,21 @@ def testExcludeAndFilter(self): for inst in all: self.assertTrue(inst.ccy not in CCYS) self.assertTrue(inst.type in types) - + def testFilterIds(self): - '''Simple filtering on ids.''' + """Simple filtering on ids.""" session = self.session() - all = yield session.query(self.model).load_only('id').all() + all = yield session.query(self.model).load_only("id").all() ids = set((all[1].id, all[5].id, all[10].id)) query = session.query(self.model) qs = yield query.filter(id=ids).all() self.assertEqual(len(qs), 3) cids = set((o.id for o in qs)) self.assertEqual(cids, ids) - + def testFilterIdExclude(self): - CCYS = ('EUR','GBP') - types = ('equity','bond','future') + CCYS = ("EUR", "GBP") + types = ("equity", "bond", "future") session = self.session() query = session.query(self.model) qs = yield query.filter(type__in=types).all() @@ -160,43 +159,43 @@ def testFilterIdExclude(self): qs1 = yield query.filter(ccy__in=CCYS).exclude(type__in=types).all() qs2 = yield query.filter(ccy__in=CCYS).exclude(id__in=qt).all() self.assertEqual(set(qs1), set(qs2)) - + def testChangeFilter(self): - '''Change the value of a filter field and perform filtering to - check for zero values''' + """Change the value of a filter field and perform filtering to + check for zero values""" session = self.session() query = session.query(self.model) - qs = query.filter(ccy='AUD') + qs = query.filter(ccy="AUD") all = yield qs.all() self.assertTrue(all) with session.begin() as t: for inst in all: - self.assertEqual(inst.ccy, 'AUD') - inst.ccy = 'USD' + self.assertEqual(inst.ccy, "AUD") + inst.ccy = "USD" t.add(inst) yield t.on_result - N = yield query.filter(ccy='AUD').count() + N = yield query.filter(ccy="AUD").count() self.assertEqual(N, 0) - + def testFilterWithSpace(self): session = self.session() - qs = session.query(self.model).filter(type='bond option') + qs = session.query(self.model).filter(type="bond option") all = yield qs.all() self.assertTrue(all) for inst in all: - self.assertEqual(inst.type,'bond option') + self.assertEqual(inst.type, "bond option") def testChainedExclude(self): session = self.session() query = session.query(self.model) - qt = query.exclude(id=(1,2,3,4)).exclude(id=(4,5,6)) - self.assertEqual(qt.eargs, {'id__in': set((1,2,3,4,5,6))}) + qt = query.exclude(id=(1, 2, 3, 4)).exclude(id=(4, 5, 6)) + self.assertEqual(qt.eargs, {"id__in": set((1, 2, 3, 4, 5, 6))}) qt = yield qt.all() res = set((q.id for q in qt)) self.assertTrue(res) - self.assertFalse(res.intersection(set((1,2,3,4,5,6)))) + self.assertFalse(res.intersection(set((1, 2, 3, 4, 5, 6)))) qt = query.exclude(id=3).exclude(id=4) - self.assertEqual(qt.eargs, {'id__in': set((3,4))}) + self.assertEqual(qt.eargs, {"id__in": set((3, 4))}) qt = yield qt.all() res = set((q.id for q in qt)) self.assertTrue(res) @@ -209,11 +208,11 @@ def testChainedExclude(self): class TestFilterOrdered(TestFilter): models = (Instrument2, Fund, Position) - + def test_instrument2(self): instrument = self.mapper.instrument self.assertEqual(instrument.model, Instrument2) - self.assertEqual(instrument._meta.app_label, 'examples2') - self.assertEqual(instrument._meta.name, 'instrument') - self.assertEqual(instrument._meta.modelkey, 'examples2.instrument') - self.assertEqual(instrument._meta.ordering.name, 'id') \ No newline at end of file + self.assertEqual(instrument._meta.app_label, "examples2") + self.assertEqual(instrument._meta.name, "instrument") + self.assertEqual(instrument._meta.modelkey, "examples2.instrument") + self.assertEqual(instrument._meta.ordering.name, "id") diff --git a/tests/all/query/load_only.py b/tests/all/query/load_only.py index 5f56c2f..1dbe258 100644 --- a/tests/all/query/load_only.py +++ b/tests/all/query/load_only.py @@ -1,252 +1,246 @@ -'''test load_only and dont_load methods''' -from stdnet.utils import test, zip +"""test load_only and dont_load methods""" +from examples.models import Group, Person, SimpleModel, Statistics3 -from examples.models import SimpleModel, Person, Group, Statistics3 +from stdnet.utils import test, zip class LoadOnlyBase(test.TestCase): model = SimpleModel - + @classmethod def after_setup(cls): with cls.session().begin() as t: - t.add(cls.model(code='a', group='group1', description='blabla')) - t.add(cls.model(code='b', group='group2', description='blabla')) - t.add(cls.model(code='c', group='group1', description='blabla')) - t.add(cls.model(code='d', group='group3', description='blabla')) - t.add(cls.model(code='e', group='group1', description='blabla')) + t.add(cls.model(code="a", group="group1", description="blabla")) + t.add(cls.model(code="b", group="group2", description="blabla")) + t.add(cls.model(code="c", group="group1", description="blabla")) + t.add(cls.model(code="d", group="group3", description="blabla")) + t.add(cls.model(code="e", group="group1", description="blabla")) return t.on_result - - + + class LoadOnly(LoadOnlyBase): - def testMeta(self): s = self.session() query = s.query(self.model) - qs = yield query.load_only('id').all() + qs = yield query.load_only("id").all() for m in qs: - self.assertEqual(m._loadedfields,()) + self.assertEqual(m._loadedfields, ()) self.assertEqual(m.has_all_data, False) - qs = yield query.load_only('code','group').all() + qs = yield query.load_only("code", "group").all() for m in qs: - self.assertEqual(m._loadedfields,('code','group')) + self.assertEqual(m._loadedfields, ("code", "group")) self.assertEqual(m.has_all_data, False) all = yield query.all() for m in all: self.assertEqual(m._loadedfields, None) self.assertEqual(m.has_all_data, True) - m = self.model(code = 'bla', group = 'foo') + m = self.model(code="bla", group="foo") self.assertEqual(m.has_all_data, False) - + def test_idonly(self): s = self.session() query = s.query(self.model) - qs = query.load_only('id') + qs = query.load_only("id") self.assertNotEqual(query, qs) - self.assertEqual(qs.fields, ('id',)) + self.assertEqual(qs.fields, ("id",)) qs = yield qs.all() self.assertTrue(all) for m in qs: - self.assertEqual(m._loadedfields,()) - self.assertEqual(tuple(m.loadedfields()),()) - self.assertFalse(hasattr(m,'code')) - self.assertFalse(hasattr(m,'group')) - self.assertFalse(hasattr(m,'description')) - self.assertTrue('id' in m._dbdata) - self.assertEqual(m._dbdata['id'], m.id) - + self.assertEqual(m._loadedfields, ()) + self.assertEqual(tuple(m.loadedfields()), ()) + self.assertFalse(hasattr(m, "code")) + self.assertFalse(hasattr(m, "group")) + self.assertFalse(hasattr(m, "description")) + self.assertTrue("id" in m._dbdata) + self.assertEqual(m._dbdata["id"], m.id) + def test_idonly_None(self): s = self.session() query = s.query(self.model) - qs = yield query.load_only('id').all() + qs = yield query.load_only("id").all() with s.begin(): for m in qs: - self.assertFalse(hasattr(m, 'description')) + self.assertFalse(hasattr(m, "description")) m.description = None s.add(m) # Check that description are empty - qs = yield query.load_only('description').all() + qs = yield query.load_only("description").all() self.assertTrue(qs) for m in qs: self.assertFalse(m.description) - + def testSimple(self): query = self.session().query(self.model) - qs = yield query.load_only('code').all() + qs = yield query.load_only("code").all() self.assertTrue(qs) for m in qs: - self.assertEqual(m._loadedfields,('code',)) + self.assertEqual(m._loadedfields, ("code",)) self.assertTrue(m.code) - self.assertFalse(hasattr(m,'group')) - self.assertFalse(hasattr(m,'description')) - qs = yield query.load_only('code','group').all() + self.assertFalse(hasattr(m, "group")) + self.assertFalse(hasattr(m, "description")) + qs = yield query.load_only("code", "group").all() self.assertTrue(qs) for m in qs: - self.assertEqual(m._loadedfields,('code','group')) + self.assertEqual(m._loadedfields, ("code", "group")) self.assertTrue(m.code) self.assertTrue(m.group) - self.assertFalse(hasattr(m,'description')) - + self.assertFalse(hasattr(m, "description")) + def testSave(self): session = self.session() query = session.query(self.model) - qs = yield query.load_only('group').all() + qs = yield query.load_only("group").all() original = dict(((m.id, m.group) for m in qs)) - yield self.async.assertEqual(query.filter(group='group1').count(), 3) + yield self.async.assertEqual(query.filter(group="group1").count(), 3) # save the models - qs = yield query.load_only('code').all() + qs = yield query.load_only("code").all() with session.begin() as t: for m in qs: t.add(m) yield t.on_result - qs = yield query.load_only('group').all() + qs = yield query.load_only("group").all() for m in qs: self.assertEqual(m.group, original[m.id]) # No check indexes - yield self.async.assertEqual(query.filter(group='group1').count(), 3) + yield self.async.assertEqual(query.filter(group="group1").count(), 3) def test_exclude_fields(self): session = self.session() - query = session.query(self.model).dont_load('description') - self.assertEqual(query.exclude_fields, ('description',)) + query = session.query(self.model).dont_load("description") + self.assertEqual(query.exclude_fields, ("description",)) qs = yield query.all() self.assertTrue(qs) for m in qs: - self.assertFalse(hasattr(m,'description')) - query = session.query(self.model).load_only('group')\ - .dont_load('description') + self.assertFalse(hasattr(m, "description")) + query = session.query(self.model).load_only("group").dont_load("description") qs = yield query.all() self.assertTrue(qs) for m in qs: - self.assertEqual(m._loadedfields,('group',)) - - -class LoadOnlyChange(LoadOnlyBase): + self.assertEqual(m._loadedfields, ("group",)) + +class LoadOnlyChange(LoadOnlyBase): def testChangeNotLoaded(self): - '''We load an object with only one field and modify a field not -loaded. The correct behavior should be to updated the field and indexes.''' + """We load an object with only one field and modify a field not + loaded. The correct behavior should be to updated the field and indexes.""" session = self.session() query = session.query(self.model) - qs = yield query.load_only('group').all() + qs = yield query.load_only("group").all() original = dict(((m.id, m.group) for m in qs)) # load only code and change the group - qs = yield query.load_only('code').all() + qs = yield query.load_only("code").all() self.assertTrue(qs) with session.begin() as t: for m in qs: - m.group = 'group4' + m.group = "group4" t.add(m) yield t.on_result - qs = query.filter(group='group1') + qs = query.filter(group="group1") yield self.async.assertEqual(qs.count(), 0) - qs = query.filter(group='group2') + qs = query.filter(group="group2") yield self.async.assertEqual(qs.count(), 0) - qs = query.filter(group='group3') + qs = query.filter(group="group3") yield self.async.assertEqual(qs.count(), 0) - qs = query.filter(group='group4') + qs = query.filter(group="group4") yield self.async.assertEqual(qs.count(), 5) qs = yield qs.all() self.assertTrue(qs) for m in qs: - self.assertEqual(m.group,'group4') - - + self.assertEqual(m.group, "group4") + + class LoadOnlyDelete(LoadOnlyBase): - def test_idonly_delete(self): query = self.session().query(self.model) - yield query.load_only('id').delete() - qs = query.filter(group='group1') + yield query.load_only("id").delete() + qs = query.filter(group="group1") yield self.async.assertEqual(qs.count(), 0) qs = yield query.all() self.assertEqual(qs, []) - - + + class LoadOnlyRelated(test.TestCase): models = (Person, Group) - + @classmethod def after_setup(cls): with cls.session().begin() as t: - g1 = t.add(Group(name='bla', description='bla bla')) - g2 = t.add(Group(name='foo', description='foo foo')) + g1 = t.add(Group(name="bla", description="bla bla")) + g2 = t.add(Group(name="foo", description="foo foo")) yield t.on_result with cls.session().begin() as t: - t.add(Person(name='luca', group=g1)) - t.add(Person(name='carl', group=g1)) - t.add(Person(name='bob', group=g1)) + t.add(Person(name="luca", group=g1)) + t.add(Person(name="carl", group=g1)) + t.add(Person(name="bob", group=g1)) yield t.on_result - + def test_simple(self): session = self.session() query = session.query(Person) - qs = yield query.load_only('group').all() + qs = yield query.load_only("group").all() for m in qs: - self.assertEqual(m._loadedfields,('group',)) - self.assertFalse(hasattr(m,'name')) - self.assertTrue(hasattr(m,'group_id')) + self.assertEqual(m._loadedfields, ("group",)) + self.assertFalse(hasattr(m, "name")) + self.assertTrue(hasattr(m, "group_id")) self.assertTrue(m.group_id) - self.assertTrue('id' in m._dbdata) - self.assertEqual(m._dbdata['id'],m.id) + self.assertTrue("id" in m._dbdata) + self.assertEqual(m._dbdata["id"], m.id) g = yield m.group self.assertTrue(isinstance(g, Group)) - + def testLoadForeignKeyFields(self): session = self.session() - qs = yield session.query(Person).load_only('group__name').all() - group = Person._meta.dfields['group'] + qs = yield session.query(Person).load_only("group__name").all() + group = Person._meta.dfields["group"] for m in qs: - self.assertEqual(m._loadedfields, ('group',)) - self.assertFalse(hasattr(m, 'name')) - self.assertTrue(hasattr(m, 'group_id')) + self.assertEqual(m._loadedfields, ("group",)) + self.assertFalse(hasattr(m, "name")) + self.assertTrue(hasattr(m, "group_id")) cache_name = group.get_cache_name() g = getattr(m, cache_name, None) self.assertTrue(g) self.assertTrue(isinstance(g, group.relmodel)) # And now check what is loaded with g - self.assertEqual(g._loadedfields, ('name',)) - self.assertFalse(hasattr(g, 'description')) - + self.assertEqual(g._loadedfields, ("name",)) + self.assertFalse(hasattr(g, "description")) + class TestFieldReplace(test.TestCase): model = Statistics3 - + @classmethod def after_setup(cls): with cls.session().begin() as t: - t.add(cls.model(name='a', - data={'pv': {'': 0.5, 'mean': 1, 'std': 3.5}})) + t.add(cls.model(name="a", data={"pv": {"": 0.5, "mean": 1, "std": 3.5}})) return t.on_result - + def test_load_only(self): session = self.session() query = session.query(self.model) - s = yield query.load_only('name', 'data__pv').get(name='a') - self.assertEqual(s.name, 'a') + s = yield query.load_only("name", "data__pv").get(name="a") + self.assertEqual(s.name, "a") self.assertEqual(s.data__pv, 0.5) self.assertFalse(s.has_all_data) - self.assertEqual(s.get_state().action, 'update') + self.assertEqual(s.get_state().action, "update") # Now set extra data - s.data = {'pv': {'mean': 2}} + s.data = {"pv": {"mean": 2}} with session.begin() as t: t.add(s) yield t.on_result - s = yield query.get(name='a') + s = yield query.get(name="a") self.assertTrue(s.has_all_data) - self.assertEqual(s.data, {'pv': {'': 0.5, 'mean': 2, 'std': 3.5}}) - + self.assertEqual(s.data, {"pv": {"": 0.5, "mean": 2, "std": 3.5}}) + def test_replace(self): session = self.session() query = session.query(self.model) - s = yield query.get(name='a') + s = yield query.get(name="a") self.assertTrue(s.has_all_data) - self.assertEqual(s.get_state().action, 'override') - s.data = {'bla': {'foo': -1}} + self.assertEqual(s.get_state().action, "override") + s.data = {"bla": {"foo": -1}} with session.begin() as t: t.add(s) yield t.on_result - s = yield self.query().get(name='a') + s = yield self.query().get(name="a") self.assertTrue(s.has_all_data) - self.assertEqual(s.data, {'bla': {'foo': -1}}) - \ No newline at end of file + self.assertEqual(s.data, {"bla": {"foo": -1}}) diff --git a/tests/all/query/load_related.py b/tests/all/query/load_related.py index 949176b..8a39f2b 100644 --- a/tests/all/query/load_related.py +++ b/tests/all/query/load_related.py @@ -1,8 +1,8 @@ -from stdnet import odm, FieldError -from stdnet.utils import test - +from examples.data import FinanceTest, Fund, Instrument, Position from examples.models import Dictionary, Profile -from examples.data import FinanceTest, Position, Instrument, Fund + +from stdnet import FieldError, odm +from stdnet.utils import test class Role(odm.StdModel): @@ -10,52 +10,51 @@ class Role(odm.StdModel): class test_load_related(FinanceTest): - @classmethod def after_setup(cls): yield cls.data.makePositions(cls) - + def testMeta(self): session = self.session() query = session.query(Position) self.assertEqual(query.select_related, None) - pos1 = query.load_related('instrument') + pos1 = query.load_related("instrument") self.assertEqual(len(pos1.select_related), 1) - self.assertEqual(pos1.select_related['instrument'], ()) - pos2 = pos1.load_related('instrument', 'name', 'ccy') - self.assertEqual(pos1.select_related['instrument'], ()) - self.assertEqual(pos2.select_related['instrument'], ('name','ccy')) - pos3 = pos2.load_related('fund','name') - self.assertEqual(len(pos1.select_related),1) - self.assertEqual(len(pos2.select_related),1) - self.assertEqual(len(pos3.select_related),2) - self.assertEqual(pos1.select_related['instrument'], ()) - self.assertEqual(pos2.select_related['instrument'], ('name','ccy')) - self.assertEqual(pos3.select_related['instrument'], ('name','ccy')) - self.assertEqual(pos3.select_related['fund'], ('name',)) + self.assertEqual(pos1.select_related["instrument"], ()) + pos2 = pos1.load_related("instrument", "name", "ccy") + self.assertEqual(pos1.select_related["instrument"], ()) + self.assertEqual(pos2.select_related["instrument"], ("name", "ccy")) + pos3 = pos2.load_related("fund", "name") + self.assertEqual(len(pos1.select_related), 1) + self.assertEqual(len(pos2.select_related), 1) + self.assertEqual(len(pos3.select_related), 2) + self.assertEqual(pos1.select_related["instrument"], ()) + self.assertEqual(pos2.select_related["instrument"], ("name", "ccy")) + self.assertEqual(pos3.select_related["instrument"], ("name", "ccy")) + self.assertEqual(pos3.select_related["fund"], ("name",)) def testSingle(self): session = self.session() query = session.query(Position) - pos = query.load_related('instrument') - fund = Position._meta.dfields['fund'] - inst = Position._meta.dfields['instrument'] + pos = query.load_related("instrument") + fund = Position._meta.dfields["fund"] + inst = Position._meta.dfields["instrument"] pos = yield pos.all() self.assertTrue(pos) for p in pos: cache = inst.get_cache_name() - val = getattr(p,cache,None) + val = getattr(p, cache, None) self.assertTrue(val) - self.assertTrue(isinstance(val,inst.relmodel)) + self.assertTrue(isinstance(val, inst.relmodel)) cache = fund.get_cache_name() - val = getattr(p,cache,None) + val = getattr(p, cache, None) self.assertFalse(val) def test_single_with_fields(self): session = self.session() query = session.query(Position) - pos = query.load_related('instrument', 'name', 'ccy') - inst = Position._meta.dfields['instrument'] + pos = query.load_related("instrument", "name", "ccy") + inst = Position._meta.dfields["instrument"] pos = yield pos.all() self.assertTrue(pos) for p in pos: @@ -63,17 +62,17 @@ def test_single_with_fields(self): val = getattr(p, cache, None) self.assertTrue(val) self.assertTrue(isinstance(val, inst.relmodel)) - self.assertEqual(set(val._loadedfields),set(('name','ccy'))) + self.assertEqual(set(val._loadedfields), set(("name", "ccy"))) self.assertTrue(val.name) self.assertTrue(val.ccy) - self.assertFalse(hasattr(val,'type')) - + self.assertFalse(hasattr(val, "type")) + def test_with_id_only(self): - '''Test load realated when loading only the id''' + """Test load realated when loading only the id""" session = self.session() query = session.query(Position) - pos = query.load_related('instrument', 'id') - inst = Position._meta.dfields['instrument'] + pos = query.load_related("instrument", "id") + inst = Position._meta.dfields["instrument"] pos = yield pos.all() self.assertTrue(pos) for p in pos: @@ -81,24 +80,23 @@ def test_with_id_only(self): val = getattr(p, cache, None) self.assertTrue(val) self.assertTrue(isinstance(val, inst.relmodel)) - self.assertFalse(hasattr(val,'name')) - self.assertFalse(hasattr(val,'ccy')) - self.assertFalse(hasattr(val,'type')) + self.assertFalse(hasattr(val, "name")) + self.assertFalse(hasattr(val, "ccy")) + self.assertFalse(hasattr(val, "type")) self.assertEqual(set(val._loadedfields), set()) def testDouble(self): session = self.session() - pos = session.query(Position).load_related('instrument')\ - .load_related('fund') - fund = Position._meta.dfields['fund'] - inst = Position._meta.dfields['instrument'] + pos = session.query(Position).load_related("instrument").load_related("fund") + fund = Position._meta.dfields["fund"] + inst = Position._meta.dfields["instrument"] pos = yield pos.all() self.assertTrue(pos) for p in pos: cache = inst.get_cache_name() - val = getattr(p,cache,None) + val = getattr(p, cache, None) self.assertTrue(val) - self.assertTrue(isinstance(val,inst.relmodel)) + self.assertTrue(isinstance(val, inst.relmodel)) cache = fund.get_cache_name() val = getattr(p, cache, None) self.assertTrue(val) @@ -107,23 +105,22 @@ def testDouble(self): def testError(self): session = self.session() query = session.query(Position) - pos = self.assertRaises(FieldError, query.load_related, 'bla') - pos = self.assertRaises(FieldError, query.load_related, 'size') - pos = query.load_related('instrument', 'id') + pos = self.assertRaises(FieldError, query.load_related, "bla") + pos = self.assertRaises(FieldError, query.load_related, "size") + pos = query.load_related("instrument", "id") self.assertEqual(len(pos.select_related), 1) - self.assertEqual(pos.select_related['instrument'], ('id',)) + self.assertEqual(pos.select_related["instrument"], ("id",)) def testLoadRelatedLoadOnly(self): session = self.session() query = session.query(Position) - inst = Position._meta.dfields['instrument'] - qs = query.load_only('dt','size').load_related('instrument') - self.assertEqual(qs.fields, ('dt','size')) + inst = Position._meta.dfields["instrument"] + qs = query.load_only("dt", "size").load_related("instrument") + self.assertEqual(qs.fields, ("dt", "size")) qs = yield qs.all() self.assertTrue(qs) for p in qs: - self.assertEqual(set(p._loadedfields), - set(('dt','instrument','size'))) + self.assertEqual(set(p._loadedfields), set(("dt", "instrument", "size"))) cache = inst.get_cache_name() val = getattr(p, cache, None) self.assertTrue(val) @@ -131,10 +128,13 @@ def testLoadRelatedLoadOnly(self): def test_with_filter(self): session = self.session() - instruments = session.query(Instrument).filter(ccy='EUR') - qs = session.query(Position).filter(instrument=instruments)\ - .load_related('instrument') - inst = Position._meta.dfields['instrument'] + instruments = session.query(Instrument).filter(ccy="EUR") + qs = ( + session.query(Position) + .filter(instrument=instruments) + .load_related("instrument") + ) + inst = Position._meta.dfields["instrument"] qs = yield qs.all() self.assertTrue(qs) for p in qs: @@ -142,68 +142,63 @@ def test_with_filter(self): val = getattr(p, cache, None) self.assertTrue(val) self.assertTrue(isinstance(val, inst.relmodel)) - self.assertEqual(p.instrument.ccy, 'EUR') + self.assertEqual(p.instrument.ccy, "EUR") class test_load_related_empty(test.TestCase): models = (Role, Profile, Position, Instrument, Fund) - + @classmethod def after_setup(cls): with cls.session().begin() as t: - p1 = t.add(Profile(name='k1')) - p2 = t.add(Profile(name='k2')) - p3 = t.add(Profile(name='k3')) + p1 = t.add(Profile(name="k1")) + p2 = t.add(Profile(name="k2")) + p3 = t.add(Profile(name="k3")) yield t.on_result with cls.session().begin() as t: t.add(Role(profile=p1)) t.add(Role(profile=p1)) t.add(Role(profile=p3)) yield t.on_result - + def testEmpty(self): models = self.mapper - insts = yield models.position.query().load_related('instrument').all() + insts = yield models.position.query().load_related("instrument").all() self.assertEqual(insts, []) - + def test_related_no_fields(self): - qs = self.query().load_related('profile') + qs = self.query().load_related("profile") query = yield qs.all() profiles = set((role.profile for role in query)) self.assertEqual(len(profiles), 2) - - + + class load_related_structure(test.TestCase): model = Dictionary - + @classmethod def after_setup(cls): with cls.session().begin() as t: - d1 = t.add(Dictionary(name='english-italian')) - d2 = t.add(Dictionary(name='italian-english')) + d1 = t.add(Dictionary(name="english-italian")) + d2 = t.add(Dictionary(name="italian-english")) yield t.on_result with cls.session().begin() as t: - d1.data.update((('ball','palla'), - ('boat','nave'), - ('cat','gatto'))) - d2.data.update((('palla','ball'), - ('nave','boat'), - ('gatto','cat'))) + d1.data.update((("ball", "palla"), ("boat", "nave"), ("cat", "gatto"))) + d2.data.update((("palla", "ball"), ("nave", "boat"), ("gatto", "cat"))) yield t.on_result def test_hash(self): session = self.session() query = session.query(Dictionary) # Check if data is there first - d = yield query.get(name = 'english-italian') + d = yield query.get(name="english-italian") data = yield d.data.items() remote = dict(data) self.assertEqual(len(remote), 3) # - d = yield query.load_related('data').get(name='english-italian') + d = yield query.load_related("data").get(name="english-italian") data = d.data # the cache should be available cache = data.cache.cache self.assertEqual(len(cache), 3) self.assertEqual(cache, remote) - diff --git a/tests/all/query/manager.py b/tests/all/query/manager.py index 7c7d87d..5ae7066 100755 --- a/tests/all/query/manager.py +++ b/tests/all/query/manager.py @@ -1,13 +1,12 @@ import random +from examples.models import SimpleModel + import stdnet from stdnet.utils import test -from examples.models import SimpleModel - class StringData(test.DataGenerator): - def generate(self): self.names = self.populate() @@ -15,7 +14,7 @@ def generate(self): class TestManager(test.TestCase): model = SimpleModel data_cls = StringData - + @classmethod def after_setup(cls): manager = cls.mapper.simplemodel @@ -23,71 +22,69 @@ def after_setup(cls): for name in cls.data.names: t.add(manager(code=name)) yield t.on_result - + def test_manager(self): models = self.mapper self.assertEqual(models[SimpleModel], models.simplemodel) self.assertEqual(models.simplemodel.model, self.model) - + def testGetOrCreate(self): objects = self.mapper[SimpleModel] - v, created = yield objects.update_or_create(code='test') + v, created = yield objects.update_or_create(code="test") self.assertTrue(created) - self.assertEqual(v.code,'test') - v2, created = yield objects.update_or_create(code='test') + self.assertEqual(v.code, "test") + v2, created = yield objects.update_or_create(code="test") self.assertFalse(created) - self.assertEqual(v,v2) - + self.assertEqual(v, v2) + def test_get(self): objects = self.mapper[SimpleModel] - v, created = yield objects.update_or_create(code='test2') + v, created = yield objects.update_or_create(code="test2") self.assertTrue(created) - v1 = yield objects.get(code='test2') + v1 = yield objects.get(code="test2") self.assertEqual(v1, v) - + def test_get_error(self): - '''Test for a ObjectNotFound exception.''' + """Test for a ObjectNotFound exception.""" objects = self.mapper[SimpleModel] - yield self.async.assertRaises(SimpleModel.DoesNotExist, - objects.get, code='test3') - yield self.async.assertRaises(SimpleModel.DoesNotExist, - objects.get, id=-400) - + yield self.async.assertRaises( + SimpleModel.DoesNotExist, objects.get, code="test3" + ) + yield self.async.assertRaises(SimpleModel.DoesNotExist, objects.get, id=-400) + def testEmptyIDFilter(self): objects = self.mapper[SimpleModel] yield self.async.assertEqual(objects.filter(id=-1).count(), 0) yield self.async.assertEqual(objects.filter(id=1).count(), 1) yield self.async.assertEqual(objects.filter(id=2).count(), 1) - + def testUniqueFilter(self): objects = self.mapper[SimpleModel] - yield self.async.assertEqual(objects.filter(code='test4').count(), 0) - yield objects.update_or_create(code='test4') - yield self.async.assertEqual(objects.filter(code='test4').count(), 1) - yield self.async.assertEqual(objects.filter(code='foo').count(), 0) - + yield self.async.assertEqual(objects.filter(code="test4").count(), 0) + yield objects.update_or_create(code="test4") + yield self.async.assertEqual(objects.filter(code="test4").count(), 1) + yield self.async.assertEqual(objects.filter(code="foo").count(), 0) + def testIndexFilter(self): objects = self.mapper.simplemodel - yield self.async.assertEqual(objects.filter(group='g1').count(), 0) - v, created = yield objects.update_or_create(code='test5', group='g2') - yield self.async.assertEqual(objects.filter(group='g1').count(), 0) - yield self.async.assertEqual(objects.filter(group='g2').count(), 1) - v1 = yield objects.get(group='g2') + yield self.async.assertEqual(objects.filter(group="g1").count(), 0) + v, created = yield objects.update_or_create(code="test5", group="g2") + yield self.async.assertEqual(objects.filter(group="g1").count(), 0) + yield self.async.assertEqual(objects.filter(group="g2").count(), 1) + v1 = yield objects.get(group="g2") self.assertEqual(v, v1) - yield self.async.assertRaises(SimpleModel.DoesNotExist, - objects.get, group='g1') - v2, created = yield objects.update_or_create(code='test6', group='g2') - yield self.async.assertEqual(objects.filter(group='g2').count(), 2) - yield self.async.assertRaises(stdnet.QuerySetError, - objects.get, group='g2') - + yield self.async.assertRaises(SimpleModel.DoesNotExist, objects.get, group="g1") + v2, created = yield objects.update_or_create(code="test6", group="g2") + yield self.async.assertEqual(objects.filter(group="g2").count(), 2) + yield self.async.assertRaises(stdnet.QuerySetError, objects.get, group="g2") + def testNoFilter(self): objects = self.mapper[SimpleModel] - filter1 = lambda : objects.filter(description = 'bo').count() + filter1 = lambda: objects.filter(description="bo").count() yield self.async.assertRaises(stdnet.QuerySetError, filter1) - + def testContainsAll(self): - '''Test filter when performing a all request''' + """Test filter when performing a all request""" objects = self.mapper[SimpleModel] qs = objects.query() all = yield qs.all() @@ -96,7 +93,7 @@ def testContainsAll(self): self.assertTrue(1 in qs) be = qs.backend_query() self.assertEqual(be.cache[None], all) - + def test_pkvalue(self): models = self.mapper all = yield models.simplemodel.all() diff --git a/tests/all/query/manytomany.py b/tests/all/query/manytomany.py index 5382cb7..128f57a 100644 --- a/tests/all/query/manytomany.py +++ b/tests/all/query/manytomany.py @@ -1,20 +1,20 @@ -from stdnet import odm, ManyToManyError +from examples.m2m import Composite, CompositeElement, Element +from examples.models import Profile, Role + +from stdnet import ManyToManyError, odm from stdnet.utils import test -from examples.models import Role, Profile -from examples.m2m import Composite, Element, CompositeElement - class TestManyToManyBase(object): models = (Role, Profile) - - def addsome(self, role1='admin', role2='coder'): + + def addsome(self, role1="admin", role2="coder"): models = self.mapper session = models.session() with session.begin() as t: - profile = t.add(models.profile(name='p1')) - profile2 = t.add(models.profile(name='p2')) - profile3 = t.add(models.profile(name='p3')) + profile = t.add(models.profile(name="p1")) + profile2 = t.add(models.profile(name="p2")) + profile3 = t.add(models.profile(name="p3")) role1 = t.add(models.role(name=role1)) role2 = t.add(models.role(name=role2)) yield t.on_result @@ -24,9 +24,9 @@ def addsome(self, role1='admin', role2='coder'): yield t.on_result self.assertEqual(len(t.saved), 1) self.assertEqual(len(list(t.saved.values())[0]), 2) - # Check role - t1 = yield role1.profiles.throughquery().load_related('role').all() - t2 = yield role2.profiles.throughquery().load_related('role').all() + # Check role + t1 = yield role1.profiles.throughquery().load_related("role").all() + t2 = yield role2.profiles.throughquery().load_related("role").all() self.assertEqual(len(t1), 1) self.assertEqual(len(t2), 1) self.assertEqual(t1[0].role, role1) @@ -47,62 +47,61 @@ def addsome(self, role1='admin', role2='coder'): self.assertEqual(p2, profile) # # Check with load_only - t1 = yield profile.roles.throughquery().load_related('profile').all() + t1 = yield profile.roles.throughquery().load_related("profile").all() self.assertEqual(len(t1), 2) self.assertEqual(t1[0].profile, profile) self.assertEqual(t1[1].profile, profile) # r = yield profile.roles.query().all() self.assertEqual(len(r), 2) - self.assertEqual(set(r), set((role1,role2))) + self.assertEqual(set(r), set((role1, role2))) yield role1, role2 - + class TestManyToMany(TestManyToManyBase, test.TestCase): - def test_meta(self): - self.assertEqual(Profile._meta.manytomany, ['roles']) + self.assertEqual(Profile._meta.manytomany, ["roles"]) roles = Profile.roles - self.assertEqual(roles.model._meta.name, 'profile_role') + self.assertEqual(roles.model._meta.name, "profile_role") self.assertEqual(roles.relmodel, Profile) - self.assertEqual(roles.name_relmodel, 'profile') + self.assertEqual(roles.name_relmodel, "profile") self.assertEqual(roles.formodel, Role) profiles = Role.profiles - self.assertEqual(profiles.model._meta.name, 'profile_role') + self.assertEqual(profiles.model._meta.name, "profile_role") self.assertEqual(profiles.relmodel, Role) self.assertEqual(profiles.formodel, Profile) - self.assertEqual(profiles.name_relmodel, 'role') + self.assertEqual(profiles.name_relmodel, "role") # through = roles.model self.assertEqual(through, profiles.model) self.assertEqual(len(through._meta.dfields), 3) - + def test_meta_instance(self): p = Profile() self.assertEqual(p.roles.formodel, Role) self.assertEqual(p.roles.related_instance, p) - yield self.addsome('admin', 'coder') - role = yield self.query(Role).get(name='admin') + yield self.addsome("admin", "coder") + role = yield self.query(Role).get(name="admin") self.assertEqual(role.profiles.formodel, Profile) self.assertEqual(role.profiles.related_instance, role) - + def testQuery(self): - yield self.addsome('bla', 'foo') - role = yield self.query(Role).get(name='bla') + yield self.addsome("bla", "foo") + role = yield self.query(Role).get(name="bla") profiles = role.profiles.query() self.assertEqual(profiles.model, Profile) self.assertEqual(profiles.session, role.session) - + def test_throughquery(self): - yield self.addsome('bla2', 'foo2') - role = yield self.query(Role).get(name='bla2') + yield self.addsome("bla2", "foo2") + role = yield self.query(Role).get(name="bla2") query = role.profiles.throughquery() self.assertEqual(query.model, role.profiles.model) self.assertEqual(query.session, role.session) - + def test_multiple_add(self): - yield self.addsome('bla3', 'foo3') - role = yield self.query(Role).get(name='bla3') + yield self.addsome("bla3", "foo3") + role = yield self.query(Role).get(name="bla3") profiles = yield role.profiles.query().all() self.assertEqual(len(profiles), 1) # lets add it again @@ -111,25 +110,24 @@ def test_multiple_add(self): profiles = yield role.profiles.query().all() self.assertEqual(len(profiles), 1) self.assertEqual(profile, profiles[0]) - - + + class TestManyToManyAddDelete(TestManyToManyBase, test.TestWrite): - def testAdd(self): return self.addsome() - + def testDelete1(self): - role1, role2 = yield self.addsome('bla', 'foo') + role1, role2 = yield self.addsome("bla", "foo") session = self.session() profiles = yield role1.profiles.query().all() self.assertEqual(len(profiles), 1) profile = profiles[0] yield self.async.assertEqual(profile.roles.query().count(), 2) yield session.delete(profile) - role1, role2 = yield session.query(Role).filter(name=('bla','foo')).all() + role1, role2 = yield session.query(Role).filter(name=("bla", "foo")).all() yield self.async.assertEqual(role1.profiles.query().count(), 0) yield self.async.assertEqual(role2.profiles.query().count(), 0) - + def testDelete2(self): yield self.addsome() session = self.session() @@ -140,16 +138,16 @@ def testDelete2(self): profile = yield session.query(Profile).get(id=1) yield self.async.assertEqual(profile.roles.query().count(), 0) yield session.delete(profile) - + def test_remove(self): session = self.session() with session.begin() as t: - p1 = t.add(Profile(name='l1')) - p2 = t.add(Profile(name='l2')) + p1 = t.add(Profile(name="l1")) + p2 = t.add(Profile(name="l2")) yield t.on_result - role, created = yield session.update_or_create(Role, name='gino') + role, created = yield session.update_or_create(Role, name="gino") self.assertTrue(created) - role, created = yield session.update_or_create(Role, name='gino') + role, created = yield session.update_or_create(Role, name="gino") self.assertFalse(created) self.assertTrue(role.id) with p1.session.begin() as t: @@ -166,21 +164,20 @@ def test_remove(self): # Now remove the role yield p2.roles.remove(role) profiles = role.profiles.query() - yield self.async.assertEqual(profiles.count(),1) + yield self.async.assertEqual(profiles.count(), 1) yield p1.roles.remove(role) profiles = role.profiles.query() yield self.async.assertEqual(profiles.count(), 0) - - + + class TestRegisteredThroughModel(TestManyToManyBase, test.TestCase): - def testMeta(self): models = self.mapper through = Profile.roles.model self.assertTrue(through in models) objects = models[through] name = through.__name__ - self.assertEqual(name, 'profile_role') + self.assertEqual(name, "profile_role") self.assertEqual(objects.backend, models[Profile].backend) self.assertEqual(objects.backend, models[Role].backend) self.assertEqual(through.role.field.model, through) @@ -189,60 +186,65 @@ def testMeta(self): self.assertTrue(isinstance(pk, odm.CompositeIdField)) self.assertEqual(pk.fields[0].relmodel, Profile) self.assertEqual(pk.fields[1].relmodel, Role) - + def test_class_add(self): - self.assertRaises(ManyToManyError, Profile.roles.add, Role(name='foo')) + self.assertRaises(ManyToManyError, Profile.roles.add, Role(name="foo")) self.assertRaises(ManyToManyError, Role.profiles.add, Profile()) - + def test_through_query(self): m = self.mapper - p1, p2, p3 = yield self.multi_async((m.profile.new(name='g1'), - m.profile.new(name='g2'), - m.profile.new(name='g3'))) - r1, r2 = yield self.multi_async((m.role.new(name='bla'), - m.role.new(name='foo'))) + p1, p2, p3 = yield self.multi_async( + ( + m.profile.new(name="g1"), + m.profile.new(name="g2"), + m.profile.new(name="g3"), + ) + ) + r1, r2 = yield self.multi_async( + (m.role.new(name="bla"), m.role.new(name="foo")) + ) # Add a role to a profile pr1, pr2 = yield self.multi_async((p1.roles.add(r1), p2.roles.add(r1))) yield self.async.assertEqual(pr1.role, r1) yield self.async.assertEqual(pr2.role, r1) - + class TestManyToManyThrough(test.TestCase): models = (Composite, Element, CompositeElement) - + def testMetaComposite(self): meta = Composite._meta m2m = None for field in meta.fields: - if field.name == 'elements': + if field.name == "elements": m2m = field self.assertTrue(isinstance(m2m, odm.ManyToManyField)) - self.assertFalse('elements' in meta.dfields) - self.assertEqual(m2m.through,CompositeElement) - self.assertTrue('elements' in meta.related) + self.assertFalse("elements" in meta.dfields) + self.assertEqual(m2m.through, CompositeElement) + self.assertTrue("elements" in meta.related) manager = Composite.elements - self.assertEqual(manager.model,CompositeElement) - self.assertEqual(manager.relmodel,Composite) - self.assertEqual(manager.formodel,Element) - self.assertEqual(len(CompositeElement._meta.indices),2) - + self.assertEqual(manager.model, CompositeElement) + self.assertEqual(manager.relmodel, Composite) + self.assertEqual(manager.formodel, Element) + self.assertEqual(len(CompositeElement._meta.indices), 2) + def testMetaElement(self): meta = Element._meta - self.assertTrue('composites' in meta.related) + self.assertTrue("composites" in meta.related) manager = Element.composites - self.assertEqual(manager.model,CompositeElement) - self.assertEqual(manager.relmodel,Element) - self.assertEqual(manager.formodel,Composite) - + self.assertEqual(manager.model, CompositeElement) + self.assertEqual(manager.relmodel, Element) + self.assertEqual(manager.formodel, Composite) + def testAdd(self): session = self.session() with session.begin() as t: - c = t.add(Composite(name='test')) - e1 = t.add(Element(name='foo')) - e2 = t.add(Element(name='bla')) + c = t.add(Composite(name="test")) + e1 = t.add(Element(name="foo")) + e2 = t.add(Element(name="bla")) yield t.on_result yield c.elements.add(e1, weight=1.5) yield c.elements.add(e2, weight=-1) elems = yield c.elements.throughquery().all() for elem in elems: - self.assertTrue(elem.weight) \ No newline at end of file + self.assertTrue(elem.weight) diff --git a/tests/all/query/meta.py b/tests/all/query/meta.py index b1ef3e0..c994aa6 100644 --- a/tests/all/query/meta.py +++ b/tests/all/query/meta.py @@ -1,25 +1,24 @@ -'''Test query meta and corner cases''' +"""Test query meta and corner cases""" +from examples.data import FinanceTest +from examples.models import Instrument + from stdnet import QuerySetError, odm from stdnet.utils import test -from examples.models import Instrument -from examples.data import FinanceTest - class TestMeta(FinanceTest): - def test_session_meta(self): models = self.mapper session = models.session() self.assertEqual(session.router, models) self.assertEqual(session.transaction, None) - + def testQueryMeta(self): models = self.mapper qs = models.instrument.query() self.assertIsInstance(qs, odm.Query) self.assertEqual(qs.model, models.instrument.model) - + def test_empty_query(self): empty = self.session().empty(Instrument) self.assertEqual(empty.meta, Instrument._meta) @@ -38,40 +37,38 @@ def test_empty_query(self): self.assertEqual(set(all), set(all2)) q = self.query().filter(ccy__in=()) yield self.async.assertEqual(q.count(), 0) - + def testProperties(self): query = self.query() self.assertFalse(query.executed) - + def test_getfield(self): query = self.query() - self.assertRaises(QuerySetError, query.get_field, 'waaaaaaa') - query = query.get_field('id') - query2 = query.get_field('id') + self.assertRaises(QuerySetError, query.get_field, "waaaaaaa") + query = query.get_field("id") + query2 = query.get_field("id") self.assertEqual(query, query2) - + def testFilterError(self): - query = self.query().filter(whoaaaaa='foo') + query = self.query().filter(whoaaaaa="foo") self.assertRaises(QuerySetError, query.all) - + def testEmptyParameters(self): - query = self.query().filter(ccy='USD') + query = self.query().filter(ccy="USD") self.assertEqual(query, query.filter()) self.assertEqual(query, query.exclude()) - - + + class TestMetaWithData(FinanceTest): - @classmethod def after_setup(cls): return cls.data.create(cls) - + def test_repr(self): models = self.mapper # make sure there is at least one of them - yield models.instrument.new(name='a123345566', ccy='EUR', type='future') - query = self.query().filter(ccy='EUR')\ - .exclude(type=('equity', 'bond')) + yield models.instrument.new(name="a123345566", ccy="EUR", type="future") + query = self.query().filter(ccy="EUR").exclude(type=("equity", "bond")) self.assertTrue(str(query)) # The query is still lazy self.assertFalse(query.executed) diff --git a/tests/all/query/ranges.py b/tests/all/query/ranges.py index 563dff4..39de62f 100644 --- a/tests/all/query/ranges.py +++ b/tests/all/query/ranges.py @@ -1,22 +1,21 @@ +from examples.models import CrossData, Feed1, NumericData + from stdnet.utils import test from stdnet.utils.py2py3 import zip -from examples.models import NumericData, CrossData, Feed1 - class NumberGenerator(test.DataGenerator): - def generate(self): - self.d1 = self.populate('integer', start=-5, end=5) - self.d2 = self.populate('float', start=-10, end=10) - self.d3 = self.populate('float', start=-10, end=10) - self.d4 = self.populate('float', start=-10, end=10) - self.d5 = self.populate('integer', start=-5, end=5) - self.d6 = self.populate('integer', start=-5, end=5) - - + self.d1 = self.populate("integer", start=-5, end=5) + self.d2 = self.populate("float", start=-10, end=10) + self.d3 = self.populate("float", start=-10, end=10) + self.d4 = self.populate("float", start=-10, end=10) + self.d5 = self.populate("integer", start=-5, end=5) + self.d6 = self.populate("integer", start=-5, end=5) + + class NumericTest(test.TestCase): - multipledb = ['redis', 'mongo'] + multipledb = ["redis", "mongo"] data_cls = NumberGenerator model = NumericData @@ -25,14 +24,19 @@ def after_setup(cls): d = cls.data with cls.session().begin() as t: for a, b, c, d, e, f in zip(d.d1, d.d2, d.d3, d.d4, d.d5, d.d6): - t.add(cls.model(pv=a, vega=b, delta=c, gamma=d, - data={'test': {'': e, - 'inner': f}})) + t.add( + cls.model( + pv=a, + vega=b, + delta=c, + gamma=d, + data={"test": {"": e, "inner": f}}, + ) + ) yield t.on_result - + class TestNumericRange(NumericTest): - def testGT(self): session = self.session() qs = session.query(NumericData).filter(pv__gt=1) @@ -44,7 +48,7 @@ def testGT(self): self.assertTrue(qs) for v in qs: self.assertTrue(v.pv > -2) - + def testGE(self): session = self.session() qs = yield session.query(NumericData).filter(pv__ge=-2).all() @@ -55,7 +59,7 @@ def testGE(self): self.assertTrue(qs) for v in qs: self.assertTrue(v.pv >= 0) - + def testLT(self): session = self.session() qs = yield session.query(NumericData).filter(pv__lt=2).all() @@ -66,7 +70,7 @@ def testLT(self): self.assertTrue(qs) for v in qs: self.assertTrue(v.pv < -1) - + def testLE(self): session = self.session() qs = yield session.query(NumericData).filter(pv__le=1).all() @@ -77,7 +81,7 @@ def testLE(self): self.assertTrue(qs) for v in qs: self.assertTrue(v.pv <= -1) - + def testMix(self): session = self.session() qs = yield session.query(NumericData).filter(pv__gt=1, pv__lt=0).all() @@ -87,35 +91,39 @@ def testMix(self): for v in qs: self.assertTrue(v.pv < 3) self.assertTrue(v.pv >= -2) - + def testMoreThanOne(self): session = self.session() - qs = yield session.query(NumericData).filter(pv__ge=-2, pv__lt=3)\ - .filter(vega__gt=0).all() + qs = ( + yield session.query(NumericData) + .filter(pv__ge=-2, pv__lt=3) + .filter(vega__gt=0) + .all() + ) self.assertTrue(qs) for v in qs: self.assertTrue(v.pv < 3) self.assertTrue(v.pv >= -2) self.assertTrue(v.vega > 0) - + def testWithString(self): session = self.session() - qs = yield session.query(NumericData).filter(pv__ge='-2').all() + qs = yield session.query(NumericData).filter(pv__ge="-2").all() self.assertTrue(qs) for v in qs: self.assertTrue(v.pv >= -2) - + def testJson(self): session = self.session() qs = yield session.query(NumericData).filter(data__test__gt=1).all() self.assertTrue(qs) for v in qs: self.assertTrue(v.data__test > 1) - qs = yield session.query(NumericData).filter(data__test__gt='-2').all() + qs = yield session.query(NumericData).filter(data__test__gt="-2").all() self.assertTrue(qs) for v in qs: self.assertTrue(v.data__test > -2) - qs = yield session.query(NumericData).filter(data__test__inner__gt='1').all() + qs = yield session.query(NumericData).filter(data__test__inner__gt="1").all() self.assertTrue(qs) for v in qs: self.assertTrue(v.data__test__inner > 1) @@ -123,10 +131,10 @@ def testJson(self): self.assertTrue(qs) for v in qs: self.assertTrue(v.data__test__inner > -2) - + class TestNumericRangeForeignKey(test.TestCase): - multipledb = ['redis', 'mongo'] + multipledb = ["redis", "mongo"] data_cls = NumberGenerator models = (CrossData, Feed1) @@ -136,23 +144,26 @@ def after_setup(cls): da = cls.data with session.begin() as t: for a, b, c, d, e, f in zip(da.d1, da.d2, da.d3, da.d4, da.d5, da.d6): - t.add(CrossData(name='live', - data={'a': a, 'b': b, 'c': c, - 'd': d, 'e': e, 'f': f})) + t.add( + CrossData( + name="live", + data={"a": a, "b": b, "c": c, "d": d, "e": e, "f": f}, + ) + ) yield t.on_result cross = yield cls.query().all() found = False with session.begin() as t: for n, c in enumerate(cross): if c.data__a > -1: - found=True - feed = 'feed%s' % (n+1) + found = True + feed = "feed%s" % (n + 1) t.add(Feed1(name=feed, live=c)) yield t.on_result - assert found, 'not found' - + assert found, "not found" + def test_feeds(self): - qs = yield self.query(Feed1).load_related('live').all() + qs = yield self.query(Feed1).load_related("live").all() self.assertTrue(qs) for feed in qs: self.assertTrue(feed.live) @@ -161,16 +172,21 @@ def test_feeds(self): self.assertTrue(qs) for c in qs: self.assertTrue(c.data__a >= -1) - + def test_gt_direct(self): qs1 = self.query().filter(data__a__gt=-1) - qs = yield self.query(Feed1).filter(live=qs1).load_related('live').all() + qs = yield self.query(Feed1).filter(live=qs1).load_related("live").all() self.assertTrue(qs) for feed in qs: self.assertTrue(feed.live.data__a >= -1) - + def test_gt(self): - qs = yield self.query(Feed1).filter(live__data__a__gt=-1).load_related('live').all() + qs = ( + yield self.query(Feed1) + .filter(live__data__a__gt=-1) + .load_related("live") + .all() + ) self.assertTrue(qs) for feed in qs: - self.assertTrue(feed.live.data__a >= -1) \ No newline at end of file + self.assertTrue(feed.live.data__a >= -1) diff --git a/tests/all/query/related.py b/tests/all/query/related.py index 0e59af0..308206f 100755 --- a/tests/all/query/related.py +++ b/tests/all/query/related.py @@ -1,10 +1,10 @@ import datetime from random import randint, uniform -from stdnet.utils import test +from examples.data import FinanceTest, Fund, Instrument, Position +from examples.models import Dictionary, Node, Profile, Role -from examples.models import Node, Role, Profile, Dictionary -from examples.data import FinanceTest, Position, Instrument, Fund +from stdnet.utils import test def create(cls, root=None, nesting=None): @@ -15,58 +15,60 @@ def create(cls, root=None, nesting=None): yield t.on_result yield create(cls, root, nesting=nesting) elif nesting: - N = randint(2,9) + N = randint(2, 9) with models.session().begin() as t: for n in range(N): - node = t.add(models.node(parent=root, weight=uniform(0,1))) + node = t.add(models.node(parent=root, weight=uniform(0, 1))) yield t.on_result - yield cls.multi_async((create(cls, node, nesting-1) for node\ - in t.saved[node._meta])) - + yield cls.multi_async( + (create(cls, node, nesting - 1) for node in t.saved[node._meta]) + ) + class TestSelfForeignKey(test.TestCase): - '''The Node model is used only in this test class and should be used only -in this test class so that we can use the manager in a parallel test suite.''' + """The Node model is used only in this test class and should be used only + in this test class so that we can use the manager in a parallel test suite.""" + model = Node nesting = 2 - + @classmethod def after_setup(cls): return create(cls, nesting=cls.nesting) - + def test_meta(self): - all = yield self.query().load_related('parent').all() + all = yield self.query().load_related("parent").all() for n in all: if n.parent: self.assertTrue(isinstance(n.parent, self.model)) - + def test_related_cache(self): all = yield self.query().all() - pcache = self.model._meta.dfields['parent'].get_cache_name() + pcache = self.model._meta.dfields["parent"].get_cache_name() for n in all: self.assertFalse(hasattr(n, pcache)) yield self.multi_async((n.parent for n in all)) for n in all: self.assertTrue(hasattr(n, pcache)) self.assertEqual(getattr(n, pcache), n.parent) - + def test_self_related(self): query = self.query() root = yield query.get(parent=None) - children = yield root.children.query().load_related('parent').all() + children = yield root.children.query().load_related("parent").all() self.assertTrue(children) for child in children: self.assertEqual(child.parent, root) - children2 = yield child.children.query().load_related('parent').all() + children2 = yield child.children.query().load_related("parent").all() self.assertTrue(children2) for child2 in children2: self.assertEqual(child2.parent, child) - + def test_self_related_filter_on_self(self): query = self.query() # We should get the nodes just after the root root = yield query.get(parent=None) - qs = yield query.filter(parent__parent=None).load_related('parent').all() + qs = yield query.filter(parent__parent=None).load_related("parent").all() self.assertTrue(qs) for node in qs: self.assertEqual(node.parent, root) @@ -75,10 +77,10 @@ def test_self_related_filter_on_self(self): class TestDeleteSelfRelated(test.TestWrite): model = Node nesting = 2 - + def setUp(self): return create(self, nesting=self.nesting) - + def test_related_delete_all(self): all = yield self.query().all() self.assertTrue(all) @@ -87,14 +89,14 @@ def test_related_delete_all(self): if a.parent is None: root += 1 self.assertEqual(root, 1) - yield self.query().delete() + yield self.query().delete() yield self.async.assertEqual(self.query().count(), 0) - + def test_related_root_delete(self): qs = self.query().filter(parent=None) yield qs.delete() yield self.async.assertEqual(self.query().count(), 0) - + def test_related_filter_delete(self): query = self.query() root = yield query.get(parent=None) @@ -108,63 +110,69 @@ def test_related_filter_delete(self): class TestRealtedQuery(FinanceTest): - @classmethod def after_setup(cls): return cls.data.makePositions(cls) - + def test_related_filter(self): query = self.query(Position) # fetch all position with EUR instruments - instruments = self.query(Instrument).filter(ccy='EUR') - peur1 = yield self.query(Position).filter(instrument=instruments)\ - .load_related('instrument').all() + instruments = self.query(Instrument).filter(ccy="EUR") + peur1 = ( + yield self.query(Position) + .filter(instrument=instruments) + .load_related("instrument") + .all() + ) self.assertTrue(peur1) for p in peur1: - self.assertEqual(p.instrument.ccy,'EUR') - peur = self.query(Position).filter(instrument__ccy='EUR') + self.assertEqual(p.instrument.ccy, "EUR") + peur = self.query(Position).filter(instrument__ccy="EUR") qe = peur.construct() self.assertEqual(qe._get_field, None) - self.assertEqual(len(qe),1) - self.assertEqual(qe.keyword, 'set') + self.assertEqual(len(qe), 1) + self.assertEqual(qe.keyword, "set") peur = yield peur.all() self.assertEqual(set(peur), set(peur1)) - + def test_related_exclude(self): query = self.query(Position) - peur = yield query.exclude(instrument__ccy='EUR').load_related('instrument').all() + peur = ( + yield query.exclude(instrument__ccy="EUR").load_related("instrument").all() + ) self.assertTrue(peur) for p in peur: - self.assertNotEqual(p.instrument.ccy, 'EUR') - + self.assertNotEqual(p.instrument.ccy, "EUR") + def test_load_related_model(self): position = yield self.query(Position).get(id=1) self.assertTrue(position.instrument_id) - cache = position.get_field('instrument').get_cache_name() + cache = position.get_field("instrument").get_cache_name() self.assertFalse(hasattr(position, cache)) - instrument = yield position.load_related_model('instrument', - load_only=('ccy',)) + instrument = yield position.load_related_model("instrument", load_only=("ccy",)) self.assertTrue(isinstance(instrument, Instrument)) - self.assertEqual(instrument._loadedfields, ('ccy',)) + self.assertEqual(instrument._loadedfields, ("ccy",)) self.assertEqual(id(instrument), id(position.instrument)) - + def test_related_manager(self): session = self.session() fund = yield session.query(Fund).get(id=1) positions1 = yield session.query(Position).filter(fund=fund).all() - positions = yield fund.positions.query().load_related('fund').all() + positions = yield fund.positions.query().load_related("fund").all() self.assertTrue(positions) for p in positions: self.assertEqual(p.fund, fund) self.assertEqual(set(positions1), set(positions)) - + def test_related_manager_exclude(self): inst = yield self.query().get(id=1) fund = yield self.query(Fund).get(id=1) - pos = yield fund.positions.exclude(instrument=inst).load_related('instrument')\ - .load_related('fund').all() + pos = ( + yield fund.positions.exclude(instrument=inst) + .load_related("instrument") + .load_related("fund") + .all() + ) for p in pos: self.assertNotEqual(p.instrument, inst) self.assertEqual(p.fund, fund) - - diff --git a/tests/all/query/session.py b/tests/all/query/session.py index b82cc1c..9cde365 100644 --- a/tests/all/query/session.py +++ b/tests/all/query/session.py @@ -1,8 +1,8 @@ -'''Sessions and transactions management''' -from stdnet import odm, getdb -from stdnet.utils import test, gen_unique_id +"""Sessions and transactions management""" +from examples.models import Instrument, SimpleModel -from examples.models import SimpleModel, Instrument +from stdnet import getdb, odm +from stdnet.utils import gen_unique_id, test class TestSession(test.TestWrite): @@ -14,7 +14,7 @@ def test_simple_create(self): self.assertFalse(session.transaction) session.begin() self.assertTrue(session.transaction) - m = models.simplemodel(code='pluto', group='planet') + m = models.simplemodel(code="pluto", group="planet") self.assertEqual(m, session.add(m)) self.assertTrue(m in session) sm = session.model(m) @@ -31,8 +31,8 @@ def test_create_objects(self): # Tests a session with two models. This was for a bug models = self.mapper with models.session().begin() as t: - t.add(models.simplemodel(code='pluto',group='planet')) - t.add(models.instrument(name='bla',ccy='EUR',type='equity')) + t.add(models.simplemodel(code="pluto", group="planet")) + t.add(models.instrument(name="bla", ccy="EUR", type="equity")) # The transaction is complete when the on_commit is not asynchronous yield t.on_result yield self.async.assertEqual(models.simplemodel.query().count(), 1) @@ -42,35 +42,35 @@ def test_simple_filter(self): models = self.mapper session = models.session() with session.begin() as t: - t.add(SimpleModel(code='pluto', group='planet')) - t.add(SimpleModel(code='venus', group='planet')) - t.add(SimpleModel(code='sun', group='star')) + t.add(SimpleModel(code="pluto", group="planet")) + t.add(SimpleModel(code="venus", group="planet")) + t.add(SimpleModel(code="sun", group="star")) yield t.on_result query = session.query(SimpleModel) yield self.async.assertEqual(query.count(), 3) all = yield query.all() self.assertEqual(len(all), 3) - qs = query.filter(group='planet') + qs = query.filter(group="planet") self.assertFalse(qs.executed) yield self.async.assertEqual(qs.count(), 2) self.assertTrue(qs.executed) - qs = query.filter(group='star') + qs = query.filter(group="star") yield self.async.assertEqual(qs.count(), 1) - qs = query.filter(group='bla') + qs = query.filter(group="bla") yield self.async.assertEqual(qs.count(), 0) def test_modify_index_field(self): session = self.session() with session.begin() as t: - t.add(SimpleModel(code='pluto', group='planet')) + t.add(SimpleModel(code="pluto", group="planet")) yield t.on_result query = session.query(SimpleModel) - qs = query.filter(group='planet') + qs = query.filter(group="planet") yield self.async.assertEqual(qs.count(), 1) el = yield qs[0] id = self.assertEqualId(el, 1) session = self.session() - el.group = 'smallplanet' + el.group = "smallplanet" with session.begin() as t: t.add(el) yield t.on_result @@ -80,13 +80,13 @@ def test_modify_index_field(self): qs = session.query(self.model).filter(id=id) yield self.async.assertEqual(qs.count(), 1) el = yield qs[0] - self.assertEqual(el.code, 'pluto') - self.assertEqual(el.group, 'smallplanet') + self.assertEqual(el.code, "pluto") + self.assertEqual(el.group, "smallplanet") # now filter on group - qs = session.query(self.model).filter(group='smallplanet') + qs = session.query(self.model).filter(group="smallplanet") yield self.async.assertEqual(qs.count(), 1) el = yield qs[0] self.assertEqual(el.id, id) # now filter on old group - qs = session.query(self.model).filter(group='planet') + qs = session.query(self.model).filter(group="planet") yield self.async.assertEqual(qs.count(), 0) diff --git a/tests/all/query/signal.py b/tests/all/query/signal.py index 36c28b6..fb99a15 100644 --- a/tests/all/query/signal.py +++ b/tests/all/query/signal.py @@ -1,8 +1,8 @@ -from stdnet.utils import test -from stdnet import odm - from examples.models import Group, Person +from stdnet import odm +from stdnet.utils import test + class TestSignals(test.TestWrite): models = (Group, Person) @@ -11,29 +11,28 @@ def setUp(self): models = self.mapper models.post_commit.bind(self.addPerson, sender=Group) - def addPerson(self, signal, sender, instances=None, session=None, - **kwargs): + def addPerson(self, signal, sender, instances=None, session=None, **kwargs): models = session.router self.assertEqual(models, self.mapper) session = models.session() with session.begin() as t: for instance in instances: self.counter += 1 - if instance.name == 'user': - t.add(models.person(name='luca', group=instance)) + if instance.name == "user": + t.add(models.person(name="luca", group=instance)) return t.on_result def testPostCommit(self): self.counter = 0 session = self.session() with session.begin() as t: - g = Group(name='user') + g = Group(name="user") t.add(g) - t.add(Group(name='admin')) + t.add(Group(name="admin")) yield t.on_result self.assertEqualId(g, 1) - users = session.query(Person).filter(group__name='user') - admins = session.query(Person).filter(group__name='admin') + users = session.query(Person).filter(group__name="user") + admins = session.query(Person).filter(group__name="admin") yield self.async.assertEqual(users.count(), 1) yield self.async.assertEqual(admins.count(), 0) self.assertEqual(self.counter, 2) diff --git a/tests/all/query/slice.py b/tests/all/query/slice.py index 6719dce..8b731df 100644 --- a/tests/all/query/slice.py +++ b/tests/all/query/slice.py @@ -1,12 +1,11 @@ -'''Slice Query to obtain subqueries.''' +"""Slice Query to obtain subqueries.""" +from examples.data import FinanceTest + from stdnet import QuerySetError from stdnet.utils import test -from examples.data import FinanceTest - class TestFilter(FinanceTest): - @classmethod def after_setup(cls): yield cls.data.create(cls) @@ -34,8 +33,8 @@ def testUnsortedSliceComplex(self): N = yield qs.count() self.assertTrue(N) q1 = yield qs[0:-1] - self.assertEqual(len(q1), N-1) - for id, q in enumerate(q1,1): + self.assertEqual(len(q1), N - 1) + for id, q in enumerate(q1, 1): self.assertEqual(q.id, id) q1 = yield qs[2:4] self.assertEqual(len(q1), 2) @@ -51,7 +50,7 @@ def testUnsortedSliceToEnd(self): self.assertEqual(len(q1), N) # This time the result is sorted by ids q1 = yield qs[3:] - self.assertEqual(len(q1), N-3) + self.assertEqual(len(q1), N - 3) for id, q in enumerate(q1, 4): self.assertEqual(q.id, id) @@ -62,15 +61,15 @@ def testSliceBack(self): self.assertTrue(N) q1 = yield qs[-2:] self.assertEqual(len(q1), 2) - self.assertEqual(q1[0].id, N-1) + self.assertEqual(q1[0].id, N - 1) self.assertEqual(q1[1].id, N) # This time the result is sorted by ids q1 = yield qs[-2:-1] - self.assertEqual(len(q1),1) - self.assertEqual(q1[0].id,N-1) + self.assertEqual(len(q1), 1) + self.assertEqual(q1[0].id, N - 1) def testSliceGetField(self): - '''test slice in conjunction with get_field method''' + """test slice in conjunction with get_field method""" session = self.session() - qs = session.query(self.model).get_field('id') + qs = session.query(self.model).get_field("id") yield self.async.assertRaises(QuerySetError, lambda: qs[:2]) diff --git a/tests/all/query/sorting.py b/tests/all/query/sorting.py index b729d2b..b016307 100755 --- a/tests/all/query/sorting.py +++ b/tests/all/query/sorting.py @@ -1,27 +1,27 @@ from datetime import date, datetime -from stdnet import QuerySetError, odm -from stdnet.utils import test, zip, range +from examples.models import Group, Person, SportAtDate, SportAtDate2, TestDateModel -from examples.models import (SportAtDate, SportAtDate2, Person, - TestDateModel, Group) +from stdnet import QuerySetError, odm +from stdnet.utils import range, test, zip class SortGenerator(test.DataGenerator): - def generate(self, **kwargs): - self.dates = self.populate('date', start=date(2005,6,1), - end=date(2012,6,6)) - self.groups = self.populate('choice', - choice_from=['football', 'rugby', 'swimming', - 'running', 'cycling']) - self.persons = self.populate('choice', - choice_from=['pippo', 'pluto', 'saturn', 'luca', 'josh', - 'carl', 'paul']) + self.dates = self.populate("date", start=date(2005, 6, 1), end=date(2012, 6, 6)) + self.groups = self.populate( + "choice", + choice_from=["football", "rugby", "swimming", "running", "cycling"], + ) + self.persons = self.populate( + "choice", + choice_from=["pippo", "pluto", "saturn", "luca", "josh", "carl", "paul"], + ) class TestSort(test.TestCase): - '''Base class for sorting''' + """Base class for sorting""" + desc = False data_cls = SortGenerator @@ -34,7 +34,7 @@ def after_setup(cls): return t.on_result def checkOrder(self, qs, attr, desc=None): - if hasattr(qs, 'all'): + if hasattr(qs, "all"): all = yield qs.all() else: all = qs @@ -44,57 +44,57 @@ def checkOrder(self, qs, attr, desc=None): for obj in all[1:]: at1 = obj.get_attr_value(attr) if desc: - self.assertTrue(at1<=at0) + self.assertTrue(at1 <= at0) else: - self.assertTrue(at1>=at0) + self.assertTrue(at1 >= at0) at0 = at1 class ExplicitOrderingMixin(object): - def test_size(self): qs = self.query() yield self.async.assertEqual(qs.count(), len(self.data.dates)) def testDateSortBy(self): - return self.checkOrder(self.query().sort_by('dt'), 'dt') + return self.checkOrder(self.query().sort_by("dt"), "dt") def testDateSortByReversed(self): - return self.checkOrder(self.query().sort_by('-dt'),'dt',True) + return self.checkOrder(self.query().sort_by("-dt"), "dt", True) def testNameSortBy(self): - return self.checkOrder(self.query().sort_by('name'),'name') + return self.checkOrder(self.query().sort_by("name"), "name") def testNameSortByReversed(self): - return self.checkOrder(self.query().sort_by('-name'),'name',True) + return self.checkOrder(self.query().sort_by("-name"), "name", True) def testSimpleSortError(self): qs = self.query() - self.assertRaises(QuerySetError, qs.sort_by, 'whaaaa') + self.assertRaises(QuerySetError, qs.sort_by, "whaaaa") def testFilter(self): - qs = self.query().filter(name='rugby').sort_by('dt') - yield self.checkOrder(qs, 'dt') + qs = self.query().filter(name="rugby").sort_by("dt") + yield self.checkOrder(qs, "dt") for v in qs: - self.assertEqual(v.name, 'rugby') + self.assertEqual(v.name, "rugby") def _slicingTest(self, attr, desc, start=0, stop=10, expected_len=10): - p = '-' if desc else '' - qs = self.query().sort_by(p+attr) + p = "-" if desc else "" + qs = self.query().sort_by(p + attr) qs1 = yield qs[start:stop] self.assertEqual(len(qs1), expected_len) self.checkOrder(qs1, attr, desc) def testDateSlicing(self): - return self._slicingTest('dt',False) + return self._slicingTest("dt", False) def testDateSlicingDesc(self): - return self._slicingTest('dt',True) + return self._slicingTest("dt", True) class TestSortBy(TestSort, ExplicitOrderingMixin): - '''Test the sort_by in a model without ordering meta attribute. -Pure explicit ordering.''' + """Test the sort_by in a model without ordering meta attribute. + Pure explicit ordering.""" + model = TestDateModel @@ -111,7 +111,7 @@ def after_setup(cls): t.add(Group(name=g)) yield t.on_result groups = yield session.query(Group).all() - gps = test.populate('choice', d.size, choice_from=groups) + gps = test.populate("choice", d.size, choice_from=groups) with session.begin() as t: for p, g in zip(d.persons, gps): t.add(cls.model(name=p, group=g)) @@ -122,46 +122,46 @@ def test_size(self): return self.async.assertEqual(qs.count(), len(self.data.dates)) def testNameSortBy(self): - return self.checkOrder(self.query().sort_by('name'),'name') + return self.checkOrder(self.query().sort_by("name"), "name") def testNameSortByReversed(self): - return self.checkOrder(self.query().sort_by('-name'),'name',True) + return self.checkOrder(self.query().sort_by("-name"), "name", True) def testSortByFK(self): qs = self.query() - qs = qs.sort_by('group__name') + qs = qs.sort_by("group__name") ordering = qs.ordering - self.assertEqual(ordering.name, 'group_id') - self.assertEqual(ordering.nested.name, 'name') + self.assertEqual(ordering.name, "group_id") + self.assertEqual(ordering.nested.name, "name") self.assertEqual(ordering.model, qs.model) - self.checkOrder(qs, 'group__name') + self.checkOrder(qs, "group__name") class TestOrderingModel(TestSort): - '''Test a model which is always sorted by the ordering meta attribute.''' + """Test a model which is always sorted by the ordering meta attribute.""" + model = SportAtDate def testMeta(self): model = self.model self.assertTrue(model._meta.ordering) ordering = model._meta.ordering - self.assertEqual(ordering.name, 'dt') - self.assertEqual(ordering.field.name, 'dt') + self.assertEqual(ordering.name, "dt") + self.assertEqual(ordering.field.name, "dt") self.assertEqual(ordering.desc, self.desc) def testSimple(self): - yield self.checkOrder(self.query(), 'dt') + yield self.checkOrder(self.query(), "dt") def testFilter(self): - qs = self.query().filter(name=('football','rugby')) - return self.checkOrder(qs,'dt') + qs = self.query().filter(name=("football", "rugby")) + return self.checkOrder(qs, "dt") def testExclude(self): - qs = self.query().exclude(name='rugby') - return self.checkOrder(qs, 'dt') + qs = self.query().exclude(name="rugby") + return self.checkOrder(qs, "dt") class TestOrderingModelDesc(TestOrderingModel): model = SportAtDate2 desc = True - diff --git a/tests/all/query/transaction.py b/tests/all/query/transaction.py index ebf0e7f..709e1a8 100644 --- a/tests/all/query/transaction.py +++ b/tests/all/query/transaction.py @@ -1,20 +1,19 @@ import random -from stdnet import odm, InvalidTransaction -from examples.models import SimpleModel, Dictionary -from stdnet.utils import test, populate +from examples.models import Dictionary, SimpleModel + +from stdnet import InvalidTransaction, odm +from stdnet.utils import populate, test LEN = 100 -names = populate('string',LEN, min_len = 5, max_len = 20) +names = populate("string", LEN, min_len=5, max_len=20) class TransactionReceiver(object): - def __init__(self): self.transactions = [] - def __call__(self, signal, sender, instances=None, session=None, - **kwargs): + def __call__(self, signal, sender, instances=None, session=None, **kwargs): self.transactions.append((sender, instances)) @@ -31,12 +30,11 @@ def testCreate(self): query = session.query(self.model) with session.begin() as t: self.assertEqual(t.session, session) - s = t.add(self.model(code='test', description='just a test')) + s = t.add(self.model(code="test", description="just a test")) self.assertFalse(s.id) - s2 = session.add(self.model(code='test2', - description='just a test')) + s2 = session.add(self.model(code="test2", description="just a test")) yield t.on_result - all = yield query.filter(code=('test','test2')).all() + all = yield query.filter(code=("test", "test2")).all() self.assertEqual(len(all), 2) receiver = self.receiver self.assertTrue(len(receiver.transactions), 1) @@ -49,24 +47,21 @@ def testDelete(self): session = self.session() query = session.query(self.model) with session.begin() as t: - s = session.add(self.model(code='bla', - description='just a test')) + s = session.add(self.model(code="bla", description="just a test")) yield t.on_result yield self.async.assertEqual(query.get(id=s.id), s) yield session.delete(s) - yield self.async.assertRaises(self.model.DoesNotExist, - query.get, id=s.id) + yield self.async.assertRaises(self.model.DoesNotExist, query.get, id=s.id) def test_force_update(self): session = self.session() with session.begin() as t: - s = session.add(self.model(code='test10', - description='just a test')) + s = session.add(self.model(code="test10", description="just a test")) yield t.on_result with session.begin() as t: s = t.add(s, force_update=True) state = s.get_state() - self.assertEqual(state.action, 'update') + self.assertEqual(state.action, "update") self.assertTrue(state.persistent) yield t.on_result @@ -75,8 +70,8 @@ class TestMultiFieldTransaction(test.TestCase): model = Dictionary def make(self): - with self.session().begin(name='create models') as t: - self.assertEqual(t.name, 'create models') + with self.session().begin(name="create models") as t: + self.assertEqual(t.name, "create models") for name in names: t.add(self.model(name=name)) return t.on_result @@ -85,18 +80,19 @@ def testHashField(self): yield self.make() session = self.session() query = session.query(self.model) - d1, d2 = yield query.filter(id__in=(1,2)).all() + d1, d2 = yield query.filter(id__in=(1, 2)).all() with session.begin() as t: - d1.data.add('ciao','hello in Italian') - d1.data.add('bla',10000) - d2.data.add('wine','drink to enjoy with or without food') - d2.data.add('foo',98) + d1.data.add("ciao", "hello in Italian") + d1.data.add("bla", 10000) + d2.data.add("wine", "drink to enjoy with or without food") + d2.data.add("foo", 98) self.assertTrue(d1.data.cache.toadd) self.assertTrue(d2.data.cache.toadd) yield t.on_result self.assertFalse(d1.data.cache.toadd) self.assertFalse(d2.data.cache.toadd) - d1, d2 = yield query.filter(id__in=(1,2)).sort_by('id').load_related('data').all() - self.assertEqual(d1.data['ciao'], 'hello in Italian') - self.assertEqual(d2.data['wine'], 'drink to enjoy with or without food') - + d1, d2 = ( + yield query.filter(id__in=(1, 2)).sort_by("id").load_related("data").all() + ) + self.assertEqual(d1.data["ciao"], "hello in Italian") + self.assertEqual(d2.data["wine"], "drink to enjoy with or without food") diff --git a/tests/all/query/unique.py b/tests/all/query/unique.py index a5ec95c..f434d28 100644 --- a/tests/all/query/unique.py +++ b/tests/all/query/unique.py @@ -1,28 +1,27 @@ -'''Test unique fields''' +"""Test unique fields""" from random import randint -from stdnet import odm, CommitException -from stdnet.utils import test, zip, range - from examples.models import SimpleModel +from stdnet import CommitException, odm +from stdnet.utils import range, test, zip + class SportGenerator(test.DataGenerator): - def generate(self): - self.sports = ['football','rugby','swimming','running','cycling'] - self.codes = set(self.populate('string', min_len=5, max_len=20)) + self.sports = ["football", "rugby", "swimming", "running", "cycling"] + self.codes = set(self.populate("string", min_len=5, max_len=20)) self.size = len(self.codes) - self.groups = self.populate('choice', choice_from=self.sports) + self.groups = self.populate("choice", choice_from=self.sports) self.codes = list(self.codes) - + def __iter__(self): return zip(self.codes, self.groups) def randomcode(self, num=1): a = set() while len(a) < num: - a.add(self.codes[randint(0, self.size-1)]) + a.add(self.codes[randint(0, self.size - 1)]) if num == 1: return tuple(a)[0] else: @@ -32,7 +31,7 @@ def randomcode(self, num=1): class TestUniqueFilter(test.TestCase): data_cls = SportGenerator model = SimpleModel - + @classmethod def after_setup(cls): with cls.session().begin() as t: @@ -42,9 +41,10 @@ def after_setup(cls): def testBadId(self): session = self.session() - yield self.async.assertRaises(self.model.DoesNotExist, - session.query(self.model).get, id=-1) - + yield self.async.assertRaises( + self.model.DoesNotExist, session.query(self.model).get, id=-1 + ) + def testFilterSimple(self): session = self.session() query = session.query(self.model) @@ -53,23 +53,23 @@ def testFilterSimple(self): qs = yield query.filter(code=code).all() self.assertEqual(len(qs), 1) self.assertEqual(qs[0].code, code) - + def testIdCode(self): session = self.session() query = session.query(self.model) all = yield session.query(self.model).all() all2 = yield self.multi_async((query.get(code=m.code) for m in all)) self.assertEqual(all, all2) - + def testExcludeSimple(self): session = self.session() query = session.query(self.model) for i in range(10): code = self.data.randomcode() all = yield query.exclude(code=code).all() - self.assertEqual(len(all), self.data.size-1) + self.assertEqual(len(all), self.data.size - 1) self.assertFalse(code in set((o.code for o in all))) - + def testFilterCodeIn(self): session = self.session() query = session.query(self.model) @@ -78,7 +78,7 @@ def testFilterCodeIn(self): self.assertTrue(qs) match = set((m.code for m in qs)) self.assertEqual(codes, match) - + def testExcludeCodeIn(self): session = self.session() query = session.query(self.model) @@ -88,61 +88,62 @@ def testExcludeCodeIn(self): match = set((m.code for m in qs)) for code in codes: self.assertFalse(code in match) - + def testExcludeInclude(self): session = self.session() query = session.query(self.model) codes = self.data.randomcode(num=3) qs = yield query.exclude(code__in=codes).filter(code=codes).all() self.assertFalse(qs) - + def testTestUnique(self): session = self.session() query = session.query(self.model) - yield self.async.assertEqual(query.test_unique('code', 'xxxxxxxxxx'), - 'xxxxxxxxxx') - m = yield query.get(id=1) yield self.async.assertEqual( - query.test_unique('code', m.code, m), m.code) - m2 = yield query.get(id = 2) - yield self.async.assertRaises(ValueError, - query.test_unique, 'code', m.code, m2, ValueError) + query.test_unique("code", "xxxxxxxxxx"), "xxxxxxxxxx" + ) + m = yield query.get(id=1) + yield self.async.assertEqual(query.test_unique("code", m.code, m), m.code) + m2 = yield query.get(id=2) + yield self.async.assertRaises( + ValueError, query.test_unique, "code", m.code, m2, ValueError + ) + class a: -#class TestUniqueCreate(test.TestWrite): + # class TestUniqueCreate(test.TestWrite): model = SimpleModel - + def testAddNew(self): session = self.session() - m = yield session.add(self.model(code='me', group='bla')) + m = yield session.add(self.model(code="me", group="bla")) self.assertEqualId(m, 1) - self.assertEqual(m.code, 'me') + self.assertEqual(m.code, "me") # Try to create another one - m2 = self.model(code='me', group='foo') + m2 = self.model(code="me", group="foo") yield self.async.assertRaises(CommitException, session.add, m2) self.assertFalse(session.transaction) query = session.query(self.model) yield self.async.assertEqual(query.count(), 1) - m = yield query.get(code='me') + m = yield query.get(code="me") self.assertEqualId(m, 1) - self.assertEqual(m.group, 'bla') + self.assertEqual(m.group, "bla") session.expunge() - m = yield session.add(self.model(code='me2', group='bla')) + m = yield session.add(self.model(code="me2", group="bla")) self.assertEqualId(m, 2) query = session.query(self.model) yield self.async.assertEqual(query.count(), 2) - + def testChangeValue(self): session = self.session() query = session.query(self.model) - m = yield session.add(self.model(code='pippo')) + m = yield session.add(self.model(code="pippo")) self.assertTrue(m.id) - m2 = yield query.get(code='pippo') + m2 = yield query.get(code="pippo") self.assertEqual(m.id, m2.id) # Save with different code - m2.code = 'pippo2' + m2.code = "pippo2" yield session.add(m2) - m3 = yield query.get(code='pippo2') + m3 = yield query.get(code="pippo2") self.assertEqual(m.id, m3.id) - yield self.async.assertRaises(self.model.DoesNotExist, query.get, - code='pippo') \ No newline at end of file + yield self.async.assertRaises(self.model.DoesNotExist, query.get, code="pippo") diff --git a/tests/all/query/where.py b/tests/all/query/where.py index 686ca27..075f397 100644 --- a/tests/all/query/where.py +++ b/tests/all/query/where.py @@ -2,32 +2,33 @@ class TestWhere(ranges.NumericTest): - multipledb = ('redis', 'mongo') - + multipledb = ("redis", "mongo") + def testWhere(self): session = self.session() - qs = session.query(self.model).where('this.vega > this.delta') + qs = session.query(self.model).where("this.vega > this.delta") qs = yield qs.all() self.assertTrue(qs) for m in qs: self.assertTrue(m.vega > m.delta) - + def testConcatenation(self): session = self.session() qs = session.query(self.model) - qs = qs.filter(pv__gt=0).where('this.vega > this.delta') + qs = qs.filter(pv__gt=0).where("this.vega > this.delta") qs = yield qs.all() self.assertTrue(qs) for m in qs: self.assertTrue(m.pv > 0) self.assertTrue(m.vega > m.delta) - + def testLoadOnly(self): - '''load only is only used in redis''' + """load only is only used in redis""" session = self.session() - qs = session.query(self.model).where('this.vega > this.delta', - load_only=('vega','foo','delta')) + qs = session.query(self.model).where( + "this.vega > this.delta", load_only=("vega", "foo", "delta") + ) qs = yield qs.all() self.assertTrue(qs) for m in qs: - self.assertTrue(m.vega > m.delta) \ No newline at end of file + self.assertTrue(m.vega > m.delta) diff --git a/tests/all/serialize/base.py b/tests/all/serialize/base.py index b62c908..2d2db5a 100644 --- a/tests/all/serialize/base.py +++ b/tests/all/serialize/base.py @@ -2,11 +2,10 @@ import tempfile from stdnet import odm -from stdnet.utils import test, BytesIO, to_bytes +from stdnet.utils import BytesIO, test, to_bytes class Tempfile(object): - def __init__(self, data, text=True): fd, path = tempfile.mkstemp(text=text) self.handler = None @@ -19,7 +18,7 @@ def __enter__(self): def write(self, data): if self.fd: - os.write(self.fd,data) + os.write(self.fd, data) os.close(self.fd) self.fd = None @@ -30,8 +29,8 @@ def close(self): def open(self): if self.handler: - raise RuntimeError('File is already opened') - self.handler = open(self.path, 'r') + raise RuntimeError("File is already opened") + self.handler = open(self.path, "r") return self.handler def __exit__(self, type, value, trace): @@ -40,8 +39,8 @@ def __exit__(self, type, value, trace): class BaseSerializerMixin(object): - serializer = 'json' - + serializer = "json" + @classmethod def after_setup(cls): yield cls.data.create(cls) @@ -53,19 +52,18 @@ def get(self, **options): self.assertFalse(s.data) self.assertTrue(s) return s - + def dump(self): models = self.mapper s = self.get() - qs = yield models.instrument.query().sort_by('id').all() + qs = yield models.instrument.query().sort_by("id").all() s.dump(qs) self.assertTrue(s.data) self.assertEqual(len(s.data), 1) yield s - - + + class SerializerMixin(BaseSerializerMixin): - def testMeta(self): self.get() @@ -79,37 +77,36 @@ def test_write(self): class LoadSerializerMixin(BaseSerializerMixin): - def testLoad(self): models = self.mapper s = yield self.dump() - qs = yield models.instrument.query().sort_by('id').all() + qs = yield models.instrument.query().sort_by("id").all() self.assertTrue(qs) data = s.write().getvalue() with Tempfile(data) as tmp: yield models.instrument.flush() yield s.load(models, tmp.open(), self.model) - qs2 = yield models.instrument.query().sort_by('id').all() + qs2 = yield models.instrument.query().sort_by("id").all() self.assertEqual(qs, qs2) class DummySerializer(odm.Serializer): - '''A Serializer for testing registration''' + """A Serializer for testing registration""" + pass class TestMeta(test.TestCase): - def testBadSerializer(self): - self.assertRaises(ValueError, odm.get_serializer, 'djsbvjchvsdjcvsdj') + self.assertRaises(ValueError, odm.get_serializer, "djsbvjchvsdjcvsdj") def testRegisterUnregister(self): - odm.register_serializer('dummy', DummySerializer()) - s = odm.get_serializer('dummy') - self.assertTrue('dummy' in odm.all_serializers()) + odm.register_serializer("dummy", DummySerializer()) + s = odm.get_serializer("dummy") + self.assertTrue("dummy" in odm.all_serializers()) self.assertTrue(isinstance(s, DummySerializer)) self.assertRaises(NotImplementedError, s.dump, None) self.assertRaises(NotImplementedError, s.write) self.assertRaises(NotImplementedError, s.load, None, None) - self.assertTrue(odm.unregister_serializer('dummy')) - self.assertRaises(ValueError, odm.get_serializer, 'dummy') + self.assertTrue(odm.unregister_serializer("dummy")) + self.assertRaises(ValueError, odm.get_serializer, "dummy") diff --git a/tests/all/serialize/csv.py b/tests/all/serialize/csv.py index 1233137..55dc1d2 100644 --- a/tests/all/serialize/csv.py +++ b/tests/all/serialize/csv.py @@ -1,14 +1,14 @@ -'''Test the CSV serializer''' -from stdnet import odm - +"""Test the CSV serializer""" from examples.data import FinanceTest, Fund +from stdnet import odm + from . import base class TestFinanceCSV(base.SerializerMixin, FinanceTest): - serializer = 'csv' - + serializer = "csv" + def testTwoModels(self): models = self.mapper s = yield self.dump() @@ -16,11 +16,11 @@ def testTwoModels(self): funds = yield models.fund.all() self.assertRaises(ValueError, s.dump, funds) self.assertEqual(len(s.data), 1) - + def testLoadError(self): s = yield self.dump() - self.assertRaises(ValueError, s.load, self.mapper, 'bla') - - + self.assertRaises(ValueError, s.load, self.mapper, "bla") + + class TestLoadFinanceCSV(base.LoadSerializerMixin, FinanceTest): - serializer = 'csv' + serializer = "csv" diff --git a/tests/all/serialize/json.py b/tests/all/serialize/json.py index 3d8e99d..8cd6a49 100644 --- a/tests/all/serialize/json.py +++ b/tests/all/serialize/json.py @@ -1,25 +1,23 @@ -'''Test the JSON serializer''' -from stdnet import odm - +"""Test the JSON serializer""" from examples.data import FinanceTest, Fund +from stdnet import odm + from . import base class TestFinanceJSON(base.SerializerMixin, FinanceTest): - serializer = 'json' + serializer = "json" def testTwoModels(self): models = self.mapper s = yield self.dump() d = s.data[0] - self.assertEqual(d['model'], str(self.model._meta)) - all = yield models.fund.query().sort_by('id').all() + self.assertEqual(d["model"], str(self.model._meta)) + all = yield models.fund.query().sort_by("id").all() s.dump(all) self.assertEqual(len(s.data), 2) class TestLoadFinanceJSON(base.LoadSerializerMixin, FinanceTest): - serializer = 'json' - - + serializer = "json" diff --git a/tests/all/structures/base.py b/tests/all/structures/base.py index aec7dbb..319ef60 100644 --- a/tests/all/structures/base.py +++ b/tests/all/structures/base.py @@ -1,18 +1,18 @@ __test__ = False -from stdnet import odm, InvalidTransaction +from stdnet import InvalidTransaction, odm class StructMixin(object): - multipledb = 'redis' + multipledb = "redis" structure = None name = None - + def create_one(self): - '''Create a structure and add few elements. Must return an instance -of the :attr:`structure`.''' + """Create a structure and add few elements. Must return an instance + of the :attr:`structure`.""" raise NotImplementedError - + def empty(self): models = self.mapper l = models.register(self.structure()) @@ -20,7 +20,7 @@ def empty(self): models.session().add(l) self.assertTrue(l.session is not None) return l - + def not_empty(self): models = self.mapper l = models.register(self.create_one()) @@ -28,36 +28,36 @@ def not_empty(self): yield models.session().add(l) self.assertTrue(l.session is not None) yield l - + def test_no_session(self): l = self.create_one() self.assertFalse(l.session) self.assertTrue(l.id) session = self.mapper.session() self.assertRaises(InvalidTransaction, session.add, l) - + def test_meta(self): models = self.mapper l = models.register(self.create_one()) self.assertTrue(l.id) session = models.session() with session.begin() as t: - t.add(l) # add the structure to the session + t.add(l) # add the structure to the session self.assertEqual(l.session, session) self.assertEqual(l._meta.name, self.name) - self.assertEqual(l._meta.model._model_type, 'structure') - #Structure have always the persistent flag set to True + self.assertEqual(l._meta.model._model_type, "structure") + # Structure have always the persistent flag set to True self.assertTrue(l.get_state().persistent) self.assertTrue(l in session) size = yield l.size() self.assertEqual(size, 0) yield t.on_result yield l - + def test_commit(self): l = yield self.test_meta() yield self.async.assertTrue(l.size()) - + def test_delete(self): models = self.mapper l = models.register(self.create_one()) @@ -68,9 +68,9 @@ def test_delete(self): yield session.delete(l) yield self.async.assertEqual(l.size(), 0) self.assertEqual(l.session, session) - + def test_empty(self): - '''Create an empty structure''' + """Create an empty structure""" models = self.mapper l = models.register(self.structure()) self.assertTrue(l.id) diff --git a/tests/all/structures/hash.py b/tests/all/structures/hash.py index 7e0d626..4f4d501 100644 --- a/tests/all/structures/hash.py +++ b/tests/all/structures/hash.py @@ -4,53 +4,54 @@ from stdnet.utils import test from .base import StructMixin - - + + class TestHash(StructMixin, test.TestCase): structure = odm.HashTable - name = 'hashtable' - + name = "hashtable" + def create_one(self): h = odm.HashTable() - h['bla'] = 'foo' - h['pluto'] = 3 + h["bla"] = "foo" + h["pluto"] = 3 return h - + def test_get_empty(self): - d = self.empty() - result = yield d.get('blaxxx', 3) + d = self.empty() + result = yield d.get("blaxxx", 3) self.assertEqual(result, 3) - + def test_pop(self): models = self.mapper d = models.register(self.create_one()) session = models.session() with session.begin() as t: d = t.add(d) - d['foo'] = 'ciao' + d["foo"] = "ciao" yield t.on_result yield self.async.assertEqual(d.size(), 3) - yield self.async.assertEqual(d['foo'], 'ciao') - yield self.async.assertRaises(KeyError, d.pop, 'blascd') - yield self.async.assertEqual(d.pop('xxx', 56), 56) - self.assertRaises(TypeError, d.pop, 'xxx', 1, 2) - yield self.async.assertEqual(d.pop('foo'), 'ciao') + yield self.async.assertEqual(d["foo"], "ciao") + yield self.async.assertRaises(KeyError, d.pop, "blascd") + yield self.async.assertEqual(d.pop("xxx", 56), 56) + self.assertRaises(TypeError, d.pop, "xxx", 1, 2) + yield self.async.assertEqual(d.pop("foo"), "ciao") yield self.async.assertEqual(d.size(), 2) - + def test_get(self): models = self.mapper d = models.register(self.structure()) session = models.session() with session.begin() as t: d = t.add(d) - d['baba'] = 'foo' - d['bee'] = 3 + d["baba"] = "foo" + d["bee"] = 3 self.assertEqual(len(d.cache.toadd), 2) yield t.on_result - result = yield multi_async((d['baba'], d.get('bee'), d.get('ggg'), - d.get('ggg', 1))) - self.assertEqual(result, ['foo', 3, None, 1]) - yield self.async.assertRaises(KeyError, lambda : d['gggggg']) + result = yield multi_async( + (d["baba"], d.get("bee"), d.get("ggg"), d.get("ggg", 1)) + ) + self.assertEqual(result, ["foo", 3, None, 1]) + yield self.async.assertRaises(KeyError, lambda: d["gggggg"]) def test_keys(self): models = self.mapper @@ -58,23 +59,22 @@ def test_keys(self): session = models.session() yield session.add(d) values = yield d.keys() - self.assertEqual(set(('bla', 'pluto')), set(values)) - + self.assertEqual(set(("bla", "pluto")), set(values)) + def test_values(self): models = self.mapper d = models.register(self.create_one()) session = models.session() yield session.add(d) values = yield d.values() - self.assertEqual(set(('foo', 3)), set(values)) - + self.assertEqual(set(("foo", 3)), set(values)) + def test_items(self): models = self.mapper d = models.register(self.create_one()) session = models.session() yield session.add(d) values = yield d.items() - data = {'bla': 'foo', 'pluto': 3} + data = {"bla": "foo", "pluto": 3} self.assertNotEqual(data, values) self.assertEqual(data, dict(values)) - \ No newline at end of file diff --git a/tests/all/structures/list.py b/tests/all/structures/list.py index 564bd98..7852442 100644 --- a/tests/all/structures/list.py +++ b/tests/all/structures/list.py @@ -1,12 +1,12 @@ from stdnet import odm -from stdnet.utils import test, encoders +from stdnet.utils import encoders, test from .base import StructMixin class TestList(StructMixin, test.TestCase): structure = odm.List - name = 'list' + name = "list" def create_one(self): l = odm.List() @@ -17,10 +17,10 @@ def create_one(self): def test_items(self): l = yield self.test_meta() self.assertFalse(l.session.transaction) - yield l.push_back('save') - yield l.push_back({'test': 1}) + yield l.push_back("save") + yield l.push_back({"test": 1}) yield self.async.assertEqual(l.size(), 4) - result = [3,5.6,'save',"{'test': 1}"] + result = [3, 5.6, "save", "{'test': 1}"] yield self.async.assertEqual(l.items(), result) def test_json_list(self): @@ -33,13 +33,13 @@ def test_json_list(self): t.add(l) l.push_back(3) l.push_back(5.6) - l.push_back('save') - l.push_back({'test': 1}) - l.push_back({'test': 2}) + l.push_back("save") + l.push_back({"test": 1}) + l.push_back({"test": 2}) self.assertEqual(len(l.cache.back), 5) yield t.on_result yield self.async.assertEqual(l.size(), 5) - result = [3, 5.6, 'save', {'test': 1}, {'test': 2}] + result = [3, 5.6, "save", {"test": 1}, {"test": 2}] yield self.async.assertEqual(l.items(), result) self.assertEqual(list(l), result) diff --git a/tests/all/structures/numarray.py b/tests/all/structures/numarray.py index 3e73526..d1dad1d 100644 --- a/tests/all/structures/numarray.py +++ b/tests/all/structures/numarray.py @@ -1,20 +1,21 @@ import os from datetime import date -from stdnet import odm, InvalidTransaction -from stdnet.utils import test, encoders, zip +from stdnet import InvalidTransaction, odm +from stdnet.utils import encoders, test, zip from stdnet.utils.populate import populate from .base import StructMixin - + + class TestNumberArray(StructMixin, test.TestCase): structure = odm.NumberArray - name = 'numberarray' - + name = "numberarray" + def create_one(self): a = self.structure() return a.push_back(56).push_back(-78.6) - + def testSizeResize(self): a = yield self.not_empty() yield self.async.assertEqual(a.size(), 2) @@ -25,10 +26,10 @@ def testSizeResize(self): self.assertAlmostEqual(data[0], 56.0) self.assertAlmostEqual(data[1], -78.6) for v in data[2:]: - self.assertNotEqual(v,v) - + self.assertNotEqual(v, v) + def testSetGet(self): a = yield self.not_empty() yield self.async.assertEqual(a.size(), 2) value = yield a[1] - self.assertAlmostEqual(value, -78.6) \ No newline at end of file + self.assertAlmostEqual(value, -78.6) diff --git a/tests/all/structures/set.py b/tests/all/structures/set.py index 9281830..19bf746 100644 --- a/tests/all/structures/set.py +++ b/tests/all/structures/set.py @@ -6,34 +6,33 @@ class TestSet(StructMixin, test.TestCase): structure = odm.Set - name = 'set' - + name = "set" + def create_one(self): s = self.structure() - s.update((1,2,3,4,5,5)) + s.update((1, 2, 3, 4, 5, 5)) return s - + def test_update(self): # Typical usage. Add a set to a session s = self.empty() s.session.add(s) yield s.add(8) yield self.async.assertEqual(s.size(), 1) - yield s.update((1,2,3,4,5,5)) + yield s.update((1, 2, 3, 4, 5, 5)) yield self.async.assertEqual(s.size(), 6) - + def test_update_delete(self): s = self.empty() with s.session.begin() as t: t.add(s) - s.update((1,2,3,4,5,5)) + s.update((1, 2, 3, 4, 5, 5)) s.discard(2) s.discard(67) s.remove(4) s.remove(46) - s.difference_update((1,56,89)) + s.difference_update((1, 56, 89)) yield t.on_result yield self.async.assertEqual(s.size(), 2) - yield s.difference_update((3,5,6,7)) + yield s.difference_update((3, 5, 6, 7)) yield self.async.assertEqual(s.size(), 0) - diff --git a/tests/all/structures/string.py b/tests/all/structures/string.py index 9faa984..0c3b7b0 100644 --- a/tests/all/structures/string.py +++ b/tests/all/structures/string.py @@ -1,22 +1,22 @@ import os from datetime import date -from stdnet import odm, InvalidTransaction -from stdnet.utils import test, encoders, zip +from stdnet import InvalidTransaction, odm +from stdnet.utils import encoders, test, zip from stdnet.utils.populate import populate from .base import StructMixin - + class TestString(StructMixin, test.TestCase): structure = odm.String - name = 'string' - + name = "string" + def create_one(self): a = self.structure() - a.push_back('this is a test') + a.push_back("this is a test") return a - + def test_incr(self): a = self.empty() a.session.add(a) @@ -24,5 +24,3 @@ def test_incr(self): yield self.async.assertEqual(a.incr(), 2) yield self.async.assertEqual(a.incr(3), 5) yield self.async.assertEqual(a.incr(-7), -2) - - \ No newline at end of file diff --git a/tests/all/structures/ts.py b/tests/all/structures/ts.py index ca1e57c..14fe4b5 100644 --- a/tests/all/structures/ts.py +++ b/tests/all/structures/ts.py @@ -2,8 +2,7 @@ from datetime import date from stdnet import odm -from stdnet.utils import test, encoders, zip - +from stdnet.utils import encoders, test, zip from tests.all.multifields.timeseries import TsData from .base import StructMixin @@ -12,7 +11,7 @@ class TestTS(StructMixin, test.TestCase): structure = odm.TS data_cls = TsData - name = 'ts' + name = "ts" def create_one(self): ts = self.structure() @@ -40,8 +39,8 @@ def test_range(self): range = yield ts.range(all_dates[start], all_dates[end]) self.assertTrue(range) for time, val in range: - self.assertTrue(time>=front[0]) - self.assertTrue(time<=back[0]) + self.assertTrue(time >= front[0]) + self.assertTrue(time <= back[0]) def test_get(self): ts = yield self.not_empty() @@ -49,9 +48,9 @@ def test_get(self): val1 = yield ts[dt1] self.assertTrue(val1) yield self.async.assertEqual(ts.get(dt1), val1) - yield self.async.assertEqual(ts.get(date(1990,1,1)),None) - yield self.async.assertEqual(ts.get(date(1990,1,1),1),1) - yield self.async.assertRaises(KeyError, lambda : ts[date(1990,1,1)]) + yield self.async.assertEqual(ts.get(date(1990, 1, 1)), None) + yield self.async.assertEqual(ts.get(date(1990, 1, 1), 1), 1) + yield self.async.assertRaises(KeyError, lambda: ts[date(1990, 1, 1)]) def test_pop(self): ts = yield self.not_empty() @@ -61,7 +60,7 @@ def test_pop(self): self.assertTrue(v) yield self.async.assertFalse(dt in ts) yield self.async.assertRaises(KeyError, ts.pop, dt) - yield self.async.assertEqual(ts.pop(dt,'bla'), 'bla') + yield self.async.assertEqual(ts.pop(dt, "bla"), "bla") def test_rank_ipop(self): ts = yield self.not_empty() @@ -80,12 +79,12 @@ def test_pop_range(self): N = len(all_dates) start = N // 4 end = 3 * N // 4 - range = yield ts.range(all_dates[start],all_dates[end]) + range = yield ts.range(all_dates[start], all_dates[end]) self.assertTrue(range) range2 = yield ts.pop_range(all_dates[start], all_dates[end]) self.assertEqual(range, range2) all_dates = yield ts.itimes() all_dates = set(all_dates) self.assertTrue(all_dates) - for dt,_ in range: + for dt, _ in range: self.assertFalse(dt in all_dates) diff --git a/tests/all/structures/zset.py b/tests/all/structures/zset.py index b272218..60197ed 100644 --- a/tests/all/structures/zset.py +++ b/tests/all/structures/zset.py @@ -2,45 +2,51 @@ from datetime import date from stdnet import odm -from stdnet.utils import test, encoders, zip +from stdnet.utils import encoders, test, zip from stdnet.utils.populate import populate from .base import StructMixin -dates = list(set(populate('date',100,start=date(2009,6,1),end=date(2010,6,6)))) -values = populate('float',len(dates),start=0,end=1000) +dates = list(set(populate("date", 100, start=date(2009, 6, 1), end=date(2010, 6, 6)))) +values = populate("float", len(dates), start=0, end=1000) class TestZset(StructMixin, test.TestCase): structure = odm.Zset - name = 'zset' - result = [(0.0022,'pluto'), - (0.06,'mercury'), - (0.11,'mars'), - (0.82,'venus'), - (1,'earth'), - (14.6,'uranus'), - (17.2,'neptune'), - (95.2,'saturn'), - (317.8,'juppiter')] - + name = "zset" + result = [ + (0.0022, "pluto"), + (0.06, "mercury"), + (0.11, "mars"), + (0.82, "venus"), + (1, "earth"), + (14.6, "uranus"), + (17.2, "neptune"), + (95.2, "saturn"), + (317.8, "juppiter"), + ] + def create_one(self): l = self.structure() - l.add(1,'earth') - l.add(0.06,'mercury') - l.add(317.8,'juppiter') - l.update(((95.2,'saturn'),\ - (0.82,'venus'),\ - (14.6,'uranus'),\ - (0.11,'mars'), - (17.2,'neptune'), - (0.0022,'pluto'))) + l.add(1, "earth") + l.add(0.06, "mercury") + l.add(317.8, "juppiter") + l.update( + ( + (95.2, "saturn"), + (0.82, "venus"), + (14.6, "uranus"), + (0.11, "mars"), + (17.2, "neptune"), + (0.0022, "pluto"), + ) + ) self.assertEqual(len(l.cache.toadd), 9) self.assertFalse(l.cache.cache) self.assertTrue(l.cache.toadd) self.assertFalse(l.cache.toremove) return l - + def planets(self): models = self.mapper l = models.register(self.create_one()) @@ -55,44 +61,43 @@ def planets(self): size = yield l.size() self.assertEqual(size, 9) yield l - + def test_irange(self): l = yield self.planets() # Get the whole range without the scores r = yield l.irange(withscores=False) self.assertEqual(r, [v[1] for v in self.result]) - + def test_irange_withscores(self): l = yield self.planets() # Get the whole range r = yield l.irange() self.assertEqual(r, self.result) - + def test_range(self): l = yield self.planets() r = yield l.range(0.5, 20, withscores=False) - self.assertEqual(r, ['venus', 'earth', 'uranus', 'neptune']) - + self.assertEqual(r, ["venus", "earth", "uranus", "neptune"]) + def test_range_withscores(self): l = yield self.planets() - r = yield l.range(0.5,20) + r = yield l.range(0.5, 20) self.assertTrue(r) k1 = 0.5 for k, v in r: - self.assertTrue(k>=k1) - self.assertTrue(k<=20) + self.assertTrue(k >= k1) + self.assertTrue(k <= 20) k1 = k - + def test_iter(self): - '''test a very simple zset with integer''' + """test a very simple zset with integer""" l = yield self.planets() r = list(l) v = [t[1] for t in self.result] - self.assertEqual(r,v) - + self.assertEqual(r, v) + def test_items(self): - '''test a very simple zset with integer''' + """test a very simple zset with integer""" l = yield self.planets() r = list(l.items()) self.assertEqual(r, [r[1] for r in self.result]) - diff --git a/tests/all/topics/finance.py b/tests/all/topics/finance.py index 1a2ea11..2decd48 100755 --- a/tests/all/topics/finance.py +++ b/tests/all/topics/finance.py @@ -2,20 +2,19 @@ import logging from random import randint +from examples.data import CCYS_TYPES, INSTS_TYPES, finance_data +from examples.models import Fund, Instrument, PortfolioView, Position, UserDefaultView + from stdnet import QuerySetError from stdnet.utils import test -from examples.models import Instrument, Fund, Position, PortfolioView,\ - UserDefaultView -from examples.data import finance_data, INSTS_TYPES, CCYS_TYPES - class TestFinanceApplication(test.TestWrite): data_cls = finance_data models = (Instrument, Fund, Position) def testGetObject(self): - '''Test get method for id and unique field''' + """Test get method for id and unique field""" session = yield self.data.create(self) query = session.query(Instrument) obj = yield query.get(id=2) @@ -25,17 +24,17 @@ def testGetObject(self): self.assertEqual(obj, obj2) def testLen(self): - '''Simply test len of objects greater than zero''' + """Simply test len of objects greater than zero""" session = yield self.data.create(self) objs = yield session.query(Instrument).all() self.assertTrue(len(objs) > 0) def testFilter(self): - '''Test filtering on a model without foreign keys''' + """Test filtering on a model without foreign keys""" yield self.data.create(self) session = self.session() query = session.query(Instrument) - self.async.assertRaises(QuerySetError, query.get, type='equity') + self.async.assertRaises(QuerySetError, query.get, type="equity") tot = 0 for t in INSTS_TYPES: fs = query.filter(type=t) @@ -44,7 +43,7 @@ def testFilter(self): for f in all: count[f.ccy] = count.get(f.ccy, 0) + 1 for c in CCYS_TYPES: - x = count.get(c,0) + x = count.get(c, 0) objs = yield fs.filter(ccy=c).all() y = 0 for obj in objs: @@ -52,21 +51,21 @@ def testFilter(self): tot += 1 self.assertEqual(obj.type, t) self.assertEqual(obj.ccy, c) - self.assertEqual(x,y) + self.assertEqual(x, y) all = query.all() self.assertEqual(tot, len(all)) def testValidation(self): pos = Position(size=10) self.assertFalse(pos.is_valid()) - self.assertEqual(len(pos._dbdata['errors']),3) - self.assertEqual(len(pos._dbdata['cleaned_data']),1) - self.assertTrue('size' in pos._dbdata['cleaned_data']) + self.assertEqual(len(pos._dbdata["errors"]), 3) + self.assertEqual(len(pos._dbdata["cleaned_data"]), 1) + self.assertTrue("size" in pos._dbdata["cleaned_data"]) def testForeignKey(self): - '''Test filtering with foreignkeys''' + """Test filtering with foreignkeys""" session = yield self.data.makePositions(self) - query = session.query(Position).load_related('instrument').load_related('fund') + query = session.query(Position).load_related("instrument").load_related("fund") # positions = yield query.all() self.assertTrue(positions) @@ -87,7 +86,7 @@ def testForeignKey(self): instruments = yield session.query(Instrument).all() # for instrument in instruments: - multi.append(instrument.positions.query().load_related('instrument').all()) + multi.append(instrument.positions.query().load_related("instrument").all()) multi = yield self.multi_async(multi) # for instrument, pos in zip(instruments, multi): @@ -107,16 +106,16 @@ def testRelatedManagerFilter(self): flist = [] for pos in positions: fund = pos.fund - n = funds.get(fund.id,0) + 1 + n = funds.get(fund.id, 0) + 1 funds[fund.id] = n if n == 1: flist.append(fund) for fund in flist: - positions = instrument.positions.filter(fund = fund) - self.assertEqual(len(positions),funds[fund.id]) + positions = instrument.positions.filter(fund=fund) + self.assertEqual(len(positions), funds[fund.id]) def testDeleteSimple(self): - '''Test delete on models without related models''' + """Test delete on models without related models""" session = yield self.data.create(self) instruments = session.query(Instrument) funds = session.query(Fund) @@ -128,7 +127,7 @@ def testDeleteSimple(self): self.assertFalse(session.query(Fund).count()) def testDelete(self): - '''Test delete on models with related models''' + """Test delete on models with related models""" # Create Positions which hold foreign keys to Instruments session = yield self.data.makePositions(self) instruments = session.query(Instrument) @@ -138,4 +137,3 @@ def testDelete(self): instruments.delete() self.assertFalse(session.query(Instrument).count()) self.assertFalse(session.query(Position).count()) - diff --git a/tests/all/topics/observer.py b/tests/all/topics/observer.py index 0fef3f6..dd9816f 100644 --- a/tests/all/topics/observer.py +++ b/tests/all/topics/observer.py @@ -1,24 +1,26 @@ from random import randint from time import time -from stdnet.utils import test +from examples.observer import Observable, Observer, update_observers -from examples.observer import Observer, Observable, update_observers +from stdnet.utils import test class ObserverData(test.DataGenerator): - sizes = {'tiny': (2, 5), # observable, observers - 'small': (5, 20), - 'normal': (10, 80), - 'big': (50, 500), - 'huge': (100, 10000)} + sizes = { + "tiny": (2, 5), # observable, observers + "small": (5, 20), + "normal": (10, 80), + "big": (50, 500), + "huge": (100, 10000), + } def generate(self): self.observables, self.observers = self.size class ObserverTest(test.TestWrite): - multipledb = 'redis' + multipledb = "redis" models = (Observer, Observable) data_cls = ObserverData @@ -48,8 +50,8 @@ def setUp(self): # The first observervable is observed by all observers created.add(observables[0]) observer.underlyings.add(observables[0]) - for i in range(randint(1, N-1)): - o = observables[randint(0, N-1)] + for i in range(randint(1, N - 1)): + o = observables[randint(0, N - 1)] created.add(o) observer.underlyings.add(o) yield t.on_result @@ -79,7 +81,7 @@ def test_created(self): self.assertEqual(created, set(observables)) def test_simple_save(self): - '''Save the first observable and check for updates.''' + """Save the first observable and check for updates.""" models = self.mapper obs = self.observables[0] now = time() diff --git a/tests/all/topics/permissions.py b/tests/all/topics/permissions.py index 56f9f43..3672f6c 100644 --- a/tests/all/topics/permissions.py +++ b/tests/all/topics/permissions.py @@ -1,5 +1,6 @@ from random import choice -from examples.permissions import User, Group, Role, Permission + +from examples.permissions import Group, Permission, Role, User from stdnet import odm from stdnet.utils import test, zip @@ -10,76 +11,75 @@ update = 30 delete = 40 + class MyModel(odm.StdModel): pass class NamesGenerator(test.DataGenerator): - def generate(self): group_size = self.size // 2 self.usernames = self.populate(min_len=5, max_len=20) self.passwords = self.populate(min_len=7, max_len=20) self.groups = self.populate(size=group_size, min_len=5, max_len=10) - class TestPermissions(test.TestCase): models = (User, Group, Role, Permission, MyModel) data_cls = NamesGenerator - + @classmethod def after_setup(cls): d = cls.data models = cls.mapper groups = [] - groups.append(models.group.create_user(username='stdnet', - can_login=False)) + groups.append(models.group.create_user(username="stdnet", can_login=False)) for username, password in zip(d.usernames, d.passwords): - groups.append(models.group.create_user(username=username, - password=password)) + groups.append( + models.group.create_user(username=username, password=password) + ) yield cls.multi_async(groups) session = models.session() groups = yield session.query(Group).all() with models.session().begin() as t: - for group in groups: - group.create_role('family') # create the group-family role - group.create_role('friends') # create the group-friends role + for group in groups: + group.create_role("family") # create the group-family role + group.create_role("friends") # create the group-friends role yield t.on_result - + def random_group(self, *excludes): if excludes: - name = choice(list(set(self.data.usernames)-set(excludes))) + name = choice(list(set(self.data.usernames) - set(excludes))) else: name = choice(self.data.usernames) return self.mapper.group.get(name=name) - + def test_group_query(self): groups = self.mapper.group - cache = groups._meta.dfields['user'].get_cache_name() + cache = groups._meta.dfields["user"].get_cache_name() groups = yield groups.all() for g in groups: self.assertTrue(hasattr(g, cache)) self.assertEqual(g.user.username, g.name) - + def test_create_role(self, name=None): # Create a new role name = name or self.data.random_string() models = self.mapper group = yield self.random_group() - role = yield group.create_role(name) # add a random role + role = yield group.create_role(name) # add a random role self.assertEqual(role.name, name) self.assertEqual(role.owner, group) permission = yield role.add_permission(MyModel, read) self.assertEqual(permission.model_type, MyModel) - self.assertEqual(permission.object_pk, '') + self.assertEqual(permission.object_pk, "") self.assertEqual(permission.operation, read) # # the role should have only one permission permissions = yield role.permissions.all() - self.assertTrue(len(permissions)>=1) + self.assertTrue(len(permissions) >= 1) yield role - + def test_role_assignto_group(self): role = yield self.test_create_role() group = yield self.random_group(role.owner.name) @@ -92,4 +92,3 @@ def test_role_assignto_group(self): # group has a new role roles = yield group.roles.all() self.assertTrue(role in roles) - \ No newline at end of file diff --git a/tests/all/topics/twitter.py b/tests/all/topics/twitter.py index 4153042..04238b9 100755 --- a/tests/all/topics/twitter.py +++ b/tests/all/topics/twitter.py @@ -1,25 +1,25 @@ from datetime import datetime -from random import randint, choice +from random import choice, randint -from stdnet import odm -from stdnet.utils import test, zip, populate +from examples.models import Post, User -from examples.models import User, Post +from stdnet import odm +from stdnet.utils import populate, test, zip class TwitterData(test.DataGenerator): - sizes = {'tiny': (10, 5), - 'small': (30, 10), - 'normal': (100, 30), - 'big': (1000, 100), - 'huge': (100000, 1000)} + sizes = { + "tiny": (10, 5), + "small": (30, 10), + "normal": (100, 30), + "big": (1000, 100), + "huge": (100000, 1000), + } def generate(self): size, _ = self.size - self.usernames = self.populate('string', size=size, min_len=5, - max_len=20) - self.passwords = self.populate('string', size=size, min_len=8, - max_len=20) + self.usernames = self.populate("string", size=size, min_len=5, max_len=20) + self.passwords = self.populate("string", size=size, min_len=8, max_len=20) def followers(self): _, max_size = self.size @@ -33,24 +33,23 @@ class TestTwitter(test.TestWrite): def setUp(self): with self.mapper.session().begin() as t: - for username, password in zip(self.data.usernames, - self.data.passwords): + for username, password in zip(self.data.usernames, self.data.passwords): t.add(User(username=username, password=password)) return t.on_result def testMeta(self): following = User.following followers = User.followers - self.assertEqual(following.formodel,User) - self.assertEqual(following.relmodel,User) - self.assertEqual(followers.formodel,User) - self.assertEqual(followers.relmodel,User) + self.assertEqual(following.formodel, User) + self.assertEqual(following.relmodel, User) + self.assertEqual(followers.formodel, User) + self.assertEqual(followers.relmodel, User) self.assertEqual(following.model, followers.model) - self.assertEqual(len(following.model._meta.dfields),3) - self.assertEqual(following.name_relmodel, 'user') - self.assertEqual(following.name_formodel, 'user2') - self.assertEqual(followers.name_relmodel, 'user2') - self.assertEqual(followers.name_formodel, 'user') + self.assertEqual(len(following.model._meta.dfields), 3) + self.assertEqual(following.name_relmodel, "user") + self.assertEqual(following.name_formodel, "user2") + self.assertEqual(followers.name_relmodel, "user2") + self.assertEqual(followers.name_formodel, "user") def testRelated(self): models = self.mapper @@ -62,14 +61,14 @@ def testRelated(self): self.assertEqual(r.user, user1) self.assertEqual(r.user2, user3) followers = user3.followers.query().all() - self.assertEqual(len(followers),1) - self.assertEqual(followers[0],user1) + self.assertEqual(len(followers), 1) + self.assertEqual(followers[0], user1) user2.following.add(user3) followers = list(user3.followers.query()) - self.assertEqual(len(followers),2) + self.assertEqual(len(followers), 2) def testFollowers(self): - '''Add followers to a user''' + """Add followers to a user""" # unwind queryset here since we are going to use it in a double loop models = self.mapper users = yield models.user.query().all() @@ -79,11 +78,11 @@ def testFollowers(self): for user in users: N = self.data.followers() uset = set() - for tofollow in populate('choice', N, choice_from=users): + for tofollow in populate("choice", N, choice_from=users): uset.add(tofollow) user.following.add(tofollow) count.append(len(uset)) - self.assertTrue(user.following.query().count()>0) + self.assertTrue(user.following.query().count() > 0) # for user, N in zip(users, count): all_following = user.following.query() @@ -92,7 +91,7 @@ def testFollowers(self): self.assertTrue(user in following.followers.query()) def testFollowersTransaction(self): - '''Add followers to a user''' + """Add followers to a user""" # unwind queryset here since we are going to use it in a double loop models = self.mapper session = models.session() @@ -104,7 +103,7 @@ def testFollowersTransaction(self): self.assertEqual(user.session, session) N = self.data.followers() following = user.following - for tofollow in populate('choice', N, choice_from=users): + for tofollow in populate("choice", N, choice_from=users): following.add(tofollow) yield t.on_result for user in users: @@ -119,7 +118,6 @@ def testMessages(self): ids = [u.id for u in users] id = choice(ids) user = yield models.user.get(id=id) - yield user.newupdate('this is my first message') - yield user.newupdate('and this is another one') + yield user.newupdate("this is my first message") + yield user.newupdate("and this is another one") yield self.async.assertEqual(user.updates.size(), 2) - diff --git a/tests/all/utils/intervals.py b/tests/all/utils/intervals.py index 83f79ab..aed9a92 100644 --- a/tests/all/utils/intervals.py +++ b/tests/all/utils/intervals.py @@ -1,70 +1,68 @@ -from stdnet.utils import test, Interval, Intervals, pickle +from stdnet.utils import Interval, Intervals, pickle, test class TestInterval(test.TestCase): - def intervals(self): - a = Interval(4,6) - b = Interval(8,10) - intervals = Intervals((b,a)) - self.assertEqual(len(intervals),2) - self.assertEqual(intervals[0],a) - self.assertEqual(intervals[1],b) + a = Interval(4, 6) + b = Interval(8, 10) + intervals = Intervals((b, a)) + self.assertEqual(len(intervals), 2) + self.assertEqual(intervals[0], a) + self.assertEqual(intervals[1], b) return intervals - + def testSimple(self): - a = Interval(4,6) - self.assertEqual(a.start,4) - self.assertEqual(a.end,6) - self.assertEqual(tuple(a),(4,6)) + a = Interval(4, 6) + self.assertEqual(a.start, 4) + self.assertEqual(a.end, 6) + self.assertEqual(tuple(a), (4, 6)) self.assertRaises(ValueError, Interval, 6, 3) - + def testPickle(self): - a = Interval(4,6) + a = Interval(4, 6) s = pickle.dumps(a) b = pickle.loads(s) - self.assertEqual(type(b),tuple) - self.assertEqual(len(b),2) - self.assertEqual(b[0],4) - self.assertEqual(b[1],6) - + self.assertEqual(type(b), tuple) + self.assertEqual(len(b), 2) + self.assertEqual(b[0], 4) + self.assertEqual(b[1], 6) + def testPickleIntervals(self): a = self.intervals() s = pickle.dumps(a) b = pickle.loads(s) - self.assertEqual(type(b),list) - self.assertEqual(len(b),len(a)) - + self.assertEqual(type(b), list) + self.assertEqual(len(b), len(a)) + def testmultiple(self): i = self.intervals() - a = Interval(20,30) + a = Interval(20, 30) i.append(a) - self.assertEqual(len(i),3) - self.assertEqual(i[-1],a) - i.append(Interval(18,21)) - self.assertEqual(len(i),3) - self.assertNotEqual(i[-1],a) - self.assertEqual(i[-1].start,18) - self.assertEqual(i[-1].end,30) - i.append(Interval(8,10)) - self.assertEqual(len(i),3) - self.assertEqual(i[-2].start,8) - self.assertEqual(i[-2].end,10) - i.append(Interval(8,25)) - self.assertEqual(len(i),2) - self.assertEqual(i[-1].start,8) - self.assertEqual(i[-1].end,30) - i.append(Interval(1,40)) - self.assertEqual(len(i),1) - self.assertEqual(i[0].start,1) - self.assertEqual(i[0].end,40) - + self.assertEqual(len(i), 3) + self.assertEqual(i[-1], a) + i.append(Interval(18, 21)) + self.assertEqual(len(i), 3) + self.assertNotEqual(i[-1], a) + self.assertEqual(i[-1].start, 18) + self.assertEqual(i[-1].end, 30) + i.append(Interval(8, 10)) + self.assertEqual(len(i), 3) + self.assertEqual(i[-2].start, 8) + self.assertEqual(i[-2].end, 10) + i.append(Interval(8, 25)) + self.assertEqual(len(i), 2) + self.assertEqual(i[-1].start, 8) + self.assertEqual(i[-1].end, 30) + i.append(Interval(1, 40)) + self.assertEqual(len(i), 1) + self.assertEqual(i[0].start, 1) + self.assertEqual(i[0].end, 40) + def testAppendtuple(self): i = self.intervals() - i.append((18,21)) - self.assertEqual(len(i),3) - self.assertEqual(i[-1].start,18) - self.assertEqual(i[-1].end,21) + i.append((18, 21)) + self.assertEqual(len(i), 3) + self.assertEqual(i[-1].start, 18) + self.assertEqual(i[-1].end, 21) self.assertRaises(TypeError, i.append, 3) - self.assertRaises(ValueError, i.append, (8,2)) - \ No newline at end of file + self.assertRaises(ValueError, i.append, (8, 2)) diff --git a/tests/all/utils/tools.py b/tests/all/utils/tools.py index 58b8ce2..6903010 100755 --- a/tests/all/utils/tools.py +++ b/tests/all/utils/tools.py @@ -1,37 +1,43 @@ import time from datetime import date, datetime +from examples.models import Statistics3 + import stdnet from stdnet import odm +from stdnet.utils import ( + _format_int, + addmul_number_dicts, + date2timestamp, + encoders, + grouper, + populate, + test, + timestamp2date, + to_bytes, + to_string, +) from stdnet.utils.version import get_git_changeset -from stdnet.utils import test, encoders, to_bytes, to_string -from stdnet.utils import date2timestamp, timestamp2date,\ - addmul_number_dicts, grouper,\ - _format_int, populate - -from examples.models import Statistics3 class TestUtils(test.TestCase): multipledb = False model = Statistics3 - + def __testNestedJasonValue(self): - data = {'data':1000, - 'folder1':{'folder11':1, - 'folder12':2, - '':'home'}} + data = {"data": 1000, "folder1": {"folder11": 1, "folder12": 2, "": "home"}} session = self.session() with session.begin(): - session.add(self.model(name='foo',data=data)) - obj = session.query(self.model).get(id = 1) - self.assertEqual(\ - nested_json_value(obj,'data__folder1__folder11',odm.JSPLITTER),1) - self.assertEqual(\ - nested_json_value(obj,'data__folder1__folder12',odm.JSPLITTER),2) - self.assertEqual(\ - nested_json_value(obj,'data__folder1',odm.JSPLITTER),'home') - + session.add(self.model(name="foo", data=data)) + obj = session.query(self.model).get(id=1) + self.assertEqual( + nested_json_value(obj, "data__folder1__folder11", odm.JSPLITTER), 1 + ) + self.assertEqual( + nested_json_value(obj, "data__folder1__folder12", odm.JSPLITTER), 2 + ) + self.assertEqual(nested_json_value(obj, "data__folder1", odm.JSPLITTER), "home") + def test_date2timestamp(self): t1 = datetime.now() ts1 = date2timestamp(t1) @@ -39,100 +45,106 @@ def test_date2timestamp(self): t1 = date.today() ts1 = date2timestamp(t1) t = timestamp2date(ts1) - self.assertEqual(t.date(),t1) - self.assertEqual(t.hour,0) - self.assertEqual(t.minute,0) - self.assertEqual(t.second,0) - self.assertEqual(t.microsecond,0) - + self.assertEqual(t.date(), t1) + self.assertEqual(t.hour, 0) + self.assertEqual(t.minute, 0) + self.assertEqual(t.second, 0) + self.assertEqual(t.microsecond, 0) + def test_addmul_number_dicts(self): - d1 = {'bla': 2.5, 'foo': 1.1} - d2 = {'bla': -2, 'foo': -0.3} - r = addmul_number_dicts(((2,d1),(-1,d2))) - self.assertEqual(len(r),2) - self.assertAlmostEqual(r['bla'],7) - self.assertAlmostEqual(r['foo'],2.5) - + d1 = {"bla": 2.5, "foo": 1.1} + d2 = {"bla": -2, "foo": -0.3} + r = addmul_number_dicts(((2, d1), (-1, d2))) + self.assertEqual(len(r), 2) + self.assertAlmostEqual(r["bla"], 7) + self.assertAlmostEqual(r["foo"], 2.5) + def test_addmul_number_dicts2(self): - d1 = {'bla': 2.5, 'foo': 1.1} - d2 = {'bla': -2, 'foo': -0.3, 'moon': 8.5} - r = addmul_number_dicts(((2,d1),(-1,d2))) - self.assertEqual(len(r),2) - self.assertEqual(r['bla'],7) - self.assertEqual(r['foo'],2.5) - + d1 = {"bla": 2.5, "foo": 1.1} + d2 = {"bla": -2, "foo": -0.3, "moon": 8.5} + r = addmul_number_dicts(((2, d1), (-1, d2))) + self.assertEqual(len(r), 2) + self.assertEqual(r["bla"], 7) + self.assertEqual(r["foo"], 2.5) + def test_addmul_number_dicts3(self): - series = [(1.0, {'carry1w': 0.08903324115987132, - 'pv': '17.7', - 'carry3m': 1.02, - 'carry6m': 1.9645094151419826, - 'carry1y': 3.7291316215073422, - 'irdelta': '#Err'}), - (1.0, {'carry1w': 0.025649796255470036, - 'pv': 12.1, - 'carry3m': '-0.61', - 'carry6m': 1.77763873433023, - 'carry1y': 5.566080890214712, - 'irdelta': '#Err'}), - (-1.0, {'carry1w': '#Err', - 'pv': 18.1, - 'carry3m': -0.04, - 'irdelta': 1})] + series = [ + ( + 1.0, + { + "carry1w": 0.08903324115987132, + "pv": "17.7", + "carry3m": 1.02, + "carry6m": 1.9645094151419826, + "carry1y": 3.7291316215073422, + "irdelta": "#Err", + }, + ), + ( + 1.0, + { + "carry1w": 0.025649796255470036, + "pv": 12.1, + "carry3m": "-0.61", + "carry6m": 1.77763873433023, + "carry1y": 5.566080890214712, + "irdelta": "#Err", + }, + ), + (-1.0, {"carry1w": "#Err", "pv": 18.1, "carry3m": -0.04, "irdelta": 1}), + ] r = addmul_number_dicts(series) self.assertEqual(len(r), 2) - self.assertAlmostEqual(r['pv'], 11.7) - self.assertAlmostEqual(r['carry3m'], 0.45) - + self.assertAlmostEqual(r["pv"], 11.7) + self.assertAlmostEqual(r["carry3m"], 0.45) + def test_addmul_nested_dicts(self): - d1 = {'bla': {'bla1': 2.5}, 'foo': 1.1} - d2 = {'bla': {'bla1': -2}, 'foo': -0.3, 'moon': 8.5} - r = addmul_number_dicts(((2,d1),(-1,d2))) - self.assertEqual(len(r),2) - self.assertEqual(r['bla']['bla1'],7) - self.assertEqual(r['foo'],2.5) - - + d1 = {"bla": {"bla1": 2.5}, "foo": 1.1} + d2 = {"bla": {"bla1": -2}, "foo": -0.3, "moon": 8.5} + r = addmul_number_dicts(((2, d1), (-1, d2))) + self.assertEqual(len(r), 2) + self.assertEqual(r["bla"]["bla1"], 7) + self.assertEqual(r["foo"], 2.5) + + class testFunctions(test.TestCase): - def testGrouper(self): - r = grouper(2,[1,2,3,4,5,6,7]) - self.assertFalse(hasattr(r,'__len__')) - self.assertEqual(list(r),[(1,2),(3,4),(5,6),(7,None)]) - r = grouper(3,'abcdefg','x') - self.assertFalse(hasattr(r,'__len__')) - self.assertEqual(list(r),[('a','b','c'),('d','e','f'),('g','x','x')]) - + r = grouper(2, [1, 2, 3, 4, 5, 6, 7]) + self.assertFalse(hasattr(r, "__len__")) + self.assertEqual(list(r), [(1, 2), (3, 4), (5, 6), (7, None)]) + r = grouper(3, "abcdefg", "x") + self.assertFalse(hasattr(r, "__len__")) + self.assertEqual(list(r), [("a", "b", "c"), ("d", "e", "f"), ("g", "x", "x")]) + def testFormatInt(self): - self.assertEqual(_format_int(4500),'4,500') - self.assertEqual(_format_int(4500780),'4,500,780') - self.assertEqual(_format_int(500),'500') - self.assertEqual(_format_int(-780),'-780') - self.assertEqual(_format_int(-4500780),'-4,500,780') - + self.assertEqual(_format_int(4500), "4,500") + self.assertEqual(_format_int(4500780), "4,500,780") + self.assertEqual(_format_int(500), "500") + self.assertEqual(_format_int(-780), "-780") + self.assertEqual(_format_int(-4500780), "-4,500,780") + def testPopulateIntegers(self): - data = populate('integer', size = 33) - self.assertEqual(len(data),33) + data = populate("integer", size=33) + self.assertEqual(len(data), 33) for d in data: - self.assertTrue(isinstance(d,int)) - + self.assertTrue(isinstance(d, int)) + def testAbstarctEncoder(self): e = encoders.Encoder() - self.assertRaises(NotImplementedError , e.dumps, 'bla') - self.assertRaises(NotImplementedError , e.loads, 'bla') - + self.assertRaises(NotImplementedError, e.dumps, "bla") + self.assertRaises(NotImplementedError, e.loads, "bla") + def test_to_bytes(self): - self.assertEqual(to_bytes(b'ciao'),b'ciao') - b = b'perch\xc3\xa9' - u = b.decode('utf-8') - l = u.encode('latin') - self.assertEqual(to_bytes(b,'latin'),l) - self.assertEqual(to_string(l,'latin'),u) - self.assertEqual(to_bytes(1), b'1') - + self.assertEqual(to_bytes(b"ciao"), b"ciao") + b = b"perch\xc3\xa9" + u = b.decode("utf-8") + l = u.encode("latin") + self.assertEqual(to_bytes(b, "latin"), l) + self.assertEqual(to_string(l, "latin"), u) + self.assertEqual(to_bytes(1), b"1") + def test_git_version(self): g = get_git_changeset() # In travis this is None. # TODO: better test on this - #self.assertTrue(g) - - \ No newline at end of file + # self.assertTrue(g) diff --git a/tests/all/utils/zset.py b/tests/all/utils/zset.py index bb75720..b3586fa 100644 --- a/tests/all/utils/zset.py +++ b/tests/all/utils/zset.py @@ -1,55 +1,54 @@ from random import randint + from stdnet.utils import test from stdnet.utils.zset import zset + class TestPythonZset(test.TestCase): - def test_add(self): s = zset() - s.add(3, 'ciao') - s.add(4, 'bla') + s.add(3, "ciao") + s.add(4, "bla") self.assertEqual(len(s), 2) - s.add(-1, 'bla') + s.add(-1, "bla") self.assertEqual(len(s), 2) data = list(s) - self.assertEqual(data, ['bla', 'ciao']) - + self.assertEqual(data, ["bla", "ciao"]) + def test_rank(self): s = zset() - s.add(3, 'ciao') - s.add(4, 'bla') - s.add(2, 'foo') - s.add(20, 'pippo') - s.add(-1, 'bla') + s.add(3, "ciao") + s.add(4, "bla") + s.add(2, "foo") + s.add(20, "pippo") + s.add(-1, "bla") self.assertEqual(len(s), 4) - self.assertEqual(s.rank('bla'), 0) - self.assertEqual(s.rank('foo'), 1) - self.assertEqual(s.rank('ciao'), 2) - self.assertEqual(s.rank('pippo'), 3) - self.assertEqual(s.rank('xxxx'), None) - + self.assertEqual(s.rank("bla"), 0) + self.assertEqual(s.rank("foo"), 1) + self.assertEqual(s.rank("ciao"), 2) + self.assertEqual(s.rank("pippo"), 3) + self.assertEqual(s.rank("xxxx"), None) + def test_update(self): - string = test.populate('string', size=100) - values = test.populate('float', size=100) + string = test.populate("string", size=100) + values = test.populate("float", size=100) s = zset() - s.update(zip(values,string)) + s.update(zip(values, string)) self.assertTrue(s) prev = None for score, _ in s.items(): if prev is not None: - self.assertTrue(score>=prev) + self.assertTrue(score >= prev) prev = score return s - + def test_remove(self): s = self.test_update() values = list(s) while values: - index = randint(0, len(values)-1) + index = randint(0, len(values) - 1) val = values.pop(index) self.assertTrue(val in s) self.assertNotEqual(s.remove(val), None) self.assertFalse(val in s) self.assertFalse(s) - - \ No newline at end of file From 220fa7b4fe230a1118e5590d218c9d2f88ef33bf Mon Sep 17 00:00:00 2001 From: Luca Date: Thu, 3 Dec 2020 21:23:31 +0000 Subject: [PATCH 2/5] remove old trevis stuff --- .github/workflows/build.yml | 8 ++++---- Makefile | 11 ++++++++++- README.rst | 24 +----------------------- 3 files changed, 15 insertions(+), 28 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 3796015..eebb95d 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -17,12 +17,12 @@ jobs: COVERALLS_REPO_TOKEN: ${{ secrets.COVERALLS_REPO_TOKEN }} strategy: matrix: - python-version: [3.6, 3.7, 3.8, 3.9] + python-version: [3.8, 3.9] steps: - uses: actions/checkout@v2 - - name: run postgres - run: make postgresql + - name: run redis + run: make redis - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v2 with: @@ -30,7 +30,7 @@ jobs: - name: Install dependencies run: make install - name: run lint - run: make test-lint + run: make lint-check - name: run tests run: make test - name: upload coverage diff --git a/Makefile b/Makefile index 2bc16be..87706e3 100644 --- a/Makefile +++ b/Makefile @@ -21,7 +21,7 @@ clean: ## remove python cache files cloc: ## Count lines of code - cloc --exclude-dir=tools,venv,node-modules,build,.pytest_cache,.mypy_cache,target . + cloc --exclude-dir=build,venv,.venv,.pytest_cache,.mypy_cache . install: ## install python dependencies in venv @@ -34,3 +34,12 @@ lint: ## run linters isort . ./dev/run-black.sh flake8 + +lint-check: ## run linters in check mode + flake8 + isort . --check + ./dev/run-black.sh --check + + +redis: ## run redis for testing + docker run --rm --network=host --name=stdnet -d redis:6 diff --git a/README.rst b/README.rst index 191711e..62bedbc 100755 --- a/README.rst +++ b/README.rst @@ -7,25 +7,6 @@ and instances of those classes with **items** in their corresponding collections Collections and items are different for different backend databases but are treated in the same way in the python language domain. -:Master CI: |master-build|_ |coverage| -:Dev CI: |dev-build|_ |coverage-dev| -:Documentation: http://pythonhosted.org/python-stdnet/ -:Dowloads: http://pypi.python.org/pypi/python-stdnet/ -:Source: https://github.com/lsbardel/python-stdnet -:Platforms: Linux, OS X, Windows. Python 2.6, 2.7, 3.2, 3.3, pypy_ -:Mailing List: https://groups.google.com/group/python-stdnet -:Keywords: server, database, redis, odm - - -.. |master-build| image:: https://secure.travis-ci.org/lsbardel/python-stdnet.png?branch=master -.. _master-build: http://travis-ci.org/lsbardel/python-stdnet -.. |dev-build| image:: https://secure.travis-ci.org/lsbardel/python-stdnet.png?branch=dev -.. _dev-build: http://travis-ci.org/lsbardel/python-stdnet -.. |coverage| image:: https://coveralls.io/repos/lsbardel/python-stdnet/badge.png?branch=master - :target: https://coveralls.io/r/lsbardel/python-stdnet?branch=master -.. |coverage-dev| image:: https://coveralls.io/repos/lsbardel/python-stdnet/badge.png?branch=dev - :target: https://coveralls.io/r/lsbardel/python-stdnet?branch=dev - Contents ~~~~~~~~~~~~~~~ @@ -49,10 +30,7 @@ Features Requirements ================= -* Python 2.6, 2.7, 3.2, 3.3 and pypy_. Single code-base. -* redis-py_ for redis backend. -* Optional pulsar_ when using the asynchronous connections or the test suite. -* You need access to a Redis_ server version 2.6 or above. +* Python 3.6 and up Philosophy From 5b11514aaf0f65ec4840c84b2218d868bdd915cc Mon Sep 17 00:00:00 2001 From: Luca Date: Thu, 3 Dec 2020 21:42:11 +0000 Subject: [PATCH 3/5] lots to do --- .gitignore | 2 + setup.cfg | 2 +- stdnet/backends/__init__.py | 20 ++--- stdnet/utils/__init__.py | 31 +------ stdnet/utils/fallbacks/__init__.py | 0 stdnet/utils/fallbacks/_collections.py | 109 ------------------------- stdnet/utils/fallbacks/_importlib.py | 36 -------- stdnet/utils/fallbacks/py2/__init__.py | 9 -- 8 files changed, 14 insertions(+), 195 deletions(-) delete mode 100755 stdnet/utils/fallbacks/__init__.py delete mode 100755 stdnet/utils/fallbacks/_collections.py delete mode 100755 stdnet/utils/fallbacks/_importlib.py delete mode 100644 stdnet/utils/fallbacks/py2/__init__.py diff --git a/.gitignore b/.gitignore index 7ef1d48..6e97785 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,8 @@ *.def dist venv +.venv +.vscode __pycache__ extensions/src/cparser.cpp build diff --git a/setup.cfg b/setup.cfg index 8f6e244..e0ce8a2 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [flake8] -exclude = __pycache__,.eggs,venv,build,dist,docs,dev +exclude = __pycache__,.eggs,venv,build,dist,docs,dev,stdnet/apps/searchengine/processors/metaphone.py max-line-length = 88 ignore = A001,A002,A003,C815,C812,W503,E203 diff --git a/stdnet/backends/__init__.py b/stdnet/backends/__init__.py index 0ed395c..084beee 100755 --- a/stdnet/backends/__init__.py +++ b/stdnet/backends/__init__.py @@ -2,14 +2,6 @@ from collections import namedtuple from inspect import isgenerator -try: - from pulsar import maybe_async as async -except ImportError: # pragma noproxy - - def async(gen): - raise NotImplementedError - - from stdnet.utils import ( int_or_float, iteritems, @@ -32,7 +24,6 @@ def async(gen): "range_lookups", "getdb", "settings", - "async", ] @@ -46,8 +37,13 @@ def async(gen): session_data = namedtuple("session_data", "meta dirty deletes queries structures") session_result = namedtuple("session_result", "meta results") -pass_through = lambda x: x -str_lower_case = lambda x: to_string(x).lower() + +def pass_through(x): + return x + + +def str_lower_case(x: str): + return to_string(x).lower() range_lookups = { @@ -285,7 +281,7 @@ def structure(self, instance, client=None): def execute(self, result, callback=None): if self.is_async(): - result = async(result) + # result = async(result) if callback: return result.add_callback(callback) else: diff --git a/stdnet/utils/__init__.py b/stdnet/utils/__init__.py index b429472..2574ea4 100755 --- a/stdnet/utils/__init__.py +++ b/stdnet/utils/__init__.py @@ -1,32 +1,7 @@ from collections import Mapping -from inspect import istraceback -from itertools import chain +from itertools import chain, zip_longest from uuid import uuid4 -from .py2py3 import * - -if ispy3k: # pragma: no cover - import pickle - - unichr = chr - - def raise_error_trace(err, traceback): - if istraceback(traceback): - raise err.with_traceback(traceback) - else: - raise err - - -else: # pragma: no cover - import cPickle as pickle - - unichr = unichr - from .fallbacks.py2 import raise_error_trace - -from .dates import * -from .jsontools import * -from .populate import populate - def gen_unique_id(short=True): id = str(uuid4()) @@ -37,7 +12,7 @@ def gen_unique_id(short=True): def iterpair(iterable): if isinstance(iterable, Mapping): - return iteritems(iterable) + return iterable.items() else: return iterable @@ -93,7 +68,7 @@ def flat2d(iterable): def _flatzsetdict(kwargs): - for k, v in iteritems(kwargs): + for k, v in kwargs.items(): yield v yield k diff --git a/stdnet/utils/fallbacks/__init__.py b/stdnet/utils/fallbacks/__init__.py deleted file mode 100755 index e69de29..0000000 diff --git a/stdnet/utils/fallbacks/_collections.py b/stdnet/utils/fallbacks/_collections.py deleted file mode 100755 index 636ac1e..0000000 --- a/stdnet/utils/fallbacks/_collections.py +++ /dev/null @@ -1,109 +0,0 @@ -from UserDict import DictMixin - -__all__ = ["OrderedDict"] - - -class OrderedDict(dict, DictMixin): - """Drop-in substitute for Py2.7's new collections.OrderedDict. - The recipe has big-oh performance that matches regular dictionaries - (amortized O(1) insertion/deletion/lookup and O(n) - iteration/repr/copy/equality_testing). - - From http://code.activestate.com/recipes/576693/""" - - def __init__(self, *args, **kwds): - if len(args) > 1: - raise TypeError("expected at most 1 arguments, got %d" % len(args)) - try: - self.__end - except AttributeError: - self.clear() - self.update(*args, **kwds) - - def clear(self): - self.__end = end = [] - end += [None, end, end] # sentinel node for doubly linked list - self.__map = {} # key --> [key, prev, next] - dict.clear(self) - - def __setitem__(self, key, value): - if key not in self: - end = self.__end - curr = end[1] - curr[2] = end[1] = self.__map[key] = [key, curr, end] - dict.__setitem__(self, key, value) - - def __delitem__(self, key): - dict.__delitem__(self, key) - key, prev, next = self.__map.pop(key) - prev[2] = next - next[1] = prev - - def __iter__(self): - end = self.__end - curr = end[2] - while curr is not end: - yield curr[0] - curr = curr[2] - - def __reversed__(self): - end = self.__end - curr = end[1] - while curr is not end: - yield curr[0] - curr = curr[1] - - def popitem(self, last=True): - if not self: - raise KeyError("dictionary is empty") - if last: - key = reversed(self).next() - else: - key = iter(self).next() - value = self.pop(key) - return key, value - - def __reduce__(self): - items = [[k, self[k]] for k in self] - tmp = self.__map, self.__end - del self.__map, self.__end - inst_dict = vars(self).copy() - self.__map, self.__end = tmp - if inst_dict: - return (self.__class__, (items,), inst_dict) - return self.__class__, (items,) - - def keys(self): - return list(self) - - setdefault = DictMixin.setdefault - update = DictMixin.update - pop = DictMixin.pop - values = DictMixin.values - items = DictMixin.items - iterkeys = DictMixin.iterkeys - itervalues = DictMixin.itervalues - iteritems = DictMixin.iteritems - - def __repr__(self): - if not self: - return "%s()" % (self.__class__.__name__,) - return "%s(%r)" % (self.__class__.__name__, self.items()) - - def copy(self): - return self.__class__(self) - - @classmethod - def fromkeys(cls, iterable, value=None): - d = cls() - for key in iterable: - d[key] = value - return d - - def __eq__(self, other): - if isinstance(other, OrderedDict): - return len(self) == len(other) and self.items() == other.items() - return dict.__eq__(self, other) - - def __ne__(self, other): - return not self == other diff --git a/stdnet/utils/fallbacks/_importlib.py b/stdnet/utils/fallbacks/_importlib.py deleted file mode 100755 index fc92331..0000000 --- a/stdnet/utils/fallbacks/_importlib.py +++ /dev/null @@ -1,36 +0,0 @@ -# Taken from Python 2.7 -import sys - - -def _resolve_name(name, package, level): - """Return the absolute name of the module to be imported.""" - if not hasattr(package, "rindex"): - raise ValueError("'package' not set to a string") - dot = len(package) - for x in xrange(level, 1, -1): - try: - dot = package.rindex(".", 0, dot) - except ValueError: - raise ValueError("attempted relative import beyond top-level " "package") - return "%s.%s" % (package[:dot], name) - - -def import_module(name, package=None): - """Import a module. - - The 'package' argument is required when performing a relative import. It - specifies the package to use as the anchor point from which to resolve the - relative import to an absolute import. - - """ - if name.startswith("."): - if not package: - raise TypeError("relative imports require the 'package' argument") - level = 0 - for character in name: - if character != ".": - break - level += 1 - name = _resolve_name(name[level:], package, level) - __import__(name) - return sys.modules[name] diff --git a/stdnet/utils/fallbacks/py2/__init__.py b/stdnet/utils/fallbacks/py2/__init__.py deleted file mode 100644 index d1266b1..0000000 --- a/stdnet/utils/fallbacks/py2/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# This is for python 2.x -from inspect import istraceback - - -def raise_error_trace(err, traceback): - if istraceback(traceback): - raise err.__class__, err, traceback - else: - raise err.__class__, err, None From 1b841260a32c6b59ed5654bb9b745f4c8d40941b Mon Sep 17 00:00:00 2001 From: Luca Date: Thu, 3 Dec 2020 22:00:50 +0000 Subject: [PATCH 4/5] x1 --- clean.py | 42 ------------------------------ covrun.py | 10 -------- runtests.py | 74 ----------------------------------------------------- setup.cfg | 2 +- setup.py | 35 +------------------------ 5 files changed, 2 insertions(+), 161 deletions(-) delete mode 100644 clean.py delete mode 100644 covrun.py delete mode 100755 runtests.py diff --git a/clean.py b/clean.py deleted file mode 100644 index fc981fe..0000000 --- a/clean.py +++ /dev/null @@ -1,42 +0,0 @@ -import os -import shutil - - -def rmgeneric(path, __func__): - try: - __func__(path) - # print 'Removed ', path - return 1 - except OSError as e: - print("Could not remove {0}, {1}".format(path, e)) - return 0 - - -def rmfiles(path, ext=None, rmcache=True): - if not os.path.isdir(path): - return 0 - trem = 0 - tall = 0 - files = os.listdir(path) - for f in files: - fullpath = os.path.join(path, f) - if os.path.isfile(fullpath): - sf = f.split(".") - if len(sf) == 2: - if ext == None or sf[1] == ext: - tall += 1 - trem += rmgeneric(fullpath, os.remove) - elif f == "__pycache__" and rmcache: - shutil.rmtree(fullpath) - tall += 1 - elif os.path.isdir(fullpath): - r, ra = rmfiles(fullpath, ext) - trem += r - tall += ra - return trem, tall - - -if __name__ == "__main__": - path = os.curdir - removed, allfiles = rmfiles(path, "pyc") - print("removed {0} pyc files out of {1}".format(removed, allfiles)) diff --git a/covrun.py b/covrun.py deleted file mode 100644 index 515903e..0000000 --- a/covrun.py +++ /dev/null @@ -1,10 +0,0 @@ -import os -import sys - -from runtests import run - -if __name__ == "__main__": - if sys.version_info > (3, 3): - run(coverage=True, coveralls=True) - else: - run() diff --git a/runtests.py b/runtests.py deleted file mode 100755 index 910891d..0000000 --- a/runtests.py +++ /dev/null @@ -1,74 +0,0 @@ -#!/usr/bin/env python -"""Stdnet asynchronous test suite. Requires pulsar.""" -import os -import sys -from multiprocessing import current_process - -## This is for dev environment with pulsar and dynts. -## If not available, some tests won't run -p = os.path -dir = p.dirname(p.dirname(p.abspath(__file__))) -try: - import pulsar -except ImportError: - pdir = p.join(dir, "pulsar") - if os.path.isdir(pdir): - sys.path.append(pdir) - import pulsar - -from pulsar.apps.test import TestSuite -from pulsar.apps.test.plugins import bench, profile -from pulsar.utils.path import Path - -# -try: - import dynts -except ImportError: - pdir = p.join(dir, "dynts") - if os.path.isdir(pdir): - sys.path.append(pdir) - try: - import dynts - except ImportError: - pass - - -def run(**params): - args = params.get("argv", sys.argv) - if "--coverage" in args or params.get("coverage"): - import coverage - - p = current_process() - p._coverage = coverage.coverage(data_suffix=True) - p._coverage.start() - runtests(**params) - - -def runtests(**params): - import stdnet - from stdnet.utils import test - - # - strip_dirs = [Path(stdnet.__file__).parent.parent, os.getcwd()] - # - suite = TestSuite( - description="Stdnet Asynchronous test suite", - modules=("tests.all",), - plugins=(test.StdnetPlugin(), bench.BenchMark(), profile.Profile()), - **params - ) - suite.bind_event("tests", test.create_tests) - suite.start() - # - if suite.cfg.coveralls: - from pulsar.utils.cov import coveralls - - coveralls( - strip_dirs=strip_dirs, - stream=suite.stream, - repo_token="ZQinNe5XNbzQ44xYGTljP8R89jrQ5xTKB", - ) - - -if __name__ == "__main__": - run() diff --git a/setup.cfg b/setup.cfg index e0ce8a2..2563a3d 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [flake8] -exclude = __pycache__,.eggs,venv,build,dist,docs,dev,stdnet/apps/searchengine/processors/metaphone.py +exclude = __pycache__,.eggs,venv,build,dist,docs,dev,tests,stdnet/apps/searchengine/processors/metaphone.py max-line-length = 88 ignore = A001,A002,A003,C815,C812,W503,E203 diff --git a/setup.py b/setup.py index f24bfac..dbab5bf 100644 --- a/setup.py +++ b/setup.py @@ -5,9 +5,6 @@ from setuptools import setup -if sys.version_info < (2, 6): - raise Exception("stdnet requires Python 2.6 or higher.") - package_name = "stdnet" package_fullname = "python-%s" % package_name root_dir = os.path.split(os.path.abspath(__file__))[0] @@ -22,17 +19,6 @@ def get_module(): mod = get_module() -# Try to import lib build -# try: -# from extensions.setup import libparams, BuildFailed -# except ImportError: -# libparams = None -libparams = False - - -def read(fname): - return open(os.path.join(root_dir, fname)).read() - def requirements(): req = read("requirements.txt").replace("\r", "").split("\n") @@ -141,24 +127,5 @@ def status_msgs(*msgs): print("*" * 75) -if libparams is False: - run_setup() -elif libparams is None: - status_msgs( - "WARNING: C extensions could not be compiled, " "Cython is not installed." - ) +if __name__ == "__main__": run_setup() - status_msgs("Plain-Python build succeeded.") -else: - try: - run_setup(libparams) - except BuildFailed as exc: - status_msgs( - exc.msg, - "WARNING: C extensions could not be compiled, " - + "speedups are not enabled.", - "Failure information, if any, is above.", - "Retrying the build without C extensions now.", - ) - run_setup() - status_msgs("Plain-Python build succeeded.") From 4bab3b040db42216937d0ab181ebd9f52db8dc7f Mon Sep 17 00:00:00 2001 From: Luca Date: Thu, 3 Dec 2020 22:31:10 +0000 Subject: [PATCH 5/5] tests --- .github/workflows/build.yml | 4 ++-- tests/all/fields/{fk.py => test_fk.py} | 0 tests/all/utils/{intervals.py => test_intervals.py} | 0 tests/all/utils/{tools.py => test_tools.py} | 2 -- tests/all/utils/{zset.py => test_zset.py} | 0 5 files changed, 2 insertions(+), 4 deletions(-) rename tests/all/fields/{fk.py => test_fk.py} (100%) rename tests/all/utils/{intervals.py => test_intervals.py} (100%) rename tests/all/utils/{tools.py => test_tools.py} (96%) rename tests/all/utils/{zset.py => test_zset.py} (100%) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index eebb95d..871f30b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -29,8 +29,8 @@ jobs: python-version: ${{ matrix.python-version }} - name: Install dependencies run: make install - - name: run lint - run: make lint-check + #- name: run lint + # run: make lint-check - name: run tests run: make test - name: upload coverage diff --git a/tests/all/fields/fk.py b/tests/all/fields/test_fk.py similarity index 100% rename from tests/all/fields/fk.py rename to tests/all/fields/test_fk.py diff --git a/tests/all/utils/intervals.py b/tests/all/utils/test_intervals.py similarity index 100% rename from tests/all/utils/intervals.py rename to tests/all/utils/test_intervals.py diff --git a/tests/all/utils/tools.py b/tests/all/utils/test_tools.py similarity index 96% rename from tests/all/utils/tools.py rename to tests/all/utils/test_tools.py index 6903010..f361859 100755 --- a/tests/all/utils/tools.py +++ b/tests/all/utils/test_tools.py @@ -1,9 +1,7 @@ -import time from datetime import date, datetime from examples.models import Statistics3 -import stdnet from stdnet import odm from stdnet.utils import ( _format_int, diff --git a/tests/all/utils/zset.py b/tests/all/utils/test_zset.py similarity index 100% rename from tests/all/utils/zset.py rename to tests/all/utils/test_zset.py