Skip to content

Feat/add prediction artifact and upload method #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 36 commits into from
Jun 26, 2025
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
b0b118e
added predictions artifact and upload method
j279li May 29, 2025
3f76ffc
fix import
j279li Jun 3, 2025
637361d
fix import again
j279li Jun 3, 2025
36a7993
fix typo
j279li Jun 3, 2025
8c7a4e9
uploaded correct prediction file
j279li Jun 3, 2025
92f39b7
small fixes to client and init.py for evaluate
j279li Jun 3, 2025
9271f69
remove duplicate keyword
j279li Jun 3, 2025
db77f00
remove metadata consolidation
j279li Jun 4, 2025
7d0abc0
dataset subgroup change
j279li Jun 5, 2025
6d6dfcc
refactor of prediction files + update
j279li Jun 5, 2025
10d94a9
update tests
j279li Jun 5, 2025
7692997
added cached directory and updated payload json
j279li Jun 6, 2025
d78cc28
ruff format
j279li Jun 6, 2025
5b6349f
small changes
j279li Jun 7, 2025
be4dee7
zarr upload + store
j279li Jun 9, 2025
8a73ad2
ruff
j279li Jun 11, 2025
fee2898
small fix
j279li Jun 11, 2025
048df0a
small fixes
j279li Jun 17, 2025
ace3587
move zarr to utils + prediction artifact method changes
j279li Jun 19, 2025
7773baa
small fixes
j279li Jun 19, 2025
fc6c343
formatting
j279li Jun 19, 2025
169c751
ruff
j279li Jun 19, 2025
c7eaaf4
moved predictions upload method to benchmarkv2
j279li Jun 19, 2025
05d98a0
circular import + test fix
j279li Jun 19, 2025
090dc51
update error message`
j279li Jun 19, 2025
c5d320c
type fix
j279li Jun 19, 2025
3f2e54d
fix to zarr creation
j279li Jun 20, 2025
fbb5add
formatting
j279li Jun 20, 2025
fa45532
added some tests for predictions
j279li Jun 20, 2025
4203bf5
updated predictionsv2
j279li Jun 23, 2025
77df89c
update validator + to_zarr
j279li Jun 23, 2025
7548ab7
fixed circular imports + added some more focused tests
j279li Jun 25, 2025
7470cee
format
j279li Jun 25, 2025
0e316e7
updated docstring
j279li Jun 25, 2025
3f185a2
update to prediction tests
j279li Jun 25, 2025
5c0cbf4
updated tests and refactored predictions
j279li Jun 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions polaris/evaluate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,5 @@
"evaluate_benchmark",
"CompetitionPredictions",
"BenchmarkPredictions",
"Predictions",
]
177 changes: 177 additions & 0 deletions polaris/evaluate/_predictions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from collections import defaultdict
import logging
import re
from typing import Any

import numpy as np
import zarr
from pydantic import (
BaseModel,
ConfigDict,
Expand All @@ -12,7 +16,9 @@
)
from typing_extensions import Self

from polaris._artifact import BaseArtifactModel
from polaris.evaluate import ResultsMetadataV1
from polaris.model import Model
from polaris.utils.misc import convert_lists_to_arrays
from polaris.utils.types import (
HttpUrlString,
Expand All @@ -21,6 +27,11 @@
PredictionsType,
SlugCompatibleStringType,
)
from polaris.dataset.zarr._manifest import generate_zarr_manifest, calculate_file_md5

logger = logging.getLogger(__name__)

_INDEX_ARRAY_KEY = "__index__"


class BenchmarkPredictions(BaseModel):
Expand Down Expand Up @@ -275,3 +286,169 @@ def __repr__(self):

def __str__(self):
return self.__repr__()


class Predictions(BaseArtifactModel):
"""
Prediction artifact for uploading predictions to a Benchmark V2.
Stores predictions as a Zarr archive, with manifest and metadata for reproducibility and integrity.
Attributes:
benchmark_artifact_id: The artifact ID of the associated benchmark.
model: (Optional) The Model artifact used to generate these predictions.
zarr_root_path: Path to the Zarr archive containing the predictions.
column_types: (Optional) A dictionary mapping column names to expected types.
"""

_artifact_type = "prediction"

benchmark_artifact_id: str
model: Model | None = None
zarr_root_path: str
_zarr_manifest_path: str | None = None
_zarr_manifest_md5sum: str | None = None
column_types: dict[str, type] = {}

@model_validator(mode="after")
def _validate_predictions_zarr_structure(self):
"""
Ensures the Zarr archive for predictions is well-formed:
- All arrays/groups at the root have the same length.
- Each group has an __index__ array, its length matches the group size (excluding the index), and all keys in the index exist in the group.
"""
# Check group index arrays
for group in self.zarr_root.group_keys():
if _INDEX_ARRAY_KEY not in self.zarr_root[group].array_keys():
raise ValueError(f"Group {group} does not have an index array (__index__).")
index_arr = self.zarr_root[group][_INDEX_ARRAY_KEY]
if len(index_arr) != len(self.zarr_root[group]) - 1:
raise ValueError(
f"Length of index array for group {group} does not match the size of the group (excluding index)."
)
if any(x not in self.zarr_root[group] for x in index_arr):
raise ValueError(f"Keys of index array for group {group} do not match the group members.")
# Check all arrays/groups at root have the same length
lengths = {len(self.zarr_root[k]) for k in self.zarr_root.array_keys()}
lengths.update({len(self.zarr_root[k][_INDEX_ARRAY_KEY]) for k in self.zarr_root.group_keys()})
if len(lengths) > 1:
raise ValueError(
f"All arrays or groups in the root should have the same length, found the following lengths: {lengths}"
)
return self

@property
def columns(self):
return list(self.zarr_root.keys())

@property
def dtypes(self):
dtypes = {}
for arr in self.zarr_root.array_keys():
dtypes[arr] = self.zarr_root[arr].dtype
for group in self.zarr_root.group_keys():
dtypes[group] = object
return dtypes

@property
def n_rows(self):
cols = self.columns
if not cols:
raise ValueError("No columns found in predictions archive.")
example = self.zarr_root[cols[0]]
if isinstance(example, zarr.Group):
return len(example)
return len(example)

@property
def rows(self):
return range(self.n_rows)

@property
def zarr_manifest_path(self):
if self._zarr_manifest_path is None:
zarr_manifest_path = generate_zarr_manifest(
self.zarr_root_path, getattr(self, "_cache_dir", None)
)
self._zarr_manifest_path = zarr_manifest_path
return self._zarr_manifest_path

@property
def zarr_manifest_md5sum(self):
if not self.has_zarr_manifest_md5sum:
logger.info("Computing the checksum. This can be slow for large predictions archives.")
self.zarr_manifest_md5sum = calculate_file_md5(self.zarr_manifest_path)
return self._zarr_manifest_md5sum

@zarr_manifest_md5sum.setter
def zarr_manifest_md5sum(self, value: str):
if not re.fullmatch(r"^[a-f0-9]{32}$", value):
raise ValueError("The checksum should be the 32-character hexdigest of a 128 bit MD5 hash.")
self._zarr_manifest_md5sum = value

@property
def has_zarr_manifest_md5sum(self):
return self._zarr_manifest_md5sum is not None

def set_prediction_value(self, column: str, row: int, value: Any):
"""
Set a prediction value for a given column and row in the Zarr archive.
Validates the column exists and, if column_types is set, checks the value type.
"""
if column not in self.columns:
raise KeyError(f"Column '{column}' not defined in predictions.")
expected_type = self.column_types.get(column)
if expected_type and not isinstance(value, expected_type):
raise TypeError(f"Value for column '{column}' must be of type {expected_type}, got {type(value)}")
self.zarr_root[column][row] = value

def get_prediction_value(self, column: str, row: int):
"""
Get a prediction value for a given column and row from the Zarr archive.
"""
if column not in self.columns:
raise KeyError(f"Column '{column}' not defined in predictions.")
return self.zarr_root[column][row]

@classmethod
def create_zarr_from_dict(
cls, path: str, data: dict[str, Any], column_types: dict[str, type] | None = None, **zarr_kwargs
):
"""
Create a Zarr archive at the given path from a dict mapping column names to arrays or lists.
Uses dtype=object for columns with complex types.
Returns the path to the created Zarr archive.
"""
store = zarr.DirectoryStore(path)
root = zarr.group(store=store)
column_types = column_types or {}
for col, arr in data.items():
dtype = (
object
if column_types.get(col) not in (float, int, str, bool, bytes, None)
else getattr(arr, "dtype", None) or type(arr[0])
)
root.create_dataset(col, data=arr, dtype=dtype, **zarr_kwargs)
zarr.consolidate_metadata(store)
return path

def upload_to_hub(
self,
owner: HubOwner | str | None = None,
parent_artifact_id: str | None = None,
model: Model | None = None,
):
"""
Uploads the predictions artifact to the Polaris Hub.
Optionally sets or overrides the model before upload.
"""
from polaris.hub.client import PolarisHubClient

if model is not None:
self.model = model
with PolarisHubClient() as client:
client.upload_prediction(self, owner=owner, parent_artifact_id=parent_artifact_id)

def __repr__(self):
return self.model_dump_json(by_alias=True, indent=2)

def __str__(self):
return self.__repr__()
64 changes: 64 additions & 0 deletions polaris/hub/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from httpx import ConnectError, HTTPStatusError, Response
from typing_extensions import Self
import fsspec
import zarr

from polaris.benchmark import (
BenchmarkV1Specification,
Expand All @@ -22,9 +23,11 @@
from polaris.model import Model
from polaris.dataset import DatasetV1, DatasetV2
from polaris.evaluate import BenchmarkResultsV1, BenchmarkResultsV2, CompetitionPredictions
from polaris.evaluate._predictions import Predictions
from polaris.hub.external_client import ExternalAuthClient
from polaris.hub.oauth import CachedTokenAuth
from polaris.hub.settings import PolarisHubSettings
from polaris.hub.storage import StorageSession
from polaris.utils.context import track_progress
from polaris.utils.errors import (
PolarisCreateArtifactError,
Expand Down Expand Up @@ -714,3 +717,64 @@ def upload_model(
progress.log(
f"[green]Your model has been successfully uploaded to the Hub. View it here: {model_url}"
)

def upload_prediction(
self,
prediction: Predictions,
timeout: TimeoutTypes = (10, 200),
owner: HubOwner | str | None = None,
if_exists: ZarrConflictResolution = "replace",
parent_artifact_id: str | None = None,
):
"""
Upload a Predictions artifact (with Zarr archive) to the Polaris Hub.
"""

logger = logging.getLogger(__name__)

# Set owner
prediction.owner = HubOwner.normalize(owner or prediction.owner)
prediction_json = prediction.model_dump(by_alias=True, exclude_none=True)

# Step 1: Upload metadata to Hub
url = f"/v2/prediction/{prediction.artifact_id}"
with track_progress(description="Uploading prediction metadata", total=1) as (progress, task):
response = self._base_request_to_hub(
url=url,
method="PUT",
json={
"zarrManifestFileContent": {
"md5Sum": prediction.zarr_manifest_md5sum,
},
"parentArtifactId": parent_artifact_id,
**prediction_json,
},
timeout=timeout,
)
inserted = response.json()
prediction.slug = inserted["slug"]
prediction_url = urljoin(self.settings.hub_url, response.headers.get("Content-Location"))
progress.log(f"[green]Prediction metadata uploaded. View it here: {prediction_url}")

# Step 2: Upload manifest file
with StorageSession(self, "write", prediction.urn) as storage:
with track_progress(description="Copying manifest file", total=1):
with open(prediction.zarr_manifest_path, "rb") as manifest_file:
storage.set_file("manifest", manifest_file.read())

# Step 3: Upload Zarr archive
with track_progress(description="Copying Zarr archive", total=1) as (progress_zarr, task_zarr):
progress_zarr.log("[yellow]This may take a while.")
destination = storage.store("root")
# Consolidate Zarr metadata and upload .zmetadata
zarr.consolidate_metadata(prediction.zarr_root.store.store)
zmetadata_content = prediction.zarr_root.store.store[".zmetadata"]
destination[".zmetadata"] = zmetadata_content
# Copy the Zarr archive
destination.copy_from_source(
prediction.zarr_root.store.store, if_exists=if_exists, log=logger.info
)

print(
f"[green]Your prediction has been successfully uploaded to the Hub. View it here: {prediction_url}"
)