-
Notifications
You must be signed in to change notification settings - Fork 391
Description
Bug report
I wanted to give maxtext a try for TPU inferencing of open weight models on TPUs, but I am getting a strange shape error, immediately when trying to start the inference server for a Gemma-1 7B model on a V4-8 TPU VM
Logs/Output
2025-07-20 10:35:34.556142: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1753007734.571556 408795 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1753007734.576301 408795 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1753007734.589191 408795 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1753007734.589212 408795 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1753007734.589225 408795 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1753007734.589230 408795 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
Updating keys from env and command line: ['model_name', 'load_parameters_path', 'weight_dtype', 'scan_layers', 'ici_fsdp_parallelism', 'ici_tensor_parallelism', 'ici_autoregressive_parallelism', 'tokenizer_path', 'per_device_batch_size', 'max_target_length', 'max_prefill_predict_length']
Running Model: gemma-7b
Updating following parameters in config
base_emb_dim: 3072
base_num_query_heads: 16
base_num_kv_heads: 16
base_mlp_dim: 24576
base_num_decoder_layers: 28
head_dim: 256
mlp_activations: ['gelu', 'linear']
vocab_size: 256128
decoder_block: gemma
normalization_layer_epsilon: 1e-06
logits_via_embedding: True
Updating keys from model: ['base_emb_dim', 'base_num_query_heads', 'base_num_kv_heads', 'base_mlp_dim', 'base_num_decoder_layers', 'head_dim', 'mlp_activations', 'vocab_size', 'decoder_block', 'normalization_layer_epsilon', 'logits_via_embedding']
Attempting to initialize the jax distributed system...
2025-07-20 10:35:42.196725: I external/xla/xla/tsl/platform/default/grpc_credentials.cc:38] gRPC insecure server credentials are used.
2025-07-20 10:35:42.196828: I external/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc:225] Initializing CoordinationService
2025-07-20 10:35:42.199230: I external/xla/xla/pjrt/distributed/service.cc:75] Coordination service is enabled.
2025-07-20 10:35:42.199509: I external/xla/xla/pjrt/distributed/service.cc:107] Jax service listening on [::]:8476
2025-07-20 10:35:42.199533: I external/xla/xla/tsl/platform/default/grpc_credentials.cc:30] gRPC insecure client credentials are used.
2025-07-20 10:35:42.200638: I external/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc:1413] Barrier([Init]Wait_for_all_tasks_to_register::0) has passed with status: OK
2025-07-20 10:35:42.200680: I external/xla/xla/tsl/distributed_runtime/coordination/coordination_service.h:523] /job:jax_worker/replica:0/task:0 has connected to coordination service. Incarnation: 3536328435953780183
2025-07-20 10:35:42.203024: I external/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc:344] Coordination agent has successfully connected.
2025-07-20 10:35:42.203170: I external/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc:416] Polling for error from coordination service. This is a long-running RPC that will return only if an error is encountered or cancelled (e.g. due to shutdown).
2025-07-20 10:35:42.203322: I external/xla/xla/pjrt/distributed/client.cc:121] Connected to distributed JAX controller
Jax distributed system initialized!
2025-07-20 10:35:42.250851: I external/xla/xla/pjrt/pjrt_api.cc:115] GetPjrtApi was found for tpu at /home/jan/googel_infrence/env/lib/python3.10/site-packages/libtpu/libtpu.so
2025-07-20 10:35:42.250888: I external/xla/xla/pjrt/pjrt_api.cc:93] PJRT_Api is set for device type tpu
2025-07-20 10:35:42.250934: I external/xla/xla/pjrt/pjrt_api.cc:161] The PJRT plugin has PJRT API version 0.69. The framework PJRT API version is 0.70.
2025-07-20 10:35:44.714067: I external/xla/xla/pjrt/pjrt_c_api_client.cc:130] PjRtCApiClient created.
Not using emergency checkpoint, ignoring local_checkpoint_directory, local_checkpoint_period, use_replicator_service and replicator_backup_interval_minutes
dataset_type set to tfds, will use keys['dataset_path']='' and keys['dataset_name']='c4/en:3.0.1'
Config param activations_in_float32: False
Config param adam_b1: 0.9
Config param adam_b2: 0.95
Config param adam_eps: 1e-08
Config param adam_eps_root: 0.0
Config param adam_weight_decay: 0.1
Config param add_bos: True
Config param add_eos: True
Config param allow_split_physical_axes: False
Config param ar_cache_axis_order: 1,2,0,3
Config param async_checkpointing: True
Config param attention: autoselected
Config param attention_type: global
Config param attn_logits_soft_cap: None
Config param autoregressive_decode_assert:
Config param base_emb_dim: 3072
Config param base_mlp_dim: 24576
Config param base_moe_mlp_dim: 7168
Config param base_num_decoder_layers: 28
Config param base_num_kv_heads: 16
Config param base_num_query_heads: 16
Config param base_output_directory:
Config param beta_fast: 32
Config param beta_slow: 1
Config param capacity_factor: -1.0
Config param cast_logits_to_fp32: True
Config param checkpoint_is_quantized: False
Config param checkpoint_period: 10000
Config param checkpoint_storage_concurrent_gb: 96
Config param checkpoint_storage_target_data_file_size_bytes: 2147483648
Config param checkpoint_storage_use_ocdbt: True
Config param checkpoint_storage_use_zarr3: True
Config param chunk_attn_window_size: 0
Config param collect_stack_trace: False
Config param colocated_python_data_input: False
Config param compile_topology:
Config param compile_topology_num_slices: -1
Config param compiled_trainstep_file:
Config param compute_axis_order: 0,1,2,3
Config param constant_bound_config: []
Config param context: remat
Config param context_parallel_load_balance: True
Config param cosine_learning_rate_final_fraction: 0.1
Config param custom_mesh:
Config param data_sharding: (('data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive'),)
Config param data_shuffle_seed: 0
Config param dataset_name: c4/en:3.0.1
Config param dataset_path:
Config param dataset_type: tfds
Config param dcn_autoregressive_parallelism: 1
Config param dcn_context_autoregressive_parallelism: 1
Config param dcn_context_parallelism: 1
Config param dcn_data_parallelism: -1
Config param dcn_expert_parallelism: 1
Config param dcn_fsdp_parallelism: 1
Config param dcn_fsdp_transpose_parallelism: 1
Config param dcn_parallelism: [-1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
Config param dcn_pipeline_parallelism: 1
Config param dcn_sequence_parallelism: 1
Config param dcn_tensor_parallelism: 1
Config param dcn_tensor_sequence_parallelism: 1
Config param dcn_tensor_transpose_parallelism: 1
Config param decode_sampling_nucleus_p: -1
Config param decode_sampling_strategy: greedy
Config param decode_sampling_temperature: 1.0
Config param decode_sampling_top_k: 0
Config param decoder_block: DecoderBlockType.GEMMA
Config param decoder_layer_input: device
Config param dpo_beta: 0.1
Config param dpo_label_smoothing: 0.0
Config param dropout_rate: 0.0
Config param dtype: bfloat16
Config param dtype_mm: float32
Config param dump_hlo: False
Config param dump_hlo_delete_local_after: True
Config param dump_hlo_gcs_dir:
Config param dump_hlo_local_dir: /tmp/xla_dump/
Config param dump_hlo_module_name: jit_train_step
Config param dump_hlo_upload_all: False
Config param dump_hlo_xla_flags:
Config param dump_step: -1
Config param emb_dim: 3072
Config param enable_checkpoint_cloud_logger: False
Config param enable_checkpointing: True
Config param enable_data_shuffling: True
Config param enable_dropout: True
Config param enable_emergency_checkpoint: False
Config param enable_gcp_goodput_metrics: True
Config param enable_gcp_step_deviation_metrics: True
Config param enable_goodput_recording: False
Config param enable_jax_profiler: False
Config param enable_llm_inference_pool: False
Config param enable_model_warmup: False
Config param enable_padding_causal_mask: True
Config param enable_pathways_goodput: False
Config param enable_prefix_caching: False
Config param enable_single_controller: False
Config param enable_single_replica_ckpt_restoring: False
Config param enable_tensorboard: True
Config param eval_data_columns: ['text']
Config param eval_dataset_name: c4/en:3.0.1
Config param eval_image_column: image
Config param eval_interval: -1
Config param eval_per_device_batch_size: 11.0
Config param eval_split: validation
Config param eval_steps: -1
Config param expansion_factor_real_data: -1
Config param final_logits_soft_cap: None
Config param first_num_dense_layers: 0
Config param float32_logits: False
Config param float32_qk_product: False
Config param force_unroll: False
Config param freeze_vision_encoder_params: True
Config param fused_mlp: False
Config param fused_qkv: False
Config param gcs_metrics: False
Config param generate_slice: v5e-16
Config param global_batch_size_to_eval_on: 44
Config param global_batch_size_to_load: 44
Config param global_batch_size_to_load_eval: 44
Config param global_batch_size_to_train_on: 44
Config param global_parameter_scale: 1
Config param goodput_upload_interval_seconds: 30
Config param gradient_accumulation_steps: 1
Config param gradient_clipping_threshold: 1.0
Config param grain_eval_files:
Config param grain_file_type: arrayrecord
Config param grain_train_files:
Config param grain_worker_count: 1
Config param grain_worker_count_eval: 1
Config param hardware: tpu
Config param head_dim: 256
Config param heartbeat_reporting_interval_in_seconds: 5
Config param hf_data_dir:
Config param hf_eval_files:
Config param hf_eval_split:
Config param hf_path:
Config param hf_train_files:
Config param hidden_size_for_vit: 1408
Config param ici_autoregressive_parallelism: 1
Config param ici_context_autoregressive_parallelism: 1
Config param ici_context_parallelism: 1
Config param ici_data_parallelism: 1
Config param ici_expert_parallelism: 1
Config param ici_fsdp_parallelism: 1
Config param ici_fsdp_transpose_parallelism: 1
Config param ici_parallelism: [1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 1]
Config param ici_pipeline_parallelism: 1
Config param ici_sequence_parallelism: 1
Config param ici_tensor_parallelism: 4
Config param ici_tensor_sequence_parallelism: 1
Config param ici_tensor_transpose_parallelism: 1
Config param image_path:
Config param image_placeholder: <|image|>
Config param image_size_for_vit: 896
Config param inference_benchmark_test: False
Config param inference_metadata_file:
Config param inference_microbenchmark_log_file_path:
Config param inference_microbenchmark_loop_iters: 10
Config param inference_microbenchmark_num_samples: [1, 2, 3, 4, 5]
Config param inference_microbenchmark_prefill_lengths: 64,128,256,512,1024
Config param inference_microbenchmark_stages: prefill,generate
Config param inference_server: MaxtextInterleavedServer
Config param inhomogeneous_layer_cycle_interval: 1
Config param init_weights_seed: 0
Config param input_data_sharding_logical_axes: ['activation_embed_and_logits_batch', 'activation_norm_length']
Config param interleave_moe_layer_step: 1
Config param intermediate_size_for_vit: 5632
Config param jax_cache_dir: ~/jax_cache
Config param jax_debug_log_modules:
Config param jax_distributed_initialization_timeout: 300
Config param jax_profiler_port: 9999
Config param key_proj: remat
Config param kv_lora_rank: 512
Config param kv_quant_axis: heads_and_dkv
Config param kv_quant_dtype: int8
Config param learning_rate: 3e-05
Config param learning_rate_schedule_steps: 150001
Config param load_balance_loss_weight: 0.01
Config param load_from_prefill_dir: False
Config param load_full_state_path:
Config param load_parameters_path: /home/jan/googel_infrence/gemma_checkpoints/7b-it/
Config param local_checkpoint_directory:
Config param local_checkpoint_period: 0
Config param local_rope_max_timescale: -1
Config param log_config: True
Config param log_period: 100
Config param logical_axis_rules: (('activation_batch', ('data', 'fsdp', 'fsdp_transpose', 'expert')), ('activation_batch_no_exp', ('data', 'fsdp', 'fsdp_transpose')), ('activation_embed_and_logits_batch', ('data', 'stage', 'fsdp', 'fsdp_transpose', 'expert')), ('activation_heads', ('tensor', 'tensor_transpose', 'sequence', 'tensor_sequence', 'autoregressive')), ('activation_kv_heads', ('tensor', 'tensor_transpose', 'sequence', 'tensor_sequence')), ('activation_length', ('sequence', 'context')), ('prefill_activation_length', ('sequence', 'context')), ('activation_length', ('context',)), ('activation_norm_length', ('tensor_sequence', 'context', 'sequence')), ('prefill_activation_norm_length', ('tensor_sequence', 'context', 'sequence')), ('activation_q_length', ('context',)), ('activation_kv_length', ()), ('activation_embed', ('tensor', 'tensor_transpose')), ('activation_mlp', ('tensor', 'tensor_transpose', 'tensor_sequence')), ('activation_kv', ('tensor', 'tensor_transpose', 'tensor_sequence')), ('activation_prefill_kv_batch', ('data', 'fsdp', 'fsdp_transpose', 'expert')), ('activation_kv_batch', ('data', 'fsdp', 'fsdp_transpose', 'expert')), ('activation_kv_head_dim', ('tensor', 'tensor_transpose', 'tensor_sequence')), ('activation_vocab', ('tensor', 'tensor_transpose', 'sequence', 'tensor_sequence')), ('activation_vocab', ('tensor', 'tensor_transpose')), ('activation_vocab', 'tensor_sequence'), ('activation_vocab', ('sequence', 'context')), ('activation_stage', 'stage'), ('activation_exp', ('expert',)), ('decode_batch', ('data', 'fsdp', 'fsdp_transpose', 'expert')), ('decode_length', ('sequence',)), ('mlp', ('fsdp_transpose', 'tensor', 'tensor_sequence', 'autoregressive')), ('vocab', ('tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive')), ('heads', ('tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive')), ('q_heads', ('tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive')), ('kv_heads', ('tensor', 'tensor_transpose', 'tensor_sequence', 'autoregressive')), ('embed', ('fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context', 'expert')), ('embed', ('fsdp', 'sequence', 'tensor_transpose', 'context', 'expert')), ('embed', ('fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert')), ('embed', ('fsdp', 'sequence', 'context', 'expert')), ('embed_no_exp', ('fsdp', 'fsdp_transpose', 'sequence', 'tensor_transpose', 'context')), ('embed_no_exp', ('fsdp', 'sequence', 'tensor_transpose', 'context')), ('embed_no_exp', ('fsdp', 'fsdp_transpose', 'sequence', 'context')), ('embed_no_exp', ('fsdp', 'sequence', 'context')), ('q_lora', ('fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert')), ('q_lora', ('fsdp', 'sequence', 'context', 'tensor_transpose', 'expert')), ('q_lora', ('fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert')), ('q_lora', ('fsdp', 'sequence', 'context', 'expert')), ('kv_lora', ('fsdp', 'fsdp_transpose', 'sequence', 'context', 'tensor_transpose', 'expert')), ('kv_lora', ('fsdp', 'sequence', 'context', 'tensor_transpose', 'expert')), ('kv_lora', ('fsdp', 'fsdp_transpose', 'sequence', 'context', 'expert')), ('kv_lora', ('fsdp', 'sequence', 'context', 'expert')), ('norm', ('tensor', 'tensor_transpose', 'tensor_sequence')), ('layers', 'stage'), ('kv', ()), ('kv_head_dim', ()), ('cache_batch_prefill', ()), ('cache_batch', ()), ('cache_heads_none', ()), ('cache_heads', ('autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence')), ('cache_heads', ('autoregressive', 'tensor', 'tensor_sequence')), ('cache_kv', ()), ('cache_sequence', ()), ('exp', 'expert'), ('paged_kv_heads', ('tensor',)), ('num_pages', ()), ('tokens_per_page', ()), ('paged_kv_head_dim_size', ()))
Config param logits_dot_in_fp32: False
Config param logits_via_embedding: True
Config param lora_input_adapters_path:
Config param matmul_precision: default
Config param max_checkify: False
Config param max_corpus_chars: 10000000
Config param max_position_embeddings: 163840
Config param max_prefill_predict_length: 1024
Config param max_target_length: 2048
Config param megablox: True
Config param mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'context', 'context_autoregressive', 'tensor', 'tensor_transpose', 'tensor_sequence', 'expert', 'autoregressive']
Config param metrics_file:
Config param micro_batch_size_to_eval_on: 44
Config param micro_batch_size_to_train_on: 44
Config param mla_naive_kvcache: True
Config param mlp_activations: ['gelu', 'linear']
Config param mlp_dim: 24576
Config param mlpwi: remat
Config param mlpwi_0: remat
Config param mlpwi_1: remat
Config param mlpwo: remat
Config param model_call_mode:
Config param model_fsdp_ag_once: False
Config param model_name: gemma-7b
Config param moe_mlp_dim: 7168
Config param monitor_goodput: False
Config param monitor_step_time_deviation: True
Config param mscale: 1.0
Config param mtp_eval_target_module: 0
Config param mtp_loss_scaling_factor: 0.1
Config param mtp_num_layers: 0
Config param mu_dtype: bfloat16
Config param multi_sampling: False
Config param n_routing_groups: -1
Config param nope_layer_interval: -1
Config param normalization_layer_epsilon: 1e-06
Config param normalize_embedding_logits: True
Config param num_attention_heads_for_vit: 16
Config param num_channels_for_vit: 3
Config param num_decoder_layers: 28
Config param num_epoch: 1
Config param num_experts: 1
Config param num_experts_per_tok: 1
Config param num_hidden_layers_for_vit: 34
Config param num_kv_heads: 16
Config param num_layers_per_pipeline_stage: 1
Config param num_pipeline_microbatches: -1
Config param num_pipeline_repeats: -1
Config param num_query_heads: 16
Config param num_slices: 1
Config param opt_type: adamw
Config param optimize_mesh_for_tpu_v6e: False
Config param optimizer_memory_host_offload: False
Config param original_max_position_embeddings: 4096
Config param out_proj: remat
Config param override_model_config: False
Config param packing: True
Config param pagedattn_head_dim_alignment: 128
Config param pagedattn_max_pages_per_group: 64
Config param pagedattn_num_pages: 64
Config param pagedattn_pages_per_compute_block: 4
Config param pagedattn_tokens_per_page: 32
Config param param_scan_axis: 1
Config param parameter_memory_host_offload: False
Config param patch_size_for_vit: 14
Config param per_device_batch_size: 11.0
Config param pipeline_delay_activation_forwarding: False
Config param pipeline_fsdp_ag_once: False
Config param pipeline_parallel_layers: -1
Config param pixel_shuffle_ratio_for_vit: 0.5
Config param prefill_cache_axis_order: 1,2,0,3
Config param prefill_cache_dir:
Config param prefill_chunk_size: 256
Config param prefill_slice: v5e-16
Config param prefix_caching_dram_byte: 100000000000
Config param prefix_caching_hbm_byte: 10000000000
Config param profile_cleanly: True
Config param profile_periodically_period: -1
Config param profiler:
Config param profiler_steps: 5
Config param projector_dropout_for_vit: 0.0
Config param projector_input_dim_for_vit: 4096
Config param projector_output_dim_for_vit: 4096
Config param prometheus_port: 0
Config param prompt: I love to
Config param q_lora_rank: 0
Config param qk_nope_head_dim: 128
Config param qk_rope_head_dim: 64
Config param qkv_proj: remat
Config param quant_cfg_path:
Config param quantization:
Config param quantization_local_shard_count: 1
Config param quantize_kvcache: False
Config param query_proj: remat
Config param ragged_block_size: 256
Config param record_internal_nn_metrics: 0
Config param remat_policy: full
Config param remat_policy_for_vit: minimal
Config param replicate_quant_scale: False
Config param replicator_backup_interval_minutes: 0
Config param report_heartbeat_metric_for_gcp_monitoring: False
Config param report_performance_metric_for_gcp_monitoring: False
Config param reshape_q: False
Config param return_log_prob: False
Config param reuse_example_batch: 0
Config param rope_factor: 40
Config param rope_max_timescale: 10000
Config param rope_min_timescale: 1
Config param rope_theta_for_vit: 10000
Config param rope_type: default
Config param rope_use_scale: True
Config param routed_bias: False
Config param routed_scaling_factor: 1.0
Config param routed_score_func:
Config param run_name: None
Config param sa_block_kv: 512
Config param sa_block_kv_compute: 512
Config param sa_block_kv_dkv: 512
Config param sa_block_kv_dkv_compute: 512
Config param sa_block_kv_dq: 512
Config param sa_block_q: 512
Config param sa_block_q_dkv: 512
Config param sa_block_q_dq: 512
Config param sa_k_layout: HEAD_DIM_MINOR
Config param sa_q_layout: HEAD_DIM_MINOR
Config param sa_use_fused_bwd_kernel: False
Config param sa_v_layout: HEAD_DIM_MINOR
Config param save_config_to_gcs: False
Config param save_quantized_params_path:
Config param scan_layers: False
Config param scan_layers_per_stage: False
Config param scan_pipeline_iterations: True
Config param set_remat_policy_on_layers_per_stage: False
Config param set_remat_policy_on_pipeline_iterations: True
Config param sft_train_on_completion_only: False
Config param sharding_tolerance: 0.02
Config param shared_experts: 1
Config param skip_first_n_steps_for_profiler: 1
Config param skip_jax_distributed_system: False
Config param sliding_window_size: 0
Config param sparse_matmul: True
Config param stack_prefill_result_cache: False
Config param stack_trace_interval_seconds: 600
Config param stack_trace_to_cloud: False
Config param step_deviation_interval_seconds: 30
Config param steps: 150001
Config param subslice_shape:
Config param target_eval_loss: 0.0
Config param temperature_tuning: False
Config param tile_activation_dim: 1024
Config param tile_batch_seq: 512
Config param tile_size_for_vit: 336
Config param tile_weight_dim: 1024
Config param tokenize_eval_data: True
Config param tokenize_train_data: True
Config param tokenizer_path: /home/jan/googel_infrence/maxtext/assets/tokenizer.gemma
Config param tokenizer_type: sentencepiece
Config param topk_routing_group: -1
Config param train_data_columns: ['text']
Config param train_image_column: image
Config param train_split: train
Config param trainable_position_size: -1
Config param upload_all_profiler_results: False
Config param use_chat_template: False
Config param use_chunked_prefill: False
Config param use_dpo: False
Config param use_iota_embed: False
Config param use_multimodal: False
Config param use_post_attn_norm: False
Config param use_post_ffw_norm: False
Config param use_qk_norm: False
Config param use_ragged_attention: False
Config param use_random_routing: False
Config param use_replicator_service: False
Config param use_sft: False
Config param use_untrainable_positional_embedding: False
Config param use_vertex_tensorboard: False
Config param using_pipeline_parallelism: False
Config param v_head_dim: 128
Config param value_proj: remat
Config param vertex_tensorboard_project:
Config param vertex_tensorboard_region:
Config param vision_output_dim_for_vit: 4096
Config param vocab_size: 256128
Config param warmup_steps_fraction: 0.1
Config param weight_dtype: bfloat16
2025-07-20 10:35:44,792 - jetstream.core.server_lib - INFO - Using devices: 4
2025-07-20 10:35:44,792 - jetstream.core.server_lib - INFO - Kicking off gRPC server.
2025-07-20 10:35:44,793 - jetstream.core.server_lib - INFO - Not starting Prometheus server: --prometheus_port flag not set
Num_devices: 4, shape (1, 1, 1, 1, 1, 1, 1, 4, 1, 1, 1, 1)
Loading decode params from /home/jan/googel_infrence/gemma_checkpoints/7b-it/
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
q_seq_len: 2048
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
query: (44, 4, 2048, 256)
key: (44, 4, 2048, 256)
value: (44, 4, 2048, 256)
restoring params from /home/jan/googel_infrence/gemma_checkpoints/7b-it/
Creating checkpoint manager with ocdbt=True and zarr3=True
WARNING:absl:The transformations API will eventually be replaced by an upgraded design. The current API will not be removed until this point, but it will no longer be actively worked on.
query: (44, 4, 1024, 256)
key: (44, 4, 1024, 256)
value: (44, 4, 1024, 256)
q_seq_len: 1024
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/home/jan/googel_infrence/maxtext/MaxText/maxengine_server.py", line 88, in <module>
main(cfg)
File "/home/jan/googel_infrence/maxtext/MaxText/maxengine_server.py", line 68, in main
jetstream_server = server_lib.run(
File "/home/jan/googel_infrence/env/lib/python3.10/site-packages/jetstream/core/server_lib.py", line 325, in run
driver = create_driver(
File "/home/jan/googel_infrence/env/lib/python3.10/site-packages/jetstream/core/server_lib.py", line 162, in create_driver
shared_params = [ie.load_params() for ie in engines.interleaved_engines]
File "/home/jan/googel_infrence/env/lib/python3.10/site-packages/jetstream/core/server_lib.py", line 162, in <listcomp>
shared_params = [ie.load_params() for ie in engines.interleaved_engines]
File "/home/jan/googel_infrence/maxtext/MaxText/maxengine.py", line 249, in load_params
self.prefill_kv_cache_annotations = maxtext_utils.get_prefill_kv_cache_annotations(
File "/home/jan/googel_infrence/maxtext/MaxText/maxtext_utils.py", line 837, in get_prefill_kv_cache_annotations
abstract_state = jax.eval_shape(init_kv_cache_partial)
File "/home/jan/googel_infrence/maxtext/MaxText/maxtext_utils.py", line 824, in init_kv_cache
model_vars = model.init(
File "/home/jan/googel_infrence/maxtext/MaxText/layers/models.py", line 117, in __call__
logits, hidden_state = self.decoder(
File "/home/jan/googel_infrence/maxtext/MaxText/layers/decoders.py", line 742, in __call__
y = layer(
File "/home/jan/googel_infrence/maxtext/MaxText/layers/gemma.py", line 90, in __call__
attention_lnx = attention_layer(
File "/home/jan/googel_infrence/maxtext/MaxText/layers/attentions.py", line 1803, in __call__
out = self.attention_op(
File "/home/jan/googel_infrence/maxtext/MaxText/layers/nnx_wrappers.py", line 396, in __call__
out = method_fn(module, *args, **kwargs)
File "/home/jan/googel_infrence/maxtext/MaxText/layers/attentions.py", line 1267, in __call__
prefill_unnormalized_output, prefill_exponentials_max, prefill_exponentials_sum = self.apply_attention(
File "/home/jan/googel_infrence/maxtext/MaxText/layers/attentions.py", line 568, in apply_attention
return self.tpu_flash_attention(query, key, value, decoder_segment_ids, self.attn_logits_soft_cap), None, None
File "/home/jan/googel_infrence/maxtext/MaxText/layers/attentions.py", line 895, in tpu_flash_attention
x = wrap_flash_attention(
File "/home/jan/googel_infrence/maxtext/MaxText/layers/attentions.py", line 891, in wrap_flash_attention
attention_output = jax.vmap(splash_kernel)(query, key, value, segment_ids=decoder_segment_ids_tuple)
File "/home/jan/googel_infrence/env/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2349, in __call__
return _splash_attention(
File "/home/jan/googel_infrence/env/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 2314, in _splash_attention
return _splash_attention_custom(
File "/home/jan/googel_infrence/env/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 1194, in _splash_attention_custom
return _splash_attention_forward( # pytype: disable=wrong-arg-types
File "/home/jan/googel_infrence/env/lib/python3.10/site-packages/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", line 1047, in _splash_attention_forward
q_sequence = jax.lax.broadcast_in_dim(
TypeError: broadcast_in_dim operand dimension sizes must either be 1, or be equal to their corresponding dimensions in the target broadcast shape; got operand of shape (2048,), target broadcast shape (1024, 128), broadcast_dimensions (0,)
2025-07-20 10:35:47.576639: I external/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc:204] Shutting down PreemptionSyncManager...
2025-07-20 10:35:47.576826: I external/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc:212] PreemptionSyncManager shut down.
2025-07-20 10:35:47.576834: I external/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc:199] PreemptionSyncManager already shut down
2025-07-20 10:35:47.576868: I external/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc:144] Cancelled call to retrieve preemption notice. This is expected upon program shutdown.
2025-07-20 10:35:47.576904: I external/xla/xla/tsl/distributed_runtime/preemption/preemption_sync_manager.cc:117] Preemption sync protocol cancelled by notifier: CANCELLED: Preemption notifier is being deleted.. This is expected during program shutdown.
2025-07-20 10:35:47.576981: I external/xla/xla/pjrt/distributed/client.cc:139] Distributed task shutdown initiated.
2025-07-20 10:35:47.576989: I external/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc:614] Coordination agent has initiated Shutdown().
2025-07-20 10:35:47.577229: I external/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc:1413] Barrier(Shutdown::14881467394488074014::0) has passed with status: OK
2025-07-20 10:35:47.577396: I external/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc:1795] Shutdown barrier in coordination service has passed.
2025-07-20 10:35:47.577413: I external/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc:635] Coordination agent has successfully shut down.
2025-07-20 10:35:47.577455: I external/xla/xla/tsl/distributed_runtime/coordination/coordination_service.cc:761] /job:jax_worker/replica:0/task:0 has disconnected from coordination service.
2025-07-20 10:35:47.577497: I external/xla/xla/pjrt/distributed/client.cc:141] Distributed task shutdown result: OK
2025-07-20 10:35:47.577509: I external/xla/xla/pjrt/distributed/service.cc:123] Jax service shutting down
2025-07-20 10:35:47.577528: I external/xla/xla/tsl/distributed_runtime/coordination/coordination_service_agent.cc:423] Cancelling error polling because the service or the agent is shutting down.
Environment Information
I am using the newest MaxText master version and I downloaded the Gemma checkpoint from Kaggle directly in the MaxText format.
Start commend:
python3 -m MaxText.maxengine_server \
MaxText/configs/base.yml \
tokenizer_path=~/googel_infrence/maxtext/assets/tokenizer.gemma \
load_parameters_path=~/googel_infrence/gemma_checkpoints/7b-it/ \
max_prefill_predict_length=1024 \
max_target_length=2048 \
model_name=gemma-7b \
ici_fsdp_parallelism=1 \
ici_autoregressive_parallelism=1 \
ici_tensor_parallelism=4 \
scan_layers=false \
weight_dtype=bfloat16 \
per_device_batch_size=11
Additional Context
I already added some guide primitive debugging prints into the code, to better understand what is going on, and it seems like there is a size mismatch between the attention mask and the seq dim of the QKV input. Also note that the model seems to have an initial successful forward pass before using the Pallas splash_attention_kernel
which then has the shape mismatch error.
Thanks for the help in advance