Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions agentplatform/agent_engines/templates/adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1449,6 +1449,92 @@ async def streaming_agent_run_with_events(self, request_json: str):
logging_enabled=bool(self._telemetry_enabled()),
)

async def bidi_stream_query(
self,
request_queue: Any,
) -> AsyncIterable[Any]:
"""Bidi streaming query the ADK application.

Args:
request_queue:
The queue of requests to stream responses for, with the type of
asyncio.Queue[Any].

Raises:
TypeError: If the request_queue is not an asyncio.Queue instance.
ValueError: If the first request does not have a user_id.
ValidationError: If failed to convert to LiveRequest.

Yields:
The stream responses of querying the ADK application.
"""
from google.adk.agents.live_request_queue import LiveRequest
from google.adk.agents.live_request_queue import LiveRequestQueue
from agentplatform._genai import _agent_engines_utils

# Manual type check needed as Pydantic doesn't support asyncio.Queue.
if not isinstance(request_queue, asyncio.Queue):
raise TypeError("request_queue must be an asyncio.Queue instance.")

first_request = await request_queue.get()
user_id = first_request.get("user_id")
if not user_id:
raise ValueError("The first request must have a user_id.")

session_id = first_request.get("session_id")
run_config = first_request.get("run_config")
first_live_request = first_request.get("live_request")

if not self._tmpl_attrs.get("runner"):
self.set_up()
if not session_id:
state = first_request.get("state")
session = await self.async_create_session(user_id=user_id, state=state)
session_id = session["id"] if isinstance(session, dict) else session.id
run_config = _validate_run_config(run_config)

live_request_queue = LiveRequestQueue()

if first_live_request and isinstance(first_live_request, Dict):
live_request_queue.send(LiveRequest.model_validate(first_live_request))

# Forwards live requests to the agent.
async def _forward_requests():
while True:
request = await request_queue.get()
live_request = LiveRequest.model_validate(request)
live_request_queue.send(live_request)

# Forwards events to the client.
async def _forward_events():
if run_config:
events_async = self._tmpl_attrs.get("runner").run_live(
user_id=user_id,
session_id=session_id,
live_request_queue=live_request_queue,
run_config=run_config,
)
else:
events_async = self._tmpl_attrs.get("runner").run_live(
user_id=user_id,
session_id=session_id,
live_request_queue=live_request_queue,
)
async for event in events_async:
yield _agent_engines_utils.dump_event_for_json(event)

requests_task = asyncio.create_task(_forward_requests())

try:
async for event in _forward_events():
yield event
finally:
requests_task.cancel()
try:
await requests_task
except asyncio.CancelledError:
pass

async def async_get_session(
self,
*,
Expand Down Expand Up @@ -2088,6 +2174,7 @@ def register_operations(self) -> Dict[str, List[str]]:
"async_stream_query",
"streaming_agent_run_with_events",
],
"bidi_stream": ["bidi_stream_query"],
}

def _telemetry_enabled(self) -> Optional[bool]:
Expand Down
87 changes: 87 additions & 0 deletions vertexai/agent_engines/templates/adk.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,6 +1409,92 @@ async def streaming_agent_run_with_events(self, request_json: str):
logging_enabled=bool(self._telemetry_enabled()),
)

async def bidi_stream_query(
self,
request_queue: Any,
) -> AsyncIterable[Any]:
"""Bidi streaming query the ADK application.

Args:
request_queue:
The queue of requests to stream responses for, with the type of
asyncio.Queue[Any].

Raises:
TypeError: If the request_queue is not an asyncio.Queue instance.
ValueError: If the first request does not have a user_id.
ValidationError: If failed to convert to LiveRequest.

Yields:
The stream responses of querying the ADK application.
"""
from google.adk.agents.live_request_queue import LiveRequest
from google.adk.agents.live_request_queue import LiveRequestQueue
from vertexai.agent_engines import _utils

# Manual type check needed as Pydantic doesn't support asyncio.Queue.
if not isinstance(request_queue, asyncio.Queue):
raise TypeError("request_queue must be an asyncio.Queue instance.")

first_request = await request_queue.get()
user_id = first_request.get("user_id")
if not user_id:
raise ValueError("The first request must have a user_id.")

session_id = first_request.get("session_id")
run_config = first_request.get("run_config")
first_live_request = first_request.get("live_request")

if not self._tmpl_attrs.get("runner"):
self.set_up()
if not session_id:
state = first_request.get("state")
session = await self.async_create_session(user_id=user_id, state=state)
session_id = session["id"] if isinstance(session, dict) else session.id
run_config = _validate_run_config(run_config)

live_request_queue = LiveRequestQueue()

if first_live_request and isinstance(first_live_request, Dict):
live_request_queue.send(LiveRequest.model_validate(first_live_request))

# Forwards live requests to the agent.
async def _forward_requests():
while True:
request = await request_queue.get()
live_request = LiveRequest.model_validate(request)
live_request_queue.send(live_request)

# Forwards events to the client.
async def _forward_events():
if run_config:
events_async = self._tmpl_attrs.get("runner").run_live(
user_id=user_id,
session_id=session_id,
live_request_queue=live_request_queue,
run_config=run_config,
)
else:
events_async = self._tmpl_attrs.get("runner").run_live(
user_id=user_id,
session_id=session_id,
live_request_queue=live_request_queue,
)
async for event in events_async:
yield _utils.dump_event_for_json(event)

requests_task = asyncio.create_task(_forward_requests())

try:
async for event in _forward_events():
yield event
finally:
requests_task.cancel()
try:
await requests_task
except asyncio.CancelledError:
pass

async def async_get_session(
self,
*,
Expand Down Expand Up @@ -1784,6 +1870,7 @@ def register_operations(self) -> Dict[str, List[str]]:
"async_stream_query",
"streaming_agent_run_with_events",
],
"bidi_stream": ["bidi_stream_query"],
}

def _telemetry_enabled(self) -> Optional[bool]:
Expand Down
Loading