-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add masked loads in SDD to support K that is not a multiple of BLOCK_K #12
base: main
Are you sure you want to change the base?
Add masked loads in SDD to support K that is not a multiple of BLOCK_K #12
Conversation
@tgale96 hoping for some guidance on better testing this change, as the current tests seem targeted for K = multiple of BLOCK_K, and also target the various triton kernels supported (SDD, DSD, etc) |
Is this a feature we need to support? I think we want For less stringent shape constraints, I recommend using the grouped code path in MegaBlocks. |
@tgale96 that makes sense. The issue in #11 was initially encountered here: databricks/megablocks#115, I assume then the underlying fix should actually be adding warnings for hidden_size/ffn_hidden_size in the megablocks repo, rather than adding support in stk? Or is that already being handled in megablocks (seems like it isn't given the min repro) |
Yes, if this non-divisibility is actually the root cause then we are missing an assertion. From a quick glance it wasn't obvious to me that that issue was this? |
While I'm not 100% sure that this is the underlying issue, I've done some experiments around this and it seems to point to it (see next comment). To clarify, we do perform padding as part of the creation of the topology matrix, and that is divisible by 128, however, my understanding is that this K dimension is actually not part of that topology matrix assertion. K is passed into the triton kernel here: stk/stk/backend/triton_kernels.py Line 318 in e5c47f6
x or w have a hidden_size not divisible by 32, I believe we encounter non-determinism.
|
I extended the min-repro from the original megablocks issue, where I add test cases for hidden_sizes = [32, 64, 8, 48] and find that we have determinism when hidden_size = [32, 64] and non-determinism for [8, 48].
|
Ah, ya - I didn't realize that the matrices in #11 were |
makes sense. I'm a bit less familiar with STK in general, would this need to be an assert for both SDD and DDS? Taking a quick look, it seems like K is taken from the shape of the dense matrix in those kernels, so could run into that issue. As an aside, the proposed solution I had was the masked load that masks out beyond the boundaries of BLOCK_K, is that behaviour that you would want to support? or is assertion the preferred way that you would see this resolved |
hey @tgale96 , should I close this PR in favour of adding assertions? or would it make sense to support masking for |
Hey! Sorry for the delay. I think we want the assertion, unless you have a use case for supporting |
Adding load masking (inspired by the triton guide) to address non-deteminism when
K % BLOCK_K != 0
.This change automatically sets the values of the loaded A and B to 0 when the bounds are outside the K dimension.
I believe this is the issue faced in #11, as the
K dim = 8
in that repro, which is not a multiple ofBLOCK_K=32
.