diff --git a/test/float8/test_base.py b/test/float8/test_base.py index 2a875c44d6..db66a206ef 100644 --- a/test/float8/test_base.py +++ b/test/float8/test_base.py @@ -53,7 +53,7 @@ is_cuda_8_9 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) def bitwise_identical(a: Float8Tensor, b: Float8Tensor) -> bool: - assert torch.all(a._data == b._data).item(), "scales are not identical" + assert torch.all(a._scale == b._scale).item(), "scales are not identical" assert torch.all(a._data == b._data).item(), "data is not identical" return True