Skip to content

Commit a85ec0e

Browse files
committed
Refactor: Update llm type annotations to use BaseLanguageModel
1 parent cf7a11b commit a85ec0e

10 files changed

Lines changed: 27 additions & 24 deletions

File tree

src/sherpa_ai/actions/arxiv_search.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from typing import Any
2-
1+
from typing import Any, Optional
2+
from langchain_core.language_models.base import BaseLanguageModel
33
from sherpa_ai.actions.base import BaseRetrievalAction
44
from sherpa_ai.tools import SearchArxivTool
55

@@ -46,7 +46,7 @@ class ArxivSearch(BaseRetrievalAction):
4646
"""
4747
role_description: str
4848
task: str
49-
llm: Any = None # The BaseLanguageModel from LangChain is not compatible with Pydantic 2 yet
49+
llm: Optional[BaseLanguageModel] = None
5050
description: str = SEARCH_SUMMARY_DESCRIPTION
5151
_search_tool: Any = None
5252

src/sherpa_ai/actions/context_search.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from typing import Any
1+
from typing import Any, Optional
22

33
from loguru import logger
4+
from langchain_core.language_models.base import BaseLanguageModel
45

56
from sherpa_ai.actions.base import BaseRetrievalAction
67
from sherpa_ai.connectors.vectorstores import get_vectordb
@@ -55,7 +56,7 @@ class ContextSearch(BaseRetrievalAction):
5556

5657
role_description: str
5758
task: str
58-
llm: Any = None # The BaseLanguageModel from LangChain is not compatible with Pydantic 2 yet
59+
llm: Optional[BaseLanguageModel] = None
5960
description: str = SEARCH_SUMMARY_DESCRIPTION
6061
_context: Any = None
6162

src/sherpa_ai/actions/deliberation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from typing import Any
2-
1+
from typing import Any, Optional
2+
from langchain_core.language_models.base import BaseLanguageModel
33
from sherpa_ai.actions.base import BaseAction
44

55

@@ -50,7 +50,7 @@ class Deliberation(BaseAction):
5050
"""
5151
# TODO: Make a version of Deliberation action that considers the context
5252
role_description: str
53-
llm: Any = None # The BaseLanguageModel from LangChain is not compatible with Pydantic 2 yet
53+
llm: Optional[BaseLanguageModel] = None
5454
description: str = DELIBERATION_DESCRIPTION
5555

5656
# Override the name and args from BaseAction

src/sherpa_ai/actions/google_search.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from typing import Any
1+
from typing import Any, Optional
22

33
from loguru import logger
4-
4+
from langchain_core.language_models.base import BaseLanguageModel
55
from sherpa_ai.actions.base import BaseRetrievalAction
66
from sherpa_ai.config.task_config import AgentConfig
77
from sherpa_ai.tools import SearchTool
@@ -65,7 +65,7 @@ class GoogleSearch(BaseRetrievalAction):
6565
"""
6666
role_description: str
6767
task: str
68-
llm: Any = None # The BaseLanguageModel from LangChain is not compatible with Pydantic 2 yet
68+
llm: Optional[BaseLanguageModel] = None
6969
description: str = SEARCH_SUMMARY_DESCRIPTION
7070
config: AgentConfig = AgentConfig()
7171
_search_tool: Any = None

src/sherpa_ai/actions/planning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Any, Optional
22

33
from loguru import logger
4-
4+
from langchain_core.language_models.base import BaseLanguageModel
55
from sherpa_ai.actions.base import BaseAction
66

77

@@ -216,7 +216,7 @@ class TaskPlanning(BaseAction):
216216
Task: Summarize the findings about quantum computing
217217
"""
218218

219-
llm: Any = None # The BaseLanguageModel from LangChain is not compatible with Pydantic 2 yet
219+
llm: Optional[BaseLanguageModel] = None
220220
num_steps: int = 5
221221
prompt: str = PLANNING_PROMPT
222222
revision_prompt: str = REVISION_PROMPT

src/sherpa_ai/actions/synthesize.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from typing import Any
1+
from typing import Any, Optional
22

3-
from langchain_core.language_models import BaseLanguageModel
3+
from langchain_core.language_models.base import BaseLanguageModel
44
from loguru import logger
55

