diff --git a/planet/cli/data.py b/planet/cli/data.py index 5f1385035..ab461d617 100644 --- a/planet/cli/data.py +++ b/planet/cli/data.py @@ -19,6 +19,7 @@ from planet import data_filter, DataClient from planet.clients.data import SEARCH_SORT, SEARCH_SORT_DEFAULT, STATS_INTERVAL +from planet.specs import get_item_types from . import types from .cmds import coro, translate_exceptions @@ -26,6 +27,9 @@ from .options import limit, pretty from .session import CliSession +ALL_ITEM_TYPES = get_item_types() +valid_item_string = "Valid entries for ITEM_TYPES: " + "|".join(ALL_ITEM_TYPES) + @asynccontextmanager async def data_client(ctx): @@ -58,6 +62,17 @@ def assets_to_filter(ctx, param, assets: List[str]) -> Optional[dict]: return data_filter.asset_filter(assets) if assets else None +def check_item_types(ctx, param, item_types) -> Optional[List[dict]]: + # Set difference between given item types and all item types + set_diff = set([item.lower() for item in item_types]) - set( + [a.lower() for a in ALL_ITEM_TYPES]) + if set_diff: + raise click.BadParameter( + f'{item_types} should be one of {ALL_ITEM_TYPES}') + else: + return item_types + + def date_range_to_filter(ctx, param, values) -> Optional[List[dict]]: def _func(obj): @@ -226,11 +241,13 @@ def filter(ctx, echo_json(filt, pretty) -@data.command() +@data.command(epilog=valid_item_string) @click.pass_context @translate_exceptions @coro -@click.argument("item_types", type=types.CommaSeparatedString()) +@click.argument("item_types", + type=types.CommaSeparatedString(), + callback=check_item_types) @click.argument("filter", type=types.JSON(), default="-", required=False) @limit @click.option('--name', type=str, help='Name of the saved search.') @@ -264,12 +281,14 @@ async def search(ctx, item_types, filter, limit, name, sort, pretty): echo_json(item, pretty) -@data.command() +@data.command(epilog=valid_item_string) @click.pass_context @translate_exceptions @coro @click.argument('name') -@click.argument("item_types", type=types.CommaSeparatedString()) +@click.argument("item_types", + type=types.CommaSeparatedString(), + callback=check_item_types) @click.argument("filter", type=types.JSON(), default="-", required=False) @click.option('--daily-email', is_flag=True, @@ -296,11 +315,13 @@ async def search_create(ctx, name, item_types, filter, daily_email, pretty): echo_json(items, pretty) -@data.command() +@data.command(epilog=valid_item_string) @click.pass_context @translate_exceptions @coro -@click.argument("item_types", type=types.CommaSeparatedString()) +@click.argument("item_types", + type=types.CommaSeparatedString(), + callback=check_item_types) @click.argument('interval', type=click.Choice(STATS_INTERVAL)) @click.argument("filter", type=types.JSON(), default="-", required=False) async def stats(ctx, item_types, interval, filter): diff --git a/planet/specs.py b/planet/specs.py index f3971ab6a..0f8c1550b 100644 --- a/planet/specs.py +++ b/planet/specs.py @@ -93,10 +93,10 @@ def get_match(test_entry, spec_entries): is hard to remember but must be exact otherwise the API throws an exception.''' try: - match = next(t for t in spec_entries - if t.lower() == test_entry.lower()) + match = next(e for e in spec_entries + if e.lower() == test_entry.lower()) except (StopIteration): - raise NoMatchException + raise NoMatchException('{test_entry} should be one of {spec_entries}') return match @@ -107,10 +107,19 @@ def get_product_bundles(): return spec['bundles'].keys() -def get_item_types(product_bundle): - '''Get item types supported by Orders API for the given product bundle.''' +def get_item_types(product_bundle=None): + '''If given product bundle, get specific item types supported by Orders + API. Otherwise, get all item types supported by Orders API.''' spec = _get_product_bundle_spec() - return spec['bundles'][product_bundle]['assets'].keys() + if product_bundle: + item_types = spec['bundles'][product_bundle]['assets'].keys() + else: + product_bundle = get_product_bundles() + all_item_types = [] + for bundle in product_bundle: + all_item_types += [*spec['bundles'][bundle]['assets'].keys()] + item_types = set(all_item_types) + return item_types def _get_product_bundle_spec(): diff --git a/tests/integration/test_data_cli.py b/tests/integration/test_data_cli.py index 60a1a6439..f146cf0fb 100644 --- a/tests/integration/test_data_cli.py +++ b/tests/integration/test_data_cli.py @@ -23,6 +23,7 @@ import pytest from planet.cli import cli +from planet.specs import get_item_types LOGGER = logging.getLogger(__name__) @@ -55,6 +56,23 @@ def test_data_command_registered(invoke): # Add other sub-commands here. +def test_data_search_command_registered(invoke): + """planet-data search command prints help and usage message.""" + runner = CliRunner() + result = invoke(["search", "--help"], runner=runner) + all_item_types = [a for a in get_item_types()] + assert result.exit_code == 0 + assert "Usage" in result.output + assert "limit" in result.output + assert "name" in result.output + assert "sort" in result.output + assert "pretty" in result.output + assert "help" in result.output + for a in all_item_types: + assert a in result.output.replace('\n', '').replace(' ', '') + # Add other sub-commands here. + + PERMISSION_FILTER = {"type": "PermissionFilter", "config": ["assets:download"]} STD_QUALITY_FILTER = { "type": "StringInFilter", @@ -358,8 +376,8 @@ def test_data_filter_update(invoke, assert_and_filters_equal, default_filters): @respx.mock @pytest.mark.asyncio @pytest.mark.parametrize("filter", ['{1:1}', '{"foo"}']) -@pytest.mark.parametrize( - "item_types", ['PSScene', 'SkySatScene', ('PSScene', 'SkySatScene')]) +@pytest.mark.parametrize("item_types", + ['PSScene', 'SkySatScene', 'PSScene, SkySatScene']) def test_data_search_cmd_filter_invalid_json(invoke, item_types, filter): """Test for planet data search_quick. Test with multiple item_types. Test should fail as filter does not contain valid JSON.""" @@ -375,8 +393,8 @@ def test_data_search_cmd_filter_invalid_json(invoke, item_types, filter): @respx.mock -@pytest.mark.parametrize( - "item_types", ['PSScene', 'SkySatScene', ('PSScene', 'SkySatScene')]) +@pytest.mark.parametrize("item_types", + ['PSScene', 'SkySatScene', 'PSScene, SkySatScene']) def test_data_search_cmd_filter_success(invoke, item_types): """Test for planet data search_quick. Test with multiple item_types. Test should succeed as filter contains valid JSON.""" @@ -495,8 +513,8 @@ def test_data_search_cmd_limit(invoke, @respx.mock @pytest.mark.asyncio @pytest.mark.parametrize("filter", ['{1:1}', '{"foo"}']) -@pytest.mark.parametrize( - "item_types", ['PSScene', 'SkySatScene', ('PSScene', 'SkySatScene')]) +@pytest.mark.parametrize("item_types", + ['PSScene', 'SkySatScene', 'PSScene, SkySatScene']) def test_data_search_create_filter_invalid_json(invoke, item_types, filter): """Test for planet data search_create. Test with multiple item_types. Test should fail as filter does not contain valid JSON.""" @@ -514,8 +532,8 @@ def test_data_search_create_filter_invalid_json(invoke, item_types, filter): @respx.mock -@pytest.mark.parametrize( - "item_types", ['PSScene', 'SkySatScene', ('PSScene', 'SkySatScene')]) +@pytest.mark.parametrize("item_types", + ['PSScene', 'SkySatScene', 'PSScene, SkySatScene']) def test_data_search_create_filter_success(invoke, item_types): """Test for planet data search_create. Test with multiple item_types. Test should succeed as filter contains valid JSON.""" @@ -601,8 +619,8 @@ def test_data_stats_invalid_filter(invoke, filter): @respx.mock -@pytest.mark.parametrize( - "item_types", ['PSScene', 'SkySatScene', ('PSScene', 'SkySatScene')]) +@pytest.mark.parametrize("item_types", + ['PSScene', 'SkySatScene', 'PSScene, SkySatScene']) @pytest.mark.parametrize("interval, exit_code", [(None, 1), ('hou', 2), ('hour', 0)]) def test_data_stats_invalid_interval(invoke, item_types, interval, exit_code): @@ -630,8 +648,8 @@ def test_data_stats_invalid_interval(invoke, item_types, interval, exit_code): @respx.mock -@pytest.mark.parametrize( - "item_types", ['PSScene', 'SkySatScene', ('PSScene', 'SkySatScene')]) +@pytest.mark.parametrize("item_types", + ['PSScene', 'SkySatScene', 'PSScene, SkySatScene']) @pytest.mark.parametrize("interval", ['hour', 'day', 'week', 'month', 'year']) def test_data_stats_success(invoke, item_types, interval): """Test for planet data stats. Test with multiple item_types. diff --git a/tests/unit/test_data_item_type.py b/tests/unit/test_data_item_type.py new file mode 100644 index 000000000..41d7b78d5 --- /dev/null +++ b/tests/unit/test_data_item_type.py @@ -0,0 +1,55 @@ +# Copyright 2022 Planet Labs PBC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import pytest +import click +from planet.cli.data import check_item_types + +LOGGER = logging.getLogger(__name__) + + +class MockContext: + + def __init__(self): + self.obj = {} + + +@pytest.mark.parametrize("item_types", + [ + 'PSScene3Band', + 'MOD09GQ', + 'MYD09GA', + 'REOrthoTile', + 'SkySatCollect', + 'SkySatScene', + 'MYD09GQ', + 'Landsat8L1G', + 'Sentinel2L1C', + 'MOD09GA', + 'Sentinel1', + 'PSScene', + 'PSOrthoTile', + 'PSScene4Band', + 'REScene' + ]) +def test_item_type_success(item_types): + ctx = MockContext() + result = check_item_types(ctx, 'item_types', [item_types]) + assert result == [item_types] + + +def test_item_type_fail(): + ctx = MockContext() + with pytest.raises(click.BadParameter): + check_item_types(ctx, 'item_type', "bad_item_type") diff --git a/tests/unit/test_specs.py b/tests/unit/test_specs.py index 6f721e9de..21b4e5717 100644 --- a/tests/unit/test_specs.py +++ b/tests/unit/test_specs.py @@ -23,6 +23,23 @@ TEST_PRODUCT_BUNDLE = 'visual' # must be a valid item type for TEST_PRODUCT_BUNDLE TEST_ITEM_TYPE = 'PSScene' +ALL_ITEM_TYPES = [ + 'PSOrthoTile', + 'Sentinel1', + 'REOrthoTile', + 'PSScene', + 'PSScene4Band', + 'Landsat8L1G', + 'PSScene3Band', + 'REScene', + 'MOD09GA', + 'MYD09GA', + 'MOD09GQ', + 'SkySatCollect', + 'Sentinel2L1C', + 'MYD09GQ', + 'SkySatScene' +] def test_get_type_match(): @@ -90,6 +107,12 @@ def test_get_product_bundles(): assert TEST_PRODUCT_BUNDLE in bundles -def test_get_item_types(): - item_types = specs.get_item_types(TEST_PRODUCT_BUNDLE) +def test_get_item_types_with_bundle(): + item_types = specs.get_item_types(product_bundle=TEST_PRODUCT_BUNDLE) assert TEST_ITEM_TYPE in item_types + + +def test_get_item_types_without_bundle(): + item_types = specs.get_item_types() + for item in item_types: + assert item in ALL_ITEM_TYPES