@@ -855,3 +855,114 @@ async def test_activity_tool_supports_complex_inputs_via_adk(client: Client):
855855 ),
856856 "annotate_trip" : "SFO->LAX:3" ,
857857 }
858+
859+
860+ def litellm_agent (model_name : str ) -> Agent :
861+ return Agent (
862+ name = "litellm_test_agent" ,
863+ model = TemporalModel (model_name ),
864+ )
865+
866+
867+ @workflow .defn
868+ class LiteLlmWorkflow :
869+ @workflow .run
870+ async def run (self , prompt : str , model_name : str ) -> Event | None :
871+ agent = litellm_agent (model_name )
872+
873+ runner = InMemoryRunner (
874+ agent = agent ,
875+ app_name = "litellm_test_app" ,
876+ )
877+
878+ session = await runner .session_service .create_session (
879+ app_name = "litellm_test_app" , user_id = "test"
880+ )
881+
882+ last_event = None
883+ async with Aclosing (
884+ runner .run_async (
885+ user_id = "test" ,
886+ session_id = session .id ,
887+ new_message = types .Content (role = "user" , parts = [types .Part (text = prompt )]),
888+ )
889+ ) as agen :
890+ async for event in agen :
891+ last_event = event
892+
893+ return last_event
894+
895+
896+ @pytest .mark .asyncio
897+ async def test_litellm_model (client : Client ):
898+ """Test that a litellm-backed model works with TemporalModel through a full Temporal workflow."""
899+ import litellm as litellm_module
900+ from google .adk .models .lite_llm import LiteLlm
901+ from google .adk .models .registry import _llm_registry_dict
902+ from litellm import ModelResponse
903+ from litellm .llms .custom_llm import CustomLLM
904+
905+ class FakeLiteLlmProvider (CustomLLM ):
906+ """A fake litellm provider that returns canned responses locally."""
907+
908+ def _make_response (self , model : str ) -> ModelResponse :
909+ return ModelResponse (
910+ choices = [
911+ {
912+ "message" : {
913+ "content" : "hello from litellm" ,
914+ "role" : "assistant" ,
915+ },
916+ "index" : 0 ,
917+ "finish_reason" : "stop" ,
918+ }
919+ ],
920+ model = model ,
921+ )
922+
923+ def completion (self , * args : Any , ** kwargs : Any ) -> ModelResponse :
924+ model = args [0 ] if args else kwargs .get ("model" , "unknown" )
925+ return self ._make_response (model )
926+
927+ async def acompletion (self , * args : Any , ** kwargs : Any ) -> ModelResponse :
928+ model = args [0 ] if args else kwargs .get ("model" , "unknown" )
929+ return self ._make_response (model )
930+
931+ # Register our fake provider with litellm
932+ litellm_module .custom_provider_map = [
933+ {"provider" : "fake" , "custom_handler" : FakeLiteLlmProvider ()}
934+ ]
935+
936+ try :
937+ # Directly register LiteLlm for "fake/.*" model names in ADK's registry backing dict
938+ _llm_registry_dict [r"fake/.*" ] = LiteLlm
939+ # Clear the resolve LRU cache so the new pattern is picked up
940+ LLMRegistry .resolve .cache_clear ()
941+ new_config = client .config ()
942+ new_config ["plugins" ] = [GoogleAdkPlugin ()]
943+ client = Client (** new_config )
944+
945+ async with Worker (
946+ client ,
947+ task_queue = "adk-task-queue-litellm" ,
948+ workflows = [LiteLlmWorkflow ],
949+ max_cached_workflows = 0 ,
950+ ):
951+ handle = await client .start_workflow (
952+ LiteLlmWorkflow .run ,
953+ args = ["Say hello" , "fake/test-model" ],
954+ id = f"litellm-agent-workflow-{ uuid .uuid4 ()} " ,
955+ task_queue = "adk-task-queue-litellm" ,
956+ execution_timeout = timedelta (seconds = 60 ),
957+ )
958+ result = await handle .result ()
959+
960+ assert result is not None
961+ assert result .content is not None
962+ assert result .content .parts is not None
963+ assert result .content .parts [0 ].text == "hello from litellm"
964+ finally :
965+ # Clean up registry state
966+ _llm_registry_dict .pop (r"fake/.*" , None )
967+ LLMRegistry .resolve .cache_clear ()
968+ litellm_module .custom_provider_map = []
0 commit comments