diff --git a/planet/clients/base.py b/planet/clients/base.py new file mode 100644 index 00000000..23f224b8 --- /dev/null +++ b/planet/clients/base.py @@ -0,0 +1,27 @@ +from typing import Any, AsyncIterator, Coroutine, Iterator, TypeVar +from planet.http import Session + +T = TypeVar("T") + + +class _BaseClient: + + def __init__(self, session: Session, base_url: str): + """ + Parameters: + session: Open session connected to server. + base_url: The base URL to use. Defaults to production data API + base url. + """ + self._session = session + + self._base_url = base_url + if self._base_url.endswith('/'): + self._base_url = self._base_url[:-1] + + def _call_sync(self, f: Coroutine[Any, Any, T]) -> T: + """block on an async function call, using the call_sync method of the session""" + return self._session._call_sync(f) + + def _aiter_to_iter(self, aiter: AsyncIterator[T]) -> Iterator[T]: + return self._session._aiter_to_iter(aiter) diff --git a/planet/clients/data.py b/planet/clients/data.py index acdd9023..83552cb5 100644 --- a/planet/clients/data.py +++ b/planet/clients/data.py @@ -17,9 +17,11 @@ import logging from pathlib import Path import time -from typing import Any, AsyncIterator, Awaitable, Callable, Dict, List, Optional, TypeVar, Union +from typing import Any, AsyncIterator, Callable, Dict, List, Optional, TypeVar, Union import uuid +from planet.clients.base import _BaseClient + from ..data_filter import empty_filter from .. import exceptions from ..constants import PLANET_BASE_URL @@ -67,7 +69,7 @@ class Searches(Paged): ITEMS_KEY = 'searches' -class DataClient: +class DataClient(_BaseClient): """High-level asynchronous access to Planet's data API. Example: @@ -92,15 +94,7 @@ def __init__(self, session: Session, base_url: Optional[str] = None): base_url: The base URL to use. Defaults to production data API base url. """ - self._session = session - - self._base_url = base_url or BASE_URL - if self._base_url.endswith('/'): - self._base_url = self._base_url[:-1] - - def _call_sync(self, f: Awaitable[T]) -> T: - """block on an async function call, using the call_sync method of the session""" - return self._session._call_sync(f) + super().__init__(session, base_url or BASE_URL) @staticmethod def _check_search_id(sid): diff --git a/planet/clients/features.py b/planet/clients/features.py index e11a7e9d..c0a353d2 100644 --- a/planet/clients/features.py +++ b/planet/clients/features.py @@ -13,8 +13,9 @@ # the License. import logging -from typing import Any, AsyncIterator, Awaitable, Optional, Union, TypeVar +from typing import Any, AsyncIterator, Optional, Union, TypeVar +from planet.clients.base import _BaseClient from planet.http import Session from planet.models import Feature, GeoInterface, Paged from planet.constants import PLANET_BASE_URL @@ -26,7 +27,7 @@ LOGGER = logging.getLogger() -class FeaturesClient: +class FeaturesClient(_BaseClient): """Asyncronous Features API client For more information about the Features API, see the documentation at @@ -55,15 +56,7 @@ def __init__(self, base_url: The base URL to use. Defaults to the Features API base url at api.planet.com. """ - self._session = session - - self._base_url = base_url or BASE_URL - if self._base_url.endswith('/'): - self._base_url = self._base_url[:-1] - - def _call_sync(self, f: Awaitable[T]) -> T: - """block on an async function call, using the call_sync method of the session""" - return self._session._call_sync(f) + super().__init__(session, base_url or BASE_URL) async def list_collections(self, limit: int = 0) -> AsyncIterator[dict]: """ diff --git a/planet/clients/orders.py b/planet/clients/orders.py index 98ea9b90..feb9e468 100644 --- a/planet/clients/orders.py +++ b/planet/clients/orders.py @@ -16,12 +16,14 @@ import asyncio import logging import time -from typing import AsyncIterator, Awaitable, Callable, Dict, List, Optional, Sequence, TypeVar, Union +from typing import AsyncIterator, Callable, Dict, List, Optional, Sequence, TypeVar, Union import uuid import json import hashlib from pathlib import Path + +from planet.clients.base import _BaseClient from .. import exceptions from ..constants import PLANET_BASE_URL from ..http import Session @@ -68,7 +70,7 @@ def is_final(cls, test): return cls.passed('running', test) -class OrdersClient: +class OrdersClient(_BaseClient): """High-level asynchronous access to Planet's orders API. Example: @@ -93,15 +95,7 @@ def __init__(self, session: Session, base_url: Optional[str] = None): base_url: The base URL to use. Defaults to production orders API base url. """ - self._session = session - - self._base_url = base_url or BASE_URL - if self._base_url.endswith('/'): - self._base_url = self._base_url[:-1] - - def _call_sync(self, f: Awaitable[T]) -> T: - """block on an async function call, using the call_sync method of the session""" - return self._session._call_sync(f) + super().__init__(session, base_url or BASE_URL) @staticmethod def _check_order_id(oid): diff --git a/planet/clients/subscriptions.py b/planet/clients/subscriptions.py index 55fe07a9..1856569d 100644 --- a/planet/clients/subscriptions.py +++ b/planet/clients/subscriptions.py @@ -1,10 +1,11 @@ """Planet Subscriptions API Python client.""" import logging -from typing import Any, AsyncIterator, Awaitable, Dict, Optional, Sequence, TypeVar, List +from typing import Any, AsyncIterator, Dict, Optional, Sequence, TypeVar, List from typing_extensions import Literal +from planet.clients.base import _BaseClient from planet.exceptions import APIError, ClientError from planet.http import Session from planet.models import Paged @@ -17,7 +18,7 @@ T = TypeVar("T") -class SubscriptionsClient: +class SubscriptionsClient(_BaseClient): """A Planet Subscriptions Service API 1.0.0 client. The methods of this class forward request parameters to the @@ -55,15 +56,7 @@ def __init__(self, base_url: The base URL to use. Defaults to production subscriptions API base url. """ - self._session = session - - self._base_url = base_url or BASE_URL - if self._base_url.endswith('/'): - self._base_url = self._base_url[:-1] - - def _call_sync(self, f: Awaitable[T]) -> T: - """block on an async function call, using the call_sync method of the session""" - return self._session._call_sync(f) + super().__init__(session, base_url or BASE_URL) async def list_subscriptions(self, status: Optional[Sequence[str]] = None, diff --git a/planet/http.py b/planet/http.py index e02bc604..53f626d8 100644 --- a/planet/http.py +++ b/planet/http.py @@ -22,7 +22,7 @@ import random import threading import time -from typing import AsyncGenerator, Awaitable, Optional, TypeVar +from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Coroutine, Iterator, Optional, TypeVar import httpx from typing_extensions import Literal @@ -299,10 +299,26 @@ def _start_background_loop(loop): daemon=True) self._loop_thread.start() - def _call_sync(self, f: Awaitable[T]) -> T: + def _call_sync(self, f: Coroutine[Any, Any, T]) -> T: self._init_loop() return asyncio.run_coroutine_threadsafe(f, self._loop).result() + def _aiter_to_iter(self, aiter: AsyncIterator[T]) -> Iterator[T]: + self._init_loop() + + # this turns an awaitable into a coroutine - works around typing + # check on run_coroutine_threadsafe (which actually does check if + # the argument is a coroutine) + async def coro(a: Awaitable[T]) -> T: + return await a + + try: + while True: + yield asyncio.run_coroutine_threadsafe(coro(aiter.__anext__()), + self._loop).result() + except StopAsyncIteration: + pass + @classmethod async def _raise_for_status(cls, response): if response.is_error: diff --git a/planet/sync/data.py b/planet/sync/data.py index 4a4c9620..dfec714d 100644 --- a/planet/sync/data.py +++ b/planet/sync/data.py @@ -13,7 +13,7 @@ # the License. """Functionality for interacting with the data api""" from pathlib import Path -from typing import Any, Callable, Dict, Iterator, List, Optional, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, TypeVar, Union from planet.models import GeojsonLike @@ -27,6 +27,8 @@ WAIT_DELAY = 5 WAIT_MAX_ATTEMPTS = 200 +T = TypeVar("T") + class DataAPI: """Data API client""" @@ -75,18 +77,13 @@ def search( references """ - results = self._client.search(item_types, - search_filter, - name, - sort, - limit, - geometry) - - try: - while True: - yield self._client._call_sync(results.__anext__()) - except StopAsyncIteration: - pass + return self._client._aiter_to_iter( + self._client.search(item_types, + search_filter, + name, + sort, + limit, + geometry)) def create_search( self, @@ -183,13 +180,9 @@ def list_searches(self, planet.exceptions.ClientError: If sort or search_type are not valid. """ - results = self._client.list_searches(sort, search_type, limit) - try: - while True: - yield self._client._call_sync(results.__anext__()) - except StopAsyncIteration: - pass + return self._client._aiter_to_iter( + self._client.list_searches(sort, search_type, limit)) def delete_search(self, search_id: str): """Delete an existing saved search. @@ -242,13 +235,8 @@ def run_search(self, planet.exceptions.ClientError: If search_id or sort is not valid. """ - results = self._client.run_search(search_id, sort, limit) - - try: - while True: - yield self._client._call_sync(results.__anext__()) - except StopAsyncIteration: - pass + return self._client._aiter_to_iter( + self._client.run_search(search_id, sort, limit)) def get_stats(self, item_types: List[str], diff --git a/planet/sync/features.py b/planet/sync/features.py index 9f0ef08b..fec47519 100644 --- a/planet/sync/features.py +++ b/planet/sync/features.py @@ -44,13 +44,8 @@ def list_collections(self, limit: int = 0) -> Iterator[dict]: print(collection) ``` """ - collections = self._client.list_collections(limit=limit) - - try: - while True: - yield self._client._call_sync(collections.__anext__()) - except StopAsyncIteration: - pass + return self._client._aiter_to_iter( + self._client.list_collections(limit=limit)) def get_collection(self, collection_id: str) -> dict: """ @@ -109,13 +104,8 @@ def list_items(self, results = pl.data.search(["PSScene"], geometry=feature]) ``` """ - results = self._client.list_items(collection_id, limit=limit) - - try: - while True: - yield self._client._call_sync(results.__anext__()) - except StopAsyncIteration: - pass + return self._client._aiter_to_iter( + self._client.list_items(collection_id, limit=limit)) def get_item(self, collection_id: str, feature_id: str) -> Feature: """ diff --git a/planet/sync/orders.py b/planet/sync/orders.py index 247b0ce3..f928e173 100644 --- a/planet/sync/orders.py +++ b/planet/sync/orders.py @@ -307,18 +307,13 @@ def list_orders(self, planet.exceptions.APIError: On API error. planet.exceptions.ClientError: If state is not valid. """ - results = self._client.list_orders(state, - limit, - source_type, - name, - name__contains, - created_on, - last_modified, - hosting, - sort_by) - - try: - while True: - yield self._client._call_sync(results.__anext__()) - except StopAsyncIteration: - pass + return self._client._aiter_to_iter( + self._client.list_orders(state, + limit, + source_type, + name, + name__contains, + created_on, + last_modified, + hosting, + sort_by)) diff --git a/planet/sync/subscriptions.py b/planet/sync/subscriptions.py index f22f899a..abe42962 100644 --- a/planet/sync/subscriptions.py +++ b/planet/sync/subscriptions.py @@ -101,24 +101,19 @@ def list_subscriptions(self, ClientError: on a client error. """ - results = self._client.list_subscriptions(status, - limit, - created, - end_time, - hosting, - name__contains, - name, - source_type, - start_time, - sort_by, - updated, - page_size) - - try: - while True: - yield self._client._call_sync(results.__anext__()) - except StopAsyncIteration: - pass + return self._client._aiter_to_iter( + self._client.list_subscriptions(status, + limit, + created, + end_time, + hosting, + name__contains, + name, + source_type, + start_time, + sort_by, + updated, + page_size)) def create_subscription(self, request: Dict) -> Dict: """Create a Subscription. @@ -253,13 +248,8 @@ def get_results( APIError: on an API server error. ClientError: on a client error. """ - results = self._client.get_results(subscription_id, status, limit) - - try: - while True: - yield self._client._call_sync(results.__anext__()) - except StopAsyncIteration: - pass + return self._client._aiter_to_iter( + self._client.get_results(subscription_id, status, limit)) def get_results_csv( self, @@ -285,17 +275,13 @@ def get_results_csv( APIError: on an API server error. ClientError: on a client error. """ - results = self._client.get_results_csv(subscription_id, status) # Note: retries are not implemented yet. This project has # retry logic for HTTP requests, but does not handle errors # during streaming. We may want to consider a retry decorator # for this entire method a la stamina: # https://github.com/hynek/stamina. - try: - while True: - yield self._client._call_sync(results.__anext__()) - except StopAsyncIteration: - pass + return self._client._aiter_to_iter( + self._client.get_results_csv(subscription_id, status)) def get_summary(self) -> Dict[str, Any]: """Summarize the status of all subscriptions via GET.