|
3 | 3 | """ |
4 | 4 |
|
5 | 5 | from abc import ABC, abstractmethod |
6 | | -from typing import Any, Callable |
| 6 | +from typing import Any, Callable, Optional |
7 | 7 |
|
8 | 8 | import numpy as np |
9 | 9 | from nltk import tokenize |
10 | 10 | from numpy.typing import ArrayLike |
11 | 11 | from pydantic import BaseModel |
| 12 | +from langchain_core.language_models.base import BaseLanguageModel |
12 | 13 |
|
13 | 14 |
|
14 | 15 | SEARCH_SUMMARY_DESCRIPTION = """Question:{question} |
@@ -65,7 +66,7 @@ class RefinementByQuery(BaseRefinement): |
65 | 66 | >>> refinement = RefinementByQuery(llm=my_llm) |
66 | 67 | >>> results = refinement.refinement(documents, query) |
67 | 68 | """ |
68 | | - llm: Any = None # The BaseLanguageModel from LangChain is not compatible with Pydantic 2 yet |
| 69 | + llm: Optional[BaseLanguageModel] = None |
69 | 70 | description: str = SEARCH_SUMMARY_DESCRIPTION |
70 | 71 | k: int = 3 |
71 | 72 |
|
@@ -108,7 +109,7 @@ class RefinementBySentence(BaseRefinement): |
108 | 109 | >>> refinement = RefinementBySentence(llm=my_llm) |
109 | 110 | >>> results = refinement.refinement(documents, query) |
110 | 111 | """ |
111 | | - llm: Any = None # The BaseLanguageModel from LangChain is not compatible with Pydantic 2 yet |
| 112 | + llm: Optional[BaseLanguageModel] = None |
112 | 113 | description: str = SEARCH_SUMMARY_DESCRIPTION_SENT |
113 | 114 |
|
114 | 115 | def refinement(self, documents: list[str], query: str) -> list[str]: |
|
0 commit comments