[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Need MPMD supporting for GPU to use pipeline parallelism training large scale model #62736

Open
MoFHeka opened this issue Jan 4, 2024 · 16 comments
Assignees
Labels
comp:dist-strat Distribution Strategy related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.15 For issues related to 2.15.x type:feature Feature requests

Comments

@MoFHeka
Copy link
MoFHeka commented Jan 4, 2024

Issue type

Feature Request

Have you reproduced the bug with TensorFlow Nightly?

Yes

Source

binary

TensorFlow version

tf 2.15

Custom code

No

OS platform and distribution

No response

Mobile device

No response

Python version

No response

Bazel version

No response

GCC/compiler version

No response

CUDA/cuDNN version

No response

GPU model and memory

No response

Current behavior?

Nowadays pipeline parallelism has been implemented in PyTorch for a long time. It's very useful for training a CTR model with Embedding pipeline or training a large language model between two machine.

Standalone code to reproduce the issue

Maybe a easy send/recv construction like tpu with xla?

Relevant log output

No response

@google-ml-butler google-ml-butler bot added the type:feature Feature requests label Jan 4, 2024
@SuryanarayanaY SuryanarayanaY added TF 2.15 For issues related to 2.15.x comp:dist-strat Distribution Strategy related issues labels Jan 5, 2024
@SuryanarayanaY
Copy link
Collaborator

Hi @MoFHeka ,

Tensorflow supports Data Parallelism only now. Model parallelism is yet to support fully for training. But with help of Dtensors we can achieve both data and model parallelism as per attached tutorial.

As per my understanding of pipeline parallelism its hybrid approach of data and model parallelism.

@MoFHeka
Copy link
Author
MoFHeka commented Jan 5, 2024

@SuryanarayanaY But apparently DTensor can't reach pipeline parallelism. Pipeline parallelism was based on sending and receiving collective operator.
As tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc show, it seems that ZeRO stage 3 and pipeline parallel were already supported by XLA.

@dathudeptrai
Copy link

@MoFHeka I always felt like the XLA and Tensorflow/Jax actually support many hidden features but never mention or write the document for it :)).

@MoFHeka
Copy link
Author
MoFHeka commented Jan 6, 2024

@dathudeptrai I really agree with that. Large projects often lead to difficulties in project management.

@SuryanarayanaY
Copy link
Collaborator

As tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc show, it seems that ZeRO stage 3 and pipeline parallel were already supported by XLA.

Hi @MoFHeka ,

I doubt it, correct me if I am wrong. I can see from Tf2.14v code auto_sharding.cc XLA supports SPMD which is for data parallelism which is supported by TF. Can you point exactly which part you are referring to that you feel that XLA supports Model parallelism or MPMD. This may help us to escalate the issue to SME and get confirmation. Thanks!

@SuryanarayanaY SuryanarayanaY added the stat:awaiting tensorflower Status - Awaiting response from tensorflower label Jan 8, 2024
@MoFHeka
Copy link
Author
MoFHeka commented Jan 8, 2024

As tensorflow/compiler/xla/hlo/experimental/auto_sharding/auto_sharding.cc show, it seems that ZeRO stage 3 and pipeline parallel were already supported by XLA.

Hi @MoFHeka ,

I doubt it, correct me if I am wrong. I can see from Tf2.14v code auto_sharding.cc XLA supports SPMD which is for data parallelism which is supported by TF. Can you point exactly which part you are referring to that you feel that XLA supports Model parallelism or MPMD. This may help us to escalate the issue to SME and get confirmation. Thanks!

@SuryanarayanaY Please check these comments. It said "This can result in a strategy similar to ZeRO stage 3. NOTE: The combination of this branch with pipeline parallel is not tested."

// NOTE: The combination of this branch with pipeline parallel is not

And please check here, TPU already support MPMD for a long time. It said "If any of the inputs/outputs have maximal sharding, then fallback to MPMD. "

// maximal sharding, then fallback to MPMD. Also fall back if any of the

@dathudeptrai
Copy link
dathudeptrai commented Jan 15, 2024

