Skip to content

Commit a052c44

Browse files
Only have contigous calls after attention blocks (#7763)
Towards #7227 . ### Description There were lots of contigous calls in the DiffusionModelUnet. It turns out these are necessary after attention blocks, as the einops operation sometimes leads to non-contigous tensors that can cause errors. I've tidied the code up so the .contiguous calls are only after attention calls. A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Mark Graham <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]>
1 parent c54bf3c commit a052c44

File tree

2 files changed

+10
-20
lines changed

2 files changed

+10
-20
lines changed

monai/networks/nets/diffusion_model_unet.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,6 @@ def forward(self, x: torch.Tensor, context: torch.Tensor | None = None) -> torch
115115

116116
class SpatialTransformer(nn.Module):
117117
"""
118-
NOTE This is a private block that we plan to merge with existing MONAI blocks in the future. Please do not make
119-
use of this block as support is not guaranteed. For more information see:
120-
https://github.com/Project-MONAI/MONAI/issues/7227
121-
122118
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
123119
standard transformer action. Finally, reshape to image.
124120
@@ -396,14 +392,11 @@ def __init__(
396392
)
397393

398394
def forward(self, x: torch.Tensor, emb: torch.Tensor) -> torch.Tensor:
399-
h = x.contiguous()
395+
h = x
400396
h = self.norm1(h)
401397
h = self.nonlinearity(h)
402398

403399
if self.upsample is not None:
404-
if h.shape[0] >= 64:
405-
x = x.contiguous()
406-
h = h.contiguous()
407400
x = self.upsample(x)
408401
h = self.upsample(h)
409402
elif self.downsample is not None:
@@ -609,7 +602,7 @@ def forward(
609602

610603
for resnet, attn in zip(self.resnets, self.attentions):
611604
hidden_states = resnet(hidden_states, temb)
612-
hidden_states = attn(hidden_states)
605+
hidden_states = attn(hidden_states).contiguous()
613606
output_states.append(hidden_states)
614607

615608
if self.downsampler is not None:
@@ -726,7 +719,7 @@ def forward(
726719

727720
for resnet, attn in zip(self.resnets, self.attentions):
728721
hidden_states = resnet(hidden_states, temb)
729-
hidden_states = attn(hidden_states, context=context)
722+
hidden_states = attn(hidden_states, context=context).contiguous()
730723
output_states.append(hidden_states)
731724

732725
if self.downsampler is not None:
@@ -790,7 +783,7 @@ def forward(
790783
) -> torch.Tensor:
791784
del context
792785
hidden_states = self.resnet_1(hidden_states, temb)
793-
hidden_states = self.attention(hidden_states)
786+
hidden_states = self.attention(hidden_states).contiguous()
794787
hidden_states = self.resnet_2(hidden_states, temb)
795788

796789
return hidden_states
@@ -1091,7 +1084,7 @@ def forward(
10911084
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
10921085

10931086
hidden_states = resnet(hidden_states, temb)
1094-
hidden_states = attn(hidden_states)
1087+
hidden_states = attn(hidden_states).contiguous()
10951088

10961089
if self.upsampler is not None:
10971090
hidden_states = self.upsampler(hidden_states, temb)
@@ -1669,7 +1662,7 @@ def forward(
16691662
down_block_res_samples = new_down_block_res_samples
16701663

16711664
# 5. mid
1672-
h = self.middle_block(hidden_states=h.contiguous(), temb=emb, context=context)
1665+
h = self.middle_block(hidden_states=h, temb=emb, context=context)
16731666

16741667
# Additional residual conections for Controlnets
16751668
if mid_block_additional_residual is not None:
@@ -1682,7 +1675,7 @@ def forward(
16821675
h = upsample_block(hidden_states=h, res_hidden_states_list=res_samples, temb=emb, context=context)
16831676

16841677
# 7. output block
1685-
output: torch.Tensor = self.out(h.contiguous())
1678+
output: torch.Tensor = self.out(h)
16861679

16871680
return output
16881681

monai/networks/nets/spade_diffusion_model_unet.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,9 +170,6 @@ def forward(self, x: torch.Tensor, emb: torch.Tensor, seg: torch.Tensor) -> torc
170170
h = self.nonlinearity(h)
171171

172172
if self.upsample is not None:
173-
if h.shape[0] >= 64:
174-
x = x.contiguous()
175-
h = h.contiguous()
176173
x = self.upsample(x)
177174
h = self.upsample(h)
178175
elif self.downsample is not None:
@@ -430,7 +427,7 @@ def forward(
430427
res_hidden_states_list = res_hidden_states_list[:-1]
431428
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
432429
hidden_states = resnet(hidden_states, temb, seg)
433-
hidden_states = attn(hidden_states)
430+
hidden_states = attn(hidden_states).contiguous()
434431

435432
if self.upsampler is not None:
436433
hidden_states = self.upsampler(hidden_states, temb)
@@ -568,7 +565,7 @@ def forward(
568565
res_hidden_states_list = res_hidden_states_list[:-1]
569566
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
570567
hidden_states = resnet(hidden_states, temb, seg)
571-
hidden_states = attn(hidden_states, context=context)
568+
hidden_states = attn(hidden_states, context=context).contiguous()
572569

573570
if self.upsampler is not None:
574571
hidden_states = self.upsampler(hidden_states, temb)
@@ -919,7 +916,7 @@ def forward(
919916
down_block_res_samples = new_down_block_res_samples
920917

921918
# 5. mid
922-
h = self.middle_block(hidden_states=h.contiguous(), temb=emb, context=context)
919+
h = self.middle_block(hidden_states=h, temb=emb, context=context)
923920

924921
# Additional residual conections for Controlnets
925922
if mid_block_additional_residual is not None:

0 commit comments

Comments
 (0)