From 60b0f8401b7669d55faae2ebee29c1a88caa2a75 Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Tue, 16 Jun 2026 09:05:16 -0700 Subject: [PATCH] fix: GenAI client - Make vertex_rag_store a required arg for client.rag.retrieve_contexts and make the tools arg in ask_contexts part of config PiperOrigin-RevId: 933119741 --- agentplatform/_genai/rag.py | 30 +++++++----- agentplatform/_genai/types/common.py | 8 ++-- .../genai/replays/test_rag_ask_contexts.py | 48 ++++++++++--------- 3 files changed, 49 insertions(+), 37 deletions(-) diff --git a/agentplatform/_genai/rag.py b/agentplatform/_genai/rag.py index ecc0b93a88..d4910ea37a 100644 --- a/agentplatform/_genai/rag.py +++ b/agentplatform/_genai/rag.py @@ -22,7 +22,6 @@ from google.genai import _api_module from google.genai import _common -from google.genai import types as genai_types from google.genai._common import get_value_by_path as getv from google.genai._common import set_value_by_path as setv @@ -32,6 +31,18 @@ logger = logging.getLogger("agentplatform_genai.rag") +def _AskContextsConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + + if getv(from_object, ["tools"]) is not None: + setv(parent_object, ["tools"], getv(from_object, ["tools"])) + + return to_object + + def _AskContextsRequestParameters_to_vertex( from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, @@ -41,10 +52,11 @@ def _AskContextsRequestParameters_to_vertex( setv(to_object, ["query"], getv(from_object, ["query"])) if getv(from_object, ["config"]) is not None: - setv(to_object, ["config"], getv(from_object, ["config"])) - - if getv(from_object, ["tools"]) is not None: - setv(to_object, ["tools"], getv(from_object, ["tools"])) + setv( + to_object, + ["config"], + _AskContextsConfig_to_vertex(getv(from_object, ["config"]), to_object), + ) return to_object @@ -1386,7 +1398,6 @@ def ask_contexts( *, query: types.RagQueryOrDict, config: Optional[types.AskContextsConfigOrDict] = None, - tools: Optional[list[genai_types.ToolOrDict]] = None, ) -> types.AskContextsResponse: """ Asks a RAG Contexts. @@ -1395,7 +1406,6 @@ def ask_contexts( parameter_model = types._AskContextsRequestParameters( query=query, config=config, - tools=tools, ) request_url_dict: Optional[dict[str, str]] @@ -2240,7 +2250,7 @@ def _update_config( def retrieve_contexts( self, *, - vertex_rag_store: Optional[types.VertexRagStoreOrDict] = None, + vertex_rag_store: types.VertexRagStoreOrDict, query: types.RagQueryOrDict, config: Optional[types.RetrieveContextsConfigOrDict] = None, ) -> types.RetrieveContextsResponse: @@ -2522,7 +2532,6 @@ async def ask_contexts( *, query: types.RagQueryOrDict, config: Optional[types.AskContextsConfigOrDict] = None, - tools: Optional[list[genai_types.ToolOrDict]] = None, ) -> types.AskContextsResponse: """ Asks a RAG Contexts. @@ -2531,7 +2540,6 @@ async def ask_contexts( parameter_model = types._AskContextsRequestParameters( query=query, config=config, - tools=tools, ) request_url_dict: Optional[dict[str, str]] @@ -3400,7 +3408,7 @@ async def _update_config( async def retrieve_contexts( self, *, - vertex_rag_store: Optional[types.VertexRagStoreOrDict] = None, + vertex_rag_store: types.VertexRagStoreOrDict, query: types.RagQueryOrDict, config: Optional[types.RetrieveContextsConfigOrDict] = None, ) -> types.RetrieveContextsResponse: diff --git a/agentplatform/_genai/types/common.py b/agentplatform/_genai/types/common.py index ec37a5c3f8..a361b2f77a 100644 --- a/agentplatform/_genai/types/common.py +++ b/agentplatform/_genai/types/common.py @@ -11942,6 +11942,7 @@ class AskContextsConfig(_common.BaseModel): http_options: Optional[genai_types.HttpOptions] = Field( default=None, description="""Used to override HTTP request options.""" ) + tools: Optional[list[genai_types.Tool]] = Field(default=None, description="""""") class AskContextsConfigDict(TypedDict, total=False): @@ -11950,6 +11951,9 @@ class AskContextsConfigDict(TypedDict, total=False): http_options: Optional[genai_types.HttpOptionsDict] """Used to override HTTP request options.""" + tools: Optional[list[genai_types.ToolDict]] + """""" + AskContextsConfigOrDict = Union[AskContextsConfig, AskContextsConfigDict] @@ -12175,7 +12179,6 @@ class _AskContextsRequestParameters(_common.BaseModel): query: Optional[RagQuery] = Field(default=None, description="""""") config: Optional[AskContextsConfig] = Field(default=None, description="""""") - tools: Optional[list[genai_types.Tool]] = Field(default=None, description="""""") class _AskContextsRequestParametersDict(TypedDict, total=False): @@ -12187,9 +12190,6 @@ class _AskContextsRequestParametersDict(TypedDict, total=False): config: Optional[AskContextsConfigDict] """""" - tools: Optional[list[genai_types.ToolDict]] - """""" - _AskContextsRequestParametersOrDict = Union[ _AskContextsRequestParameters, _AskContextsRequestParametersDict diff --git a/tests/unit/agentplatform/genai/replays/test_rag_ask_contexts.py b/tests/unit/agentplatform/genai/replays/test_rag_ask_contexts.py index aafa110b52..18748d04b0 100644 --- a/tests/unit/agentplatform/genai/replays/test_rag_ask_contexts.py +++ b/tests/unit/agentplatform/genai/replays/test_rag_ask_contexts.py @@ -34,19 +34,21 @@ def test_ask_contexts(client): text="earnings", similarity_top_k=5, ), - tools=[ - genai_types.Tool( - retrieval=genai_types.Retrieval( - vertex_rag_store=genai_types.VertexRagStore( - rag_resources=[ - genai_types.VertexRagStoreRagResource( - rag_corpus="projects/vertex-sdk-dev/locations/us-central1/ragCorpora/2305843009213693952" - ) - ] + config=types.AskContextsConfig( + tools=[ + genai_types.Tool( + retrieval=genai_types.Retrieval( + vertex_rag_store=genai_types.VertexRagStore( + rag_resources=[ + genai_types.VertexRagStoreRagResource( + rag_corpus="projects/vertex-sdk-dev/locations/us-central1/ragCorpora/2305843009213693952" + ) + ] + ) ) ) - ) - ], + ], + ), ) assert isinstance(rag_contexts, types.AskContextsResponse) @@ -63,19 +65,21 @@ async def test_ask_contexts_async(client): text="Grounding query", similarity_top_k=5, ), - tools=[ - genai_types.Tool( - retrieval=genai_types.Retrieval( - vertex_rag_store=genai_types.VertexRagStore( - rag_resources=[ - genai_types.VertexRagStoreRagResource( - rag_corpus="projects/vertex-sdk-dev/locations/us-central1/ragCorpora/2305843009213693952" - ) - ] + config=types.AskContextsConfig( + tools=[ + genai_types.Tool( + retrieval=genai_types.Retrieval( + vertex_rag_store=genai_types.VertexRagStore( + rag_resources=[ + genai_types.VertexRagStoreRagResource( + rag_corpus="projects/vertex-sdk-dev/locations/us-central1/ragCorpora/2305843009213693952" + ) + ] + ) ) ) - ) - ], + ], + ), ) assert isinstance(rag_contexts, types.AskContextsResponse)