Skip to content

Commit 35e927b

Browse files
authored
Updated decode output cls (state-spaces#807)
1 parent 10b5d63 commit 35e927b

1 file changed

Lines changed: 3 additions & 4 deletions

File tree

mamba_ssm/utils/generation.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from einops import rearrange, repeat
1212
from torch import Tensor
1313
from torch.profiler import ProfilerActivity, profile, record_function
14-
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer
14+
from transformers.generation import GenerateDecoderOnlyOutput, TextStreamer
1515

1616

1717
@dataclass
@@ -146,7 +146,7 @@ def decode(
146146
max_length: int
147147
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the
148148
logits, the next token is taken from the teacher_outputs. Useful for testing.
149-
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields:
149+
Returns: GenerateDecoderOnlyOutput, with the following fields:
150150
sequences: (batch, max_length)
151151
scores: tuples of (batch, vocab_size)
152152
"""
@@ -240,8 +240,7 @@ def should_stop(current_token, inference_params):
240240
end.record()
241241
torch.cuda.synchronize()
242242
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms")
243-
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput
244-
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
243+
return GenerateDecoderOnlyOutput(sequences=torch.cat(sequences, dim=1), scores=tuple(scores))
245244

246245

247246
class GenerationMixin:

0 commit comments

Comments
 (0)