diff --git a/src/maxtext/models/deepseek.py b/src/maxtext/models/deepseek.py index d3a72b31bf..3ec5c32f1f 100644 --- a/src/maxtext/models/deepseek.py +++ b/src/maxtext/models/deepseek.py @@ -303,9 +303,9 @@ def self_attention_with_norm_op( return hidden_states, intermediate_inputs def engram_op(self, x, decoder_input_tokens): - normed_x = self.engram_layer_norm(x) + normed_x = self.engram_layer_norm(x) # pyrefly: ignore[not-callable] hash_ids = self.ngram_hash_mapping(decoder_input_tokens)[self.layer_idx] - return self.engram(normed_x, hash_ids) + return self.engram(normed_x, hash_ids) # pyrefly: ignore[not-callable] class DeepSeekDenseLayer(DeepSeekGenericLayer): diff --git a/src/maxtext/models/gpt3.py b/src/maxtext/models/gpt3.py index 7cf8fbd773..b09523e453 100644 --- a/src/maxtext/models/gpt3.py +++ b/src/maxtext/models/gpt3.py @@ -330,7 +330,7 @@ def init_kv_caches(self, inputs_kv_shape: tuple[int, ...]): ) def update_kv_caches(self, key, value, decoder_segment_ids, model_mode, previous_chunk): - prefill_kv_cache, ar_kv_cache = self.KVCache_0( + prefill_kv_cache, ar_kv_cache = self.KVCache_0( # pyrefly: ignore[not-callable] key=key, value=value, decoder_segment_ids=decoder_segment_ids, diff --git a/src/maxtext/models/llama2.py b/src/maxtext/models/llama2.py index eeb54da934..fa0f8b85e5 100644 --- a/src/maxtext/models/llama2.py +++ b/src/maxtext/models/llama2.py @@ -218,8 +218,8 @@ def update_cache(cache, val): return cache.at[layer_idx].set(val) return cache - stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache) - return (layer_output, stacked_kv_cache, layer_idx + 1), None + stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache) # pyrefly: ignore[unbound-name] + return (layer_output, stacked_kv_cache, layer_idx + 1), None # pyrefly: ignore[unbound-name] elif cfg.scan_layers: return layer_output, None else: diff --git a/src/maxtext/models/llama4.py b/src/maxtext/models/llama4.py index 26fd4d322d..65a22c41e2 100644 --- a/src/maxtext/models/llama4.py +++ b/src/maxtext/models/llama4.py @@ -50,7 +50,7 @@ class Llama4UnfoldConvolution(nnx.Module): config: Config containing model parameters """ - def __init__(self, config: Config, *, rngs: nnx.Rngs = None): + def __init__(self, config: Config, *, rngs: nnx.Rngs = None): # pyrefly: ignore[bad-function-definition] self.config = config self.rngs = rngs self.vit_unfold_linear = linears.DenseGeneral( @@ -123,7 +123,7 @@ class Llama4VisionMLP(nnx.Module): config: Config containing model parameters """ - def __init__(self, config: Config, *, rngs: nnx.Rngs = None): + def __init__(self, config: Config, *, rngs: nnx.Rngs = None): # pyrefly: ignore[bad-function-definition] self.config = config self.rngs = rngs self.vit_encoder_layer_mlp_fc1 = linears.DenseGeneral( @@ -157,7 +157,7 @@ class Llama4VisionMLP2(nnx.Module): config: Config containing model parameters """ - def __init__(self, config: Config, *, rngs: nnx.Rngs = None): + def __init__(self, config: Config, *, rngs: nnx.Rngs = None): # pyrefly: ignore[bad-function-definition] self.config = config self.rngs = rngs self.vit_pixel_shuffle_mlp_fc1 = linears.DenseGeneral( @@ -196,7 +196,7 @@ class Llama4VisionPixelShuffleMLP(nnx.Module): config: Config containing model parameters """ - def __init__(self, config: Config, *, rngs: nnx.Rngs = None): + def __init__(self, config: Config, *, rngs: nnx.Rngs = None): # pyrefly: ignore[bad-function-definition] self.config = config self.rngs = rngs self.pixel_shuffle_ratio = self.config.pixel_shuffle_ratio_for_vit @@ -221,7 +221,7 @@ class Llama4MultiModalProjector(nnx.Module): config: Config containing model parameters """ - def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None): + def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None): # pyrefly: ignore[bad-function-definition] self.config = config self.mesh = mesh self.rngs = rngs @@ -515,8 +515,8 @@ def update_cache(cache, val): return cache.at[layer_idx].set(val) return cache - stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache) - return (layer_output, stacked_kv_cache, layer_idx + 1), None + stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache) # pyrefly: ignore[unbound-name] + return (layer_output, stacked_kv_cache, layer_idx + 1), None # pyrefly: ignore[unbound-name] elif cfg.scan_layers: return layer_output, None else: @@ -623,7 +623,7 @@ def __call__( class Llama4VisionEncoderLayer(nnx.Module): """Transformer encoder layer for Llama4 vision model.""" - def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None): + def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None): # pyrefly: ignore[bad-function-definition] self.config = config self.mesh = mesh self.rngs = rngs @@ -701,7 +701,7 @@ class Llama4VisionEncoder(nnx.Module): mesh: Mesh, JAX device mesh (used for sharding) """ - def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None): + def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None): # pyrefly: ignore[bad-function-definition] self.config = config self.mesh = mesh self.rngs = rngs @@ -733,7 +733,7 @@ class Llama4VisionModel(nnx.Module): mesh: Mesh, JAX device mesh (used for sharding) """ - def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None): + def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None): # pyrefly: ignore[bad-function-definition] self.config = config self.mesh = mesh self.rngs = rngs diff --git a/src/maxtext/models/mistral.py b/src/maxtext/models/mistral.py index fa7f0956d4..78fbcfa1d7 100644 --- a/src/maxtext/models/mistral.py +++ b/src/maxtext/models/mistral.py @@ -192,8 +192,8 @@ def update_cache(cache, val): return cache.at[layer_idx].set(val) return cache - stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache) - return (layer_output, stacked_kv_cache, layer_idx + 1), None + stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache) # pyrefly: ignore[unbound-name] + return (layer_output, stacked_kv_cache, layer_idx + 1), None # pyrefly: ignore[unbound-name] elif cfg.scan_layers: return layer_output, None else: diff --git a/src/maxtext/models/models.py b/src/maxtext/models/models.py index ac908c0f96..881a1687b5 100644 --- a/src/maxtext/models/models.py +++ b/src/maxtext/models/models.py @@ -102,9 +102,9 @@ def setup(self): self.mtp_block = multi_token_prediction_block_as_linen( config=self.config, mesh=self.mesh, - transformer_layer_module=mtp_layer_nnx, - decoder=self.decoder, - rngs=self.make_rng("mtp_block"), + transformer_layer_module=mtp_layer_nnx, # pyrefly: ignore[bad-argument-type] + decoder=self.decoder, # pyrefly: ignore[bad-argument-type] + rngs=self.make_rng("mtp_block"), # pyrefly: ignore[bad-argument-type] ) def logits_from_hidden_states_for_vocab_tiling(self, hidden_states, deterministic, model_mode): @@ -163,7 +163,7 @@ def __call__( deepstack_visual_embeds = None if self.config.use_multimodal and encoder_images is not None: - image_embeddings, deepstack_visual_embeds = self.vision_encoder( + image_embeddings, deepstack_visual_embeds = self.vision_encoder( # pyrefly: ignore[not-callable] input_images=encoder_images, deterministic=not enable_dropout ) bidirectional_mask_image = mm_processor.get_bidirectional_mask_vision( @@ -171,7 +171,7 @@ def __call__( ) if self.config.use_multimodal and encoder_videos is not None: - video_embeddings, deepstack_visual_embeds = self.vision_encoder( + video_embeddings, deepstack_visual_embeds = self.vision_encoder( # pyrefly: ignore[not-callable] input_images=encoder_videos, deterministic=not enable_dropout ) bidirectional_mask_video = mm_processor.get_bidirectional_mask_vision( @@ -384,7 +384,7 @@ def __init__( dummy_attention_metadata = None if not cfg.pure_nnx_decoder: - self.decoder.lazy_init( + self.decoder.lazy_init( # pyrefly: ignore[missing-attribute] shared_embedding=self.token_embedder, decoder_input_tokens=dummy_decoder_input_tokens, decoder_positions=dummy_decoder_positions, @@ -490,7 +490,7 @@ def __call__( audio_embeddings = None deepstack_visual_embeds = None if self.config.use_multimodal and encoder_images is not None: - image_embeddings, deepstack_visual_embeds = self.vision_encoder( + image_embeddings, deepstack_visual_embeds = self.vision_encoder( # pyrefly: ignore[not-callable] input_images=encoder_images, deterministic=not enable_dropout ) bidirectional_mask_image = mm_processor.get_bidirectional_mask_vision( @@ -498,7 +498,7 @@ def __call__( ) if self.config.use_multimodal and encoder_videos is not None: - video_embeddings, deepstack_visual_embeds = self.vision_encoder( + video_embeddings, deepstack_visual_embeds = self.vision_encoder( # pyrefly: ignore[not-callable] input_images=encoder_videos, deterministic=not enable_dropout ) bidirectional_mask_video = mm_processor.get_bidirectional_mask_vision( @@ -563,7 +563,7 @@ def __call__( kv_caches=kv_caches, attention_metadata=attention_metadata, deepstack_visual_embeds=deepstack_visual_embeds, - mutable=mutable_collections, + mutable=mutable_collections, # pyrefly: ignore[unexpected-keyword] ) # pytype: disable=wrong-keyword-args # If we are initializing the model AND MTP is enabled, we must create diff --git a/src/maxtext/utils/gradient_accumulation.py b/src/maxtext/utils/gradient_accumulation.py index 46c37354bb..35fdf65503 100644 --- a/src/maxtext/utils/gradient_accumulation.py +++ b/src/maxtext/utils/gradient_accumulation.py @@ -157,7 +157,7 @@ def reshape_to_microbatch_accumulations(batch_arr): "ga_params": ga_params, } if is_nnx: - init_grad_and_loss["rest_state"] = rest + init_grad_and_loss["rest_state"] = rest # pyrefly: ignore[unbound-name] grad_and_loss, aux = jax.lax.scan( accumulate_gradient, init_grad_and_loss, data, length=config.gradient_accumulation_steps diff --git a/src/maxtext/utils/max_utils.py b/src/maxtext/utils/max_utils.py index 4ee35a3844..99a49383cd 100644 --- a/src/maxtext/utils/max_utils.py +++ b/src/maxtext/utils/max_utils.py @@ -336,8 +336,8 @@ def initialize_jax_for_gpu(raw_keys): jax.distributed.initialize( coordinator_address=f"{coordinator_ip}:{coordinator_port}", - num_processes=int(os.getenv("NNODES")), - process_id=int(os.getenv("NODE_RANK")), + num_processes=int(os.getenv("NNODES")), # pyrefly: ignore[bad-argument-type] + process_id=int(os.getenv("NODE_RANK")), # pyrefly: ignore[bad-argument-type] initialization_timeout=raw_keys["jax_distributed_initialization_timeout"], local_device_ids=devices, ) @@ -349,16 +349,16 @@ def initialize_jax_for_cpu(raw_keys): coordinator_ip_address = get_coordinator_ip_address() coordinator_address = coordinator_ip_address + ":1234" # JAX coordinator port used in XPK # Env variables to be set in XPK or otherwise - job_index = int(os.environ.get("JOB_INDEX")) - job_completion_index = int(os.environ.get("JOB_COMPLETION_INDEX")) - processes_in_job = int(os.environ.get("PROCESSES_IN_JOB")) + job_index = int(os.environ.get("JOB_INDEX")) # pyrefly: ignore[bad-argument-type] + job_completion_index = int(os.environ.get("JOB_COMPLETION_INDEX")) # pyrefly: ignore[bad-argument-type] + processes_in_job = int(os.environ.get("PROCESSES_IN_JOB")) # pyrefly: ignore[bad-argument-type] pid = job_index * processes_in_job + job_completion_index max_logging.log(f" Jax process id is {pid} ") # Explicit initialize is needed only for CPUs jax.distributed.initialize( coordinator_address=coordinator_address, process_id=pid, - num_processes=int(os.environ.get("JAX_PROCESS_COUNT")), + num_processes=int(os.environ.get("JAX_PROCESS_COUNT")), # pyrefly: ignore[bad-argument-type] initialization_timeout=raw_keys["jax_distributed_initialization_timeout"], ) @@ -444,7 +444,7 @@ def get_coordinator_ip_address(): max_coordinator_lookups = 50 while not coordinator_found and lookup_attempt <= max_coordinator_lookups: try: - coordinator_ip_address = socket.gethostbyname(coordinator_address) + coordinator_ip_address = socket.gethostbyname(coordinator_address) # pyrefly: ignore[bad-argument-type] coordinator_found = True except socket.gaierror: max_logging.log( @@ -676,7 +676,7 @@ def _cross_entropy_with_logits_fwd(logits: jnp.ndarray, targets: jnp.ndarray, z_ log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1) total_z_loss = z_loss * jax.lax.square(log_z) loss += total_z_loss - return (loss, total_z_loss), ( + return (loss, total_z_loss), ( # pyrefly: ignore[bad-return] logits, targets, z_loss, @@ -698,11 +698,11 @@ def _cross_entropy_with_logits_bwd( g: tuple[jnp.ndarray, jnp.ndarray], ) -> tuple[jnp.ndarray, None, None]: """Backward-mode of `cross_entropy_with_logits`.""" - g = g[0] # Ignore z_loss component as that is only used for logging. + g = g[0] # Ignore z_loss component as that is only used for logging. # pyrefly: ignore[bad-assignment] logits, targets, z_loss, exp_shifted, sum_exp, log_z = res # z-loss term adds the (2 * z_loss * log_z) factor. deriv = jnp.expand_dims(1 + 2 * z_loss * log_z, -1) * exp_shifted / sum_exp - targets - g_logits = jnp.expand_dims(g, axis=-1) * deriv + g_logits = jnp.expand_dims(g, axis=-1) * deriv # pyrefly: ignore[bad-argument-type] return ( jnp.asarray(g_logits, logits.dtype), @@ -1214,7 +1214,7 @@ def transformer_engine_context(): tp_resource="tensor", # tpsp_resource = "tensor_sequence", #TODO(Phuong): add this back when upstreaming CGEMM fsdp_resource="fsdp", - pp_resource=None, + pp_resource=None, # pyrefly: ignore[bad-argument-type] cp_resource="context", ) with global_shard_guard(mesh_resource): diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 47c2c7e587..930ff90540 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -95,7 +95,7 @@ def get_functional_train_with_signature( ): """Get the shardings (both state and data) for `train_step`.""" functional_train = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings) - functional_train.__name__ = "train_step" + functional_train.__name__ = "train_step" # pyrefly: ignore[missing-attribute] if config.pure_nnx: in_shardings = (state_mesh_shardings, data_sharding) # State, batch else: @@ -109,7 +109,7 @@ def get_functional_train_with_signature( def get_functional_eval_with_signature(eval_step, data_sharding, state_mesh_shardings, model, config): """Get the shardings (both state and data) for `eval_step`.""" functional_eval = functools.partial(eval_step, model, config) - functional_eval.__name__ = "eval_step" + functional_eval.__name__ = "eval_step" # pyrefly: ignore[missing-attribute] if config.pure_nnx: in_shardings = (state_mesh_shardings, data_sharding) # State, batch (NNX: no rng) else: @@ -1392,7 +1392,7 @@ def get_abstract_param(model, config): {"params": key, "dropout": key, "aqt": key}, np.ones(input_shape, dtype=jnp.int32), np.ones(input_shape, dtype=jnp.int32), - encoder_images=np.ones(image_shape, dtype=jnp.int32) if config.use_multimodal else None, + encoder_images=np.ones(image_shape, dtype=jnp.int32) if config.use_multimodal else None, # pyrefly: ignore[no-matching-overload] encoder_audios=np.ones(audio_shape, dtype=jnp.float32) if config.use_audio else None, ) return abstract_vars @@ -2022,7 +2022,7 @@ def to_abstract(x): # Convert all input arguments recursively to purely local abstract ShapeDtypeStruct objects # to completely bypass remote Array objects and proxy tracing overhead. abstract_inputs = jax.tree.map(to_abstract, train_step_inputs) - p_train_jaxpr = jax.make_jaxpr(unwrapped_step)(*abstract_inputs) + p_train_jaxpr = jax.make_jaxpr(unwrapped_step)(*abstract_inputs) # pyrefly: ignore[no-matching-overload] local_filename = "train_step.jaxpr" local_path = os.path.join(config.dump_jaxpr_local_dir, local_filename) diff --git a/src/maxtext/utils/maxtext_utils_nnx.py b/src/maxtext/utils/maxtext_utils_nnx.py index 20494dbc71..7b265ab3bd 100644 --- a/src/maxtext/utils/maxtext_utils_nnx.py +++ b/src/maxtext/utils/maxtext_utils_nnx.py @@ -164,7 +164,7 @@ def create_nnx_sharded_model( named_sharding = nnx_extract_named_sharding(abstract_state) if mesh is None: - mesh = abstract_model.mesh + mesh = abstract_model.mesh # pyrefly: ignore[missing-attribute] # JIT a function that creates the model state with proper sharding from the start. # By providing out_shardings, we instruct JAX to produce sharded output directly, diff --git a/src/maxtext/utils/qk_clip_utils.py b/src/maxtext/utils/qk_clip_utils.py index 87c3688fcf..7c26f130d4 100644 --- a/src/maxtext/utils/qk_clip_utils.py +++ b/src/maxtext/utils/qk_clip_utils.py @@ -153,7 +153,7 @@ def apply_qk_clip_nnx(state, intermediate_outputs, config): tau = float(config.qk_clip_threshold) _, params_state, _ = nnx.split(state.model, nnx.Param, ...) - params_dict = params_state.to_pure_dict() + params_dict = params_state.to_pure_dict() # pyrefly: ignore[missing-attribute] def clip_mla_weights(path, param): if len(path) < 2: diff --git a/src/maxtext/utils/sharding.py b/src/maxtext/utils/sharding.py index dc2f70ae14..01e961e2a4 100644 --- a/src/maxtext/utils/sharding.py +++ b/src/maxtext/utils/sharding.py @@ -540,7 +540,7 @@ def add_data_to_sharding(mesh, path, aval, sharding): partition = (partition,) if size % mesh.shape["data"] == 0 and (partition is None or "tensor" not in partition): - added_component = ("data",) + partition + added_component = ("data",) + partition # pyrefly: ignore[unsupported-operation] new_pspec = jax.sharding.PartitionSpec(*(pspec[:idx] + (added_component,) + pspec[idx + 1 :])) new_sharding = jax.sharding.NamedSharding(sharding.mesh, new_pspec) return new_sharding @@ -581,7 +581,7 @@ def maybe_update_params_sharding_with_opt(config, state_mesh_shardings): # When quantization=fp8 is enabled the sharded_fp32_params # are not wrapped in `params`. Here we wrap them back. sharded_fp32_params = {"params": sharded_fp32_params} - state_mesh_shardings = state_mesh_shardings.replace(params=dict(prev_params_shardings, **sharded_fp32_params)) + state_mesh_shardings = state_mesh_shardings.replace(params=dict(prev_params_shardings, **sharded_fp32_params)) # pyrefly: ignore[bad-unpacking] return prev_params_shardings, state_mesh_shardings diff --git a/src/maxtext/utils/train_utils.py b/src/maxtext/utils/train_utils.py index 73e5d06b05..83451a5948 100644 --- a/src/maxtext/utils/train_utils.py +++ b/src/maxtext/utils/train_utils.py @@ -331,11 +331,11 @@ def create_train_state_fn(): inner_state_shardings, # Match the outer params' pure-dict structure (build_diloco_state stores # outer_params via to_pure_dict), so the sharding tree matches the state tree. - state_mesh_shardings_params.to_pure_dict() + state_mesh_shardings_params.to_pure_dict() # pyrefly: ignore[missing-attribute] if config.pure_nnx else state_mesh_shardings_params, outer_opt_state_sharding, - jax.sharding.NamedSharding( + jax.sharding.NamedSharding( # pyrefly: ignore[bad-argument-type] mesh=step_mesh, spec=jax.sharding.PartitionSpec() ), ) @@ -355,7 +355,7 @@ def create_train_state_fn(): logical_annotations = maxtext_utils.get_logical_annotations(config, mesh, init_state_fn) logical_annotations_params = logical_annotations.params - max_utils.print_non_trivial_mesh_axis(model.mesh) + max_utils.print_non_trivial_mesh_axis(model.mesh) # pyrefly: ignore[missing-attribute] maxtext_utils.print_shardings_params(state_params, state_mesh_shardings_params, mesh, logical_annotations_params) if config.pure_nnx: @@ -363,9 +363,9 @@ def create_train_state_fn(): # Don't merge the DiLoCoTrainState into the plain-model graphdef. The inner # train step needs that graphdef as jit_model; the wrapper passes through as state. train_state = state - model = state_graphdef + model = state_graphdef # pyrefly: ignore[unbound-name] else: - train_state = nnx.merge(state_graphdef, state) + train_state = nnx.merge(state_graphdef, state) # pyrefly: ignore[unbound-name] model = train_state.model else: train_state = state