From 2f5c8f81c272215951a35ca3334728427263e77b Mon Sep 17 00:00:00 2001 From: Jeff Scudder Date: Wed, 17 Jun 2026 18:59:26 -0700 Subject: [PATCH] fix: added safety checks in download_from_gcs PiperOrigin-RevId: 934043631 --- google/cloud/aiplatform/utils/gcs_utils.py | 52 +++++++++++++++++++--- tests/unit/aiplatform/test_utils.py | 14 ++++++ 2 files changed, 59 insertions(+), 7 deletions(-) diff --git a/google/cloud/aiplatform/utils/gcs_utils.py b/google/cloud/aiplatform/utils/gcs_utils.py index 5bebd9ee01..a6727adf2d 100644 --- a/google/cloud/aiplatform/utils/gcs_utils.py +++ b/google/cloud/aiplatform/utils/gcs_utils.py @@ -414,6 +414,43 @@ def download_file_from_gcs( source_blob.download_to_filename(filename=destination_file_path) +def _is_path_subdirectory(destination_path: str, blob_path: str) -> bool: + """Checks that the path provided by the blob would produce a subdirectory of the destination path. + + Args: + destination_path (str): + Required. The destination directory's path to use as a foundation. + blob_path (str): + Required. The file sub path provided by the blob. + + Returns: + bool: True if the blob path is a subdirectory of the destination path, False otherwise. + """ + blob_path_obj = pathlib.Path(blob_path) + destination_path_obj = pathlib.Path(destination_path) + if any((part == ".." or ":" in part) for part in blob_path_obj.parts): + _logger.warning( + "The specified prefix '%s' contains '..' or ':', " + "which is not supported. Skipping blob '%s'.", + destination_path, + blob_path, + ) + return False + resolved_path = (destination_path_obj / blob_path_obj).resolve() + resolved_destination_path = destination_path_obj.resolve() + if os.path.commonpath([str(resolved_destination_path), str(resolved_path)]) != str( + resolved_destination_path + ): + _logger.warning( + "The specified prefix '%s' is not a subdirectory " + "of the destination path '%s'.", + blob_path, + destination_path, + ) + return False + return True + + def download_from_gcs( source_uri: str, destination_path: str, @@ -452,13 +489,14 @@ def download_from_gcs( # These files ends with '/', and we'll skip them. if not blob.name.endswith("/"): rel_path = os.path.relpath(blob.name, prefix) - filename = ( - destination_path - if rel_path == "." - else os.path.join(destination_path, rel_path) - ) - os.makedirs(os.path.dirname(filename), exist_ok=True) - blob.download_to_filename(filename=filename) + if _is_path_subdirectory(destination_path, rel_path): + filename = ( + destination_path + if rel_path == "." + else os.path.join(destination_path, rel_path) + ) + os.makedirs(os.path.dirname(filename), exist_ok=True) + blob.download_to_filename(filename=filename) def _upload_pandas_df_to_gcs( diff --git a/tests/unit/aiplatform/test_utils.py b/tests/unit/aiplatform/test_utils.py index 0116fd1815..ca19c793d7 100644 --- a/tests/unit/aiplatform/test_utils.py +++ b/tests/unit/aiplatform/test_utils.py @@ -658,6 +658,20 @@ def test_download_from_gcs_invalid_source_uri(self): ): gcs_utils.download_from_gcs(source_uri, destination_path) + def test_is_path_subdirectory(self): + assert gcs_utils._is_path_subdirectory( + destination_path="test_dir", blob_path="test_dir/test_file" + ) + assert gcs_utils._is_path_subdirectory( + destination_path="test_dir", blob_path="test_file" + ) + assert not gcs_utils._is_path_subdirectory( + destination_path="test_dir", blob_path="test_dir/../test_file" + ) + assert not gcs_utils._is_path_subdirectory( + destination_path="test_dir", blob_path="test_dir/c:file" + ) + def test_validate_gcs_path(self): test_valid_path = "gs://test_valid_path" gcs_utils.validate_gcs_path(test_valid_path)