From 9f4ff20747df9f92cd647107ef53113080f40893 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Tue, 16 Jun 2026 17:59:14 -0700 Subject: [PATCH] fix: support GetSessionConfig in AdkApp templates and forward it in Runner.run_async FUTURE_COPYBARA_INTEGRATE_REVIEW=https://github.com/googleapis/python-aiplatform/pull/6876 from googleapis:release-please--branches--main a0975040f2c40acd693f2e9d997c5ab44c516c5d PiperOrigin-RevId: 933400638 --- agentplatform/agent_engines/templates/adk.py | 1016 +++++++++-------- .../test_agent_engine_templates_adk.py | 35 + vertexai/agent_engines/templates/adk.py | 924 +++++++-------- .../reasoning_engines/templates/adk.py | 1014 ++++++++-------- 4 files changed, 1518 insertions(+), 1471 deletions(-) diff --git a/agentplatform/agent_engines/templates/adk.py b/agentplatform/agent_engines/templates/adk.py index 91737e688a..11c37bf5d0 100644 --- a/agentplatform/agent_engines/templates/adk.py +++ b/agentplatform/agent_engines/templates/adk.py @@ -683,11 +683,11 @@ def _use_client_cert_effective() -> bool: class AdkApp: - """An ADK Application.""" + """An ADK Application.""" - agent_framework = "google-adk" + agent_framework = "google-adk" - def __init__( + def __init__( self, *, app: "App" = None, @@ -703,7 +703,7 @@ def __init__( ] = None, instrumentor_builder: Optional[Callable[..., Any]] = None, ): - """An ADK Application. + """An ADK Application. See https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/adk for details on how to develop ADK applications on Agent Engine. @@ -743,34 +743,34 @@ def __init__( If not provided, a default instrumentor builder will be used. This parameter is ignored if `enable_tracing` is False. """ - import os - from google.cloud.aiplatform import initializer + import os + from google.cloud.aiplatform import initializer - adk_version = get_adk_version() - if not is_version_sufficient("1.5.0"): - msg = ( + adk_version = get_adk_version() + if not is_version_sufficient("1.5.0"): + msg = ( f"Unsupported google-adk version: {adk_version}, please use " "google-adk>=1.5.0 for AdkApp deployment on Agent Engine." ) - raise ValueError(msg) + raise ValueError(msg) - if not agent and not app: - raise ValueError("One of `agent` or `app` must be provided.") - if app: - if app_name: - raise ValueError( + if not agent and not app: + raise ValueError("One of `agent` or `app` must be provided.") + if app: + if app_name: + raise ValueError( "When app is provided, app_name should not be provided, " "since it will be derived from app.name." ) - if agent: - raise ValueError("When app is provided, agent should not be provided.") - if plugins: - raise ValueError( + if agent: + raise ValueError("When app is provided, agent should not be provided.") + if plugins: + raise ValueError( "When app is provided, plugins should not be provided and" " should be provided in the app instead." ) - self._tmpl_attrs: Dict[str, Any] = { + self._tmpl_attrs: Dict[str, Any] = { "project": initializer.global_config.project, "location": initializer.global_config.location, "agent": agent, @@ -788,108 +788,108 @@ def __init__( ), } - def _serialize(self, obj: Any) -> Any: - """Serializes an object to be JSON compatible.""" - if hasattr(obj, "model_dump"): - return obj.model_dump(mode="json") - elif hasattr(obj, "dict"): - return self._serialize(obj.dict()) - elif isinstance(obj, dict): - return {k: self._serialize(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [self._serialize(v) for v in obj] - return obj - - def _app_name(self) -> str: - """Returns the app name.""" - app = self._tmpl_attrs.get("app") - return app.name if app else self._tmpl_attrs.get("app_name") - - async def _init_session( + def _serialize(self, obj: Any) -> Any: + """Serializes an object to be JSON compatible.""" + if hasattr(obj, "model_dump"): + return obj.model_dump(mode="json") + elif hasattr(obj, "dict"): + return self._serialize(obj.dict()) + elif isinstance(obj, dict): + return {k: self._serialize(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._serialize(v) for v in obj] + return obj + + def _app_name(self) -> str: + """Returns the app name.""" + app = self._tmpl_attrs.get("app") + return app.name if app else self._tmpl_attrs.get("app_name") + + async def _init_session( self, session_service: "BaseSessionService", artifact_service: "BaseArtifactService", request: _StreamRunRequest, ): - """Initializes the session, and returns the session id.""" - from google.adk.events.event import Event + """Initializes the session, and returns the session id.""" + from google.adk.events.event import Event - session_state = None - if request.authorizations: - session_state = {} - for auth_id, auth in request.authorizations.items(): - auth = _Authorization(**auth) - session_state[auth_id] = auth.access_token + session_state = None + if request.authorizations: + session_state = {} + for auth_id, auth in request.authorizations.items(): + auth = _Authorization(**auth) + session_state[auth_id] = auth.access_token - session = await session_service.create_session( + session = await session_service.create_session( app_name=self._app_name(), user_id=request.user_id, state=session_state, ) - if not session: - raise RuntimeError("Create session failed.") - if request.events: - for event in request.events: - await session_service.append_event(session, Event(**event)) - if request.artifacts: - await self._save_artifacts(session.id, artifact_service, request) - return session - - async def _save_artifacts( + if not session: + raise RuntimeError("Create session failed.") + if request.events: + for event in request.events: + await session_service.append_event(session, Event(**event)) + if request.artifacts: + await self._save_artifacts(session.id, artifact_service, request) + return session + + async def _save_artifacts( self, session_id: str, artifact_service: "BaseArtifactService", request: _StreamRunRequest, ): - """Saves the artifacts.""" - if request.artifacts: - for artifact in request.artifacts: - artifact = _Artifact(**artifact) - for version_data in sorted( + """Saves the artifacts.""" + if request.artifacts: + for artifact in request.artifacts: + artifact = _Artifact(**artifact) + for version_data in sorted( artifact.versions, key=lambda x: x["version"] ): - version_data = _ArtifactVersion(**version_data) - saved_version = await artifact_service.save_artifact( + version_data = _ArtifactVersion(**version_data) + saved_version = await artifact_service.save_artifact( app_name=self._app_name(), user_id=request.user_id, session_id=session_id, filename=artifact.file_name, artifact=version_data.data, ) - if saved_version != version_data.version: - from google.cloud.aiplatform import base + if saved_version != version_data.version: + from google.cloud.aiplatform import base - _LOGGER = base.Logger(__name__) - _LOGGER.debug( + _LOGGER = base.Logger(__name__) + _LOGGER.debug( "Artifact '%s' saved at version %s instead of %s", artifact.file_name, saved_version, version_data.version, ) - async def _convert_response_events( + async def _convert_response_events( self, user_id: str, session_id: str, events: List["Event"], artifact_service: Optional["BaseArtifactService"], ) -> _StreamingRunResponse: - """Converts the events to the streaming run response object.""" - import collections + """Converts the events to the streaming run response object.""" + import collections - result = _StreamingRunResponse( + result = _StreamingRunResponse( events=events, artifacts=[], session_id=session_id ) - # Save the generated artifacts into the result object. - artifact_versions = collections.defaultdict(list) - for event in events: - if event.actions and event.actions.artifact_delta: - for key, version in event.actions.artifact_delta.items(): - artifact_versions[key].append(version) + # Save the generated artifacts into the result object. + artifact_versions = collections.defaultdict(list) + for event in events: + if event.actions and event.actions.artifact_delta: + for key, version in event.actions.artifact_delta.items(): + artifact_versions[key].append(version) - for key, versions in artifact_versions.items(): - result.artifacts.append( + for key, versions in artifact_versions.items(): + result.artifacts.append( _Artifact( file_name=key, versions=[ @@ -908,13 +908,13 @@ async def _convert_response_events( ) ) - return result.dump() + return result.dump() - def clone(self): - """Returns a clone of the ADK application.""" - import copy + def clone(self): + """Returns a clone of the ADK application.""" + import copy - return self.__class__( + return self.__class__( app=copy.deepcopy(self._tmpl_attrs.get("app")), enable_tracing=self._tmpl_attrs.get("enable_tracing"), agent=( @@ -941,59 +941,59 @@ def clone(self): instrumentor_builder=self._tmpl_attrs.get("instrumentor_builder"), ) - def set_up(self): - """Sets up the ADK application.""" - import os - from google.adk.runners import Runner - from google.adk.sessions.in_memory_session_service import InMemorySessionService - from google.adk.artifacts.in_memory_artifact_service import ( + def set_up(self): + """Sets up the ADK application.""" + import os + from google.adk.runners import Runner + from google.adk.sessions.in_memory_session_service import InMemorySessionService + from google.adk.artifacts.in_memory_artifact_service import ( InMemoryArtifactService, ) - from google.adk.memory.in_memory_memory_service import InMemoryMemoryService - from google.adk.auth.credential_service.in_memory_credential_service import ( + from google.adk.memory.in_memory_memory_service import InMemoryMemoryService + from google.adk.auth.credential_service.in_memory_credential_service import ( InMemoryCredentialService, ) - os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1" - project = self._tmpl_attrs.get("project") - if project: - os.environ["GOOGLE_CLOUD_PROJECT"] = project - location = self._tmpl_attrs.get("location") - if location: - if "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION" not in os.environ: - os.environ["GOOGLE_CLOUD_AGENT_ENGINE_LOCATION"] = location - if "GOOGLE_CLOUD_LOCATION" not in os.environ: - os.environ["GOOGLE_CLOUD_LOCATION"] = location - agent_engine_location = os.environ.get( + os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1" + project = self._tmpl_attrs.get("project") + if project: + os.environ["GOOGLE_CLOUD_PROJECT"] = project + location = self._tmpl_attrs.get("location") + if location: + if "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION" not in os.environ: + os.environ["GOOGLE_CLOUD_AGENT_ENGINE_LOCATION"] = location + if "GOOGLE_CLOUD_LOCATION" not in os.environ: + os.environ["GOOGLE_CLOUD_LOCATION"] = location + agent_engine_location = os.environ.get( "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", # the runtime env var (if set) location, # the location set in the AdkApp template ) - express_mode_api_key = self._tmpl_attrs.get("express_mode_api_key") - if express_mode_api_key and not project: - os.environ["GOOGLE_API_KEY"] = express_mode_api_key - # Clear location and project env vars if express mode api key is provided. - os.environ.pop("GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", None) - os.environ.pop("GOOGLE_CLOUD_LOCATION", None) - os.environ.pop("GOOGLE_CLOUD_PROJECT", None) - location = None - - # Disable content capture in custom ADK spans unless user enabled - # tracing explicitly with the old flag - # (this is to preserve compatibility with old behavior). - if self._tmpl_attrs.get("enable_tracing"): - os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] = "true" - else: - os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] = "false" + express_mode_api_key = self._tmpl_attrs.get("express_mode_api_key") + if express_mode_api_key and not project: + os.environ["GOOGLE_API_KEY"] = express_mode_api_key + # Clear location and project env vars if express mode api key is provided. + os.environ.pop("GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", None) + os.environ.pop("GOOGLE_CLOUD_LOCATION", None) + os.environ.pop("GOOGLE_CLOUD_PROJECT", None) + location = None + + # Disable content capture in custom ADK spans unless user enabled + # tracing explicitly with the old flag + # (this is to preserve compatibility with old behavior). + if self._tmpl_attrs.get("enable_tracing"): + os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] = "true" + else: + os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] = "false" - enable_logging = bool(self._telemetry_enabled()) + enable_logging = bool(self._telemetry_enabled()) - custom_instrumentor = self._tmpl_attrs.get("instrumentor_builder") + custom_instrumentor = self._tmpl_attrs.get("instrumentor_builder") - if self._tmpl_attrs.get("enable_tracing"): - _warn_if_telemetry_api_disabled() + if self._tmpl_attrs.get("enable_tracing"): + _warn_if_telemetry_api_disabled() - if self._tmpl_attrs.get("enable_tracing") is False: - _warn( + if self._tmpl_attrs.get("enable_tracing") is False: + _warn( ( "Your 'enable_tracing=False' setting is being deprecated " "and will be removed in a future release.\n" @@ -1018,102 +1018,102 @@ def set_up(self): ), ) - if custom_instrumentor and self._tracing_enabled(): - self._tmpl_attrs["instrumentor"] = custom_instrumentor(self.project_id()) + if custom_instrumentor and self._tracing_enabled(): + self._tmpl_attrs["instrumentor"] = custom_instrumentor(self.project_id()) - if not custom_instrumentor: - self._tmpl_attrs["instrumentor"] = _default_instrumentor_builder( + if not custom_instrumentor: + self._tmpl_attrs["instrumentor"] = _default_instrumentor_builder( self.project_id(), enable_tracing=self._tracing_enabled(), enable_logging=enable_logging, ) - if not self._tmpl_attrs.get("app_name"): - if "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ: - self._tmpl_attrs["app_name"] = os.environ.get( + if not self._tmpl_attrs.get("app_name"): + if "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ: + self._tmpl_attrs["app_name"] = os.environ.get( "GOOGLE_CLOUD_AGENT_ENGINE_ID", ) - else: - self._tmpl_attrs["app_name"] = _DEFAULT_APP_NAME + else: + self._tmpl_attrs["app_name"] = _DEFAULT_APP_NAME - artifact_service_builder = self._tmpl_attrs.get("artifact_service_builder") - if artifact_service_builder: - self._tmpl_attrs["artifact_service"] = artifact_service_builder() - else: - self._tmpl_attrs["artifact_service"] = InMemoryArtifactService() - - session_service_builder = self._tmpl_attrs.get("session_service_builder") - if session_service_builder: - self._tmpl_attrs["session_service"] = session_service_builder() - elif "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ: - try: - from google.adk.sessions.vertex_ai_session_service import ( + artifact_service_builder = self._tmpl_attrs.get("artifact_service_builder") + if artifact_service_builder: + self._tmpl_attrs["artifact_service"] = artifact_service_builder() + else: + self._tmpl_attrs["artifact_service"] = InMemoryArtifactService() + + session_service_builder = self._tmpl_attrs.get("session_service_builder") + if session_service_builder: + self._tmpl_attrs["session_service"] = session_service_builder() + elif "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ: + try: + from google.adk.sessions.vertex_ai_session_service import ( VertexAiSessionService, ) - # If the express mode api key is set, it will be read from the - # environment variable when initializing the session service. - self._tmpl_attrs["session_service"] = VertexAiSessionService( + # If the express mode api key is set, it will be read from the + # environment variable when initializing the session service. + self._tmpl_attrs["session_service"] = VertexAiSessionService( project=project, location=agent_engine_location, agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"), ) - except (ImportError, AttributeError): - from google.adk.sessions.vertex_ai_session_service_g3 import ( + except (ImportError, AttributeError): + from google.adk.sessions.vertex_ai_session_service_g3 import ( VertexAiSessionService, ) - # If the express mode api key is set, it will be read from the - # environment variable when initializing the session service. - self._tmpl_attrs["session_service"] = VertexAiSessionService( + # If the express mode api key is set, it will be read from the + # environment variable when initializing the session service. + self._tmpl_attrs["session_service"] = VertexAiSessionService( project=project, location=agent_engine_location, agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"), ) - else: - self._tmpl_attrs["session_service"] = InMemorySessionService() + else: + self._tmpl_attrs["session_service"] = InMemorySessionService() - memory_service_builder = self._tmpl_attrs.get("memory_service_builder") - if memory_service_builder: - self._tmpl_attrs["memory_service"] = memory_service_builder() - elif "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ and is_version_sufficient( + memory_service_builder = self._tmpl_attrs.get("memory_service_builder") + if memory_service_builder: + self._tmpl_attrs["memory_service"] = memory_service_builder() + elif "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ and is_version_sufficient( "1.5.0" ): - try: - from google.adk.memory.vertex_ai_memory_bank_service import ( + try: + from google.adk.memory.vertex_ai_memory_bank_service import ( VertexAiMemoryBankService, ) - # If the express mode api key is set, it will be read from the - # environment variable when initializing the memory service. - self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService( + # If the express mode api key is set, it will be read from the + # environment variable when initializing the memory service. + self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService( project=project, location=agent_engine_location, agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"), ) - except (ImportError, AttributeError): - from google.adk.memory.vertex_ai_memory_bank_service_g3 import ( + except (ImportError, AttributeError): + from google.adk.memory.vertex_ai_memory_bank_service_g3 import ( VertexAiMemoryBankService, ) - # If the express mode api key is set, it will be read from the - # environment variable when initializing the memory service. - self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService( + # If the express mode api key is set, it will be read from the + # environment variable when initializing the memory service. + self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService( project=project, location=agent_engine_location, agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"), ) - else: - self._tmpl_attrs["memory_service"] = InMemoryMemoryService() + else: + self._tmpl_attrs["memory_service"] = InMemoryMemoryService() - credential_service_builder = self._tmpl_attrs.get("credential_service_builder") - if credential_service_builder: - self._tmpl_attrs["credential_service"] = credential_service_builder() - else: - self._tmpl_attrs["credential_service"] = InMemoryCredentialService() + credential_service_builder = self._tmpl_attrs.get("credential_service_builder") + if credential_service_builder: + self._tmpl_attrs["credential_service"] = credential_service_builder() + else: + self._tmpl_attrs["credential_service"] = InMemoryCredentialService() - self._tmpl_attrs["runner"] = Runner( + self._tmpl_attrs["runner"] = Runner( app=self._tmpl_attrs.get("app"), agent=( None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("agent") @@ -1131,10 +1131,10 @@ def set_up(self): memory_service=self._tmpl_attrs.get("memory_service"), credential_service=self._tmpl_attrs.get("credential_service"), ) - self._tmpl_attrs["in_memory_session_service"] = InMemorySessionService() - self._tmpl_attrs["in_memory_artifact_service"] = InMemoryArtifactService() - self._tmpl_attrs["in_memory_memory_service"] = InMemoryMemoryService() - self._tmpl_attrs["in_memory_runner"] = Runner( + self._tmpl_attrs["in_memory_session_service"] = InMemorySessionService() + self._tmpl_attrs["in_memory_artifact_service"] = InMemoryArtifactService() + self._tmpl_attrs["in_memory_memory_service"] = InMemoryMemoryService() + self._tmpl_attrs["in_memory_runner"] = Runner( app=self._tmpl_attrs.get("app"), app_name=( None @@ -1152,7 +1152,7 @@ def set_up(self): memory_service=self._tmpl_attrs.get("in_memory_memory_service"), ) - async def async_stream_query( + async def async_stream_query( self, *, message: Union[str, Dict[str, Any]], @@ -1162,7 +1162,7 @@ async def async_stream_query( run_config: Optional[Dict[str, Any]] = None, **kwargs, ) -> AsyncIterable[Dict[str, Any]]: - """Streams responses asynchronously from the ADK application. + """Streams responses asynchronously from the ADK application. Args: message (str): @@ -1192,75 +1192,75 @@ async def async_stream_query( a Content object. ValueError: If both session_id and session_events are specified. """ - from agentplatform._genai import _agent_engines_utils - from google.genai import types + from agentplatform._genai import _agent_engines_utils + from google.genai import types - if isinstance(message, Dict): - content = types.Content.model_validate(message) - elif isinstance(message, str): - content = types.Content(role="user", parts=[types.Part(text=message)]) - else: - raise TypeError( + if isinstance(message, Dict): + content = types.Content.model_validate(message) + elif isinstance(message, str): + content = types.Content(role="user", parts=[types.Part(text=message)]) + else: + raise TypeError( "message must be a string or a dictionary representing" " a Content object." ) - if not self._tmpl_attrs.get("runner"): - self.set_up() - if session_id and session_events: - raise ValueError( + if not self._tmpl_attrs.get("runner"): + self.set_up() + if session_id and session_events: + raise ValueError( "Only one of session_id and session_events should be specified." ) - if not session_id: - session = await self.async_create_session(user_id=user_id) - session_id = session["id"] - if session_events is not None: - # We allow for session_events to be an empty list. - from google.adk.events.event import Event - - session_service = self._tmpl_attrs.get("session_service") - session_obj = await session_service.get_session( + if not session_id: + session = await self.async_create_session(user_id=user_id) + session_id = session["id"] + if session_events is not None: + # We allow for session_events to be an empty list. + from google.adk.events.event import Event + + session_service = self._tmpl_attrs.get("session_service") + session_obj = await session_service.get_session( app_name=self._app_name(), user_id=user_id, session_id=session_id, ) - for event in session_events: - if not isinstance(event, Event): - event = Event.model_validate(event) - await session_service.append_event( + for event in session_events: + if not isinstance(event, Event): + event = Event.model_validate(event) + await session_service.append_event( session=session_obj, event=event, ) - run_config = _validate_run_config(run_config) - if run_config: - events_async = self._tmpl_attrs.get("runner").run_async( + run_config = _validate_run_config(run_config) + if run_config: + events_async = self._tmpl_attrs.get("runner").run_async( user_id=user_id, session_id=session_id, new_message=content, run_config=run_config, **kwargs, ) - else: - events_async = self._tmpl_attrs.get("runner").run_async( + else: + events_async = self._tmpl_attrs.get("runner").run_async( user_id=user_id, session_id=session_id, new_message=content, **kwargs, ) - try: - async for event in events_async: - # Yield the event data as a dictionary - yield _agent_engines_utils.dump_event_for_json(event) - finally: - # Avoid telemetry data loss having to do with CPU throttling on instance turndown - _ = await _force_flush_otel( + try: + async for event in events_async: + # Yield the event data as a dictionary + yield _agent_engines_utils.dump_event_for_json(event) + finally: + # Avoid telemetry data loss having to do with CPU throttling on instance turndown + _ = await _force_flush_otel( tracing_enabled=self._tracing_enabled(), logging_enabled=bool(self._telemetry_enabled()), ) - def stream_query( + def stream_query( self, *, message: Union[str, Dict[str, Any]], @@ -1269,7 +1269,7 @@ def stream_query( run_config: Optional[Dict[str, Any]] = None, **kwargs, ): - """Deprecated. Use async_stream_query instead. + """Deprecated. Use async_stream_query instead. Streams responses from the ADK application in response to a message. @@ -1292,7 +1292,7 @@ def stream_query( Yields: The output of querying the ADK application. """ - warnings.warn( + warnings.warn( ( "AdkApp.stream_query(...) is deprecated. " "Use AdkApp.async_stream_query(...) instead. See " @@ -1302,45 +1302,45 @@ def stream_query( DeprecationWarning, stacklevel=2, ) - from agentplatform._genai import _agent_engines_utils - from google.genai import types + from agentplatform._genai import _agent_engines_utils + from google.genai import types - if isinstance(message, Dict): - content = types.Content.model_validate(message) - elif isinstance(message, str): - content = types.Content(role="user", parts=[types.Part(text=message)]) - else: - raise TypeError( + if isinstance(message, Dict): + content = types.Content.model_validate(message) + elif isinstance(message, str): + content = types.Content(role="user", parts=[types.Part(text=message)]) + else: + raise TypeError( "message must be a string or a dictionary representing" " a Content object." ) - if not self._tmpl_attrs.get("runner"): - self.set_up() - if not session_id: - session = self.create_session(user_id=user_id) - session_id = session["id"] - run_config = _validate_run_config(run_config) - if run_config: - for event in self._tmpl_attrs.get("runner").run( + if not self._tmpl_attrs.get("runner"): + self.set_up() + if not session_id: + session = self.create_session(user_id=user_id) + session_id = session["id"] + run_config = _validate_run_config(run_config) + if run_config: + for event in self._tmpl_attrs.get("runner").run( user_id=user_id, session_id=session_id, new_message=content, run_config=run_config, **kwargs, ): - yield _agent_engines_utils.dump_event_for_json(event) - else: - for event in self._tmpl_attrs.get("runner").run( + yield _agent_engines_utils.dump_event_for_json(event) + else: + for event in self._tmpl_attrs.get("runner").run( user_id=user_id, session_id=session_id, new_message=content, **kwargs, ): - yield _agent_engines_utils.dump_event_for_json(event) + yield _agent_engines_utils.dump_event_for_json(event) - async def streaming_agent_run_with_events(self, request_json: str): - """Streams responses asynchronously from the ADK application. + async def streaming_agent_run_with_events(self, request_json: str): + """Streams responses asynchronously from the ADK application. In general, you should use `async_stream_query` instead, as it has a more structured API and works with the respective ADK services that @@ -1352,12 +1352,12 @@ async def streaming_agent_run_with_events(self, request_json: str): Required. The request to stream responses for. """ - import json - from google.genai import types - from google.genai.errors import ClientError + import json + from google.genai import types + from google.genai.errors import ClientError - request = _StreamRunRequest(**json.loads(request_json)) - if not any( + request = _StreamRunRequest(**json.loads(request_json)) + if not any( self._tmpl_attrs.get(service) for service in ( "in_memory_runner", @@ -1370,93 +1370,93 @@ async def streaming_agent_run_with_events(self, request_json: str): "memory_service", ) ): - self.set_up() - - # Try to get the session, if it doesn't exist, create a new one. - state_delta = None - if request.session_id: - session_service = self._tmpl_attrs.get("session_service") - artifact_service = self._tmpl_attrs.get("artifact_service") - runner = self._tmpl_attrs.get("runner") - session = None - try: - session = await session_service.get_session( + self.set_up() + + # Try to get the session, if it doesn't exist, create a new one. + state_delta = None + if request.session_id: + session_service = self._tmpl_attrs.get("session_service") + artifact_service = self._tmpl_attrs.get("artifact_service") + runner = self._tmpl_attrs.get("runner") + session = None + try: + session = await session_service.get_session( app_name=self._app_name(), user_id=request.user_id, session_id=request.session_id, ) - if session: - await self._save_artifacts( + if session: + await self._save_artifacts( session_id=request.session_id, artifact_service=artifact_service, request=request, ) - if request.authorizations: - state_delta = {} - for auth_id, auth in request.authorizations.items(): - auth = _Authorization(**auth) - state_delta[auth_id] = auth.access_token - except ClientError: - pass - if not session: - # Fall back to create session if the session is not found. - # Specifying session_id on creation is not supported, - # so session id will be regenerated. - session = await self._init_session( + if request.authorizations: + state_delta = {} + for auth_id, auth in request.authorizations.items(): + auth = _Authorization(**auth) + state_delta[auth_id] = auth.access_token + except ClientError: + pass + if not session: + # Fall back to create session if the session is not found. + # Specifying session_id on creation is not supported, + # so session id will be regenerated. + session = await self._init_session( session_service=session_service, artifact_service=artifact_service, request=request, ) - else: - # Not providing a session ID will create a new in-memory session. - session_service = self._tmpl_attrs.get("in_memory_session_service") - artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") - runner = self._tmpl_attrs.get("in_memory_runner") - session = await self._init_session( + else: + # Not providing a session ID will create a new in-memory session. + session_service = self._tmpl_attrs.get("in_memory_session_service") + artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") + runner = self._tmpl_attrs.get("in_memory_runner") + session = await self._init_session( session_service=session_service, artifact_service=artifact_service, request=request, ) - if not session: - raise RuntimeError("Session initialization failed.") + if not session: + raise RuntimeError("Session initialization failed.") - # Run the agent - message_for_agent = types.Content(**request.message) - try: - async for event in runner.run_async( + # Run the agent + message_for_agent = types.Content(**request.message) + try: + async for event in runner.run_async( user_id=request.user_id, session_id=session.id, new_message=message_for_agent, state_delta=state_delta, ): - converted_event = await self._convert_response_events( + converted_event = await self._convert_response_events( user_id=request.user_id, session_id=session.id, events=[event], artifact_service=artifact_service, ) - yield converted_event - finally: - if session and not request.session_id: - await session_service.delete_session( + yield converted_event + finally: + if session and not request.session_id: + await session_service.delete_session( app_name=self._app_name(), user_id=request.user_id, session_id=session.id, ) - # Avoid telemetry data loss having to do with CPU throttling on instance turndown - _ = await _force_flush_otel( + # Avoid telemetry data loss having to do with CPU throttling on instance turndown + _ = await _force_flush_otel( tracing_enabled=self._tracing_enabled(), logging_enabled=bool(self._telemetry_enabled()), ) - async def async_get_session( + async def async_get_session( self, *, user_id: str, session_id: str, **kwargs, ): - """Get a session for the given user. + """Get a session for the given user. Args: user_id (str): @@ -1474,32 +1474,36 @@ async def async_get_session( Raises: RuntimeError: If the session is not found. """ - if not self._tmpl_attrs.get("session_service"): - self.set_up() - session = await self._tmpl_attrs.get("session_service").get_session( - app_name=self._app_name(), - user_id=user_id, - session_id=session_id, - **kwargs, - ) - if not session: - raise RuntimeError( + if not self._tmpl_attrs.get("session_service"): + self.set_up() + if "config" in kwargs and isinstance(kwargs["config"], dict): + from google.adk.sessions.base_session_service import GetSessionConfig + + kwargs["config"] = GetSessionConfig(**kwargs["config"]) + session = await self._tmpl_attrs.get("session_service").get_session( + app_name=self._app_name(), + user_id=user_id, + session_id=session_id, + **kwargs, + ) + if not session: + raise RuntimeError( "Session not found. Please create it using .create_session()" ) - return session + return session - def get_session( + def get_session( self, *, user_id: str, session_id: str, **kwargs, ): - """Deprecated. Use async_get_session instead. + """Deprecated. Use async_get_session instead. Get a session for the given user. """ - warnings.warn( + warnings.warn( ( "AdkApp.get_session(...) is deprecated. " "Use AdkApp.async_get_session(...) instead. See " @@ -1509,37 +1513,37 @@ def get_session( DeprecationWarning, stacklevel=2, ) - event_queue = queue.Queue(maxsize=1) + event_queue = queue.Queue(maxsize=1) - async def _invoke_async_get_session(): - return await self.async_get_session( + async def _invoke_async_get_session(): + return await self.async_get_session( user_id=user_id, session_id=session_id, **kwargs ) - def _asyncio_thread_main(): - try: - result = asyncio.run(_invoke_async_get_session()) - event_queue.put(result) - except Exception as e: - event_queue.put(e) + def _asyncio_thread_main(): + try: + result = asyncio.run(_invoke_async_get_session()) + event_queue.put(result) + except Exception as e: + event_queue.put(e) - thread = threading.Thread(target=_asyncio_thread_main) - thread.start() + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() - # Wait for the thread to finish - thread.join() - try: - outcome = event_queue.get(timeout=10) - except queue.Empty: - raise RuntimeError( + # Wait for the thread to finish + thread.join() + try: + outcome = event_queue.get(timeout=10) + except queue.Empty: + raise RuntimeError( "Session not found. Please create it using .create_session()" ) from None - if isinstance(outcome, RuntimeError): - raise outcome from None - return outcome + if isinstance(outcome, RuntimeError): + raise outcome from None + return outcome - async def async_list_sessions(self, *, user_id: str, **kwargs): - """List sessions for the given user. + async def async_list_sessions(self, *, user_id: str, **kwargs): + """List sessions for the given user. Args: user_id (str): @@ -1551,20 +1555,20 @@ async def async_list_sessions(self, *, user_id: str, **kwargs): Returns: ListSessionsResponse: The list of sessions. """ - if not self._tmpl_attrs.get("session_service"): - self.set_up() - return await self._tmpl_attrs.get("session_service").list_sessions( + if not self._tmpl_attrs.get("session_service"): + self.set_up() + return await self._tmpl_attrs.get("session_service").list_sessions( app_name=self._app_name(), user_id=user_id, **kwargs, ) - def list_sessions(self, *, user_id: str, **kwargs): - """Deprecated. Use async_list_sessions instead. + def list_sessions(self, *, user_id: str, **kwargs): + """Deprecated. Use async_list_sessions instead. List sessions for the given user. """ - warnings.warn( + warnings.warn( ( "AdkApp.list_sessions(...) is deprecated. " "Use AdkApp.async_list_sessions(...) instead. See " @@ -1574,31 +1578,31 @@ def list_sessions(self, *, user_id: str, **kwargs): DeprecationWarning, stacklevel=2, ) - event_queue = queue.Queue() - - async def _invoke_async_list_sessions(): - try: - response = await self.async_list_sessions(user_id=user_id, **kwargs) - event_queue.put(response) - except RuntimeError as e: - event_queue.put(e) - - def _asyncio_thread_main(): - try: - asyncio.run(_invoke_async_list_sessions()) - finally: - event_queue.put(None) - - thread = threading.Thread(target=_asyncio_thread_main) - thread.start() - # Wait for the thread to finish - thread.join() - try: - return event_queue.get(timeout=10) - except queue.Empty: - raise RuntimeError("Failed to list sessions.") from None + event_queue = queue.Queue() + + async def _invoke_async_list_sessions(): + try: + response = await self.async_list_sessions(user_id=user_id, **kwargs) + event_queue.put(response) + except RuntimeError as e: + event_queue.put(e) + + def _asyncio_thread_main(): + try: + asyncio.run(_invoke_async_list_sessions()) + finally: + event_queue.put(None) + + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() + # Wait for the thread to finish + thread.join() + try: + return event_queue.get(timeout=10) + except queue.Empty: + raise RuntimeError("Failed to list sessions.") from None - async def async_create_session( + async def async_create_session( self, *, user_id: str, @@ -1606,7 +1610,7 @@ async def async_create_session( state: Optional[Dict[str, Any]] = None, **kwargs, ): - """Creates a new session. + """Creates a new session. Args: user_id (str): @@ -1623,18 +1627,18 @@ async def async_create_session( Returns: Session: The newly created session instance. """ - if not self._tmpl_attrs.get("session_service"): - self.set_up() - session = await self._tmpl_attrs.get("session_service").create_session( + if not self._tmpl_attrs.get("session_service"): + self.set_up() + session = await self._tmpl_attrs.get("session_service").create_session( app_name=self._app_name(), user_id=user_id, session_id=session_id, state=state, **kwargs, ) - return self._serialize(session) + return self._serialize(session) - def create_session( + def create_session( self, *, user_id: str, @@ -1642,11 +1646,11 @@ def create_session( state: Optional[Dict[str, Any]] = None, **kwargs, ): - """Deprecated. Use async_create_session instead. + """Deprecated. Use async_create_session instead. Creates a new session. """ - warnings.warn( + warnings.warn( ( "AdkApp.create_session(...) is deprecated. " "Use AdkApp.async_create_session(...) instead. See " @@ -1656,44 +1660,44 @@ def create_session( DeprecationWarning, stacklevel=2, ) - event_queue = queue.Queue(maxsize=1) + event_queue = queue.Queue(maxsize=1) - async def _invoke_async_create_session(): - return await self.async_create_session( + async def _invoke_async_create_session(): + return await self.async_create_session( user_id=user_id, session_id=session_id, state=state, **kwargs, ) - def _asyncio_thread_main(): - try: - result = asyncio.run(_invoke_async_create_session()) - event_queue.put(result) - except RuntimeError as e: - event_queue.put(e) + def _asyncio_thread_main(): + try: + result = asyncio.run(_invoke_async_create_session()) + event_queue.put(result) + except RuntimeError as e: + event_queue.put(e) - thread = threading.Thread(target=_asyncio_thread_main) - thread.start() - # Wait for the thread to finish - thread.join() + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() + # Wait for the thread to finish + thread.join() - try: - outcome = event_queue.get(timeout=10) - except queue.Empty: - raise RuntimeError("Failed to create session.") from None - if isinstance(outcome, RuntimeError): - raise outcome from None - return outcome - - async def async_delete_session( + try: + outcome = event_queue.get(timeout=10) + except queue.Empty: + raise RuntimeError("Failed to create session.") from None + if isinstance(outcome, RuntimeError): + raise outcome from None + return outcome + + async def async_delete_session( self, *, user_id: str, session_id: str, **kwargs, ): - """Deletes a session for the given user. + """Deletes a session for the given user. Args: user_id (str): @@ -1704,27 +1708,27 @@ async def async_delete_session( Optional. Additional keyword arguments to pass to the session service. """ - if not self._tmpl_attrs.get("session_service"): - self.set_up() - await self._tmpl_attrs.get("session_service").delete_session( + if not self._tmpl_attrs.get("session_service"): + self.set_up() + await self._tmpl_attrs.get("session_service").delete_session( app_name=self._app_name(), user_id=user_id, session_id=session_id, **kwargs, ) - def delete_session( + def delete_session( self, *, user_id: str, session_id: str, **kwargs, ): - """Deprecated. Use async_delete_session instead. + """Deprecated. Use async_delete_session instead. Deletes a session for the given user. """ - warnings.warn( + warnings.warn( ( "AdkApp.delete_session(...) is deprecated. " "Use AdkApp.async_delete_session(...) instead. See " @@ -1734,31 +1738,31 @@ def delete_session( DeprecationWarning, stacklevel=2, ) - event_queue = queue.Queue(maxsize=1) + event_queue = queue.Queue(maxsize=1) - async def _invoke_async_delete_session(): - await self.async_delete_session( + async def _invoke_async_delete_session(): + await self.async_delete_session( user_id=user_id, session_id=session_id, **kwargs ) - def _asyncio_thread_main(): - try: - asyncio.run(_invoke_async_delete_session()) - event_queue.put(None) - except RuntimeError as e: - event_queue.put(e) + def _asyncio_thread_main(): + try: + asyncio.run(_invoke_async_delete_session()) + event_queue.put(None) + except RuntimeError as e: + event_queue.put(e) - thread = threading.Thread(target=_asyncio_thread_main) - thread.start() - # Wait for the thread to finish - thread.join() + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() + # Wait for the thread to finish + thread.join() - outcome = event_queue.get(timeout=10) - if isinstance(outcome, RuntimeError): - raise outcome from None + outcome = event_queue.get(timeout=10) + if isinstance(outcome, RuntimeError): + raise outcome from None - async def async_add_session_to_memory(self, *, session: Dict[str, Any]): - """Generates memories. + async def async_add_session_to_memory(self, *, session: Dict[str, Any]): + """Generates memories. Args: session (Dict[str, Any]): @@ -1766,26 +1770,26 @@ async def async_add_session_to_memory(self, *, session: Dict[str, Any]): be a dictionary representing an ADK Session object, e.g. session.model_dump(mode="json"). """ - from google.adk.sessions.session import Session - - if isinstance(session, Dict): - session = Session.model_validate(session) - elif not isinstance(session, Session): - raise TypeError("session must be a Session object.") - if not session.events: - # Get the latest version of the session in case it was updated. - session = await self.async_get_session( + from google.adk.sessions.session import Session + + if isinstance(session, Dict): + session = Session.model_validate(session) + elif not isinstance(session, Session): + raise TypeError("session must be a Session object.") + if not session.events: + # Get the latest version of the session in case it was updated. + session = await self.async_get_session( user_id=session.user_id, session_id=session.id, ) - if not self._tmpl_attrs.get("memory_service"): - self.set_up() - return await self._tmpl_attrs.get("memory_service").add_session_to_memory( + if not self._tmpl_attrs.get("memory_service"): + self.set_up() + return await self._tmpl_attrs.get("memory_service").add_session_to_memory( session=session, ) - async def async_search_memory(self, *, user_id: str, query: str): - """Searches memories for the given user. + async def async_search_memory(self, *, user_id: str, query: str): + """Searches memories for the given user. Args: user_id: The id of the user. @@ -1794,15 +1798,15 @@ async def async_search_memory(self, *, user_id: str, query: str): Returns: A SearchMemoryResponse containing the matching memories. """ - if not self._tmpl_attrs.get("memory_service"): - self.set_up() - return await self._tmpl_attrs.get("memory_service").search_memory( + if not self._tmpl_attrs.get("memory_service"): + self.set_up() + return await self._tmpl_attrs.get("memory_service").search_memory( app_name=self._app_name(), user_id=user_id, query=query, ) - async def async_save_artifact( + async def async_save_artifact( self, *, user_id: str, @@ -1812,7 +1816,7 @@ async def async_save_artifact( custom_metadata: Optional[Dict[str, Any]] = None, **kwargs, ): - """Saves an artifact to the artifact service storage. + """Saves an artifact to the artifact service storage. Args: user_id (str): @@ -1832,19 +1836,19 @@ async def async_save_artifact( Returns: int: The revision ID. """ - if isinstance(artifact, str): - try: - from google.genai import types - except ImportError: - raise ImportError( + if isinstance(artifact, str): + try: + from google.genai import types + except ImportError: + raise ImportError( "The `google-genai` package is required to use AdkApp. " "Please install it with `pip install google-genai`." ) - artifact = types.Part(text=artifact) + artifact = types.Part(text=artifact) - if not self._tmpl_attrs.get("artifact_service"): - self.set_up() - return await self._tmpl_attrs.get("artifact_service").save_artifact( + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").save_artifact( app_name=self._app_name(), user_id=user_id, filename=filename, @@ -1854,7 +1858,7 @@ async def async_save_artifact( **kwargs, ) - async def async_load_artifact( + async def async_load_artifact( self, *, user_id: str, @@ -1863,7 +1867,7 @@ async def async_load_artifact( version: Optional[int] = None, **kwargs, ): - """Gets an artifact from the artifact service storage. + """Gets an artifact from the artifact service storage. Args: user_id (str): @@ -1881,9 +1885,9 @@ async def async_load_artifact( Returns: Optional[types.Part]: The artifact or None if not found. """ - if not self._tmpl_attrs.get("artifact_service"): - self.set_up() - return await self._tmpl_attrs.get("artifact_service").load_artifact( + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").load_artifact( app_name=self._app_name(), user_id=user_id, filename=filename, @@ -1892,14 +1896,14 @@ async def async_load_artifact( **kwargs, ) - async def async_list_artifact_keys( + async def async_list_artifact_keys( self, *, user_id: str, session_id: Optional[str] = None, **kwargs, ): - """Lists all the artifact filenames within a session. + """Lists all the artifact filenames within a session. Args: user_id (str): @@ -1913,16 +1917,16 @@ async def async_list_artifact_keys( Returns: list[str]: A list of artifact filenames. """ - if not self._tmpl_attrs.get("artifact_service"): - self.set_up() - return await self._tmpl_attrs.get("artifact_service").list_artifact_keys( + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").list_artifact_keys( app_name=self._app_name(), user_id=user_id, session_id=session_id, **kwargs, ) - async def async_delete_artifact( + async def async_delete_artifact( self, *, user_id: str, @@ -1930,7 +1934,7 @@ async def async_delete_artifact( session_id: Optional[str] = None, **kwargs, ): - """Deletes an artifact. + """Deletes an artifact. Args: user_id (str): @@ -1943,9 +1947,9 @@ async def async_delete_artifact( Optional. Additional keyword arguments to pass to the artifact service. """ - if not self._tmpl_attrs.get("artifact_service"): - self.set_up() - await self._tmpl_attrs.get("artifact_service").delete_artifact( + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + await self._tmpl_attrs.get("artifact_service").delete_artifact( app_name=self._app_name(), user_id=user_id, filename=filename, @@ -1953,7 +1957,7 @@ async def async_delete_artifact( **kwargs, ) - async def async_list_versions( + async def async_list_versions( self, *, user_id: str, @@ -1961,7 +1965,7 @@ async def async_list_versions( session_id: Optional[str] = None, **kwargs, ): - """Lists all versions of an artifact. + """Lists all versions of an artifact. Args: user_id (str): @@ -1977,9 +1981,9 @@ async def async_list_versions( Returns: list[int]: A list of all available versions of the artifact. """ - if not self._tmpl_attrs.get("artifact_service"): - self.set_up() - return await self._tmpl_attrs.get("artifact_service").list_versions( + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").list_versions( app_name=self._app_name(), user_id=user_id, filename=filename, @@ -1987,7 +1991,7 @@ async def async_list_versions( **kwargs, ) - async def async_list_artifact_versions( + async def async_list_artifact_versions( self, *, user_id: str, @@ -1995,7 +1999,7 @@ async def async_list_artifact_versions( session_id: Optional[str] = None, **kwargs, ): - """Lists all versions and their metadata for a specific artifact. + """Lists all versions and their metadata for a specific artifact. Args: user_id (str): @@ -2011,9 +2015,9 @@ async def async_list_artifact_versions( Returns: list[ArtifactVersion]: A list of ArtifactVersion objects. """ - if not self._tmpl_attrs.get("artifact_service"): - self.set_up() - return await self._tmpl_attrs.get("artifact_service").list_artifact_versions( + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").list_artifact_versions( app_name=self._app_name(), user_id=user_id, filename=filename, @@ -2021,7 +2025,7 @@ async def async_list_artifact_versions( **kwargs, ) - async def async_get_artifact_version( + async def async_get_artifact_version( self, *, user_id: str, @@ -2030,7 +2034,7 @@ async def async_get_artifact_version( version: Optional[int] = None, **kwargs, ): - """Gets the metadata for a specific version of an artifact. + """Gets the metadata for a specific version of an artifact. Args: user_id (str): @@ -2048,9 +2052,9 @@ async def async_get_artifact_version( Returns: Optional[ArtifactVersion]: An ArtifactVersion object or None. """ - if not self._tmpl_attrs.get("artifact_service"): - self.set_up() - return await self._tmpl_attrs.get("artifact_service").get_artifact_version( + if not self._tmpl_attrs.get("artifact_service"): + self.set_up() + return await self._tmpl_attrs.get("artifact_service").get_artifact_version( app_name=self._app_name(), user_id=user_id, filename=filename, @@ -2059,9 +2063,9 @@ async def async_get_artifact_version( **kwargs, ) - def register_operations(self) -> Dict[str, List[str]]: - """Registers the operations of the ADK application.""" - return { + def register_operations(self) -> Dict[str, List[str]]: + """Registers the operations of the ADK application.""" + return { "": [ "get_session", "list_sessions", @@ -2090,8 +2094,8 @@ def register_operations(self) -> Dict[str, List[str]]: ], } - def _telemetry_enabled(self) -> Optional[bool]: - """Return status of telemetry enablement depending on enablement env variable. + def _telemetry_enabled(self) -> Optional[bool]: + """Return status of telemetry enablement depending on enablement env variable. In detail: - Logging is always enabled when telemetry is enabled. @@ -2101,25 +2105,25 @@ def _telemetry_enabled(self) -> Optional[bool]: True if telemetry is enabled, False if telemetry is disabled, or None if telemetry enablement is not set (i.e. old deployments which don't support this env variable). """ - import os + import os - GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = ( + GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = ( "GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY" ) - env_value = os.getenv( + env_value = os.getenv( GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY, "unspecified" ).lower() - if env_value in ("true", "1"): - return True - if env_value in ("false", "0"): - return False - return None + if env_value in ("true", "1"): + return True + if env_value in ("false", "0"): + return False + return None - # Tracing enablement follows truth table: - def _tracing_enabled(self) -> bool: - """Tracing enablement follows true table: + # Tracing enablement follows truth table: + def _tracing_enabled(self) -> bool: + """Tracing enablement follows true table: | enable_tracing | enable_telemetry(env) | tracing_actually_enabled | |----------------|-----------------------|--------------------------| @@ -2133,26 +2137,26 @@ def _tracing_enabled(self) -> bool: | None(default) | true | adk_version >= 1.17 | | None(default) | None | false | """ - enable_tracing: Optional[bool] = self._tmpl_attrs.get("enable_tracing") - enable_telemetry: Optional[bool] = self._telemetry_enabled() + enable_tracing: Optional[bool] = self._tmpl_attrs.get("enable_tracing") + enable_telemetry: Optional[bool] = self._telemetry_enabled() - return (enable_tracing is True and enable_telemetry is not False) or ( + return (enable_tracing is True and enable_telemetry is not False) or ( enable_tracing is None and enable_telemetry is True and is_version_sufficient("1.17.0") ) - def project_id(self) -> Optional[str]: - if project := self._tmpl_attrs.get("project"): - try: - from google.cloud.aiplatform.utils import ( + def project_id(self) -> Optional[str]: + if project := self._tmpl_attrs.get("project"): + try: + from google.cloud.aiplatform.utils import ( resource_manager_utils, ) - from google.api_core import exceptions + from google.api_core import exceptions - return resource_manager_utils.get_project_id(project) - # Fail open as temporary workaround for identity_type config parameter - except (exceptions.PermissionDenied, exceptions.Unauthenticated): - return project + return resource_manager_utils.get_project_id(project) + # Fail open as temporary workaround for identity_type config parameter + except (exceptions.PermissionDenied, exceptions.Unauthenticated): + return project - return None + return None diff --git a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py index 2e80315044..5879a9d647 100644 --- a/tests/unit/vertex_adk/test_agent_engine_templates_adk.py +++ b/tests/unit/vertex_adk/test_agent_engine_templates_adk.py @@ -631,6 +631,41 @@ async def test_async_get_session(self, get_project_id_mock: mock.Mock): assert session2.user_id == _TEST_USER_ID assert session1["id"] == session2.id + @pytest.mark.asyncio + async def test_async_get_session_with_config(self, get_project_id_mock: mock.Mock): + from google.adk.events import event + from google.adk.sessions.session import Session + + app = agent_engines.AdkApp(agent=_TEST_AGENT) + app.set_up() + session_service = app._tmpl_attrs.get("session_service") + + session1 = await session_service.create_session( + app_name=app._app_name(), + user_id=_TEST_USER_ID, + ) + + e1 = event.Event(id="e1", timestamp=1.0) + e2 = event.Event(id="e2", timestamp=2.0) + await session_service.append_event(session=session1, event=e1) + await session_service.append_event(session=session1, event=e2) + + # Get session without config + session_all = await app.async_get_session( + user_id=_TEST_USER_ID, + session_id=session1.id, + ) + assert len(session_all.events) == 2 + + # Get session with config + session_filtered = await app.async_get_session( + user_id=_TEST_USER_ID, + session_id=session1.id, + config={"num_recent_events": 1}, + ) + assert len(session_filtered.events) == 1 + assert session_filtered.events[0].id == "e2" + @pytest.mark.asyncio async def test_async_list_sessions(self, get_project_id_mock: mock.Mock): app = agent_engines.AdkApp(agent=_TEST_AGENT) diff --git a/vertexai/agent_engines/templates/adk.py b/vertexai/agent_engines/templates/adk.py index 58c3025285..953b8c7cc7 100644 --- a/vertexai/agent_engines/templates/adk.py +++ b/vertexai/agent_engines/templates/adk.py @@ -667,11 +667,11 @@ def _use_client_cert_effective() -> bool: class AdkApp: - """An ADK Application.""" + """An ADK Application.""" - agent_framework = "google-adk" + agent_framework = "google-adk" - def __init__( + def __init__( self, *, app: "App" = None, @@ -684,7 +684,7 @@ def __init__( memory_service_builder: Optional[Callable[..., "BaseMemoryService"]] = None, instrumentor_builder: Optional[Callable[..., Any]] = None, ): - """An ADK Application. + """An ADK Application. See https://cloud.google.com/vertex-ai/generative-ai/docs/agent-engine/develop/adk for details on how to develop ADK applications on Agent Engine. @@ -721,34 +721,34 @@ def __init__( If not provided, a default instrumentor builder will be used. This parameter is ignored if `enable_tracing` is False. """ - import os - from google.cloud.aiplatform import initializer + import os + from google.cloud.aiplatform import initializer - adk_version = get_adk_version() - if not is_version_sufficient("1.5.0"): - msg = ( + adk_version = get_adk_version() + if not is_version_sufficient("1.5.0"): + msg = ( f"Unsupported google-adk version: {adk_version}, please use " "google-adk>=1.5.0 for AdkApp deployment on Agent Engine." ) - raise ValueError(msg) + raise ValueError(msg) - if not agent and not app: - raise ValueError("One of `agent` or `app` must be provided.") - if app: - if app_name: - raise ValueError( + if not agent and not app: + raise ValueError("One of `agent` or `app` must be provided.") + if app: + if app_name: + raise ValueError( "When app is provided, app_name should not be provided, " "since it will be derived from app.name." ) - if agent: - raise ValueError("When app is provided, agent should not be provided.") - if plugins: - raise ValueError( + if agent: + raise ValueError("When app is provided, agent should not be provided.") + if plugins: + raise ValueError( "When app is provided, plugins should not be provided and" " should be provided in the app instead." ) - self._tmpl_attrs: Dict[str, Any] = { + self._tmpl_attrs: Dict[str, Any] = { "project": initializer.global_config.project, "location": initializer.global_config.location, "agent": agent, @@ -765,108 +765,108 @@ def __init__( ), } - def _serialize(self, obj: Any) -> Any: - """Serializes an object to be JSON compatible.""" - if hasattr(obj, "model_dump"): - return obj.model_dump(mode="json") - elif hasattr(obj, "dict"): - return self._serialize(obj.dict()) - elif isinstance(obj, dict): - return {k: self._serialize(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [self._serialize(v) for v in obj] - return obj - - def _app_name(self) -> str: - """Returns the app name.""" - app = self._tmpl_attrs.get("app") - return app.name if app else self._tmpl_attrs.get("app_name") - - async def _init_session( + def _serialize(self, obj: Any) -> Any: + """Serializes an object to be JSON compatible.""" + if hasattr(obj, "model_dump"): + return obj.model_dump(mode="json") + elif hasattr(obj, "dict"): + return self._serialize(obj.dict()) + elif isinstance(obj, dict): + return {k: self._serialize(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._serialize(v) for v in obj] + return obj + + def _app_name(self) -> str: + """Returns the app name.""" + app = self._tmpl_attrs.get("app") + return app.name if app else self._tmpl_attrs.get("app_name") + + async def _init_session( self, session_service: "BaseSessionService", artifact_service: "BaseArtifactService", request: _StreamRunRequest, ): - """Initializes the session, and returns the session id.""" - from google.adk.events.event import Event + """Initializes the session, and returns the session id.""" + from google.adk.events.event import Event - session_state = None - if request.authorizations: - session_state = {} - for auth_id, auth in request.authorizations.items(): - auth = _Authorization(**auth) - session_state[auth_id] = auth.access_token + session_state = None + if request.authorizations: + session_state = {} + for auth_id, auth in request.authorizations.items(): + auth = _Authorization(**auth) + session_state[auth_id] = auth.access_token - session = await session_service.create_session( + session = await session_service.create_session( app_name=self._app_name(), user_id=request.user_id, state=session_state, ) - if not session: - raise RuntimeError("Create session failed.") - if request.events: - for event in request.events: - await session_service.append_event(session, Event(**event)) - if request.artifacts: - await self._save_artifacts(session.id, artifact_service, request) - return session - - async def _save_artifacts( + if not session: + raise RuntimeError("Create session failed.") + if request.events: + for event in request.events: + await session_service.append_event(session, Event(**event)) + if request.artifacts: + await self._save_artifacts(session.id, artifact_service, request) + return session + + async def _save_artifacts( self, session_id: str, artifact_service: "BaseArtifactService", request: _StreamRunRequest, ): - """Saves the artifacts.""" - if request.artifacts: - for artifact in request.artifacts: - artifact = _Artifact(**artifact) - for version_data in sorted( + """Saves the artifacts.""" + if request.artifacts: + for artifact in request.artifacts: + artifact = _Artifact(**artifact) + for version_data in sorted( artifact.versions, key=lambda x: x["version"] ): - version_data = _ArtifactVersion(**version_data) - saved_version = await artifact_service.save_artifact( + version_data = _ArtifactVersion(**version_data) + saved_version = await artifact_service.save_artifact( app_name=self._app_name(), user_id=request.user_id, session_id=session_id, filename=artifact.file_name, artifact=version_data.data, ) - if saved_version != version_data.version: - from google.cloud.aiplatform import base + if saved_version != version_data.version: + from google.cloud.aiplatform import base - _LOGGER = base.Logger(__name__) - _LOGGER.debug( + _LOGGER = base.Logger(__name__) + _LOGGER.debug( "Artifact '%s' saved at version %s instead of %s", artifact.file_name, saved_version, version_data.version, ) - async def _convert_response_events( + async def _convert_response_events( self, user_id: str, session_id: str, events: List["Event"], artifact_service: Optional["BaseArtifactService"], ) -> _StreamingRunResponse: - """Converts the events to the streaming run response object.""" - import collections + """Converts the events to the streaming run response object.""" + import collections - result = _StreamingRunResponse( + result = _StreamingRunResponse( events=events, artifacts=[], session_id=session_id ) - # Save the generated artifacts into the result object. - artifact_versions = collections.defaultdict(list) - for event in events: - if event.actions and event.actions.artifact_delta: - for key, version in event.actions.artifact_delta.items(): - artifact_versions[key].append(version) + # Save the generated artifacts into the result object. + artifact_versions = collections.defaultdict(list) + for event in events: + if event.actions and event.actions.artifact_delta: + for key, version in event.actions.artifact_delta.items(): + artifact_versions[key].append(version) - for key, versions in artifact_versions.items(): - result.artifacts.append( + for key, versions in artifact_versions.items(): + result.artifacts.append( _Artifact( file_name=key, versions=[ @@ -885,13 +885,13 @@ async def _convert_response_events( ) ) - return result.dump() + return result.dump() - def clone(self): - """Returns a clone of the ADK application.""" - import copy + def clone(self): + """Returns a clone of the ADK application.""" + import copy - return self.__class__( + return self.__class__( app=copy.deepcopy(self._tmpl_attrs.get("app")), enable_tracing=self._tmpl_attrs.get("enable_tracing"), agent=( @@ -915,57 +915,57 @@ def clone(self): instrumentor_builder=self._tmpl_attrs.get("instrumentor_builder"), ) - def set_up(self): - """Sets up the ADK application.""" - import os - from google.adk.runners import Runner - from google.adk.sessions.in_memory_session_service import InMemorySessionService - from google.adk.artifacts.in_memory_artifact_service import ( + def set_up(self): + """Sets up the ADK application.""" + import os + from google.adk.runners import Runner + from google.adk.sessions.in_memory_session_service import InMemorySessionService + from google.adk.artifacts.in_memory_artifact_service import ( InMemoryArtifactService, ) - from google.adk.memory.in_memory_memory_service import InMemoryMemoryService - - os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1" - os.environ["GOOGLE_GENAI_USE_ENTERPRISE"] = "1" - project = self._tmpl_attrs.get("project") - if project: - os.environ["GOOGLE_CLOUD_PROJECT"] = project - location = self._tmpl_attrs.get("location") - if location: - if "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION" not in os.environ: - os.environ["GOOGLE_CLOUD_AGENT_ENGINE_LOCATION"] = location - if "GOOGLE_CLOUD_LOCATION" not in os.environ: - os.environ["GOOGLE_CLOUD_LOCATION"] = location - agent_engine_location = os.environ.get( + from google.adk.memory.in_memory_memory_service import InMemoryMemoryService + + os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1" + os.environ["GOOGLE_GENAI_USE_ENTERPRISE"] = "1" + project = self._tmpl_attrs.get("project") + if project: + os.environ["GOOGLE_CLOUD_PROJECT"] = project + location = self._tmpl_attrs.get("location") + if location: + if "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION" not in os.environ: + os.environ["GOOGLE_CLOUD_AGENT_ENGINE_LOCATION"] = location + if "GOOGLE_CLOUD_LOCATION" not in os.environ: + os.environ["GOOGLE_CLOUD_LOCATION"] = location + agent_engine_location = os.environ.get( "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", # the runtime env var (if set) location, # the location set in the AdkApp template ) - express_mode_api_key = self._tmpl_attrs.get("express_mode_api_key") - if express_mode_api_key and not project: - os.environ["GOOGLE_API_KEY"] = express_mode_api_key - # Clear location and project env vars if express mode api key is provided. - os.environ.pop("GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", None) - os.environ.pop("GOOGLE_CLOUD_LOCATION", None) - os.environ.pop("GOOGLE_CLOUD_PROJECT", None) - location = None - - # Disable content capture in custom ADK spans unless user enabled - # tracing explicitly with the old flag - # (this is to preserve compatibility with old behavior). - if self._tmpl_attrs.get("enable_tracing"): - os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] = "true" - else: - os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] = "false" + express_mode_api_key = self._tmpl_attrs.get("express_mode_api_key") + if express_mode_api_key and not project: + os.environ["GOOGLE_API_KEY"] = express_mode_api_key + # Clear location and project env vars if express mode api key is provided. + os.environ.pop("GOOGLE_CLOUD_AGENT_ENGINE_LOCATION", None) + os.environ.pop("GOOGLE_CLOUD_LOCATION", None) + os.environ.pop("GOOGLE_CLOUD_PROJECT", None) + location = None + + # Disable content capture in custom ADK spans unless user enabled + # tracing explicitly with the old flag + # (this is to preserve compatibility with old behavior). + if self._tmpl_attrs.get("enable_tracing"): + os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] = "true" + else: + os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] = "false" - enable_logging = bool(self._telemetry_enabled()) + enable_logging = bool(self._telemetry_enabled()) - custom_instrumentor = self._tmpl_attrs.get("instrumentor_builder") + custom_instrumentor = self._tmpl_attrs.get("instrumentor_builder") - if self._tmpl_attrs.get("enable_tracing"): - _warn_if_telemetry_api_disabled() + if self._tmpl_attrs.get("enable_tracing"): + _warn_if_telemetry_api_disabled() - if self._tmpl_attrs.get("enable_tracing") is False: - _warn( + if self._tmpl_attrs.get("enable_tracing") is False: + _warn( ( "Your 'enable_tracing=False' setting is being deprecated " "and will be removed in a future release.\n" @@ -990,96 +990,96 @@ def set_up(self): ), ) - if custom_instrumentor and self._tracing_enabled(): - self._tmpl_attrs["instrumentor"] = custom_instrumentor(self.project_id()) + if custom_instrumentor and self._tracing_enabled(): + self._tmpl_attrs["instrumentor"] = custom_instrumentor(self.project_id()) - if not custom_instrumentor: - self._tmpl_attrs["instrumentor"] = _default_instrumentor_builder( + if not custom_instrumentor: + self._tmpl_attrs["instrumentor"] = _default_instrumentor_builder( self.project_id(), enable_tracing=self._tracing_enabled(), enable_logging=enable_logging, ) - if not self._tmpl_attrs.get("app_name"): - if "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ: - self._tmpl_attrs["app_name"] = os.environ.get( + if not self._tmpl_attrs.get("app_name"): + if "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ: + self._tmpl_attrs["app_name"] = os.environ.get( "GOOGLE_CLOUD_AGENT_ENGINE_ID", ) - else: - self._tmpl_attrs["app_name"] = _DEFAULT_APP_NAME + else: + self._tmpl_attrs["app_name"] = _DEFAULT_APP_NAME - artifact_service_builder = self._tmpl_attrs.get("artifact_service_builder") - if artifact_service_builder: - self._tmpl_attrs["artifact_service"] = artifact_service_builder() - else: - self._tmpl_attrs["artifact_service"] = InMemoryArtifactService() - - session_service_builder = self._tmpl_attrs.get("session_service_builder") - if session_service_builder: - self._tmpl_attrs["session_service"] = session_service_builder() - elif "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ: - try: - from google.adk.sessions.vertex_ai_session_service import ( + artifact_service_builder = self._tmpl_attrs.get("artifact_service_builder") + if artifact_service_builder: + self._tmpl_attrs["artifact_service"] = artifact_service_builder() + else: + self._tmpl_attrs["artifact_service"] = InMemoryArtifactService() + + session_service_builder = self._tmpl_attrs.get("session_service_builder") + if session_service_builder: + self._tmpl_attrs["session_service"] = session_service_builder() + elif "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ: + try: + from google.adk.sessions.vertex_ai_session_service import ( VertexAiSessionService, ) - # If the express mode api key is set, it will be read from the - # environment variable when initializing the session service. - self._tmpl_attrs["session_service"] = VertexAiSessionService( + # If the express mode api key is set, it will be read from the + # environment variable when initializing the session service. + self._tmpl_attrs["session_service"] = VertexAiSessionService( project=project, location=agent_engine_location, agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"), ) - except (ImportError, AttributeError): - from google.adk.sessions.vertex_ai_session_service_g3 import ( + except (ImportError, AttributeError): + from google.adk.sessions.vertex_ai_session_service_g3 import ( VertexAiSessionService, ) - # If the express mode api key is set, it will be read from the - # environment variable when initializing the session service. - self._tmpl_attrs["session_service"] = VertexAiSessionService( + # If the express mode api key is set, it will be read from the + # environment variable when initializing the session service. + self._tmpl_attrs["session_service"] = VertexAiSessionService( project=project, location=agent_engine_location, agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"), ) - else: - self._tmpl_attrs["session_service"] = InMemorySessionService() + else: + self._tmpl_attrs["session_service"] = InMemorySessionService() - memory_service_builder = self._tmpl_attrs.get("memory_service_builder") - if memory_service_builder: - self._tmpl_attrs["memory_service"] = memory_service_builder() - elif "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ and is_version_sufficient( + memory_service_builder = self._tmpl_attrs.get("memory_service_builder") + if memory_service_builder: + self._tmpl_attrs["memory_service"] = memory_service_builder() + elif "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ and is_version_sufficient( "1.5.0" ): - try: - from google.adk.memory.vertex_ai_memory_bank_service import ( + try: + from google.adk.memory.vertex_ai_memory_bank_service import ( VertexAiMemoryBankService, ) - # If the express mode api key is set, it will be read from the - # environment variable when initializing the memory service. - self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService( + # If the express mode api key is set, it will be read from the + # environment variable when initializing the memory service. + self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService( project=project, location=agent_engine_location, agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"), ) - except (ImportError, AttributeError): - from google.adk.memory.vertex_ai_memory_bank_service_g3 import ( + except (ImportError, AttributeError): + from google.adk.memory.vertex_ai_memory_bank_service_g3 import ( VertexAiMemoryBankService, ) - # If the express mode api key is set, it will be read from the - # environment variable when initializing the memory service. - self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService( + # If the express mode api key is set, it will be read from the + # environment variable when initializing the memory service. + self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService( project=project, location=agent_engine_location, agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"), ) - else: - self._tmpl_attrs["memory_service"] = InMemoryMemoryService() + else: + self._tmpl_attrs["memory_service"] = InMemoryMemoryService() - self._tmpl_attrs["runner"] = Runner( + self._tmpl_attrs["runner"] = Runner( app=self._tmpl_attrs.get("app"), agent=( None if self._tmpl_attrs.get("app") else self._tmpl_attrs.get("agent") @@ -1096,10 +1096,10 @@ def set_up(self): artifact_service=self._tmpl_attrs.get("artifact_service"), memory_service=self._tmpl_attrs.get("memory_service"), ) - self._tmpl_attrs["in_memory_session_service"] = InMemorySessionService() - self._tmpl_attrs["in_memory_artifact_service"] = InMemoryArtifactService() - self._tmpl_attrs["in_memory_memory_service"] = InMemoryMemoryService() - self._tmpl_attrs["in_memory_runner"] = Runner( + self._tmpl_attrs["in_memory_session_service"] = InMemorySessionService() + self._tmpl_attrs["in_memory_artifact_service"] = InMemoryArtifactService() + self._tmpl_attrs["in_memory_memory_service"] = InMemoryMemoryService() + self._tmpl_attrs["in_memory_runner"] = Runner( app=self._tmpl_attrs.get("app"), app_name=( None @@ -1117,7 +1117,7 @@ def set_up(self): memory_service=self._tmpl_attrs.get("in_memory_memory_service"), ) - async def async_stream_query( + async def async_stream_query( self, *, message: Union[str, Dict[str, Any]], @@ -1127,7 +1127,7 @@ async def async_stream_query( run_config: Optional[Dict[str, Any]] = None, **kwargs, ) -> AsyncIterable[Dict[str, Any]]: - """Streams responses asynchronously from the ADK application. + """Streams responses asynchronously from the ADK application. Args: message (str): @@ -1157,70 +1157,70 @@ async def async_stream_query( a Content object. ValueError: If both session_id and session_events are specified. """ - from vertexai.agent_engines import _utils - from google.genai import types + from vertexai.agent_engines import _utils + from google.genai import types - if isinstance(message, Dict): - content = types.Content.model_validate(message) - elif isinstance(message, str): - content = types.Content(role="user", parts=[types.Part(text=message)]) - else: - raise TypeError( + if isinstance(message, Dict): + content = types.Content.model_validate(message) + elif isinstance(message, str): + content = types.Content(role="user", parts=[types.Part(text=message)]) + else: + raise TypeError( "message must be a string or a dictionary representing" " a Content object." ) - if not self._tmpl_attrs.get("runner"): - self.set_up() - if session_id and session_events: - raise ValueError( + if not self._tmpl_attrs.get("runner"): + self.set_up() + if session_id and session_events: + raise ValueError( "Only one of session_id and session_events should be specified." ) - if not session_id: - session = await self.async_create_session(user_id=user_id) - session_id = session["id"] - if session_events is not None: - # We allow for session_events to be an empty list. - from google.adk.events.event import Event - - session_service = self._tmpl_attrs.get("session_service") - for event in session_events: - if not isinstance(event, Event): - event = Event.model_validate(event) - await session_service.append_event( + if not session_id: + session = await self.async_create_session(user_id=user_id) + session_id = session["id"] + if session_events is not None: + # We allow for session_events to be an empty list. + from google.adk.events.event import Event + + session_service = self._tmpl_attrs.get("session_service") + for event in session_events: + if not isinstance(event, Event): + event = Event.model_validate(event) + await session_service.append_event( session=session, event=event, ) - run_config = _validate_run_config(run_config) - if run_config: - events_async = self._tmpl_attrs.get("runner").run_async( + run_config = _validate_run_config(run_config) + if run_config: + events_async = self._tmpl_attrs.get("runner").run_async( user_id=user_id, session_id=session_id, new_message=content, run_config=run_config, **kwargs, ) - else: - events_async = self._tmpl_attrs.get("runner").run_async( + else: + events_async = self._tmpl_attrs.get("runner").run_async( user_id=user_id, session_id=session_id, new_message=content, **kwargs, ) - try: - async for event in events_async: - # Yield the event data as a dictionary - yield _utils.dump_event_for_json(event) - finally: - # Avoid telemetry data loss having to do with CPU throttling on instance turndown - _ = await _force_flush_otel( + try: + async for event in events_async: + # Yield the event data as a dictionary + yield _utils.dump_event_for_json(event) + finally: + # Avoid telemetry data loss having to do with CPU throttling on instance turndown + _ = await _force_flush_otel( tracing_enabled=self._tracing_enabled(), logging_enabled=bool(self._telemetry_enabled()), ) - def stream_query( + def stream_query( self, *, message: Union[str, Dict[str, Any]], @@ -1229,7 +1229,7 @@ def stream_query( run_config: Optional[Dict[str, Any]] = None, **kwargs, ): - """Deprecated. Use async_stream_query instead. + """Deprecated. Use async_stream_query instead. Streams responses from the ADK application in response to a message. @@ -1252,7 +1252,7 @@ def stream_query( Yields: The output of querying the ADK application. """ - warnings.warn( + warnings.warn( ( "AdkApp.stream_query(...) is deprecated. " "Use AdkApp.async_stream_query(...) instead. See " @@ -1262,45 +1262,45 @@ def stream_query( DeprecationWarning, stacklevel=2, ) - from vertexai.agent_engines import _utils - from google.genai import types + from vertexai.agent_engines import _utils + from google.genai import types - if isinstance(message, Dict): - content = types.Content.model_validate(message) - elif isinstance(message, str): - content = types.Content(role="user", parts=[types.Part(text=message)]) - else: - raise TypeError( + if isinstance(message, Dict): + content = types.Content.model_validate(message) + elif isinstance(message, str): + content = types.Content(role="user", parts=[types.Part(text=message)]) + else: + raise TypeError( "message must be a string or a dictionary representing" " a Content object." ) - if not self._tmpl_attrs.get("runner"): - self.set_up() - if not session_id: - session = self.create_session(user_id=user_id) - session_id = session["id"] - run_config = _validate_run_config(run_config) - if run_config: - for event in self._tmpl_attrs.get("runner").run( + if not self._tmpl_attrs.get("runner"): + self.set_up() + if not session_id: + session = self.create_session(user_id=user_id) + session_id = session["id"] + run_config = _validate_run_config(run_config) + if run_config: + for event in self._tmpl_attrs.get("runner").run( user_id=user_id, session_id=session_id, new_message=content, run_config=run_config, **kwargs, ): - yield _utils.dump_event_for_json(event) - else: - for event in self._tmpl_attrs.get("runner").run( + yield _utils.dump_event_for_json(event) + else: + for event in self._tmpl_attrs.get("runner").run( user_id=user_id, session_id=session_id, new_message=content, **kwargs, ): - yield _utils.dump_event_for_json(event) + yield _utils.dump_event_for_json(event) - async def streaming_agent_run_with_events(self, request_json: str): - """Streams responses asynchronously from the ADK application. + async def streaming_agent_run_with_events(self, request_json: str): + """Streams responses asynchronously from the ADK application. In general, you should use `async_stream_query` instead, as it has a more structured API and works with the respective ADK services that @@ -1312,12 +1312,12 @@ async def streaming_agent_run_with_events(self, request_json: str): Required. The request to stream responses for. """ - import json - from google.genai import types - from google.genai.errors import ClientError + import json + from google.genai import types + from google.genai.errors import ClientError - request = _StreamRunRequest(**json.loads(request_json)) - if not any( + request = _StreamRunRequest(**json.loads(request_json)) + if not any( self._tmpl_attrs.get(service) for service in ( "in_memory_runner", @@ -1330,93 +1330,93 @@ async def streaming_agent_run_with_events(self, request_json: str): "memory_service", ) ): - self.set_up() - - # Try to get the session, if it doesn't exist, create a new one. - state_delta = None - if request.session_id: - session_service = self._tmpl_attrs.get("session_service") - artifact_service = self._tmpl_attrs.get("artifact_service") - runner = self._tmpl_attrs.get("runner") - session = None - try: - session = await session_service.get_session( + self.set_up() + + # Try to get the session, if it doesn't exist, create a new one. + state_delta = None + if request.session_id: + session_service = self._tmpl_attrs.get("session_service") + artifact_service = self._tmpl_attrs.get("artifact_service") + runner = self._tmpl_attrs.get("runner") + session = None + try: + session = await session_service.get_session( app_name=self._app_name(), user_id=request.user_id, session_id=request.session_id, ) - if session: - await self._save_artifacts( + if session: + await self._save_artifacts( session_id=request.session_id, artifact_service=artifact_service, request=request, ) - if request.authorizations: - state_delta = {} - for auth_id, auth in request.authorizations.items(): - auth = _Authorization(**auth) - state_delta[auth_id] = auth.access_token - except ClientError: - pass - if not session: - # Fall back to create session if the session is not found. - # Specifying session_id on creation is not supported, - # so session id will be regenerated. - session = await self._init_session( + if request.authorizations: + state_delta = {} + for auth_id, auth in request.authorizations.items(): + auth = _Authorization(**auth) + state_delta[auth_id] = auth.access_token + except ClientError: + pass + if not session: + # Fall back to create session if the session is not found. + # Specifying session_id on creation is not supported, + # so session id will be regenerated. + session = await self._init_session( session_service=session_service, artifact_service=artifact_service, request=request, ) - else: - # Not providing a session ID will create a new in-memory session. - session_service = self._tmpl_attrs.get("in_memory_session_service") - artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") - runner = self._tmpl_attrs.get("in_memory_runner") - session = await self._init_session( + else: + # Not providing a session ID will create a new in-memory session. + session_service = self._tmpl_attrs.get("in_memory_session_service") + artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") + runner = self._tmpl_attrs.get("in_memory_runner") + session = await self._init_session( session_service=session_service, artifact_service=artifact_service, request=request, ) - if not session: - raise RuntimeError("Session initialization failed.") + if not session: + raise RuntimeError("Session initialization failed.") - # Run the agent - message_for_agent = types.Content(**request.message) - try: - async for event in runner.run_async( + # Run the agent + message_for_agent = types.Content(**request.message) + try: + async for event in runner.run_async( user_id=request.user_id, session_id=session.id, new_message=message_for_agent, state_delta=state_delta, ): - converted_event = await self._convert_response_events( + converted_event = await self._convert_response_events( user_id=request.user_id, session_id=session.id, events=[event], artifact_service=artifact_service, ) - yield converted_event - finally: - if session and not request.session_id: - await session_service.delete_session( + yield converted_event + finally: + if session and not request.session_id: + await session_service.delete_session( app_name=self._app_name(), user_id=request.user_id, session_id=session.id, ) - # Avoid telemetry data loss having to do with CPU throttling on instance turndown - _ = await _force_flush_otel( + # Avoid telemetry data loss having to do with CPU throttling on instance turndown + _ = await _force_flush_otel( tracing_enabled=self._tracing_enabled(), logging_enabled=bool(self._telemetry_enabled()), ) - async def async_get_session( + async def async_get_session( self, *, user_id: str, session_id: str, **kwargs, ): - """Get a session for the given user. + """Get a session for the given user. Args: user_id (str): @@ -1434,32 +1434,36 @@ async def async_get_session( Raises: RuntimeError: If the session is not found. """ - if not self._tmpl_attrs.get("session_service"): - self.set_up() - session = await self._tmpl_attrs.get("session_service").get_session( - app_name=self._app_name(), - user_id=user_id, - session_id=session_id, - **kwargs, - ) - if not session: - raise RuntimeError( + if not self._tmpl_attrs.get("session_service"): + self.set_up() + if "config" in kwargs and isinstance(kwargs["config"], dict): + from google.adk.sessions.base_session_service import GetSessionConfig + + kwargs["config"] = GetSessionConfig(**kwargs["config"]) + session = await self._tmpl_attrs.get("session_service").get_session( + app_name=self._app_name(), + user_id=user_id, + session_id=session_id, + **kwargs, + ) + if not session: + raise RuntimeError( "Session not found. Please create it using .create_session()" ) - return session + return session - def get_session( + def get_session( self, *, user_id: str, session_id: str, **kwargs, ): - """Deprecated. Use async_get_session instead. + """Deprecated. Use async_get_session instead. Get a session for the given user. """ - warnings.warn( + warnings.warn( ( "AdkApp.get_session(...) is deprecated. " "Use AdkApp.async_get_session(...) instead. See " @@ -1469,37 +1473,37 @@ def get_session( DeprecationWarning, stacklevel=2, ) - event_queue = queue.Queue(maxsize=1) + event_queue = queue.Queue(maxsize=1) - async def _invoke_async_get_session(): - return await self.async_get_session( + async def _invoke_async_get_session(): + return await self.async_get_session( user_id=user_id, session_id=session_id, **kwargs ) - def _asyncio_thread_main(): - try: - result = asyncio.run(_invoke_async_get_session()) - event_queue.put(result) - except Exception as e: - event_queue.put(e) + def _asyncio_thread_main(): + try: + result = asyncio.run(_invoke_async_get_session()) + event_queue.put(result) + except Exception as e: + event_queue.put(e) - thread = threading.Thread(target=_asyncio_thread_main) - thread.start() + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() - # Wait for the thread to finish - thread.join() - try: - outcome = event_queue.get(timeout=10) - except queue.Empty: - raise RuntimeError( + # Wait for the thread to finish + thread.join() + try: + outcome = event_queue.get(timeout=10) + except queue.Empty: + raise RuntimeError( "Session not found. Please create it using .create_session()" ) from None - if isinstance(outcome, RuntimeError): - raise outcome from None - return outcome + if isinstance(outcome, RuntimeError): + raise outcome from None + return outcome - async def async_list_sessions(self, *, user_id: str, **kwargs): - """List sessions for the given user. + async def async_list_sessions(self, *, user_id: str, **kwargs): + """List sessions for the given user. Args: user_id (str): @@ -1511,20 +1515,20 @@ async def async_list_sessions(self, *, user_id: str, **kwargs): Returns: ListSessionsResponse: The list of sessions. """ - if not self._tmpl_attrs.get("session_service"): - self.set_up() - return await self._tmpl_attrs.get("session_service").list_sessions( + if not self._tmpl_attrs.get("session_service"): + self.set_up() + return await self._tmpl_attrs.get("session_service").list_sessions( app_name=self._app_name(), user_id=user_id, **kwargs, ) - def list_sessions(self, *, user_id: str, **kwargs): - """Deprecated. Use async_list_sessions instead. + def list_sessions(self, *, user_id: str, **kwargs): + """Deprecated. Use async_list_sessions instead. List sessions for the given user. """ - warnings.warn( + warnings.warn( ( "AdkApp.list_sessions(...) is deprecated. " "Use AdkApp.async_list_sessions(...) instead. See " @@ -1534,31 +1538,31 @@ def list_sessions(self, *, user_id: str, **kwargs): DeprecationWarning, stacklevel=2, ) - event_queue = queue.Queue() - - async def _invoke_async_list_sessions(): - try: - response = await self.async_list_sessions(user_id=user_id, **kwargs) - event_queue.put(response) - except RuntimeError as e: - event_queue.put(e) - - def _asyncio_thread_main(): - try: - asyncio.run(_invoke_async_list_sessions()) - finally: - event_queue.put(None) - - thread = threading.Thread(target=_asyncio_thread_main) - thread.start() - # Wait for the thread to finish - thread.join() - try: - return event_queue.get(timeout=10) - except queue.Empty: - raise RuntimeError("Failed to list sessions.") from None + event_queue = queue.Queue() + + async def _invoke_async_list_sessions(): + try: + response = await self.async_list_sessions(user_id=user_id, **kwargs) + event_queue.put(response) + except RuntimeError as e: + event_queue.put(e) + + def _asyncio_thread_main(): + try: + asyncio.run(_invoke_async_list_sessions()) + finally: + event_queue.put(None) + + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() + # Wait for the thread to finish + thread.join() + try: + return event_queue.get(timeout=10) + except queue.Empty: + raise RuntimeError("Failed to list sessions.") from None - async def async_create_session( + async def async_create_session( self, *, user_id: str, @@ -1566,7 +1570,7 @@ async def async_create_session( state: Optional[Dict[str, Any]] = None, **kwargs, ): - """Creates a new session. + """Creates a new session. Args: user_id (str): @@ -1583,18 +1587,18 @@ async def async_create_session( Returns: Session: The newly created session instance. """ - if not self._tmpl_attrs.get("session_service"): - self.set_up() - session = await self._tmpl_attrs.get("session_service").create_session( + if not self._tmpl_attrs.get("session_service"): + self.set_up() + session = await self._tmpl_attrs.get("session_service").create_session( app_name=self._app_name(), user_id=user_id, session_id=session_id, state=state, **kwargs, ) - return self._serialize(session) + return self._serialize(session) - def create_session( + def create_session( self, *, user_id: str, @@ -1602,11 +1606,11 @@ def create_session( state: Optional[Dict[str, Any]] = None, **kwargs, ): - """Deprecated. Use async_create_session instead. + """Deprecated. Use async_create_session instead. Creates a new session. """ - warnings.warn( + warnings.warn( ( "AdkApp.create_session(...) is deprecated. " "Use AdkApp.async_create_session(...) instead. See " @@ -1616,44 +1620,44 @@ def create_session( DeprecationWarning, stacklevel=2, ) - event_queue = queue.Queue(maxsize=1) + event_queue = queue.Queue(maxsize=1) - async def _invoke_async_create_session(): - return await self.async_create_session( + async def _invoke_async_create_session(): + return await self.async_create_session( user_id=user_id, session_id=session_id, state=state, **kwargs, ) - def _asyncio_thread_main(): - try: - result = asyncio.run(_invoke_async_create_session()) - event_queue.put(result) - except RuntimeError as e: - event_queue.put(e) + def _asyncio_thread_main(): + try: + result = asyncio.run(_invoke_async_create_session()) + event_queue.put(result) + except RuntimeError as e: + event_queue.put(e) - thread = threading.Thread(target=_asyncio_thread_main) - thread.start() - # Wait for the thread to finish - thread.join() + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() + # Wait for the thread to finish + thread.join() - try: - outcome = event_queue.get(timeout=10) - except queue.Empty: - raise RuntimeError("Failed to create session.") from None - if isinstance(outcome, RuntimeError): - raise outcome from None - return outcome - - async def async_delete_session( + try: + outcome = event_queue.get(timeout=10) + except queue.Empty: + raise RuntimeError("Failed to create session.") from None + if isinstance(outcome, RuntimeError): + raise outcome from None + return outcome + + async def async_delete_session( self, *, user_id: str, session_id: str, **kwargs, ): - """Deletes a session for the given user. + """Deletes a session for the given user. Args: user_id (str): @@ -1664,27 +1668,27 @@ async def async_delete_session( Optional. Additional keyword arguments to pass to the session service. """ - if not self._tmpl_attrs.get("session_service"): - self.set_up() - await self._tmpl_attrs.get("session_service").delete_session( + if not self._tmpl_attrs.get("session_service"): + self.set_up() + await self._tmpl_attrs.get("session_service").delete_session( app_name=self._app_name(), user_id=user_id, session_id=session_id, **kwargs, ) - def delete_session( + def delete_session( self, *, user_id: str, session_id: str, **kwargs, ): - """Deprecated. Use async_delete_session instead. + """Deprecated. Use async_delete_session instead. Deletes a session for the given user. """ - warnings.warn( + warnings.warn( ( "AdkApp.delete_session(...) is deprecated. " "Use AdkApp.async_delete_session(...) instead. See " @@ -1694,31 +1698,31 @@ def delete_session( DeprecationWarning, stacklevel=2, ) - event_queue = queue.Queue(maxsize=1) + event_queue = queue.Queue(maxsize=1) - async def _invoke_async_delete_session(): - await self.async_delete_session( + async def _invoke_async_delete_session(): + await self.async_delete_session( user_id=user_id, session_id=session_id, **kwargs ) - def _asyncio_thread_main(): - try: - asyncio.run(_invoke_async_delete_session()) - event_queue.put(None) - except RuntimeError as e: - event_queue.put(e) + def _asyncio_thread_main(): + try: + asyncio.run(_invoke_async_delete_session()) + event_queue.put(None) + except RuntimeError as e: + event_queue.put(e) - thread = threading.Thread(target=_asyncio_thread_main) - thread.start() - # Wait for the thread to finish - thread.join() + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() + # Wait for the thread to finish + thread.join() - outcome = event_queue.get(timeout=10) - if isinstance(outcome, RuntimeError): - raise outcome from None + outcome = event_queue.get(timeout=10) + if isinstance(outcome, RuntimeError): + raise outcome from None - async def async_add_session_to_memory(self, *, session: Dict[str, Any]): - """Generates memories. + async def async_add_session_to_memory(self, *, session: Dict[str, Any]): + """Generates memories. Args: session (Dict[str, Any]): @@ -1726,26 +1730,26 @@ async def async_add_session_to_memory(self, *, session: Dict[str, Any]): be a dictionary representing an ADK Session object, e.g. session.model_dump(mode="json"). """ - from google.adk.sessions.session import Session - - if isinstance(session, Dict): - session = Session.model_validate(session) - elif not isinstance(session, Session): - raise TypeError("session must be a Session object.") - if not session.events: - # Get the latest version of the session in case it was updated. - session = await self.async_get_session( + from google.adk.sessions.session import Session + + if isinstance(session, Dict): + session = Session.model_validate(session) + elif not isinstance(session, Session): + raise TypeError("session must be a Session object.") + if not session.events: + # Get the latest version of the session in case it was updated. + session = await self.async_get_session( user_id=session.user_id, session_id=session.id, ) - if not self._tmpl_attrs.get("memory_service"): - self.set_up() - return await self._tmpl_attrs.get("memory_service").add_session_to_memory( + if not self._tmpl_attrs.get("memory_service"): + self.set_up() + return await self._tmpl_attrs.get("memory_service").add_session_to_memory( session=session, ) - async def async_search_memory(self, *, user_id: str, query: str): - """Searches memories for the given user. + async def async_search_memory(self, *, user_id: str, query: str): + """Searches memories for the given user. Args: user_id: The id of the user. @@ -1754,17 +1758,17 @@ async def async_search_memory(self, *, user_id: str, query: str): Returns: A SearchMemoryResponse containing the matching memories. """ - if not self._tmpl_attrs.get("memory_service"): - self.set_up() - return await self._tmpl_attrs.get("memory_service").search_memory( + if not self._tmpl_attrs.get("memory_service"): + self.set_up() + return await self._tmpl_attrs.get("memory_service").search_memory( app_name=self._app_name(), user_id=user_id, query=query, ) - def register_operations(self) -> Dict[str, List[str]]: - """Registers the operations of the ADK application.""" - return { + def register_operations(self) -> Dict[str, List[str]]: + """Registers the operations of the ADK application.""" + return { "": [ "get_session", "list_sessions", @@ -1786,8 +1790,8 @@ def register_operations(self) -> Dict[str, List[str]]: ], } - def _telemetry_enabled(self) -> Optional[bool]: - """Return status of telemetry enablement depending on enablement env variable. + def _telemetry_enabled(self) -> Optional[bool]: + """Return status of telemetry enablement depending on enablement env variable. In detail: - Logging is always enabled when telemetry is enabled. @@ -1797,25 +1801,25 @@ def _telemetry_enabled(self) -> Optional[bool]: True if telemetry is enabled, False if telemetry is disabled, or None if telemetry enablement is not set (i.e. old deployments which don't support this env variable). """ - import os + import os - GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = ( + GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = ( "GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY" ) - env_value = os.getenv( + env_value = os.getenv( GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY, "unspecified" ).lower() - if env_value in ("true", "1"): - return True - if env_value in ("false", "0"): - return False - return None + if env_value in ("true", "1"): + return True + if env_value in ("false", "0"): + return False + return None - # Tracing enablement follows truth table: - def _tracing_enabled(self) -> bool: - """Tracing enablement follows true table: + # Tracing enablement follows truth table: + def _tracing_enabled(self) -> bool: + """Tracing enablement follows true table: | enable_tracing | enable_telemetry(env) | tracing_actually_enabled | |----------------|-----------------------|--------------------------| @@ -1829,26 +1833,26 @@ def _tracing_enabled(self) -> bool: | None(default) | true | adk_version >= 1.17 | | None(default) | None | false | """ - enable_tracing: Optional[bool] = self._tmpl_attrs.get("enable_tracing") - enable_telemetry: Optional[bool] = self._telemetry_enabled() + enable_tracing: Optional[bool] = self._tmpl_attrs.get("enable_tracing") + enable_telemetry: Optional[bool] = self._telemetry_enabled() - return (enable_tracing is True and enable_telemetry is not False) or ( + return (enable_tracing is True and enable_telemetry is not False) or ( enable_tracing is None and enable_telemetry is True and is_version_sufficient("1.17.0") ) - def project_id(self) -> Optional[str]: - if project := self._tmpl_attrs.get("project"): - try: - from google.cloud.aiplatform.utils import ( + def project_id(self) -> Optional[str]: + if project := self._tmpl_attrs.get("project"): + try: + from google.cloud.aiplatform.utils import ( resource_manager_utils, ) - from google.api_core import exceptions + from google.api_core import exceptions - return resource_manager_utils.get_project_id(project) - # Fail open as temporary workaround for identity_type config parameter - except (exceptions.PermissionDenied, exceptions.Unauthenticated): - return project + return resource_manager_utils.get_project_id(project) + # Fail open as temporary workaround for identity_type config parameter + except (exceptions.PermissionDenied, exceptions.Unauthenticated): + return project - return None + return None diff --git a/vertexai/preview/reasoning_engines/templates/adk.py b/vertexai/preview/reasoning_engines/templates/adk.py index 469ecf71cc..6f6450e8cf 100644 --- a/vertexai/preview/reasoning_engines/templates/adk.py +++ b/vertexai/preview/reasoning_engines/templates/adk.py @@ -547,11 +547,11 @@ def _validate_run_config(run_config: Optional[Dict[str, Any]]): class AdkApp: - """An ADK Application.""" + """An ADK Application.""" - agent_framework = "google-adk" + agent_framework = "google-adk" - def __init__( + def __init__( self, *, agent: "BaseAgent", @@ -565,18 +565,18 @@ def __init__( ] = None, env_vars: Optional[Dict[str, str]] = None, ): - """An ADK Application.""" - from google.cloud.aiplatform import initializer + """An ADK Application.""" + from google.cloud.aiplatform import initializer - adk_version = get_adk_version() - if not is_version_sufficient("1.0.0"): - msg = ( + adk_version = get_adk_version() + if not is_version_sufficient("1.0.0"): + msg = ( f"Unsupported google-adk version: {adk_version}, " "please use google-adk>=1.0.0 for AdkApp deployment." ) - raise ValueError(msg) + raise ValueError(msg) - self._tmpl_attrs: Dict[str, Any] = { + self._tmpl_attrs: Dict[str, Any] = { "project": initializer.global_config.project, "location": initializer.global_config.location, "agent": agent, @@ -590,109 +590,109 @@ def __init__( "env_vars": env_vars or {}, } - def _serialize(self, obj: Any) -> Any: - """Serializes an object to be JSON compatible.""" - if hasattr(obj, "model_dump"): - return obj.model_dump(mode="json") - elif hasattr(obj, "dict"): - return self._serialize(obj.dict()) - elif isinstance(obj, dict): - return {k: self._serialize(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [self._serialize(v) for v in obj] - return obj - - async def _init_session( + def _serialize(self, obj: Any) -> Any: + """Serializes an object to be JSON compatible.""" + if hasattr(obj, "model_dump"): + return obj.model_dump(mode="json") + elif hasattr(obj, "dict"): + return self._serialize(obj.dict()) + elif isinstance(obj, dict): + return {k: self._serialize(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._serialize(v) for v in obj] + return obj + + async def _init_session( self, session_service: "BaseSessionService", artifact_service: "BaseArtifactService", request: _StreamRunRequest, ): - """Initializes the session, and returns the session id.""" - from google.adk.events.event import Event + """Initializes the session, and returns the session id.""" + from google.adk.events.event import Event - session_state = None - if request.authorizations: - session_state = {} - for auth_id, auth in request.authorizations.items(): - auth = _Authorization(**auth) - session_state[auth_id] = auth.access_token + session_state = None + if request.authorizations: + session_state = {} + for auth_id, auth in request.authorizations.items(): + auth = _Authorization(**auth) + session_state[auth_id] = auth.access_token - session = await session_service.create_session( + session = await session_service.create_session( app_name=self._tmpl_attrs.get("app_name"), user_id=request.user_id, state=session_state, ) - if not session: - raise RuntimeError("Create session failed.") - if request.events: - for event in request.events: - await session_service.append_event(session, Event(**event)) - if request.artifacts: - await self._save_artifacts( + if not session: + raise RuntimeError("Create session failed.") + if request.events: + for event in request.events: + await session_service.append_event(session, Event(**event)) + if request.artifacts: + await self._save_artifacts( session_id=session.id, artifact_service=artifact_service, request=request, ) - return session + return session - async def _save_artifacts( + async def _save_artifacts( self, session_id: str, artifact_service: "BaseArtifactService", request: _StreamRunRequest, ): - """Saves the artifacts.""" - app = self._tmpl_attrs.get("app") - if request.artifacts: - for artifact in request.artifacts: - artifact = _Artifact(**artifact) - for version_data in sorted( + """Saves the artifacts.""" + app = self._tmpl_attrs.get("app") + if request.artifacts: + for artifact in request.artifacts: + artifact = _Artifact(**artifact) + for version_data in sorted( artifact.versions, key=lambda x: x["version"] ): - version_data = _ArtifactVersion(**version_data) - saved_version = await artifact_service.save_artifact( + version_data = _ArtifactVersion(**version_data) + saved_version = await artifact_service.save_artifact( app_name=app.name if app else self._tmpl_attrs.get("app_name"), user_id=request.user_id, session_id=session_id, filename=artifact.file_name, artifact=version_data.data, ) - if saved_version != version_data.version: - from google.cloud.aiplatform import base + if saved_version != version_data.version: + from google.cloud.aiplatform import base - _LOGGER = base.Logger(__name__) - _LOGGER.debug( + _LOGGER = base.Logger(__name__) + _LOGGER.debug( "Artifact '%s' saved at version %s instead of %s", artifact.file_name, saved_version, version_data.version, ) - async def _convert_response_events( + async def _convert_response_events( self, user_id: str, session_id: str, events: List["Event"], artifact_service: Optional["BaseArtifactService"], ) -> _StreamingRunResponse: - """Converts the events to the streaming run response object.""" - import collections + """Converts the events to the streaming run response object.""" + import collections - result = _StreamingRunResponse( + result = _StreamingRunResponse( events=events, artifacts=[], session_id=session_id ) - # Save the generated artifacts into the result object. - artifact_versions = collections.defaultdict(list) - for event in events: - if event.actions and event.actions.artifact_delta: - for key, version in event.actions.artifact_delta.items(): - artifact_versions[key].append(version) + # Save the generated artifacts into the result object. + artifact_versions = collections.defaultdict(list) + for event in events: + if event.actions and event.actions.artifact_delta: + for key, version in event.actions.artifact_delta.items(): + artifact_versions[key].append(version) - for key, versions in artifact_versions.items(): - result.artifacts.append( + for key, versions in artifact_versions.items(): + result.artifacts.append( _Artifact( file_name=key, versions=[ @@ -711,13 +711,13 @@ async def _convert_response_events( ) ) - return result.dump() + return result.dump() - def clone(self): - """Returns a clone of the ADK application.""" - import copy + def clone(self): + """Returns a clone of the ADK application.""" + import copy - return AdkApp( + return AdkApp( agent=copy.deepcopy(self._tmpl_attrs.get("agent")), enable_tracing=self._tmpl_attrs.get("enable_tracing"), session_service_builder=self._tmpl_attrs.get("session_service_builder"), @@ -726,39 +726,39 @@ def clone(self): env_vars=self._tmpl_attrs.get("env_vars"), ) - def set_up(self): - """Sets up the ADK application.""" - import os - from google.adk.runners import Runner - from google.adk.sessions.in_memory_session_service import InMemorySessionService - from google.adk.artifacts.in_memory_artifact_service import ( + def set_up(self): + """Sets up the ADK application.""" + import os + from google.adk.runners import Runner + from google.adk.sessions.in_memory_session_service import InMemorySessionService + from google.adk.artifacts.in_memory_artifact_service import ( InMemoryArtifactService, ) - from google.adk.memory.in_memory_memory_service import InMemoryMemoryService - - os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1" - project = self._tmpl_attrs.get("project") - os.environ["GOOGLE_CLOUD_PROJECT"] = project - location = self._tmpl_attrs.get("location") - if location: - if "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION" not in os.environ: - os.environ["GOOGLE_CLOUD_AGENT_ENGINE_LOCATION"] = location - if "GOOGLE_CLOUD_LOCATION" not in os.environ: - os.environ["GOOGLE_CLOUD_LOCATION"] = location - - # Disable content capture in custom ADK spans unless user enabled - # tracing explicitly with the old flag - # (this is to preserve compatibility with old behavior). - if self._tmpl_attrs.get("enable_tracing"): - os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] = "true" - else: - os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] = "false" - - if self._tmpl_attrs.get("enable_tracing"): - self._warn_if_telemetry_api_disabled() - - if self._tmpl_attrs.get("enable_tracing") is False: - _warn( + from google.adk.memory.in_memory_memory_service import InMemoryMemoryService + + os.environ["GOOGLE_GENAI_USE_VERTEXAI"] = "1" + project = self._tmpl_attrs.get("project") + os.environ["GOOGLE_CLOUD_PROJECT"] = project + location = self._tmpl_attrs.get("location") + if location: + if "GOOGLE_CLOUD_AGENT_ENGINE_LOCATION" not in os.environ: + os.environ["GOOGLE_CLOUD_AGENT_ENGINE_LOCATION"] = location + if "GOOGLE_CLOUD_LOCATION" not in os.environ: + os.environ["GOOGLE_CLOUD_LOCATION"] = location + + # Disable content capture in custom ADK spans unless user enabled + # tracing explicitly with the old flag + # (this is to preserve compatibility with old behavior). + if self._tmpl_attrs.get("enable_tracing"): + os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] = "true" + else: + os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] = "false" + + if self._tmpl_attrs.get("enable_tracing"): + self._warn_if_telemetry_api_disabled() + + if self._tmpl_attrs.get("enable_tracing") is False: + _warn( ( "Your 'enable_tracing=False' setting is being deprecated " "and will be removed in a future release.\n" @@ -783,95 +783,95 @@ def set_up(self): ), ) - enable_logging = bool(self._telemetry_enabled()) + enable_logging = bool(self._telemetry_enabled()) - self._tmpl_attrs["instrumentor"] = _default_instrumentor_builder( + self._tmpl_attrs["instrumentor"] = _default_instrumentor_builder( self.project_id(), enable_tracing=self._tracing_enabled(), enable_logging=enable_logging, ) - for key, value in self._tmpl_attrs.get("env_vars").items(): - os.environ[key] = value - if "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ: - self._tmpl_attrs["app_name"] = os.environ.get( + for key, value in self._tmpl_attrs.get("env_vars").items(): + os.environ[key] = value + if "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ: + self._tmpl_attrs["app_name"] = os.environ.get( "GOOGLE_CLOUD_AGENT_ENGINE_ID", self._tmpl_attrs.get("app_name"), ) - artifact_service_builder = self._tmpl_attrs.get("artifact_service_builder") - if artifact_service_builder: - self._tmpl_attrs["artifact_service"] = artifact_service_builder() - else: - self._tmpl_attrs["artifact_service"] = InMemoryArtifactService() - - session_service_builder = self._tmpl_attrs.get("session_service_builder") - if session_service_builder: - self._tmpl_attrs["session_service"] = session_service_builder() - elif "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ: - try: - from google.adk.sessions.vertex_ai_session_service import ( + artifact_service_builder = self._tmpl_attrs.get("artifact_service_builder") + if artifact_service_builder: + self._tmpl_attrs["artifact_service"] = artifact_service_builder() + else: + self._tmpl_attrs["artifact_service"] = InMemoryArtifactService() + + session_service_builder = self._tmpl_attrs.get("session_service_builder") + if session_service_builder: + self._tmpl_attrs["session_service"] = session_service_builder() + elif "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ: + try: + from google.adk.sessions.vertex_ai_session_service import ( VertexAiSessionService, ) - if is_version_sufficient("1.5.0"): - self._tmpl_attrs["session_service"] = VertexAiSessionService( + if is_version_sufficient("1.5.0"): + self._tmpl_attrs["session_service"] = VertexAiSessionService( project=project, location=location, agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"), ) - else: - self._tmpl_attrs["session_service"] = VertexAiSessionService( + else: + self._tmpl_attrs["session_service"] = VertexAiSessionService( project=project, location=location, ) - except ImportError: - from google.adk.sessions.vertex_ai_session_service_g3 import ( + except ImportError: + from google.adk.sessions.vertex_ai_session_service_g3 import ( VertexAiSessionService, ) - self._tmpl_attrs["session_service"] = VertexAiSessionService( + self._tmpl_attrs["session_service"] = VertexAiSessionService( project=project, location=location, agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"), ) - else: - self._tmpl_attrs["session_service"] = InMemorySessionService() + else: + self._tmpl_attrs["session_service"] = InMemorySessionService() - memory_service_builder = self._tmpl_attrs.get("memory_service_builder") - if memory_service_builder: - self._tmpl_attrs["memory_service"] = memory_service_builder() - elif "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ and is_version_sufficient( + memory_service_builder = self._tmpl_attrs.get("memory_service_builder") + if memory_service_builder: + self._tmpl_attrs["memory_service"] = memory_service_builder() + elif "GOOGLE_CLOUD_AGENT_ENGINE_ID" in os.environ and is_version_sufficient( "1.5.0" ): - try: - from google.adk.memory.vertex_ai_memory_bank_service import ( + try: + from google.adk.memory.vertex_ai_memory_bank_service import ( VertexAiMemoryBankService, ) - self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService( + self._tmpl_attrs["memory_service"] = VertexAiMemoryBankService( project=project, location=location, agent_engine_id=os.environ.get("GOOGLE_CLOUD_AGENT_ENGINE_ID"), ) - except ImportError: - # TODO(ysian): Handle this via _g3 import for google3. - pass - else: - self._tmpl_attrs["memory_service"] = InMemoryMemoryService() - - credential_service_builder = self._tmpl_attrs.get("credential_service_builder") - if credential_service_builder: - self._tmpl_attrs["credential_service"] = credential_service_builder() - else: - from google.adk.auth.credential_service.in_memory_credential_service import ( + except ImportError: + # TODO(ysian): Handle this via _g3 import for google3. + pass + else: + self._tmpl_attrs["memory_service"] = InMemoryMemoryService() + + credential_service_builder = self._tmpl_attrs.get("credential_service_builder") + if credential_service_builder: + self._tmpl_attrs["credential_service"] = credential_service_builder() + else: + from google.adk.auth.credential_service.in_memory_credential_service import ( InMemoryCredentialService, ) - self._tmpl_attrs["credential_service"] = InMemoryCredentialService() + self._tmpl_attrs["credential_service"] = InMemoryCredentialService() - self._tmpl_attrs["runner"] = Runner( + self._tmpl_attrs["runner"] = Runner( agent=self._tmpl_attrs.get("agent"), plugins=self._tmpl_attrs.get("plugins"), session_service=self._tmpl_attrs.get("session_service"), @@ -879,10 +879,10 @@ def set_up(self): memory_service=self._tmpl_attrs.get("memory_service"), app_name=self._tmpl_attrs.get("app_name"), ) - self._tmpl_attrs["in_memory_session_service"] = InMemorySessionService() - self._tmpl_attrs["in_memory_artifact_service"] = InMemoryArtifactService() - self._tmpl_attrs["in_memory_memory_service"] = InMemoryMemoryService() - self._tmpl_attrs["in_memory_runner"] = Runner( + self._tmpl_attrs["in_memory_session_service"] = InMemorySessionService() + self._tmpl_attrs["in_memory_artifact_service"] = InMemoryArtifactService() + self._tmpl_attrs["in_memory_memory_service"] = InMemoryMemoryService() + self._tmpl_attrs["in_memory_runner"] = Runner( agent=self._tmpl_attrs.get("agent"), plugins=self._tmpl_attrs.get("plugins"), session_service=self._tmpl_attrs.get("in_memory_session_service"), @@ -892,7 +892,7 @@ def set_up(self): app_name=self._tmpl_attrs.get("app_name"), ) - def stream_query( + def stream_query( self, *, message: Union[str, Dict[str, Any]], @@ -901,7 +901,7 @@ def stream_query( run_config: Optional[Dict[str, Any]] = None, **kwargs, ): - """Streams responses from the ADK application in response to a message. + """Streams responses from the ADK application in response to a message. Args: message (Union[str, Dict[str, Any]]): @@ -922,44 +922,44 @@ def stream_query( Yields: The output of querying the ADK application. """ - from vertexai.agent_engines import _utils - from google.genai import types - - if isinstance(message, Dict): - content = types.Content.model_validate(message) - elif isinstance(message, str): - content = types.Content(role="user", parts=[types.Part(text=message)]) - else: - raise TypeError( + from vertexai.agent_engines import _utils + from google.genai import types + + if isinstance(message, Dict): + content = types.Content.model_validate(message) + elif isinstance(message, str): + content = types.Content(role="user", parts=[types.Part(text=message)]) + else: + raise TypeError( "message must be a string or a dictionary representing" " a Content object." ) - if not self._tmpl_attrs.get("runner"): - self.set_up() - if not session_id: - session = self.create_session(user_id=user_id) - session_id = session["id"] - run_config = _validate_run_config(run_config) - if run_config: - for event in self._tmpl_attrs.get("runner").run( + if not self._tmpl_attrs.get("runner"): + self.set_up() + if not session_id: + session = self.create_session(user_id=user_id) + session_id = session["id"] + run_config = _validate_run_config(run_config) + if run_config: + for event in self._tmpl_attrs.get("runner").run( user_id=user_id, session_id=session_id, new_message=content, run_config=run_config, **kwargs, ): - yield _utils.dump_event_for_json(event) - else: - for event in self._tmpl_attrs.get("runner").run( + yield _utils.dump_event_for_json(event) + else: + for event in self._tmpl_attrs.get("runner").run( user_id=user_id, session_id=session_id, new_message=content, **kwargs, ): - yield _utils.dump_event_for_json(event) + yield _utils.dump_event_for_json(event) - async def async_stream_query( + async def async_stream_query( self, *, message: Union[str, Dict[str, Any]], @@ -968,7 +968,7 @@ async def async_stream_query( run_config: Optional[Dict[str, Any]] = None, **kwargs, ) -> AsyncIterable[Dict[str, Any]]: - """Streams responses asynchronously from the ADK application. + """Streams responses asynchronously from the ADK application. Args: message (str): @@ -989,63 +989,63 @@ async def async_stream_query( Yields: Event dictionaries asynchronously. """ - from vertexai.agent_engines import _utils - from google.genai import types - - if isinstance(message, Dict): - content = types.Content.model_validate(message) - elif isinstance(message, str): - content = types.Content(role="user", parts=[types.Part(text=message)]) - else: - raise TypeError( + from vertexai.agent_engines import _utils + from google.genai import types + + if isinstance(message, Dict): + content = types.Content.model_validate(message) + elif isinstance(message, str): + content = types.Content(role="user", parts=[types.Part(text=message)]) + else: + raise TypeError( "message must be a string or a dictionary representing" " a Content object." ) - if not self._tmpl_attrs.get("runner"): - self.set_up() - if not session_id: - session = await self.async_create_session(user_id=user_id) - session_id = session["id"] + if not self._tmpl_attrs.get("runner"): + self.set_up() + if not session_id: + session = await self.async_create_session(user_id=user_id) + session_id = session["id"] - run_config = _validate_run_config(run_config) - if run_config: - events_async = self._tmpl_attrs.get("runner").run_async( + run_config = _validate_run_config(run_config) + if run_config: + events_async = self._tmpl_attrs.get("runner").run_async( user_id=user_id, session_id=session_id, new_message=content, run_config=run_config, **kwargs, ) - else: - events_async = self._tmpl_attrs.get("runner").run_async( + else: + events_async = self._tmpl_attrs.get("runner").run_async( user_id=user_id, session_id=session_id, new_message=content, **kwargs, ) - try: - async for event in events_async: - # Yield the event data as a dictionary - yield _utils.dump_event_for_json(event) - finally: - # Avoid telemetry data loss having to do with CPU throttling on instance turndown - _ = await _force_flush_otel( + try: + async for event in events_async: + # Yield the event data as a dictionary + yield _utils.dump_event_for_json(event) + finally: + # Avoid telemetry data loss having to do with CPU throttling on instance turndown + _ = await _force_flush_otel( tracing_enabled=self._tracing_enabled(), logging_enabled=bool(self._telemetry_enabled()), ) - def streaming_agent_run_with_events(self, request_json: str): - import json - from google.genai import types - from google.genai.errors import ClientError + def streaming_agent_run_with_events(self, request_json: str): + import json + from google.genai import types + from google.genai.errors import ClientError - event_queue = queue.Queue(maxsize=1) + event_queue = queue.Queue(maxsize=1) - async def _invoke_agent_async(): - request = _StreamRunRequest(**json.loads(request_json)) - if not any( + async def _invoke_agent_async(): + request = _StreamRunRequest(**json.loads(request_json)) + if not any( self._tmpl_attrs.get(service) for service in ( "in_memory_runner", @@ -1058,104 +1058,104 @@ async def _invoke_agent_async(): "memory_service", ) ): - self.set_up() - # Try to get the session, if it doesn't exist, create a new one. - if request.session_id: - session_service = self._tmpl_attrs.get("session_service") - artifact_service = self._tmpl_attrs.get("artifact_service") - runner = self._tmpl_attrs.get("runner") - session = None - try: - session = await session_service.get_session( + self.set_up() + # Try to get the session, if it doesn't exist, create a new one. + if request.session_id: + session_service = self._tmpl_attrs.get("session_service") + artifact_service = self._tmpl_attrs.get("artifact_service") + runner = self._tmpl_attrs.get("runner") + session = None + try: + session = await session_service.get_session( app_name=self._tmpl_attrs.get("app_name"), user_id=request.user_id, session_id=request.session_id, ) - if session: - await self._save_artifacts( + if session: + await self._save_artifacts( session_id=request.session_id, artifact_service=artifact_service, request=request, ) - except ClientError: - pass - if not session: - # Fall back to create session if the session is not found. - # Specifying session_id on creation is not supported, - # so session id will be regenerated. - session = await self._init_session( + except ClientError: + pass + if not session: + # Fall back to create session if the session is not found. + # Specifying session_id on creation is not supported, + # so session id will be regenerated. + session = await self._init_session( session_service=session_service, artifact_service=artifact_service, request=request, ) - else: - # Not providing a session ID will create a new in-memory session. - session_service = self._tmpl_attrs.get("in_memory_session_service") - artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") - runner = self._tmpl_attrs.get("in_memory_runner") - session = await self._init_session( + else: + # Not providing a session ID will create a new in-memory session. + session_service = self._tmpl_attrs.get("in_memory_session_service") + artifact_service = self._tmpl_attrs.get("in_memory_artifact_service") + runner = self._tmpl_attrs.get("in_memory_runner") + session = await self._init_session( session_service=session_service, artifact_service=artifact_service, request=request, ) - if not session: - raise RuntimeError("Session initialization failed.") - # Run the agent. - message_for_agent = types.Content(**request.message) - try: - for event in runner.run( + if not session: + raise RuntimeError("Session initialization failed.") + # Run the agent. + message_for_agent = types.Content(**request.message) + try: + for event in runner.run( user_id=request.user_id, session_id=session.id, new_message=message_for_agent, ): - converted_event = await self._convert_response_events( + converted_event = await self._convert_response_events( user_id=request.user_id, session_id=session.id, events=[event], artifact_service=artifact_service, ) - event_queue.put(converted_event) - finally: - if session and not request.session_id: - await session_service.delete_session( + event_queue.put(converted_event) + finally: + if session and not request.session_id: + await session_service.delete_session( app_name=self._tmpl_attrs.get("app_name"), user_id=request.user_id, session_id=session.id, ) - # Avoid telemetry data loss having to do with CPU throttling on instance turndown - _ = await _force_flush_otel( + # Avoid telemetry data loss having to do with CPU throttling on instance turndown + _ = await _force_flush_otel( tracing_enabled=self._tracing_enabled(), logging_enabled=bool(self._telemetry_enabled()), ) - def _asyncio_thread_main(): - try: - asyncio.run(_invoke_agent_async()) - except RuntimeError as e: - event_queue.put(e) - finally: - # Use None as a sentinel to stop the main thread. - event_queue.put(None) + def _asyncio_thread_main(): + try: + asyncio.run(_invoke_agent_async()) + except RuntimeError as e: + event_queue.put(e) + finally: + # Use None as a sentinel to stop the main thread. + event_queue.put(None) - thread = threading.Thread(target=_asyncio_thread_main) - thread.start() + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() - try: - while True: - event = event_queue.get() - if event is None: - break - if isinstance(event, RuntimeError): - raise event - yield event - finally: - thread.join() - - async def bidi_stream_query( + try: + while True: + event = event_queue.get() + if event is None: + break + if isinstance(event, RuntimeError): + raise event + yield event + finally: + thread.join() + + async def bidi_stream_query( self, request_queue: Any, ) -> AsyncIterable[Any]: - """Bidi streaming query the ADK application. + """Bidi streaming query the ADK application. Args: request_queue: @@ -1170,81 +1170,81 @@ async def bidi_stream_query( Yields: The stream responses of querying the ADK application. """ - from google.adk.agents.live_request_queue import LiveRequest - from google.adk.agents.live_request_queue import LiveRequestQueue - from vertexai.agent_engines import _utils - - # Manual type check needed as Pydantic doesn't support asyncio.Queue. - if not isinstance(request_queue, asyncio.Queue): - raise TypeError("request_queue must be an asyncio.Queue instance.") - - first_request = await request_queue.get() - user_id = first_request.get("user_id") - if not user_id: - raise ValueError("The first request must have a user_id.") - - session_id = first_request.get("session_id") - run_config = first_request.get("run_config") - first_live_request = first_request.get("live_request") - - if not self._tmpl_attrs.get("runner"): - self.set_up() - if not session_id: - state = first_request.get("state") - session = await self.async_create_session(user_id=user_id, state=state) - session_id = session["id"] if isinstance(session, dict) else session.id - run_config = _validate_run_config(run_config) - - live_request_queue = LiveRequestQueue() - - if first_live_request and isinstance(first_live_request, Dict): - live_request_queue.send(LiveRequest.model_validate(first_live_request)) - - # Forwards live requests to the agent. - async def _forward_requests(): - while True: - request = await request_queue.get() - live_request = LiveRequest.model_validate(request) - live_request_queue.send(live_request) - - # Forwards events to the client. - async def _forward_events(): - if run_config: - events_async = self._tmpl_attrs.get("runner").run_live( + from google.adk.agents.live_request_queue import LiveRequest + from google.adk.agents.live_request_queue import LiveRequestQueue + from vertexai.agent_engines import _utils + + # Manual type check needed as Pydantic doesn't support asyncio.Queue. + if not isinstance(request_queue, asyncio.Queue): + raise TypeError("request_queue must be an asyncio.Queue instance.") + + first_request = await request_queue.get() + user_id = first_request.get("user_id") + if not user_id: + raise ValueError("The first request must have a user_id.") + + session_id = first_request.get("session_id") + run_config = first_request.get("run_config") + first_live_request = first_request.get("live_request") + + if not self._tmpl_attrs.get("runner"): + self.set_up() + if not session_id: + state = first_request.get("state") + session = await self.async_create_session(user_id=user_id, state=state) + session_id = session["id"] if isinstance(session, dict) else session.id + run_config = _validate_run_config(run_config) + + live_request_queue = LiveRequestQueue() + + if first_live_request and isinstance(first_live_request, Dict): + live_request_queue.send(LiveRequest.model_validate(first_live_request)) + + # Forwards live requests to the agent. + async def _forward_requests(): + while True: + request = await request_queue.get() + live_request = LiveRequest.model_validate(request) + live_request_queue.send(live_request) + + # Forwards events to the client. + async def _forward_events(): + if run_config: + events_async = self._tmpl_attrs.get("runner").run_live( user_id=user_id, session_id=session_id, live_request_queue=live_request_queue, run_config=run_config, ) - else: - events_async = self._tmpl_attrs.get("runner").run_live( + else: + events_async = self._tmpl_attrs.get("runner").run_live( user_id=user_id, session_id=session_id, live_request_queue=live_request_queue, ) - async for event in events_async: - yield _utils.dump_event_for_json(event) + async for event in events_async: + yield _utils.dump_event_for_json(event) - requests_task = asyncio.create_task(_forward_requests()) + requests_task = asyncio.create_task(_forward_requests()) - try: - async for event in _forward_events(): - yield event - finally: - requests_task.cancel() - try: - await requests_task - except asyncio.CancelledError: - pass - - async def async_get_session( + try: + async for event in _forward_events(): + yield event + finally: + requests_task.cancel() + try: + await requests_task + except asyncio.CancelledError: + pass + + async def async_get_session( self, *, user_id: str, session_id: str, **kwargs, ): - """Get a session for the given user. + """Get a session for the given user. Args: user_id (str): @@ -1262,59 +1262,63 @@ async def async_get_session( Raises: RuntimeError: If the session is not found. """ - if not self._tmpl_attrs.get("session_service"): - self.set_up() - session = await self._tmpl_attrs.get("session_service").get_session( - app_name=self._tmpl_attrs.get("app_name"), - user_id=user_id, - session_id=session_id, - **kwargs, - ) - if not session: - raise RuntimeError( + if not self._tmpl_attrs.get("session_service"): + self.set_up() + if "config" in kwargs and isinstance(kwargs["config"], dict): + from google.adk.sessions.base_session_service import GetSessionConfig + + kwargs["config"] = GetSessionConfig(**kwargs["config"]) + session = await self._tmpl_attrs.get("session_service").get_session( + app_name=self._tmpl_attrs.get("app_name"), + user_id=user_id, + session_id=session_id, + **kwargs, + ) + if not session: + raise RuntimeError( "Session not found. Please create it using .create_session()" ) - return session + return session - def get_session( + def get_session( self, *, user_id: str, session_id: str, **kwargs, ): - """Get a session for the given user.""" - event_queue = queue.Queue(maxsize=1) + """Get a session for the given user.""" + event_queue = queue.Queue(maxsize=1) - async def _invoke_async_get_session(): - return await self.async_get_session( + async def _invoke_async_get_session(): + return await self.async_get_session( user_id=user_id, session_id=session_id, **kwargs ) - def _asyncio_thread_main(): - try: - result = asyncio.run(_invoke_async_get_session()) - event_queue.put(result) - except RuntimeError as e: - event_queue.put(e) + def _asyncio_thread_main(): + try: + result = asyncio.run(_invoke_async_get_session()) + event_queue.put(result) + except RuntimeError as e: + event_queue.put(e) - thread = threading.Thread(target=_asyncio_thread_main) - thread.start() + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() - # Wait for the thread to finish - thread.join() - try: - outcome = event_queue.get(timeout=10) - except queue.Empty: - raise RuntimeError( + # Wait for the thread to finish + thread.join() + try: + outcome = event_queue.get(timeout=10) + except queue.Empty: + raise RuntimeError( "Session not found. Please create it using .create_session()" ) from None - if isinstance(outcome, RuntimeError): - raise outcome from None - return outcome + if isinstance(outcome, RuntimeError): + raise outcome from None + return outcome - async def async_list_sessions(self, *, user_id: str, **kwargs): - """List sessions for the given user. + async def async_list_sessions(self, *, user_id: str, **kwargs): + """List sessions for the given user. Args: user_id (str): @@ -1326,41 +1330,41 @@ async def async_list_sessions(self, *, user_id: str, **kwargs): Returns: ListSessionsResponse: The list of sessions. """ - if not self._tmpl_attrs.get("session_service"): - self.set_up() - return await self._tmpl_attrs.get("session_service").list_sessions( + if not self._tmpl_attrs.get("session_service"): + self.set_up() + return await self._tmpl_attrs.get("session_service").list_sessions( app_name=self._tmpl_attrs.get("app_name"), user_id=user_id, **kwargs, ) - def list_sessions(self, *, user_id: str, **kwargs): - """List sessions for the given user.""" - event_queue = queue.Queue() - - async def _invoke_async_list_sessions(): - try: - response = await self.async_list_sessions(user_id=user_id, **kwargs) - event_queue.put(response) - except RuntimeError as e: - event_queue.put(e) - - def _asyncio_thread_main(): - try: - asyncio.run(_invoke_async_list_sessions()) - finally: - event_queue.put(None) - - thread = threading.Thread(target=_asyncio_thread_main) - thread.start() - # Wait for the thread to finish - thread.join() - try: - return event_queue.get(timeout=10) - except queue.Empty: - raise RuntimeError("Failed to list sessions.") from None + def list_sessions(self, *, user_id: str, **kwargs): + """List sessions for the given user.""" + event_queue = queue.Queue() + + async def _invoke_async_list_sessions(): + try: + response = await self.async_list_sessions(user_id=user_id, **kwargs) + event_queue.put(response) + except RuntimeError as e: + event_queue.put(e) + + def _asyncio_thread_main(): + try: + asyncio.run(_invoke_async_list_sessions()) + finally: + event_queue.put(None) + + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() + # Wait for the thread to finish + thread.join() + try: + return event_queue.get(timeout=10) + except queue.Empty: + raise RuntimeError("Failed to list sessions.") from None - async def async_create_session( + async def async_create_session( self, *, user_id: str, @@ -1368,7 +1372,7 @@ async def async_create_session( state: Optional[Dict[str, Any]] = None, **kwargs, ): - """Creates a new session. + """Creates a new session. Args: user_id (str): @@ -1385,18 +1389,18 @@ async def async_create_session( Returns: Session: The newly created session instance. """ - if not self._tmpl_attrs.get("session_service"): - self.set_up() - session = await self._tmpl_attrs.get("session_service").create_session( + if not self._tmpl_attrs.get("session_service"): + self.set_up() + session = await self._tmpl_attrs.get("session_service").create_session( app_name=self._tmpl_attrs.get("app_name"), user_id=user_id, session_id=session_id, state=state, **kwargs, ) - return self._serialize(session) + return self._serialize(session) - def create_session( + def create_session( self, *, user_id: str, @@ -1404,45 +1408,45 @@ def create_session( state: Optional[Dict[str, Any]] = None, **kwargs, ): - """Creates a new session.""" - event_queue = queue.Queue(maxsize=1) + """Creates a new session.""" + event_queue = queue.Queue(maxsize=1) - async def _invoke_async_create_session(): - return await self.async_create_session( + async def _invoke_async_create_session(): + return await self.async_create_session( user_id=user_id, session_id=session_id, state=state, **kwargs, ) - def _asyncio_thread_main(): - try: - result = asyncio.run(_invoke_async_create_session()) - event_queue.put(result) - except RuntimeError as e: - event_queue.put(e) + def _asyncio_thread_main(): + try: + result = asyncio.run(_invoke_async_create_session()) + event_queue.put(result) + except RuntimeError as e: + event_queue.put(e) - thread = threading.Thread(target=_asyncio_thread_main) - thread.start() - # Wait for the thread to finish - thread.join() + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() + # Wait for the thread to finish + thread.join() - try: - outcome = event_queue.get(timeout=10) - except queue.Empty: - raise RuntimeError("Failed to create session.") from None - if isinstance(outcome, RuntimeError): - raise outcome from None - return outcome - - async def async_delete_session( + try: + outcome = event_queue.get(timeout=10) + except queue.Empty: + raise RuntimeError("Failed to create session.") from None + if isinstance(outcome, RuntimeError): + raise outcome from None + return outcome + + async def async_delete_session( self, *, user_id: str, session_id: str, **kwargs, ): - """Deletes a session for the given user. + """Deletes a session for the given user. Args: user_id (str): @@ -1453,48 +1457,48 @@ async def async_delete_session( Optional. Additional keyword arguments to pass to the session service. """ - if not self._tmpl_attrs.get("session_service"): - self.set_up() - await self._tmpl_attrs.get("session_service").delete_session( + if not self._tmpl_attrs.get("session_service"): + self.set_up() + await self._tmpl_attrs.get("session_service").delete_session( app_name=self._tmpl_attrs.get("app_name"), user_id=user_id, session_id=session_id, **kwargs, ) - def delete_session( + def delete_session( self, *, user_id: str, session_id: str, **kwargs, ): - """Deletes a session for the given user.""" - event_queue = queue.Queue(maxsize=1) + """Deletes a session for the given user.""" + event_queue = queue.Queue(maxsize=1) - async def _invoke_async_delete_session(): - await self.async_delete_session( + async def _invoke_async_delete_session(): + await self.async_delete_session( user_id=user_id, session_id=session_id, **kwargs ) - def _asyncio_thread_main(): - try: - asyncio.run(_invoke_async_delete_session()) - event_queue.put(None) - except RuntimeError as e: - event_queue.put(e) + def _asyncio_thread_main(): + try: + asyncio.run(_invoke_async_delete_session()) + event_queue.put(None) + except RuntimeError as e: + event_queue.put(e) - thread = threading.Thread(target=_asyncio_thread_main) - thread.start() - # Wait for the thread to finish - thread.join() + thread = threading.Thread(target=_asyncio_thread_main) + thread.start() + # Wait for the thread to finish + thread.join() - outcome = event_queue.get(timeout=10) - if isinstance(outcome, RuntimeError): - raise outcome from None + outcome = event_queue.get(timeout=10) + if isinstance(outcome, RuntimeError): + raise outcome from None - async def async_add_session_to_memory(self, *, session: Dict[str, Any]): - """Generates memories. + async def async_add_session_to_memory(self, *, session: Dict[str, Any]): + """Generates memories. Args: session (Dict[str, Any]): @@ -1502,26 +1506,26 @@ async def async_add_session_to_memory(self, *, session: Dict[str, Any]): be a dictionary representing an ADK Session object, e.g. session.model_dump(mode="json"). """ - from google.adk.sessions.session import Session - - if isinstance(session, Dict): - session = Session.model_validate(session) - elif not isinstance(session, Session): - raise TypeError("session must be a Session object.") - if not session.events: - # Get the latest version of the session in case it was updated. - session = await self.async_get_session( + from google.adk.sessions.session import Session + + if isinstance(session, Dict): + session = Session.model_validate(session) + elif not isinstance(session, Session): + raise TypeError("session must be a Session object.") + if not session.events: + # Get the latest version of the session in case it was updated. + session = await self.async_get_session( user_id=session.user_id, session_id=session.id, ) - if not self._tmpl_attrs.get("memory_service"): - self.set_up() - return await self._tmpl_attrs.get("memory_service").add_session_to_memory( + if not self._tmpl_attrs.get("memory_service"): + self.set_up() + return await self._tmpl_attrs.get("memory_service").add_session_to_memory( session=session, ) - async def async_search_memory(self, *, user_id: str, query: str): - """Searches memories for the given user. + async def async_search_memory(self, *, user_id: str, query: str): + """Searches memories for the given user. Args: user_id: The id of the user. @@ -1530,17 +1534,17 @@ async def async_search_memory(self, *, user_id: str, query: str): Returns: A SearchMemoryResponse containing the matching memories. """ - if not self._tmpl_attrs.get("memory_service"): - self.set_up() - return await self._tmpl_attrs.get("memory_service").search_memory( + if not self._tmpl_attrs.get("memory_service"): + self.set_up() + return await self._tmpl_attrs.get("memory_service").search_memory( app_name=self._tmpl_attrs.get("app_name"), user_id=user_id, query=query, ) - def register_operations(self) -> Dict[str, List[str]]: - """Registers the operations of the ADK application.""" - return { + def register_operations(self) -> Dict[str, List[str]]: + """Registers the operations of the ADK application.""" + return { "": [ "get_session", "list_sessions", @@ -1560,8 +1564,8 @@ def register_operations(self) -> Dict[str, List[str]]: "bidi_stream": ["bidi_stream_query"], } - def _telemetry_enabled(self) -> Optional[bool]: - """Return status of telemetry enablement depending on enablement env variable. + def _telemetry_enabled(self) -> Optional[bool]: + """Return status of telemetry enablement depending on enablement env variable. In detail: - Logging is always enabled when telemetry is enabled. @@ -1571,25 +1575,25 @@ def _telemetry_enabled(self) -> Optional[bool]: True if telemetry is enabled, False if telemetry is disabled, or None if telemetry enablement is not set (i.e. old deployments which don't support this env variable). """ - import os + import os - GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = ( + GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = ( "GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY" ) - env_value = os.getenv( + env_value = os.getenv( GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY, "unspecified" ).lower() - if env_value in ("true", "1"): - return True - if env_value in ("false", "0"): - return False - return None + if env_value in ("true", "1"): + return True + if env_value in ("false", "0"): + return False + return None - # Tracing enablement follows truth table: - def _tracing_enabled(self) -> bool: - """Tracing enablement follows true table: + # Tracing enablement follows truth table: + def _tracing_enabled(self) -> bool: + """Tracing enablement follows true table: | enable_tracing | enable_telemetry(env) | tracing_actually_enabled | |----------------|-----------------------|--------------------------| @@ -1603,41 +1607,41 @@ def _tracing_enabled(self) -> bool: | None(default) | true | adk_version >= 1.17 | | None(default) | None | false | """ - enable_tracing: Optional[bool] = self._tmpl_attrs.get("enable_tracing") - enable_telemetry: Optional[bool] = self._telemetry_enabled() + enable_tracing: Optional[bool] = self._tmpl_attrs.get("enable_tracing") + enable_telemetry: Optional[bool] = self._telemetry_enabled() - return (enable_tracing is True and enable_telemetry is not False) or ( + return (enable_tracing is True and enable_telemetry is not False) or ( enable_tracing is None and enable_telemetry is True and is_version_sufficient("1.17.0") ) - def _warn_if_telemetry_api_disabled(self): - """Warn if telemetry API is disabled.""" - try: - import google.auth.transport.requests - import google.auth - except (ImportError, AttributeError): - return - credentials, project = google.auth.default() - session = google.auth.transport.requests.AuthorizedSession( + def _warn_if_telemetry_api_disabled(self): + """Warn if telemetry API is disabled.""" + try: + import google.auth.transport.requests + import google.auth + except (ImportError, AttributeError): + return + credentials, project = google.auth.default() + session = google.auth.transport.requests.AuthorizedSession( credentials=credentials ) - r = session.post("https://telemetry.googleapis.com/v1/traces", data=None) - if "Telemetry API has not been used in project" in r.text: - _warn(_TELEMETRY_API_DISABLED_WARNING % (project, project)) - - def project_id(self) -> Optional[str]: - if project := self._tmpl_attrs.get("project"): - try: - from google.cloud.aiplatform.utils import ( + r = session.post("https://telemetry.googleapis.com/v1/traces", data=None) + if "Telemetry API has not been used in project" in r.text: + _warn(_TELEMETRY_API_DISABLED_WARNING % (project, project)) + + def project_id(self) -> Optional[str]: + if project := self._tmpl_attrs.get("project"): + try: + from google.cloud.aiplatform.utils import ( resource_manager_utils, ) - from google.api_core import exceptions + from google.api_core import exceptions - return resource_manager_utils.get_project_id(project) - # Fail open as temporary workaround for identity_type config parameter - except (exceptions.PermissionDenied, exceptions.Unauthenticated): - return project + return resource_manager_utils.get_project_id(project) + # Fail open as temporary workaround for identity_type config parameter + except (exceptions.PermissionDenied, exceptions.Unauthenticated): + return project - return None + return None