-
Notifications
You must be signed in to change notification settings - Fork 288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor sampling layer setup #1912
base: main
Are you sure you want to change the base?
Conversation
@@ -1131,6 +1134,7 @@ void LlamaBatch<T>::InitializeSampling(const GenerationState& g) | |||
} | |||
} | |||
outputs_ = std::move(outputs); | |||
sync_check_cuda_error(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why should we cudaDeviceSynchronize
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will sync only when TM_DEBUG_LEVEL=DEBUG, I add here to better calculate the time
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 unit tests have failed and may need to be fixed.
[ FAILED ] 6 tests, listed below:
[ FAILED ] SamplingDecodeTest/0.TopP, where TypeParam = float
[ FAILED ] SamplingDecodeTest/0.BatchTopP, where TypeParam = float
[ FAILED ] SamplingDecodeTest/0.InvalidArgsZeroTopP, where TypeParam = float
[ FAILED ] SamplingDecodeTest/0.InvalidArgsBatchTopPContainZero, where TypeParam = float
[ FAILED ] SamplingDecodeTest/0.InvalidArgsTopKBatchTopPContainZero, where TypeParam = float
[ FAILED ] SamplingDecodeTest/0.LocalBatchBatchTopP, where TypeParam = float
6 FAILED TESTS
uint min_top_k = runtime_top_k_size > 0 ? runtime_top_k.min<uint>() : 0; | ||
skip_all_ = false; | ||
|
||
if (runtime_top_p_size == 0 || min_top_k > 0) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why skip when min_top_k > 0
?
ref https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/text
The temperature is used for sampling during response generation, which occurs when topP and topK are applied. Temperature controls the degree of randomness in token selection. Lower temperatures are good for prompts that require a less open-ended or creative response, while higher temperatures can lead to more diverse or creative results. A temperature of 0 means that the highest probability tokens are always selected. In this case, responses for a given prompt are mostly deterministic, but a small amount of variation is still possible.
If the model returns a response that's too generic, too short, or the model gives a fallback response, try increasing the temperature.
Top-K changes how the model selects tokens for output. A top-K of 1 means the next selected token is the most probable among all tokens in the model's vocabulary (also called greedy decoding), while a top-K of 3 means that the next token is selected from among the three most probable tokens by using temperature.
For each token selection step, the top-K tokens with the highest probabilities are sampled. Then tokens are further filtered based on top-P with the final token selected using temperature sampling.
Specify a lower value for less random responses and a higher value for more random responses.
Top-P changes how the model selects tokens for output. Tokens are selected from the most (see top-K) to least probable until the sum of their probabilities equals the top-P value. For example, if tokens A, B, and C have a probability of 0.3, 0.2, and 0.1 and the top-P value is 0.5, then the model will select either A or B as the next token by using temperature and excludes C as a candidate.
Specify a lower value for less random responses and a higher value for more random responses.
highlight: Then tokens are further filtered based on top-P with the final token selected using temperature sampling.
So I think when min_top_k > 0
, we can also use top p sampling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The topk is fused kernel and wil do topp sampling.
If top_k > 0, chose topk sampling layer, otherwise chose topp sampling layer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
||
if (!skip_init_sampling) { | ||
g.max_init_ctx_len = max_context_len; | ||
g.step = max_context_len; | ||
InitializeSampling(g); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there any necessity that we move intializeSampling
from Initialize
to Forward
?
float min_top_p = runtime_top_p_size > 0 ? runtime_top_p.min<float>() : 0.0f; | ||
skip_all_ = false; | ||
|
||
if (max_top_k == 0 && min_top_p != 0.0f) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does max_top_k==0 && min_top_p==0.0f
indicate?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get it in setup_topk_runtime_args
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the relation between top_k
and n_logprobs
? Is n_logprobs
not greater than top_k? I just wonder if
skip_allaffects returning
logitsor
logprobs`
|
||
uint top_k = runtime_top_k.max<uint>(); | ||
float top_p = runtime_top_p_size == 0 ? 0.0f : runtime_top_p.getVal<float>(); | ||
// skip topk setup & forward if all top_k is zero and all top_p is not zero |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we move this comment before if (max_top_k == 0 && min_top_p != 0.0f)
?
LGTM. |
I write some suggestions in #1966 |
Motivation
setup sampling layer after forward to increase parallelism
test with llama-2-7b-chat and 1000 num_prompts
before
after