You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
The from_pretrained and from_local methods with the @classmethod decorator in zonos/model.py take a device: str = DEFAULT_DEVICE argument with str as its type hint. However, DEFAULT_DEVICE and get_device defined in zonos/utils.py shows that they're actually of torch.device type:
defget_device() ->torch.device:
iftorch.cuda.is_available():
returntorch.device(torch.cuda.current_device())
# MPS breaks for whatever reason. Uncomment when it's working.# if torch.mps.is_available():# return torch.device("mps")returntorch.device("cpu")
DEFAULT_DEVICE=get_device()