Skip to content

Commit b0bb7aa

Browse files
authored
Merge pull request #76 from AnswerDotAI/live
Fix duplicated output during tool use by pausing Live display
2 parents 9831d5d + 4999ee1 commit b0bb7aa

3 files changed

Lines changed: 76 additions & 25 deletions

File tree

nbs/00_core.ipynb

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@
3434
"outputs": [],
3535
"source": [
3636
"#| export\n",
37+
"from contextlib import contextmanager\n",
3738
"from datetime import datetime\n",
38-
"from itertools import accumulate\n",
3939
"from fastcore.script import *\n",
4040
"from fastcore.tools import *\n",
4141
"from fastcore.utils import *\n",
@@ -89,7 +89,10 @@
8989
"source": [
9090
"#| export\n",
9191
"console = Console()\n",
92-
"print = console.print"
92+
"print = console.print\n",
93+
"_live = None\n",
94+
"_res = \"\"\n",
95+
"_md = None"
9396
]
9497
},
9598
{
@@ -508,30 +511,50 @@
508511
"opts = get_opts(model=None, log=None, api_base=None, api_key=''); opts"
509512
]
510513
},
514+
{
515+
"cell_type": "markdown",
516+
"id": "ae529a86",
517+
"metadata": {},
518+
"source": [
519+
"Rich's `Live` display and `input()` can't coexist — `Live` manages the terminal cursor, and `input()` needs it too. When `Live.stop()` is called, it prints its current renderable as static text, which causes duplication since `get_res` accumulates the full response. The fix: clear Live's renderable before stopping, flush the accumulated response as static markdown, and reset the buffer. This way streaming resumes fresh after the tool interaction with no overlap."
520+
]
521+
},
511522
{
512523
"cell_type": "code",
513524
"execution_count": null,
514-
"id": "5b3ae69b",
525+
"id": "7eba8e55",
515526
"metadata": {},
516527
"outputs": [],
517528
"source": [
518529
"#| export\n",
519530
"_always_allow = set() # Session-level tracking of auto-approved tools\n",
520531
"\n",
532+
"@contextmanager\n",
533+
"def _pause_live():\n",
534+
" global _res\n",
535+
" _live.update('', refresh=True)\n",
536+
" if _res: print(_md(_res))\n",
537+
" _res = ''\n",
538+
" _live.stop()\n",
539+
" try: yield\n",
540+
" finally: _live.start()\n",
541+
"\n",
521542
"def with_permission(action_desc):\n",
522543
" def decorator(func):\n",
523544
" @wraps(func)\n",
524545
" def wrapper(*args, **kwargs):\n",
546+
" global _res\n",
525547
" if IN_NOTEBOOK or func.__name__ in _always_allow: return func(*args, **kwargs)\n",
526548
" limit = 50\n",
527549
" details_dict = {\n",
528550
" \"args\": [str(arg)[:limit] + (\"...\" if len(str(arg)) > limit else \"\") for arg in args],\n",
529551
" \"kwargs\": {k: str(v)[:limit] + (\"...\" if len(str(v)) > limit else \"\") for k, v in kwargs.items()}\n",
530552
" }\n",
531-
" print(f\"About to {action_desc} with the following arguments:\")\n",
532-
" print(details_dict if args else kwargs)\n",
533-
" res = input(\"Execute this? (y/n/a=always/suggestion): \").lower().strip()\n",
534-
"\n",
553+
" with _pause_live():\n",
554+
" print(f\"About to {action_desc} with the following arguments:\", str(details_dict) if args else str(kwargs))\n",
555+
" res = input(\"Execute this? (y/n/a=always/suggestion): \").lower().strip()\n",
556+
" print()\n",
557+
" \n",
535558
" if res == 'a': _always_allow.add(func.__name__)\n",
536559
" if res in ('y','a'): return func(*args, **kwargs)\n",
537560
" elif res == 'n': return \"[Command cancelled by user]\"\n",
@@ -666,10 +689,14 @@
666689
"source": [
667690
"#| export\n",
668691
"def get_res(sage, q, opts):\n",
669-
" from litellm.types.utils import ModelResponseStream # lazy load\n",
670-
" # need to use stream=True to get search citations\n",
671-
" gen = sage(q, max_steps=10, stream=True, api_base=opts.api_base, api_key=opts.api_key, think=opts.think) \n",
672-
" yield from accumulate(o.choices[0].delta.content or \"\" for o in gen if isinstance(o, ModelResponseStream))"
692+
" global _res\n",
693+
" from litellm.types.utils import ModelResponseStream\n",
694+
" _res = \"\"\n",
695+
" gen = sage(q, max_steps=10, stream=True, api_base=opts.api_base, api_key=opts.api_key, think=opts.think)\n",
696+
" for o in gen:\n",
697+
" if isinstance(o, ModelResponseStream):\n",
698+
" _res += o.choices[0].delta.content or \"\"\n",
699+
" yield _res"
673700
]
674701
},
675702
{
@@ -864,10 +891,12 @@
864891
" res=\"\"\n",
865892
" try:\n",
866893
" with Live(Spinner(\"dots\", text=\"Connecting...\"), auto_refresh=False) as live:\n",
894+
" global _live, _md\n",
895+
" _live = live\n",
867896
" if mode not in ['default', 'sassy']:\n",
868897
" raise Exception(f\"{mode} is not valid. Must be one of the following: ['default', 'sassy']\")\n",
869898
" \n",
870-
" md = partial(Markdown, code_theme=opts.code_theme, inline_code_lexer=opts.code_lexer,\n",
899+
" _md = partial(Markdown, code_theme=opts.code_theme, inline_code_lexer=opts.code_lexer,\n",
871900
" inline_code_theme=opts.code_theme)\n",
872901
" query = ' '.join(query)\n",
873902
" ctxt = '' if skip_system else _sys_info()\n",
@@ -889,7 +918,7 @@
889918
" query = f'{ctxt}\\n<query>\\n{query}\\n</query>'\n",
890919
"\n",
891920
" sage = get_sage(opts.model, mode, search=opts.search, use_safecmd=opts.safecmd)\n",
892-
" for res in get_res(sage, query, opts): live.update(md(res), refresh=True)\n",
921+
" for res in get_res(sage, query, opts): live.update(_md(res), refresh=True)\n",
893922
" \n",
894923
" # Handle logging if the log flag is set\n",
895924
" if opts.log:\n",

shell_sage/_modidx.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
'shell_sage/core.py'),
1414
'shell_sage.core.Log': ('core.html#log', 'shell_sage/core.py'),
1515
'shell_sage.core._aliases': ('core.html#_aliases', 'shell_sage/core.py'),
16+
'shell_sage.core._pause_live': ('core.html#_pause_live', 'shell_sage/core.py'),
1617
'shell_sage.core._sys_info': ('core.html#_sys_info', 'shell_sage/core.py'),
1718
'shell_sage.core.extract': ('core.html#extract', 'shell_sage/core.py'),
1819
'shell_sage.core.extract_cf': ('core.html#extract_cf', 'shell_sage/core.py'),

shell_sage/core.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
'main', 'extract_cf', 'extract']
77

