-
Notifications
You must be signed in to change notification settings - Fork 9
Feat/add prediction artifact and upload method #292
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some comments about the model structure. We can chat about it some more if you'd like. 😄
polaris/evaluate/__init__.py
Outdated
@@ -29,4 +29,5 @@ | |||
"evaluate_benchmark", | |||
"CompetitionPredictions", | |||
"BenchmarkPredictions", | |||
"Predictions", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment
For this to be re-exported, you need to import it at the top of the file as well.
polaris/hub/client.py
Outdated
Upload a Predictions artifact (with Zarr archive) to the Polaris Hub. | ||
""" | ||
|
||
logger = logging.getLogger(__name__) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment
There's already a logger defined at the top of the module.
polaris/hub/client.py
Outdated
print( | ||
f"[green]Your prediction has been successfully uploaded to the Hub. View it here: {prediction_url}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment
The [green]
prefix will only work with progress.log
. It won't be recognized as text formatting by print
.
polaris/hub/client.py
Outdated
destination[".zmetadata"] = zmetadata_content | ||
# Copy the Zarr archive | ||
destination.copy_from_source( | ||
prediction.zarr_root.store.store, if_exists=if_exists, log=logger.info |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Suggestion
Pass in the log
from the progress tracker, to have the printout be nicer.
prediction.zarr_root.store.store, if_exists=if_exists, log=logger.info | |
prediction.zarr_root.store.store, if_exists=if_exists, log=progress_zarr.log |
polaris/hub/client.py
Outdated
"zarrManifestFileContent": { | ||
"md5Sum": prediction.zarr_manifest_md5sum, | ||
}, | ||
"parentArtifactId": parent_artifact_id, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change requested
Predictions don't have parent artifacts.
"parentArtifactId": parent_artifact_id, |
polaris/evaluate/_predictions.py
Outdated
def zarr_manifest_path(self): | ||
if self._zarr_manifest_path is None: | ||
zarr_manifest_path = generate_zarr_manifest( | ||
self.zarr_root_path, getattr(self, "_cache_dir", None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment
_cache_dir
is never set on the class, so this will always be None
.
polaris/evaluate/_predictions.py
Outdated
_zarr_manifest_path: str | None = None | ||
_zarr_manifest_md5sum: str | None = None | ||
column_types: dict[str, type] = {} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment
What you need in this class is to instantiate a Zarr file (probably as a Pydantic Private attribute), in which you'll store the data the user sets.
That Zarr file can be initialized with a root group, and an array for each target column. These arrays can use codecs (from polaris/dataset/zarr/codecs.py
) to handle more complex data types, matching the types defined on the underlying dataset of the benchmark.
polaris/evaluate/_predictions.py
Outdated
column_types: dict[str, type] = {} | ||
|
||
@model_validator(mode="after") | ||
def _validate_predictions_zarr_structure(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment
Since the Zarr archive won't be filled when the Prediction
instance is created, this validation will not work. Also, it's not needed.
polaris/evaluate/_predictions.py
Outdated
@property | ||
def dtypes(self): | ||
dtypes = {} | ||
for arr in self.zarr_root.array_keys(): | ||
dtypes[arr] = self.zarr_root[arr].dtype | ||
for group in self.zarr_root.group_keys(): | ||
dtypes[group] = object | ||
return dtypes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question
This matches the API from the dataset, but I don't think we need it here.
polaris/evaluate/_predictions.py
Outdated
def has_zarr_manifest_md5sum(self): | ||
return self._zarr_manifest_md5sum is not None | ||
|
||
def set_prediction_value(self, column: str, row: int, value: Any): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment
I think this is broadly correct, but the users might be looking to set the values as arrays, rather than individually. @cwognum can probably let us know what API would be more useful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That makes sense, what about having both, just in case we want some more modularity?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what about having both
I would prefer we stick with one. Even if it may not seem like it now, I could see this become one of those things where we end up maintaining two diverging code paths. We can always add it if users ask for it!
I think it's a safe assumption that predictions will always fit in memory and would thus provide an API that lets you set all of the predictions at once.
Ideally, setting them directly through the constructor would be nicest, I think:
Predictions(
predictions={
'target_column_a': [0, 1, 2],
'target_column_b': [3, 4, 5],
}
)
For more complex objects, like rdkit.Chem.Molecule
or fastpdb.AtomArray
.
@jstlaurent I can remember us discussing this before and opting for dedicated set_*()
methods, but my memory is failing me. Am I remembering that correctly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For more complex objects, like
rdkit.Chem.Molecule
orfastpdb.AtomArray
. @jstlaurent I can remember us discussing this before and opting for dedicatedset_*()
methods, but my memory is failing me. Am I remembering that correctly?
If the prediction Zarr arrays are initialized with the same codecs as the dataset columns, then the array elements passed in could be complex types and Zarr should encode them correctly.
We did have a discussion about how desirable it is to tightly couple the codecs classes defined in polaris
with the Zarr archive definition, which we left up in the air.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice work, @j279li !
polaris/evaluate/_predictions.py
Outdated
for group in self.zarr_root.group_keys(): | ||
if _INDEX_ARRAY_KEY not in self.zarr_root[group].array_keys(): | ||
raise ValueError(f"Group {group} does not have an index array (__index__).") | ||
index_arr = self.zarr_root[group][_INDEX_ARRAY_KEY] | ||
if len(index_arr) != len(self.zarr_root[group]) - 1: | ||
raise ValueError( | ||
f"Length of index array for group {group} does not match the size of the group (excluding index)." | ||
) | ||
if any(x not in self.zarr_root[group] for x in index_arr): | ||
raise ValueError(f"Keys of index array for group {group} do not match the group members.") | ||
# Check all arrays/groups at root have the same length | ||
lengths = {len(self.zarr_root[k]) for k in self.zarr_root.array_keys()} | ||
lengths.update({len(self.zarr_root[k][_INDEX_ARRAY_KEY]) for k in self.zarr_root.group_keys()}) | ||
if len(lengths) > 1: | ||
raise ValueError( | ||
f"All arrays or groups in the root should have the same length, found the following lengths: {lengths}" | ||
) | ||
return self |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment: It's ok to leave this for a later, separate PR, but:
The support for subgroups with this index array is one of my brilliant ideas, but I don't think we need it anymore. We used this to support PDB structures (i.e. fastpdb.AtomArray
), which is a collection of arrays, but have realized since then that we could just as well use MsgPack
to encode that entire subgroup into a single, logical entry in an array.
This is all we need:
if len(list(self.zarr_root.group_keys())) > 0:
raise InvalidPredictionsError("Predictions can't have subgroups")
# Check all arrays/groups at root have the same length
lengths = {len(self.zarr_root[k]) for k in self.zarr_root.array_keys()}
if len(lengths) > 1:
raise InvalidPredictionsError(
f"All arrays or groups in the root should have the same length, found the following lengths: {lengths}"
)
return self
If we update this here, we should also update it for the V2 dataset.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can do this here at the same time, unless we explicitly want another PR for this.
polaris/evaluate/_predictions.py
Outdated
@property | ||
def zarr_manifest_md5sum(self): | ||
if not self.has_zarr_manifest_md5sum: | ||
logger.info("Computing the checksum. This can be slow for large predictions archives.") | ||
self.zarr_manifest_md5sum = calculate_file_md5(self.zarr_manifest_path) | ||
return self._zarr_manifest_md5sum | ||
|
||
@zarr_manifest_md5sum.setter | ||
def zarr_manifest_md5sum(self, value: str): | ||
if not re.fullmatch(r"^[a-f0-9]{32}$", value): | ||
raise ValueError("The checksum should be the 32-character hexdigest of a 128 bit MD5 hash.") | ||
self._zarr_manifest_md5sum = value | ||
|
||
@property | ||
def has_zarr_manifest_md5sum(self): | ||
return self._zarr_manifest_md5sum is not None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: We have a ChecksumMixin
class that provides these methods and we may be able to use it here.
I can't remember why we didn't use it for the DatasetV2, though. The only thing that comes to mind is that _zarr_manifest_md5sum
would be renamed to _md5sum
, which is less descriptive. @Andrewq11 , do you recall?
polaris/evaluate/_predictions.py
Outdated
if column_types.get(col) not in (float, int, str, bool, bytes, None) | ||
else getattr(arr, "dtype", None) or type(arr[0]) | ||
) | ||
root.create_dataset(col, data=arr, dtype=dtype, **zarr_kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment: Minor, but we've been using create_array
. create_dataset
is an alias to make the API compatible with h5py
. Arrays are known as “datasets” in HDF5 terminology.
polaris/evaluate/_predictions.py
Outdated
def upload_to_hub( | ||
self, | ||
owner: HubOwner | str | None = None, | ||
parent_artifact_id: str | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: I don't think we need to version predictions, and so we won't need to specify the parent artifact.
polaris/hub/client.py
Outdated
@@ -714,3 +717,55 @@ def upload_model( | |||
progress.log( | |||
f"[green]Your model has been successfully uploaded to the Hub. View it here: {model_url}" | |||
) | |||
|
|||
def upload_prediction( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor: For consistency?
def upload_prediction( | |
def upload_predictions( |
polaris/evaluate/_predictions.py
Outdated
if model is not None: | ||
self.model = model |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment: This feels like a deviation from the API we have so far. We should keep this consistent with how it's done for results.
polaris/evaluate/_predictions.py
Outdated
_artifact_type = "prediction" | ||
|
||
benchmark_artifact_id: str | ||
model: Model | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Change requested: We don't want to actually upload the model itself when we upload the predictions, just its artifact ID. We can exclude the field here, like we do for results:
model: Model | None = Field(None, exclude=True)
And instead add a computed_field
with the model artifact ID:
@computed_field
@property
def model_artifact_id(self) -> str:
return self.model.artifact_id if self.model else None
fe10e34
to
db77f00
Compare
@jstlaurent I've encountered circular import issues with BenchmarkV2Specification, since the benchmark package already imports from the evaluate package. To work around this, I moved the Predictions class into its own package for now. Would that be alright, or do you have other ideas? |
Summary of Changes
This PR introduces a new Prediction artifact, and an associated upload method
New Features
Prediction Artifact (
Predictions
)Predictions
model to represent prediction artifacts as Zarr archives.columns
,dtypes
,n_rows
, etc.) for quick data exploration.Upload API (
upload_prediction
)upload_prediction
method inPolarisHubClient
for uploading prediction artifacts.