-
Notifications
You must be signed in to change notification settings - Fork 74k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Need MPMD supporting for GPU to use pipeline parallelism training large scale model #62736
Comments
Hi @MoFHeka , Tensorflow supports Data Parallelism only now. Model parallelism is yet to support fully for training. But with help of Dtensors we can achieve both data and model parallelism as per attached tutorial. As per my understanding of pipeline parallelism its hybrid approach of data and model parallelism. |
@SuryanarayanaY But apparently DTensor can't reach pipeline parallelism. Pipeline parallelism was based on sending and receiving collective operator. |
@MoFHeka I always felt like the XLA and Tensorflow/Jax actually support many hidden features but never mention or write the document for it :)). |
@dathudeptrai I really agree with that. Large projects often lead to difficulties in project management. |
Hi @MoFHeka , I doubt it, correct me if I am wrong. I can see from Tf2.14v code auto_sharding.cc XLA supports SPMD which is for data parallelism which is supported by TF. Can you point exactly which part you are referring to that you feel that XLA supports Model parallelism or MPMD. This may help us to escalate the issue to SME and get confirmation. Thanks! |
@SuryanarayanaY Please check these comments. It said "This can result in a strategy similar to ZeRO stage 3. NOTE: The combination of this branch with pipeline parallel is not tested." tensorflow/tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc Line 3249 in 99d80a9
And please check here, TPU already support MPMD for a long time. It said "If any of the inputs/outputs have maximal sharding, then fallback to MPMD. " tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/tpu_sharding_identification_pass.cc Line 580 in d032157
|
@MoFHeka I suggest you use jax instead. SOmething like this or flash-attention, FP8 training, int8 training, ... all available in jax with support from XLA (natively). TF also used XLA but kinda hard to custom. |
@dathudeptrai Unfortunately, JAX also doesn't support many features, such as sequence parallelism. And at the XLA level, even with JAX, many features are actually only supported by TPU. |
@dathudeptrai This really surprised me. I always thought it was difficult to split the segmentation of sequence dimension in JAX sharding process. |
@MoFHeka Yeah. Generally speaking, coding in jax is harder than pytorch and a bit easier than TF. About low level customization, I think jax is better. Performance wise in my experiments showed that jax is better than pytorch :). Even deepspeed + Flash-attention-2 + Pytorch still not as good as jax :)). You can refer some opensource to see how you can custom the paralelism training in jax. https://github.com/alpa-projects/alpa, this I called |
@dathudeptrai Unfortunately, due to the lack of sequence parallelism, the compute utilization of Alpa is lower than that of Megatron with the same tensor parallelism optimization. Because Alpa use too much device memory when using TP, which leads to a smaller batch size. Besides, the CTR models really can't be trained with Jax. The Jax ecosystem of online services, data processing, and other components (such as Keras) is way too far behind TF. |
@MoFHeka Why not use both TF and Jax at the same time :), you can call jax code in TF code nowadays. Also please check out some advanced attention techniques recently introduced in jax (https://github.com/lhao499/large-sequence-modeling/tree/main). I personally think the biggest problem with both TF and Jax is documentations :)). |
@dathudeptrai Yes, that's right, the Jax kernel can be used in TF code, although there’s no big difference between Jax kernels and Keras layers with XLA. Even in recent updates, DTensor is used to support tensor parallelism. But in the training of the CTR model, one of the biggest usage scenarios of TF, what is more needed is the ability of pipeline parallelism. |
@SuryanarayanaY Hi~? Is there any way to implement pipeline training in tensorflow? |
@SuryanarayanaY Any progress? Tensorflow seems to be way behind in its competition with Pytorch... |
Issue type
Feature Request
Have you reproduced the bug with TensorFlow Nightly?
Yes
Source
binary
TensorFlow version
tf 2.15
Custom code
No
OS platform and distribution
No response
Mobile device
No response
Python version
No response
Bazel version
No response
GCC/compiler version
No response
CUDA/cuDNN version
No response
GPU model and memory
No response
Current behavior?
Nowadays pipeline parallelism has been implemented in PyTorch for a long time. It's very useful for training a CTR model with Embedding pipeline or training a large language model between two machine.
Standalone code to reproduce the issue
Maybe a easy send/recv construction like tpu with xla?
Relevant log output
No response
The text was updated successfully, but these errors were encountered: