[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance Regression from commit 7dcd870 #22683

Closed
2 of 4 tasks
fpgaminer opened this issue Apr 10, 2023 · 7 comments
Closed
2 of 4 tasks

Performance Regression from commit 7dcd870 #22683

fpgaminer opened this issue Apr 10, 2023 · 7 comments

Comments

@fpgaminer
Copy link
Contributor

System Info

  • transformers version: 4.28.0.dev0 (656e869)
  • Platform: Linux-5.15.0-67-generic-x86_64-with-glibc2.35
  • Python version: 3.10.10
  • Huggingface_hub version: 0.13.4
  • Safetensors version: 0.3.0
  • PyTorch version (GPU?): 2.0.0 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using GPU in script?: True
  • Using distributed or parallel set-up in script?: False

Who can help?

@ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

I have a benchmark script which benchmarks the generation speed of different LLaMA models. Before commit 7dcd870 my generation speed averaged around 48 tokens/s in ideal cases, RTX 3090. After that commit the average speed is 43 tokens/s.

The specific issue seems to be the change to apply_rotary_pos_emb. My guess is the change from a rather simple slicing of two Tensors to a scatter-gather.

To test my theory I patched apply_rotary_pos_emb to its pre 7dcd870 state, and minimally modified LlamaAttention accordingly. No other modifications. Speed jumped back to 48 tokens/s.

The problem should apply generally, but the specific script I'm using is: https://github.com/fpgaminer/GPTQ-triton/blob/99ec4a3adb7fad9de33ff026bbfb64cbb3bab2f8/benchmark_generate.py

Expected behavior

I would not expect a 10% drop in performance.

@sgugger
Copy link
Collaborator
sgugger commented Apr 10, 2023

cc @gante and @ArthurZucker

@gante
Copy link
Member
gante commented Apr 14, 2023

@fpgaminer commit 7dcd870 fixes generation when there is padding in the input (which is almost always the case for batch_size>1). It's natural that it introduces slowdowns, as the correct behavior implies changing to the tensor gathering you mentioned :)

We don't optimize for performance but rather for correctness. To skip this gathering while remaining correct, .generate() would need to be rewritten to dynamically squeeze padding and evict completed rows, which is something we have in our plans for the next months.

Meanwhile, is there anything else we can help you with?

@fpgaminer
Copy link
Contributor Author

That's fair, though a 10% performance hit is rather painful.

To that end, here's my attempt to optimize apply_rotary_pos_emb:

def ref_apply_rotary_pos_emb(q, k, cos, sin, position_ids):
	gather_indices = position_ids[:, None, :, None]  # [bs, 1, seq_len, 1]
	gather_indices = gather_indices.repeat(1, cos.shape[1], 1, cos.shape[3])
	cos = torch.gather(cos.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
	sin = torch.gather(sin.repeat(gather_indices.shape[0], 1, 1, 1), 2, gather_indices)
	q_embed = (q * cos) + (rotate_half(q) * sin)
	k_embed = (k * cos) + (rotate_half(k) * sin)
	return q_embed, k_embed

def fast_apply_rotary_pos_emb(q, k, cos, sin, position_ids):
	cos = cos.squeeze((0, 1))  # [seq_len, dim]
	sin = sin.squeeze((0, 1))  # [seq_len, dim]
	cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
	sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]
	q_embed = (q * cos) + (rotate_half(q) * sin)
	k_embed = (k * cos) + (rotate_half(k) * sin)
	return q_embed, k_embed

def test_foo(B, L):
	cos = torch.randn(1, 1, 2048, 128, dtype=torch.float16, device='cuda')
	sin = torch.randn(1, 1, 2048, 128, dtype=torch.float16, device='cuda')
	position_ids = torch.randint(0, 2048, (B, L), dtype=torch.int64, device='cuda')

	q = torch.randn(B, 32, L, 128, dtype=torch.float16, device='cuda')
	k = torch.randn(B, 32, L, 128, dtype=torch.float16, device='cuda')

	# Verify
	ref = ref_apply_rotary_pos_emb(q, k, cos, sin, position_ids)
	fast = fast_apply_rotary_pos_emb(q, k, cos, sin, position_ids)
	assert torch.equal(ref[0], fast[0])
	assert torch.equal(ref[1], fast[1])

	# Benchmark
	ref_ms, ref_min_ms, ref_max_ms = triton.testing.do_bench(lambda: ref_apply_rotary_pos_emb(q, k, cos, sin, position_ids))
	fast_ms, fast_min_ms, fast_max_ms = triton.testing.do_bench(lambda: fast_apply_rotary_pos_emb(q, k, cos, sin, position_ids))

	speedup = ref_ms * 100 / fast_ms
	print(f'{B} | {L:3d} |  {ref_ms:.6f} | {fast_ms:.6f} | {speedup:.2f}%')


print('B |  L  |    ref    |   fast   | speedup')
for B in [1, 2, 4, 8]:
	for L in [1, 2, 4, 8, 10, 100]:
		test_foo(B, L)

Output:

B |  L  |    ref    |   fast   | speedup
1 |   1 |  0.043008 | 0.035840 | 120.00%
1 |   2 |  0.044032 | 0.036864 | 119.44%
1 |   4 |  0.047104 | 0.038912 | 121.05%
1 |   8 |  0.046080 | 0.039936 | 115.38%
1 |  10 |  0.048128 | 0.039936 | 120.51%
1 | 100 |  0.058368 | 0.052224 | 111.76%
2 |   1 |  0.047104 | 0.036864 | 127.78%
2 |   2 |  0.049152 | 0.039936 | 123.08%
2 |   4 |  0.050176 | 0.040960 | 122.50%
2 |   8 |  0.050176 | 0.041984 | 119.51%
2 |  10 |  0.050176 | 0.041984 | 119.51%
2 | 100 |  0.079872 | 0.070656 | 113.04%
4 |   1 |  0.051200 | 0.039936 | 128.21%
4 |   2 |  0.053248 | 0.040960 | 130.00%
4 |   4 |  0.054272 | 0.041984 | 129.27%
4 |   8 |  0.057344 | 0.045056 | 127.27%
4 |  10 |  0.057344 | 0.045056 | 127.27%
4 | 100 |  0.130048 | 0.119808 | 108.55%
8 |   1 |  0.057344 | 0.040960 | 140.00%
8 |   2 |  0.059392 | 0.041984 | 141.46%
8 |   4 |  0.062464 | 0.045056 | 138.64%

For reference, the pre 7dc870 function runs in 0.030ms on 1x1, so this isn't quite as fast but gets closer.

Would a pull request with this change be welcome? I've done my best to verify its correctness with the above code.

@gante
Copy link
Member
gante commented Apr 15, 2023

@fpgaminer that is great! Absolutely, a PR would be very welcome 🙌

(We'd be happy to integrate other optimization opportunities if you spot them, we rarely have the bandwidth to optimize our modeling code)

@aljungberg
Copy link
Contributor
aljungberg commented Apr 27, 2023

@fpgaminer commit 7dcd870 fixes generation when there is padding in the input (which is almost always the case for batch_size>1). It's natural that it introduces slowdowns, as the correct behavior implies changing to the tensor gathering you mentioned :)

Maybe there's something I'm not seeing here but Llama uses rotary positional embeddings so left padding should have no effect on the result?

Sure, the intermediate result from apply_rotary_pos_emb changes if you shift all tokens left or right, but the whole point of using relative embeddings is that they're invariant to the absolute position in terms of the final attention weight. So you can shift all tokens 50 positions to the right and the attention score between pairs of tokens will be the same, modulus any rounding errors.

Or are you saying there are cases when padding is literally inserted inside of the sequence, therefore changing the relative distances between tokens, @gante?

@gante
Copy link
Member
gante commented May 1, 2023

@aljungberg I agree with everything you wrote, rotary positional embeddings should be position-invariant. In practice, the small rounding errors compound over autoregressive text generation, leading greedy decoding (which is normally invariant wrt small fluctuations) to produce different text.

With the right position index, the error becomes much smaller, and the results become more stable regardless of padding. That's why we also added it to our high-performance text generation repo, despite the difference being quite small.

Out of curiosity, this test was failing on GPTNeoX and Llama before we added this change. In theory, it shouldn't have failed at all!

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@github-actions github-actions bot closed this as completed Jun 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants