Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/maxtext/utils/model_creation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import sys
from typing import Callable, overload
from etils import epath
from huggingface_hub import get_token
from flax import nnx
from flax.core.meta import Partitioned
import flax.linen as nn
Expand Down Expand Up @@ -811,8 +812,12 @@ def from_pretrained(
if config.convert_checkpoint_if_possible and not config.load_parameters_path:
if not (epath.Path(config.base_output_directory) / "0" / "items").exists():
# Try to convert checkpoint on the fly
if not config.hf_access_token:
raise ValueError("hf_access_token must be provided when not providing a pre-existing checkpoint")
hf_access_token = config.hf_access_token or get_token()
if not hf_access_token:
raise ValueError(
"hf_access_token must be provided (or authenticate via"
" huggingface-cli) when not providing a pre-existing checkpoint"
)

# Only process 0 performs the conversion; other processes wait at the barrier below.
# Otherwise every host would race to download from HF and concurrently write the same
Expand All @@ -830,8 +835,8 @@ def from_pretrained(
conversion_env = os.environ.copy()
conversion_env["JAX_PLATFORMS"] = "cpu"
# conversion_env["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={simulated_cpu_devices_count}"
if config.hf_access_token:
conversion_env["HF_TOKEN"] = config.hf_access_token
if hf_access_token:
conversion_env["HF_TOKEN"] = hf_access_token

to_maxtext_cmd = [
sys.executable,
Expand Down
97 changes: 97 additions & 0 deletions tests/unit/model_creation_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,3 +775,100 @@ def test_returns_linen_train_state_and_annotations(self):

if __name__ == "__main__":
unittest.main()


class TestFromPretrainedAuth(unittest.TestCase):
"""Tests for Hugging Face authentication in from_pretrained."""

def _make_nnx_metadata_mock(self):
meta = MagicMock()
meta.item_metadata.tree.keys.return_value = ["decoder"]
meta.item_metadata.tree.get.return_value = {}
return meta

@patch("maxtext.utils.model_creation_utils.ocp")
@patch("maxtext.utils.model_creation_utils.subprocess.run")
@patch("maxtext.utils.model_creation_utils.get_token")
@patch("maxtext.utils.model_creation_utils.epath.Path")
def test_auth_success_with_config_token(self, mock_path, mock_get_token, mock_run, mock_ocp):
config = _make_config(
convert_checkpoint_if_possible=True,
base_output_directory="gs://fake_bucket/fake_run",
hf_access_token="config_token",
)
mesh = _make_mesh(config)

mock_ckpt_path = MagicMock()
mock_ckpt_path.exists.return_value = False
mock_path.return_value.__truediv__.return_value.__truediv__.return_value = mock_ckpt_path

mock_ckptr = MagicMock()
mock_ckptr.metadata.return_value = self._make_nnx_metadata_mock()
mock_ckptr.restore.side_effect = lambda path, item=None, **kw: item
mock_ocp.Checkpointer.return_value = mock_ckptr
mock_ocp.checkpoint_utils.construct_restore_args.return_value = {}
mock_ocp.ArrayRestoreArgs = ocp.ArrayRestoreArgs

model = model_creation_utils.from_pretrained(config, mesh)
self.assertIsInstance(model, models.Transformer)

mock_run.assert_called_once()
called_env = mock_run.call_args[1].get("env", {})
self.assertEqual(called_env.get("HF_TOKEN"), "config_token")
mock_get_token.assert_not_called()

@patch("maxtext.utils.model_creation_utils.ocp")
@patch("maxtext.utils.model_creation_utils.subprocess.run")
@patch("maxtext.utils.model_creation_utils.get_token")
@patch("maxtext.utils.model_creation_utils.epath.Path")
def test_auth_success_with_cached_token(self, mock_path, mock_get_token, mock_run, mock_ocp):
config = _make_config(
convert_checkpoint_if_possible=True,
base_output_directory="gs://fake_bucket/fake_run",
hf_access_token="",
)
mesh = _make_mesh(config)

mock_ckpt_path = MagicMock()
mock_ckpt_path.exists.return_value = False
mock_path.return_value.__truediv__.return_value.__truediv__.return_value = mock_ckpt_path

mock_get_token.return_value = "cached_token"

mock_ckptr = MagicMock()
mock_ckptr.metadata.return_value = self._make_nnx_metadata_mock()
mock_ckptr.restore.side_effect = lambda path, item=None, **kw: item
mock_ocp.Checkpointer.return_value = mock_ckptr
mock_ocp.checkpoint_utils.construct_restore_args.return_value = {}
mock_ocp.ArrayRestoreArgs = ocp.ArrayRestoreArgs

model = model_creation_utils.from_pretrained(config, mesh)
self.assertIsInstance(model, models.Transformer)

mock_get_token.assert_called_once()
mock_run.assert_called_once()
called_env = mock_run.call_args[1].get("env", {})
self.assertEqual(called_env.get("HF_TOKEN"), "cached_token")

@patch("maxtext.utils.model_creation_utils.ocp")
@patch("maxtext.utils.model_creation_utils.subprocess.run")
@patch("maxtext.utils.model_creation_utils.get_token")
@patch("maxtext.utils.model_creation_utils.epath.Path")
def test_auth_failure_no_token(self, mock_path, mock_get_token, mock_run, mock_ocp):
config = _make_config(
convert_checkpoint_if_possible=True,
base_output_directory="gs://fake_bucket/fake_run",
hf_access_token="",
)
mesh = _make_mesh(config)

mock_ckpt_path = MagicMock()
mock_ckpt_path.exists.return_value = False
mock_path.return_value.__truediv__.return_value.__truediv__.return_value = mock_ckpt_path

mock_get_token.return_value = None

with self.assertRaisesRegex(ValueError, "hf_access_token must be provided"):
model_creation_utils.from_pretrained(config, mesh)

mock_run.assert_not_called()
Loading