Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions dataflow/gemma/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# This uses Ubuntu with Python 3.11
# You can check the Python version for a given tensorflow
# container at https://hub.docker.com/r/tensorflow/tensorflow/tags
ARG SERVING_BUILD_IMAGE=tensorflow/tensorflow:2.16.1-gpu
ARG SERVING_BUILD_IMAGE=tensorflow/tensorflow:2.20.0-gpu

FROM ${SERVING_BUILD_IMAGE}

Expand All @@ -29,7 +29,7 @@ RUN pip install --upgrade --no-cache-dir pip \
&& pip install --no-cache-dir -r requirements.txt

# Copy files from official SDK image, including script/dependencies.
COPY --from=apache/beam_python3.14_sdk:2.73.0 /opt/apache/beam /opt/apache/beam
COPY --from=apache/beam_python3.11_sdk:2.74.0 /opt/apache/beam /opt/apache/beam

# Copy the model directory downloaded from Kaggle and the pipeline code.
COPY gemma_2b gemma_2B
Expand Down
22 changes: 13 additions & 9 deletions dataflow/gemma/custom_model_gemma.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
self,
model_name: str = "gemma_2B",
):
""" Implementation of the ModelHandler interface for Gemma using text as input.
"""Implementation of the ModelHandler interface for Gemma using text as input.

Example Usage::