88
# %% ../nbs/00_core.ipynb #d7c5634a
9+
from contextlib import contextmanager
910
from datetime import datetime
10-
from itertools import accumulate
1111
from fastcore.script import *
1212
from fastcore.tools import *
1313
from fastcore.utils import *
@@ -34,6 +34,9 @@ def __rich_console__(self:CodeBlock, console, options):
3434
# %% ../nbs/00_core.ipynb #9d52ca34
3535
console = Console()
3636
print = console.print
37+
_live = None
38+
_res = ""
39+
_md = None
3740

3841
# %% ../nbs/00_core.ipynb #977bd215
3942
def Chat(*arg, **kw):
@@ -179,23 +182,35 @@ def get_opts(**opts):
179182
if v is None: opts[k] = cfg.get(k, default_cfg.get(k))
180183
return AttrDict(opts)
181184

182-
# %% ../nbs/00_core.ipynb #5b3ae69b
185+
# %% ../nbs/00_core.ipynb #7eba8e55
183186
_always_allow = set() # Session-level tracking of auto-approved tools
184187

188+
@contextmanager
189+
def _pause_live():
190+
global _res
191+
_live.update('', refresh=True)
192+
if _res: print(_md(_res))
193+
_res = ''
194+
_live.stop()
195+
try: yield
196+
finally: _live.start()
197+
185198
def with_permission(action_desc):
186199
def decorator(func):
187200
@wraps(func)
188201
def wrapper(*args, **kwargs):
202+
global _res
189203
if IN_NOTEBOOK or func.__name__ in _always_allow: return func(*args, **kwargs)
190204
limit = 50
191205
details_dict = {
192206
"args": [str(arg)[:limit] + ("..." if len(str(arg)) > limit else "") for arg in args],
193207
"kwargs": {k: str(v)[:limit] + ("..." if len(str(v)) > limit else "") for k, v in kwargs.items()}
194208
}
195-
print(f"About to {action_desc} with the following arguments:")
196-
print(details_dict if args else kwargs)
197-
res = input("Execute this? (y/n/a=always/suggestion): ").lower().strip()
198-
209+
with _pause_live():
210+
print(f"About to {action_desc} with the following arguments:", str(details_dict) if args else str(kwargs))
211+
res = input("Execute this? (y/n/a=always/suggestion): ").lower().strip()
212+
print()
213+
199214
if res == 'a': _always_allow.add(func.__name__)
200215
if res in ('y','a'): return func(*args, **kwargs)
201216
elif res == 'n': return "[Command cancelled by user]"
@@ -219,10 +234,14 @@ def get_sage(model, mode='default', search=False, use_safecmd=False):
219234