66
from sherpa_ai.actions.base import BaseAction
@@ -42,7 +42,7 @@ class SynthesizeOutput(BaseAction):
4242
"""
4343

4444
role_description: str
45-
llm: Any = None # The BaseLanguageModel from LangChain is not compatible with Pydantic 2 yet
45+
llm: Optional[BaseLanguageModel] = None
4646
description: str = None
4747
add_citation: bool = False
4848

src/sherpa_ai/actions/utils/refinement.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
"""
44

55
from abc import ABC, abstractmethod
6-
from typing import Any, Callable
6+
from typing import Any, Callable, Optional
77

88
import numpy as np
99
from nltk import tokenize
1010
from numpy.typing import ArrayLike
1111
from pydantic import BaseModel
12+
from langchain_core.language_models.base import BaseLanguageModel
1213

1314

1415
SEARCH_SUMMARY_DESCRIPTION = """Question:{question}
@@ -65,7 +66,7 @@ class RefinementByQuery(BaseRefinement):
6566
>>> refinement = RefinementByQuery(llm=my_llm)
6667
>>> results = refinement.refinement(documents, query)
6768
"""
68-
llm: Any = None # The BaseLanguageModel from LangChain is not compatible with Pydantic 2 yet
69+
llm: Optional[BaseLanguageModel] = None
6970
description: str = SEARCH_SUMMARY_DESCRIPTION
7071
k: int = 3
7172

@@ -108,7 +109,7 @@ class RefinementBySentence(BaseRefinement):
108109
>>> refinement = RefinementBySentence(llm=my_llm)
109110
>>> results = refinement.refinement(documents, query)
110111
"""
111-
llm: Any = None # The BaseLanguageModel from LangChain is not compatible with Pydantic 2 yet
112+
llm: Optional[BaseLanguageModel] = None
112113
description: str = SEARCH_SUMMARY_DESCRIPTION_SENT
113114

114115
def refinement(self, documents: list[str], query: str) -> list[str]:

src/sherpa_ai/agents/base.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from loguru import logger
1010
from pydantic import BaseModel, ConfigDict
11+
from langchain_core.language_models.base import BaseLanguageModel
1112

1213
from sherpa_ai.actions.base import BaseAction
1314
from sherpa_ai.actions.exceptions import (
@@ -73,7 +74,7 @@ class BaseAgent(ABC, BaseModel):
7374
validations: List[BaseOutputProcessor] = []
7475
feedback_agent_name: str = "critic"
7576
global_regen_max: int = 12
76-
llm: Any = None
77+
llm: Optional[BaseLanguageModel] = None
7778
prompt_template: PromptTemplate = None
7879

7980
# Checks whether the execution of the agent should be stopped

src/sherpa_ai/policies/react_policy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from loguru import logger
1414
from pydantic import BaseModel, ConfigDict
15+
from langchain_core.language_models.base import BaseLanguageModel
1516

1617
from sherpa_ai.policies.base import BasePolicy, PolicyOutput
1718
from sherpa_ai.policies.exceptions import SherpaPolicyException
@@ -55,8 +56,7 @@ class ReactPolicy(BasePolicy):
5556

5657
role_description: str
5758
output_instruction: str
58-
# Cannot use langchain's BaseLanguageModel due to they are using Pydantic v1
59-
llm: Any = None
59+
llm: Optional[BaseLanguageModel] = None
6060
prompt_template: PromptTemplate = PromptTemplate("./sherpa_ai/prompts/prompts.json")
6161
response_format: dict = {
6262
"command": {

src/sherpa_ai/policies/react_sm_policy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from loguru import logger
1515
from pydantic import BaseModel, ConfigDict
16+
from langchain_core.language_models.base import BaseLanguageModel
1617

1718
from sherpa_ai.actions.base import BaseAction
1819
from sherpa_ai.policies.base import BasePolicy, PolicyOutput
@@ -55,8 +56,7 @@ class ReactStateMachinePolicy(BasePolicy):
5556
)
5657
role_description: str
5758
output_instruction: str
58-
# Cannot use langchain's BaseLanguageModel due to they are using Pydantic v1
59-
llm: Any = None
59+
llm: Optional[BaseLanguageModel] = None
6060
prompt_template: PromptTemplate = PromptTemplate("./sherpa_ai/prompts/prompts.json")
6161

6262
response_format: dict = {

0 commit comments

Comments
 (0)