@MoFHeka I suggest you use jax instead. SOmething like this or flash-attention, FP8 training, int8 training, ... all available in jax with support from XLA (natively). TF also used XLA but kinda hard to custom.

@MoFHeka
Copy link
Author
MoFHeka commented Jan 15, 2024

@MoFHeka I suggest you use jax instead. SOmething like this or flash-attention, FP8 training, int8 training, ... all available in jax with support from XLA (natively). TF also used XLA but kinda hard to custom.

@dathudeptrai Unfortunately, JAX also doesn't support many features, such as sequence parallelism. And at the XLA level, even with JAX, many features are actually only supported by TPU.
One more important thing, I can't train the CTR model with JAX, which lacks too many things.

@dathudeptrai
Copy link

@MoFHeka NVIDIA/TransformerEngine#602

@MoFHeka
Copy link
Author
MoFHeka commented Jan 15, 2024

@MoFHeka NVIDIA/TransformerEngine#602

@dathudeptrai This really surprised me. I always thought it was difficult to split the segmentation of sequence dimension in JAX sharding process.
But there is one thing I am not sure about, if I am going to use pipeline parallel training LLM with JAX, should I use a ray engine like alpa or a JAX native one? JAX doesn't seem to have a good software library that supports all accelerations right now.

@dathudeptrai
Copy link
dathudeptrai commented Jan 15, 2024

@MoFHeka Yeah. Generally speaking, coding in jax is harder than pytorch and a bit easier than TF. About low level customization, I think jax is better. Performance wise in my experiments showed that jax is better than pytorch :). Even deepspeed + Flash-attention-2 + Pytorch still not as good as jax :)).

You can refer some opensource to see how you can custom the paralelism training in jax. https://github.com/alpa-projects/alpa, this I called Deepspeed for Jax :).

@MoFHeka
Copy link
Author
MoFHeka commented Jan 16, 2024

@dathudeptrai Unfortunately, due to the lack of sequence parallelism, the compute utilization of Alpa is lower than that of Megatron with the same tensor parallelism optimization. Because Alpa use too much device memory when using TP, which leads to a smaller batch size.

Besides, the CTR models really can't be trained with Jax. The Jax ecosystem of online services, data processing, and other components (such as Keras) is way too far behind TF.

@dathudeptrai
Copy link
dathudeptrai commented Jan 16, 2024

@MoFHeka Why not use both TF and Jax at the same time :), you can call jax code in TF code nowadays. Also please check out some advanced attention techniques recently introduced in jax (https://github.com/lhao499/large-sequence-modeling/tree/main).

I personally think the biggest problem with both TF and Jax is documentations :)).

@MoFHeka
Copy link
Author
MoFHeka commented Jan 16, 2024

@MoFHeka Why not use both TF and Jax at the same time :), you can call jax code in TF code nowadays. Also please check out some advanced attention techniques recently introduced in jax (https://github.com/lhao499/large-sequence-modeling/tree/main).

@dathudeptrai Yes, that's right, the Jax kernel can be used in TF code, although there’s no big difference between Jax kernels and Keras layers with XLA.
But the problem is that the pipeline parallelism capabilities of JAX cannot be used in TF. TF currently lacks pipeline parallelism components.

Even in recent updates, DTensor is used to support tensor parallelism. But in the training of the CTR model, one of the biggest usage scenarios of TF, what is more needed is the ability of pipeline parallelism.

@MoFHeka
Copy link
Author
MoFHeka commented Jan 19, 2024

@SuryanarayanaY Hi~? Is there any way to implement pipeline training in tensorflow?
'tf.distribute.experimental.rpc.Server' with 'server.register' looks like a good choice, but I'm not sure.

@MoFHeka
Copy link
Author
MoFHeka commented Jun 19, 2024

@SuryanarayanaY Any progress? Tensorflow seems to be way behind in its competition with Pytorch...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
comp:dist-strat Distribution Strategy related issues stat:awaiting tensorflower Status - Awaiting response from tensorflower TF 2.15 For issues related to 2.15.x type:feature Feature requests
Projects
None yet
Development

No branches or pull requests

3 participants