-
Notifications
You must be signed in to change notification settings - Fork 376
Description
Overview
The current way handling cuda stream is that every time we execute engine on a seperate stream and perform stream synchonization on every run. If there is a graph break, the synchonization is happening in every submodule, which is unnecessary. We are devicing a new way to handle the cuda stream.
We register two operators to torch that guard the stream:
-
torch.ops.tensorrt.enter_compute_stream_gaurd
- initialize a new stream if needed
- wait the default stream if needed
- set the current stream to this stream
-
torch.ops.tensorrt.exit_compute_stream_gaurd
- syn with the main stream
- set the stream back to the default torch stream
Detailed Explanation
1. Current stream is default stream
compiled_gm(x)In this situation, we have to create a different compute stream so that the TRT computing stream would not be on torch stream.
We should have the graph such that do following:
compiled_gm(x):
torch.ops.tensorrt.enter_compute_stream_gaurd:
stream = torch.cuda.Stream()
stream.wait(torch.cuda.default_stream()
torch.cuda.set_stream(stream)
torch.ops.tensorrt.execute_engine
torch.ops.tensorrt.exit_compute_stream_gaurd
stream = torch.cuda.current_stream()
torch.cuda.default_stream().wait_stream(stream)
torch.cuda.set_stream(torch.cuda.default_stream())
2. Current stream is not on default stream:
with torch.cuda.Stream() as s1:
compiled_gm(x)In this situation, we can do the operation in the current stream and don't need other manipulation:
compiled_gm(x):
torch.ops.tensorrt.enter_compute_stream_gaurd:
Do Nothing
torch.ops.tensorrt.execute_engine
torch.ops.tensorrt.exit_compute_stream_gaurd
Do Nothing
3. Input tensors come from different streams
with torch.cuda.Stream() as s1:
j = a(i)
with torch.cuda.Stream() as s2:
k = b(i)
wait_stream(s1, s2)
w = graph(j, k)
Implementation detail
Graph Modification
The original graph looks like:
compiled_gm(x):
%2: tensor = torch.ops.tensorrt.execute_engine(x)
return %2
The graph after inserting ops like:
compiled_gm(x):
%0: List[tensor] = prim::packlist(x)
%1: bool = torch.ops.tensorrt.enter_compute_stream_gaurd([x])
%2: tensor = torch.ops.tensorrt.execute_engine(x)
%3: List[tensor] = torch.ops.tensorrt.exit_compute_stream_gaurd(%2, %1)
%4: tensor = prim::unpacklist(x)
return %4
graph(List[tensor]) -> List[tensor]
Ops registration
The ops are by default registered in C++ runtime. But if TORCHTRT_RUNTIME is not available (python only build) then we switch to python registration.
The implementation looks like:
# C++
Similar to what the code below in C++ syntax
# Python
from torch_tensorrt._features import EnabledFeatures
if not EnabledFeatures.TORCHTRT_RUNTIME:
@torch.library(tensorrt::enter_compute_stream)
def enter_compute_stream(x: List[torch.Tensor]) -> bool :
stream = torch.cuda.current_stream()
if stream == torch.cuda.default_stream():
new_stream = torch.cuda.Stream()
new_stream.wait_stream(torch.cuda.default_stream())
torch.cuda.set_stream(new_stream)
return True
return False
...
@torch.library(tensorrt::enter_compute_stream)
def exit_compute_stream(x: List[torch.Tensor], return_to_default: bool) -> List[Tensor]:
if return_to_default:
torch.cuda.default_stream().wait_stream(stream)
torch.cuda.set_stream(torch.cuda.default_stream())
return x
@torch.register_fake(tensorrt::enter_compute_stream)
def enter_compute_stream():
...
def exit_compute_stream():
...