@@ -111,27 +111,39 @@ def generate_command(args):
111111 history_ids = set (input_ids [0 ].tolist ())
112112
113113 print ("-" * 50 )
114- print (f"Prompt: { prompt } " )
115- print ("Generated Continuation:" )
116-
117- for _ in range (max_new_tokens ):
118- # Check if we should use autocast (skip if model uses float32)
119- use_autocast = True
120- if config .torch_dtype == torch .float32 :
121- use_autocast = False
122-
123- if use_autocast :
124- with torch .amp .autocast ('cuda' if device .type == 'cuda' else 'cpu' , dtype = model .config .torch_dtype ):
114+ print (f"Prompt: { prompt } " )
115+ print ("Generated Continuation:" )
116+
117+ for step in range (max_new_tokens ):
118+ # Check if we should use autocast (skip if model uses float32)
119+ use_autocast = True
120+ if config .torch_dtype == torch .float32 :
121+ use_autocast = False
122+
123+ if use_autocast :
124+ with torch .amp .autocast ('cuda' if device .type == 'cuda' else 'cpu' , dtype = model .config .torch_dtype ):
125+ outputs = model (generated_ids )
126+ logits = outputs ['logits' ]
127+ next_token_logits = logits [:, - 1 , :]
128+ else :
125129 outputs = model (generated_ids )
126130 logits = outputs ['logits' ]
127131 next_token_logits = logits [:, - 1 , :]
128- else :
129- outputs = model (generated_ids )
130- logits = outputs ['logits' ]
131- next_token_logits = logits [:, - 1 , :]
132-
133- # Repetition penalty
134- for token_id in history_ids :
132+
133+ # --- DEBUG: Print Top Predictions for First Step ---
134+ if step == 0 :
135+ probs = F .softmax (next_token_logits , dim = - 1 )
136+ top_probs , top_indices = torch .topk (probs , 5 )
137+ print ("\n [DEBUG] Step 0 Top-5 Predictions:" )
138+ for i in range (5 ):
139+ token_idx = top_indices [0 , i ].item ()
140+ prob = top_probs [0 , i ].item ()
141+ token_str = tokenizer .decode ([token_idx ])
142+ print (f" { i + 1 } . '{ token_str } ' ({ prob :.4f} )" )
143+ print ("-----------------------------------" )
144+ # ---------------------------------------------------
145+
146+ # Repetition penalty for token_id in history_ids:
135147 if token_id < next_token_logits .size (- 1 ):
136148 logit = next_token_logits [0 , token_id ].item ()
137149 if logit > 0 :
0 commit comments