Expand All @@ -48,7 +48,7 @@ def __init__(
self._env_vars = {}

def share_model_across_processes(self) -> bool:
""" Indicates if the model should be loaded once-per-VM rather than
"""Indicates if the model should be loaded once-per-VM rather than
once-per-worker-process on a VM. Because Gemma is a large language model,
this will always return True to avoid OOM errors.
"""
Expand All @@ -62,7 +62,7 @@ def run_inference(
self,
batch: Sequence[str],
model: GemmaCausalLM,
inference_args: Optional[dict[str, Any]] = None
inference_args: Optional[dict[str, Any]] = None,
) -> Iterable[PredictionResult]:
"""Runs inferences on a batch of text strings.

Expand All @@ -85,7 +85,8 @@ def run_inference(
class FormatOutput(beam.DoFn):
def process(self, element, *args, **kwargs):
yield "Input: {input}, Output: {output}".format(
input=element.example, output=element.inference)
input=element.example, output=element.inference
)


if __name__ == "__main__":
Expand Down Expand Up @@ -119,13 +120,16 @@ def process(self, element, *args, **kwargs):

pipeline = beam.Pipeline(options=beam_options)
_ = (
pipeline | "Read Topic" >>
beam.io.ReadFromPubSub(subscription=args.messages_subscription)
pipeline
| "Read Topic"
>> beam.io.ReadFromPubSub(subscription=args.messages_subscription)
| "Parse" >> beam.Map(lambda x: x.decode("utf-8"))
| "RunInference-Gemma" >> RunInference(
| "RunInference-Gemma"
>> RunInference(
GemmaModelHandler(args.model_path)
) # Send the prompts to the model and get responses.
| "Format Output" >> beam.ParDo(FormatOutput()) # Format the output.
| "Publish Result" >>
beam.io.gcp.pubsub.WriteStringsToPubSub(topic=args.responses_topic))
| "Publish Result"
>> beam.io.gcp.pubsub.WriteStringsToPubSub(topic=args.responses_topic)
)
pipeline.run()
41 changes: 21 additions & 20 deletions dataflow/gemma/e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
NOTE: For the tests to find the conftest in the testing infrastructure,
add the PYTHONPATH to the "env" in your noxfile_config.py file.
"""

from collections.abc import Callable, Iterator

import conftest # python-docs-samples/dataflow/conftest.py
Expand Down Expand Up @@ -70,8 +71,9 @@ def messages_topic(pubsub_topic: Callable[[str], str]) -> str:


@pytest.fixture(scope="session")
def messages_subscription(pubsub_subscription: Callable[[str, str], str],
messages_topic: str) -> str:
def messages_subscription(
pubsub_subscription: Callable[[str, str], str], messages_topic: str
) -> str:
return pubsub_subscription("messages", messages_topic)


Expand All @@ -81,20 +83,21 @@ def responses_topic(pubsub_topic: Callable[[str], str]) -> str:


@pytest.fixture(scope="session")
def responses_subscription(pubsub_subscription: Callable[[str, str], str],
responses_topic: str) -> str:
def responses_subscription(
pubsub_subscription: Callable[[str, str], str], responses_topic: str
) -> str:
return pubsub_subscription("responses", responses_topic)


@pytest.fixture(scope="session")
def dataflow_job(
project: str,
bucket_name: str,
location: str,
unique_name: str,
container_image: str,
messages_subscription: str,
responses_topic: str,
project: str,
bucket_name: str,
location: str,
unique_name: str,
container_image: str,
messages_subscription: str,
responses_topic: str,
) -> Iterator[str]:
# Launch the streaming Dataflow pipeline.
conftest.run_cmd(
Expand Down Expand Up @@ -127,20 +130,18 @@ def dataflow_job(

@pytest.mark.timeout(3600)
def test_pipeline_dataflow(
project: str,
location: str,
dataflow_job: str,
messages_topic: str,
responses_subscription: str,
project: str,
location: str,
dataflow_job: str,
messages_topic: str,
responses_subscription: str,
) -> None:
print(f"Waiting for the Dataflow workers to start: {dataflow_job}")
conftest.wait_until(
lambda: conftest.dataflow_num_workers(project, location, dataflow_job)
> 0,
lambda: conftest.dataflow_num_workers(project, location, dataflow_job) > 0,
"workers are running",
)
num_workers = conftest.dataflow_num_workers(project, location,
dataflow_job)
num_workers = conftest.dataflow_num_workers(project, location, dataflow_job)
print(f"Dataflow job num_workers: {num_workers}")

messages = ["This is a test for a Python sample."]
Expand Down
8 changes: 3 additions & 5 deletions dataflow/gemma/noxfile_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
# You can opt out from the test for specific Python versions.
# The Python version used is defined by the Dockerfile and the job
# submission enviornment must match.
# Note: Docker-based sample, testing only against version specified in Dockerfile (3.14)
"ignored_versions": ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"],
"envs": {
"PYTHONPATH": ".."
},
# Note: Docker-based sample, testing only against version specified in Dockerfile (3.11)
"ignored_versions": ["3.8", "3.9", "3.10"],

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The comment on line 21 states that this is a Docker-based sample, testing only against the Python version specified in the Dockerfile (3.11). However, the ignored_versions list only ignores 3.8, 3.9, and 3.10, which means tests will still run on 3.12, 3.13, and 3.14. To align with the comment and avoid redundant test runs on other Python versions, please ignore all versions except 3.11.

Suggested change
"ignored_versions": ["3.8", "3.9", "3.10"],
"ignored_versions": ["3.8", "3.9", "3.10", "3.12", "3.13", "3.14"],
References
  1. When upgrading test dependencies that require a minimum Python version, ensure that test configurations (e.g., ignored_versions in noxfile_config.py) are updated to ignore incompatible Python versions.

"envs": {"PYTHONPATH": ".."},
}
10 changes: 5 additions & 5 deletions dataflow/gemma/requirements-test.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
google-cloud-aiplatform==1.49.0
google-cloud-dataflow-client==0.8.10
google-cloud-storage==2.16.0
pytest==9.0.3; python_version >= "3.10"
pytest-timeout==2.3.1
google-cloud-aiplatform==1.157.0
google-cloud-dataflow-client==0.14.0
google-cloud-storage==3.12.0
pytest==9.0.3
pytest-timeout==2.4.0
9 changes: 5 additions & 4 deletions dataflow/gemma/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
apache_beam[gcp]==2.54.0
protobuf==4.25.0
keras_nlp==0.8.2
keras==3.0.5
protobuf==6.33.6
apache_beam[gcp]==2.74.0
keras==3.14.1
keras_nlp==0.29.1
pyOpenSSL==25.3.0
Loading