Skip to content

Commit a33692e

Browse files
lapp0Andrew Lapp
andauthored
Put ancestors on same device as next_token_logits (#651)
Fixes #649 --------- Co-authored-by: Andrew Lapp <[email protected]>
1 parent 29bd1fe commit a33692e

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

outlines/samplers.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def __call__(
6666
logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1)
6767
next_token_ids = torch.argmax(logprobs, dim=-1, keepdim=True)
6868

69-
ancestors = torch.arange(next_token_logits.shape[0])
69+
ancestors = torch.arange(
70+
next_token_logits.shape[0], device=next_token_logits.device
71+
)
7072
weights = sequence_weights + torch.gather(logprobs, 1, next_token_ids).squeeze()
7173

7274
return next_token_ids, ancestors, weights
@@ -144,7 +146,9 @@ def __call__(
144146
next_token_ids = torch.multinomial(probs, num_samples=1, generator=rng)
145147

146148
logprobs = torch.nn.functional.log_softmax(altered_next_token_logits, dim=-1)
147-
ancestors = torch.arange(altered_next_token_logits.shape[0])
149+
ancestors = torch.arange(
150+
altered_next_token_logits.shape[0], device=next_token_logits.device
151+
)
148152
weights = sequence_weights + torch.gather(logprobs, 1, next_token_ids).squeeze()
149153

150154
return next_token_ids, ancestors, weights
@@ -292,7 +296,7 @@ def __call__(
292296

293297
# Re-shape the weights, next_token_ids and ancestors to (n_batch * n_samples, 1)
294298
first_batch_idx = torch.arange(
295-
0, batch_size * self.samples, self.samples
299+
0, batch_size * self.samples, self.samples, device=next_token_logits.device
296300
).unsqueeze(1)
297301
ancestors = ancestors + first_batch_idx
298302

0 commit comments

Comments
 (0)