Description
Currently the some model tests are failing on Linux GPU on GHA.
Error observations:
Here are sample of the error from a run in 17 January 2023:
FAILED test/test_models.py::test_classification_model[cuda-resnet101] - AssertionError: Tensor-likes are not close!
Mismatched elements: 14 / 50 (28.0%)
Greatest absolute difference: 9.2578125 at index (0, 29) (up to 0.001 allowed)
Greatest relative difference: 0.16600049957503943 at index (0, 22) (up to 0.001 allowed)
After tracing back, seems like the problem start from around 8 or 9 December 2022. We notice in 8 December 2022 the run was succeeded, however it skip the GPU test and only run CPU test (example of 8 December 2022 run).
test/test_models.py::test_classification_model[cpu-wide_resnet50_2] PASSED [ 78%]
test/test_models.py::test_classification_model[cuda-alexnet] SKIPPED [ 78%]
And on 9 December 2022, we notice it run both CPU and GPU test and the GPU test failed by having different result from the CPU counterpart (example run on 9 December 2022, notice that the failure on resnet101 has different relative difference with the one on 17 January 2023).
FAILED test/test_models.py::test_classification_model[cuda-resnet152] - AssertionError: Tensor-likes are not close!
Mismatched elements: 9 / 50 (18.0%)
Greatest absolute difference: 10275.0 at index (0, 23) (up to 0.001 allowed)
Greatest relative difference: 0.042592364818058275 at index (0, 15) (up to 0.001 allowed)
Another observation is on 9 December 2022 if we see the PR #6919, we can see that although the GHA linux GPU failed due to precision problem, the circle CI gpu test succeed.
There is not change in the model (resnet34) and the test, and the CPU test always succeed between 8 December 2022 to 17 January 2023.
Possible problems
- From [proto][ci] Try add GPU ci for prototype transforms #6919 it seems that the precision error might caused by moving to GHA (possibly different GPU or configs)
- We also notice that the precision error changes between 9 December 2022 to 17 January 2023 and it become larger, hence there might be another changes that cause the precision error bigger, this might be due to some changes on pytorch core (we confirm this by running the script below)
Script to reproduce the problem
Here is a small script that able to reproduce the problems:
import torch
import torchvision
import random
def get_cpu_gpu_model_output_maxdiff(model_fn, seed):
torch.manual_seed(seed)
random.seed(seed)
m_cpu = model_fn(num_classes=50).eval()
m_gpu = model_fn(num_classes=50)
m_gpu.load_state_dict(m_cpu.state_dict())
m_gpu = m_gpu.to("cuda").eval()
input_shape = (1, 3, 224, 224)
x_cpu = torch.rand(input_shape)
x_gpu = x_cpu.clone().to("cuda")
y_cpu = m_cpu(x_cpu).squeeze(0)
y_gpu = m_gpu(x_gpu).to("cpu").squeeze(0)
abs_diff = torch.abs(y_gpu - y_cpu)
max_abs_diff = torch.max(abs_diff)
max_abs_idx = torch.argmax(abs_diff)
max_rel_diff = torch.abs(max_abs_diff / y_cpu[max_abs_idx])
max_val_gpu = torch.max(torch.abs(y_gpu))
mean_val_gpu = torch.mean(torch.abs(y_gpu))
prec = 1e-3
pass_test = torch.allclose(y_gpu, y_cpu, atol=prec, rtol=prec)
print(f" [{seed}]max_abs_diff: {max_abs_diff},\tmax_rel_diff: {max_rel_diff},\tmax_val_gpu: {max_val_gpu},\tmean_val_gpu: {mean_val_gpu},\tpass_test: {pass_test}")
for model_fn in [torchvision.models.resnet.resnet34, torchvision.models.resnet.resnet101, torchvision.models.efficientnet.efficientnet_b0]:
print(f"model_fn: {model_fn.__name__}")
for seed in range(5):
get_cpu_gpu_model_output_maxdiff(model_fn, seed)
When I ran this script on AWS Cluster with cuda 11.6 on python 3.8 (I provide the result of collect_env.py at the end of the section), I got the following output log:
model_fn: resnet34
[0]max_abs_diff: 0.012034416198730469, max_rel_diff: 0.0012628681724891067, max_val_gpu: 35.31159210205078, mean_val_gpu: 10.925275802612305, pass_test: False
[1]max_abs_diff: 0.01442718505859375, max_rel_diff: 0.0007660656701773405, max_val_gpu: 37.38212585449219, mean_val_gpu: 12.549612998962402, pass_test: False
[2]max_abs_diff: 0.029125213623046875, max_rel_diff: 0.001757694175466895, max_val_gpu: 130.0574188232422, mean_val_gpu: 29.450868606567383, pass_test: False
[3]max_abs_diff: 0.014329195022583008, max_rel_diff: 0.004036294762045145, max_val_gpu: 38.02964401245117, mean_val_gpu: 12.15386962890625, pass_test: False
[4]max_abs_diff: 0.017838478088378906, max_rel_diff: 0.001571571920067072, max_val_gpu: 43.16202163696289, mean_val_gpu: 14.939404487609863, pass_test: False
model_fn: resnet101
[0]max_abs_diff: 9.53857421875, max_rel_diff: 0.0014522294513881207, max_val_gpu: 27715.107421875, mean_val_gpu: 9278.046875, pass_test: False
[1]max_abs_diff: 30.28759765625, max_rel_diff: 0.006908989977091551, max_val_gpu: 47344.68359375, mean_val_gpu: 12955.2421875, pass_test: False
[2]max_abs_diff: 19.2783203125, max_rel_diff: 0.0016507417894899845, max_val_gpu: 46184.32421875, mean_val_gpu: 20209.998046875, pass_test: False
[3]max_abs_diff: 16.796875, max_rel_diff: 0.0008035682258196175, max_val_gpu: 32151.07421875, mean_val_gpu: 13038.66015625, pass_test: False
[4]max_abs_diff: 17.41796875, max_rel_diff: 0.0015470916405320168, max_val_gpu: 28275.6484375, mean_val_gpu: 11823.5322265625, pass_test: False
model_fn: efficientnet_b0
[0]max_abs_diff: 2.8128270112420806e-16, max_rel_diff: 0.0031724595464766026, max_val_gpu: 2.260063203208748e-13, mean_val_gpu: 9.397712642800204e-14, pass_test: True
[1]max_abs_diff: 3.2067989756690007e-16, max_rel_diff: 0.0029657522682100534, max_val_gpu: 3.657568497204833e-13, mean_val_gpu: 1.1394578585798704e-13, pass_test: True
[2]max_abs_diff: 1.7404748296886638e-16, max_rel_diff: 0.013579211197793484, max_val_gpu: 1.936249012686464e-13, mean_val_gpu: 7.33500039136123e-14, pass_test: True
[3]max_abs_diff: 1.8396200361647796e-16, max_rel_diff: 0.0022520553320646286, max_val_gpu: 3.433499957874314e-13, mean_val_gpu: 9.2998471337008e-14, pass_test: True
[4]max_abs_diff: 2.201540273867597e-16, max_rel_diff: 0.002613184042274952, max_val_gpu: 3.3239963516049076e-13, mean_val_gpu: 1.0326688912624601e-13, pass_test: True
Our test we have tolerance of 0.001
and these results are consistently bigger than the usual tolerance for resnet models, hence it is unexpected.
There seems no change associated with resnet models in torchvision, hence most likely some changes in pytorch-core cause this differences.
The environment I used to run this reproduction:
OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.22.3
Libc version: glibc-2.27
Python version: 3.8.15 (default, Nov 24 2022, 15:19:38) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.4.0-1069-aws-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: 11.6.112
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100-SXM4-40GB
Nvidia driver version: 510.47.03
cuDNN version: Probably one of the following:
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.5
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.1.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] mypy==0.991
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.23.4
[pip3] torch==2.0.0.dev20221221
[pip3] torchaudio==2.0.0.dev20221221
[pip3] torchvision==0.15.0a0+dca6617
[conda] blas 1.0 mkl
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py38h7f8727e_0
[conda] mkl_fft 1.3.1 py38hd3c417c_0
[conda] mkl_random 1.2.2 py38h51133e4_0
[conda] numpy 1.23.4 py38h14f4228_0
[conda] numpy-base 1.23.4 py38h31eccc5_0
[conda] pytorch 2.0.0.dev20221221 py3.8_cuda11.6_cudnn8.3.2_0 pytorch-nightly
[conda] pytorch-cuda 11.6 h867d48c_2 pytorch-nightly
[conda] pytorch-mutex 1.0 cuda pytorch-nightly
[conda] torchaudio 2.0.0.dev20221221 py38_cu116 pytorch-nightly
[conda] torchtriton 2.0.0+0d7e753227 py38 pytorch-nightly
[conda] torchvision 0.15.0a0+dca6617 dev_0 <develop>