Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 45 additions & 7 deletions google/cloud/aiplatform/utils/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,43 @@ def download_file_from_gcs(
source_blob.download_to_filename(filename=destination_file_path)


def _is_path_subdirectory(destination_path: str, blob_path: str) -> bool:
"""Checks that the path provided by the blob would produce a subdirectory of the destination path.

Args:
destination_path (str):
Required. The destination directory's path to use as a foundation.
blob_path (str):
Required. The file sub path provided by the blob.

Returns:
bool: True if the blob path is a subdirectory of the destination path, False otherwise.
"""
blob_path_obj = pathlib.Path(blob_path)
destination_path_obj = pathlib.Path(destination_path)
if any((part == ".." or ":" in part) for part in blob_path_obj.parts):
_logger.warning(
"The specified prefix '%s' contains '..' or ':', "
"which is not supported. Skipping blob '%s'.",
destination_path,
blob_path,
)
return False
resolved_path = (destination_path_obj / blob_path_obj).resolve()
resolved_destination_path = destination_path_obj.resolve()
if os.path.commonpath([str(resolved_destination_path), str(resolved_path)]) != str(
resolved_destination_path
):
_logger.warning(
"The specified prefix '%s' is not a subdirectory "
"of the destination path '%s'.",
blob_path,
destination_path,
)
return False
return True


def download_from_gcs(
source_uri: str,
destination_path: str,
Expand Down Expand Up @@ -452,13 +489,14 @@ def download_from_gcs(
# These files ends with '/', and we'll skip them.
if not blob.name.endswith("/"):
rel_path = os.path.relpath(blob.name, prefix)
filename = (
destination_path
if rel_path == "."
else os.path.join(destination_path, rel_path)
)
os.makedirs(os.path.dirname(filename), exist_ok=True)
blob.download_to_filename(filename=filename)
if _is_path_subdirectory(destination_path, rel_path):
filename = (
destination_path
if rel_path == "."
else os.path.join(destination_path, rel_path)
)
os.makedirs(os.path.dirname(filename), exist_ok=True)
blob.download_to_filename(filename=filename)


def _upload_pandas_df_to_gcs(
Expand Down
14 changes: 14 additions & 0 deletions tests/unit/aiplatform/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,20 @@ def test_download_from_gcs_invalid_source_uri(self):
):
gcs_utils.download_from_gcs(source_uri, destination_path)

def test_is_path_subdirectory(self):
assert gcs_utils._is_path_subdirectory(
destination_path="test_dir", blob_path="test_dir/test_file"
)
assert gcs_utils._is_path_subdirectory(
destination_path="test_dir", blob_path="test_file"
)
assert not gcs_utils._is_path_subdirectory(
destination_path="test_dir", blob_path="test_dir/../test_file"
)
assert not gcs_utils._is_path_subdirectory(
destination_path="test_dir", blob_path="test_dir/c:file"
)

def test_validate_gcs_path(self):
test_valid_path = "gs://test_valid_path"
gcs_utils.validate_gcs_path(test_valid_path)
Expand Down
Loading