@@ -66,7 +66,9 @@ def __call__(
66
66
logprobs = torch .nn .functional .log_softmax (next_token_logits , dim = - 1 )
67
67
next_token_ids = torch .argmax (logprobs , dim = - 1 , keepdim = True )
68
68
69
- ancestors = torch .arange (next_token_logits .shape [0 ])
69
+ ancestors = torch .arange (
70
+ next_token_logits .shape [0 ], device = next_token_logits .device
71
+ )
70
72
weights = sequence_weights + torch .gather (logprobs , 1 , next_token_ids ).squeeze ()
71
73
72
74
return next_token_ids , ancestors , weights
@@ -144,7 +146,9 @@ def __call__(
144
146
next_token_ids = torch .multinomial (probs , num_samples = 1 , generator = rng )
145
147
146
148
logprobs = torch .nn .functional .log_softmax (altered_next_token_logits , dim = - 1 )
147
- ancestors = torch .arange (altered_next_token_logits .shape [0 ])
149
+ ancestors = torch .arange (
150
+ altered_next_token_logits .shape [0 ], device = next_token_logits .device
151
+ )
148
152
weights = sequence_weights + torch .gather (logprobs , 1 , next_token_ids ).squeeze ()
149
153
150
154
return next_token_ids , ancestors , weights
@@ -292,7 +296,7 @@ def __call__(
292
296
293
297
# Re-shape the weights, next_token_ids and ancestors to (n_batch * n_samples, 1)
294
298
first_batch_idx = torch .arange (
295
- 0 , batch_size * self .samples , self .samples
299
+ 0 , batch_size * self .samples , self .samples , device = next_token_logits . device
296
300
).unsqueeze (1 )
297
301
ancestors = ancestors + first_batch_idx
298
302
0 commit comments