From 7ac67b9e7494bbcc9e916a7c0017abccaa4588ab Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Tue, 11 Jan 2022 06:59:07 -0800 Subject: [PATCH 1/6] Return RGB frames as output of GPU decoder --- setup.py | 1 + test/test_video_gpu_decoder.py | 59 ++++++++++++++++++- torchvision/csrc/io/decoder/gpu/decoder.cpp | 48 ++++++--------- .../csrc/io/decoder/gpu/gpu_decoder.cpp | 42 +------------ torchvision/csrc/io/decoder/gpu/gpu_decoder.h | 1 - torchvision/io/__init__.py | 10 ---- 6 files changed, 76 insertions(+), 85 deletions(-) diff --git a/setup.py b/setup.py index 36d95c75bec..a8ed59a9cf6 100644 --- a/setup.py +++ b/setup.py @@ -469,6 +469,7 @@ def get_extensions(): "z", "pthread", "dl", + "nppicc", ], extra_compile_args=extra_compile_args, ) diff --git a/test/test_video_gpu_decoder.py b/test/test_video_gpu_decoder.py index 84309e3e217..69decee9651 100644 --- a/test/test_video_gpu_decoder.py +++ b/test/test_video_gpu_decoder.py @@ -22,6 +22,56 @@ ] +def _yuv420_to_444(mat): + # logic taken from + # https://en.wikipedia.org/wiki/YUV#Y%E2%80%B2UV420p_(and_Y%E2%80%B2V12_or_YV12)_to_RGB888_conversion + width = mat.shape[-1] + height = mat.shape[0] * 2 // 3 + luma = mat[:height] + uv = mat[height:].reshape(2, height // 2, width // 2) + uv2 = torch.nn.functional.interpolate(uv[None], scale_factor=2, mode="nearest")[0] + yuv2 = torch.cat([luma[None], uv2]).permute(1, 2, 0) + return yuv2 + + +def _yuv420_to_rgb(mat, limited_color_range=True, standard="bt709"): + # taken from https://en.wikipedia.org/wiki/YCbCr + if standard == "bt601": + # ITU-R BT.601, as used by decord + # taken from https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion + m = torch.tensor( + [ + [1.0000, 0.0000, 1.402], + [1.0000, -(1.772 * 0.114 / 0.587), -(1.402 * 0.299 / 0.587)], + [1.0000, 1.772, 0.0000], + ], + device=mat.device, + ) + elif standard == "bt709": + # ITU-R BT.709 + # taken from https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.709_conversion + m = torch.tensor( + [[1.0000, 0.0000, 1.5748], [1.0000, -0.1873, -0.4681], [1.0000, 1.8556, 0.0000]], device=mat.device + ) + else: + raise ValueError(f"{standard} not supported") + + if limited_color_range: + # also present in https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion + # being mentioned as compensation for the footroom and headroom + m = m * torch.tensor([255 / 219, 255 / 224, 255 / 224], device=mat.device) + + m = m.T + + # TODO: maybe this needs to come together with limited_color_range + offset = torch.tensor([16.0, 128.0, 128.0], device=mat.device) + + yuv2 = _yuv420_to_444(mat) + + res = (yuv2 - offset) @ m + return res + + @pytest.mark.skipif(_HAS_VIDEO_DECODER is False, reason="Didn't compile with support for gpu decoder") class TestVideoGPUDecoder: @pytest.mark.skipif(av is None, reason="PyAV unavailable") @@ -31,10 +81,13 @@ def test_frame_reading(self): decoder = VideoReader(full_path, device="cuda:0") with av.open(full_path) as container: for av_frame in container.decode(container.streams.video[0]): - av_frames = torch.tensor(av_frame.to_ndarray().flatten()) + av_frames_yuv = torch.tensor(av_frame.to_ndarray()) vision_frames = next(decoder)["data"] - mean_delta = torch.mean(torch.abs(av_frames.float() - decoder._reformat(vision_frames).float())) - assert mean_delta < 0.1 + av_frames = _yuv420_to_rgb(av_frames_yuv) + av_frames.clamp_(min=0, max=255) + av_frames = av_frames.round().to(torch.uint8) + mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float())) + assert mean_delta < 0.7 if __name__ == "__main__": diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index 4471fd6b783..b33c41c0995 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -1,5 +1,6 @@ #include "decoder.h" #include +#include #include #include #include @@ -138,38 +139,25 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* disp_info) { } auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); - torch::Tensor decoded_frame = torch::empty({get_frame_size()}, options); + torch::Tensor decoded_frame = + torch::empty({get_height(), get_width(), 3}, options); uint8_t* frame_ptr = decoded_frame.data_ptr(); + const uint8_t* const source_arr[] = { + (const uint8_t* const)source_frame, + (const uint8_t* const)(source_frame + source_pitch * ((surface_height + 1) & ~1))}; + + auto err = nppiNV12ToRGB_709CSC_8u_P2C3R( + source_arr, + source_pitch, + frame_ptr, + width * 3, + {(int)decoded_frame.size(1), (int)decoded_frame.size(0)}); + + TORCH_CHECK( + err == NPP_NO_ERROR, + "Failed to convert from NV12 to RGB. Error code:", + err); - // Copy luma plane - CUDA_MEMCPY2D m = {0}; - m.srcMemoryType = CU_MEMORYTYPE_DEVICE; - m.srcDevice = source_frame; - m.srcPitch = source_pitch; - m.dstMemoryType = CU_MEMORYTYPE_DEVICE; - m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr); - m.dstPitch = get_width() * bytes_per_pixel; - m.WidthInBytes = get_width() * bytes_per_pixel; - m.Height = luma_height; - check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__); - - // Copy chroma plane - // NVDEC output has luma height aligned by 2. Adjust chroma offset by aligning - // height - m.srcDevice = - (CUdeviceptr)((uint8_t*)source_frame + m.srcPitch * ((surface_height + 1) & ~1)); - m.dstDevice = (CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * luma_height); - m.Height = chroma_height; - check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__); - - if (num_chroma_planes == 2) { - m.srcDevice = - (CUdeviceptr)((uint8_t*)source_frame + m.srcPitch * ((surface_height + 1) & ~1) * 2); - m.dstDevice = - (CUdeviceptr)(m.dstHost = frame_ptr + m.dstPitch * luma_height * 2); - m.Height = chroma_height; - check_for_cuda_errors(cuMemcpy2DAsync(&m, cuvidStream), __LINE__, __FILE__); - } check_for_cuda_errors(cuStreamSynchronize(cuvidStream), __LINE__, __FILE__); decoded_frames.push(decoded_frame); check_for_cuda_errors(cuCtxPopCurrent(NULL), __LINE__, __FILE__); diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp index e6255aab5aa..5ff8c4924a6 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp @@ -38,48 +38,8 @@ torch::Tensor GPUDecoder::decode() { return frame; } -/* Convert a tensor with data in NV12 format to a tensor with data in YUV420 - * format in-place. - */ -torch::Tensor GPUDecoder::nv12_to_yuv420(torch::Tensor frameTensor) { - int width = decoder.get_width(), height = decoder.get_height(); - int pitch = width; - uint8_t* frame = frameTensor.data_ptr(); - uint8_t* ptr = new uint8_t[((width + 1) / 2) * ((height + 1) / 2)]; - - // sizes of source surface plane - int sizePlaneY = pitch * height; - int sizePlaneU = ((pitch + 1) / 2) * ((height + 1) / 2); - int sizePlaneV = sizePlaneU; - - uint8_t* uv = frame + sizePlaneY; - uint8_t* u = uv; - uint8_t* v = uv + sizePlaneU; - - // split chroma from interleave to planar - for (int y = 0; y < (height + 1) / 2; y++) { - for (int x = 0; x < (width + 1) / 2; x++) { - u[y * ((pitch + 1) / 2) + x] = uv[y * pitch + x * 2]; - ptr[y * ((width + 1) / 2) + x] = uv[y * pitch + x * 2 + 1]; - } - } - if (pitch == width) { - memcpy(v, ptr, sizePlaneV * sizeof(uint8_t)); - } else { - for (int i = 0; i < (height + 1) / 2; i++) { - memcpy( - v + ((pitch + 1) / 2) * i, - ptr + ((width + 1) / 2) * i, - ((width + 1) / 2) * sizeof(uint8_t)); - } - } - delete[] ptr; - return frameTensor; -} - TORCH_LIBRARY(torchvision, m) { m.class_("GPUDecoder") .def(torch::init()) - .def("next", &GPUDecoder::decode) - .def("reformat", &GPUDecoder::nv12_to_yuv420); + .def("next", &GPUDecoder::decode); } diff --git a/torchvision/csrc/io/decoder/gpu/gpu_decoder.h b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h index 02b14fda99e..daea9fc46a2 100644 --- a/torchvision/csrc/io/decoder/gpu/gpu_decoder.h +++ b/torchvision/csrc/io/decoder/gpu/gpu_decoder.h @@ -8,7 +8,6 @@ class GPUDecoder : public torch::CustomClassHolder { GPUDecoder(std::string, int64_t); ~GPUDecoder(); torch::Tensor decode(); - torch::Tensor nv12_to_yuv420(torch::Tensor); private: Demuxer demuxer; diff --git a/torchvision/io/__init__.py b/torchvision/io/__init__.py index eed2a0ecba4..64815b0660f 100644 --- a/torchvision/io/__init__.py +++ b/torchvision/io/__init__.py @@ -210,16 +210,6 @@ def set_current_stream(self, stream: str) -> bool: print("GPU decoding only works with video stream.") return self._c.set_current_stream(stream) - def _reformat(self, tensor, output_format: str = "yuv420"): - supported_formats = [ - "yuv420", - ] - if output_format not in supported_formats: - raise RuntimeError(f"{output_format} not supported, please use one of {', '.join(supported_formats)}") - if not isinstance(tensor, torch.Tensor): - raise RuntimeError("Expected tensor as input parameter!") - return self._c.reformat(tensor.cpu()) - __all__ = [ "write_video", From d5ef8bc376378b2e646e0cdb611f18e5a935668b Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Wed, 12 Jan 2022 02:41:43 -0800 Subject: [PATCH 2/6] Move clamp to the conversion function --- test/test_video_gpu_decoder.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/test_video_gpu_decoder.py b/test/test_video_gpu_decoder.py index 69decee9651..878a8f50837 100644 --- a/test/test_video_gpu_decoder.py +++ b/test/test_video_gpu_decoder.py @@ -67,9 +67,9 @@ def _yuv420_to_rgb(mat, limited_color_range=True, standard="bt709"): offset = torch.tensor([16.0, 128.0, 128.0], device=mat.device) yuv2 = _yuv420_to_444(mat) - res = (yuv2 - offset) @ m - return res + res.clamp_(min=0, max=255) + return res.round().to(torch.uint8) @pytest.mark.skipif(_HAS_VIDEO_DECODER is False, reason="Didn't compile with support for gpu decoder") @@ -84,8 +84,6 @@ def test_frame_reading(self): av_frames_yuv = torch.tensor(av_frame.to_ndarray()) vision_frames = next(decoder)["data"] av_frames = _yuv420_to_rgb(av_frames_yuv) - av_frames.clamp_(min=0, max=255) - av_frames = av_frames.round().to(torch.uint8) mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float())) assert mean_delta < 0.7 From 7a472d1564afe21dff8c28848f3f139b45ac7dc3 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Wed, 12 Jan 2022 02:48:35 -0800 Subject: [PATCH 3/6] Cleaned up a bit --- test/test_video_gpu_decoder.py | 37 +++++++--------------------------- 1 file changed, 7 insertions(+), 30 deletions(-) diff --git a/test/test_video_gpu_decoder.py b/test/test_video_gpu_decoder.py index 878a8f50837..f4bc6029174 100644 --- a/test/test_video_gpu_decoder.py +++ b/test/test_video_gpu_decoder.py @@ -34,38 +34,15 @@ def _yuv420_to_444(mat): return yuv2 -def _yuv420_to_rgb(mat, limited_color_range=True, standard="bt709"): - # taken from https://en.wikipedia.org/wiki/YCbCr - if standard == "bt601": - # ITU-R BT.601, as used by decord - # taken from https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion - m = torch.tensor( - [ - [1.0000, 0.0000, 1.402], - [1.0000, -(1.772 * 0.114 / 0.587), -(1.402 * 0.299 / 0.587)], - [1.0000, 1.772, 0.0000], - ], - device=mat.device, - ) - elif standard == "bt709": - # ITU-R BT.709 - # taken from https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.709_conversion - m = torch.tensor( - [[1.0000, 0.0000, 1.5748], [1.0000, -0.1873, -0.4681], [1.0000, 1.8556, 0.0000]], device=mat.device - ) - else: - raise ValueError(f"{standard} not supported") - - if limited_color_range: - # also present in https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.601_conversion - # being mentioned as compensation for the footroom and headroom - m = m * torch.tensor([255 / 219, 255 / 224, 255 / 224], device=mat.device) - +def _yuv420_to_rgb(mat): + # ITU-R BT.709 + # taken from https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.709_conversion + m = torch.tensor( + [[1.0000, 0.0000, 1.5748], [1.0000, -0.1873, -0.4681], [1.0000, 1.8556, 0.0000]], device=mat.device + ) + m = m * torch.tensor([255 / 219, 255 / 224, 255 / 224], device=mat.device) m = m.T - - # TODO: maybe this needs to come together with limited_color_range offset = torch.tensor([16.0, 128.0, 128.0], device=mat.device) - yuv2 = _yuv420_to_444(mat) res = (yuv2 - offset) @ m res.clamp_(min=0, max=255) From 73ab184eb8576cfa83eb3b0ffca72b2ccfa31af6 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Mon, 17 Jan 2022 03:57:27 -0800 Subject: [PATCH 4/6] Remove utility functions from test --- test/test_video_gpu_decoder.py | 32 ++------------------------------ 1 file changed, 2 insertions(+), 30 deletions(-) diff --git a/test/test_video_gpu_decoder.py b/test/test_video_gpu_decoder.py index f4bc6029174..7f7c4cf2230 100644 --- a/test/test_video_gpu_decoder.py +++ b/test/test_video_gpu_decoder.py @@ -22,33 +22,6 @@ ] -def _yuv420_to_444(mat): - # logic taken from - # https://en.wikipedia.org/wiki/YUV#Y%E2%80%B2UV420p_(and_Y%E2%80%B2V12_or_YV12)_to_RGB888_conversion - width = mat.shape[-1] - height = mat.shape[0] * 2 // 3 - luma = mat[:height] - uv = mat[height:].reshape(2, height // 2, width // 2) - uv2 = torch.nn.functional.interpolate(uv[None], scale_factor=2, mode="nearest")[0] - yuv2 = torch.cat([luma[None], uv2]).permute(1, 2, 0) - return yuv2 - - -def _yuv420_to_rgb(mat): - # ITU-R BT.709 - # taken from https://en.wikipedia.org/wiki/YCbCr#ITU-R_BT.709_conversion - m = torch.tensor( - [[1.0000, 0.0000, 1.5748], [1.0000, -0.1873, -0.4681], [1.0000, 1.8556, 0.0000]], device=mat.device - ) - m = m * torch.tensor([255 / 219, 255 / 224, 255 / 224], device=mat.device) - m = m.T - offset = torch.tensor([16.0, 128.0, 128.0], device=mat.device) - yuv2 = _yuv420_to_444(mat) - res = (yuv2 - offset) @ m - res.clamp_(min=0, max=255) - return res.round().to(torch.uint8) - - @pytest.mark.skipif(_HAS_VIDEO_DECODER is False, reason="Didn't compile with support for gpu decoder") class TestVideoGPUDecoder: @pytest.mark.skipif(av is None, reason="PyAV unavailable") @@ -58,11 +31,10 @@ def test_frame_reading(self): decoder = VideoReader(full_path, device="cuda:0") with av.open(full_path) as container: for av_frame in container.decode(container.streams.video[0]): - av_frames_yuv = torch.tensor(av_frame.to_ndarray()) + av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray()) vision_frames = next(decoder)["data"] - av_frames = _yuv420_to_rgb(av_frames_yuv) mean_delta = torch.mean(torch.abs(av_frames.float() - vision_frames.cpu().float())) - assert mean_delta < 0.7 + assert mean_delta < 0.75 if __name__ == "__main__": From ff09aac7a4d837347eb28a3aaa75c8d640527d08 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Tue, 18 Jan 2022 04:25:08 -0800 Subject: [PATCH 5/6] Use data member width directly --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index b33c41c0995..e4be9de3b94 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -140,7 +140,7 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* disp_info) { auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); torch::Tensor decoded_frame = - torch::empty({get_height(), get_width(), 3}, options); + torch::empty({get_height(), width, 3}, options); uint8_t* frame_ptr = decoded_frame.data_ptr(); const uint8_t* const source_arr[] = { (const uint8_t* const)source_frame, From eb63a8a2844b2618ff2bdc987864b7f9efc9a7a2 Mon Sep 17 00:00:00 2001 From: Prabhat Roy Date: Tue, 18 Jan 2022 05:57:53 -0800 Subject: [PATCH 6/6] Fix linter error --- torchvision/csrc/io/decoder/gpu/decoder.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/torchvision/csrc/io/decoder/gpu/decoder.cpp b/torchvision/csrc/io/decoder/gpu/decoder.cpp index e4be9de3b94..0e451298825 100644 --- a/torchvision/csrc/io/decoder/gpu/decoder.cpp +++ b/torchvision/csrc/io/decoder/gpu/decoder.cpp @@ -139,8 +139,7 @@ int Decoder::handle_picture_display(CUVIDPARSERDISPINFO* disp_info) { } auto options = torch::TensorOptions().dtype(torch::kU8).device(torch::kCUDA); - torch::Tensor decoded_frame = - torch::empty({get_height(), width, 3}, options); + torch::Tensor decoded_frame = torch::empty({get_height(), width, 3}, options); uint8_t* frame_ptr = decoded_frame.data_ptr(); const uint8_t* const source_arr[] = { (const uint8_t* const)source_frame,