220235
# %% ../nbs/00_core.ipynb #68be9484
221236
def get_res(sage, q, opts):
222-
from litellm.types.utils import ModelResponseStream # lazy load
223-
# need to use stream=True to get search citations
224-
gen = sage(q, max_steps=10, stream=True, api_base=opts.api_base, api_key=opts.api_key, think=opts.think)
225-
yield from accumulate(o.choices[0].delta.content or "" for o in gen if isinstance(o, ModelResponseStream))
237+
global _res
238+
from litellm.types.utils import ModelResponseStream
239+
_res = ""
240+
gen = sage(q, max_steps=10, stream=True, api_base=opts.api_base, api_key=opts.api_key, think=opts.think)
241+
for o in gen:
242+
if isinstance(o, ModelResponseStream):
243+
_res += o.choices[0].delta.content or ""
244+
yield _res
226245

227246
# %% ../nbs/00_core.ipynb #4e6e4d92
228247
class Log: id:int; timestamp:str; query:str; response:str; model:str; mode:str
@@ -260,10 +279,12 @@ def main(
260279
res=""
261280
try:
262281
with Live(Spinner("dots", text="Connecting..."), auto_refresh=False) as live:
282+
global _live, _md
283+
_live = live
263284
if mode not in ['default', 'sassy']:
264285
raise Exception(f"{mode} is not valid. Must be one of the following: ['default', 'sassy']")
265286

266-
md = partial(Markdown, code_theme=opts.code_theme, inline_code_lexer=opts.code_lexer,
287+
_md = partial(Markdown, code_theme=opts.code_theme, inline_code_lexer=opts.code_lexer,
267288
inline_code_theme=opts.code_theme)
268289
query = ' '.join(query)
269290
ctxt = '' if skip_system else _sys_info()
@@ -285,7 +306,7 @@ def main(
285306
query = f'{ctxt}\n<query>\n{query}\n</query>'
286307

287308
sage = get_sage(opts.model, mode, search=opts.search, use_safecmd=opts.safecmd)
288-
for res in get_res(sage, query, opts): live.update(md(res), refresh=True)
309+
for res in get_res(sage, query, opts): live.update(_md(res), refresh=True)
289310

290311
# Handle logging if the log flag is set
291312
if opts.log:

0 commit comments

Comments
 (0)