Skip to content

Commit 2db8698

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
chore: migrate legacy langgraph imports to maintain compatibility
PiperOrigin-RevId: 911224637
1 parent f32233c commit 2db8698

4 files changed

Lines changed: 102 additions & 42 deletions

File tree

tests/unit/vertex_langchain/test_agent_engine_templates_langgraph.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,12 @@ def test_query(self, langchain_dump_mock):
208208
mocks.attach_mock(mock=agent._tmpl_attrs.get("runnable"), attribute="invoke")
209209
agent.query(input="test query")
210210
mocks.assert_has_calls(
211-
[mock.call.invoke.invoke(input={"input": "test query"}, config=None)]
211+
[
212+
mock.call.invoke.invoke(
213+
input={"input": "test query", "messages": [("user", "test query")]},
214+
config=None,
215+
)
216+
]
212217
)
213218

214219
def test_stream_query(self, langchain_dump_mock):
@@ -217,7 +222,10 @@ def test_stream_query(self, langchain_dump_mock):
217222
agent._tmpl_attrs["runnable"].stream.return_value = []
218223
list(agent.stream_query(input="test stream query"))
219224
agent._tmpl_attrs["runnable"].stream.assert_called_once_with(
220-
input={"input": "test stream query"},
225+
input={
226+
"input": "test stream query",
227+
"messages": [("user", "test stream query")],
228+
},
221229
config=None,
222230
)
223231

tests/unit/vertex_langchain/test_reasoning_engine_templates_langgraph.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,12 @@ def test_query(self, langchain_dump_mock):
208208
mocks.attach_mock(mock=agent._runnable, attribute="invoke")
209209
agent.query(input="test query")
210210
mocks.assert_has_calls(
211-
[mock.call.invoke.invoke(input={"input": "test query"}, config=None)]
211+
[
212+
mock.call.invoke.invoke(
213+
input={"input": "test query", "messages": [("user", "test query")]},
214+
config=None,
215+
)
216+
]
212217
)
213218

214219
def test_stream_query(self, langchain_dump_mock):
@@ -217,7 +222,10 @@ def test_stream_query(self, langchain_dump_mock):
217222
agent._runnable.stream.return_value = []
218223
list(agent.stream_query(input="test stream query"))
219224
agent._runnable.stream.assert_called_once_with(
220-
input={"input": "test stream query"},
225+
input={
226+
"input": "test stream query",
227+
"messages": [("user", "test stream query")],
228+
},
221229
config=None,
222230
)
223231

vertexai/agent_engines/templates/langgraph.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,16 @@
3434
BaseLanguageModel = Any
3535

3636
try:
37-
from langchain_google_vertexai.functions_utils import _ToolsType
37+
from langchain_google_genai.functions_utils import _ToolsType
3838

3939
_ToolLike = _ToolsType
4040
except ImportError:
41-
_ToolLike = Any
41+
try:
42+
from langchain_google_vertexai.functions_utils import _ToolsType
43+
44+
_ToolLike = _ToolsType
45+
except ImportError:
46+
_ToolLike = Any
4247

