Skip to content

Feat/add prediction artifact and upload method #292

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

j279li
Copy link
Contributor

@j279li j279li commented Jun 3, 2025

Summary of Changes

This PR introduces a new Prediction artifact, and an associated upload method

New Features

  • Prediction Artifact (Predictions)

    • Added a new Predictions model to represent prediction artifacts as Zarr archives.
    • Includes validators to ensure Zarr structure consistency.
    • Provides convenient properties (columns, dtypes, n_rows, etc.) for quick data exploration.
    • Methods for setting/getting prediction values, creating Zarr archives from dicts, and uploading to the Polaris Hub.
  • Upload API (upload_prediction)

    • Added a new upload_prediction method in PolarisHubClient for uploading prediction artifacts.
    • Handles metadata, manifest, and archive uploads with progress tracking.

@j279li j279li self-assigned this Jun 3, 2025
@j279li j279li added the feature Annotates any PR that adds new features; Used in the release process label Jun 3, 2025
Copy link
Contributor

@jstlaurent jstlaurent left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some comments about the model structure. We can chat about it some more if you'd like. 😄

@@ -29,4 +29,5 @@
"evaluate_benchmark",
"CompetitionPredictions",
"BenchmarkPredictions",
"Predictions",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment
For this to be re-exported, you need to import it at the top of the file as well.

Upload a Predictions artifact (with Zarr archive) to the Polaris Hub.
"""

logger = logging.getLogger(__name__)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment
There's already a logger defined at the top of the module.

Comment on lines 778 to 780
print(
f"[green]Your prediction has been successfully uploaded to the Hub. View it here: {prediction_url}"
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment
The [green] prefix will only work with progress.log. It won't be recognized as text formatting by print.

destination[".zmetadata"] = zmetadata_content
# Copy the Zarr archive
destination.copy_from_source(
prediction.zarr_root.store.store, if_exists=if_exists, log=logger.info
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion
Pass in the log from the progress tracker, to have the printout be nicer.

Suggested change
prediction.zarr_root.store.store, if_exists=if_exists, log=logger.info
prediction.zarr_root.store.store, if_exists=if_exists, log=progress_zarr.log

"zarrManifestFileContent": {
"md5Sum": prediction.zarr_manifest_md5sum,
},
"parentArtifactId": parent_artifact_id,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change requested
Predictions don't have parent artifacts.

Suggested change
"parentArtifactId": parent_artifact_id,

def zarr_manifest_path(self):
if self._zarr_manifest_path is None:
zarr_manifest_path = generate_zarr_manifest(
self.zarr_root_path, getattr(self, "_cache_dir", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment
_cache_dir is never set on the class, so this will always be None.

_zarr_manifest_path: str | None = None
_zarr_manifest_md5sum: str | None = None
column_types: dict[str, type] = {}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment
What you need in this class is to instantiate a Zarr file (probably as a Pydantic Private attribute), in which you'll store the data the user sets.

That Zarr file can be initialized with a root group, and an array for each target column. These arrays can use codecs (from polaris/dataset/zarr/codecs.py) to handle more complex data types, matching the types defined on the underlying dataset of the benchmark.

column_types: dict[str, type] = {}

@model_validator(mode="after")
def _validate_predictions_zarr_structure(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment
Since the Zarr archive won't be filled when the Prediction instance is created, this validation will not work. Also, it's not needed.

Comment on lines 342 to 349
@property
def dtypes(self):
dtypes = {}
for arr in self.zarr_root.array_keys():
dtypes[arr] = self.zarr_root[arr].dtype
for group in self.zarr_root.group_keys():
dtypes[group] = object
return dtypes
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question
This matches the API from the dataset, but I don't think we need it here.

def has_zarr_manifest_md5sum(self):
return self._zarr_manifest_md5sum is not None

def set_prediction_value(self, column: str, row: int, value: Any):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment
I think this is broadly correct, but the users might be looking to set the values as arrays, rather than individually. @cwognum can probably let us know what API would be more useful.

Copy link
Contributor Author

@j279li j279li Jun 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, what about having both, just in case we want some more modularity?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about having both

I would prefer we stick with one. Even if it may not seem like it now, I could see this become one of those things where we end up maintaining two diverging code paths. We can always add it if users ask for it!

I think it's a safe assumption that predictions will always fit in memory and would thus provide an API that lets you set all of the predictions at once.

Ideally, setting them directly through the constructor would be nicest, I think:

Predictions(
    predictions={
        'target_column_a': [0, 1, 2],
        'target_column_b': [3, 4, 5],
    }
)

For more complex objects, like rdkit.Chem.Molecule or fastpdb.AtomArray.

@jstlaurent I can remember us discussing this before and opting for dedicated set_*() methods, but my memory is failing me. Am I remembering that correctly?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For more complex objects, like rdkit.Chem.Molecule or fastpdb.AtomArray. @jstlaurent I can remember us discussing this before and opting for dedicated set_*() methods, but my memory is failing me. Am I remembering that correctly?

If the prediction Zarr arrays are initialized with the same codecs as the dataset columns, then the array elements passed in could be complex types and Zarr should encode them correctly.

We did have a discussion about how desirable it is to tightly couple the codecs classes defined in polaris with the Zarr archive definition, which we left up in the air.

Copy link
Collaborator

@cwognum cwognum left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work, @j279li !

Comment on lines 319 to 336
for group in self.zarr_root.group_keys():
if _INDEX_ARRAY_KEY not in self.zarr_root[group].array_keys():
raise ValueError(f"Group {group} does not have an index array (__index__).")
index_arr = self.zarr_root[group][_INDEX_ARRAY_KEY]
if len(index_arr) != len(self.zarr_root[group]) - 1:
raise ValueError(
f"Length of index array for group {group} does not match the size of the group (excluding index)."
)
if any(x not in self.zarr_root[group] for x in index_arr):
raise ValueError(f"Keys of index array for group {group} do not match the group members.")
# Check all arrays/groups at root have the same length
lengths = {len(self.zarr_root[k]) for k in self.zarr_root.array_keys()}
lengths.update({len(self.zarr_root[k][_INDEX_ARRAY_KEY]) for k in self.zarr_root.group_keys()})
if len(lengths) > 1:
raise ValueError(
f"All arrays or groups in the root should have the same length, found the following lengths: {lengths}"
)
return self
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment: It's ok to leave this for a later, separate PR, but:

The support for subgroups with this index array is one of my brilliant ideas, but I don't think we need it anymore. We used this to support PDB structures (i.e. fastpdb.AtomArray), which is a collection of arrays, but have realized since then that we could just as well use MsgPack to encode that entire subgroup into a single, logical entry in an array.

This is all we need:

        if len(list(self.zarr_root.group_keys())) > 0:
            raise InvalidPredictionsError("Predictions can't have subgroups")
        # Check all arrays/groups at root have the same length
        lengths = {len(self.zarr_root[k]) for k in self.zarr_root.array_keys()}
        if len(lengths) > 1:
            raise InvalidPredictionsError(
                f"All arrays or groups in the root should have the same length, found the following lengths: {lengths}"
            )
        return self

If we update this here, we should also update it for the V2 dataset.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can do this here at the same time, unless we explicitly want another PR for this.

Comment on lines 374 to 389
@property
def zarr_manifest_md5sum(self):
if not self.has_zarr_manifest_md5sum:
logger.info("Computing the checksum. This can be slow for large predictions archives.")
self.zarr_manifest_md5sum = calculate_file_md5(self.zarr_manifest_path)
return self._zarr_manifest_md5sum

@zarr_manifest_md5sum.setter
def zarr_manifest_md5sum(self, value: str):
if not re.fullmatch(r"^[a-f0-9]{32}$", value):
raise ValueError("The checksum should be the 32-character hexdigest of a 128 bit MD5 hash.")
self._zarr_manifest_md5sum = value

@property
def has_zarr_manifest_md5sum(self):
return self._zarr_manifest_md5sum is not None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: We have a ChecksumMixin class that provides these methods and we may be able to use it here.

I can't remember why we didn't use it for the DatasetV2, though. The only thing that comes to mind is that _zarr_manifest_md5sum would be renamed to _md5sum, which is less descriptive. @Andrewq11 , do you recall?

if column_types.get(col) not in (float, int, str, bool, bytes, None)
else getattr(arr, "dtype", None) or type(arr[0])
)
root.create_dataset(col, data=arr, dtype=dtype, **zarr_kwargs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment: Minor, but we've been using create_array. create_dataset is an alias to make the API compatible with h5py. Arrays are known as “datasets” in HDF5 terminology.

def upload_to_hub(
self,
owner: HubOwner | str | None = None,
parent_artifact_id: str | None = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: I don't think we need to version predictions, and so we won't need to specify the parent artifact.

@@ -714,3 +717,55 @@ def upload_model(
progress.log(
f"[green]Your model has been successfully uploaded to the Hub. View it here: {model_url}"
)

def upload_prediction(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor: For consistency?

Suggested change
def upload_prediction(
def upload_predictions(

Comment on lines 445 to 446
if model is not None:
self.model = model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment: This feels like a deviation from the API we have so far. We should keep this consistent with how it's done for results.

_artifact_type = "prediction"

benchmark_artifact_id: str
model: Model | None = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change requested: We don't want to actually upload the model itself when we upload the predictions, just its artifact ID. We can exclude the field here, like we do for results:

model: Model | None = Field(None, exclude=True)

And instead add a computed_field with the model artifact ID:

    @computed_field
    @property
    def model_artifact_id(self) -> str:
        return self.model.artifact_id if self.model else None

@j279li j279li force-pushed the feat/predictions-submission branch from fe10e34 to db77f00 Compare June 5, 2025 15:23
@j279li
Copy link
Contributor Author

j279li commented Jun 5, 2025

@jstlaurent I've encountered circular import issues with BenchmarkV2Specification, since the benchmark package already imports from the evaluate package. To work around this, I moved the Predictions class into its own package for now. Would that be alright, or do you have other ideas?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature Annotates any PR that adds new features; Used in the release process
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants