Skip to content

Commit 6b131d1

Browse files
vertex-sdk-botcopybara-github
authored andcommitted
feat: Add google.genai.types.Content as an allowed message type to ADK's stream_query method
PiperOrigin-RevId: 754073499
1 parent 1cbe028 commit 6b131d1

File tree

2 files changed

+42
-3
lines changed

2 files changed

+42
-3
lines changed

tests/unit/vertex_adk/test_reasoning_engine_templates_adk.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from unittest import mock
1818

1919
from google import auth
20+
from google.genai import types
2021
import vertexai
2122
from google.cloud.aiplatform import initializer
2223
from vertexai.preview import reasoning_engines
@@ -172,6 +173,28 @@ def test_stream_query(self):
172173
)
173174
assert len(events) == 1
174175

176+
def test_stream_query_with_content(self):
177+
app = reasoning_engines.AdkApp(
178+
agent=Agent(name="test_agent", model=_TEST_MODEL)
179+
)
180+
assert app._tmpl_attrs.get("runner") is None
181+
app.set_up()
182+
app._tmpl_attrs["runner"] = _MockRunner()
183+
events = list(
184+
app.stream_query(
185+
user_id="test_user_id",
186+
message=types.Content(
187+
role="user",
188+
parts=[
189+
types.Part(
190+
text="test message with content",
191+
)
192+
],
193+
).model_dump(),
194+
)
195+
)
196+
assert len(events) == 1
197+
175198
def test_streaming_agent_run_with_events(self):
176199
app = reasoning_engines.AdkApp(
177200
agent=Agent(name="test_agent", model=_TEST_MODEL)

vertexai/preview/reasoning_engines/templates/adk.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,17 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
#
16-
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
16+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
17+
1718

1819
if TYPE_CHECKING:
20+
try:
21+
from google.genai import types
22+
23+
ContentDict = types.Content
24+
except (ImportError, AttributeError):
25+
ContentDict = Dict
26+
1927
try:
2028
from google.adk.events.event import Event
2129

@@ -442,7 +450,7 @@ def set_up(self):
442450
def stream_query(
443451
self,
444452
*,
445-
message: str,
453+
message: Union[str, "ContentDict"],
446454
user_id: str,
447455
session_id: Optional[str] = None,
448456
**kwargs,
@@ -466,7 +474,15 @@ def stream_query(
466474
"""
467475
from google.genai import types
468476

469-
content = types.Content(role="user", parts=[types.Part(text=message)])
477+
if isinstance(message, Dict):
478+
content = types.Content.model_validate(message)
479+
elif isinstance(message, str):
480+
content = types.Content(role="user", parts=[types.Part(text=message)])
481+
else:
482+
raise TypeError(
483+
"message must be a string or a dictionary representing a Content object."
484+
)
485+
470486
if not self._tmpl_attrs.get("runner"):
471487
self.set_up()
472488
if not session_id:

0 commit comments

Comments
 (0)