diff --git a/src/maxtext/utils/model_creation_utils.py b/src/maxtext/utils/model_creation_utils.py index c0248783eb..e65dcd6b8a 100644 --- a/src/maxtext/utils/model_creation_utils.py +++ b/src/maxtext/utils/model_creation_utils.py @@ -38,6 +38,7 @@ import sys from typing import Callable, overload from etils import epath +from huggingface_hub import get_token from flax import nnx from flax.core.meta import Partitioned import flax.linen as nn @@ -811,8 +812,12 @@ def from_pretrained( if config.convert_checkpoint_if_possible and not config.load_parameters_path: if not (epath.Path(config.base_output_directory) / "0" / "items").exists(): # Try to convert checkpoint on the fly - if not config.hf_access_token: - raise ValueError("hf_access_token must be provided when not providing a pre-existing checkpoint") + hf_access_token = config.hf_access_token or get_token() + if not hf_access_token: + raise ValueError( + "hf_access_token must be provided (or authenticate via" + " huggingface-cli) when not providing a pre-existing checkpoint" + ) # Only process 0 performs the conversion; other processes wait at the barrier below. # Otherwise every host would race to download from HF and concurrently write the same @@ -830,8 +835,8 @@ def from_pretrained( conversion_env = os.environ.copy() conversion_env["JAX_PLATFORMS"] = "cpu" # conversion_env["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={simulated_cpu_devices_count}" - if config.hf_access_token: - conversion_env["HF_TOKEN"] = config.hf_access_token + if hf_access_token: + conversion_env["HF_TOKEN"] = hf_access_token to_maxtext_cmd = [ sys.executable, diff --git a/tests/unit/model_creation_utils_test.py b/tests/unit/model_creation_utils_test.py index 2568547944..e668912778 100644 --- a/tests/unit/model_creation_utils_test.py +++ b/tests/unit/model_creation_utils_test.py @@ -775,3 +775,100 @@ def test_returns_linen_train_state_and_annotations(self): if __name__ == "__main__": unittest.main() + + +class TestFromPretrainedAuth(unittest.TestCase): + """Tests for Hugging Face authentication in from_pretrained.""" + + def _make_nnx_metadata_mock(self): + meta = MagicMock() + meta.item_metadata.tree.keys.return_value = ["decoder"] + meta.item_metadata.tree.get.return_value = {} + return meta + + @patch("maxtext.utils.model_creation_utils.ocp") + @patch("maxtext.utils.model_creation_utils.subprocess.run") + @patch("maxtext.utils.model_creation_utils.get_token") + @patch("maxtext.utils.model_creation_utils.epath.Path") + def test_auth_success_with_config_token(self, mock_path, mock_get_token, mock_run, mock_ocp): + config = _make_config( + convert_checkpoint_if_possible=True, + base_output_directory="gs://fake_bucket/fake_run", + hf_access_token="config_token", + ) + mesh = _make_mesh(config) + + mock_ckpt_path = MagicMock() + mock_ckpt_path.exists.return_value = False + mock_path.return_value.__truediv__.return_value.__truediv__.return_value = mock_ckpt_path + + mock_ckptr = MagicMock() + mock_ckptr.metadata.return_value = self._make_nnx_metadata_mock() + mock_ckptr.restore.side_effect = lambda path, item=None, **kw: item + mock_ocp.Checkpointer.return_value = mock_ckptr + mock_ocp.checkpoint_utils.construct_restore_args.return_value = {} + mock_ocp.ArrayRestoreArgs = ocp.ArrayRestoreArgs + + model = model_creation_utils.from_pretrained(config, mesh) + self.assertIsInstance(model, models.Transformer) + + mock_run.assert_called_once() + called_env = mock_run.call_args[1].get("env", {}) + self.assertEqual(called_env.get("HF_TOKEN"), "config_token") + mock_get_token.assert_not_called() + + @patch("maxtext.utils.model_creation_utils.ocp") + @patch("maxtext.utils.model_creation_utils.subprocess.run") + @patch("maxtext.utils.model_creation_utils.get_token") + @patch("maxtext.utils.model_creation_utils.epath.Path") + def test_auth_success_with_cached_token(self, mock_path, mock_get_token, mock_run, mock_ocp): + config = _make_config( + convert_checkpoint_if_possible=True, + base_output_directory="gs://fake_bucket/fake_run", + hf_access_token="", + ) + mesh = _make_mesh(config) + + mock_ckpt_path = MagicMock() + mock_ckpt_path.exists.return_value = False + mock_path.return_value.__truediv__.return_value.__truediv__.return_value = mock_ckpt_path + + mock_get_token.return_value = "cached_token" + + mock_ckptr = MagicMock() + mock_ckptr.metadata.return_value = self._make_nnx_metadata_mock() + mock_ckptr.restore.side_effect = lambda path, item=None, **kw: item + mock_ocp.Checkpointer.return_value = mock_ckptr + mock_ocp.checkpoint_utils.construct_restore_args.return_value = {} + mock_ocp.ArrayRestoreArgs = ocp.ArrayRestoreArgs + + model = model_creation_utils.from_pretrained(config, mesh) + self.assertIsInstance(model, models.Transformer) + + mock_get_token.assert_called_once() + mock_run.assert_called_once() + called_env = mock_run.call_args[1].get("env", {}) + self.assertEqual(called_env.get("HF_TOKEN"), "cached_token") + + @patch("maxtext.utils.model_creation_utils.ocp") + @patch("maxtext.utils.model_creation_utils.subprocess.run") + @patch("maxtext.utils.model_creation_utils.get_token") + @patch("maxtext.utils.model_creation_utils.epath.Path") + def test_auth_failure_no_token(self, mock_path, mock_get_token, mock_run, mock_ocp): + config = _make_config( + convert_checkpoint_if_possible=True, + base_output_directory="gs://fake_bucket/fake_run", + hf_access_token="", + ) + mesh = _make_mesh(config) + + mock_ckpt_path = MagicMock() + mock_ckpt_path.exists.return_value = False + mock_path.return_value.__truediv__.return_value.__truediv__.return_value = mock_ckpt_path + + mock_get_token.return_value = None + + with self.assertRaisesRegex(ValueError, "hf_access_token must be provided"): + model_creation_utils.from_pretrained(config, mesh) + + mock_run.assert_not_called()