Skip to content

✨[Feature] Torch stream handling for graph break #3977

@cehongwang

Description

@cehongwang

Overview

The current way handling cuda stream is that every time we execute engine on a seperate stream and perform stream synchonization on every run. If there is a graph break, the synchonization is happening in every submodule, which is unnecessary. We are devicing a new way to handle the cuda stream.

We register two operators to torch that guard the stream:

  • torch.ops.tensorrt.enter_compute_stream_gaurd

    • initialize a new stream if needed
    • wait the default stream if needed
    • set the current stream to this stream
  • torch.ops.tensorrt.exit_compute_stream_gaurd

    • syn with the main stream
    • set the stream back to the default torch stream

Detailed Explanation

1. Current stream is default stream

compiled_gm(x)

In this situation, we have to create a different compute stream so that the TRT computing stream would not be on torch stream.

We should have the graph such that do following:

compiled_gm(x):
    torch.ops.tensorrt.enter_compute_stream_gaurd:
        stream = torch.cuda.Stream()
        stream.wait(torch.cuda.default_stream()
        torch.cuda.set_stream(stream)
    torch.ops.tensorrt.execute_engine
    torch.ops.tensorrt.exit_compute_stream_gaurd
        stream = torch.cuda.current_stream()
        torch.cuda.default_stream().wait_stream(stream)
        torch.cuda.set_stream(torch.cuda.default_stream())

2. Current stream is not on default stream:

with torch.cuda.Stream() as s1:
    compiled_gm(x)

In this situation, we can do the operation in the current stream and don't need other manipulation:

compiled_gm(x):
    torch.ops.tensorrt.enter_compute_stream_gaurd:
        Do Nothing
    torch.ops.tensorrt.execute_engine
    torch.ops.tensorrt.exit_compute_stream_gaurd
        Do Nothing

3. Input tensors come from different streams

with torch.cuda.Stream() as s1:
    j = a(i)
with torch.cuda.Stream() as s2:
    k = b(i)
    
wait_stream(s1, s2)
w = graph(j, k)
    

Implementation detail

Graph Modification

The original graph looks like:

compiled_gm(x):
    %2: tensor = torch.ops.tensorrt.execute_engine(x)
    return %2 

The graph after inserting ops like:

   
compiled_gm(x):
    %0: List[tensor] = prim::packlist(x)
    %1: bool = torch.ops.tensorrt.enter_compute_stream_gaurd([x])    
    %2: tensor = torch.ops.tensorrt.execute_engine(x)
    %3: List[tensor]  = torch.ops.tensorrt.exit_compute_stream_gaurd(%2, %1)
    %4: tensor = prim::unpacklist(x)
   return %4 

graph(List[tensor]) -> List[tensor]

Ops registration

The ops are by default registered in C++ runtime. But if TORCHTRT_RUNTIME is not available (python only build) then we switch to python registration.

The implementation looks like:

# C++
Similar to what the code below in C++ syntax
# Python
from torch_tensorrt._features import EnabledFeatures 

if not EnabledFeatures.TORCHTRT_RUNTIME: 
    @torch.library(tensorrt::enter_compute_stream)
    def enter_compute_stream(x: List[torch.Tensor]) -> bool :
        stream = torch.cuda.current_stream()
        if stream == torch.cuda.default_stream():
            new_stream = torch.cuda.Stream()
            new_stream.wait_stream(torch.cuda.default_stream())
            torch.cuda.set_stream(new_stream)
            return True
        return False
        ...
        
    @torch.library(tensorrt::enter_compute_stream)
    def exit_compute_stream(x: List[torch.Tensor], return_to_default: bool) -> List[Tensor]:
        if return_to_default:
            torch.cuda.default_stream().wait_stream(stream)
            torch.cuda.set_stream(torch.cuda.default_stream())
        return x

@torch.register_fake(tensorrt::enter_compute_stream)
def enter_compute_stream():
    ...
    
def exit_compute_stream():
    ...

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions