Skip to content

Simpler extension interface #1243

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Oct 12, 2023
7 changes: 7 additions & 0 deletions pystac/asset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
if TYPE_CHECKING:
from pystac.collection import Collection
from pystac.common_metadata import CommonMetadata
from pystac.extensions.ext import AssetExt
from pystac.item import Item

A = TypeVar("A", bound="Asset")
Expand Down Expand Up @@ -261,6 +262,12 @@ def delete(self) -> None:
href = _absolute_href(self.href, self.owner, "delete")
os.remove(href)

@property
def ext(self) -> AssetExt:
from pystac.extensions.ext import AssetExt

return AssetExt(stac_object=self)


def _absolute_href(
href: str, owner: Optional[Union[Item, Collection]], action: str = "access"
Expand Down
7 changes: 7 additions & 0 deletions pystac/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)

if TYPE_CHECKING:
from pystac.extensions.ext import CollectionExt
from pystac.item import Item

C = TypeVar("C", bound="Collection")
Expand Down Expand Up @@ -830,3 +831,9 @@ def full_copy(
@classmethod
def matches_object_type(cls, d: Dict[str, Any]) -> bool:
return identify_stac_object_type(d) == STACObjectType.COLLECTION

@property
def ext(self) -> CollectionExt:
from pystac.extensions.ext import CollectionExt

return CollectionExt(stac_object=self)
54 changes: 50 additions & 4 deletions pystac/extensions/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import re
import warnings
from abc import ABC, abstractmethod
from typing import (
TYPE_CHECKING,
Any,
Dict,
Generic,
Expand All @@ -16,6 +19,9 @@

import pystac

if TYPE_CHECKING:
from pystac.extensions.item_assets import AssetDefinition

VERSION_REGEX = re.compile("/v[0-9].[0-9].*/")


Expand Down Expand Up @@ -157,7 +163,39 @@ def has_extension(cls, obj: S) -> bool:
@classmethod
def validate_owner_has_extension(
cls,
asset: pystac.Asset,
asset: Union[pystac.Asset, AssetDefinition],
add_if_missing: bool = False,
) -> None:
"""
DEPRECATED

.. deprecated:: 1.9
Use :meth:`ensure_owner_has_extension` instead.

Given an :class:`~pystac.Asset`, checks if the asset's owner has this
extension's schema URI in its :attr:`~pystac.STACObject.stac_extensions` list.
If ``add_if_missing`` is ``True``, the schema URI will be added to the owner.

Args:
asset : The asset whose owner should be validated.
add_if_missing : Whether to add the schema URI to the owner if it is
not already present. Defaults to False.

Raises:
STACError : If ``add_if_missing`` is ``True`` and ``asset.owner`` is
``None``.
"""
warnings.warn(
"ensure_owner_has_extension is deprecated and will be removed in v1.9. "
"Use ensure_owner_has_extension instead",
DeprecationWarning,
)
return cls.ensure_owner_has_extension(asset, add_if_missing)

@classmethod
def ensure_owner_has_extension(
cls,
asset: Union[pystac.Asset, AssetDefinition],
add_if_missing: bool = False,
) -> None:
"""Given an :class:`~pystac.Asset`, checks if the asset's owner has this
Expand All @@ -176,8 +214,8 @@ def validate_owner_has_extension(
if asset.owner is None:
if add_if_missing:
raise pystac.STACError(
"Attempted to use add_if_missing=True for an Asset with no owner. "
"Use Asset.set_owner or set add_if_missing=False."
"Attempted to use add_if_missing=True for an Asset or ItemAsset "
"with no owner. Use .set_owner or set add_if_missing=False."
)
else:
return
Expand Down Expand Up @@ -223,8 +261,16 @@ def ensure_has_extension(cls, obj: S, add_if_missing: bool = False) -> None:
cls.add_to(obj)

if not cls.has_extension(obj):
name = getattr(cls, "name", cls.__name__)
suggestion = (
f"``obj.ext.add('{name}')"
if hasattr(cls, "name")
else f"``{name}.add_to(obj)``"
)

raise pystac.ExtensionNotImplemented(
f"Could not find extension schema URI {cls.get_schema_uri()} in object."
f"Extension '{name}' is not implemented on object."
f"STAC producers can add the extension using {suggestion}"
)

@classmethod
Expand Down
10 changes: 5 additions & 5 deletions pystac/extensions/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Generic,
Iterable,
List,
Literal,
Optional,
Pattern,
TypeVar,
Expand All @@ -18,7 +19,7 @@
)

import pystac
import pystac.extensions.item_assets as item_assets
from pystac.extensions import item_assets
from pystac.extensions.base import (
ExtensionManagementMixin,
PropertiesExtension,
Expand Down Expand Up @@ -426,6 +427,7 @@ class ClassificationExtension(
method can be used to construct the proper class for you.
"""

name: Literal["classification"] = "classification"
properties: Dict[str, Any]
"""The :class:`~pystac.Asset` fields, including extension properties."""

Expand Down Expand Up @@ -534,12 +536,10 @@ def ext(cls, obj: T, add_if_missing: bool = False) -> ClassificationExtension[T]
cls.ensure_has_extension(obj, add_if_missing)
return cast(ClassificationExtension[T], ItemClassificationExtension(obj))
elif isinstance(obj, pystac.Asset):
cls.validate_owner_has_extension(obj, add_if_missing)
cls.ensure_owner_has_extension(obj, add_if_missing)
return cast(ClassificationExtension[T], AssetClassificationExtension(obj))
elif isinstance(obj, item_assets.AssetDefinition):
cls.ensure_has_extension(
cast(Union[pystac.Item, pystac.Collection], obj.owner), add_if_missing
)
cls.ensure_owner_has_extension(obj, add_if_missing)
return cast(
ClassificationExtension[T], ItemAssetsClassificationExtension(obj)
)
Expand Down
23 changes: 20 additions & 3 deletions pystac/extensions/datacube.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@
from __future__ import annotations

from abc import ABC
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union, cast
from typing import Any, Dict, Generic, List, Literal, Optional, TypeVar, Union, cast

import pystac
from pystac.extensions import item_assets
from pystac.extensions.base import ExtensionManagementMixin, PropertiesExtension
from pystac.extensions.hooks import ExtensionHooks
from pystac.utils import StringEnum, get_required, map_opt

T = TypeVar("T", pystac.Collection, pystac.Item, pystac.Asset)
T = TypeVar(
"T", pystac.Collection, pystac.Item, pystac.Asset, item_assets.AssetDefinition
)

SCHEMA_URI = "https://stac-extensions.github.io/datacube/v2.0.0/schema.json"

Expand Down Expand Up @@ -469,6 +472,8 @@ class DatacubeExtension(
>>> dc_ext = DatacubeExtension.ext(item)
"""

name: Literal["cube"] = "cube"

def apply(
self,
dimensions: Dict[str, Dimension],
Expand Down Expand Up @@ -543,8 +548,11 @@ def ext(cls, obj: T, add_if_missing: bool = False) -> DatacubeExtension[T]:
cls.ensure_has_extension(obj, add_if_missing)
return cast(DatacubeExtension[T], ItemDatacubeExtension(obj))
elif isinstance(obj, pystac.Asset):
cls.validate_owner_has_extension(obj, add_if_missing)
cls.ensure_owner_has_extension(obj, add_if_missing)
return cast(DatacubeExtension[T], AssetDatacubeExtension(obj))
elif isinstance(obj, item_assets.AssetDefinition):
cls.ensure_owner_has_extension(obj, add_if_missing)
return cast(DatacubeExtension[T], ItemAssetsDatacubeExtension(obj))
else:
raise pystac.ExtensionTypeError(cls._ext_error_message(obj))

Expand Down Expand Up @@ -614,6 +622,15 @@ def __repr__(self) -> str:
return "<AssetDatacubeExtension Item id={}>".format(self.asset_href)


class ItemAssetsDatacubeExtension(DatacubeExtension[item_assets.AssetDefinition]):
properties: Dict[str, Any]
asset_defn: item_assets.AssetDefinition

def __init__(self, item_asset: item_assets.AssetDefinition):
self.asset_defn = item_asset
self.properties = item_asset.properties


class DatacubeExtensionHooks(ExtensionHooks):
schema_uri: str = SCHEMA_URI
prev_extension_ids = {"datacube"}
Expand Down
31 changes: 28 additions & 3 deletions pystac/extensions/eo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Generic,
Iterable,
List,
Literal,
Optional,
Tuple,
TypeVar,
Expand All @@ -17,7 +18,7 @@
)

import pystac
from pystac.extensions import projection, view
from pystac.extensions import item_assets, projection, view
from pystac.extensions.base import (
ExtensionManagementMixin,
PropertiesExtension,
Expand All @@ -28,7 +29,7 @@
from pystac.summaries import RangeSummary
from pystac.utils import get_required, map_opt

T = TypeVar("T", pystac.Item, pystac.Asset)
T = TypeVar("T", pystac.Item, pystac.Asset, item_assets.AssetDefinition)

SCHEMA_URI: str = "https://stac-extensions.github.io/eo/v1.1.0/schema.json"
SCHEMA_URIS: List[str] = [
Expand Down Expand Up @@ -309,6 +310,8 @@ class EOExtension(
>>> eo_ext = EOExtension.ext(item)
"""

name: Literal["eo"] = "eo"

def apply(
self,
bands: Optional[List[Band]] = None,
Expand Down Expand Up @@ -408,8 +411,11 @@ def ext(cls, obj: T, add_if_missing: bool = False) -> EOExtension[T]:
cls.ensure_has_extension(obj, add_if_missing)
return cast(EOExtension[T], ItemEOExtension(obj))
elif isinstance(obj, pystac.Asset):
cls.validate_owner_has_extension(obj, add_if_missing)
cls.ensure_owner_has_extension(obj, add_if_missing)
return cast(EOExtension[T], AssetEOExtension(obj))
elif isinstance(obj, item_assets.AssetDefinition):
cls.ensure_owner_has_extension(obj, add_if_missing)
return cast(EOExtension[T], ItemAssetsEOExtension(obj))
else:
raise pystac.ExtensionTypeError(cls._ext_error_message(obj))

Expand Down Expand Up @@ -534,6 +540,25 @@ def __repr__(self) -> str:
return "<AssetEOExtension Asset href={}>".format(self.asset_href)


class ItemAssetsEOExtension(EOExtension[item_assets.AssetDefinition]):
properties: Dict[str, Any]
asset_defn: item_assets.AssetDefinition

def _get_bands(self) -> Optional[List[Band]]:
if BANDS_PROP not in self.properties:
return None
return list(
map(
lambda band: Band(band),
cast(List[Dict[str, Any]], self.properties.get(BANDS_PROP)),
)
)

def __init__(self, item_asset: item_assets.AssetDefinition):
self.asset_defn = item_asset
self.properties = item_asset.properties


class SummariesEOExtension(SummariesExtension):
"""A concrete implementation of :class:`~SummariesExtension` that extends
the ``summaries`` field of a :class:`~pystac.Collection` to include properties
Expand Down
Loading