From ff1eb2532338aa1711675728cd23fdb8fa42d4a3 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Tue, 16 Jun 2026 17:51:13 -0700 Subject: [PATCH] fix: implement bidi_stream_query in AdkApp register_operations The bidi_stream_query method was documented but not implemented in the GA version of AdkApp. This change ports the bidi_stream_query method from the preview reasoning_engines template to the GA agent_engines template in both vertexai and agentplatform namespaces, and registers it in register_operations. It also adds the method schema to the deployment class methods in ADK. Close #5611 FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/googleapis/python-aiplatform/pull/6876 from googleapis:release-please--branches--main a0975040f2c40acd693f2e9d997c5ab44c516c5d PiperOrigin-RevId: 933396959 --- agentplatform/agent_engines/templates/adk.py | 87 ++++++++++++++++++++ vertexai/agent_engines/templates/adk.py | 87 ++++++++++++++++++++ 2 files changed, 174 insertions(+) diff --git a/agentplatform/agent_engines/templates/adk.py b/agentplatform/agent_engines/templates/adk.py index 91737e688a..e455d471fe 100644 --- a/agentplatform/agent_engines/templates/adk.py +++ b/agentplatform/agent_engines/templates/adk.py @@ -1449,6 +1449,92 @@ async def streaming_agent_run_with_events(self, request_json: str): logging_enabled=bool(self._telemetry_enabled()), ) + async def bidi_stream_query( + self, + request_queue: Any, + ) -> AsyncIterable[Any]: + """Bidi streaming query the ADK application. + + Args: + request_queue: + The queue of requests to stream responses for, with the type of + asyncio.Queue[Any]. + + Raises: + TypeError: If the request_queue is not an asyncio.Queue instance. + ValueError: If the first request does not have a user_id. + ValidationError: If failed to convert to LiveRequest. + + Yields: + The stream responses of querying the ADK application. + """ + from google.adk.agents.live_request_queue import LiveRequest + from google.adk.agents.live_request_queue import LiveRequestQueue + from agentplatform._genai import _agent_engines_utils + + # Manual type check needed as Pydantic doesn't support asyncio.Queue. + if not isinstance(request_queue, asyncio.Queue): + raise TypeError("request_queue must be an asyncio.Queue instance.") + + first_request = await request_queue.get() + user_id = first_request.get("user_id") + if not user_id: + raise ValueError("The first request must have a user_id.") + + session_id = first_request.get("session_id") + run_config = first_request.get("run_config") + first_live_request = first_request.get("live_request") + + if not self._tmpl_attrs.get("runner"): + self.set_up() + if not session_id: + state = first_request.get("state") + session = await self.async_create_session(user_id=user_id, state=state) + session_id = session["id"] if isinstance(session, dict) else session.id + run_config = _validate_run_config(run_config) + + live_request_queue = LiveRequestQueue() + + if first_live_request and isinstance(first_live_request, Dict): + live_request_queue.send(LiveRequest.model_validate(first_live_request)) + + # Forwards live requests to the agent. + async def _forward_requests(): + while True: + request = await request_queue.get() + live_request = LiveRequest.model_validate(request) + live_request_queue.send(live_request) + + # Forwards events to the client. + async def _forward_events(): + if run_config: + events_async = self._tmpl_attrs.get("runner").run_live( + user_id=user_id, + session_id=session_id, + live_request_queue=live_request_queue, + run_config=run_config, + ) + else: + events_async = self._tmpl_attrs.get("runner").run_live( + user_id=user_id, + session_id=session_id, + live_request_queue=live_request_queue, + ) + async for event in events_async: + yield _agent_engines_utils.dump_event_for_json(event) + + requests_task = asyncio.create_task(_forward_requests()) + + try: + async for event in _forward_events(): + yield event + finally: + requests_task.cancel() + try: + await requests_task + except asyncio.CancelledError: + pass + async def async_get_session( self, *, @@ -2088,6 +2174,7 @@ def register_operations(self) -> Dict[str, List[str]]: "async_stream_query", "streaming_agent_run_with_events", ], + "bidi_stream": ["bidi_stream_query"], } def _telemetry_enabled(self) -> Optional[bool]: diff --git a/vertexai/agent_engines/templates/adk.py b/vertexai/agent_engines/templates/adk.py index 58c3025285..054db4b2da 100644 --- a/vertexai/agent_engines/templates/adk.py +++ b/vertexai/agent_engines/templates/adk.py @@ -1409,6 +1409,92 @@ async def streaming_agent_run_with_events(self, request_json: str): logging_enabled=bool(self._telemetry_enabled()), ) + async def bidi_stream_query( + self, + request_queue: Any, + ) -> AsyncIterable[Any]: + """Bidi streaming query the ADK application. + + Args: + request_queue: + The queue of requests to stream responses for, with the type of + asyncio.Queue[Any]. + + Raises: + TypeError: If the request_queue is not an asyncio.Queue instance. + ValueError: If the first request does not have a user_id. + ValidationError: If failed to convert to LiveRequest. + + Yields: + The stream responses of querying the ADK application. + """ + from google.adk.agents.live_request_queue import LiveRequest + from google.adk.agents.live_request_queue import LiveRequestQueue + from vertexai.agent_engines import _utils + + # Manual type check needed as Pydantic doesn't support asyncio.Queue. + if not isinstance(request_queue, asyncio.Queue): + raise TypeError("request_queue must be an asyncio.Queue instance.") + + first_request = await request_queue.get() + user_id = first_request.get("user_id") + if not user_id: + raise ValueError("The first request must have a user_id.") + + session_id = first_request.get("session_id") + run_config = first_request.get("run_config") + first_live_request = first_request.get("live_request") + + if not self._tmpl_attrs.get("runner"): + self.set_up() + if not session_id: + state = first_request.get("state") + session = await self.async_create_session(user_id=user_id, state=state) + session_id = session["id"] if isinstance(session, dict) else session.id + run_config = _validate_run_config(run_config) + + live_request_queue = LiveRequestQueue() + + if first_live_request and isinstance(first_live_request, Dict): + live_request_queue.send(LiveRequest.model_validate(first_live_request)) + + # Forwards live requests to the agent. + async def _forward_requests(): + while True: + request = await request_queue.get() + live_request = LiveRequest.model_validate(request) + live_request_queue.send(live_request) + + # Forwards events to the client. + async def _forward_events(): + if run_config: + events_async = self._tmpl_attrs.get("runner").run_live( + user_id=user_id, + session_id=session_id, + live_request_queue=live_request_queue, + run_config=run_config, + ) + else: + events_async = self._tmpl_attrs.get("runner").run_live( + user_id=user_id, + session_id=session_id, + live_request_queue=live_request_queue, + ) + async for event in events_async: + yield _utils.dump_event_for_json(event) + + requests_task = asyncio.create_task(_forward_requests()) + + try: + async for event in _forward_events(): + yield event + finally: + requests_task.cancel() + try: + await requests_task + except asyncio.CancelledError: + pass + async def async_get_session( self, *, @@ -1784,6 +1870,7 @@ def register_operations(self) -> Dict[str, List[str]]: "async_stream_query", "streaming_agent_run_with_events", ], + "bidi_stream": ["bidi_stream_query"], } def _telemetry_enabled(self) -> Optional[bool]: