Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/maxtext/models/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,9 +303,9 @@ def self_attention_with_norm_op(
return hidden_states, intermediate_inputs

def engram_op(self, x, decoder_input_tokens):
normed_x = self.engram_layer_norm(x)
normed_x = self.engram_layer_norm(x) # pyrefly: ignore[not-callable]
hash_ids = self.ngram_hash_mapping(decoder_input_tokens)[self.layer_idx]
return self.engram(normed_x, hash_ids)
return self.engram(normed_x, hash_ids) # pyrefly: ignore[not-callable]


class DeepSeekDenseLayer(DeepSeekGenericLayer):
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/models/gpt3.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def init_kv_caches(self, inputs_kv_shape: tuple[int, ...]):
)

def update_kv_caches(self, key, value, decoder_segment_ids, model_mode, previous_chunk):
prefill_kv_cache, ar_kv_cache = self.KVCache_0(
prefill_kv_cache, ar_kv_cache = self.KVCache_0( # pyrefly: ignore[not-callable]
key=key,
value=value,
decoder_segment_ids=decoder_segment_ids,
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/models/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ def update_cache(cache, val):
return cache.at[layer_idx].set(val)
return cache

stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache)
return (layer_output, stacked_kv_cache, layer_idx + 1), None
stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache) # pyrefly: ignore[unbound-name]
return (layer_output, stacked_kv_cache, layer_idx + 1), None # pyrefly: ignore[unbound-name]
elif cfg.scan_layers:
return layer_output, None
else:
Expand Down
20 changes: 10 additions & 10 deletions src/maxtext/models/llama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class Llama4UnfoldConvolution(nnx.Module):
config: Config containing model parameters
"""

def __init__(self, config: Config, *, rngs: nnx.Rngs = None):
def __init__(self, config: Config, *, rngs: nnx.Rngs = None): # pyrefly: ignore[bad-function-definition]
self.config = config
self.rngs = rngs
self.vit_unfold_linear = linears.DenseGeneral(
Expand Down Expand Up @@ -123,7 +123,7 @@ class Llama4VisionMLP(nnx.Module):
config: Config containing model parameters
"""

def __init__(self, config: Config, *, rngs: nnx.Rngs = None):
def __init__(self, config: Config, *, rngs: nnx.Rngs = None): # pyrefly: ignore[bad-function-definition]
self.config = config
self.rngs = rngs
self.vit_encoder_layer_mlp_fc1 = linears.DenseGeneral(
Expand Down Expand Up @@ -157,7 +157,7 @@ class Llama4VisionMLP2(nnx.Module):
config: Config containing model parameters
"""

def __init__(self, config: Config, *, rngs: nnx.Rngs = None):
def __init__(self, config: Config, *, rngs: nnx.Rngs = None): # pyrefly: ignore[bad-function-definition]
self.config = config
self.rngs = rngs
self.vit_pixel_shuffle_mlp_fc1 = linears.DenseGeneral(
Expand Down Expand Up @@ -196,7 +196,7 @@ class Llama4VisionPixelShuffleMLP(nnx.Module):
config: Config containing model parameters
"""

def __init__(self, config: Config, *, rngs: nnx.Rngs = None):
def __init__(self, config: Config, *, rngs: nnx.Rngs = None): # pyrefly: ignore[bad-function-definition]
self.config = config
self.rngs = rngs
self.pixel_shuffle_ratio = self.config.pixel_shuffle_ratio_for_vit
Expand All @@ -221,7 +221,7 @@ class Llama4MultiModalProjector(nnx.Module):
config: Config containing model parameters
"""

def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None):
def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None): # pyrefly: ignore[bad-function-definition]
self.config = config
self.mesh = mesh
self.rngs = rngs
Expand Down Expand Up @@ -515,8 +515,8 @@ def update_cache(cache, val):
return cache.at[layer_idx].set(val)
return cache

stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache)
return (layer_output, stacked_kv_cache, layer_idx + 1), None
stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache) # pyrefly: ignore[unbound-name]
return (layer_output, stacked_kv_cache, layer_idx + 1), None # pyrefly: ignore[unbound-name]
elif cfg.scan_layers:
return layer_output, None
else:
Expand Down Expand Up @@ -623,7 +623,7 @@ def __call__(
class Llama4VisionEncoderLayer(nnx.Module):
"""Transformer encoder layer for Llama4 vision model."""

def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None):
def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None): # pyrefly: ignore[bad-function-definition]
self.config = config
self.mesh = mesh
self.rngs = rngs
Expand Down Expand Up @@ -701,7 +701,7 @@ class Llama4VisionEncoder(nnx.Module):
mesh: Mesh, JAX device mesh (used for sharding)
"""

def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None):
def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None): # pyrefly: ignore[bad-function-definition]
self.config = config
self.mesh = mesh
self.rngs = rngs
Expand Down Expand Up @@ -733,7 +733,7 @@ class Llama4VisionModel(nnx.Module):
mesh: Mesh, JAX device mesh (used for sharding)
"""

def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None):
def __init__(self, config: Config, mesh: Mesh, *, rngs: nnx.Rngs = None): # pyrefly: ignore[bad-function-definition]
self.config = config
self.mesh = mesh
self.rngs = rngs
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/models/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,8 +192,8 @@ def update_cache(cache, val):
return cache.at[layer_idx].set(val)
return cache

stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache)
return (layer_output, stacked_kv_cache, layer_idx + 1), None
stacked_kv_cache = jax.tree_util.tree_map(update_cache, stacked_kv_cache, kv_cache) # pyrefly: ignore[unbound-name]
return (layer_output, stacked_kv_cache, layer_idx + 1), None # pyrefly: ignore[unbound-name]
elif cfg.scan_layers:
return layer_output, None
else:
Expand Down
18 changes: 9 additions & 9 deletions src/maxtext/models/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ def setup(self):
self.mtp_block = multi_token_prediction_block_as_linen(
config=self.config,
mesh=self.mesh,
transformer_layer_module=mtp_layer_nnx,
decoder=self.decoder,
rngs=self.make_rng("mtp_block"),
transformer_layer_module=mtp_layer_nnx, # pyrefly: ignore[bad-argument-type]
decoder=self.decoder, # pyrefly: ignore[bad-argument-type]
rngs=self.make_rng("mtp_block"), # pyrefly: ignore[bad-argument-type]
)

def logits_from_hidden_states_for_vocab_tiling(self, hidden_states, deterministic, model_mode):
Expand Down Expand Up @@ -163,15 +163,15 @@ def __call__(
deepstack_visual_embeds = None

if self.config.use_multimodal and encoder_images is not None:
image_embeddings, deepstack_visual_embeds = self.vision_encoder(
image_embeddings, deepstack_visual_embeds = self.vision_encoder( # pyrefly: ignore[not-callable]
input_images=encoder_images, deterministic=not enable_dropout
)
bidirectional_mask_image = mm_processor.get_bidirectional_mask_vision(
self.config, decoder_input_tokens, is_video=False
)

if self.config.use_multimodal and encoder_videos is not None:
video_embeddings, deepstack_visual_embeds = self.vision_encoder(
video_embeddings, deepstack_visual_embeds = self.vision_encoder( # pyrefly: ignore[not-callable]
input_images=encoder_videos, deterministic=not enable_dropout
)
bidirectional_mask_video = mm_processor.get_bidirectional_mask_vision(
Expand Down Expand Up @@ -384,7 +384,7 @@ def __init__(
dummy_attention_metadata = None

if not cfg.pure_nnx_decoder:
self.decoder.lazy_init(
self.decoder.lazy_init( # pyrefly: ignore[missing-attribute]
shared_embedding=self.token_embedder,
decoder_input_tokens=dummy_decoder_input_tokens,
decoder_positions=dummy_decoder_positions,
Expand Down Expand Up @@ -490,15 +490,15 @@ def __call__(
audio_embeddings = None
deepstack_visual_embeds = None
if self.config.use_multimodal and encoder_images is not None:
image_embeddings, deepstack_visual_embeds = self.vision_encoder(
image_embeddings, deepstack_visual_embeds = self.vision_encoder( # pyrefly: ignore[not-callable]
input_images=encoder_images, deterministic=not enable_dropout
)
bidirectional_mask_image = mm_processor.get_bidirectional_mask_vision(
self.config, decoder_input_tokens, is_video=False
)

if self.config.use_multimodal and encoder_videos is not None:
video_embeddings, deepstack_visual_embeds = self.vision_encoder(
video_embeddings, deepstack_visual_embeds = self.vision_encoder( # pyrefly: ignore[not-callable]
input_images=encoder_videos, deterministic=not enable_dropout
)
bidirectional_mask_video = mm_processor.get_bidirectional_mask_vision(
Expand Down Expand Up @@ -563,7 +563,7 @@ def __call__(
kv_caches=kv_caches,
attention_metadata=attention_metadata,
deepstack_visual_embeds=deepstack_visual_embeds,
mutable=mutable_collections,
mutable=mutable_collections, # pyrefly: ignore[unexpected-keyword]
) # pytype: disable=wrong-keyword-args

# If we are initializing the model AND MTP is enabled, we must create
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/utils/gradient_accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def reshape_to_microbatch_accumulations(batch_arr):
"ga_params": ga_params,
}
if is_nnx:
init_grad_and_loss["rest_state"] = rest
init_grad_and_loss["rest_state"] = rest # pyrefly: ignore[unbound-name]

grad_and_loss, aux = jax.lax.scan(
accumulate_gradient, init_grad_and_loss, data, length=config.gradient_accumulation_steps
Expand Down
22 changes: 11 additions & 11 deletions src/maxtext/utils/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,8 +336,8 @@ def initialize_jax_for_gpu(raw_keys):

jax.distributed.initialize(
coordinator_address=f"{coordinator_ip}:{coordinator_port}",
num_processes=int(os.getenv("NNODES")),
process_id=int(os.getenv("NODE_RANK")),
num_processes=int(os.getenv("NNODES")), # pyrefly: ignore[bad-argument-type]
process_id=int(os.getenv("NODE_RANK")), # pyrefly: ignore[bad-argument-type]
initialization_timeout=raw_keys["jax_distributed_initialization_timeout"],
local_device_ids=devices,
)
Expand All @@ -349,16 +349,16 @@ def initialize_jax_for_cpu(raw_keys):
coordinator_ip_address = get_coordinator_ip_address()
coordinator_address = coordinator_ip_address + ":1234" # JAX coordinator port used in XPK
# Env variables to be set in XPK or otherwise
job_index = int(os.environ.get("JOB_INDEX"))
job_completion_index = int(os.environ.get("JOB_COMPLETION_INDEX"))
processes_in_job = int(os.environ.get("PROCESSES_IN_JOB"))
job_index = int(os.environ.get("JOB_INDEX")) # pyrefly: ignore[bad-argument-type]
job_completion_index = int(os.environ.get("JOB_COMPLETION_INDEX")) # pyrefly: ignore[bad-argument-type]
processes_in_job = int(os.environ.get("PROCESSES_IN_JOB")) # pyrefly: ignore[bad-argument-type]
pid = job_index * processes_in_job + job_completion_index
max_logging.log(f" Jax process id is {pid} ")
# Explicit initialize is needed only for CPUs
jax.distributed.initialize(
coordinator_address=coordinator_address,
process_id=pid,
num_processes=int(os.environ.get("JAX_PROCESS_COUNT")),
num_processes=int(os.environ.get("JAX_PROCESS_COUNT")), # pyrefly: ignore[bad-argument-type]
initialization_timeout=raw_keys["jax_distributed_initialization_timeout"],
)

Expand Down Expand Up @@ -444,7 +444,7 @@ def get_coordinator_ip_address():
max_coordinator_lookups = 50
while not coordinator_found and lookup_attempt <= max_coordinator_lookups:
try:
coordinator_ip_address = socket.gethostbyname(coordinator_address)
coordinator_ip_address = socket.gethostbyname(coordinator_address) # pyrefly: ignore[bad-argument-type]
coordinator_found = True
except socket.gaierror:
max_logging.log(
Expand Down Expand Up @@ -676,7 +676,7 @@ def _cross_entropy_with_logits_fwd(logits: jnp.ndarray, targets: jnp.ndarray, z_
log_z = jnp.squeeze(jnp.log(sum_exp) + max_logit, axis=-1)
total_z_loss = z_loss * jax.lax.square(log_z)
loss += total_z_loss
return (loss, total_z_loss), (
return (loss, total_z_loss), ( # pyrefly: ignore[bad-return]
logits,
targets,
z_loss,
Expand All @@ -698,11 +698,11 @@ def _cross_entropy_with_logits_bwd(
g: tuple[jnp.ndarray, jnp.ndarray],
) -> tuple[jnp.ndarray, None, None]:
"""Backward-mode of `cross_entropy_with_logits`."""
g = g[0] # Ignore z_loss component as that is only used for logging.
g = g[0] # Ignore z_loss component as that is only used for logging. # pyrefly: ignore[bad-assignment]
logits, targets, z_loss, exp_shifted, sum_exp, log_z = res
# z-loss term adds the (2 * z_loss * log_z) factor.
deriv = jnp.expand_dims(1 + 2 * z_loss * log_z, -1) * exp_shifted / sum_exp - targets
g_logits = jnp.expand_dims(g, axis=-1) * deriv
g_logits = jnp.expand_dims(g, axis=-1) * deriv # pyrefly: ignore[bad-argument-type]

return (
jnp.asarray(g_logits, logits.dtype),
Expand Down Expand Up @@ -1214,7 +1214,7 @@ def transformer_engine_context():
tp_resource="tensor",
# tpsp_resource = "tensor_sequence", #TODO(Phuong): add this back when upstreaming CGEMM
fsdp_resource="fsdp",
pp_resource=None,
pp_resource=None, # pyrefly: ignore[bad-argument-type]
cp_resource="context",
)
with global_shard_guard(mesh_resource):
Expand Down
8 changes: 4 additions & 4 deletions src/maxtext/utils/maxtext_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_functional_train_with_signature(
):
"""Get the shardings (both state and data) for `train_step`."""
functional_train = functools.partial(train_step, model, config, state_mesh_shardings, params_shardings)
functional_train.__name__ = "train_step"
functional_train.__name__ = "train_step" # pyrefly: ignore[missing-attribute]
if config.pure_nnx:
in_shardings = (state_mesh_shardings, data_sharding) # State, batch
else:
Expand All @@ -109,7 +109,7 @@ def get_functional_train_with_signature(
def get_functional_eval_with_signature(eval_step, data_sharding, state_mesh_shardings, model, config):
"""Get the shardings (both state and data) for `eval_step`."""
functional_eval = functools.partial(eval_step, model, config)
functional_eval.__name__ = "eval_step"
functional_eval.__name__ = "eval_step" # pyrefly: ignore[missing-attribute]
if config.pure_nnx:
in_shardings = (state_mesh_shardings, data_sharding) # State, batch (NNX: no rng)
else:
Expand Down Expand Up @@ -1392,7 +1392,7 @@ def get_abstract_param(model, config):
{"params": key, "dropout": key, "aqt": key},
np.ones(input_shape, dtype=jnp.int32),
np.ones(input_shape, dtype=jnp.int32),
encoder_images=np.ones(image_shape, dtype=jnp.int32) if config.use_multimodal else None,
encoder_images=np.ones(image_shape, dtype=jnp.int32) if config.use_multimodal else None, # pyrefly: ignore[no-matching-overload]
encoder_audios=np.ones(audio_shape, dtype=jnp.float32) if config.use_audio else None,
)
return abstract_vars
Expand Down Expand Up @@ -2022,7 +2022,7 @@ def to_abstract(x):
# Convert all input arguments recursively to purely local abstract ShapeDtypeStruct objects
# to completely bypass remote Array objects and proxy tracing overhead.
abstract_inputs = jax.tree.map(to_abstract, train_step_inputs)
p_train_jaxpr = jax.make_jaxpr(unwrapped_step)(*abstract_inputs)
p_train_jaxpr = jax.make_jaxpr(unwrapped_step)(*abstract_inputs) # pyrefly: ignore[no-matching-overload]

local_filename = "train_step.jaxpr"
local_path = os.path.join(config.dump_jaxpr_local_dir, local_filename)
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/utils/maxtext_utils_nnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def create_nnx_sharded_model(
named_sharding = nnx_extract_named_sharding(abstract_state)

if mesh is None:
mesh = abstract_model.mesh
mesh = abstract_model.mesh # pyrefly: ignore[missing-attribute]

# JIT a function that creates the model state with proper sharding from the start.
# By providing out_shardings, we instruct JAX to produce sharded output directly,
Expand Down
2 changes: 1 addition & 1 deletion src/maxtext/utils/qk_clip_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def apply_qk_clip_nnx(state, intermediate_outputs, config):
tau = float(config.qk_clip_threshold)

_, params_state, _ = nnx.split(state.model, nnx.Param, ...)
params_dict = params_state.to_pure_dict()
params_dict = params_state.to_pure_dict() # pyrefly: ignore[missing-attribute]

def clip_mla_weights(path, param):
if len(path) < 2:
Expand Down
4 changes: 2 additions & 2 deletions src/maxtext/utils/sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def add_data_to_sharding(mesh, path, aval, sharding):
partition = (partition,)

if size % mesh.shape["data"] == 0 and (partition is None or "tensor" not in partition):
added_component = ("data",) + partition
added_component = ("data",) + partition # pyrefly: ignore[unsupported-operation]
new_pspec = jax.sharding.PartitionSpec(*(pspec[:idx] + (added_component,) + pspec[idx + 1 :]))
new_sharding = jax.sharding.NamedSharding(sharding.mesh, new_pspec)
return new_sharding
Expand Down Expand Up @@ -581,7 +581,7 @@ def maybe_update_params_sharding_with_opt(config, state_mesh_shardings):
# When quantization=fp8 is enabled the sharded_fp32_params
# are not wrapped in `params`. Here we wrap them back.
sharded_fp32_params = {"params": sharded_fp32_params}
state_mesh_shardings = state_mesh_shardings.replace(params=dict(prev_params_shardings, **sharded_fp32_params))
state_mesh_shardings = state_mesh_shardings.replace(params=dict(prev_params_shardings, **sharded_fp32_params)) # pyrefly: ignore[bad-unpacking]
return prev_params_shardings, state_mesh_shardings


Expand Down
10 changes: 5 additions & 5 deletions src/maxtext/utils/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,11 +331,11 @@ def create_train_state_fn():
inner_state_shardings,
# Match the outer params' pure-dict structure (build_diloco_state stores
# outer_params via to_pure_dict), so the sharding tree matches the state tree.
state_mesh_shardings_params.to_pure_dict()
state_mesh_shardings_params.to_pure_dict() # pyrefly: ignore[missing-attribute]
if config.pure_nnx
else state_mesh_shardings_params,
outer_opt_state_sharding,
jax.sharding.NamedSharding(
jax.sharding.NamedSharding( # pyrefly: ignore[bad-argument-type]
mesh=step_mesh, spec=jax.sharding.PartitionSpec()
),
)
Expand All @@ -355,17 +355,17 @@ def create_train_state_fn():
logical_annotations = maxtext_utils.get_logical_annotations(config, mesh, init_state_fn)
logical_annotations_params = logical_annotations.params

max_utils.print_non_trivial_mesh_axis(model.mesh)
max_utils.print_non_trivial_mesh_axis(model.mesh) # pyrefly: ignore[missing-attribute]
maxtext_utils.print_shardings_params(state_params, state_mesh_shardings_params, mesh, logical_annotations_params)

if config.pure_nnx:
if config.enable_diloco:
# Don't merge the DiLoCoTrainState into the plain-model graphdef. The inner
# train step needs that graphdef as jit_model; the wrapper passes through as state.
train_state = state
model = state_graphdef
model = state_graphdef # pyrefly: ignore[unbound-name]
else:
train_state = nnx.merge(state_graphdef, state)
train_state = nnx.merge(state_graphdef, state) # pyrefly: ignore[unbound-name]
model = train_state.model
else:
train_state = state
Expand Down
Loading