Skip to content

Commit 5782756

Browse files
authored
Refactor device selection (#864)
* added new functions for determining the best available device * added device test and integrated new function into device selection * fixed reference * made sure embed are still on the same device * added log point * added more log points * fixed typo * set W_E to be on the same device * set rms norm to correct device * set device to grouped query attention * Revert "set device to grouped query attention" This reverts commit 788d355. * Revert "set rms norm to correct device" This reverts commit e7018c4. * Revert "set W_E to be on the same device" This reverts commit 33f4436. * added debug points * reverted most calls to new function * reverted device list * reverted block loop * reverted cache call * updated move model function to use calculations * fixed remaining device identification issues * restored if * made sure rms norm or on the same device before calculations * added device check before linear attn * checked b_ * moved device selection * moved rotary to device * changed device move * rassigned prop * made sure all abstract attention tensors are on the same device * remvoed assignment * updated prop setting * put resid pre on device * removed some log points * ran format * resolved test * removed reassignment * changed device selection point * esnure gated mlp items are on the same device * chagned device direction
1 parent 53dee84 commit 5782756

File tree

9 files changed

+195
-39
lines changed

9 files changed

+195
-39
lines changed

tests/acceptance/test_multi_gpu.py

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55

66
from transformer_lens.HookedTransformer import HookedTransformer
7-
from transformer_lens.utilities.devices import get_device_for_block_index
7+
from transformer_lens.utilities.devices import get_best_available_device
88

99

1010
@pytest.fixture
@@ -19,36 +19,6 @@ def gpt2_medium_on_4_devices():
1919
return model
2020

2121

22-
@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 CUDA devices")
23-
def test_get_device_for_block_index(gpt2_medium_on_4_devices):
24-
config = gpt2_medium_on_4_devices.cfg
25-
n_layers = config.n_layers
26-
n_devices = config.n_devices
27-
layers_per_device = n_layers // n_devices
28-
config_device = torch.device(config.device)
29-
30-
# Test with default device (config.device)
31-
for i in range(n_layers):
32-
expected_device = torch.device(config_device.type, i // layers_per_device)
33-
assert get_device_for_block_index(i, config) == expected_device
34-
35-
# Test with explicit device
36-
device_override = "cuda"
37-
for i in range(n_layers):
38-
expected_device = torch.device(device_override, i // layers_per_device)
39-
assert get_device_for_block_index(i, config, device_override) == expected_device
40-
41-
# Test with explicit torch.device object
42-
device_override_obj = torch.device("cuda")
43-
for i in range(n_layers):
44-
expected_device = torch.device(device_override_obj.type, i // layers_per_device)
45-
assert get_device_for_block_index(i, config, device_override_obj) == expected_device
46-
47-
# Test when index is out of bounds
48-
# with pytest.raises(IndexError):
49-
# get_device_for_block_index(n_layers, config)
50-
51-
5222
@pytest.mark.skipif(torch.cuda.device_count() < 4, reason="Requires at least 4 CUDA devices")
5323
@pytest.mark.parametrize("n_devices", [1, 2, 3, 4])
5424
def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices):
@@ -85,7 +55,7 @@ def test_device_separation_and_cache(gpt2_medium_on_1_device, n_devices):
8555

8656
# Make sure the tensors in cache remain on their respective devices
8757
for i in range(model_n_devices.cfg.n_layers):
88-
expected_device = get_device_for_block_index(i, cfg=model_n_devices.cfg)
58+
expected_device = get_best_available_device(model_n_devices.cfg.device)
8959
cache_device = gpt2_cache_n_devices[f"blocks.{i}.mlp.hook_post"].device
9060
assert cache_device == expected_device
9161

tests/unit/utilities/test_devices.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
from unittest.mock import Mock
2+
3+
import torch
4+
5+
from transformer_lens.utilities.devices import (
6+
calculate_available_device_cuda_memory,
7+
determine_available_memory_for_available_devices,
8+
sort_devices_based_on_available_memory,
9+
)
10+
11+
12+
def mock_available_devices(memory_stats: list[tuple[int, int]]):
13+
torch.cuda.device_count = Mock(return_value=len(memory_stats))
14+
15+
def device_props_return(*args, **kwargs):
16+
total_memory = memory_stats[args[0]][0]
17+
device_props = Mock()
18+
device_props.total_memory = total_memory
19+
return device_props
20+
21+
def memory_allocated_return(*args, **kwargs):
22+
return memory_stats[args[0]][1]
23+
24+
torch.cuda.get_device_properties = Mock(side_effect=device_props_return)
25+
torch.cuda.memory_allocated = Mock(side_effect=memory_allocated_return)
26+
27+
28+
def test_calculate_available_device_cuda_memory():
29+
mock_available_devices([(80, 40)])
30+
31+
result = calculate_available_device_cuda_memory(0)
32+
assert result == 40
33+
34+
35+
def test_determine_available_memory_for_available_devices():
36+
mock_available_devices(
37+
[
38+
(80, 60),
39+
(80, 15),
40+
(80, 40),
41+
]
42+
)
43+
44+
result = determine_available_memory_for_available_devices(3)
45+
46+
assert result == [
47+
(0, 20),
48+
(1, 65),
49+
(2, 40),
50+
]
51+
52+
53+
def test_sort_devices_based_on_available_memory():
54+
devices = [
55+
(0, 20),
56+
(1, 65),
57+
(2, 40),
58+
]
59+
60+
result = sort_devices_based_on_available_memory(devices)
61+
62+
assert result == [
63+
(1, 65),
64+
(2, 40),
65+
(0, 20),
66+
]

transformer_lens/HookedTransformer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,17 +1091,17 @@ def mps(self):
10911091
return self.to("mps")
10921092

10931093
def move_model_modules_to_device(self):
1094-
self.embed.to(devices.get_device_for_block_index(0, self.cfg))
1095-
self.hook_embed.to(devices.get_device_for_block_index(0, self.cfg))
1094+
self.embed.to(devices.get_best_available_device(self.cfg))
1095+
self.hook_embed.to(devices.get_best_available_device(self.cfg))
10961096
if self.cfg.positional_embedding_type != "rotary":
1097-
self.pos_embed.to(devices.get_device_for_block_index(0, self.cfg))
1098-
self.hook_pos_embed.to(devices.get_device_for_block_index(0, self.cfg))
1097+
self.pos_embed.to(devices.get_best_available_device(self.cfg))
1098+
self.hook_pos_embed.to(devices.get_best_available_device(self.cfg))
10991099

11001100
if hasattr(self, "ln_final"):
1101-
self.ln_final.to(devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg))
1102-
self.unembed.to(devices.get_device_for_block_index(self.cfg.n_layers - 1, self.cfg))
1101+
self.ln_final.to(devices.get_best_available_device(self.cfg))
1102+
self.unembed.to(devices.get_best_available_device(self.cfg))
11031103
for i, block in enumerate(self.blocks):
1104-
block.to(devices.get_device_for_block_index(i, self.cfg))
1104+
block.to(devices.get_best_available_device(self.cfg))
11051105

11061106
@classmethod
11071107
def from_pretrained(

transformer_lens/components/abstract_attention.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,12 @@ def forward(
279279
w = einops.rearrange(
280280
self.W_O, "head_index d_head d_model -> d_model (head_index d_head)"
281281
)
282+
283+
if self.b_O.device != w.device:
284+
w = w.to(self.b_O.device)
285+
if self.b_O.device != z.device:
286+
z = z.to(self.b_O.device)
287+
282288
out = F.linear(
283289
z.reshape(z.shape[0], z.shape[1], self.cfg.d_head * self.cfg.n_heads),
284290
w,
@@ -552,6 +558,10 @@ def apply_rotary(
552558
attention_mask: Optional[Int[torch.Tensor, "batch offset_pos"]] = None,
553559
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
554560
# Only apply rotary to first rotary_dim dimensions (eg, if rotary_dim=64 and d_head=256, only apply to first 1/4 of dimensions)
561+
562+
if x.device != self.rotary_sin.device:
563+
x = x.to(self.rotary_sin.device)
564+
555565
x_pos = x.size(1)
556566
x_rot = x[..., : self.cfg.rotary_dim]
557567
x_pass = x[..., self.cfg.rotary_dim :]

transformer_lens/components/mlps/gated_mlp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def forward(
5050
self, x: Float[torch.Tensor, "batch pos d_model"]
5151
) -> Float[torch.Tensor, "batch pos d_model"]:
5252
# Technically, all these einsums could be done with a single matmul, but this is more readable.
53+
if self.W_gate.device != x.device:
54+
x = x.to(self.W_gate.device)
5355
pre_act = self.hook_pre(
5456
torch.matmul(x, self.W_gate) # batch pos d_model, d_model d_mlp -> batch pos d_mlp
5557
) # [batch, pos, d_mlp]

transformer_lens/components/rms_norm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,8 @@ def forward(
4242
(x.pow(2).mean(-1, keepdim=True) + self.eps).sqrt()
4343
)
4444
x = self.hook_normalized(x / scale).to(self.cfg.dtype) # [batch, pos, length]
45+
46+
if x.device != self.w.device:
47+
self.to(x.device)
48+
4549
return x * self.w

transformer_lens/components/transformer_block.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,10 @@ def forward(
173173
# is added to the residual stream"
174174
attn_out = self.ln1_post(attn_out)
175175
attn_out = self.hook_attn_out(attn_out)
176+
177+
if resid_pre.device != attn_out.device:
178+
resid_pre = resid_pre.to(attn_out.device)
179+
176180
if not self.cfg.attn_only and not self.cfg.parallel_attn_mlp:
177181
resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model]
178182
mlp_in = (

transformer_lens/utilities/attention.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,15 @@ def simple_attn_linear(
1515
b: Float[torch.Tensor, "head_index d_head"],
1616
) -> Float[torch.Tensor, "batch pos head_index d_head"]:
1717
"""Linear layer for attention calculation."""
18+
19+
if input.device != w.device:
20+
w = w.to(input.device)
21+
if input.device != b.device:
22+
b = b.to(input.device)
23+
1824
w = einops.rearrange(w, "head_index d_model d_head -> (head_index d_head) d_model")
1925
b_ = einops.rearrange(b, "head_index d_head -> (head_index d_head)")
26+
2027
return F.linear(input, w, b_).reshape(input.shape[0], input.shape[1], b.shape[0], b.shape[1])
2128

2229

transformer_lens/utilities/devices.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,93 @@
1313

1414
import transformer_lens
1515

16+
AvailableDeviceMemory = list[tuple[int, int]]
17+
"""
18+
This type is passed around between different CUDA memory operations.
19+
The first entry of each tuple will be the device index.
20+
The second entry will be how much memory is currently available.
21+
"""
22+
23+
24+
def calculate_available_device_cuda_memory(i: int) -> int:
25+
"""Calculates how much memory is available at this moment for the device at the indicated index
26+
27+
Args:
28+
i (int): The index we are looking at
29+
30+
Returns:
31+
int: How memory is available
32+
"""
33+
total = torch.cuda.get_device_properties(i).total_memory
34+
allocated = torch.cuda.memory_allocated(i)
35+
return total - allocated
36+
37+
38+
def determine_available_memory_for_available_devices(max_devices: int) -> AvailableDeviceMemory:
39+
"""Gets all available CUDA devices with their current memory calculated
40+
41+
Returns:
42+
AvailableDeviceMemory: The list of all available devices with memory precalculated
43+
"""
44+
devices = []
45+
for i in range(max_devices):
46+
devices.append((i, calculate_available_device_cuda_memory(i)))
47+
48+
return devices
49+
50+
51+
def sort_devices_based_on_available_memory(devices: AvailableDeviceMemory) -> AvailableDeviceMemory:
52+
"""Sorts all available devices with devices with the most available memory returned first
53+
54+
Args:
55+
devices (AvailableDeviceMemory): All available devices with memory calculated
56+
57+
Returns:
58+
AvailableDeviceMemory: The same list of passed through devices sorted with devices with most
59+
available memory first
60+
"""
61+
return sorted(devices, key=lambda x: x[1], reverse=True)
62+
63+
64+
def get_best_available_cuda_device(max_devices: Optional[int] = None) -> torch.device:
65+
"""Gets whichever cuda device has the most available amount of memory for use
66+
67+
Raises:
68+
EnvironmentError: If there are no available devices, this will error out
69+
70+
Returns:
71+
torch.device: The specific device that should be used
72+
"""
73+
max_devices = max_devices if max_devices is not None else torch.cuda.device_count()
74+
devices = determine_available_memory_for_available_devices(max_devices)
75+
76+
if len(devices) <= 0:
77+
raise EnvironmentError(
78+
"TransformerLens has been configured to use CUDA, but no available devices are present"
79+
)
80+
81+
sorted_devices = sort_devices_based_on_available_memory(devices=devices)
82+
83+
return torch.device("cuda", sorted_devices[0][0])
84+
85+
86+
def get_best_available_device(cfg: "transformer_lens.HookedTransformerConfig") -> torch.device:
87+
"""Gets the best available device to be used based on the passed in arguments
88+
89+
Args:
90+
device (Union[torch.device, str]): Either the existing torch device or the string identifier
91+
92+
Returns:
93+
torch.device: The best available device
94+
"""
95+
assert cfg.device is not None
96+
device = torch.device(cfg.device)
97+
98+
if device.type == "cuda":
99+
return get_best_available_cuda_device(cfg.n_devices)
100+
else:
101+
return device
102+
16103

17104
def get_device_for_block_index(
18105
index: int,
@@ -25,6 +112,7 @@ def get_device_for_block_index(
25112
This function assists in distributing model layers across multiple devices. The distribution
26113
is based on the configuration's number of layers (cfg.n_layers) and devices (cfg.n_devices).
27114
115+
28116
Args:
29117
index (int): Model layer index.
30118
cfg (HookedTransformerConfig): Model and device configuration.
@@ -33,6 +121,11 @@ def get_device_for_block_index(
33121
34122
Returns:
35123
torch.device: The device for the specified layer index.
124+
125+
Deprecated:
126+
This function did not take into account a few factors for multi-GPU support. You should now
127+
use get_best_available_device in order to properly run models on multiple devices.
128+
This will be removed in 3.0
36129
"""
37130
assert cfg.device is not None
38131
layers_per_device = cfg.n_layers // cfg.n_devices

0 commit comments

Comments
 (0)