[go: nahoru, domu]

Skip to content

Commit

Permalink
move fsdp handling to accelerate (huggingface#23158)
Browse files Browse the repository at this point in the history
* mixed precision support via accelerate

* fix issues

* fix for the sharded ddp case

* fix flax and tf failing tests

* `refactor the place to create `Accelerator` object

* move ddp prep to accelerate

* fix 😅

* resolving comments

* move fsdp handling to accelerate

* fixex

* fix saving
  • Loading branch information
pacman100 authored and sheonhan committed Jun 1, 2023
1 parent 8a81959 commit 61c92a9
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 116 deletions.
193 changes: 78 additions & 115 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,12 @@ def __init__(
# create accelerator object
self.accelerator = Accelerator()

# post accelerator creation setup
if getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
fsdp_plugin = self.accelerator.state.fsdp_plugin
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get("limit_all_gathers", False)
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", False)

# memory metrics - must set up as early as possible
self._memory_tracker = TrainerMemoryTracker(self.args.skip_memory_metrics)
self._memory_tracker.start()
Expand Down Expand Up @@ -464,7 +470,7 @@ def __init__(
self.fsdp = ShardingStrategy.NO_SHARD

self.backward_prefetch = BackwardPrefetch.BACKWARD_PRE
if "backward_prefetch" in self.args.fsdp_config and "backward_pos" in self.args.fsdp_config.get(
if "backward_prefetch" in self.args.fsdp_config and "backward_post" in self.args.fsdp_config.get(
"backward_prefetch", []
):
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST
Expand Down Expand Up @@ -1479,114 +1485,58 @@ def _wrap_model(self, model, training=True, dataloader=None):
cpu_offload=cpu_offload,
).to(self.args.device)
# Distributed training using PyTorch FSDP
elif self.fsdp is not None:
if not self.args.fsdp_config["xla"]:
# PyTorch FSDP!
from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload, MixedPrecision
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy

if FSDPOption.OFFLOAD in self.args.fsdp:
cpu_offload = CPUOffload(offload_params=True)
else:
cpu_offload = CPUOffload(offload_params=False)

auto_wrap_policy = None

if FSDPOption.AUTO_WRAP in self.args.fsdp:
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
)
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
transformer_cls_to_wrap = set()
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)
mixed_precision_policy = None
dtype = None
if self.args.fp16:
dtype = torch.float16
elif self.args.bf16:
dtype = torch.bfloat16
if dtype is not None:
mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
if type(model) != FSDP:
# XXX: Breaking the self.model convention but I see no way around it for now.
signature = inspect.signature(FSDP.__init__).parameters.keys()
kwargs = {}
for arg in ["limit_all_gathers", "forward_prefetch", "backward_prefetch"]:
if arg in signature:
kwargs[arg] = getattr(self, arg)
self.model = model = FSDP(
model,
sharding_strategy=self.fsdp,
cpu_offload=cpu_offload,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision_policy,
device_id=self.args.device,
**kwargs,
)
else:
try:
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
from torch_xla.distributed.fsdp import checkpoint_module
from torch_xla.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
)
except ImportError:
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
auto_wrap_policy = None
auto_wrapper_callable = None
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
)
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
transformer_cls_to_wrap = set()
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)
fsdp_kwargs = self.args.xla_fsdp_config
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
def auto_wrapper_callable(m, *args, **kwargs):
return FSDP(checkpoint_module(m), *args, **kwargs)

# Wrap the base model with an outer FSDP wrapper
self.model = model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable,
**fsdp_kwargs,
elif self.fsdp is not None and self.args.fsdp_config["xla"]:
try:
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
from torch_xla.distributed.fsdp import checkpoint_module
from torch_xla.distributed.fsdp.wrap import (
size_based_auto_wrap_policy,
transformer_auto_wrap_policy,
)
except ImportError:
raise ImportError("Missing XLA FSDP related module; please make sure to use torch-xla >= 2.0.")
auto_wrap_policy = None
auto_wrapper_callable = None
if self.args.fsdp_config["fsdp_min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
)
elif self.args.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
transformer_cls_to_wrap = set()
for layer_class in self.args.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]:
transformer_cls = get_module_class_from_name(model, layer_class)
if transformer_cls is None:
raise Exception("Could not find the transformer layer class to wrap in the model.")
else:
transformer_cls_to_wrap.add(transformer_cls)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
# Transformer layer class to wrap
transformer_layer_cls=transformer_cls_to_wrap,
)
fsdp_kwargs = self.args.xla_fsdp_config
if self.args.fsdp_config["xla_fsdp_grad_ckpt"]:
# Apply gradient checkpointing to auto-wrapped sub-modules if specified
def auto_wrapper_callable(m, *args, **kwargs):
return FSDP(checkpoint_module(m), *args, **kwargs)

# Wrap the base model with an outer FSDP wrapper
self.model = model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
auto_wrapper_callable=auto_wrapper_callable,
**fsdp_kwargs,
)

# Patch `xm.optimizer_step` should not reduce gradients in this case,
# as FSDP does not need gradient reduction over sharded parameters.
def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
loss = optimizer.step(**optimizer_args)
if barrier:
xm.mark_step()
return loss
# Patch `xm.optimizer_step` should not reduce gradients in this case,
# as FSDP does not need gradient reduction over sharded parameters.
def patched_optimizer_step(optimizer, barrier=False, optimizer_args={}):
loss = optimizer.step(**optimizer_args)
if barrier:
xm.mark_step()
return loss

xm.optimizer_step = patched_optimizer_step
xm.optimizer_step = patched_optimizer_step
elif is_sagemaker_dp_enabled():
model = nn.parallel.DistributedDataParallel(
model, device_ids=[int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))]
Expand Down Expand Up @@ -1796,17 +1746,26 @@ def _inner_training_loop(
if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
self._load_from_checkpoint(resume_from_checkpoint, model)

# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model
# as the model is wrapped, don't use `accelerator.prepare`
# this is for unhandled cases such as
# Fairscale Sharded DDP, FSDP-XLA, SageMaker MP/DP, DataParallel, IPEX
use_accelerator_prepare = True if model is self.model else False

if delay_optimizer_creation:
self.create_optimizer_and_scheduler(num_training_steps=max_steps)

# prepare using `accelerator` prepare
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)
if use_accelerator_prepare:
model, self.optimizer, self.lr_scheduler = self.accelerator.prepare(
self.model, self.optimizer, self.lr_scheduler
)

if getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
self.model = model

# for the rest of this function `model` is the outside model, whether it was wrapped or not
if model is not self.model:
self.model_wrapped = model

# Check if saved optimizer or scheduler states exist
self._load_optimizer_and_scheduler(resume_from_checkpoint)
Expand Down Expand Up @@ -2894,11 +2853,15 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
ShardedDDPOption.ZERO_DP_2 in self.args.sharded_ddp
or ShardedDDPOption.ZERO_DP_3 in self.args.sharded_ddp
or self.fsdp is not None
or getattr(self.accelerator.state, "fsdp_plugin", None) is not None
):
state_dict = self.model.state_dict()
if getattr(self.accelerator.state, "fsdp_plugin", None) is not None:
self.accelerator.state.fsdp_plugin.save_model(self.accelerator, self.model, output_dir)
else:
state_dict = self.model.state_dict()

if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
if self.args.should_save:
self._save(output_dir, state_dict=state_dict)
elif self.deepspeed:
# this takes care of everything as long as we aren't under zero3
if self.args.should_save:
Expand Down
28 changes: 27 additions & 1 deletion src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ class TrainingArguments:
- `"backward_pre"` : Prefetches the next set of parameters before the current set of parameter's
gradient
computation.
- `"backward_pos"` : This prefetches the next set of parameters after the current set of
- `"backward_post"` : This prefetches the next set of parameters after the current set of
parameter’s
gradient computation.
- fsdp_forward_prefetch (`bool`, *optional*, defaults to `False`)
Expand Down Expand Up @@ -1504,6 +1504,32 @@ def __post_init__(self):
if self.fsdp_config["xla_fsdp_grad_ckpt"]:
warnings.warn("`--xla_fsdp_grad_ckpt` is useful only when `--xla` is set to true.")

# accelerate integration for FSDP
if len(self.fsdp) > 0 and not self.fsdp_config["xla"]:
os.environ["ACCELERATE_USE_FSDP"] = "true"
from accelerate.utils.constants import (
FSDP_AUTO_WRAP_POLICY,
FSDP_SHARDING_STRATEGY,
)

for fsdp_option in self.fsdp:
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
# set environment variable for FSDP sharding strategy
os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1)
elif fsdp_option == FSDPOption.OFFLOAD:
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
elif fsdp_option == FSDPOption.AUTO_WRAP:
if self.fsdp_config["fsdp_min_num_params"] > 0:
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
)
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()

if self.tpu_metrics_debug:
warnings.warn(
"using `--tpu_metrics_debug` is deprecated and will be removed in version 5 of 🤗 Transformers. Use"
Expand Down

0 comments on commit 61c92a9

Please sign in to comment.