Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,12 @@ def test_query(self, langchain_dump_mock):
mocks.attach_mock(mock=agent._tmpl_attrs.get("runnable"), attribute="invoke")
agent.query(input="test query")
mocks.assert_has_calls(
[mock.call.invoke.invoke(input={"input": "test query"}, config=None)]
[
mock.call.invoke.invoke(
input={"input": "test query", "messages": [("user", "test query")]},
config=None,
)
]
)

def test_stream_query(self, langchain_dump_mock):
Expand All @@ -217,7 +222,10 @@ def test_stream_query(self, langchain_dump_mock):
agent._tmpl_attrs["runnable"].stream.return_value = []
list(agent.stream_query(input="test stream query"))
agent._tmpl_attrs["runnable"].stream.assert_called_once_with(
input={"input": "test stream query"},
input={
"input": "test stream query",
"messages": [("user", "test stream query")],
},
config=None,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,12 @@ def test_query(self, langchain_dump_mock):
mocks.attach_mock(mock=agent._runnable, attribute="invoke")
agent.query(input="test query")
mocks.assert_has_calls(
[mock.call.invoke.invoke(input={"input": "test query"}, config=None)]
[
mock.call.invoke.invoke(
input={"input": "test query", "messages": [("user", "test query")]},
config=None,
)
]
)

def test_stream_query(self, langchain_dump_mock):
Expand All @@ -217,7 +222,10 @@ def test_stream_query(self, langchain_dump_mock):
agent._runnable.stream.return_value = []
list(agent.stream_query(input="test stream query"))
agent._runnable.stream.assert_called_once_with(
input={"input": "test stream query"},
input={
"input": "test stream query",
"messages": [("user", "test stream query")],
},
config=None,
)

Expand Down
59 changes: 41 additions & 18 deletions vertexai/agent_engines/templates/langgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,16 @@
BaseLanguageModel = Any

try:
from langchain_google_vertexai.functions_utils import _ToolsType
from langchain_google_genai.functions_utils import _ToolsType

_ToolLike = _ToolsType
except ImportError:
_ToolLike = Any
try:
from langchain_google_vertexai.functions_utils import _ToolsType

_ToolLike = _ToolsType
except ImportError:
_ToolLike = Any

try:
from opentelemetry.sdk import trace
Expand Down Expand Up @@ -87,17 +92,29 @@ def _default_model_builder(
Returns:
BaseLanguageModel: The language model.
"""
import vertexai
from google.cloud.aiplatform import initializer
from langchain_google_vertexai import ChatVertexAI

model_kwargs = model_kwargs or {}
current_project = initializer.global_config.project
current_location = initializer.global_config.location
vertexai.init(project=project, location=location)
model = ChatVertexAI(model_name=model_name, **model_kwargs)
vertexai.init(project=current_project, location=current_location)
return model
try:
from langchain_google_genai import ChatGoogleGenerativeAI

model = ChatGoogleGenerativeAI(
model=model_name,
project=project,
location=location,
vertexai=True,
**model_kwargs,
)
return model
except ImportError:
import vertexai
from google.cloud.aiplatform import initializer
from langchain_google_vertexai import ChatVertexAI

current_project = initializer.global_config.project
current_location = initializer.global_config.location
vertexai.init(project=project, location=location)
model = ChatVertexAI(model_name=model_name, **model_kwargs)
vertexai.init(project=current_project, location=current_location)
return model


def _default_runnable_builder(
Expand Down Expand Up @@ -554,13 +571,16 @@ def query(
Returns:
The output of querying the Agent with the given input and config.
"""
from langchain.load import dump as langchain_load_dump
try:
from langchain_core.load import dumpd
except ImportError:
from langchain.load.dump import dumpd

if isinstance(input, str):
input = {"input": input}
input = {"input": input, "messages": [("user", input)]}
if not self._tmpl_attrs.get("runnable"):
self.set_up()
return langchain_load_dump.dumpd(
return dumpd(
self._tmpl_attrs.get("runnable").invoke(
input=input, config=config, **kwargs
)
Expand All @@ -587,18 +607,21 @@ def stream_query(
Yields:
The output of querying the Agent with the given input and config.
"""
from langchain.load import dump as langchain_load_dump
try:
from langchain_core.load import dumpd
except ImportError:
from langchain.load.dump import dumpd

if isinstance(input, str):
input = {"input": input}
input = {"input": input, "messages": [("user", input)]}
if not self._tmpl_attrs.get("runnable"):
self.set_up()
for chunk in self._tmpl_attrs.get("runnable").stream(
input=input,
config=config,
**kwargs,
):
yield langchain_load_dump.dumpd(chunk)
yield dumpd(chunk)

def get_state_history(
self,
Expand Down
61 changes: 41 additions & 20 deletions vertexai/preview/reasoning_engines/templates/langgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,16 @@
RunnableSerializable = Any

try:
from langchain_google_vertexai.functions_utils import _ToolsType
from langchain_google_genai.functions_utils import _ToolsType

_ToolLike = _ToolsType
except ImportError:
_ToolLike = Any
try:
from langchain_google_vertexai.functions_utils import _ToolsType

_ToolLike = _ToolsType
except ImportError:
_ToolLike = Any

try:
from opentelemetry.sdk import trace
Expand Down Expand Up @@ -95,17 +100,29 @@ def _default_model_builder(
Returns:
BaseLanguageModel: The language model.
"""
import vertexai
from google.cloud.aiplatform import initializer
from langchain_google_vertexai import ChatVertexAI

model_kwargs = model_kwargs or {}
current_project = initializer.global_config.project
current_location = initializer.global_config.location
vertexai.init(project=project, location=location)
model = ChatVertexAI(model_name=model_name, **model_kwargs)
vertexai.init(project=current_project, location=current_location)
return model
try:
from langchain_google_genai import ChatGoogleGenerativeAI

model = ChatGoogleGenerativeAI(
model=model_name,
project=project,
location=location,
vertexai=True,
**model_kwargs,
)
return model
except ImportError:
import vertexai
from google.cloud.aiplatform import initializer
from langchain_google_vertexai import ChatVertexAI

current_project = initializer.global_config.project
current_location = initializer.global_config.location
vertexai.init(project=project, location=location)
model = ChatVertexAI(model_name=model_name, **model_kwargs)
vertexai.init(project=current_project, location=current_location)
return model


def _default_runnable_builder(
Expand Down Expand Up @@ -541,15 +558,16 @@ def query(
Returns:
The output of querying the Agent with the given input and config.
"""
from langchain.load import dump as langchain_load_dump
try:
from langchain_core.load import dumpd
except ImportError:
from langchain.load.dump import dumpd

if isinstance(input, str):
input = {"input": input}
input = {"input": input, "messages": [("user", input)]}
if not self._runnable:
self.set_up()
return langchain_load_dump.dumpd(
self._runnable.invoke(input=input, config=config, **kwargs)
)
return dumpd(self._runnable.invoke(input=input, config=config, **kwargs))

def stream_query(
self,
Expand All @@ -572,14 +590,17 @@ def stream_query(
Yields:
The output of querying the Agent with the given input and config.
"""
from langchain.load import dump as langchain_load_dump
try:
from langchain_core.load import dumpd
except ImportError:
from langchain.load.dump import dumpd

if isinstance(input, str):
input = {"input": input}
input = {"input": input, "messages": [("user", input)]}
if not self._runnable:
self.set_up()
for chunk in self._runnable.stream(input=input, config=config, **kwargs):
yield langchain_load_dump.dumpd(chunk)
yield dumpd(chunk)

def get_state_history(
self,
Expand Down
Loading