[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add masked loads in SDD to support K that is not a multiple of BLOCK_K #12

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

sashaDoubov
Copy link
Contributor
@sashaDoubov sashaDoubov commented Jul 3, 2024

Adding load masking (inspired by the triton guide) to address non-deteminism when K % BLOCK_K != 0.

This change automatically sets the values of the loaded A and B to 0 when the bounds are outside the K dimension.

I believe this is the issue faced in #11, as the K dim = 8 in that repro, which is not a multiple of BLOCK_K=32.

@sashaDoubov
Copy link
Contributor Author
sashaDoubov commented Jul 3, 2024

@tgale96 hoping for some guidance on better testing this change, as the current tests seem targeted for K = multiple of BLOCK_K, and also target the various triton kernels supported (SDD, DSD, etc)

@sashaDoubov sashaDoubov changed the title Add masked loads in SDD implementation to support K that is not a multiple of BLOCK_K Add masked loads in SDD to support K that is not a multiple of BLOCK_K Jul 3, 2024
@tgale96
Copy link
Collaborator
tgale96 commented Jul 3, 2024

Is this a feature we need to support? I think we want K % 128 == 0 for these kernels always, since in the backward pass what was the contraction will now be one of the non-contracting dims where we have that constraint anyways?

For less stringent shape constraints, I recommend using the grouped code path in MegaBlocks.

@sashaDoubov
Copy link
Contributor Author

@tgale96 that makes sense. The issue in #11 was initially encountered here: databricks/megablocks#115, I assume then the underlying fix should actually be adding warnings for hidden_size/ffn_hidden_size in the megablocks repo, rather than adding support in stk? Or is that already being handled in megablocks (seems like it isn't given the min repro)

@tgale96
Copy link
Collaborator
tgale96 commented Jul 3, 2024

Yes, if this non-divisibility is actually the root cause then we are missing an assertion. From a quick glance it wasn't obvious to me that that issue was this?

@sashaDoubov
Copy link
Contributor Author

While I'm not 100% sure that this is the underlying issue, I've done some experiments around this and it seems to point to it (see next comment). To clarify, we do perform padding as part of the creation of the topology matrix, and that is divisible by 128, however, my understanding is that this K dimension is actually not part of that topology matrix assertion.

K is passed into the triton kernel here:

M, K = lhs.shape
rather than being part of the padded topology matrix created in megablocks, so if x or w have a hidden_size not divisible by 32, I believe we encounter non-determinism.

@sashaDoubov
Copy link
Contributor Author

I extended the min-repro from the original megablocks issue, where I add test cases for hidden_sizes = [32, 64, 8, 48] and find that we have determinism when hidden_size = [32, 64] and non-determinism for [8, 48].

import io
import os
import sys
import types

import stk
import torch
from megablocks.layers.activation_fn import act_fn
from megablocks.layers.arguments import Arguments
from megablocks.layers.dmoe import dMoE
from megablocks.layers.mlp import resolve_dtensor
from torch import distributed as dist


def nonempty(t: torch.Tensor, show_features: int = 8) -> torch.Tensor:
    """Treat the last dim as features, gather all non-empty features to a 2D tensor.

    Args:
        t: Tensor of shape (..., feature_count).
        show_features: The number of features to show. If None, show all.

    Returns:
        2D tensor of shape (nonempty_count, show_features).
    """
    if not isinstance(t, torch.Tensor):
        t = t.data
    t = t.reshape(-1, t.shape[-1])  # Reshape to (-1, d)
    t = t[(t != 0).any(dim=1)]  # Remove all rows that are all 0
    return t[..., :show_features]


def are_matrices_equal(a: stk.Matrix, b: stk.Matrix) -> bool:
    """Check if two matrices are equal."""
    return (
        (a.row_indices == b.row_indices).all()
        and (a.data == b.data).all()
        and (a.column_indices == b.column_indices).all()
        and (a.offsets == b.offsets).all()
        and (a.column_indices_t == b.column_indices_t).all()
        and (a.offsets_t == b.offsets_t).all()
        and (a.block_offsets_t == b.block_offsets_t).all()
    )


def glu_forward(self, x, topo):
    self.act_dict = {}
    if self.args.memory_optimized_mlp:
        raise NotImplementedError(
            "Memory optimized implementation not yet supported with GLU with sparse kernels."
        )

    w1, v1, w2 = (
        self.scale_grad(self.w1),
        self.scale_grad(self.v1),
        self.scale_grad(self.w2),
    )
    w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1), resolve_dtensor(w2)

    # Compute the GLU.
    self.act_dict["x"] = x
    self.act_dict["w1_resolved"] = w1
    x1 = stk.ops.sdd(x, w1.t(), topo)
    x2 = stk.ops.sdd(x, v1.t(), topo)

    activation_fn_out = act_fn(x1, self.args.activation_fn)
    x1 = stk.ops.mul(activation_fn_out, x2)

    output = stk.ops.dsd(x1, w2)
    return output


def try_sdd(dim: int) -> tuple[bool, str]:
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12362"
    if not dist.is_initialized():
        dist.init_process_group(backend="gloo", rank=0, world_size=1)

    megablocks_args = Arguments(
        hidden_size=dim,
        ffn_hidden_size=128,
        bias=False,
        return_bias=False,
        activation_fn=torch.nn.functional.silu,
        moe_num_experts=2,
        moe_top_k=1,
        moe_loss_weight=0.05,
        moe_normalize_expert_weights=1.0,
        moe_jitter_eps=0.0,
        mlp_type="glu",
        mlp_impl="sparse",
        moe_expert_model_parallelism=False,
        expert_parallel_group=None,
        fp16=False,
        bf16=True,
        device=torch.device("cuda"),
    )
    dmoe_ = dMoE(megablocks_args)
    dmoe_.experts.mlp.forward = types.MethodType(glu_forward, dmoe_.experts.mlp)

    input_ = torch.randn([1, 2, dim], dtype=torch.bfloat16, device=torch.device("cuda"))
    dmoe_(input_)

    topo = stk.Matrix(
        (256, 256),
        torch.empty((2, 128, 128), device="meta", dtype=torch.bfloat16),
        torch.tensor([0, 1], device="cuda:0", dtype=torch.int16),
        torch.tensor([0, 1], device="cuda:0", dtype=torch.int16),
        torch.tensor([0, 1, 2], device="cuda:0", dtype=torch.int32),
        torch.tensor([0, 1], device="cuda:0", dtype=torch.int16),
        torch.tensor([0, 1, 2], device="cuda:0", dtype=torch.int32),
        torch.tensor([0, 1], device="cuda:0", dtype=torch.int32),
    )

    x = dmoe_.experts.mlp.act_dict["x"]
    w = dmoe_.experts.mlp.act_dict["w1_resolved"]
    x1 = stk.ops.sdd(x, w.t(), topo)

    x_clone = x.clone()
    x1_clone = stk.ops.sdd(x_clone, w.t(), topo)
    equal = are_matrices_equal(x1_clone, x1)

    with io.StringIO() as output:
        sys.stdout = output

        print("Input X is the same:", (x_clone == x).all())
        print("SDD output is the same:", equal)
        print()

        print("Breakdown of SDD output:")
        print("Shape:", x1_clone.shape == x1.shape)
        print("Data:", x1_clone.data.allclose(x1.data))
        print("Row indices:", x1_clone.row_indices == x1.row_indices)
        print("Column indices:", x1_clone.column_indices == x1.column_indices)
        print("Offsets:", x1_clone.offsets == x1.offsets)
        print("Block offsets:", x1_clone.block_offsets_t == x1.block_offsets_t)
        print()

        print("Breakdown of SDD output data:")
        print("Nonempty elements:", nonempty(x1).shape, nonempty(x1_clone).shape)
        print("Per-row mean:", nonempty(x1).mean(dim=1), nonempty(x1_clone).mean(dim=1))
        print(
            "Cross-equality:",
            (nonempty(x1)[None, :, :] - nonempty(x1_clone)[:, None, :]).sum(-1),
        )

        sys.stdout = sys.__stdout__
        output_str = output.getvalue()

    return equal, output_str


def main() -> None:
    repetition_count = 10
    for dim in [32, 64, 8, 48]:
        for trial_id in range(repetition_count):
            equal, output_str = try_sdd(dim)
            if not equal:
                print(f"Trial {trial_id} with dim {dim} failed.")
                print(output_str)
                break
        else:
            print(f"All {repetition_count} repetitions with dim {dim} passed.")


if __name__ == "__main__":
    main()

@tgale96
Copy link
Collaborator
tgale96 commented Jul 3, 2024

Ah, ya - I didn't realize that the matrices in #11 were (256, 8). That's going to be a problem and we should be asserting on it, probably in STK.

@sashaDoubov
Copy link
Contributor Author

makes sense. I'm a bit less familiar with STK in general, would this need to be an assert for both SDD and DDS? Taking a quick look, it seems like K is taken from the shape of the dense matrix in those kernels, so could run into that issue.

As an aside, the proposed solution I had was the masked load that masks out beyond the boundaries of BLOCK_K, is that behaviour that you would want to support? or is assertion the preferred way that you would see this resolved

@sashaDoubov
Copy link
Contributor Author

hey @tgale96 , should I close this PR in favour of adding assertions? or would it make sense to support masking for k % BLOCK_K != 0.

@tgale96
Copy link
Collaborator
tgale96 commented Jul 16, 2024

Hey! Sorry for the delay. I think we want the assertion, unless you have a use case for supporting k % BLOCK_K != 0 (e.g., you only care about running that case for the forward pass).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants