-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathlangchain.py
More file actions
163 lines (139 loc) · 5.68 KB
/
langchain.py
File metadata and controls
163 lines (139 loc) · 5.68 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import contextvars
import logging
from typing import Any
from uuid import UUID
import braintrust
_logger = logging.getLogger("braintrust.wrappers.langchain")
try:
# Modern langchain
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.documents.base import Document
from langchain_core.messages.base import BaseMessage
from langchain_core.outputs.llm_result import LLMResult
except ImportError:
try:
# after the release of langchain v1, these submodules were also migrated to langchain_classic
from langchain_classic.callbacks.base import BaseCallbackHandler
from langchain_classic.schema import Document
from langchain_classic.schema.messages import BaseMessage
from langchain_classic.schema.output import LLMResult
except ImportError:
try:
# Old langchain from before the v1 version and the creation of langchain_classic
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import Document
from langchain.schema.messages import BaseMessage
from langchain.schema.output import LLMResult
except ImportError:
raise ImportError(
"Could not import langchain callbacks and schema submodules. "
"Install one of: langchain_classic, langchain-core, or langchain<1."
)
langchain_parent = contextvars.ContextVar("langchain_current_span", default=None)
class BraintrustTracer(BaseCallbackHandler):
def __init__(self, logger=None):
_logger.warning("BraintrustTracer is deprecated, use `pip install braintrust-langchain` instead")
self.logger = logger
self.spans = {}
def _start_span(self, parent_run_id, run_id, name: str | None, **kwargs: Any) -> Any:
assert run_id not in self.spans, f"Span already exists for run_id {run_id} (this is likely a bug)"
current_parent = langchain_parent.get()
if parent_run_id in self.spans:
parent_span = self.spans[parent_run_id]
elif current_parent is not None:
parent_span = current_parent
elif self.logger is not None:
parent_span = self.logger
else:
parent_span = braintrust
span = parent_span.start_span(name=name, **kwargs)
langchain_parent.set(span)
self.spans[run_id] = span
return span
def _end_span(self, run_id, **kwargs: Any) -> Any:
assert run_id in self.spans, f"No span exists for run_id {run_id} (this is likely a bug)"
span = self.spans.pop(run_id)
span.log(**kwargs)
if langchain_parent.get() == span:
langchain_parent.set(None)
span.end()
def on_chain_start(
self,
serialized: dict[str, Any],
inputs: dict[str, Any],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
**kwargs: Any,
) -> Any:
self._start_span(parent_run_id, run_id, "Chain", input=inputs, metadata={"tags": tags})
def on_chain_end(
self, outputs: dict[str, Any], *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any
) -> Any:
self._end_span(run_id, output=outputs)
def on_llm_start(
self,
serialized: dict[str, Any],
prompts: list[str],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
**kwargs: Any,
) -> Any:
self._start_span(
parent_run_id,
run_id,
"LLM",
input=prompts,
metadata={"tags": tags, **kwargs["invocation_params"]},
)
def on_chat_model_start(
self,
serialized: dict[str, Any],
messages: list[list[BaseMessage]],
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
**kwargs: Any,
) -> Any:
self._start_span(
parent_run_id,
run_id,
"Chat Model",
input=[[m.dict() for m in batch] for batch in messages],
metadata={"tags": tags, **kwargs["invocation_params"]},
)
def on_llm_end(
self, response: LLMResult, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any
) -> Any:
metrics = {}
token_usage = response.llm_output.get("token_usage", {})
if "total_tokens" in token_usage:
metrics["tokens"] = token_usage["total_tokens"]
if "prompt_tokens" in token_usage:
metrics["prompt_tokens"] = token_usage["prompt_tokens"]
if "completion_tokens" in token_usage:
metrics["completion_tokens"] = token_usage["completion_tokens"]
self._end_span(run_id, output=[[m.dict() for m in batch] for batch in response.generations], metrics=metrics)
def on_tool_start(
self,
serialized: dict[str, Any],
input_str: str,
*,
run_id: UUID,
parent_run_id: UUID | None = None,
tags: list[str] | None = None,
**kwargs: Any,
) -> Any:
_logger.warning("Starting tool, but it will not be traced in braintrust (unsupported)")
def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any:
pass
def on_retriever_start(self, query: str, *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any) -> Any:
_logger.warning("Starting retriever, but it will not be traced in braintrust (unsupported)")
def on_retriever_end(
self, response: list[Document], *, run_id: UUID, parent_run_id: UUID | None = None, **kwargs: Any
) -> Any:
pass