Skip to content

Only have contigous calls after attention blocks #7763

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 7 additions & 14 deletions monai/networks/nets/diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,6 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch

class SpatialTransformer(nn.Module):
"""
NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make
use of this block as support is not guaranteed. For more information see:
https://github.com/Project-MONAI/MONAI/issues/7227

Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image.

Expand Down Expand Up @@ -396,14 +392,11 @@ def __init__(
)

def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
h = x.contiguous()
h = x
h = self.norm1(h)
h = self.nonlinearity(h)

if self.upsample is not None:
if h.shape[0] >= 64:
x = x.contiguous()
h = h.contiguous()
x = self.upsample(x)
h = self.upsample(h)
elif self.downsample is not None:
Expand Down Expand Up @@ -609,7 +602,7 @@ def forward(

for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
hidden_states = attn(hidden_states).contiguous()
output_states.append(hidden_states)

if self.downsampler is not None:
Expand Down Expand Up @@ -726,7 +719,7 @@ def forward(

for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=context)
hidden_states = attn(hidden_states, context=context).contiguous()
output_states.append(hidden_states)

if self.downsampler is not None:
Expand Down Expand Up @@ -790,7 +783,7 @@ def forward(
) -> torch.Tensor:
del context
hidden_states = self.resnet_1(hidden_states, temb)
hidden_states = self.attention(hidden_states)
hidden_states = self.attention(hidden_states).contiguous()
hidden_states = self.resnet_2(hidden_states, temb)

return hidden_states
Expand Down Expand Up @@ -1091,7 +1084,7 @@ def forward(
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)

hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states)
hidden_states = attn(hidden_states).contiguous()

if self.upsampler is not None:
hidden_states = self.upsampler(hidden_states, temb)
Expand Down Expand Up @@ -1669,7 +1662,7 @@ def forward(
down_block_res_samples = new_down_block_res_samples

# 5. mid
h = self.middle_block(hidden_states=h.contiguous(), temb=emb, context=context)
h = self.middle_block(hidden_states=h, temb=emb, context=context)

# Additional residual conections for Controlnets
if mid_block_additional_residual is not None:
Expand All @@ -1682,7 +1675,7 @@ def forward(
h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context)

# 7. output block
output: torch.Tensor = self.out(h.contiguous())
output: torch.Tensor = self.out(h)

return output

Expand Down
9 changes: 3 additions & 6 deletions monai/networks/nets/spade_diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,6 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor, seg: torch.Tensor) -> torc
h = self.nonlinearity(h)

if self.upsample is not None:
if h.shape[0] >= 64:
x = x.contiguous()
h = h.contiguous()
x = self.upsample(x)
h = self.upsample(h)
elif self.downsample is not None:
Expand Down Expand Up @@ -430,7 +427,7 @@ def forward(
res_hidden_states_list = res_hidden_states_list[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb, seg)
hidden_states = attn(hidden_states)
hidden_states = attn(hidden_states).contiguous()

if self.upsampler is not None:
hidden_states = self.upsampler(hidden_states, temb)
Expand Down Expand Up @@ -568,7 +565,7 @@ def forward(
res_hidden_states_list = res_hidden_states_list[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb, seg)
hidden_states = attn(hidden_states, context=context)
hidden_states = attn(hidden_states, context=context).contiguous()

if self.upsampler is not None:
hidden_states = self.upsampler(hidden_states, temb)
Expand Down Expand Up @@ -919,7 +916,7 @@ def forward(
down_block_res_samples = new_down_block_res_samples

# 5. mid
h = self.middle_block(hidden_states=h.contiguous(), temb=emb, context=context)
h = self.middle_block(hidden_states=h, temb=emb, context=context)

# Additional residual conections for Controlnets
if mid_block_additional_residual is not None:
Expand Down
Loading