diff --git a/crates/axum-server/src/prompt.rs b/crates/axum-server/src/prompt.rs index 13a650d20..ebac2c01b 100644 --- a/crates/axum-server/src/prompt.rs +++ b/crates/axum-server/src/prompt.rs @@ -4,6 +4,9 @@ use db::queries::{chats, prompts}; use db::{Chat, DatasetConnection, Transaction}; use tiktoken_rs::{num_tokens_from_messages, ChatCompletionRequestMessage}; +// If we are getting called from the API we'll possible have a buch of chat messaages +// that's why chat is a Vec +// For the UI they'll be just one. pub async fn execute_prompt( transaction: &Transaction<'_>, prompt_id: i32, @@ -98,70 +101,77 @@ async fn generate_prompt( } }; + // This is the space we have to fill + let size_allowed = model_context_size - max_tokens; + let mut size_so_far = 0; + + // Add a system message if we have one if let Some(system_prompt) = &system_prompt { - messages.push(ChatCompletionRequestMessage { - role: "user".to_string(), - content: Some(system_prompt.clone()), - name: None, - function_call: None, - }); + size_so_far = add_message( + &mut messages, + "system".to_string(), + system_prompt.clone(), + size_so_far, + size_allowed, + ); } + // Add the messages that have come from the UI or the API + // This may already overflow the context!! for message in question.into_iter() { - messages.push(ChatCompletionRequestMessage { - role: message.role, - content: Some(message.content), - name: None, - function_call: None, - }); + size_so_far = add_message( + &mut messages, + message.role, + message.content, + size_so_far, + size_allowed, + ); } - // This is the space we have to fill - let context_size = model_context_size - max_tokens; - let mut size_so_far = num_tokens_from_messages("gpt-4", &messages).unwrap(); let mut related_context: Vec<&String> = related_context.iter().rev().collect(); let mut context_so_far: String = Default::default(); // Keep adding history and context until meet the requirements of the prompt - while size_so_far < context_size { + while size_so_far < size_allowed { // Add some relevant context if let Some(rel_context) = related_context.pop() { - context_so_far.push_str(rel_context); - context_so_far += "\n"; - if let Some(prompt) = &system_prompt { - let replaced = prompt.replace("{context_str}", &context_so_far); - messages[0].content = Some(replaced); + let size_rel_context = size_context(rel_context.to_string()); + + if size_so_far + size_rel_context < size_allowed { + context_so_far.push_str(rel_context); + context_so_far += "\n"; + if let Some(prompt) = &system_prompt { + let replaced = prompt.replace("{context_str}", &context_so_far); + messages[0].content = Some(replaced); + } + size_so_far += size_rel_context; } } - size_so_far = num_tokens_from_messages("gpt-4", &messages).unwrap(); - - if size_so_far >= context_size { - break; - } - // Add some history if let Some(hist) = history.pop() { // Add the histor in before the last message if let Some(top_message) = messages.pop() { - messages.push(ChatCompletionRequestMessage { - role: "user".to_string(), - content: Some(hist.user_request), - name: None, - function_call: None, - }); - messages.push(ChatCompletionRequestMessage { - role: "assistant".to_string(), - content: hist.response, - name: None, - function_call: None, - }); + size_so_far = add_message( + &mut messages, + "user".to_string(), + hist.user_request, + size_so_far, + size_allowed, + ); + if let Some(response) = hist.response { + size_so_far = add_message( + &mut messages, + "assistant".to_string(), + response, + size_so_far, + size_allowed, + ); + } messages.push(top_message); } } - size_so_far = num_tokens_from_messages("gpt-4", &messages).unwrap(); - if history.is_empty() && related_context.is_empty() { break; } @@ -170,6 +180,41 @@ async fn generate_prompt( messages } +fn size_context(context: String) -> usize { + let request = ChatCompletionRequestMessage { + role: "".to_string(), + content: Some(context), + name: None, + function_call: None, + }; + num_tokens_from_messages("gpt-4", &[request.clone()]).unwrap() +} + +// Only add a message if the context doesn't overflow +fn add_message( + messages: &mut Vec, + role: String, + content: String, + size_so_far: usize, + size_allowed: usize, +) -> usize { + let request = ChatCompletionRequestMessage { + role, + content: Some(content), + name: None, + function_call: None, + }; + + let size = num_tokens_from_messages("gpt-4", &[request.clone()]).unwrap(); + + if (size + size_so_far) < size_allowed { + messages.push(request); + return size_so_far + size; + } + + size_so_far +} + // Query the vector database using a similarity search. // The prompt decides how we use the datasets async fn get_related_context( @@ -332,12 +377,38 @@ mod tests { assert!(messages.len() == 4); - dbg!(&messages); - assert!(messages[0].content == Some("You are a helpful asistant\n\nContext information is below.\n--------------------\nThis might help\n\n--------------------".to_string())); assert!(messages[3].content == Some("How are you today?".to_string())); } + #[tokio::test] + async fn test_with_lots_of_context() { + let messages = generate_prompt( + 2048, + 1024, + Some("You are a helpful asistant".to_string()), + vec![create_prompt( + "What time is it?".to_string(), + "I don't know".repeat(400), + )], + vec![Message { + role: "user".to_string(), + content: "How are you today?".to_string(), + }], + vec![ + "This might help".to_string(), + "word ".repeat(100), + "test ".repeat(100), + "name ".repeat(1000), + ], + ) + .await; + + let size_so_far = num_tokens_from_messages("gpt-4", &messages).unwrap(); + + assert!(size_so_far < 1024); + } + fn create_prompt(question: String, answer: String) -> Chat { Chat { id: 0,