diff --git a/pageindex/config.yaml b/pageindex/config.yaml index fd73e3a2c..1e2798cb9 100644 --- a/pageindex/config.yaml +++ b/pageindex/config.yaml @@ -1,4 +1,5 @@ model: "gpt-4o-2024-11-20" +# model: "anthropic/claude-haiku-4-5-20251001" toc_check_page_num: 20 max_page_num_each_node: 10 max_token_num_each_node: 20000 diff --git a/pageindex/page_index.py b/pageindex/page_index.py index d646bb9d2..2f0f91aac 100644 --- a/pageindex/page_index.py +++ b/pageindex/page_index.py @@ -36,7 +36,7 @@ async def check_title_appearance(item, page_list, start_index=1, model=None): }} Directly return the final JSON structure. Do not output anything else.""" - response = await ChatGPT_API_async(model=model, prompt=prompt) + response = await allm_complete(model=model, prompt=prompt) response = extract_json(response) if 'answer' in response: answer = response['answer'] @@ -64,7 +64,7 @@ async def check_title_appearance_in_start(title, page_text, model=None, logger=N }} Directly return the final JSON structure. Do not output anything else.""" - response = await ChatGPT_API_async(model=model, prompt=prompt) + response = await allm_complete(model=model, prompt=prompt) response = extract_json(response) if logger: logger.info(f"Response: {response}") @@ -116,7 +116,7 @@ def toc_detector_single_page(content, model=None): Directly return the final JSON structure. Do not output anything else. Please note: abstract,summary, notation list, figure list, table list, etc. are not table of contents.""" - response = ChatGPT_API(model=model, prompt=prompt) + response = llm_complete(model=model, prompt=prompt) # print('response', response) json_content = extract_json(response) return json_content['toc_detected'] @@ -135,7 +135,7 @@ def check_if_toc_extraction_is_complete(content, toc, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = prompt + '\n Document:\n' + content + '\n Table of contents:\n' + toc - response = ChatGPT_API(model=model, prompt=prompt) + response = llm_complete(model=model, prompt=prompt) json_content = extract_json(response) return json_content['completed'] @@ -153,7 +153,7 @@ def check_if_toc_transformation_is_complete(content, toc, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = prompt + '\n Raw Table of contents:\n' + content + '\n Cleaned Table of contents:\n' + toc - response = ChatGPT_API(model=model, prompt=prompt) + response = llm_complete(model=model, prompt=prompt) json_content = extract_json(response) return json_content['completed'] @@ -165,7 +165,7 @@ def extract_toc_content(content, model=None): Directly return the full table of contents content. Do not output anything else.""" - response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) + response, finish_reason = llm_complete(model=model, prompt=prompt, return_finish_reason=True) if_complete = check_if_toc_transformation_is_complete(content, response, model) if if_complete == "yes" and finish_reason == "finished": @@ -176,7 +176,7 @@ def extract_toc_content(content, model=None): {"role": "assistant", "content": response}, ] prompt = f"""please continue the generation of table of contents , directly output the remaining part of the structure""" - new_response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt, chat_history=chat_history) + new_response, finish_reason = llm_complete(model=model, prompt=prompt, chat_history=chat_history, return_finish_reason=True) response = response + new_response if_complete = check_if_toc_transformation_is_complete(content, response, model) @@ -193,7 +193,7 @@ def extract_toc_content(content, model=None): {"role": "assistant", "content": response}, ] prompt = f"""please continue the generation of table of contents , directly output the remaining part of the structure""" - new_response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt, chat_history=chat_history) + new_response, finish_reason = llm_complete(model=model, prompt=prompt, chat_history=chat_history, return_finish_reason=True) response = response + new_response if_complete = check_if_toc_transformation_is_complete(content, response, model) @@ -215,7 +215,7 @@ def detect_page_index(toc_content, model=None): }} Directly return the final JSON structure. Do not output anything else.""" - response = ChatGPT_API(model=model, prompt=prompt) + response = llm_complete(model=model, prompt=prompt) json_content = extract_json(response) return json_content['page_index_given_in_toc'] @@ -264,7 +264,7 @@ def toc_index_extractor(toc, content, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = toc_extractor_prompt + '\nTable of contents:\n' + str(toc) + '\nDocument pages:\n' + content - response = ChatGPT_API(model=model, prompt=prompt) + response = llm_complete(model=model, prompt=prompt) json_content = extract_json(response) return json_content @@ -292,7 +292,7 @@ def toc_transformer(toc_content, model=None): Directly return the final JSON structure, do not output anything else. """ prompt = init_prompt + '\n Given table of contents\n:' + toc_content - last_complete, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) + last_complete, finish_reason = llm_complete(model=model, prompt=prompt, return_finish_reason=True) if_complete = check_if_toc_transformation_is_complete(toc_content, last_complete, model) if if_complete == "yes" and finish_reason == "finished": last_complete = extract_json(last_complete) @@ -316,7 +316,7 @@ def toc_transformer(toc_content, model=None): Please continue the json structure, directly output the remaining part of the json structure.""" - new_complete, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) + new_complete, finish_reason = llm_complete(model=model, prompt=prompt, return_finish_reason=True) if new_complete.startswith('```json'): new_complete = get_json_content(new_complete) @@ -477,7 +477,7 @@ def add_page_number_to_toc(part, structure, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = fill_prompt_seq + f"\n\nCurrent Partial Document:\n{part}\n\nGiven Structure\n{json.dumps(structure, indent=2)}\n" - current_json_raw = ChatGPT_API(model=model, prompt=prompt) + current_json_raw = llm_complete(model=model, prompt=prompt) json_result = extract_json(current_json_raw) for item in json_result: @@ -527,7 +527,7 @@ def generate_toc_continue(toc_content, part, model="gpt-4o-2024-11-20"): Directly return the additional part of the final JSON structure. Do not output anything else.""" prompt = prompt + '\nGiven text\n:' + part + '\nPrevious tree structure\n:' + json.dumps(toc_content, indent=2) - response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) + response, finish_reason = llm_complete(model=model, prompt=prompt, return_finish_reason=True) if finish_reason == 'finished': return extract_json(response) else: @@ -561,7 +561,7 @@ def generate_toc_init(part, model=None): Directly return the final JSON structure. Do not output anything else.""" prompt = prompt + '\nGiven text\n:' + part - response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) + response, finish_reason = llm_complete(model=model, prompt=prompt, return_finish_reason=True) if finish_reason == 'finished': return extract_json(response) @@ -746,7 +746,7 @@ def single_toc_item_index_fixer(section_title, content, model="gpt-4o-2024-11-20 Directly return the final JSON structure. Do not output anything else.""" prompt = toc_extractor_prompt + '\nSection Title:\n' + str(section_title) + '\nDocument pages:\n' + content - response = ChatGPT_API(model=model, prompt=prompt) + response = llm_complete(model=model, prompt=prompt) json_content = extract_json(response) return convert_physical_index_to_int(json_content['physical_index']) diff --git a/pageindex/utils.py b/pageindex/utils.py index 3517ab80c..622fb1ca6 100644 --- a/pageindex/utils.py +++ b/pageindex/utils.py @@ -1,5 +1,4 @@ -import tiktoken -import openai +import litellm import logging import os from datetime import datetime @@ -17,96 +16,78 @@ from pathlib import Path from types import SimpleNamespace as config -CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY") +# Backward compatibility: support CHATGPT_API_KEY as alias for OPENAI_API_KEY +if not os.getenv("OPENAI_API_KEY") and os.getenv("CHATGPT_API_KEY"): + os.environ["OPENAI_API_KEY"] = os.getenv("CHATGPT_API_KEY") + +litellm.drop_params = True + def count_tokens(text, model=None): if not text: return 0 - enc = tiktoken.encoding_for_model(model) - tokens = enc.encode(text) - return len(tokens) + return litellm.token_counter(model=model, text=text) + -def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): +def llm_complete(model, prompt, chat_history=None, return_finish_reason=False): max_retries = 10 - client = openai.OpenAI(api_key=api_key) + messages = list(chat_history) + [{"role": "user", "content": prompt}] if chat_history else [{"role": "user", "content": prompt}] for i in range(max_retries): try: - if chat_history: - messages = chat_history - messages.append({"role": "user", "content": prompt}) - else: - messages = [{"role": "user", "content": prompt}] - - response = client.chat.completions.create( + response = litellm.completion( model=model, messages=messages, temperature=0, ) - if response.choices[0].finish_reason == "length": - return response.choices[0].message.content, "max_output_reached" - else: - return response.choices[0].message.content, "finished" - + content = response.choices[0].message.content + if return_finish_reason: + finish_reason = "max_output_reached" if response.choices[0].finish_reason == "length" else "finished" + return content, finish_reason + return content except Exception as e: print('************* Retrying *************') logging.error(f"Error: {e}") if i < max_retries - 1: - time.sleep(1) # Wait for 1秒 before retrying + time.sleep(1) else: logging.error('Max retries reached for prompt: ' + prompt) return "", "error" +def llm_complete_stream(model, prompt): + """Return a generator that yields token chunks (str) one at a time.""" + response = litellm.completion( + model=model, + messages=[{"role": "user", "content": prompt}], + temperature=0, + stream=True, + ) + for chunk in response: + delta = chunk.choices[0].delta.content + if delta: + yield delta + -def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): +async def allm_complete(model, prompt): max_retries = 10 - client = openai.OpenAI(api_key=api_key) + messages = [{"role": "user", "content": prompt}] for i in range(max_retries): try: - if chat_history: - messages = chat_history - messages.append({"role": "user", "content": prompt}) - else: - messages = [{"role": "user", "content": prompt}] - - response = client.chat.completions.create( + response = await litellm.acompletion( model=model, messages=messages, temperature=0, ) - return response.choices[0].message.content except Exception as e: print('************* Retrying *************') logging.error(f"Error: {e}") if i < max_retries - 1: - time.sleep(1) # Wait for 1秒 before retrying + await asyncio.sleep(1) else: logging.error('Max retries reached for prompt: ' + prompt) return "Error" - -async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY): - max_retries = 10 - messages = [{"role": "user", "content": prompt}] - for i in range(max_retries): - try: - async with openai.AsyncOpenAI(api_key=api_key) as client: - response = await client.chat.completions.create( - model=model, - messages=messages, - temperature=0, - ) - return response.choices[0].message.content - except Exception as e: - print('************* Retrying *************') - logging.error(f"Error: {e}") - if i < max_retries - 1: - await asyncio.sleep(1) # Wait for 1s before retrying - else: - logging.error('Max retries reached for prompt: ' + prompt) - return "Error" - def get_json_content(response): start_idx = response.find("```json") @@ -411,14 +392,13 @@ def add_preface_if_needed(data): def get_page_tokens(pdf_path, model="gpt-4o-2024-11-20", pdf_parser="PyPDF2"): - enc = tiktoken.encoding_for_model(model) if pdf_parser == "PyPDF2": pdf_reader = PyPDF2.PdfReader(pdf_path) page_list = [] for page_num in range(len(pdf_reader.pages)): page = pdf_reader.pages[page_num] page_text = page.extract_text() - token_length = len(enc.encode(page_text)) + token_length = litellm.token_counter(model=model, text=page_text) page_list.append((page_text, token_length)) return page_list elif pdf_parser == "PyMuPDF": @@ -430,7 +410,7 @@ def get_page_tokens(pdf_path, model="gpt-4o-2024-11-20", pdf_parser="PyPDF2"): page_list = [] for page in doc: page_text = page.get_text() - token_length = len(enc.encode(page_text)) + token_length = litellm.token_counter(model=model, text=page_text) page_list.append((page_text, token_length)) return page_list else: @@ -609,7 +589,7 @@ async def generate_node_summary(node, model=None): Directly return the description, do not include any other text. """ - response = await ChatGPT_API_async(model, prompt) + response = await allm_complete(model, prompt) return response @@ -654,7 +634,7 @@ def generate_doc_description(structure, model=None): Directly return the description, do not include any other text. """ - response = ChatGPT_API(model, prompt) + response = llm_complete(model, prompt) return response diff --git a/requirements.txt b/requirements.txt index 463db58f1..31e4f164b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ openai==1.101.0 +litellm pymupdf==1.26.4 PyPDF2==3.0.1 python-dotenv==1.1.0 -tiktoken==0.11.0 pyyaml==6.0.2