|
11 | 11 | from einops import rearrange, repeat |
12 | 12 | from torch import Tensor |
13 | 13 | from torch.profiler import ProfilerActivity, profile, record_function |
14 | | -from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer |
| 14 | +from transformers.generation import GenerateDecoderOnlyOutput, TextStreamer |
15 | 15 |
|
16 | 16 |
|
17 | 17 | @dataclass |
@@ -146,7 +146,7 @@ def decode( |
146 | 146 | max_length: int |
147 | 147 | teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the |
148 | 148 | logits, the next token is taken from the teacher_outputs. Useful for testing. |
149 | | - Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: |
| 149 | + Returns: GenerateDecoderOnlyOutput, with the following fields: |
150 | 150 | sequences: (batch, max_length) |
151 | 151 | scores: tuples of (batch, vocab_size) |
152 | 152 | """ |
@@ -240,8 +240,7 @@ def should_stop(current_token, inference_params): |
240 | 240 | end.record() |
241 | 241 | torch.cuda.synchronize() |
242 | 242 | print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms") |
243 | | - output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput |
244 | | - return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores)) |
| 243 | + return GenerateDecoderOnlyOutput(sequences=torch.cat(sequences, dim=1), scores=tuple(scores)) |
245 | 244 |
|
246 | 245 |
|
247 | 246 | class GenerationMixin: |
|
0 commit comments