[go: nahoru, domu]

Skip to content

Commit

Permalink
fix redefinition of unused function issue.
Browse files Browse the repository at this point in the history
  • Loading branch information
TeaPoly committed Jun 27, 2024
1 parent 194a4ce commit 709f545
Showing 1 changed file with 0 additions and 101 deletions.
101 changes: 0 additions & 101 deletions mmfreelm/modules/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -725,107 +725,6 @@ def backward(ctx, dout, *args):
None,
)

@staticmethod
@contiguous
def forward(
ctx,
x,
norm_weight,
norm_bias,
linear_weight,
linear_bias,
residual=None,
eps=1e-6,
prenorm=False,
residual_in_fp32=False,
is_rms_norm=False,
):
x_shape_og = x.shape
# reshape input data into 2D tensor
x = x.reshape(-1, x.shape[-1])
if residual is not None:
assert residual.shape == x_shape_og
residual = residual.reshape(-1, residual.shape[-1])
residual_dtype = (
residual.dtype
if residual is not None
else (torch.float32 if residual_in_fp32 else None)
)
y, mean, rstd, residual_out = _layer_norm_fwd(
x,
norm_weight,
norm_bias,
eps,
residual,
out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_gpu_dtype(),
residual_dtype=residual_dtype,
is_rms_norm=is_rms_norm,
)
y = y.reshape(x_shape_og)
dtype = torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
linear_shaweight = linear_weight.to(dtype)
linear_bias = linear_bias.to(
dtype) if linear_bias is not None else None

linear_weight_a = weight_quant(linear_weight)
y = activation_quant(y)
out = F.linear(y.to(linear_weight.dtype), linear_weight_a, linear_bias)
ctx.save_for_backward(residual_out, norm_weight,
norm_bias, linear_weight, mean, rstd)
# We don't store y, will be recomputed in the backward pass to save memory

ctx.x_shape_og = x_shape_og
ctx.eps = eps
ctx.is_rms_norm = is_rms_norm
ctx.has_residual = residual is not None
ctx.prenorm = prenorm
ctx.x_dtype = x.dtype
ctx.linear_bias_is_none = linear_bias is None
return out if not prenorm else (out, residual_out.reshape(x_shape_og))

@staticmethod
@contiguous
def backward(ctx, dout, *args):
x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
dout = dout.reshape(-1, dout.shape[-1])
linear_weight = weight_quant(linear_weight.t())
dy = F.linear(dout, linear_weight)
dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
assert dy.shape == x.shape
if ctx.prenorm:
dresidual = args[0]
dresidual = dresidual.reshape(-1, dresidual.shape[-1])
assert dresidual.shape == x.shape
else:
dresidual = None
dx, dnorm_weight, dnorm_bias, dresidual_in, y = _layer_norm_bwd(
dy,
x,
norm_weight,
norm_bias,
ctx.eps,
mean,
rstd,
dresidual,
ctx.has_residual,
ctx.is_rms_norm,
x_dtype=ctx.x_dtype,
recompute_output=True,
)
dlinear_weight = torch.einsum("bo,bi->oi", dout, y)
return (
dx.reshape(ctx.x_shape_og),
dnorm_weight,
dnorm_bias,
dlinear_weight,
dlinear_bias,
dresidual_in.reshape(ctx.x_shape_og) if ctx.has_residual else None,
None,
None,
None,
None,
)


def layer_norm_linear_fn(
x,
Expand Down

0 comments on commit 709f545

Please sign in to comment.