From 13b685bd617cbcce36025625590a3ecb28494247 Mon Sep 17 00:00:00 2001 From: Sara Robinson Date: Wed, 17 Jun 2026 08:07:41 -0700 Subject: [PATCH] feat: GenAI client - Add upload_file method to RAG module PiperOrigin-RevId: 933730856 --- agentplatform/_genai/rag.py | 402 +++++++++++++++++- agentplatform/_genai/types/__init__.py | 26 ++ agentplatform/_genai/types/common.py | 159 +++++++ .../genai/replays/test_rag_upload.py | 65 +++ 4 files changed, 644 insertions(+), 8 deletions(-) create mode 100644 tests/unit/agentplatform/genai/replays/test_rag_upload.py diff --git a/agentplatform/_genai/rag.py b/agentplatform/_genai/rag.py index 46ca4fc19d..4cac2ab7ec 100644 --- a/agentplatform/_genai/rag.py +++ b/agentplatform/_genai/rag.py @@ -17,11 +17,14 @@ import json import logging +import mimetypes +import os from typing import Any, Optional, Union from urllib.parse import urlencode from google.genai import _api_module from google.genai import _common +from google.genai import _extra_utils from google.genai._common import get_value_by_path as getv from google.genai._common import set_value_by_path as setv @@ -1555,6 +1558,67 @@ def _UpdateRagCorpusRequestParameters_to_vertex( return to_object +def _UploadRagFileConfig_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["rag_file_chunking_config"]) is not None: + setv( + to_object, + ["ragFileChunkingConfig"], + getv(from_object, ["rag_file_chunking_config"]), + ) + + if getv(from_object, ["rag_file_metadata_config"]) is not None: + setv( + to_object, + ["ragFileMetadataConfig"], + getv(from_object, ["rag_file_metadata_config"]), + ) + + if getv(from_object, ["rag_file_parsing_config"]) is not None: + setv( + to_object, + ["ragFileParsingConfig"], + _RagFileParsingConfig_to_vertex( + getv(from_object, ["rag_file_parsing_config"]), to_object + ), + ) + + if getv(from_object, ["rag_file_transformation_config"]) is not None: + setv( + to_object, + ["ragFileTransformationConfig"], + getv(from_object, ["rag_file_transformation_config"]), + ) + + return to_object + + +def _UploadRagFileParameters_to_vertex( + from_object: Union[dict[str, Any], object], + parent_object: Optional[dict[str, Any]] = None, +) -> dict[str, Any]: + to_object: dict[str, Any] = {} + if getv(from_object, ["name"]) is not None: + setv(to_object, ["_url", "name"], getv(from_object, ["name"])) + + if getv(from_object, ["rag_file"]) is not None: + setv(to_object, ["ragFile"], getv(from_object, ["rag_file"])) + + if getv(from_object, ["upload_rag_file_config"]) is not None: + setv( + to_object, + ["uploadRagFileConfig"], + _UploadRagFileConfig_to_vertex( + getv(from_object, ["upload_rag_file_config"]), to_object + ), + ) + + return to_object + + def _VertexAiSearchConfig_from_vertex( from_object: Union[dict[str, Any], object], parent_object: Optional[dict[str, Any]] = None, @@ -2719,6 +2783,78 @@ def _get_import_files_operation( self._api_client._verify_response(return_value) return return_value + def _upload_file( + self, + *, + name: str, + rag_file: types.RagFileOrDict, + upload_rag_file_config: Optional[types.UploadRagFileConfigOrDict] = None, + config: Optional[types.UploadRagFileRequestConfigOrDict] = None, + ) -> types.UploadRagFileResponse: + parameter_model = types._UploadRagFileParameters( + name=name, + rag_file=rag_file, + upload_rag_file_config=upload_rag_file_config, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in Gemini Enterprise Agent Platform mode, not in Gemini Developer API mode." + ) + else: + request_dict = _UploadRagFileParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/ragFiles:upload".format_map(request_url_dict) + else: + path = "{name}/ragFiles:upload" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request("post", path, request_dict, http_options) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.UploadRagFileResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + def create_corpus( self, *, @@ -2726,7 +2862,7 @@ def create_corpus( config: Optional[types.CreateRagCorpusConfigOrDict] = None, ) -> types.RagCorpus: """ - Creates a new Rag Corpus and waits for completion. + Creates a new RAG Corpus and waits for completion. Args: rag_corpus: The RagCorpus to create. @@ -2754,7 +2890,7 @@ def delete_corpus( config: Optional[types.DeleteRagCorpusConfigOrDict] = None, ) -> None: """ - Deletes a Rag Corpus and waits for the delete operation to complete. + Deletes a RAG Corpus and waits for the delete operation to complete. """ operation = self._delete_corpus(name=name, config=config) @@ -2775,7 +2911,7 @@ def delete_file( config: Optional[types.DeleteRagFileConfigOrDict] = None, ) -> None: """ - Deletes a file from a Rag Corpus and waits for the delete operation to complete. + Deletes a file from a RAG Corpus and waits for the delete operation to complete. """ operation = self._delete_file(name=name, config=config) @@ -2799,7 +2935,7 @@ def update_corpus( config: Optional[types.UpdateRagCorpusConfigOrDict] = None, ) -> types.RagCorpus: """ - Updates a Rag Corpus and waits for completion. + Updates a RAG Corpus and waits for completion. Args: name: The name of the RagCorpus to update, formatted as @@ -2916,6 +3052,94 @@ def import_files( return operation.response + def upload_file( + self, + *, + corpus_name: str, + path: str, + display_name: Optional[str] = None, + upload_rag_file_config: Optional[types.UploadRagFileConfigOrDict] = None, + request_config: Optional[types.UploadRagFileRequestConfigOrDict] = None, + ) -> types.RagFile: + """ + Uploads a file to a RAG Corpus. + + Args: + corpus_name: The name of the RAG Corpus to upload to. + path: The path to the file to upload. + display_name: Optional. The display name for the uploaded file. If not provided, a display name will be generated. + upload_rag_file_config: Optional. The configuration to use for the upload. + request_config: Optional. The configuration to use for the request. + + Returns: + The uploaded RagFile. + """ + + if not display_name: + display_name = f"file_{_common.timestamped_unique_name()}" + + rag_file = types.RagFile(display_name=display_name) + + mime_type, _ = mimetypes.guess_type(path) + + if mime_type is None: + mime_type = "application/octet-stream" + + http_options, size_bytes, mime_type = _extra_utils.prepare_resumable_upload( + path, + user_http_options=request_config.http_options if request_config else None, + user_mime_type=mime_type, + ) + + current_api_version = self._api_client._http_options.api_version or "v1beta1" + upload_api_version = f"upload/{current_api_version}" + + http_options.api_version = upload_api_version + + parameter_model = types._UploadRagFileParameters( + name=corpus_name, + rag_file=rag_file, + upload_rag_file_config=upload_rag_file_config, + ) + request_dict = _UploadRagFileParameters_to_vertex(parameter_model) + + request_dict.pop("_url", None) + request_dict.pop("_query", None) + + request_path = f"{corpus_name}/ragFiles:upload" + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = self._api_client.request( + "post", + request_path, + request_dict, + http_options, + ) + + if response.headers is None or ( + "x-goog-upload-url" not in response.headers + and "X-Goog-Upload-URL" not in response.headers + ): + raise KeyError( + "Failed to create file. Upload URL was not returned from the create file request." + ) + + upload_url = response.headers.get( + "x-goog-upload-url", response.headers.get("X-Goog-Upload-URL") + ) + + fs_path = os.fspath(path) + return_file = self._api_client.upload_file( + fs_path, upload_url, size_bytes, http_options=http_options + ) + + rag_file_payload = return_file.json.get("ragFile") or return_file.json.get( + "rag_file", {} + ) + return types.RagFile(**rag_file_payload) + class AsyncRag(_api_module.BaseModule): @@ -4091,6 +4315,80 @@ async def _get_import_files_operation( self._api_client._verify_response(return_value) return return_value + async def _upload_file( + self, + *, + name: str, + rag_file: types.RagFileOrDict, + upload_rag_file_config: Optional[types.UploadRagFileConfigOrDict] = None, + config: Optional[types.UploadRagFileRequestConfigOrDict] = None, + ) -> types.UploadRagFileResponse: + parameter_model = types._UploadRagFileParameters( + name=name, + rag_file=rag_file, + upload_rag_file_config=upload_rag_file_config, + config=config, + ) + + request_url_dict: Optional[dict[str, str]] + if not self._api_client.vertexai: + raise ValueError( + "This method is only supported in Gemini Enterprise Agent Platform mode, not in Gemini Developer API mode." + ) + else: + request_dict = _UploadRagFileParameters_to_vertex(parameter_model) + request_url_dict = request_dict.get("_url") + if request_url_dict: + path = "{name}/ragFiles:upload".format_map(request_url_dict) + else: + path = "{name}/ragFiles:upload" + + query_params = request_dict.get("_query") + if query_params: + path = f"{path}?{urlencode(query_params)}" + # TODO: remove the hack that pops config. + request_dict.pop("config", None) + + http_options: Optional[types.HttpOptions] = None + if ( + parameter_model.config is not None + and parameter_model.config.http_options is not None + ): + http_options = parameter_model.config.http_options + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", path, request_dict, http_options + ) + + response_dict = {} if not response.body else json.loads(response.body) + + return_value = types.UploadRagFileResponse._from_response( + response=response_dict, + kwargs=( + { + "config": { + "response_schema": getattr( + parameter_model.config, "response_schema", None + ), + "response_json_schema": getattr( + parameter_model.config, "response_json_schema", None + ), + "include_all_fields": getattr( + parameter_model.config, "include_all_fields", None + ), + } + } + if getattr(parameter_model, "config", None) + else {} + ), + ) + + self._api_client._verify_response(return_value) + return return_value + async def create_corpus( self, *, @@ -4098,7 +4396,7 @@ async def create_corpus( config: Optional[types.CreateRagCorpusConfigOrDict] = None, ) -> types.RagCorpus: """ - Creates a new Rag Corpus and waits for completion asynchronously. + Creates a new RAG Corpus and waits for completion asynchronously. Args: rag_corpus: The RagCorpus to create. @@ -4126,7 +4424,7 @@ async def delete_corpus( config: Optional[types.DeleteRagCorpusConfigOrDict] = None, ) -> None: """ - Deletes a Rag Corpus and waits for the delete operation to complete asynchronously. + Deletes a RAG Corpus and waits for the delete operation to complete asynchronously. """ operation = await self._delete_corpus(name=name, config=config) @@ -4147,7 +4445,7 @@ async def delete_file( config: Optional[types.DeleteRagFileConfigOrDict] = None, ) -> None: """ - Deletes a file from a Rag Corpus and waits for the delete operation to complete asynchronously. + Deletes a file from a RAG Corpus and waits for the delete operation to complete asynchronously. """ operation = await self._delete_file(name=name, config=config) @@ -4171,7 +4469,7 @@ async def update_corpus( config: Optional[types.UpdateRagCorpusConfigOrDict] = None, ) -> types.RagCorpus: """ - Updates a Rag Corpus and waits for completion asynchronously. + Updates a RAG Corpus and waits for completion asynchronously. Args: name: The name of the RagCorpus to update, formatted as @@ -4290,3 +4588,91 @@ async def import_files( ) return operation.response + + async def upload_file( + self, + *, + corpus_name: str, + path: str, + display_name: Optional[str] = None, + upload_rag_file_config: Optional[types.UploadRagFileConfigOrDict] = None, + request_config: Optional[types.UploadRagFileRequestConfigOrDict] = None, + ) -> types.RagFile: + """ + Uploads a file to a RAG Corpus. + + Args: + corpus_name: The name of the RAG Corpus to upload to. + path: The path to the file to upload. + display_name: Optional. The display name for the uploaded file. If not provided, a display name will be generated. + upload_rag_file_config: Optional. The configuration to use for the upload. + request_config: Optional. The configuration to use for the request. + + Returns: + The uploaded RagFile. + """ + + if not display_name: + display_name = f"file_{_common.timestamped_unique_name()}" + + rag_file = types.RagFile(display_name=display_name) + + mime_type, _ = mimetypes.guess_type(path) + + if mime_type is None: + mime_type = "application/octet-stream" + + http_options, size_bytes, mime_type = _extra_utils.prepare_resumable_upload( + path, + user_http_options=request_config.http_options if request_config else None, + user_mime_type=mime_type, + ) + + current_api_version = self._api_client._http_options.api_version or "v1beta1" + upload_api_version = f"upload/{current_api_version}" + + http_options.api_version = upload_api_version + + parameter_model = types._UploadRagFileParameters( + name=corpus_name, + rag_file=rag_file, + upload_rag_file_config=upload_rag_file_config, + ) + request_dict = _UploadRagFileParameters_to_vertex(parameter_model) + + request_dict.pop("_url", None) + request_dict.pop("_query", None) + + request_path = f"{corpus_name}/ragFiles:upload" + + request_dict = _common.convert_to_dict(request_dict) + request_dict = _common.encode_unserializable_types(request_dict) + + response = await self._api_client.async_request( + "post", + request_path, + request_dict, + http_options, + ) + + if response.headers is None or ( + "x-goog-upload-url" not in response.headers + and "X-Goog-Upload-URL" not in response.headers + ): + raise KeyError( + "Failed to create file. Upload URL was not returned from the create file request." + ) + + upload_url = response.headers.get( + "x-goog-upload-url", response.headers.get("X-Goog-Upload-URL") + ) + + fs_path = os.fspath(path) + return_file = await self._api_client.async_upload_file( + fs_path, upload_url, size_bytes, http_options=http_options + ) + + rag_file_payload = return_file.json.get("ragFile") or return_file.json.get( + "rag_file", {} + ) + return types.RagFile(**rag_file_payload) diff --git a/agentplatform/_genai/types/__init__.py b/agentplatform/_genai/types/__init__.py index 4b2f9b651c..999aed44fb 100644 --- a/agentplatform/_genai/types/__init__.py +++ b/agentplatform/_genai/types/__init__.py @@ -149,6 +149,7 @@ from .common import _UpdateRagConfigRequestParameters from .common import _UpdateRagCorpusRequestParameters from .common import _UpdateSkillRequestParameters +from .common import _UploadRagFileParameters from .common import A2aTask from .common import A2aTaskDict from .common import A2aTaskOrDict @@ -735,6 +736,9 @@ from .common import GoogleDriveSourceResourceId from .common import GoogleDriveSourceResourceIdDict from .common import GoogleDriveSourceResourceIdOrDict +from .common import GoogleRpcStatus +from .common import GoogleRpcStatusDict +from .common import GoogleRpcStatusOrDict from .common import IdentityType from .common import Importance from .common import ImportRagFilesConfig @@ -1757,6 +1761,15 @@ from .common import UpdateSkillConfig from .common import UpdateSkillConfigDict from .common import UpdateSkillConfigOrDict +from .common import UploadRagFileConfig +from .common import UploadRagFileConfigDict +from .common import UploadRagFileConfigOrDict +from .common import UploadRagFileRequestConfig +from .common import UploadRagFileRequestConfigDict +from .common import UploadRagFileRequestConfigOrDict +from .common import UploadRagFileResponse +from .common import UploadRagFileResponseDict +from .common import UploadRagFileResponseOrDict from .common import VertexAiSearchConfig from .common import VertexAiSearchConfigDict from .common import VertexAiSearchConfigOrDict @@ -2836,6 +2849,18 @@ "GetImportFilesOperationConfig", "GetImportFilesOperationConfigDict", "GetImportFilesOperationConfigOrDict", + "UploadRagFileRequestConfig", + "UploadRagFileRequestConfigDict", + "UploadRagFileRequestConfigOrDict", + "UploadRagFileConfig", + "UploadRagFileConfigDict", + "UploadRagFileConfigOrDict", + "GoogleRpcStatus", + "GoogleRpcStatusDict", + "GoogleRpcStatusOrDict", + "UploadRagFileResponse", + "UploadRagFileResponseDict", + "UploadRagFileResponseOrDict", "GetAgentEngineRuntimeRevisionConfig", "GetAgentEngineRuntimeRevisionConfigDict", "GetAgentEngineRuntimeRevisionConfigOrDict", @@ -3442,6 +3467,7 @@ "_GetRagConfigOperationParameters", "_ImportRagFilesRequestParameters", "_GetImportFilesOperationParameters", + "_UploadRagFileParameters", "_GetAgentEngineRuntimeRevisionRequestParameters", "_ListAgentEngineRuntimeRevisionsRequestParameters", "_DeleteAgentEngineRuntimeRevisionRequestParameters", diff --git a/agentplatform/_genai/types/common.py b/agentplatform/_genai/types/common.py index e04cec58fb..6db6c8249b 100644 --- a/agentplatform/_genai/types/common.py +++ b/agentplatform/_genai/types/common.py @@ -15334,6 +15334,165 @@ class _GetImportFilesOperationParametersDict(TypedDict, total=False): ] +class UploadRagFileRequestConfig(_common.BaseModel): + """Config for the request to upload a Rag File.""" + + http_options: Optional[genai_types.HttpOptions] = Field( + default=None, description="""Used to override HTTP request options.""" + ) + + +class UploadRagFileRequestConfigDict(TypedDict, total=False): + """Config for the request to upload a Rag File.""" + + http_options: Optional[genai_types.HttpOptionsDict] + """Used to override HTTP request options.""" + + +UploadRagFileRequestConfigOrDict = Union[ + UploadRagFileRequestConfig, UploadRagFileRequestConfigDict +] + + +class UploadRagFileConfig(_common.BaseModel): + """Config for uploading RagFile.""" + + rag_file_chunking_config: Optional[RagFileChunkingConfig] = Field( + default=None, + description="""Specifies the size and overlap of chunks after uploading RagFile.""", + ) + rag_file_metadata_config: Optional[RagFileMetadataConfig] = Field( + default=None, + description="""Optional. Specifies the metadata config for RagFiles. Including paths for metadata schema and metadata. Alteratively, inline metadata schema and metadata can be provided. Deprecated: Not in use.""", + ) + rag_file_parsing_config: Optional[RagFileParsingConfig] = Field( + default=None, + description="""Optional. Specifies the parsing config for RagFiles. RAG will use the default parser if this field is not set.""", + ) + rag_file_transformation_config: Optional[RagFileTransformationConfig] = Field( + default=None, + description="""Specifies the transformation config for RagFiles.""", + ) + + +class UploadRagFileConfigDict(TypedDict, total=False): + """Config for uploading RagFile.""" + + rag_file_chunking_config: Optional[RagFileChunkingConfigDict] + """Specifies the size and overlap of chunks after uploading RagFile.""" + + rag_file_metadata_config: Optional[RagFileMetadataConfigDict] + """Optional. Specifies the metadata config for RagFiles. Including paths for metadata schema and metadata. Alteratively, inline metadata schema and metadata can be provided. Deprecated: Not in use.""" + + rag_file_parsing_config: Optional[RagFileParsingConfigDict] + """Optional. Specifies the parsing config for RagFiles. RAG will use the default parser if this field is not set.""" + + rag_file_transformation_config: Optional[RagFileTransformationConfigDict] + """Specifies the transformation config for RagFiles.""" + + +UploadRagFileConfigOrDict = Union[UploadRagFileConfig, UploadRagFileConfigDict] + + +class _UploadRagFileParameters(_common.BaseModel): + """Parameters for uploading a Rag File.""" + + name: Optional[str] = Field( + default=None, + description="""The name of the RagCorpus resource into which to upload the file.""", + ) + rag_file: Optional[RagFile] = Field( + default=None, description="""The RagFile metadata to upload.""" + ) + upload_rag_file_config: Optional[UploadRagFileConfig] = Field( + default=None, + description="""The config for the RagFiles to be uploaded into the RagCorpus.""", + ) + config: Optional[UploadRagFileRequestConfig] = Field( + default=None, description="""Used to override the default configuration.""" + ) + + +class _UploadRagFileParametersDict(TypedDict, total=False): + """Parameters for uploading a Rag File.""" + + name: Optional[str] + """The name of the RagCorpus resource into which to upload the file.""" + + rag_file: Optional[RagFileDict] + """The RagFile metadata to upload.""" + + upload_rag_file_config: Optional[UploadRagFileConfigDict] + """The config for the RagFiles to be uploaded into the RagCorpus.""" + + config: Optional[UploadRagFileRequestConfigDict] + """Used to override the default configuration.""" + + +_UploadRagFileParametersOrDict = Union[ + _UploadRagFileParameters, _UploadRagFileParametersDict +] + + +class GoogleRpcStatus(_common.BaseModel): + """The `Status` type defines a logical error model that is suitable for different programming environments, including REST APIs and RPC APIs. It is used by [gRPC](https://github.com/grpc). Each `Status` message contains three pieces of data: error code, error message, and error details. You can find out more about this error model and how to work with it in the [API Design Guide](https://cloud.google.com/apis/design/errors).""" + + code: Optional[int] = Field( + default=None, + description="""The status code, which should be an enum value of google.rpc.Code.""", + ) + details: Optional[list[dict[str, Any]]] = Field( + default=None, + description="""A list of messages that carry the error details. There is a common set of message types for APIs to use.""", + ) + message: Optional[str] = Field( + default=None, + description="""A developer-facing error message, which should be in English. Any user-facing error message should be localized and sent in the google.rpc.Status.details field, or localized by the client.""", + ) + + +class GoogleRpcStatusDict(TypedDict, total=False): + """The `Status` type defines a logical error model that is suitable for different programming environments, including REST APIs and RPC APIs. It is used by [gRPC](https://github.com/grpc). Each `Status` message contains three pieces of data: error code, error message, and error details. You can find out more about this error model and how to work with it in the [API Design Guide](https://cloud.google.com/apis/design/errors).""" + + code: Optional[int] + """The status code, which should be an enum value of google.rpc.Code.""" + + details: Optional[list[dict[str, Any]]] + """A list of messages that carry the error details. There is a common set of message types for APIs to use.""" + + message: Optional[str] + """A developer-facing error message, which should be in English. Any user-facing error message should be localized and sent in the google.rpc.Status.details field, or localized by the client.""" + + +GoogleRpcStatusOrDict = Union[GoogleRpcStatus, GoogleRpcStatusDict] + + +class UploadRagFileResponse(_common.BaseModel): + """Response for uploading a Rag File.""" + + error: Optional[GoogleRpcStatus] = Field( + default=None, + description="""The error that occurred while processing the RagFile.""", + ) + rag_file: Optional[RagFile] = Field( + default=None, + description="""The RagFile that had been uploaded into the RagCorpus.""", + ) + + +class UploadRagFileResponseDict(TypedDict, total=False): + """Response for uploading a Rag File.""" + + error: Optional[GoogleRpcStatusDict] + """The error that occurred while processing the RagFile.""" + + rag_file: Optional[RagFileDict] + """The RagFile that had been uploaded into the RagCorpus.""" + + +UploadRagFileResponseOrDict = Union[UploadRagFileResponse, UploadRagFileResponseDict] + + class GetAgentEngineRuntimeRevisionConfig(_common.BaseModel): """Config for getting an Agent Engine Runtime Revision.""" diff --git a/tests/unit/agentplatform/genai/replays/test_rag_upload.py b/tests/unit/agentplatform/genai/replays/test_rag_upload.py new file mode 100644 index 0000000000..fe1e50f062 --- /dev/null +++ b/tests/unit/agentplatform/genai/replays/test_rag_upload.py @@ -0,0 +1,65 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# pylint: disable=protected-access,bad-continuation,missing-function-docstring + +import pytest +from agentplatform._genai import types +from tests.unit.agentplatform.genai.replays import pytest_helper + +pytestmark = pytest_helper.setup( + file=__file__, + globals_for_file=globals(), +) + +file_display_name = "Test rag file upload" + + +def test_upload_rag_file(client, tmp_path, monkeypatch): + file_name = "test_replay_upload.txt" + + file_path = tmp_path / file_name + file_path.write_text("This is a test file for RAG upload.") + + monkeypatch.chdir(tmp_path) + + uploaded_file = client.rag.upload_file( + corpus_name="projects/vertex-sdk-dev/locations/us-central1/ragCorpora/5400941853124067328", + path=file_name, + display_name=file_display_name, + ) + + assert isinstance(uploaded_file, types.RagFile) + assert uploaded_file.display_name == file_display_name + + +pytest_plugins = ("pytest_asyncio",) + + +@pytest.mark.asyncio +async def test_upload_rag_file_async(client, tmp_path, monkeypatch): + file_name = "test_rag_file_upload.txt" + file_path = tmp_path / file_name + file_path.write_text("This is a test file for async RAG upload.") + + monkeypatch.chdir(tmp_path) + + uploaded_file = await client.aio.rag.upload_file( + corpus_name="projects/vertex-sdk-dev/locations/us-central1/ragCorpora/5400941853124067328", + path=file_name, + display_name=file_display_name, + ) + + assert isinstance(uploaded_file, types.RagFile) + assert uploaded_file.display_name == file_display_name