Skip to content

Feat/add prediction artifact and upload method #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
25 changes: 4 additions & 21 deletions polaris/dataset/_dataset_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,31 +53,14 @@ class DatasetV2(BaseDataset):
def _validate_v2_dataset_model(self) -> Self:
"""Verifies some dependencies between properties"""

# Since the keys for subgroups are not ordered, we have no easy way to index these groups.
# Any subgroup should therefore have a special array that defines the index for that group.
for group in self.zarr_root.group_keys():
if _INDEX_ARRAY_KEY not in self.zarr_root[group].array_keys():
raise InvalidDatasetError(f"Group {group} does not have an index array.")

index_arr = self.zarr_root[group][_INDEX_ARRAY_KEY]
if len(index_arr) != len(self.zarr_root[group]) - 1:
raise InvalidDatasetError(
f"Length of index array for group {group} does not match the size of the group."
)
if any(x not in self.zarr_root[group] for x in index_arr):
raise InvalidDatasetError(
f"Keys of index array for group {group} does not match the group members."
)

# Check the structure of the Zarr archive
# All arrays or groups in the root should have the same length.
if len(list(self.zarr_root.group_keys())) > 0:
raise InvalidDatasetError("Datasets can't have subgroups")
# Check all arrays at root have the same length
lengths = {len(self.zarr_root[k]) for k in self.zarr_root.array_keys()}
lengths.update({len(self.zarr_root[k][_INDEX_ARRAY_KEY]) for k in self.zarr_root.group_keys()})
if len(lengths) > 1:
raise InvalidDatasetError(
f"All arrays or groups in the root should have the same length, found the following lengths: {lengths}"
f"All arrays at root should have the same length, found the following lengths: {lengths}"
)

return self

@property
Expand Down
3 changes: 3 additions & 0 deletions polaris/evaluate/_predictions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from collections import defaultdict
import logging

import numpy as np
from pydantic import (
Expand All @@ -22,6 +23,8 @@
SlugCompatibleStringType,
)

logger = logging.getLogger(__name__)


