|
34 | 34 | "outputs": [], |
35 | 35 | "source": [ |
36 | 36 | "#| export\n", |
| 37 | + "from contextlib import contextmanager\n", |
37 | 38 | "from datetime import datetime\n", |
38 | | - "from itertools import accumulate\n", |
39 | 39 | "from fastcore.script import *\n", |
40 | 40 | "from fastcore.tools import *\n", |
41 | 41 | "from fastcore.utils import *\n", |
|
89 | 89 | "source": [ |
90 | 90 | "#| export\n", |
91 | 91 | "console = Console()\n", |
92 | | - "print = console.print" |
| 92 | + "print = console.print\n", |
| 93 | + "_live = None\n", |
| 94 | + "_res = \"\"\n", |
| 95 | + "_md = None" |
93 | 96 | ] |
94 | 97 | }, |
95 | 98 | { |
|
508 | 511 | "opts = get_opts(model=None, log=None, api_base=None, api_key=''); opts" |
509 | 512 | ] |
510 | 513 | }, |
| 514 | + { |
| 515 | + "cell_type": "markdown", |
| 516 | + "id": "ae529a86", |
| 517 | + "metadata": {}, |
| 518 | + "source": [ |
| 519 | + "Rich's `Live` display and `input()` can't coexist — `Live` manages the terminal cursor, and `input()` needs it too. When `Live.stop()` is called, it prints its current renderable as static text, which causes duplication since `get_res` accumulates the full response. The fix: clear Live's renderable before stopping, flush the accumulated response as static markdown, and reset the buffer. This way streaming resumes fresh after the tool interaction with no overlap." |
| 520 | + ] |
| 521 | + }, |
511 | 522 | { |
512 | 523 | "cell_type": "code", |
513 | 524 | "execution_count": null, |
514 | | - "id": "5b3ae69b", |
| 525 | + "id": "7eba8e55", |
515 | 526 | "metadata": {}, |
516 | 527 | "outputs": [], |
517 | 528 | "source": [ |
518 | 529 | "#| export\n", |
519 | 530 | "_always_allow = set() # Session-level tracking of auto-approved tools\n", |
520 | 531 | "\n", |
| 532 | + "@contextmanager\n", |
| 533 | + "def _pause_live():\n", |
| 534 | + " global _res\n", |
| 535 | + " _live.update('', refresh=True)\n", |
| 536 | + " if _res: print(_md(_res))\n", |
| 537 | + " _res = ''\n", |
| 538 | + " _live.stop()\n", |
| 539 | + " try: yield\n", |
| 540 | + " finally: _live.start()\n", |
| 541 | + "\n", |
521 | 542 | "def with_permission(action_desc):\n", |
522 | 543 | " def decorator(func):\n", |
523 | 544 | " @wraps(func)\n", |
524 | 545 | " def wrapper(*args, **kwargs):\n", |
| 546 | + " global _res\n", |
525 | 547 | " if IN_NOTEBOOK or func.__name__ in _always_allow: return func(*args, **kwargs)\n", |
526 | 548 | " limit = 50\n", |
527 | 549 | " details_dict = {\n", |
528 | 550 | " \"args\": [str(arg)[:limit] + (\"...\" if len(str(arg)) > limit else \"\") for arg in args],\n", |
529 | 551 | " \"kwargs\": {k: str(v)[:limit] + (\"...\" if len(str(v)) > limit else \"\") for k, v in kwargs.items()}\n", |
530 | 552 | " }\n", |
531 | | - " print(f\"About to {action_desc} with the following arguments:\")\n", |
532 | | - " print(details_dict if args else kwargs)\n", |
533 | | - " res = input(\"Execute this? (y/n/a=always/suggestion): \").lower().strip()\n", |
534 | | - "\n", |
| 553 | + " with _pause_live():\n", |
| 554 | + " print(f\"About to {action_desc} with the following arguments:\", str(details_dict) if args else str(kwargs))\n", |
| 555 | + " res = input(\"Execute this? (y/n/a=always/suggestion): \").lower().strip()\n", |
| 556 | + " print()\n", |
| 557 | + " \n", |
535 | 558 | " if res == 'a': _always_allow.add(func.__name__)\n", |
536 | 559 | " if res in ('y','a'): return func(*args, **kwargs)\n", |
537 | 560 | " elif res == 'n': return \"[Command cancelled by user]\"\n", |
|
666 | 689 | "source": [ |
667 | 690 | "#| export\n", |
668 | 691 | "def get_res(sage, q, opts):\n", |
669 | | - " from litellm.types.utils import ModelResponseStream # lazy load\n", |
670 | | - " # need to use stream=True to get search citations\n", |
671 | | - " gen = sage(q, max_steps=10, stream=True, api_base=opts.api_base, api_key=opts.api_key, think=opts.think) \n", |
672 | | - " yield from accumulate(o.choices[0].delta.content or \"\" for o in gen if isinstance(o, ModelResponseStream))" |
| 692 | + " global _res\n", |
| 693 | + " from litellm.types.utils import ModelResponseStream\n", |
| 694 | + " _res = \"\"\n", |
| 695 | + " gen = sage(q, max_steps=10, stream=True, api_base=opts.api_base, api_key=opts.api_key, think=opts.think)\n", |
| 696 | + " for o in gen:\n", |
| 697 | + " if isinstance(o, ModelResponseStream):\n", |
| 698 | + " _res += o.choices[0].delta.content or \"\"\n", |
| 699 | + " yield _res" |
673 | 700 | ] |
674 | 701 | }, |
675 | 702 | { |
|
864 | 891 | " res=\"\"\n", |
865 | 892 | " try:\n", |
866 | 893 | " with Live(Spinner(\"dots\", text=\"Connecting...\"), auto_refresh=False) as live:\n", |
| 894 | + " global _live, _md\n", |
| 895 | + " _live = live\n", |
867 | 896 | " if mode not in ['default', 'sassy']:\n", |
868 | 897 | " raise Exception(f\"{mode} is not valid. Must be one of the following: ['default', 'sassy']\")\n", |
869 | 898 | " \n", |
870 | | - " md = partial(Markdown, code_theme=opts.code_theme, inline_code_lexer=opts.code_lexer,\n", |
| 899 | + " _md = partial(Markdown, code_theme=opts.code_theme, inline_code_lexer=opts.code_lexer,\n", |
871 | 900 | " inline_code_theme=opts.code_theme)\n", |
872 | 901 | " query = ' '.join(query)\n", |
873 | 902 | " ctxt = '' if skip_system else _sys_info()\n", |
|
889 | 918 | " query = f'{ctxt}\\n<query>\\n{query}\\n</query>'\n", |
890 | 919 | "\n", |
891 | 920 | " sage = get_sage(opts.model, mode, search=opts.search, use_safecmd=opts.safecmd)\n", |
892 | | - " for res in get_res(sage, query, opts): live.update(md(res), refresh=True)\n", |
| 921 | + " for res in get_res(sage, query, opts): live.update(_md(res), refresh=True)\n", |
893 | 922 | " \n", |
894 | 923 | " # Handle logging if the log flag is set\n", |
895 | 924 | " if opts.log:\n", |
|
0 commit comments