diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 00000000..b94b4507 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,73 @@ +name: "Test" + +on: + push: + paths-ignore: + - "docs/**" + pull_request: + paths-ignore: + - "docs/**" + schedule: + - cron: '40 1 * * 3' + + +jobs: + test: + name: test-python${{ matrix.python-version }}-sa${{ matrix.sqlalchemy-version }}-${{ matrix.db-engine }} + strategy: + matrix: + python-version: +# - "2.7" +# - "3.4" +# - "3.5" +# - "3.6" +# - "3.7" + - "3.8" +# - "3.9" +# - "3.10" +# - "pypy-3.7" + sqlalchemy-version: + - "<1.4" + - ">=1.4" + db-engine: + - sqlite + - postgres + - postgres-native + - mysql + runs-on: ubuntu-latest + services: + mysql: + image: mysql + ports: + - 3306:3306 + env: + MYSQL_DATABASE: sqlalchemy_continuum_test + MYSQL_ALLOW_EMPTY_PASSWORD: yes + options: >- + --health-cmd "mysqladmin ping" + --health-interval 5s + --health-timeout 2s + --health-retries 3 + postgres: + image: postgres + ports: + - 5432:5432 + env: + POSTGRES_PASSWORD: postgres + POSTGRES_DB: sqlalchemy_continuum_test + options: >- + --health-cmd pg_isready + --health-interval 5s + --health-timeout 2s + --health-retries 3 + steps: + - uses: actions/checkout@v1 + - name: Install sqlalchemy + run: pip3 install 'sqlalchemy${{ matrix.sqlalchemy-version }}' + - name: Build + run: pip3 install -e '.[test]' + - name: Run tests + run: pytest + env: + DB: ${{ matrix.db-engine }} + diff --git a/.gitignore b/.gitignore index a015a2d6..fe795a50 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,9 @@ nosetests.xml .mr.developer.cfg .project .pydevproject + +# mypy +.mypy_cache/ + +# Unit test / coverage reports +.cache diff --git a/.travis.yml b/.travis.yml index 6ee80dba..035bf151 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,5 +1,9 @@ -addons: - postgresql: 9.3 +services: + - mysql + - postgresql + +dist: xenial +sudo: true env: - DB=mysql @@ -15,9 +19,10 @@ before_script: language: python python: - 2.7 - - 3.3 - 3.4 - 3.5 + - 3.6 + - 3.7 install: - pip install -e ".[test]" script: diff --git a/CHANGES.rst b/CHANGES.rst index 646c99a3..39a33ef8 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -4,10 +4,71 @@ Changelog Here you can see the full list of changes between each SQLAlchemy-Continuum release. -1.3.12 (2019-03-18) +1.3.12 (2022-01-18) +^^^^^^^^^^^^^^^^^^^ + +- Support SA 1.4 + +1.3.11 (2020-05-24) +^^^^^^^^^^^^^^^^^^^ + +- Made ModelBuilder create column aliases in version models (#246, courtesy of killthekitten) + + +1.3.10 (2020-05-10) +^^^^^^^^^^^^^^^^^^^ + +- Added explicit "pseudo-backref" relationships for version/parent (#240, courtesy of lgedgar) +- Fixed m2m Bug when an unrelated change is made to a model (#242, courtesy of Andrew-Dickinson) + + +1.3.9 (2019-03-19) ^^^^^^^^^^^^^^^^^^ - Added SA 1.3 support +- Reverted trigger creation from 1.3.7 + + +1.3.8 (2019-02-27) +^^^^^^^^^^^^^^^^^^ + +- Fixed revert to ignore non-columns (#197, courtesy of mauler) + + +1.3.7 (2019-01-13) +^^^^^^^^^^^^^^^^^^ + +- Fix trigger creation during alembic migrations (#209, courtesy of lyndsysimon) + + +1.3.6 (2018-07-30) +^^^^^^^^^^^^^^^^^^ + +- Fixed ResourceClosedErrors from connections leaking when using an external transaction (#196, courtesy of vault) + + +1.3.5 (2018-06-03) +^^^^^^^^^^^^^^^^^^ + +- Track cloned connections (#167, courtesy of netcriptus) + + +1.3.4 (2018-03-07) +^^^^^^^^^^^^^^^^^^ + +- Exclude many-to-many properties from versioning if they are added in exclude parameter (#169, courtesy of fuhrysteve) + + +1.3.3 (2017-11-05) +^^^^^^^^^^^^^^^^^^ + +- Fixed changeset when updating object in same transaction as inserting it (#141, courtesy of oinopion) + + +1.3.2 (2017-10-12) +^^^^^^^^^^^^^^^^^^ + +- Fixed multiple schema handling (#132, courtesy of vault) 1.3.1 (2017-06-28) diff --git a/LICENSE b/LICENSE index d604ce84..cccc1d8f 100644 --- a/LICENSE +++ b/LICENSE @@ -12,8 +12,9 @@ modification, are permitted provided that the following conditions are met: this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. -* The names of the contributors may not be used to endorse or promote products - derived from this software without specific prior written permission. +* Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED diff --git a/benchmark.py b/benchmark.py index 9ee0250f..f4b11aab 100644 --- a/benchmark.py +++ b/benchmark.py @@ -50,7 +50,7 @@ def test_versioning( make_versioned(options=options) - dns = 'postgres://postgres@localhost/sqlalchemy_continuum_test' + dns = 'postgresql://postgres:postgres@localhost/sqlalchemy_continuum_test' versioning_manager.plugins = plugins versioning_manager.transaction_cls = transaction_cls versioning_manager.user_cls = user_cls diff --git a/docs/alembic.rst b/docs/alembic.rst index 562ebb34..d46a4700 100644 --- a/docs/alembic.rst +++ b/docs/alembic.rst @@ -1,6 +1,11 @@ Alembic migrations ================== -Each time you make changes to database structure you should also change the associated history tables. When you make changes to your models SQLAlchemy-Continuum automatically alters the history model definitions, hence you can use `alembic revision --autogenerate` just like before. You just need to make sure `make_versioned` function gets called before alembic gathers all your models. +Each time you make changes to database structure you should also change the associated history tables. When you make changes to your models SQLAlchemy-Continuum automatically alters the history model definitions, hence you can use `alembic revision --autogenerate` just like before. You just need to make sure `make_versioned` function gets called before alembic gathers all your models and `configure_mappers` is called afterwards. Pay close attention when dropping or moving data from parent tables and reflecting these changes to history tables. + +Troubleshooting +############### + +If alembic didn't detect any changes or generates reversed migration (tries to remove `*_version` tables from database instead of creating), make sure that `configure_mappers` was called by alembic command. diff --git a/docs/intro.rst b/docs/intro.rst index e6002b87..8b276994 100644 --- a/docs/intro.rst +++ b/docs/intro.rst @@ -5,7 +5,7 @@ Introduction Why? ^^^^ -SQLAlchemy already has versioning extension. This extension however is very limited. It does not support versioning entire transactions. +SQLAlchemy `already has a versioning extension `_. This extension however is very limited. It does not support versioning entire transactions. Hibernate for Java has Envers, which had nice features but lacks a nice API. Ruby on Rails has papertrail_, which has very nice API but lacks the efficiency and feature set of Envers. @@ -54,7 +54,7 @@ In order to make your models versioned you need two things: from sqlalchemy_continuum import make_versioned - make_versioned() + make_versioned(user_cls=None) class Article(Base): diff --git a/docs/native_versioning.rst b/docs/native_versioning.rst index 76431d83..7b91045f 100644 --- a/docs/native_versioning.rst +++ b/docs/native_versioning.rst @@ -25,9 +25,13 @@ Schema migrations When making schema migrations (for example adding new columns to version tables) you need to remember to call sync_trigger in order to keep the version trigger up-to-date. :: - from sqlalchemy_continuum import versioning_manager # or import your custom one, if you have one - sync_trigger(conn, - 'article_version', - versioning_manager=versioning_manager) + from sqlalchemy_continuum.dialects.postgresql import sync_trigger + sync_trigger(conn, 'article_version') + +If you don't use `PropertyModTrackerPlugin`, then you have to disable it: + +:: + + sync_trigger(conn, 'article_version', use_property_mod_tracking=False) diff --git a/docs/plugins.rst b/docs/plugins.rst index 5f88b6fe..09c1e261 100644 --- a/docs/plugins.rst +++ b/docs/plugins.rst @@ -7,7 +7,7 @@ Using plugins :: - from sqlalchemy.continuum.plugins import PropertyModTrackerPlugin + from sqlalchemy_continuum.plugins import PropertyModTrackerPlugin versioning_manager.plugins.append(PropertyModTrackerPlugin()) diff --git a/docs/version_objects.rst b/docs/version_objects.rst index ecbb71ce..54f0b774 100644 --- a/docs/version_objects.rst +++ b/docs/version_objects.rst @@ -102,7 +102,7 @@ you can easily check the changeset of given object in current transaction. article = Article(name=u'Some article') changeset(article) - # {'name': [u'Some article', None]} + # {'name': [None, u'Some article']} Version relationships diff --git a/setup.py b/setup.py index 28b80328..fda24ebd 100644 --- a/setup.py +++ b/setup.py @@ -28,15 +28,14 @@ def get_version(): 'pytest>=2.3.5', 'flexmock>=0.9.7', 'psycopg2>=2.4.6', - 'PyMySQL==0.6.1', + 'PyMySQL>=0.8.0', 'six>=1.4.0' ], - 'anyjson': ['anyjson>=0.3.3'], 'flask': ['Flask>=0.9'], 'flask-login': ['Flask-Login>=0.2.9'], 'flask-sqlalchemy': ['Flask-SQLAlchemy>=1.0'], 'flexmock': ['flexmock>=0.9.7'], - 'i18n': ['SQLAlchemy-i18n>=0.8.4'], + 'i18n': ['SQLAlchemy-i18n>=0.8.4,!=1.1.0'], } @@ -77,9 +76,10 @@ def get_version(): 'Programming Language :: Python :: 2', 'Programming Language :: Python :: 2.7', 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.3', 'Programming Language :: Python :: 3.4', 'Programming Language :: Python :: 3.5', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', 'Topic :: Internet :: WWW/HTTP :: Dynamic Content', 'Topic :: Software Development :: Libraries :: Python Modules' ] diff --git a/sqlalchemy_continuum/__init__.py b/sqlalchemy_continuum/__init__.py index c48e7fe1..ba4c043c 100644 --- a/sqlalchemy_continuum/__init__.py +++ b/sqlalchemy_continuum/__init__.py @@ -20,7 +20,7 @@ ) -__version__ = '1.4.3' +__version__ = '1.5.0' versioning_manager = VersioningManager() @@ -72,6 +72,18 @@ def make_versioned( manager.track_association_operations ) + sa.event.listen( + sa.engine.Engine, + 'rollback', + manager.clear_connection + ) + + sa.event.listen( + sa.engine.Engine, + 'set_connection_execution_options', + manager.track_cloned_connections + ) + def remove_versioning( mapper=sa.orm.mapper, @@ -98,3 +110,15 @@ def remove_versioning( 'before_cursor_execute', manager.track_association_operations ) + + sa.event.remove( + sa.engine.Engine, + 'rollback', + manager.clear_connection + ) + + sa.event.remove( + sa.engine.Engine, + 'set_connection_execution_options', + manager.track_cloned_connections + ) diff --git a/sqlalchemy_continuum/builder.py b/sqlalchemy_continuum/builder.py index 47bbef35..53134029 100644 --- a/sqlalchemy_continuum/builder.py +++ b/sqlalchemy_continuum/builder.py @@ -1,5 +1,6 @@ from copy import copy from inspect import getmro +from functools import wraps import sqlalchemy as sa from sqlalchemy_utils.functions import get_declarative_base @@ -10,6 +11,18 @@ from .table_builder import TableBuilder +def prevent_reentry(handler): + in_handler = False + @wraps(handler) + def check_reentry(*args, **kwargs): + nonlocal in_handler + if in_handler: + return + in_handler = True + handler(*args, **kwargs) + in_handler = False + return check_reentry + class Builder(object): def build_triggers(self): """ @@ -141,17 +154,20 @@ def build_transaction_class(self): self.manager.create_transaction_model() self.manager.plugins.after_build_tx_class(self.manager) + @prevent_reentry def configure_versioned_classes(self): """ Configures all versioned classes that were collected during - instrumentation process. The configuration has 4 steps: + instrumentation process. The configuration has 6 steps: 1. Build tables for version models. 2. Build the actual version model declarative classes. 3. Build relationships between these models. 4. Empty pending_classes list so that consecutive mapper configuration does not create multiple version classes - 5. Assign all versioned attributes to use active history. + 5. Build aliases for columns. + 6. Assign all versioned attributes to use active history. + """ if not self.manager.options['versioning']: return @@ -168,11 +184,39 @@ def configure_versioned_classes(self): # Create copy of all pending versioned classes so that we can inspect # them later when creating relationships. - pending_copy = copy(self.manager.pending_classes) + pending_classes_copies = copy(self.manager.pending_classes) self.manager.pending_classes = [] - self.build_relationships(pending_copy) + self.build_relationships(pending_classes_copies) + self.enable_active_history(pending_classes_copies) + self.create_column_aliases(pending_classes_copies) - for cls in pending_copy: - # set the "active_history" flag + def enable_active_history(self, version_classes): + """ + Assign all versioned attributes to use active history. + """ + for cls in version_classes: for prop in sa.inspect(cls).iterate_properties: getattr(cls, prop.key).impl.active_history = True + + def create_column_aliases(self, version_classes): + """ + Create aliases for the columns from the original model. + + This, for example, imitates the behavior of @declared_attr columns. + """ + for cls in version_classes: + model_mapper = sa.inspect(cls) + version_class = self.manager.version_class_map.get(cls) + if not version_class: + continue + + version_class_mapper = sa.inspect(version_class) + + for key, column in model_mapper.columns.items(): + if key != column.key: + version_class_column = version_class.__table__.c.get(column.key) + + if version_class_column is None: + continue + + version_class_mapper.add_property(key, sa.orm.column_property(version_class_column)) diff --git a/sqlalchemy_continuum/dialects/postgresql.py b/sqlalchemy_continuum/dialects/postgresql.py index da68359d..40d2735f 100644 --- a/sqlalchemy_continuum/dialects/postgresql.py +++ b/sqlalchemy_continuum/dialects/postgresql.py @@ -489,10 +489,7 @@ def reverse_table_name_format(version_table_name_format): return '^' + version_table_name_format.replace('%s', '(.*)') + '$' DEFAULT_VERSION_TABLE_NAME_FORMAT = '%s_version' -def sync_trigger(conn, - table_name, - versioning_manager, - schema=None): +def sync_trigger(conn, table_name, **kwargs): """ Synchronizes versioning trigger for given table with given connection. @@ -504,29 +501,19 @@ def sync_trigger(conn, :param conn: SQLAlchemy connection object :param table_name: Name of the table to synchronize versioning trigger for - :param versioning_manager: (Optional) the versioning manager + :params **kwargs: kwargs to pass to create_trigger .. versionadded: 1.1.0 """ - custom_version_table_name_format = versioning_manager.options.get('table_name') if versioning_manager else None - version_table_name_format = custom_version_table_name_format or DEFAULT_VERSION_TABLE_NAME_FORMAT - parent_table_name_regex = reverse_table_name_format(version_table_name_format) - - meta = sa.MetaData(schema=schema) + meta = sa.MetaData() version_table = sa.Table( table_name, meta, autoload=True, autoload_with=conn ) - - try: - parent_table_name = re.findall(parent_table_name_regex, table_name)[0] - except IndexError: - raise ValueError('The version table name %s that was provided to sync_trigger does not conform to the format %s' % (table_name, version_table_name_format)) - parent_table = sa.Table( - parent_table_name, + table_name[0:-len('_version')], meta, autoload=True, autoload_with=conn @@ -535,11 +522,8 @@ def sync_trigger(conn, set(c.name for c in parent_table.c) - set(c.name for c in version_table.c if not c.name.endswith('_mod')) ) - drop_trigger(conn, parent_table.name, parent_table.schema) - create_trigger(conn, - table=parent_table, - versioning_manager=versioning_manager, - excluded_columns=excluded_columns) + drop_trigger(conn, parent_table.name) + create_trigger(conn, table=parent_table, excluded_columns=excluded_columns, **kwargs) def create_trigger( diff --git a/sqlalchemy_continuum/factory.py b/sqlalchemy_continuum/factory.py index 5e36dc81..9951f67a 100644 --- a/sqlalchemy_continuum/factory.py +++ b/sqlalchemy_continuum/factory.py @@ -6,7 +6,11 @@ def __call__(self, manager): Create model class but only if it doesn't already exist in declarative model registry. """ - registry = manager.declarative_base._decl_class_registry + Base = manager.declarative_base + try: + registry = Base.registry._class_registry + except AttributeError: # SQLAlchemy < 1.4 + registry = Base._decl_class_registry if self.model_name not in registry: return self.create_class(manager) return registry[self.model_name] diff --git a/sqlalchemy_continuum/fetcher.py b/sqlalchemy_continuum/fetcher.py index 1ac1a175..0f262b6f 100644 --- a/sqlalchemy_continuum/fetcher.py +++ b/sqlalchemy_continuum/fetcher.py @@ -59,7 +59,7 @@ def _transaction_id_subquery(self, obj, next_or_prev='next', alias=None): func = sa.func.max if alias is None: - alias = sa.orm.aliased(obj) + alias = sa.orm.aliased(obj.__class__) table = alias.__table__ if hasattr(alias, 'c'): attrs = alias.c @@ -117,7 +117,7 @@ def _index_query(self, obj): Returns the query needed for fetching the index of this record relative to version history. """ - alias = sa.orm.aliased(obj) + alias = sa.orm.aliased(obj.__class__) subquery = ( sa.select([sa.func.count('1')], from_obj=[alias.__table__]) diff --git a/sqlalchemy_continuum/manager.py b/sqlalchemy_continuum/manager.py index 99ace2c1..218eadce 100644 --- a/sqlalchemy_continuum/manager.py +++ b/sqlalchemy_continuum/manager.py @@ -25,7 +25,15 @@ def wrapper(self, mapper, connection, target): try: uow = self.units_of_work[conn] except KeyError: - uow = self.units_of_work.get(conn.engine, None) + try: + uow = self.units_of_work[conn.engine] + except KeyError: + for connection in self.units_of_work.keys(): + if not connection.closed and connection.connection is conn.connection: + uow = self.unit_of_work(session) + break # The ConnectionFairy is the same, this connection is a clone + else: + raise return func(self, uow, target) return wrapper @@ -409,11 +417,39 @@ def clear(self, session): if session.transaction.nested: return conn = self.session_connection_map.pop(session, None) + if conn is None: + return + if conn in self.units_of_work: uow = self.units_of_work[conn] uow.reset(session) del self.units_of_work[conn] + for connection in dict(self.units_of_work).keys(): + if connection.closed or conn.connection is connection.connection: + uow = self.units_of_work[connection] + uow.reset(session) + del self.units_of_work[connection] + + def clear_connection(self, conn): + if conn in self.units_of_work: + uow = self.units_of_work[conn] + uow.reset() + del self.units_of_work[conn] + + + for session, connection in dict(self.session_connection_map).items(): + if connection is conn: + del self.session_connection_map[session] + + + for connection in dict(self.units_of_work).keys(): + if connection.closed or conn.connection is connection.connection: + uow = self.units_of_work[connection] + uow.reset() + del self.units_of_work[connection] + + def append_association_operation(self, conn, table_name, params, op): """ Append history association operation to pending_statements list. @@ -440,9 +476,23 @@ def append_association_operation(self, conn, table_name, params, op): try: uow = self.units_of_work[conn.engine] except KeyError: - return + for connection in self.units_of_work.keys(): + if not connection.closed and connection.connection is conn.connection: + uow = self.unit_of_work(conn.session) + break # The ConnectionFairy is the same, this connection is a clone + else: + raise uow.pending_statements.append(stmt) + def track_cloned_connections(self, c, opt): + """ + Track cloned connections from association tables. + """ + if c not in self.units_of_work.keys(): + for connection, uow in dict(self.units_of_work).items(): + if not connection.closed and connection.connection is c.connection: # ConnectionFairy is the same - this is a clone + self.units_of_work[c] = uow + def track_association_operations( self, conn, cursor, statement, parameters, context, executemany ): diff --git a/sqlalchemy_continuum/model_builder.py b/sqlalchemy_continuum/model_builder.py index 6987a575..d62904bc 100644 --- a/sqlalchemy_continuum/model_builder.py +++ b/sqlalchemy_continuum/model_builder.py @@ -107,6 +107,7 @@ class represents). """ conditions = [] foreign_keys = [] + model_keys = [] for key, column in sa.inspect(self.model).columns.items(): if column.primary_key: conditions.append( @@ -117,6 +118,9 @@ class represents). foreign_keys.append( getattr(self.version_class, key) ) + model_keys.append( + getattr(self.model, key) + ) # We need to check if versions relation was already set for parent # class. @@ -130,12 +134,19 @@ class represents). option(self.model, 'transaction_column_name') ), lazy='dynamic', - backref=sa.orm.backref( - 'version_parent' - ), viewonly=True, sync_backref=False ) + # We must explicitly declare this relationship, instead of + # specifying as a backref to the one above, since they are + # viewonly=True and SQLAlchemy will warn if using backref. + self.version_class.version_parent = sa.orm.relationship( + self.model, + primaryjoin=sa.and_(*conditions), + foreign_keys=model_keys, + viewonly=True, + uselist=False, + ) def build_transaction_relationship(self, tx_class): """ @@ -263,6 +274,7 @@ def mapper_args(cls): name = '%sVersion' % (self.model.__name__,) return type(name, self.base_classes(), args) + def __call__(self, table, tx_class): """ Build history model and relationships to parent model, transaction diff --git a/sqlalchemy_continuum/plugins/activity.py b/sqlalchemy_continuum/plugins/activity.py index 8905079a..10b85d3f 100644 --- a/sqlalchemy_continuum/plugins/activity.py +++ b/sqlalchemy_continuum/plugins/activity.py @@ -191,6 +191,7 @@ import sqlalchemy as sa from sqlalchemy.ext.hybrid import hybrid_property +from sqlalchemy.inspection import inspect from sqlalchemy_utils import JSONType, generic_relationship from .base import Plugin @@ -254,11 +255,13 @@ def _calculate_tx_id(self, obj): if object_version: return object_version.transaction_id - version_cls = version_class(obj.__class__) + model = obj.__class__ + version_cls = version_class(model) + primary_key = inspect(model).primary_key[0].name return session.query( sa.func.max(version_cls.transaction_id) ).filter( - version_cls.id == obj.id + getattr(version_cls, primary_key) == getattr(obj, primary_key) ).scalar() def calculate_object_tx_id(self): @@ -314,6 +317,8 @@ def target_version_type(cls): class ActivityPlugin(Plugin): + activity_cls = None + def after_build_models(self, manager): self.activity_cls = ActivityFactory()(manager) manager.activity_cls = self.activity_cls diff --git a/sqlalchemy_continuum/plugins/flask.py b/sqlalchemy_continuum/plugins/flask.py index ad9ac19d..62a021a2 100644 --- a/sqlalchemy_continuum/plugins/flask.py +++ b/sqlalchemy_continuum/plugins/flask.py @@ -36,7 +36,7 @@ def fetch_current_user_id(): if _app_ctx_stack.top is None or _request_ctx_stack.top is None: return try: - return current_user.id + return current_user.get_id() except AttributeError: return diff --git a/sqlalchemy_continuum/relationship_builder.py b/sqlalchemy_continuum/relationship_builder.py index 17cce3b6..d48b60ff 100644 --- a/sqlalchemy_continuum/relationship_builder.py +++ b/sqlalchemy_continuum/relationship_builder.py @@ -249,6 +249,7 @@ def association_subquery(self, obj): FROM article_tag_version as article_tag_version2 WHERE article_tag_version2.tag_id = article_tag_version.tag_id AND article_tag_version2.tx_id <=5 + AND article_tag_version2.article_id = 3 GROUP BY article_tag_version2.tag_id HAVING MAX(article_tag_version2.tx_id) = @@ -260,6 +261,8 @@ def association_subquery(self, obj): """ tx_column = option(obj, 'transaction_column_name') + join_column = self.property.primaryjoin.right.name + object_join_column = self.property.primaryjoin.left.name reflector = VersionExpressionReflector(obj, self.property) association_table_alias = self.association_version_table.alias() @@ -276,6 +279,7 @@ def association_subquery(self, obj): sa.and_( association_table_alias.c[tx_column] <= getattr(obj, tx_column), + association_table_alias.c[join_column] == getattr(obj, object_join_column), *[association_col == self.association_version_table.c[association_col.name] for association_col @@ -317,9 +321,13 @@ def build_association_version_tables(self): self.model ) metadata = column.table.metadata - table_schema = apply_table_schema(self.manager.option(self.model, 'table_schema'), - column.table.schema or metadata.schema) - table_name = ((table_schema + '.') if table_schema else '') + builder.table_name + + if builder.parent_table.schema: + table_name = builder.parent_table.schema + '.' + builder.table_name + elif metadata.schema: + table_name = metadata.schema + '.' + builder.table_name + else: + table_name = builder.table_name if table_name not in metadata.tables: self.association_version_table = table = builder() @@ -344,7 +352,10 @@ def __call__(self): except ClassNotVersioned: self.remote_cls = self.property.mapper.class_ - if self.property.secondary is not None and not self.property.viewonly: + if (self.property.secondary is not None and + not self.property.viewonly and + not self.manager.is_excluded_property( + self.model, self.property.key)): self.build_association_version_tables() # store remote cls to association table column pairs diff --git a/sqlalchemy_continuum/table_builder.py b/sqlalchemy_continuum/table_builder.py index 5909d092..4056e697 100644 --- a/sqlalchemy_continuum/table_builder.py +++ b/sqlalchemy_continuum/table_builder.py @@ -152,6 +152,6 @@ def __call__(self, extends=None): extends.name if extends is not None else self.table_name, self.parent_table.metadata, *columns, - extend_existing=extends is not None, - schema=apply_table_schema(self.option('table_schema'), self.parent_table.schema) + schema=self.parent_table.schema, + extend_existing=extends is not None ) diff --git a/sqlalchemy_continuum/transaction.py b/sqlalchemy_continuum/transaction.py index 85149437..96d7847f 100644 --- a/sqlalchemy_continuum/transaction.py +++ b/sqlalchemy_continuum/transaction.py @@ -1,4 +1,5 @@ from datetime import datetime +from functools import partial try: from collections import OrderedDict @@ -22,6 +23,10 @@ def compile_big_integer(element, compiler, **kw): return 'INTEGER' +class NoChangesAttribute(Exception): + pass + + class TransactionBase(object): issued_at = sa.Column(sa.DateTime, default=datetime.utcnow) @@ -29,8 +34,13 @@ class TransactionBase(object): def entity_names(self): """ Return a list of entity names that changed during this transaction. + Raises a NoChangesAttribute exception if the 'changes' column does + not exist, most likely because TransactionChangesPlugin is not enabled. """ - return [changes.entity_name for changes in self.changes] + if hasattr(self, 'changes'): + return [changes.entity_name for changes in self.changes] + else: + raise NoChangesAttribute() @property def changed_entities(self): @@ -47,8 +57,11 @@ def changed_entities(self): session = sa.orm.object_session(self) for class_, version_class in tuples: - if class_.__name__ not in self.entity_names: - continue + try: + if class_.__name__ not in self.entity_names: + continue + except NoChangesAttribute: + pass tx_column = manager.option(class_, 'transaction_column_name') @@ -131,7 +144,11 @@ class Transaction( if manager.user_cls: user_cls = manager.user_cls - registry = manager.declarative_base._decl_class_registry + Base = manager.declarative_base + try: + registry = Base.registry._class_registry + except AttributeError: # SQLAlchemy < 1.4 + registry = Base._decl_class_registry if isinstance(user_cls, six.string_types): try: diff --git a/sqlalchemy_continuum/unit_of_work.py b/sqlalchemy_continuum/unit_of_work.py index 5f91b13d..50ab768b 100644 --- a/sqlalchemy_continuum/unit_of_work.py +++ b/sqlalchemy_continuum/unit_of_work.py @@ -253,6 +253,11 @@ def update_version_validity(self, parent, version_obj): version_obj, alias=sa.orm.aliased(class_.__table__) ) + try: + subquery = subquery.scalar_subquery() + except AttributeError: # SQLAlchemy < 1.4 + subquery = subquery.as_scalar() + query = ( session.query(class_.__table__) .filter( diff --git a/sqlalchemy_continuum/utils.py b/sqlalchemy_continuum/utils.py index 90515733..bd150ed5 100644 --- a/sqlalchemy_continuum/utils.py +++ b/sqlalchemy_continuum/utils.py @@ -157,21 +157,17 @@ def version_table(model, table): :param model: SQLAlchemy declarative model class :param table: SQLAlchemy Table object """ - table_name = option(model, 'table_name') - table_schema = option(model, 'table_schema') - if table.schema: return table.metadata.tables[ - table_name % (apply_table_schema(table_schema, table.schema) + '.' + table.name) + table.schema + '.' + table.name + '_version' ] elif table.metadata.schema: return table.metadata.tables[ - table_name % (apply_table_schema(table_schema, table.metadata.schema) + '.' + table.name) + table.metadata.schema + '.' + table.name + '_version' ] else: - schema = apply_table_schema(table_schema, None) return table.metadata.tables[ - table_name % (((schema + '.') if schema else '') + table.name) + table.name + '_version' ] @@ -230,7 +226,11 @@ def versioned_column_properties(obj_or_class): cls = obj_or_class if isclass(obj_or_class) else obj_or_class.__class__ mapper = sa.inspect(cls) - for key in mapper.columns.keys(): + for key, column in mapper.columns.items(): + # Ignores non table columns + if not is_table_column(column): + continue + if not manager.is_excluded_property(obj_or_class, key): yield getattr(mapper.attrs, key) @@ -247,7 +247,7 @@ def versioned_relationships(obj, versioned_column_keys): yield prop -def vacuum(session, model): +def vacuum(session, model, yield_per=1000): """ When making structural changes to version tables (for example dropping columns) there are sometimes situations where some old version records @@ -268,6 +268,7 @@ def vacuum(session, model): :param session: SQLAlchemy session object :param model: SQLAlchemy declarative model class + :param yield_per: how many rows to process at a time """ version_cls = version_class(model) versions = defaultdict(list) @@ -275,15 +276,28 @@ def vacuum(session, model): query = ( session.query(version_cls) .order_by(option(version_cls, 'transaction_column_name')) - ) + ).yield_per(yield_per) + + primary_key_col = sa.inspection.inspect(model).primary_key[0].name for version in query: - if versions[version.id]: - prev_version = versions[version.id][-1] + version_id = getattr(version, primary_key_col) + if versions[version_id]: + prev_version = versions[version_id][-1] if naturally_equivalent(prev_version, version): session.delete(version) else: - versions[version.id].append(version) + versions[version_id].append(version) + + +def is_table_column(column): + """ + Return wheter of not give field is a column over the database table. + + :param column: SQLAclhemy model field. + :rtype: bool + """ + return isinstance(column, sa.Column) def is_internal_column(model, column_name): @@ -429,7 +443,10 @@ def changeset(obj): data = {} session = sa.orm.object_session(obj) if session and obj in session.deleted: - for column in sa.inspect(obj.__class__).columns.values(): + columns = [c for c in sa.inspect(obj.__class__).columns.values() + if is_table_column(c)] + + for column in columns: if not column.primary_key: value = getattr(obj, column.key) if value is not None: diff --git a/sqlalchemy_continuum/version.py b/sqlalchemy_continuum/version.py index 5c3c1ed2..d71e745d 100644 --- a/sqlalchemy_continuum/version.py +++ b/sqlalchemy_continuum/version.py @@ -1,4 +1,5 @@ import sqlalchemy as sa + from .reverter import Reverter from .utils import get_versioning_manager, is_internal_column, parent_class @@ -49,9 +50,6 @@ def changeset(self): and second list value as the new value. """ previous_version = self.previous - if not previous_version and self.operation_type != 0: - return {} - data = {} for key in sa.inspect(self.__class__).columns.keys(): diff --git a/tests/__init__.py b/tests/__init__.py index 4bc17e01..4d5a50af 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,4 @@ + from copy import copy import inspect import itertools as it @@ -6,7 +7,7 @@ import sqlalchemy as sa from sqlalchemy import create_engine from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import sessionmaker, column_property from sqlalchemy_continuum import ( ClassNotVersioned, version_class, @@ -42,9 +43,9 @@ def log_sql( def get_dns_from_driver(driver): if driver == 'postgres': - return 'postgres://postgres@localhost/sqlalchemy_continuum_test' + return 'postgresql://postgres:postgres@localhost/sqlalchemy_continuum_test' elif driver == 'mysql': - return 'mysql+pymysql://travis@localhost/sqlalchemy_continuum_test' + return 'mysql+pymysql://root@localhost/sqlalchemy_continuum_test' elif driver == 'sqlite': return 'sqlite:///:memory:' else: @@ -162,6 +163,9 @@ class Article(self.Model): content = sa.Column(sa.UnicodeText) description = sa.Column(sa.UnicodeText) + # Dynamic column cotaining all text content data + fulltext_content = column_property(name + content + description) + class Tag(self.Model): __tablename__ = 'tag' __versioned__ = copy(self.options) diff --git a/tests/builders/test_table_builder.py b/tests/builders/test_table_builder.py index 16323d02..a2255c83 100644 --- a/tests/builders/test_table_builder.py +++ b/tests/builders/test_table_builder.py @@ -3,6 +3,7 @@ import sqlalchemy as sa from sqlalchemy_continuum import version_class from tests import TestCase +from pytest import mark class TestTableBuilder(TestCase): @@ -69,3 +70,31 @@ class Article(self.Model): def test_takes_out_onupdate_triggers(self): table = version_class(self.Article).__table__ assert table.c.last_update.onupdate is None + +@mark.skipif("os.environ.get('DB') == 'sqlite'") +class TestTableBuilderInOtherSchema(TestCase): + def create_models(self): + class Article(self.Model): + __tablename__ = 'article' + __versioned__ = copy(self.options) + __table_args__ = {'schema': 'other'} + + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + last_update = sa.Column( + sa.DateTime, + default=datetime.utcnow, + onupdate=datetime.utcnow, + nullable=False + ) + self.Article = Article + + def create_tables(self): + self.connection.execute('DROP SCHEMA IF EXISTS other') + self.connection.execute('CREATE SCHEMA other') + TestCase.create_tables(self) + + def test_created_tables_retain_schema(self): + table = version_class(self.Article).__table__ + assert table.schema is not None + assert table.schema == self.Article.__table__.schema + diff --git a/tests/inheritance/test_single_table_inheritance.py b/tests/inheritance/test_single_table_inheritance.py index 9b723c15..73295bea 100644 --- a/tests/inheritance/test_single_table_inheritance.py +++ b/tests/inheritance/test_single_table_inheritance.py @@ -1,4 +1,5 @@ import sqlalchemy as sa +from sqlalchemy.ext.declarative import declared_attr from sqlalchemy_continuum import versioning_manager, version_class from tests import TestCase, create_test_cases @@ -18,6 +19,7 @@ class TextItem(self.Model): __mapper_args__ = { 'polymorphic_on': discriminator, + 'polymorphic_identity': u'base', 'with_polymorphic': '*' } @@ -25,6 +27,10 @@ class Article(TextItem): __mapper_args__ = {'polymorphic_identity': u'article'} name = sa.Column(sa.Unicode(255)) + @sa.ext.declarative.declared_attr + def status(cls): + return sa.Column("_status", sa.Unicode(255)) + class BlogPost(TextItem): __mapper_args__ = {'polymorphic_identity': u'blog_post'} title = sa.Column(sa.Unicode(255)) @@ -79,5 +85,8 @@ def test_transaction_changed_entities(self): assert transaction.entity_names == [u'Article'] assert transaction.changed_entities + def test_declared_attr_inheritance(self): + assert self.ArticleVersion.status + create_test_cases(SingleTableInheritanceTestCase) diff --git a/tests/plugins/test_activity.py b/tests/plugins/test_activity.py index 812eb542..4d0ab532 100644 --- a/tests/plugins/test_activity.py +++ b/tests/plugins/test_activity.py @@ -36,6 +36,34 @@ def create_activity(self, object=None, target=None): return activity +class TestActivityNotId(ActivityTestCase): + + def create_models(self): + TestCase.create_models(self) + + class NotIdModel(self.Model): + __tablename__ = 'not_id' + __versioned__ = { + 'base_classes': (self.Model, ) + } + + pk = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255), nullable=False) + self.NotIdModel = NotIdModel + + def test_create_activity_with_pk(self): + not_id_model = self.NotIdModel(name=u'Some model without id PK') + self.session.add(not_id_model) + self.session.commit() + self.create_activity(not_id_model) + self.session.commit() + activity = self.session.query(versioning_manager.activity_cls).first() + assert activity + assert activity.transaction_id + assert activity.object == not_id_model + assert activity.object_version == not_id_model.versions[-1] + + class TestActivity(ActivityTestCase): def test_creates_activity_class(self): assert versioning_manager.activity_cls.__name__ == 'Activity' diff --git a/tests/plugins/test_flask.py b/tests/plugins/test_flask.py index 6cd1c6a4..27bceeb3 100644 --- a/tests/plugins/test_flask.py +++ b/tests/plugins/test_flask.py @@ -1,7 +1,7 @@ import os from flask import Flask, url_for -from flask_login import LoginManager +from flask_login import LoginManager, UserMixin from flask_sqlalchemy import SQLAlchemy, _SessionSignalEvents from flexmock import flexmock @@ -66,17 +66,17 @@ def login(self, user): :returns: the logged in user """ with self.client.session_transaction() as s: - s['user_id'] = user.id + s['_user_id'] = user.id return user def logout(self, user=None): with self.client.session_transaction() as s: - s['user_id'] = None + s['_user_id'] = None def create_models(self): TestCase.create_models(self) - class User(self.Model): + class User(self.Model, UserMixin): __tablename__ = 'user' __versioned__ = { 'base_classes': (self.Model, ) diff --git a/tests/relationships/test_association_table_relations.py b/tests/relationships/test_association_table_relations.py new file mode 100644 index 00000000..b9d515c9 --- /dev/null +++ b/tests/relationships/test_association_table_relations.py @@ -0,0 +1,61 @@ +import sqlalchemy as sa +from sqlalchemy import PrimaryKeyConstraint +from sqlalchemy.orm import relationship +from tests import TestCase, create_test_cases + + +class AssociationTableRelationshipsTestCase(TestCase): + def create_models(self): + super(AssociationTableRelationshipsTestCase, self).create_models() + + class PublishedArticle(self.Model): + __tablename__ = 'published_article' + __table_args__ = ( + PrimaryKeyConstraint("article_id", "author_id"), + {'keep_existing': True} + ) + + article_id = sa.Column(sa.Integer, sa.ForeignKey('article.id')) + author_id = sa.Column(sa.Integer, sa.ForeignKey('author.id')) + author = relationship('Author') + article = relationship('Article') + + self.PublishedArticle = PublishedArticle + + published_articles_table = sa.Table(PublishedArticle.__tablename__, + PublishedArticle.metadata, + extend_existing=True) + + class Author(self.Model): + __tablename__ = 'author' + __versioned__ = { + 'base_classes': (self.Model, ) + } + + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + articles = relationship('Article', secondary=published_articles_table) + + self.Author = Author + + def test_version_relations(self): + article = self.Article() + name = u'Some article' + article.name = name + article.content = u'Some content' + self.session.add(article) + self.session.commit() + assert article.versions[0].name == name + + au = self.Author(name=u'Some author') + self.session.add(au) + self.session.commit() + + pa = self.PublishedArticle(article_id=article.id, author_id=au.id) + self.session.add(pa) + + self.session.commit() + + + +create_test_cases(AssociationTableRelationshipsTestCase) diff --git a/tests/relationships/test_many_to_many_relations.py b/tests/relationships/test_many_to_many_relations.py index 053deef4..b068fa6b 100644 --- a/tests/relationships/test_many_to_many_relations.py +++ b/tests/relationships/test_many_to_many_relations.py @@ -1,4 +1,5 @@ import pytest +from pytest import mark import sqlalchemy as sa from sqlalchemy_continuum import versioning_manager @@ -69,6 +70,33 @@ def test_single_insert(self): self.session.commit() assert len(article.versions[0].tags) == 1 + def test_unrelated_change(self): + tag1 = self.Tag(name=u'some tag') + tag2 = self.Tag(name=u'some tag2') + + self.session.add(tag1) + self.session.add(tag2) + self.session.commit() + + article1 = self.Article(name="Some article", ) + article1.name = u'Some article' + article1.tags.append(tag1) + + self.session.add(article1) + self.session.commit() + + article2 = self.Article() + article2.name = u'Some article2' + article2.tags.append(tag1) + + self.session.add(article2) + self.session.commit() + + article1.name = u'Some other name' + self.session.commit() + + assert len(article1.versions[1].tags) == 1 + def test_multi_insert(self): article = self.Article() article.name = u'Some article' @@ -341,3 +369,108 @@ def test_multiple_inserts_over_multiple_transactions(self): assert len(reference1.versions[2].cited_by) == 1 assert article.versions[2] in reference1.versions[2].cited_by + + +@mark.skipif("os.environ.get('DB') == 'sqlite'") +class TestManyToManySelfReferentialInOtherSchema(TestManyToManySelfReferential): + def create_models(self): + class Article(self.Model): + __tablename__ = 'article' + __versioned__ = {} + __table_args__ = {'schema': 'other'} + + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + article_references = sa.Table( + 'article_references', + self.Model.metadata, + sa.Column( + 'referring_id', + sa.Integer, + sa.ForeignKey('other.article.id'), + primary_key=True, + ), + sa.Column( + 'referred_id', + sa.Integer, + sa.ForeignKey('other.article.id'), + primary_key=True + ), + schema='other' + ) + + Article.references = sa.orm.relationship( + Article, + secondary=article_references, + primaryjoin=Article.id == article_references.c.referring_id, + secondaryjoin=Article.id == article_references.c.referred_id, + backref='cited_by' + ) + + self.Article = Article + self.referenced_articles_table = article_references + + def create_tables(self): + self.connection.execute('DROP SCHEMA IF EXISTS other') + self.connection.execute('CREATE SCHEMA other') + TestManyToManySelfReferential.create_tables(self) + + +@mark.skipif("os.environ.get('DB') == 'sqlite'") +class ManyToManyRelationshipsInOtherSchemaTestCase(ManyToManyRelationshipsTestCase): + def create_models(self): + class Article(self.Model): + __tablename__ = 'article' + __versioned__ = { + 'base_classes': (self.Model, ) + } + __table_args__ = {'schema': 'other'} + + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + article_tag = sa.Table( + 'article_tag', + self.Model.metadata, + sa.Column( + 'article_id', + sa.Integer, + sa.ForeignKey('other.article.id'), + primary_key=True, + ), + sa.Column( + 'tag_id', + sa.Integer, + sa.ForeignKey('other.tag.id'), + primary_key=True + ), + schema='other' + ) + + class Tag(self.Model): + __tablename__ = 'tag' + __versioned__ = { + 'base_classes': (self.Model, ) + } + __table_args__ = {'schema': 'other'} + + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + + Tag.articles = sa.orm.relationship( + Article, + secondary=article_tag, + backref='tags' + ) + + self.Article = Article + self.Tag = Tag + + + def create_tables(self): + self.connection.execute('DROP SCHEMA IF EXISTS other') + self.connection.execute('CREATE SCHEMA other') + ManyToManyRelationshipsTestCase.create_tables(self) + +create_test_cases(ManyToManyRelationshipsInOtherSchemaTestCase) diff --git a/tests/test_changeset.py b/tests/test_changeset.py index c479c85f..5604744c 100644 --- a/tests/test_changeset.py +++ b/tests/test_changeset.py @@ -56,7 +56,11 @@ def test_changeset_for_history_that_does_not_have_first_insert(self): self.transaction_column_name, tx_log.id) ) - assert self.session.query(self.ArticleVersion).first().changeset == {} + assert self.session.query(self.ArticleVersion).first().changeset == { + 'content': [None, 'some content'], + 'id': [None, 1], + 'name': [None, 'something'] + } class TestChangeSetWithValidityStrategy(ChangeSetTestCase): @@ -72,7 +76,7 @@ def create_models(self): class Article(self.Model): __tablename__ = 'article' __versioned__ = { - 'base_classes': (self.Model, ) + 'base_classes': (self.Model,) } id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) @@ -83,7 +87,7 @@ class Article(self.Model): class Tag(self.Model): __tablename__ = 'tag' __versioned__ = { - 'base_classes': (self.Model, ) + 'base_classes': (self.Model,) } id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) diff --git a/tests/test_column_inclusion_and_exclusion.py b/tests/test_column_inclusion_and_exclusion.py index e8530d2d..e916b383 100644 --- a/tests/test_column_inclusion_and_exclusion.py +++ b/tests/test_column_inclusion_and_exclusion.py @@ -53,3 +53,52 @@ class TextItem(self.Model): content = sa.Column('_content', sa.UnicodeText) self.TextItem = TextItem + + +class TestColumnExclusionWithRelationship(TestCase): + def create_models(self): + + class Word(self.Model): + __tablename__ = 'word' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + word = sa.Column(sa.Unicode(255)) + + class TextItemWord(self.Model): + __tablename__ = 'text_item_word' + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + text_item_id = sa.Column(sa.Integer, sa.ForeignKey('text_item.id'), nullable=False) + word_id = sa.Column(sa.Integer, sa.ForeignKey('word.id'), nullable=False) + + class TextItem(self.Model): + __tablename__ = 'text_item' + __versioned__ = { + 'exclude': ['content'] + } + + id = sa.Column(sa.Integer, autoincrement=True, primary_key=True) + name = sa.Column(sa.Unicode(255)) + content = sa.orm.relationship(Word, secondary='text_item_word') + + self.TextItem = TextItem + self.Word = Word + + def test_excluded_columns_not_included_in_version_class(self): + cls = version_class(self.TextItem) + manager = cls._sa_class_manager + assert 'content' not in manager.keys() + + def test_versioning_with_column_exclusion(self): + item = self.TextItem(name=u'Some textitem', + content=[self.Word(word=u'bird')]) + self.session.add(item) + self.session.commit() + + assert item.versions[0].name == u'Some textitem' + + def test_does_not_create_record_if_only_excluded_column_updated(self): + item = self.TextItem(name=u'Some textitem') + self.session.add(item) + self.session.commit() + item.content.append(self.Word(word=u'Some content')) + self.session.commit() + assert item.versions.count() == 1 diff --git a/tests/test_mapper_args.py b/tests/test_mapper_args.py index 356f85f1..bb14ffb5 100644 --- a/tests/test_mapper_args.py +++ b/tests/test_mapper_args.py @@ -1,3 +1,6 @@ +from pytest import mark +from packaging import version + import sqlalchemy as sa from sqlalchemy_continuum import version_class from tests import TestCase @@ -29,6 +32,7 @@ def test_supports_column_prefix(self): assert self.TextItem._id +@mark.skipif("version.parse(sa.__version__) >= version.parse('1.4')") class TestOrderByWithStringArg(TestCase): def create_models(self): class TextItem(self.Model): @@ -55,6 +59,7 @@ def test_reflects_order_by(self): assert self.TextItemVersion.__mapper_args__['order_by'] == 'id' +@mark.skipif("version.parse(sa.__version__) >= version.parse('1.4')") class TestOrderByWithInstrumentedAttribute(TestCase): def create_models(self): class TextItem(self.Model): diff --git a/tests/test_sessions.py b/tests/test_sessions.py index f5780f89..6a1fbfb0 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -52,3 +52,19 @@ class TestUnitOfWork(TestCase): def test_with_session_arg(self): uow = versioning_manager.unit_of_work(self.session) assert isinstance(uow, UnitOfWork) + + +class TestExternalTransactionSession(TestCase): + + def test_session_with_external_transaction(self): + conn = self.engine.connect() + t = conn.begin() + session = Session(bind=conn) + + article = self.Article(name=u'My Session Article') + session.add(article) + session.flush() + + session.close() + t.rollback() + conn.close() diff --git a/tests/test_transaction.py b/tests/test_transaction.py index f647130d..2b3e7b56 100644 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -1,6 +1,9 @@ import sqlalchemy as sa from sqlalchemy_continuum import versioning_manager from tests import TestCase +from pytest import mark +from sqlalchemy_continuum.plugins import TransactionMetaPlugin + class TestTransaction(TestCase): @@ -37,6 +40,19 @@ def test_repr(self): repr(transaction) ) + def test_changed_entities(self): + article_v0 = self.article.versions[0] + transaction = article_v0.transaction + assert transaction.changed_entities == { + self.ArticleVersion: [article_v0], + self.TagVersion: [self.article.tags[0].versions[0]], + } + + +# Check that the tests pass without TransactionChangesPlugin +class TestTransactionWithoutChangesPlugin(TestTransaction): + plugins = [TransactionMetaPlugin()] + class TestAssigningUserClass(TestCase): user_cls = 'User' @@ -56,3 +72,31 @@ class User(self.Model): def test_copies_primary_key_type_from_user_class(self): attr = versioning_manager.transaction_cls.user_id assert isinstance(attr.property.columns[0].type, sa.Unicode) + + +@mark.skipif("os.environ.get('DB') == 'sqlite'") +class TestAssigningUserClassInOtherSchema(TestCase): + user_cls = 'User' + + def create_models(self): + class User(self.Model): + __tablename__ = 'user' + __versioned__ = { + 'base_classes': (self.Model,) + } + __table_args__ = {'schema': 'other'} + + id = sa.Column(sa.Unicode(255), primary_key=True) + name = sa.Column(sa.Unicode(255), nullable=False) + + self.User = User + + def create_tables(self): + self.connection.execute('DROP SCHEMA IF EXISTS other') + self.connection.execute('CREATE SCHEMA other') + TestCase.create_tables(self) + + def test_can_build_transaction_model(self): + # If create_models didn't crash this should be good + pass + diff --git a/tox.ini b/tox.ini index 0f2033d4..2d9b2fd2 100644 --- a/tox.ini +++ b/tox.ini @@ -1,5 +1,5 @@ [tox] -envlist = py27, py33, py34, py35 +envlist = py27, py33, py34, py35, py36, py37 [testenv] commands = pip install -e ".[test]"