class BenchmarkPredictions(BaseModel):
"""
Expand Down
58 changes: 58 additions & 0 deletions polaris/hub/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
from polaris.model import Model
from polaris.dataset import DatasetV1, DatasetV2
from polaris.evaluate import BenchmarkResultsV1, BenchmarkResultsV2, CompetitionPredictions
from polaris.prediction._predictions_v2 import Predictions
from polaris.hub.external_client import ExternalAuthClient
from polaris.hub.oauth import CachedTokenAuth
from polaris.hub.settings import PolarisHubSettings
from polaris.hub.storage import StorageSession
from polaris.utils.context import track_progress
from polaris.utils.errors import (
PolarisCreateArtifactError,
Expand Down Expand Up @@ -714,3 +716,59 @@ def upload_model(
progress.log(
f"[green]Your model has been successfully uploaded to the Hub. View it here: {model_url}"
)

def upload_predictions(
self,
prediction: Predictions,
timeout: TimeoutTypes = (10, 200),
owner: HubOwner | str | None = None,
if_exists: ZarrConflictResolution = "replace",
):
"""
Upload a Predictions artifact (with Zarr archive) to the Polaris Hub.
"""
# Set owner
prediction.owner = HubOwner.normalize(owner or prediction.owner)
prediction_json = prediction.model_dump(by_alias=True, exclude_none=True)

# Only include modelArtifactId if there's actually a model
if prediction.model_artifact_id:
prediction_json["modelArtifactId"] = prediction.model_artifact_id
prediction_json["benchmarkArtifactId"] = prediction.benchmark_artifact_id

# Step 1: Upload metadata to Hub
with track_progress(description="Uploading prediction metadata", total=1) as (progress, task):
response = self._base_request_to_hub(
url=f"/v2/prediction/{prediction.artifact_id}",
method="PUT",
withhold_token=False,
json={
"zarrManifestFileContent": {
"md5Sum": prediction.zarr_manifest_md5sum,
},
**prediction_json,
},
timeout=timeout,
)
inserted = response.json()
prediction.slug = inserted["slug"]
prediction_url = urljoin(self.settings.hub_url, response.headers.get("Content-Location"))
progress.log(f"[green]Prediction metadata uploaded. View it here: {prediction_url}")

# Step 2: Upload manifest file
with StorageSession(self, "write", prediction.urn) as storage:
with track_progress(description="Copying manifest file", total=1):
with open(prediction.zarr_manifest_path, "rb") as manifest_file:
storage.set_file("manifest", manifest_file.read())

# Step 3: Upload Zarr archive
with track_progress(description="Copying Zarr archive", total=1) as (progress_zarr, task_zarr):
progress_zarr.log("[yellow]This may take a while.")
destination = storage.store("root")
destination.copy_from_source(
prediction.zarr_root.store, if_exists=if_exists, log=progress_zarr.log
)

progress.log(
f"[green]Your prediction has been successfully uploaded to the Hub. View it here: {prediction_url}"
)
8 changes: 7 additions & 1 deletion polaris/hub/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class DatasetV1Paths(ArtifactPaths):


class DatasetV2Paths(ArtifactPaths):
type: Literal["dataset-v2"] = "dataset-v2"
root: AnyUrlString = Field(json_schema_extra={"store": True})
manifest: AnyUrlString = Field(json_schema_extra={"file": True})

Expand All @@ -97,11 +98,16 @@ class BenchmarkV2Paths(ArtifactPaths):
test_2: int = 0


class PredictionPaths(ArtifactPaths):
type: Literal["prediction"] = "prediction"
root: AnyUrlString = Field(json_schema_extra={"store": True})
manifest: AnyUrlString = Field(json_schema_extra={"file": True})

class StorageTokenData(BaseModel):
key: str
secret: str
endpoint: HttpUrlString
paths: DatasetV1Paths | DatasetV2Paths | BenchmarkV2Paths = Field(union_mode="smart")
paths: DatasetV1Paths | DatasetV2Paths | BenchmarkV2Paths | PredictionPaths = Field(union_mode="smart")


class HubOAuth2Token(BaseModel):
Expand Down
25 changes: 23 additions & 2 deletions polaris/hub/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from zarr.storage import Store
from zarr.util import buffer_size

from polaris.hub.oauth import BenchmarkV2Paths, DatasetV1Paths, DatasetV2Paths, HubStorageOAuth2Token
from polaris.hub.oauth import BenchmarkV2Paths, DatasetV1Paths, DatasetV2Paths, HubStorageOAuth2Token, PredictionPaths
from polaris.utils.context import track_progress
from polaris.utils.errors import PolarisHubError
from polaris.utils.types import ArtifactUrn, ZarrConflictResolution
Expand Down Expand Up @@ -549,7 +549,7 @@ def ensure_active_token(self, token: OAuth2Token | None = None) -> bool:
return True

@property
def paths(self) -> DatasetV1Paths | DatasetV2Paths | BenchmarkV2Paths:
def paths(self) -> DatasetV1Paths | DatasetV2Paths | BenchmarkV2Paths | PredictionPaths:
return self.token.extra_data.paths

def _relative_path(self, path: str) -> PurePath:
Expand Down Expand Up @@ -583,3 +583,24 @@ def set_file(self, path: str, value: bytes | bytearray):
)

store[relative_path.name] = value

def store(self, path: str) -> S3Store:
"""
Create an S3Store for the specified path.
"""
if path not in self.paths.stores:
raise NotImplementedError(
f"{type(self.paths).__name__} only supports these stores: {self.paths.stores}."
)

relative_path = self._relative_path(getattr(self.paths, path))

storage_data = self.token.extra_data
print(f"[store] JWT: {self.token.access_token}")
return S3Store(
path=relative_path,
access_key=storage_data.key,
secret_key=storage_data.secret,
token=f"jwt/{self.token.access_token}",
endpoint_url=storage_data.endpoint,
)
3 changes: 3 additions & 0 deletions polaris/prediction/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._predictions_v2 import Predictions

__all__ = ["Predictions"]
Loading
Loading