diff --git a/sqlalchemydiff/comparer.py b/sqlalchemydiff/comparer.py index 2fbc506..870d031 100644 --- a/sqlalchemydiff/comparer.py +++ b/sqlalchemydiff/comparer.py @@ -93,6 +93,12 @@ def compare(left_uri, right_uri, ignores=None, ignores_sep=None): tables_info.common, left_inspector, right_inspector, ignore_manager ) + info['enums'] = _get_enums_info( + left_inspector, + right_inspector, + ignore_manager.get('*', 'enum'), + ) + errors = _compile_errors(info) result = _make_result(info, errors) @@ -161,6 +167,7 @@ def _get_info_dict(left_uri, right_uri, tables_info): 'common': tables_info.common, }, 'tables_data': {}, + 'enums': {}, } return info @@ -214,6 +221,13 @@ def _get_table_data( ignore_manager.get(table_name, 'col') ) + table_data['constraints'] = _get_constraints_info( + left_inspector, + right_inspector, + table_name, + ignore_manager.get(table_name, 'cons') + ) + return table_data @@ -335,6 +349,56 @@ def _get_columns(inspector, table_name): return inspector.get_columns(table_name) +def _get_constraints_info(left_inspector, right_inspector, + table_name, ignores): + left_constraints_list = _get_constraints_data(left_inspector, table_name) + right_constraints_list = _get_constraints_data(right_inspector, table_name) + + left_constraints_list = _discard_ignores_by_name(left_constraints_list, + ignores) + right_constraints_list = _discard_ignores_by_name(right_constraints_list, + ignores) + + # process into dict + left_constraints = dict((elem['name'], elem) + for elem in left_constraints_list) + right_constraints = dict((elem['name'], elem) + for elem in right_constraints_list) + + return _diff_dicts(left_constraints, right_constraints) + + +def _get_constraints_data(inspector, table_name): + try: + return inspector.get_check_constraints(table_name) + except (AttributeError, NotImplementedError): # pragma: no cover + # sqlalchemy < 1.1.0 + # or a dialect that doesn't implement get_check_constraints + return [] + + +def _get_enums_info(left_inspector, right_inspector, ignores): + left_enums_list = _get_enums_data(left_inspector) + right_enums_list = _get_enums_data(right_inspector) + + left_enums_list = _discard_ignores_by_name(left_enums_list, ignores) + right_enums_list = _discard_ignores_by_name(right_enums_list, ignores) + + # process into dict + left_enums = dict((elem['name'], elem) for elem in left_enums_list) + right_enums = dict((elem['name'], elem) for elem in right_enums_list) + + return _diff_dicts(left_enums, right_enums) + + +def _get_enums_data(inspector): + try: + # as of 1.2.0, PostgreSQL dialect only; see PGInspector + return inspector.get_enums() + except AttributeError: + return [] + + def _discard_ignores_by_name(items, ignores): return [item for item in items if item['name'] not in ignores] @@ -364,6 +428,7 @@ def _compile_errors(info): errors_template = { 'tables': {}, 'tables_data': {}, + 'enums': {}, } errors = deepcopy(errors_template) @@ -375,7 +440,8 @@ def _compile_errors(info): errors['tables']['right_only'] = info['tables']['right_only'] # then check if there is a discrepancy in the data for each table - keys = ['foreign_keys', 'primary_keys', 'indexes', 'columns'] + keys = ['foreign_keys', 'primary_keys', 'indexes', 'columns', + 'constraints'] subkeys = ['left_only', 'right_only', 'diff'] for table_name in info['tables_data']: @@ -386,6 +452,10 @@ def _compile_errors(info): table_d.setdefault(key, {})[subkey] = info[ 'tables_data'][table_name][key][subkey] + for subkey in subkeys: + if info['enums'][subkey]: + errors['enums'][subkey] = info['enums'][subkey] + if errors != errors_template: errors['uris'] = info['uris'] return errors diff --git a/sqlalchemydiff/util.py b/sqlalchemydiff/util.py index 6b2fc37..755618e 100644 --- a/sqlalchemydiff/util.py +++ b/sqlalchemydiff/util.py @@ -108,7 +108,7 @@ def prepare_schema_from_models(uri, sqlalchemy_base): class IgnoreManager: - allowed_identifiers = ['pk', 'fk', 'idx', 'col'] + allowed_identifiers = ['pk', 'fk', 'idx', 'col', 'cons', 'enum'] def __init__(self, ignores, separator=None): self.separator = separator or '.' diff --git a/test/endtoend/enumadaptor.py b/test/endtoend/enumadaptor.py new file mode 100644 index 0000000..f503cd3 --- /dev/null +++ b/test/endtoend/enumadaptor.py @@ -0,0 +1,20 @@ +""" +Adapt Enum across versions of SQLAlchemy. + +SQLAlchemy supports PEP 435 Enum classes as of 1.1. +Prior versions supported only the values as strings. + +Export a suitable column type for either case. +""" +import enum +import sqlalchemy + + +def Enum(*enums, **kw): + if sqlalchemy.__version__ >= '1.1': + return sqlalchemy.Enum(*enums, **kw) + + if len(enums) == 1 and issubclass(enums[0], enum.Enum): + return sqlalchemy.Enum(*(v.name for v in enums[0]), **kw) + + return sqlalchemy.Enum(*enums, **kw) diff --git a/test/endtoend/models_left.py b/test/endtoend/models_left.py index 0565d9c..c3c03fa 100644 --- a/test/endtoend/models_left.py +++ b/test/endtoend/models_left.py @@ -1,11 +1,20 @@ # -*- coding: utf-8 -*- +import enum + from sqlalchemy import Column, ForeignKey, Integer, String, Unicode from sqlalchemy.ext.declarative import declarative_base +from .enumadaptor import Enum + Base = declarative_base() +class Polarity(enum.Enum): + NEGATIVE = 'NEGATIVE' + POSITIVE = 'POSITIVE' + + class Employee(Base): __tablename__ = "employees" @@ -14,6 +23,8 @@ class Employee(Base): age = Column(Integer, nullable=False, default=21) ssn = Column(Unicode(30), nullable=False) number_of_pets = Column(Integer, default=1, nullable=False) + polarity = Column(Enum(Polarity, native_enum=True)) + spin = Column(Enum('spin_down', 'spin_up', native_enum=False)) company_id = Column( Integer, diff --git a/test/endtoend/models_right.py b/test/endtoend/models_right.py index 8c42493..d61ad14 100644 --- a/test/endtoend/models_right.py +++ b/test/endtoend/models_right.py @@ -1,11 +1,20 @@ # -*- coding: utf-8 -*- +import enum + from sqlalchemy import Column, ForeignKey, Integer, String, Unicode from sqlalchemy.ext.declarative import declarative_base +from .enumadaptor import Enum + Base = declarative_base() +class Polarity(enum.Enum): + NEG = 'NEG' + POS = 'POS' + + class Employee(Base): __tablename__ = "employees" @@ -14,6 +23,8 @@ class Employee(Base): age = Column(Integer, nullable=False, default=21) ssn = Column(Unicode(30), nullable=False) number_of_pets = Column(Integer, default=1, nullable=False) + polarity = Column(Enum(Polarity, native_enum=True)) + spin = Column(Enum('down', 'up', native_enum=False)) company_id = Column( Integer, diff --git a/test/endtoend/test_example.py b/test/endtoend/test_example.py index e4e845f..ea2488f 100644 --- a/test/endtoend/test_example.py +++ b/test/endtoend/test_example.py @@ -2,6 +2,7 @@ import json import pytest +from sqlalchemy import create_engine from sqlalchemydiff.comparer import compare from sqlalchemydiff.util import ( @@ -108,6 +109,39 @@ def test_errors_dict_catches_all_differences(uri_left, uri_right): } }, 'employees': { + 'columns': { + 'diff': [ + { + 'key': 'polarity', + 'left': { + 'default': None, + 'name': 'polarity', + 'nullable': True, + 'type': "ENUM('NEGATIVE','POSITIVE')"}, + 'right': { + 'default': None, + 'name': 'polarity', + 'nullable': True, + 'type': "ENUM('NEG','POS')" + } + }, + { + 'key': 'spin', + 'left': { + 'default': None, + 'name': 'spin', + 'nullable': True, + 'type': 'VARCHAR(9)' + }, + 'right': { + 'default': None, + 'name': 'spin', + 'nullable': True, + 'type': 'VARCHAR(4)' + } + } + ] + }, 'foreign_keys': { 'left_only': [ { @@ -215,12 +249,27 @@ def test_errors_dict_catches_all_differences(uri_left, uri_right): } } }, + 'enums': { + }, 'uris': { 'left': uri_left, 'right': uri_right, } } + engine = create_engine(uri_left) + dialect = engine.dialect + if getattr(dialect, 'supports_comments', False): + # sqlalchemy 1.2.0 adds support for SQL comments + # expect them in the errors when supported + for table in expected_errors['tables_data'].values(): + for column in table['columns']['diff']: + for side in ['left', 'right']: + column[side].update(comment=None) + for side in ['left_only', 'right_only']: + for column in table['columns'].get(side, []): + column.update(comment=None) + assert not result.is_match compare_error_dicts(expected_errors, result.errors) @@ -297,8 +346,11 @@ def test_ignores(uri_left, uri_right): ignores = [ 'mobile_numbers', 'phone_numbers', + '*.enum.polarity', 'companies.col.name', 'companies.idx.name', + 'employees.col.polarity', + 'employees.col.spin', 'employees.fk.fk_employees_companies', 'employees.fk.fk_emp_comp', 'employees.idx.ix_employees_name', @@ -328,8 +380,11 @@ def test_ignores_alternative_sep(uri_left, uri_right): ignores = [ 'mobile_numbers', 'phone_numbers', + '*#enum#polarity', 'companies#col#name', 'companies#idx#name', + 'employees#col#polarity', + 'employees#col#spin', 'employees#fk#fk_employees_companies', 'employees#fk#fk_emp_comp', 'employees#idx#ix_employees_name', @@ -353,6 +408,7 @@ def test_ignores_alternative_sep(uri_left, uri_right): @pytest.mark.parametrize('missing_ignore', [ 'mobile_numbers', 'phone_numbers', + '*.enum.polarity', 'companies.col.name', 'companies.idx.name', 'employees.fk.fk_employees_companies', @@ -375,6 +431,7 @@ def test_ignores_all_needed(uri_left, uri_right, missing_ignore): ignores = [ 'mobile_numbers', 'phone_numbers', + '*.enum.polarity', 'companies.col.name', 'companies.idx.name', 'employees.fk.fk_employees_companies', diff --git a/test/unit/test_comparer.py b/test/unit/test_comparer.py index 4002459..a637be2 100644 --- a/test/unit/test_comparer.py +++ b/test/unit/test_comparer.py @@ -93,9 +93,15 @@ def _get_tables_info_mock(self): ) yield m + @pytest.yield_fixture + def _get_enums_data_mock(self): + with patch('sqlalchemydiff.comparer._get_enums_data') as m: + m.return_value = [] + yield m + def test_compare_calls_chain( self, _get_tables_info_mock, _get_tables_data_mock, - _compile_errors_mock): + _get_enums_data_mock, _compile_errors_mock): """By inspecting `info` and `errors` at the end, we automatically check that the whole process works as expected. What this test leaves out is the verifications about inspectors. @@ -134,6 +140,12 @@ def test_compare_calls_chain( 'data': 'some-data-B', }, }, + 'enums': { + 'left_only': [], + 'right_only': [], + 'common': [], + 'diff': [], + }, } expected_errors = expected_info.copy() @@ -144,7 +156,8 @@ def test_compare_calls_chain( def test__get_tables_info_called_with_correct_inspectors( self, _get_inspectors_mock, _get_tables_info_mock, - _get_tables_data_mock, _compile_errors_mock): + _get_tables_data_mock, _get_enums_data_mock, + _compile_errors_mock): left_inspector, right_inspector = _get_inspectors_mock.return_value compare("left_uri", "right_uri", ignores=['ignore_me']) @@ -219,6 +232,11 @@ def _get_columns_info_mock(self): with patch('sqlalchemydiff.comparer._get_columns_info') as m: yield m + @pytest.yield_fixture + def _get_constraints_info_mock(self): + with patch('sqlalchemydiff.comparer._get_constraints_info') as m: + yield m + # TESTS def test__get_inspectors(self): @@ -302,6 +320,7 @@ def test__get_info_dict(self): 'common': ['C'], }, 'tables_data': {}, + 'enums': {}, } assert expected_info == info @@ -616,7 +635,8 @@ def test_process_type(self): def test__get_table_data( self, _get_foreign_keys_info_mock, _get_primary_keys_info_mock, - _get_indexes_info_mock, _get_columns_info_mock): + _get_indexes_info_mock, _get_columns_info_mock, + _get_constraints_info_mock): left_inspector, right_inspector = Mock(), Mock() _get_foreign_keys_info_mock.return_value = { @@ -631,6 +651,9 @@ def test__get_table_data( _get_columns_info_mock.return_value = { 'left_only': 13, 'right_only': 14, 'common': 15, 'diff': 16 } + _get_constraints_info_mock.return_value = { + 'left_only': 17, 'right_only': 18, 'common': 19, 'diff': 20 + } result = _get_table_data( left_inspector, right_inspector, 'table_A', Mock() @@ -661,6 +684,12 @@ def test__get_table_data( 'common': 15, 'diff': 16, }, + 'constraints': { + 'left_only': 17, + 'right_only': 18, + 'common': 19, + 'diff': 20, + }, } assert expected_result == result @@ -704,6 +733,12 @@ def test__compile_errors_with_errors(self): 'right_only': 14, 'common': 15, 'diff': 16, + }, + 'constraints': { + 'left_only': 17, + 'right_only': 18, + 'common': 19, + 'diff': 20, } }, @@ -731,8 +766,20 @@ def test__compile_errors_with_errors(self): 'right_only': 14, 'common': 15, 'diff': 16, + }, + 'constraints': { + 'left_only': 17, + 'right_only': 18, + 'common': 19, + 'diff': 20, } } + }, + 'enums': { + 'left_only': 21, + 'right_only': 22, + 'common': 23, + 'diff': 24, } } @@ -766,6 +813,11 @@ def test__compile_errors_with_errors(self): 'left_only': 13, 'right_only': 14, 'diff': 16, + }, + 'constraints': { + 'left_only': 17, + 'right_only': 18, + 'diff': 20, } }, @@ -789,8 +841,18 @@ def test__compile_errors_with_errors(self): 'left_only': 13, 'right_only': 14, 'diff': 16, + }, + 'constraints': { + 'left_only': 17, + 'right_only': 18, + 'diff': 20, } } + }, + 'enums': { + 'left_only': 21, + 'right_only': 22, + 'diff': 24, } } @@ -837,7 +899,19 @@ def test__compile_errors_without_errors(self): 'common': 4, 'diff': [], }, + 'constraints': { + 'left_only': [], + 'right_only': [], + 'common': 5, + 'diff': [], + }, } + }, + 'enums': { + 'left_only': [], + 'right_only': [], + 'common': 6, + 'diff': [], } } diff --git a/test/unit/test_util.py b/test/unit/test_util.py index e70f97d..15c26ab 100644 --- a/test/unit/test_util.py +++ b/test/unit/test_util.py @@ -210,12 +210,13 @@ def test_identifier_incorrect(self): IgnoreManager(ignore_data) assert ( - "unknown is invalid. It must be in ['pk', 'fk', 'idx', 'col']", + "unknown is invalid. It must be in " + "['pk', 'fk', 'idx', 'col', 'cons', 'enum']", ) == err.value.args @pytest.mark.parametrize('clause', [ - 'too.few', - 'too.many.definitely.for-sure', + 'too.few', + 'too.many.definitely.for-sure', ]) def test_incorrect_clause(self, clause): ignore_data = [clause] @@ -229,8 +230,8 @@ def test_incorrect_clause(self, clause): ) == err.value.args @pytest.mark.parametrize('clause', [ - '.pk.b', - 'a.pk.', + '.pk.b', + 'a.pk.', ]) def test_incorrect_empty_clause(self, clause): ignore_data = [clause]