[go: nahoru, domu]

Skip to content

Commit

Permalink
fix: Tighten up on context size (bionic-gpt#156)
Browse files Browse the repository at this point in the history
  • Loading branch information
9876691 committed Dec 7, 2023
1 parent faaba88 commit 2a4cdc3
Showing 1 changed file with 114 additions and 43 deletions.
157 changes: 114 additions & 43 deletions crates/axum-server/src/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@ use db::queries::{chats, prompts};
use db::{Chat, DatasetConnection, Transaction};
use tiktoken_rs::{num_tokens_from_messages, ChatCompletionRequestMessage};

// If we are getting called from the API we'll possible have a buch of chat messaages
// that's why chat is a Vec<Message>
// For the UI they'll be just one.
pub async fn execute_prompt(
transaction: &Transaction<'_>,
prompt_id: i32,
Expand Down Expand Up @@ -98,70 +101,77 @@ async fn generate_prompt(
}
};

// This is the space we have to fill
let size_allowed = model_context_size - max_tokens;
let mut size_so_far = 0;

// Add a system message if we have one
if let Some(system_prompt) = &system_prompt {
messages.push(ChatCompletionRequestMessage {
role: "user".to_string(),
content: Some(system_prompt.clone()),
name: None,
function_call: None,
});
size_so_far = add_message(
&mut messages,
"system".to_string(),
system_prompt.clone(),
size_so_far,
size_allowed,
);
}

// Add the messages that have come from the UI or the API
// This may already overflow the context!!
for message in question.into_iter() {
messages.push(ChatCompletionRequestMessage {
role: message.role,
content: Some(message.content),
name: None,
function_call: None,
});
size_so_far = add_message(
&mut messages,
message.role,
message.content,
size_so_far,
size_allowed,
);
}

// This is the space we have to fill
let context_size = model_context_size - max_tokens;
let mut size_so_far = num_tokens_from_messages("gpt-4", &messages).unwrap();
let mut related_context: Vec<&String> = related_context.iter().rev().collect();
let mut context_so_far: String = Default::default();

// Keep adding history and context until meet the requirements of the prompt
while size_so_far < context_size {
while size_so_far < size_allowed {
// Add some relevant context
if let Some(rel_context) = related_context.pop() {
context_so_far.push_str(rel_context);
context_so_far += "\n";
if let Some(prompt) = &system_prompt {
let replaced = prompt.replace("{context_str}", &context_so_far);
messages[0].content = Some(replaced);
let size_rel_context = size_context(rel_context.to_string());

if size_so_far + size_rel_context < size_allowed {
context_so_far.push_str(rel_context);
context_so_far += "\n";
if let Some(prompt) = &system_prompt {
let replaced = prompt.replace("{context_str}", &context_so_far);
messages[0].content = Some(replaced);
}
size_so_far += size_rel_context;
}
}

size_so_far = num_tokens_from_messages("gpt-4", &messages).unwrap();

if size_so_far >= context_size {
break;
}

// Add some history
if let Some(hist) = history.pop() {
// Add the histor in before the last message
if let Some(top_message) = messages.pop() {
messages.push(ChatCompletionRequestMessage {
role: "user".to_string(),
content: Some(hist.user_request),
name: None,
function_call: None,
});
messages.push(ChatCompletionRequestMessage {
role: "assistant".to_string(),
content: hist.response,
name: None,
function_call: None,
});
size_so_far = add_message(
&mut messages,
"user".to_string(),
hist.user_request,
size_so_far,
size_allowed,
);
if let Some(response) = hist.response {
size_so_far = add_message(
&mut messages,
"assistant".to_string(),
response,
size_so_far,
size_allowed,
);
}
messages.push(top_message);
}
}

size_so_far = num_tokens_from_messages("gpt-4", &messages).unwrap();

if history.is_empty() && related_context.is_empty() {
break;
}
Expand All @@ -170,6 +180,41 @@ async fn generate_prompt(
messages
}

fn size_context(context: String) -> usize {
let request = ChatCompletionRequestMessage {
role: "".to_string(),
content: Some(context),
name: None,
function_call: None,
};
num_tokens_from_messages("gpt-4", &[request.clone()]).unwrap()
}

// Only add a message if the context doesn't overflow
fn add_message(
messages: &mut Vec<ChatCompletionRequestMessage>,
role: String,
content: String,
size_so_far: usize,
size_allowed: usize,
) -> usize {
let request = ChatCompletionRequestMessage {
role,
content: Some(content),
name: None,
function_call: None,
};

let size = num_tokens_from_messages("gpt-4", &[request.clone()]).unwrap();

if (size + size_so_far) < size_allowed {
messages.push(request);
return size_so_far + size;
}

size_so_far
}

// Query the vector database using a similarity search.
// The prompt decides how we use the datasets
async fn get_related_context(
Expand Down Expand Up @@ -332,12 +377,38 @@ mod tests {

assert!(messages.len() == 4);

dbg!(&messages);

assert!(messages[0].content == Some("You are a helpful asistant\n\nContext information is below.\n--------------------\nThis might help\n\n--------------------".to_string()));
assert!(messages[3].content == Some("How are you today?".to_string()));
}

#[tokio::test]
async fn test_with_lots_of_context() {
let messages = generate_prompt(
2048,
1024,
Some("You are a helpful asistant".to_string()),
vec![create_prompt(
"What time is it?".to_string(),
"I don't know".repeat(400),
)],
vec![Message {
role: "user".to_string(),
content: "How are you today?".to_string(),
}],
vec![
"This might help".to_string(),
"word ".repeat(100),
"test ".repeat(100),
"name ".repeat(1000),
],
)
.await;

let size_so_far = num_tokens_from_messages("gpt-4", &messages).unwrap();

assert!(size_so_far < 1024);
}

fn create_prompt(question: String, answer: String) -> Chat {
Chat {
id: 0,
Expand Down

0 comments on commit 2a4cdc3

Please sign in to comment.