diff --git a/planet/specs.py b/planet/specs.py index 0f8c1550b..9c08a0bb8 100644 --- a/planet/specs.py +++ b/planet/specs.py @@ -15,7 +15,7 @@ """Functionality for validating against the Planet API specification.""" import json import logging - +import itertools from .constants import DATA_DIR PRODUCT_BUNDLE_SPEC_NAME = 'orders_product_bundle_2022_02_02.json' @@ -37,19 +37,32 @@ LOGGER = logging.getLogger(__name__) -class SpecificationException(Exception): +class NoMatchException(Exception): '''No match was found''' pass +class SpecificationException(Exception): + '''No match was found''' + + def __init__(self, value, supported, field_name): + self.value = value + self.supported = supported + self.field_name = field_name + self.opts = ', '.join(["'" + s + "'" for s in supported]) + + def __str__(self): + return f'{self.field_name} - \'{self.value}\' is not one of {self.opts}.' + + def validate_bundle(bundle): supported = get_product_bundles() return _validate_field(bundle, supported, 'product_bundle') def validate_item_type(item_type, bundle): - bundle = validate_bundle(bundle) - supported = get_item_types(bundle) + validated_bundle = validate_bundle(bundle) + supported = get_item_types(validated_bundle) return _validate_field(item_type, supported, 'item_type') @@ -73,20 +86,13 @@ def validate_file_format(file_format): def _validate_field(value, supported, field_name): try: - value = get_match(value, supported) + value = get_match(value, supported, field_name) except (NoMatchException): - opts = ', '.join(["'" + s + "'" for s in supported]) - msg = f'{field_name} - \'{value}\' is not one of {opts}.' - raise SpecificationException(msg) + raise SpecificationException(value, supported, field_name) return value -class NoMatchException(Exception): - '''No match was found''' - pass - - -def get_match(test_entry, spec_entries): +def get_match(test_entry, spec_entries, field_name): '''Find and return matching spec entry regardless of capitalization. This is helpful for working with the API spec, where the capitalization @@ -96,7 +102,7 @@ def get_match(test_entry, spec_entries): match = next(e for e in spec_entries if e.lower() == test_entry.lower()) except (StopIteration): - raise NoMatchException('{test_entry} should be one of {spec_entries}') + raise SpecificationException(test_entry, spec_entries, field_name) return match @@ -111,14 +117,15 @@ def get_item_types(product_bundle=None): '''If given product bundle, get specific item types supported by Orders API. Otherwise, get all item types supported by Orders API.''' spec = _get_product_bundle_spec() + if product_bundle: - item_types = spec['bundles'][product_bundle]['assets'].keys() + item_types = set(spec['bundles'][product_bundle]['assets'].keys()) else: - product_bundle = get_product_bundles() - all_item_types = [] - for bundle in product_bundle: - all_item_types += [*spec['bundles'][bundle]['assets'].keys()] - item_types = set(all_item_types) + item_types = set( + itertools.chain.from_iterable( + spec['bundles'][bundle]['assets'].keys() + for bundle in get_product_bundles())) + return item_types diff --git a/tests/unit/test_specs.py b/tests/unit/test_specs.py index 21b4e5717..e7c17bf73 100644 --- a/tests/unit/test_specs.py +++ b/tests/unit/test_specs.py @@ -46,10 +46,11 @@ def test_get_type_match(): spec_list = ['Locket', 'drop', 'DEER'] test_entry = 'locket' - assert 'Locket' == specs.get_match(test_entry, spec_list) + field_name = 'field_name' + assert 'Locket' == specs.get_match(test_entry, spec_list, field_name) - with pytest.raises(specs.NoMatchException): - specs.get_match('a', ['b']) + with pytest.raises(specs.SpecificationException): + specs.get_match('a', ['b'], field_name) def test_validate_bundle_supported():