Skip to content

Feat/add prediction artifact and upload method #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
b0b118e
added predictions artifact and upload method
j279li May 29, 2025
3f76ffc
fix import
j279li Jun 3, 2025
637361d
fix import again
j279li Jun 3, 2025
36a7993
fix typo
j279li Jun 3, 2025
8c7a4e9
uploaded correct prediction file
j279li Jun 3, 2025
92f39b7
small fixes to client and init.py for evaluate
j279li Jun 3, 2025
9271f69
remove duplicate keyword
j279li Jun 3, 2025
db77f00
remove metadata consolidation
j279li Jun 4, 2025
7d0abc0
dataset subgroup change
j279li Jun 5, 2025
6d6dfcc
refactor of prediction files + update
j279li Jun 5, 2025
10d94a9
update tests
j279li Jun 5, 2025
7692997
added cached directory and updated payload json
j279li Jun 6, 2025
d78cc28
ruff format
j279li Jun 6, 2025
5b6349f
small changes
j279li Jun 7, 2025
be4dee7
zarr upload + store
j279li Jun 9, 2025
8a73ad2
ruff
j279li Jun 11, 2025
fee2898
small fix
j279li Jun 11, 2025
048df0a
small fixes
j279li Jun 17, 2025
ace3587
move zarr to utils + prediction artifact method changes
j279li Jun 19, 2025
7773baa
small fixes
j279li Jun 19, 2025
fc6c343
formatting
j279li Jun 19, 2025
169c751
ruff
j279li Jun 19, 2025
c7eaaf4
moved predictions upload method to benchmarkv2
j279li Jun 19, 2025
05d98a0
circular import + test fix
j279li Jun 19, 2025
090dc51
update error message`
j279li Jun 19, 2025
c5d320c
type fix
j279li Jun 19, 2025
3f2e54d
fix to zarr creation
j279li Jun 20, 2025
fbb5add
formatting
j279li Jun 20, 2025
fa45532
added some tests for predictions
j279li Jun 20, 2025
4203bf5
updated predictionsv2
j279li Jun 23, 2025
77df89c
update validator + to_zarr
j279li Jun 23, 2025
7548ab7
fixed circular imports + added some more focused tests
j279li Jun 25, 2025
7470cee
format
j279li Jun 25, 2025
0e316e7
updated docstring
j279li Jun 25, 2025
3f185a2
update to prediction tests
j279li Jun 25, 2025
5c0cbf4
updated tests and refactored predictions
j279li Jun 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions docs/api/dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,3 @@

---

::: polaris.dataset.zarr
options:
filters: ["!^_"]

---
56 changes: 55 additions & 1 deletion polaris/benchmark/_benchmark_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,19 @@

from polaris.benchmark import BenchmarkSpecification
from polaris.evaluate.utils import evaluate_benchmark
from polaris.utils.types import IncomingPredictionsType
from polaris.utils.types import (
IncomingPredictionsType,
HubOwner,
SlugCompatibleStringType,
HubUser,
)

from polaris.evaluate import BenchmarkResultsV2
from polaris.benchmark._split_v2 import SplitSpecificationV2Mixin
from polaris.dataset import DatasetV2, Subset
from polaris.utils.errors import InvalidBenchmarkError
from polaris.utils.types import ColumnName
from polaris.model import Model


class BenchmarkV2Specification(SplitSpecificationV2Mixin, BenchmarkSpecification[BenchmarkResultsV2]):
Expand Down Expand Up @@ -154,3 +160,51 @@ def evaluate(
)

return BenchmarkResultsV2(results=scores, benchmark_artifact_id=self.artifact_id)

def submit_predictions(
self,
predictions: IncomingPredictionsType,
prediction_name: SlugCompatibleStringType,
prediction_owner: str,
contributors: list[HubUser] | None = None,
model: Model | None = None,
description: str = "",
tags: list[str] | None = None,
user_attributes: dict[str, str] | None = None,
) -> None:
"""
Convenient wrapper around the
[`PolarisHubClient.submit_benchmark_predictions`][polaris.hub.client.PolarisHubClient.submit_benchmark_predictions] method.
It handles the creation of a standardized Predictions object, which is expected by the Hub, automatically.

Args:
predictions: The predictions for each test set defined in the benchmark.
prediction_name: The name of the prediction.
prediction_owner: The slug of the user/organization which owns the prediction.
contributors: The users credited with generating these predictions.
model: (Optional) The Model artifact used to generate these predictions.
description: An optional and short description of the predictions.
tags: An optional list of tags to categorize the prediction by.
user_attributes: An optional dict with additional, textual user attributes.
"""
from polaris.hub.client import PolarisHubClient
from polaris.prediction import BenchmarkPredictionsV2

standardized_predictions = BenchmarkPredictionsV2(
name=prediction_name,
owner=HubOwner(slug=prediction_owner),
dataset_zarr_root=self.dataset.zarr_root,
benchmark_artifact_id=self.artifact_id,
predictions=predictions,
target_labels=list(self.target_cols),
test_set_labels=self.test_set_labels,
test_set_sizes=self.test_set_sizes,
contributors=contributors or [],
model=model,
description=description,
tags=tags or [],
user_attributes=user_attributes or {},
)

with PolarisHubClient() as client:
client.submit_benchmark_predictions(prediction=standardized_predictions, owner=prediction_owner)
2 changes: 1 addition & 1 deletion polaris/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from polaris.dataset._dataset_v2 import DatasetV2
from polaris.dataset._factory import DatasetFactory, create_dataset_from_file, create_dataset_from_files
from polaris.dataset._subset import Subset
from polaris.dataset.zarr import codecs
from polaris.utils.zarr import codecs

__all__ = [
"create_dataset_from_file",
Expand Down
4 changes: 2 additions & 2 deletions polaris/dataset/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@
from polaris._artifact import BaseArtifactModel
from polaris.dataset._adapters import Adapter
from polaris.dataset._column import ColumnAnnotation
from polaris.dataset.zarr import MemoryMappedDirectoryStore
from polaris.dataset.zarr._utils import check_zarr_codecs, load_zarr_group_to_memory
from polaris.utils.zarr import MemoryMappedDirectoryStore
from polaris.utils.zarr._utils import check_zarr_codecs, load_zarr_group_to_memory
from polaris.utils.constants import DEFAULT_CACHE_DIR
from polaris.utils.context import track_progress
from polaris.utils.dict2html import dict2html
Expand Down
2 changes: 1 addition & 1 deletion polaris/dataset/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from polaris.dataset._adapters import Adapter
from polaris.dataset._base import BaseDataset
from polaris.dataset.zarr import ZarrFileChecksum, compute_zarr_checksum
from polaris.utils.zarr import ZarrFileChecksum, compute_zarr_checksum
from polaris.mixins._checksum import ChecksumMixin
from polaris.utils.errors import InvalidDatasetError
from polaris.utils.types import (
Expand Down
30 changes: 8 additions & 22 deletions polaris/dataset/_dataset_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from polaris.dataset._adapters import Adapter
from polaris.dataset._base import BaseDataset
from polaris.dataset.zarr._manifest import calculate_file_md5, generate_zarr_manifest
from polaris.utils.zarr._manifest import calculate_file_md5, generate_zarr_manifest
from polaris.utils.errors import InvalidDatasetError
from polaris.utils.types import ChecksumStrategy, HubOwner, ZarrConflictResolution

Expand Down Expand Up @@ -53,31 +53,17 @@ class DatasetV2(BaseDataset):
def _validate_v2_dataset_model(self) -> Self:
"""Verifies some dependencies between properties"""

# Since the keys for subgroups are not ordered, we have no easy way to index these groups.
# Any subgroup should therefore have a special array that defines the index for that group.
for group in self.zarr_root.group_keys():
if _INDEX_ARRAY_KEY not in self.zarr_root[group].array_keys():
raise InvalidDatasetError(f"Group {group} does not have an index array.")

index_arr = self.zarr_root[group][_INDEX_ARRAY_KEY]
if len(index_arr) != len(self.zarr_root[group]) - 1:
raise InvalidDatasetError(
f"Length of index array for group {group} does not match the size of the group."
)
if any(x not in self.zarr_root[group] for x in index_arr):
raise InvalidDatasetError(
f"Keys of index array for group {group} does not match the group members."
)

# Check the structure of the Zarr archive
# All arrays or groups in the root should have the same length.
group_keys = list(self.zarr_root.group_keys())
if len(group_keys) > 0:
raise InvalidDatasetError(
f"The Zarr archive of a Dataset can't have any subgroups. Found {group_keys}."
)
# Check all arrays at root have the same length
lengths = {len(self.zarr_root[k]) for k in self.zarr_root.array_keys()}
lengths.update({len(self.zarr_root[k][_INDEX_ARRAY_KEY]) for k in self.zarr_root.group_keys()})
if len(lengths) > 1:
raise InvalidDatasetError(
f"All arrays or groups in the root should have the same length, found the following lengths: {lengths}"
f"All arrays at root should have the same length, found the following lengths: {lengths}"
)

return self

@property
Expand Down
4 changes: 2 additions & 2 deletions polaris/dataset/converters/_pdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from polaris.dataset import ColumnAnnotation, Modality
from polaris.dataset._adapters import Adapter
from polaris.dataset.converters._base import Converter, FactoryProduct
from polaris.dataset.zarr._utils import load_zarr_group_to_memory
from polaris.utils.zarr._utils import load_zarr_group_to_memory

if TYPE_CHECKING:
from polaris.dataset import DatasetFactory
Expand Down Expand Up @@ -77,7 +77,7 @@ def zarr_to_pdb(atom_dict: zarr.Group):
return struc.array(atom_array)


@deprecated("Please use the custom codecs in `polaris.dataset.zarr.codecs` instead.")
@deprecated("Please use the custom codecs in `polaris.utils.zarr.codecs` instead.")
class PDBConverter(Converter):
"""
Converts PDB files into a Polaris dataset based on fastpdb.
Expand Down
2 changes: 1 addition & 1 deletion polaris/dataset/converters/_sdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from polaris.dataset import DatasetFactory


@deprecated("Please use the custom codecs in `polaris.dataset.zarr.codecs` instead.")
@deprecated("Please use the custom codecs in `polaris.utils.zarr.codecs` instead.")
class SDFConverter(Converter):
"""
Converts a SDF file into a Polaris dataset.
Expand Down
2 changes: 1 addition & 1 deletion polaris/dataset/converters/_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from polaris.dataset import DatasetFactory


@deprecated("Please use the custom codecs in `polaris.dataset.zarr.codecs` instead.")
@deprecated("Please use the custom codecs in `polaris.utils.zarr.codecs` instead.")
class ZarrConverter(Converter):
"""Parse a [.zarr](https://zarr.readthedocs.io/en/stable/index.html) archive into a Polaris `Dataset`.

Expand Down
65 changes: 65 additions & 0 deletions polaris/hub/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
from polaris.model import Model
from polaris.dataset import DatasetV1, DatasetV2
from polaris.evaluate import BenchmarkResultsV1, BenchmarkResultsV2, CompetitionPredictions
from polaris.prediction._predictions_v2 import BenchmarkPredictionsV2
from polaris.hub.external_client import ExternalAuthClient
from polaris.hub.oauth import CachedTokenAuth
from polaris.hub.settings import PolarisHubSettings
from polaris.hub.storage import StorageSession
from polaris.utils.context import track_progress
from polaris.utils.errors import (
PolarisCreateArtifactError,
Expand Down Expand Up @@ -714,3 +716,66 @@ def upload_model(
progress.log(
f"[green]Your model has been successfully uploaded to the Hub. View it here: {model_url}"
)

def submit_benchmark_predictions(
self,
prediction: BenchmarkPredictionsV2,
timeout: TimeoutTypes = (10, 200),
owner: HubOwner | str | None = None,
if_exists: ZarrConflictResolution = "replace",
):
"""Submit predictions for a benchmark to the Polaris Hub.

This method handles uploading predictions for a benchmark to the Hub. The predictions must be
provided as a BenchmarkPredictionsV2 object, which ensures proper validation and formatting.

Info: Owner
The owner of the predictions will automatically be inferred from the prediction object.
You can override this by passing an explicit owner to this method.

Args:
prediction: A BenchmarkPredictionsV2 instance containing the predictions and metadata.
owner: Which Hub user or organization owns the artifact.
timeout: Request timeout values. User can modify the value when uploading large dataset as needed.
if_exists: Action for handling existing files in the Zarr archive. Options are 'raise' to throw
an error, 'replace' to overwrite, or 'skip' to proceed without altering the existing files.
"""
# Ensure Zarr archive is created
prediction.to_zarr()

# Set owner
prediction.owner = HubOwner.normalize(owner or prediction.owner)
prediction_json = prediction.model_dump(by_alias=True, exclude_none=True)

# Step 1: Upload metadata to Hub
with track_progress(description="Uploading prediction metadata", total=1) as (progress, task):
response = self._base_request_to_hub(
url=f"/v2/prediction/{prediction.artifact_id}",
method="PUT",
withhold_token=False,
json={
"zarrManifestFileContent": {
"md5Sum": prediction.zarr_manifest_md5sum,
},
**prediction_json,
},
timeout=timeout,
)
inserted = response.json()
prediction.slug = inserted["slug"]

# Step 2: Upload manifest file
with StorageSession(self, "write", prediction.urn) as storage:
with track_progress(description="Copying manifest file", total=1):
with open(prediction.zarr_manifest_path, "rb") as manifest_file:
storage.set_file("manifest", manifest_file.read())

# Step 3: Upload Zarr archive
with track_progress(description="Copying Zarr archive", total=1) as (progress_zarr, task_zarr):
progress_zarr.log("[yellow]This may take a while.")
destination = storage.store("root")
destination.copy_from_source(
prediction.zarr_root.store, if_exists=if_exists, log=progress_zarr.log
)

progress.log("[green]Your prediction has been successfully uploaded to the Hub.")
11 changes: 10 additions & 1 deletion polaris/hub/oauth.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ class DatasetV1Paths(ArtifactPaths):


class DatasetV2Paths(ArtifactPaths):
# Discriminator field used to identify this as a dataset-v2 type when deserializing paths
type: Literal["dataset-v2"] = "dataset-v2"
root: AnyUrlString = Field(json_schema_extra={"store": True})
manifest: AnyUrlString = Field(json_schema_extra={"file": True})

Expand All @@ -97,11 +99,18 @@ class BenchmarkV2Paths(ArtifactPaths):
test_2: int = 0


class PredictionPaths(ArtifactPaths):
# Discriminator field used to identify this as a prediction type when deserializing paths
type: Literal["prediction"] = "prediction"
root: AnyUrlString = Field(json_schema_extra={"store": True})
manifest: AnyUrlString = Field(json_schema_extra={"file": True})


class StorageTokenData(BaseModel):
key: str
secret: str
endpoint: HttpUrlString
paths: DatasetV1Paths | DatasetV2Paths | BenchmarkV2Paths = Field(union_mode="smart")
paths: DatasetV1Paths | DatasetV2Paths | BenchmarkV2Paths | PredictionPaths = Field(union_mode="smart")


class HubOAuth2Token(BaseModel):
Expand Down
30 changes: 28 additions & 2 deletions polaris/hub/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,13 @@
from zarr.storage import Store
from zarr.util import buffer_size

from polaris.hub.oauth import BenchmarkV2Paths, DatasetV1Paths, DatasetV2Paths, HubStorageOAuth2Token
from polaris.hub.oauth import (
BenchmarkV2Paths,
DatasetV1Paths,
DatasetV2Paths,
HubStorageOAuth2Token,
PredictionPaths,
)
from polaris.utils.context import track_progress
from polaris.utils.errors import PolarisHubError
from polaris.utils.types import ArtifactUrn, ZarrConflictResolution
Expand Down Expand Up @@ -549,7 +555,7 @@ def ensure_active_token(self, token: OAuth2Token | None = None) -> bool:
return True

@property
def paths(self) -> DatasetV1Paths | DatasetV2Paths | BenchmarkV2Paths:
def paths(self) -> DatasetV1Paths | DatasetV2Paths | BenchmarkV2Paths | PredictionPaths:
return self.token.extra_data.paths

def _relative_path(self, path: str) -> PurePath:
Expand Down Expand Up @@ -583,3 +589,23 @@ def set_file(self, path: str, value: bytes | bytearray):
)

store[relative_path.name] = value

def store(self, path: str) -> S3Store:
"""
Create an S3Store for the specified path.
"""
if path not in self.paths.stores:
raise NotImplementedError(
f"{type(self.paths).__name__} only supports these stores: {self.paths.stores}."
)

relative_path = self._relative_path(getattr(self.paths, path))

storage_data = self.token.extra_data
return S3Store(
path=relative_path,
access_key=storage_data.key,
secret_key=storage_data.secret,
token=f"jwt/{self.token.access_token}",
endpoint_url=storage_data.endpoint,
)
3 changes: 3 additions & 0 deletions polaris/prediction/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from ._predictions_v2 import BenchmarkPredictionsV2

__all__ = ["BenchmarkPredictionsV2"]
Loading