4348
try:
4449
from opentelemetry.sdk import trace
@@ -87,17 +92,29 @@ def _default_model_builder(
8792
Returns:
8893
BaseLanguageModel: The language model.
8994
"""
90-
import vertexai
91-
from google.cloud.aiplatform import initializer
92-
from langchain_google_vertexai import ChatVertexAI
93-
9495
model_kwargs = model_kwargs or {}
95-
current_project = initializer.global_config.project
96-
current_location = initializer.global_config.location
97-
vertexai.init(project=project, location=location)
98-
model = ChatVertexAI(model_name=model_name, **model_kwargs)
99-
vertexai.init(project=current_project, location=current_location)
100-
return model
96+
try:
97+
from langchain_google_genai import ChatGoogleGenerativeAI
98+
99+
model = ChatGoogleGenerativeAI(
100+
model=model_name,
101+
project=project,
102+
location=location,
103+
vertexai=True,
104+
**model_kwargs,
105+
)
106+
return model
107+
except ImportError:
108+
import vertexai
109+
from google.cloud.aiplatform import initializer
110+
from langchain_google_vertexai import ChatVertexAI
111+
112+
current_project = initializer.global_config.project
113+
current_location = initializer.global_config.location
114+
vertexai.init(project=project, location=location)
115+
model = ChatVertexAI(model_name=model_name, **model_kwargs)
116+
vertexai.init(project=current_project, location=current_location)
117+
return model
101118

102119

103120
def _default_runnable_builder(
@@ -554,13 +571,16 @@ def query(
554571
Returns:
555572
The output of querying the Agent with the given input and config.
556573
"""
557-
from langchain.load import dump as langchain_load_dump
574+
try:
575+
from langchain_core.load import dumpd
576+
except ImportError:
577+
from langchain.load.dump import dumpd
558578

559579
if isinstance(input, str):
560-
input = {"input": input}
580+
input = {"input": input, "messages": [("user", input)]}
561581
if not self._tmpl_attrs.get("runnable"):
562582
self.set_up()
563-
return langchain_load_dump.dumpd(
583+
return dumpd(
564584
self._tmpl_attrs.get("runnable").invoke(
565585
input=input, config=config, **kwargs
566586
)
@@ -587,18 +607,21 @@ def stream_query(
587607
Yields:
588608
The output of querying the Agent with the given input and config.
589609
"""
590-
from langchain.load import dump as langchain_load_dump
610+
try:
611+
from langchain_core.load import dumpd
612+
except ImportError:
613+
from langchain.load.dump import dumpd
591614

592615
if isinstance(input, str):
593-
input = {"input": input}
616+
input = {"input": input, "messages": [("user", input)]}
594617
if not self._tmpl_attrs.get("runnable"):
595618
self.set_up()
596619
for chunk in self._tmpl_attrs.get("runnable").stream(
597620
input=input,
598621
config=config,
599622
**kwargs,
600623
):
601-
yield langchain_load_dump.dumpd(chunk)
624+
yield dumpd(chunk)
602625

603626
def get_state_history(
604627
self,

vertexai/preview/reasoning_engines/templates/langgraph.py

Lines changed: 41 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,16 @@
4242
RunnableSerializable = Any
4343

4444
try:
45-
from langchain_google_vertexai.functions_utils import _ToolsType
45+
from langchain_google_genai.functions_utils import _ToolsType
4646

4747
_ToolLike = _ToolsType
4848
except ImportError:
49-
_ToolLike = Any
49+
try:
50+
from langchain_google_vertexai.functions_utils import _ToolsType
51+
52+
_ToolLike = _ToolsType
53+
except ImportError:
54+
_ToolLike = Any
5055

5156
try:
5257
from opentelemetry.sdk import trace
@@ -95,17 +100,29 @@ def _default_model_builder(
95100
Returns:
96101
BaseLanguageModel: The language model.
97102
"""
98-
import vertexai
99-
from google.cloud.aiplatform import initializer
100-
from langchain_google_vertexai import ChatVertexAI
101-
102103
model_kwargs = model_kwargs or {}
103-
current_project = initializer.global_config.project
104-
current_location = initializer.global_config.location
105-
vertexai.init(project=project, location=location)
106-
model = ChatVertexAI(model_name=model_name, **model_kwargs)
107-
vertexai.init(project=current_project, location=current_location)
108-
return model
104+
try:
105+
from langchain_google_genai import ChatGoogleGenerativeAI
106+
107+
model = ChatGoogleGenerativeAI(
108+
model=model_name,
109+
project=project,
110+
location=location,
111+
vertexai=True,
112+
**model_kwargs,
113+
)
114+
return model
115+
except ImportError:
116+
import vertexai
117+
from google.cloud.aiplatform import initializer
118+
from langchain_google_vertexai import ChatVertexAI
119+
120+
current_project = initializer.global_config.project
121+
current_location = initializer.global_config.location
122+
vertexai.init(project=project, location=location)
123+
model = ChatVertexAI(model_name=model_name, **model_kwargs)
124+
vertexai.init(project=current_project, location=current_location)
125+
return model
109126

110127

111128
def _default_runnable_builder(
@@ -541,15 +558,16 @@ def query(
541558
Returns:
542559
The output of querying the Agent with the given input and config.
543560
"""
544-
from langchain.load import dump as langchain_load_dump
561+
try:
562+
from langchain_core.load import dumpd
563+
except ImportError:
564+
from langchain.load.dump import dumpd
545565

546566
if isinstance(input, str):
547-
input = {"input": input}
567+
input = {"input": input, "messages": [("user", input)]}
548568
if not self._runnable:
549569
self.set_up()
550-
return langchain_load_dump.dumpd(
551-
self._runnable.invoke(input=input, config=config, **kwargs)
552-
)
570+
return dumpd(self._runnable.invoke(input=input, config=config, **kwargs))
553571

554572
def stream_query(
555573
self,
@@ -572,14 +590,17 @@ def stream_query(
572590
Yields:
573591
The output of querying the Agent with the given input and config.
574592
"""
575-
from langchain.load import dump as langchain_load_dump
593+
try:
594+
from langchain_core.load import dumpd
595+
except ImportError:
596+
from langchain.load.dump import dumpd
576597

577598
if isinstance(input, str):
578-
input = {"input": input}
599+
input = {"input": input, "messages": [("user", input)]}
579600
if not self._runnable:
580601
self.set_up()
581602
for chunk in self._runnable.stream(input=input, config=config, **kwargs):
582-
yield langchain_load_dump.dumpd(chunk)
603+
yield dumpd(chunk)
583604

584605
def get_state_history(
585606
self,

0 commit comments

Comments
 (0)