11import logging
22from functools import wraps
3- from typing import Any , Callable , Mapping , Optional , Sequence
3+ from typing import Any , Callable , Mapping , Optional , Sequence , TypeVar
4+ from typing_extensions import ParamSpec
45
56from opentelemetry .trace import Span , Tracer
67from opentelemetry import context as context_api
78import requests
89
910from humanloop .base_client import BaseHumanloop
1011from humanloop .context import get_trace_id , set_trace_id
12+ from humanloop .types .chat_message import ChatMessage
1113from humanloop .utilities .helpers import bind_args
1214from humanloop .eval_utils .types import File
1315from humanloop .otel .constants import (
2123logger = logging .getLogger ("humanloop.sdk" )
2224
2325
26+ P = ParamSpec ("P" )
27+ R = TypeVar ("R" )
28+
29+
2430def flow (
2531 client : "BaseHumanloop" ,
2632 opentelemetry_tracer : Tracer ,
@@ -29,19 +35,19 @@ def flow(
2935):
3036 flow_kernel = {"attributes" : attributes or {}}
3137
32- def decorator (func : Callable ) :
38+ def decorator (func : Callable [ P , R ]) -> Callable [ P , R ] :
3339 decorator_path = path or func .__name__
3440 file_type = "flow"
3541
3642 @wraps (func )
37- def wrapper (* args : Sequence [ Any ] , ** kwargs : Mapping [ str , Any ] ) -> Any :
43+ def wrapper (* args : P . args , ** kwargs : P . kwargs ) -> Optional [ R ] :
3844 span : Span
3945 with opentelemetry_tracer .start_as_current_span ("humanloop.flow" ) as span : # type: ignore
4046 trace_id = get_trace_id ()
4147 args_to_func = bind_args (func , args , kwargs )
4248
4349 # Create the trace ahead so we have a parent ID to reference
44- log_inputs = {
50+ init_log_inputs = {
4551 "inputs" : {k : v for k , v in args_to_func .items () if k != "messages" },
4652 "messages" : args_to_func .get ("messages" ),
4753 "trace_parent_id" : trace_id ,
@@ -53,7 +59,7 @@ def wrapper(*args: Sequence[Any], **kwargs: Mapping[str, Any]) -> Any:
5359 "path" : path ,
5460 "flow" : flow_kernel ,
5561 "log_status" : "incomplete" ,
56- ** log_inputs ,
62+ ** init_log_inputs ,
5763 },
5864 ).json ()
5965 # log = client.flows.log(
@@ -66,34 +72,37 @@ def wrapper(*args: Sequence[Any], **kwargs: Mapping[str, Any]) -> Any:
6672 span .set_attribute (HUMANLOOP_PATH_KEY , decorator_path )
6773 span .set_attribute (HUMANLOOP_FILE_TYPE_KEY , file_type )
6874
69- # Call the decorated function
75+ func_output : Optional [R ]
76+ log_output : str
77+ log_error : Optional [str ]
78+ log_output_message : ChatMessage
7079 try :
71- output = func (* args , ** kwargs )
80+ func_output = func (* args , ** kwargs )
7281 if (
73- isinstance (output , dict )
74- and len (output .keys ()) == 2
75- and "role" in output
76- and "content" in output
82+ isinstance (func_output , dict )
83+ and len (func_output .keys ()) == 2
84+ and "role" in func_output
85+ and "content" in func_output
7786 ):
78- output_message = output
79- output = None
87+ log_output_message = ChatMessage ( ** func_output )
88+ log_output = None
8089 else :
81- output = process_output (func = func , output = output )
82- output_message = None
83- error = None
90+ log_output = process_output (func = func , output = func_output )
91+ log_output_message = None
92+ log_error = None
8493 except Exception as e :
8594 logger .error (f"Error calling { func .__name__ } : { e } " )
8695 output = None
87- output_message = None
88- error = str (e )
96+ log_output_message = None
97+ log_error = str (e )
8998
9099 flow_log = {
91100 "inputs" : {k : v for k , v in args_to_func .items () if k != "messages" },
92101 "messages" : args_to_func .get ("messages" ),
93102 "log_status" : "complete" ,
94- "output" : output ,
95- "error" : error ,
96- "output_message" : output_message ,
103+ "output" : log_output ,
104+ "error" : log_error ,
105+ "output_message" : log_output_message ,
97106 "id" : init_log ["id" ],
98107 }
99108
0 commit comments