diff --git a/array_api_compat/torch/_aliases.py b/array_api_compat/torch/_aliases.py index bfa7610b..e7fc7a81 100644 --- a/array_api_compat/torch/_aliases.py +++ b/array_api_compat/torch/_aliases.py @@ -99,6 +99,8 @@ def _f(x1, x2, /, **kwargs): return _f def _fix_promotion(x1, x2, only_scalar=True): + if not isinstance(x1, torch.Tensor) or not isinstance(x2, torch.Tensor): + return x1, x2 if x1.dtype not in _array_api_dtypes or x2.dtype not in _array_api_dtypes: return x1, x2 # If an argument is 0-D pytorch downcasts the other argument