In this second part of this series of LLM with the JAX ecosystem from scratch, I’ll go over my understanding of parallel training and how I implemented it in JAX. Most of my theory understanding came from the incredibly helpful tutorial How To Scale You Model (the scaling book). I would highly recommend anyone to take a look at that one. For someone like me whose day job doesn’t involve such low level ML system engineering and optimization, there is always something new to learn from the series every time I read it.
Theory
The reason for needing parallel training (and parallel model deployment) across multiple devices is basically two-fold.
- Most of the modern and realistic LLMs are too large to be loaded into one single device. Even if the model weights might fit into the VRAM, actually running the training will take a lot more memory because of optimizer states, activations, and all the intermediary calculation results saved for computing gradient later (even with gradient checkpointing, this part could still be significant).
- We want to efficiently use the compute (FLOPS) from multiple devices to speed up the training.
For training LLMs, five parallelisms are commonly used.
- Data parallelism. This shards the hidden states across the batch dimension. It could also be further improved by sharding certain model weights along the same mesh (devices are organized into one or more dimensional meshes) axis. This is what’s called Fully-Sharded Data Parallelism (FSDP) or ZeRO3. Because the model weights are sharded, during the forward pass, we will then need to all-gather the weights just in time to compute the activations. However, with careful designs, it is possible to keep the training in the compute bound regime and scale up training efficiently with more devices.
- Tensor/model/Megatron parallelism. This shards the hidden states across the activation dimension, and some model weight dimensions across the same mesh axis. In the forward pass, we will then need to all-gather/reduce-scatter activations (as opposed to weights in FSDP). This puts a different type of constraint on the sharding strategy in order to achieve compute bound.
- Sequence/context parallelism. This shards the hidden states across the sequence length dimension. For the dense/MLP layers, it is relatively easy to understand. This is pretty much the same as the data parallelism. But for the self-attention layer, we need to consider cross-token interactions, so it is a lot harder to design and implement.
- Pipeline parallelism. This shards the model and hidden states across the layer dimension. It also requires fairly sophisticated orchestration design to avoid idle times (bubbles) in the full forward and backward pass. I’m not sure if this is commonly used in academia or the open source community. It seems a lot of engineering and compute resources are required to make good use of it. One good source of learning how frontier labs might be doing this is the DeepSeek-V3 paper.
- Expert parallelism. This shards the experts in the Mixer of Expert models across multiple devices. I didn’t fully dig into this because of the time and compute resources it would take, but the DeepSeek-V3 paper once again is another incredible source of how this could be done.
Since the latter three cost a lot more than the first two, for my practice in this series, I only worked out the details and implemented the FSDP and TP.
FSDP
For FSDP, because we need to all-gather model weights (and gradients), this puts a constraint on the batch size in order to stay in the compute bound regime. For the detailed derivation, see this section in the scaling book. Here I collect the sharding scheme of the MLP layer and the attention layer in the forward pass, using the same notation as in the scaling book. The backward pass can be derived from there, and Autograd in JAX should in principle handle that automatically.
Since the hidden states don’t shard on the activation dimension, there is no need to shard the embedding matrix, or the weights of the RMSNorm layers. The RoPE matrix is also fairly small and there is no need to shard it.
MLP/SwiGLU
The MLP layer in theory can be abstracted as two matrix multiplication:
More specifically, the following operations are performed:
All these operations can be supported by explicit sharding in JAX out of the box as explained later. I just have to define the weight matrix shardings and the matmul output shardings.
Attention
Similarly, the operations and shardings of the self-attention layer could be broken down in the following steps.
- In-projection into QKV: .
- Split into Q, K, V.
- ), where is the number of heads and is the head dimension.
- Dot product attention.
- Output projection.
So as we can see, FSDP can also be implemented by explicit sharding relatively easily.
FSDP + TP
When using TP, because we need to all-gather hidden state activations, this puts a constraint on the max number of TP shards in order to stay in the compute bound regime. For the detailed derivation, see this section in the scaling book. In my exercise, I didn’t do separate TP runs. Instead I combined FSDP and TP directly to maximize the cost effectiveness of my hobby training fund since the goal is to learn how to implement and optimize these parallel training algorithms.
Similar to FSDP, there is no need to shard the RMSNorm layer or the RoPE matrix, but the embedding matrix can now be sharded along the activation dimension, same as the hidden states.
MLP/SwiGLU
Similar to FSDP, here are the operations and shardings of the MLP layer in the forward pass.
Attention
Here are the specific steps and shardings for the attention layer in the forward pass.
- In-projection into QKV: .
- , where
- Split into Q, K, V.
- )
- Dot product attention.
- Output projection.
This is also relatively straightforward to implement in JAX. One just needs to be careful in step 2, the size of Y axis should be divisible by 3, so that the sharded dimension can be easily moved to the number of head dimension.
Implementation
JAX has a fairly strong support for sharding. There are three modes available:
- Auto sharding. This basically lets the XLA compiler take the reins. The JAX code mostly doesn’t specify the sharding, except when there are user-provided constraints via
with_sharding_constraint. - Explicit sharding. In this case, data shardings are explicitly defined as part of the data’s JAX types. When creating a mesh, its axes will also need to be of explicit type (although the latest versions, such as 0.9 and later have that by default when using
jax.make_meshto create meshes). - Manual sharding. This gives the maximal flexibility in deciding how all the sharding and communication collectives.
For more details, see this official tutorial.
In my exercise, I used explicit sharding. Most of the things just work out of the box. I grouped all the sharding related configurations (weight sharding, output sharding etc.) into custom classes that are associated with the layers (nnx.Module’s). The sharding strategies are passed the layer constructors and then later used in __call__, so that the signatures and implementations of __call__ can mostly remain the same as the unsharded version.
With this simple design in place, the FSDP + TP sharding can be implemented as the following:
FSDP_TP_SHARDING = TransformerLmSharding( token_embeddings=EmbeddingSharding( embedding_matrix=P(None, "model"), out=P("data", None, "model") ), transformer_blocks=TransformerBlockSharding( rms_norm_pre_attn=RMSNormSharding( weight=P( None, ) ), attn=MultiHeadSelfAttentionSharding( combined_in_projection=LinearSharding( weight=P("data", "model"), out=P("data", None, "model") ), out_projection=LinearSharding( weight=P("model", "data"), out=P("data", None, "model") ), ), rms_norm_pre_ff=RMSNormSharding( weight=P( None, ) ), ffn=SwiGLUSharding( up_projection=LinearSharding( weight=P("data", "model"), out=P("data", None, "model") ), down_projection=LinearSharding( weight=P("model", "data"), out=P("data", None, "model") ), ), ), ln_final=RMSNormSharding( weight=P( None, ) ), lm_head=LinearSharding(weight=P("model", None), out=P("data", None, None)),)Simulate multi-device training on CPU
Another interesting thing to note is how to simulate multi-device training runs on a CPU only device. This can be very useful, for example, to verify that the sharding and other aspects of the training pipeline actually work without running it on expensive multi-GPU pods. This can be done by passing an XLA flag in the command line before starting the pipeline, such as the following:
…SHARDING_STRATEGY="fsdp_tp"
XLA_FLAGS="--xla_force_host_platform_device_count=8"TRAIN_CMD="uv run llm_with_jax_practice/train_main.py"
XLA_FLAGS="${XLA_FLAGS}" ${TRAIN_CMD} \ --checkpoint_dir="${CHECKPOINT_DIR}" \ … --sharding_strategy="${SHARDING_STRATEGY}"One can also set the flag in Python, but that would have to be done before importing JAX, which basically means the flag setting has to be hardcoded (since command line flag parsing typically will be in main). It requires separate pipeline implementations for different use cases, which looks very bad.
Sharp bit: black box operators erasing sharding info
Normally the sharding info of the arrays throughout the training steps will be kept and transformed in reasonable ways, but I found when the operator is a black boxed to JAX, such as FlashAttention as implemented in CUDA, there could be situations where the output of these “black box operators” (from the eyes of JAX) will erase the sharding info. This is what happens with jax.nn.dot_product_attention when implementation is cudnn, which is an implementation of FlashAttention.
This will cause trouble in downstream operations, namely the out projection in the self-attention block. With the CUDNN implementation and FSDP+TP sharding, the out projection would fail with this error message
`jax._src.source_info_util.JaxStackTraceBeforeTransformation: jax._src.core.ShardingTypeError: dot_general requires contracting dimensions to have consistent sharding, got ('model',) and ('data',).`The fix is also relatively easy, which is to explicit the sharding of the output of the CUDNN implementation using jax.sharding.shard, such as what is being done here.
Closing
With the implementation above, I can verify it actually works by running FSDP+TP locally on my laptop (integration test) and show quick loss curves like this (from this run)

Note that I’ve glossed over a huge deal of details in this post. In particular, how to compute the constraints on the batch size and tensor parallelism’s axis size in order to stay in the compute bound regime. For those details, I would highly recommend the training section of the scaling book.
Next, we can move to muP and scaling law ✌️