From b36edc10fa46bddf366486728d94cf58d8a1a089 Mon Sep 17 00:00:00 2001 From: 9876691 <36966+9876691@users.noreply.github.com> Date: Wed, 21 Feb 2024 15:58:44 +0100 Subject: [PATCH] fix: Use embeddings model attached to dataset. (#378) --- .devcontainer/.bash_aliases | 2 +- crates/asset-pipeline/input.css | 2 +- crates/axum-server/src/prompt.rs | 37 +++--- crates/axum-server/src/prompts/form.rs | 31 ++--- crates/axum-server/src/team/new_team.rs | 4 +- crates/daisy-rsx/check_box.rs | 105 ++++++++++++++++ crates/daisy-rsx/lib.rs | 2 + crates/db/authz.rs | 4 +- crates/db/lib.rs | 3 +- .../20240221131717_refactor_prompts.sql | 14 +++ crates/db/queries/chunks.sql | 18 ++- crates/db/queries/datasets.sql | 4 + crates/db/queries/prompts.sql | 31 ++--- crates/db/vector_search.rs | 83 +++--------- crates/embeddings-api/src/lib.rs | 24 ++-- crates/pipeline-job/src/main.rs | 7 +- crates/ui-pages/prompts/dataset_connection.rs | 29 ----- crates/ui-pages/prompts/form.rs | 119 ++++++++---------- crates/ui-pages/prompts/index.rs | 14 +-- crates/ui-pages/prompts/mod.rs | 18 --- 20 files changed, 283 insertions(+), 268 deletions(-) create mode 100644 crates/daisy-rsx/check_box.rs create mode 100644 crates/db/migrations/20240221131717_refactor_prompts.sql delete mode 100644 crates/ui-pages/prompts/dataset_connection.rs diff --git a/.devcontainer/.bash_aliases b/.devcontainer/.bash_aliases index 0a6c3795e..ce55c1c53 100644 --- a/.devcontainer/.bash_aliases +++ b/.devcontainer/.bash_aliases @@ -27,7 +27,7 @@ alias watch-app='mold -run cargo watch --workdir /workspace/ -w crates/daisy-rsx alias wa=watch-app alias watch-pipeline='npm install --prefix /workspace/crates/asset-pipeline && npm run start --prefix /workspace/crates/asset-pipeline' alias wp=watch-pipeline -alias watch-embeddings='mold -run cargo watch --workdir /workspace/ -w crates/open-api -w crates/pipeline-job --no-gitignore -x "run --bin pipeline-job"' +alias watch-embeddings='mold -run cargo watch --workdir /workspace/ -w crates/embeddings-api -w crates/pipeline-job --no-gitignore -x "run --bin pipeline-job"' alias we=watch-embeddings alias watch-tailwind='cd /workspace/crates/asset-pipeline && npx tailwindcss -i ./input.css -o ./dist/output.css --watch' alias wt=watch-tailwind diff --git a/crates/asset-pipeline/input.css b/crates/asset-pipeline/input.css index b86b40a4d..1effb2086 100644 --- a/crates/asset-pipeline/input.css +++ b/crates/asset-pipeline/input.css @@ -41,7 +41,7 @@ a { } /** In the chat the first p has margin, we need to remove it **/ -response-formatter > p:first-child, streaming-chat > p:first-child { +.response-formatter > p:first-child, streaming-chat > p:first-child { margin-top: 0; } diff --git a/crates/axum-server/src/prompt.rs b/crates/axum-server/src/prompt.rs index be92be89d..d5467ce9c 100644 --- a/crates/axum-server/src/prompt.rs +++ b/crates/axum-server/src/prompt.rs @@ -27,22 +27,27 @@ pub async fn execute_prompt( }; // Turn the users message into something the vector database can use - let embeddings = embeddings_api::get_embeddings(&question) - .await - .map_err(|e| CustomError::ExternalApi(e.to_string()))?; - - tracing::info!(prompt.name); - // Get related context - let related_context = db::get_related_context( - transaction, - prompt.dataset_connection, - prompt_id, - team_id, - prompt.max_chunks, - embeddings, - ) - .await?; - tracing::info!("Retrieved {} chunks", related_context.len()); + let mut related_context = Default::default(); + if let (Some(embeddings_base_url), Some(embeddings_model)) = + (prompt.embeddings_base_url, prompt.embeddings_model) + { + let embeddings = + embeddings_api::get_embeddings(&question, &embeddings_base_url, &embeddings_model) + .await + .map_err(|e| CustomError::ExternalApi(e.to_string()))?; + + tracing::info!(prompt.name); + // Get related context + related_context = db::get_related_context( + transaction, + prompt_id, + team_id, + prompt.max_chunks, + embeddings, + ) + .await?; + tracing::info!("Retrieved {} chunks", related_context.len()); + } // Get the maximum required amount of chat history let chat_history = if let Some(conversation_id) = conversation_id { diff --git a/crates/axum-server/src/prompts/form.rs b/crates/axum-server/src/prompts/form.rs index d9232d152..2efd74a5b 100644 --- a/crates/axum-server/src/prompts/form.rs +++ b/crates/axum-server/src/prompts/form.rs @@ -7,7 +7,7 @@ use db::authz; use db::Pool; use db::{queries, Transaction}; use serde::Deserialize; -use ui_pages::{prompts::string_to_dataset_connection, string_to_visibility}; +use ui_pages::string_to_visibility; use validator::Validate; #[derive(Deserialize, Validate, Default, Debug)] @@ -16,15 +16,14 @@ pub struct NewPromptTemplate { #[validate(length(min = 1, message = "The name is mandatory"))] pub name: String, pub system_prompt: String, - pub dataset_connection: String, pub model_id: i32, - pub datasets: Option, + #[serde(default)] + pub datasets: Vec, pub max_history_items: i32, pub max_chunks: i32, pub max_tokens: i32, pub trim_ratio: i32, pub temperature: f32, - pub top_p: f32, pub visibility: String, } @@ -57,14 +56,12 @@ pub async fn upsert( &new_prompt_template.model_id, &new_prompt_template.name, &visibility, - &string_to_dataset_connection(&new_prompt_template.dataset_connection), &system_prompt, &new_prompt_template.max_history_items, &new_prompt_template.max_chunks, &new_prompt_template.max_tokens, &new_prompt_template.trim_ratio, &new_prompt_template.temperature, - &new_prompt_template.top_p, &id, ) .await?; @@ -92,14 +89,12 @@ pub async fn upsert( &new_prompt_template.model_id, &new_prompt_template.name, &visibility, - &string_to_dataset_connection(&new_prompt_template.dataset_connection), &system_prompt, &new_prompt_template.max_history_items, &new_prompt_template.max_chunks, &new_prompt_template.max_tokens, &new_prompt_template.trim_ratio, &new_prompt_template.temperature, - &new_prompt_template.top_p, ) .one() .await?; @@ -126,23 +121,13 @@ pub async fn upsert( async fn insert_datasets( transaction: &Transaction<'_>, prompt_id: i32, - datasets: Option, + datasets: Vec, ) -> Result<(), CustomError> { // Create the connections to any datasets - if let Some(datasets) = datasets { - // The environments we have selected for the ser come in as a comma - // separated list of ids. - let datasets: Vec = datasets - .split(',') - .map(|e| e.parse::().unwrap_or(-1)) - .filter(|e| *e != -1) - .collect(); - - for dataset in datasets { - queries::prompts::insert_prompt_dataset() - .bind(transaction, &prompt_id, &dataset) - .await?; - } + for dataset in datasets { + queries::prompts::insert_prompt_dataset() + .bind(transaction, &prompt_id, &dataset) + .await?; } Ok(()) diff --git a/crates/axum-server/src/team/new_team.rs b/crates/axum-server/src/team/new_team.rs index 11687ed98..0f14004ed 100644 --- a/crates/axum-server/src/team/new_team.rs +++ b/crates/axum-server/src/team/new_team.rs @@ -7,7 +7,7 @@ use axum::{ use db::authz; use db::types; use db::Pool; -use db::{queries, DatasetConnection, Visibility}; +use db::{queries, Visibility}; use serde::Deserialize; use validator::Validate; @@ -59,14 +59,12 @@ pub async fn new_team( &model.id, &"Default (Exclude All Datasets)", &Visibility::Private, - &DatasetConnection::None, &system_prompt, &3, &10, &1024, &80, &0.7, - &0.1, ) .one() .await?; diff --git a/crates/daisy-rsx/check_box.rs b/crates/daisy-rsx/check_box.rs new file mode 100644 index 000000000..e30725936 --- /dev/null +++ b/crates/daisy-rsx/check_box.rs @@ -0,0 +1,105 @@ +#![allow(non_snake_case)] +use dioxus::prelude::*; + +#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)] +pub enum CheckBoxScheme { + #[default] + Default, + Primary, + Outline, + Danger, +} + +impl CheckBoxScheme { + pub fn to_string(&self) -> &'static str { + match self { + CheckBoxScheme::Default => "checkbox-default", + CheckBoxScheme::Primary => "checkbox-primary", + CheckBoxScheme::Outline => "checkbox-outline", + CheckBoxScheme::Danger => "checkbox-warning", + } + } +} + +#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)] +pub enum CheckBoxSize { + #[default] + Default, + Small, + ExtraSmall, + Large, + Medium, +} + +impl CheckBoxSize { + pub fn to_string(&self) -> &'static str { + match self { + CheckBoxSize::Default => "checkbox-sm", + CheckBoxSize::ExtraSmall => "checkbox-xs", + CheckBoxSize::Small => "checkbox-sm", + CheckBoxSize::Medium => "checkbox-md", + CheckBoxSize::Large => "checkbox-lg", + } + } +} + +#[derive(Props)] +pub struct CheckBoxProps<'a> { + children: Element<'a>, + id: Option<&'a str>, + checked: Option, + class: Option<&'a str>, + name: &'a str, + value: &'a str, + checkbox_size: Option, + checkbox_scheme: Option, +} + +pub fn CheckBox<'a>(cx: Scope<'a, CheckBoxProps<'a>>) -> Element { + let checkbox_scheme = if cx.props.checkbox_scheme.is_some() { + cx.props.checkbox_scheme.unwrap() + } else { + Default::default() + }; + + let checkbox_size = if cx.props.checkbox_size.is_some() { + cx.props.checkbox_size.unwrap() + } else { + Default::default() + }; + + let class = if let Some(class) = cx.props.class { + class + } else { + "" + }; + + let checked = if let Some(checked) = cx.props.checked { + if checked { + Some("checked") + } else { + None + } + } else { + None + }; + + let class = format!( + "checkbox {} {} {}", + class, + checkbox_scheme.to_string(), + checkbox_size.to_string() + ); + + cx.render(rsx!( + input { + "type": "checkbox", + class: "{class}", + id: cx.props.id, + name: cx.props.name, + value: cx.props.value, + checked: checked, + &cx.props.children, + } + )) +} diff --git a/crates/daisy-rsx/lib.rs b/crates/daisy-rsx/lib.rs index f07aba351..98012ab0b 100644 --- a/crates/daisy-rsx/lib.rs +++ b/crates/daisy-rsx/lib.rs @@ -4,6 +4,7 @@ pub mod avatar; pub mod blank_slate; pub mod button; pub mod card; +pub mod check_box; pub mod drawer; pub mod drop_down; pub mod input; @@ -24,6 +25,7 @@ pub use avatar::{Avatar, AvatarSize, AvatarType}; pub use blank_slate::BlankSlate; pub use button::{Button, ButtonScheme, ButtonSize, ButtonType}; pub use card::{Box, BoxBody, BoxHeader}; +pub use check_box::{CheckBox, CheckBoxScheme, CheckBoxSize}; pub use drawer::{Drawer, DrawerBody, DrawerFooter}; pub use drop_down::{Direction, DropDown, DropDownLink}; pub use input::{Input, InputSize, InputType}; diff --git a/crates/db/authz.rs b/crates/db/authz.rs index cfeab6d95..a5cc4d099 100644 --- a/crates/db/authz.rs +++ b/crates/db/authz.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; use crate::queries; -use crate::{types, DatasetConnection, Permission, Transaction, Visibility}; +use crate::{types, Permission, Transaction, Visibility}; #[derive(Serialize, Deserialize, Debug)] pub struct Authentication { @@ -114,14 +114,12 @@ pub async fn setup_user_if_not_already_registered( &model.id, &"Default (Exclude All Datasets)", &Visibility::Private, - &DatasetConnection::None, &system_prompt, &3, &10, &1024, &100, &0.7, - &0.1, ) .one() .await?; diff --git a/crates/db/lib.rs b/crates/db/lib.rs index 468ec096f..da0b3e133 100644 --- a/crates/db/lib.rs +++ b/crates/db/lib.rs @@ -24,8 +24,7 @@ pub use queries::teams::GetUsers as Member; pub use queries::teams::{Team, TeamOwner}; pub use queries::users::User; pub use types::public::{ - AuditAccessType, AuditAction, ChatStatus, DatasetConnection, ModelType, Permission, Role, - Visibility, + AuditAccessType, AuditAction, ChatStatus, ModelType, Permission, Role, Visibility, }; pub use vector_search::{get_related_context, RelatedContext}; diff --git a/crates/db/migrations/20240221131717_refactor_prompts.sql b/crates/db/migrations/20240221131717_refactor_prompts.sql new file mode 100644 index 000000000..58758c761 --- /dev/null +++ b/crates/db/migrations/20240221131717_refactor_prompts.sql @@ -0,0 +1,14 @@ +-- migrate:up +ALTER TABLE prompts DROP COLUMN dataset_connection; +ALTER TABLE prompts DROP COLUMN top_p; +DROP TYPE dataset_connection; + +-- migrate:down +CREATE TYPE dataset_connection AS ENUM ( + 'All', + 'None', + 'Selected' +); +COMMENT ON TYPE dataset_connection IS 'A prompt can use all datasets, no datasets or selected datasets.'; +ALTER TABLE prompts ADD COLUMN dataset_connection dataset_connection NOT NULL DEFAULT 'None'; +ALTER TABLE prompts ADD COLUMN top_p REAL CHECK (top_p >= 0 AND top_p <= 1); \ No newline at end of file diff --git a/crates/db/queries/chunks.sql b/crates/db/queries/chunks.sql index babe96699..3374ba939 100644 --- a/crates/db/queries/chunks.sql +++ b/crates/db/queries/chunks.sql @@ -1,7 +1,23 @@ --! unprocessed_chunks : Chunk() SELECT id, - text + text, + (SELECT + base_url + FROM + models + WHERE + id IN (SELECT embeddings_model_id FROM datasets ds WHERE ds.id IN + (SELECT dataset_id FROM documents d WHERE d.id = document_id)) + ) as base_url, + (SELECT + name + FROM + models + WHERE + id IN (SELECT embeddings_model_id FROM datasets ds WHERE ds.id IN + (SELECT dataset_id FROM documents d WHERE d.id = document_id)) + ) as model FROM chunks WHERE diff --git a/crates/db/queries/datasets.sql b/crates/db/queries/datasets.sql index 71b996182..6032cdbb3 100644 --- a/crates/db/queries/datasets.sql +++ b/crates/db/queries/datasets.sql @@ -9,6 +9,7 @@ SELECT new_after_n_chars, multipage_sections, (SELECT COUNT(id) FROM documents WHERE dataset_id = d.id) as count, + (SELECT name FROM models WHERE id = d.embeddings_model_id) as embeddings_model_name, created_at, updated_at FROM @@ -30,6 +31,7 @@ SELECT new_after_n_chars, multipage_sections, (SELECT COUNT(id) FROM documents WHERE dataset_id = d.id) as count, + (SELECT name FROM models WHERE id = d.embeddings_model_id) as embeddings_model_name, created_at, updated_at FROM @@ -50,6 +52,7 @@ SELECT new_after_n_chars, multipage_sections, (SELECT COUNT(id) FROM documents WHERE dataset_id = d.id) as count, + (SELECT name FROM models WHERE id = d.embeddings_model_id) as embeddings_model_name, created_at, updated_at FROM @@ -74,6 +77,7 @@ SELECT new_after_n_chars, multipage_sections, (SELECT COUNT(id) FROM documents WHERE dataset_id = d.id) as count, + (SELECT name FROM models WHERE id = d.embeddings_model_id) as embeddings_model_name, created_at, updated_at FROM diff --git a/crates/db/queries/prompts.sql b/crates/db/queries/prompts.sql index 06a538c1b..7e6c0fa1a 100644 --- a/crates/db/queries/prompts.sql +++ b/crates/db/queries/prompts.sql @@ -1,4 +1,5 @@ ---: Prompt(temperature?, top_p?, system_prompt?) +--: Prompt(temperature?, system_prompt?) +--: SinglePrompt(temperature?, system_prompt?, embeddings_base_url?, embeddings_model?) --! prompts : Prompt SELECT @@ -10,7 +11,6 @@ SELECT p.model_id, p.name, p.visibility, - p.dataset_connection, -- Creata a string showing the datsets connected to this prompt ( SELECT @@ -32,7 +32,6 @@ SELECT p.max_tokens, p.trim_ratio, p.temperature, - p.top_p, -- Convert times to ISO 8601 string. trim(both '"' from to_json(p.created_at)::text) as created_at, trim(both '"' from to_json(p.updated_at)::text) as updated_at @@ -50,17 +49,22 @@ WHERE OR p.visibility='Company' ORDER BY updated_at; ---! prompt : Prompt +--! prompt : SinglePrompt SELECT p.id, (SELECT name FROM models WHERE id = p.model_id) as model_name, (SELECT base_url FROM models WHERE id = p.model_id) as base_url, (SELECT context_size FROM models WHERE id = p.model_id) as model_context_size, - (SELECT team_id FROM models WHERE id = p.model_id) as team_id, + (SELECT team_id FROM models WHERE id = p.model_id) as team_id, + (SELECT base_url FROM models WHERE id IN + (SELECT embeddings_model_id FROM datasets ds WHERE ds.id IN + (SELECT dataset_id FROM prompt_dataset WHERE prompt_id = p.id LIMIT 1))) as embeddings_base_url, + (SELECT name FROM models WHERE id IN + (SELECT embeddings_model_id FROM datasets ds WHERE ds.id IN + (SELECT dataset_id FROM prompt_dataset WHERE prompt_id = p.id LIMIT 1))) as embeddings_model, p.model_id, p.name, p.visibility, - p.dataset_connection, -- Creata a string showing the datsets connected to this prompt ( SELECT @@ -82,7 +86,6 @@ SELECT p.max_tokens, p.trim_ratio, p.temperature, - p.top_p, -- Convert times to ISO 8601 string. trim(both '"' from to_json(p.created_at)::text) as created_at, trim(both '"' from to_json(p.updated_at)::text) as updated_at @@ -112,7 +115,6 @@ SELECT p.model_id, p.name, p.visibility, - p.dataset_connection, -- Creata a string showing the datsets connected to this prompt ( SELECT @@ -134,7 +136,6 @@ SELECT p.max_tokens, p.trim_ratio, p.temperature, - p.top_p, -- Convert times to ISO 8601 string. trim(both '"' from to_json(p.created_at)::text) as created_at, trim(both '"' from to_json(p.updated_at)::text) as updated_at @@ -197,28 +198,24 @@ INSERT INTO prompts ( model_id, name, visibility, - dataset_connection, system_prompt, max_history_items, max_chunks, max_tokens, trim_ratio, - temperature, - top_p + temperature ) VALUES( :team_id, :model_id, :name, :visibility, - :dataset_connection, :system_prompt, :max_history_items, :max_chunks, :max_tokens, :trim_ratio, - :temperature, - :top_p + :temperature ) RETURNING id; @@ -229,14 +226,12 @@ SET model_id = :model_id, name = :name, visibility = :visibility, - dataset_connection = :dataset_connection, system_prompt = :system_prompt, max_history_items = :max_history_items, max_chunks = :max_chunks, max_tokens = :max_tokens, trim_ratio = :trim_ratio, - temperature = :temperature, - top_p = :top_p + temperature = :temperature WHERE id = :id AND diff --git a/crates/db/vector_search.rs b/crates/db/vector_search.rs index 4eb82d21d..52ae0d113 100644 --- a/crates/db/vector_search.rs +++ b/crates/db/vector_search.rs @@ -1,6 +1,6 @@ use crate::queries::prompts; use crate::TokioPostgresError; -use crate::{DatasetConnection, Transaction}; +use crate::Transaction; pub struct RelatedContext { pub chunk_id: i32, @@ -11,16 +11,11 @@ pub struct RelatedContext { // The prompt decides how we use the datasets pub async fn get_related_context( transaction: &Transaction<'_>, - dataset_connection: DatasetConnection, prompt_id: i32, team_id: i32, limit: i32, embeddings: Vec, ) -> Result, TokioPostgresError> { - if dataset_connection == DatasetConnection::None { - return Ok(Default::default()); - } - // Which datasets does the prompt use let datasets = prompts::prompt_datasets() .bind(transaction, &prompt_id) @@ -32,51 +27,10 @@ pub async fn get_related_context( // Format the embeddings in PGVector format let embedding_data = pgvector::Vector::from(embeddings.clone()); - match dataset_connection { - DatasetConnection::None => Ok(Default::default()), - DatasetConnection::All => { - // Find sections of documents that are related to the users question - let related_context = transaction - .query( - " - SELECT - id, - text - FROM - chunks - WHERE - document_id IN ( - SELECT id FROM documents WHERE dataset_id IN ( - SELECT id FROM datasets WHERE team_id IN ( - SELECT team_id FROM team_users - WHERE user_id = current_app_user() - AND team_id = $1 - ) - ) - ) - ORDER BY - embeddings <-> $2 - LIMIT $3; - ", - &[&team_id, &embedding_data, &(limit as i64)], - ) - .await?; - - // Just get the text from the returned rows - let related_context: Vec = related_context - .into_iter() - .map(|content| RelatedContext { - chunk_id: content.get(0), - chunk_text: content.get(1), - }) - .collect(); - Ok(related_context) - } - DatasetConnection::Selected => { - // Find sections of documents that are related to the users question - let related_context = transaction - .query( - " + // Find sections of documents that are related to the users question + let related_context = transaction + .query( + " SELECT id, text @@ -97,19 +51,18 @@ pub async fn get_related_context( embeddings <-> $3 LIMIT $4; ", - &[&team_id, &datasets, &embedding_data, &(limit as i64)], - ) - .await?; + &[&team_id, &datasets, &embedding_data, &(limit as i64)], + ) + .await?; + + // Just get the text from the returned rows + let related_context: Vec = related_context + .into_iter() + .map(|content| RelatedContext { + chunk_id: content.get(0), + chunk_text: content.get(1), + }) + .collect(); - // Just get the text from the returned rows - let related_context: Vec = related_context - .into_iter() - .map(|content| RelatedContext { - chunk_id: content.get(0), - chunk_text: content.get(1), - }) - .collect(); - Ok(related_context) - } - } + Ok(related_context) } diff --git a/crates/embeddings-api/src/lib.rs b/crates/embeddings-api/src/lib.rs index 42d24ae0d..b93af5c59 100644 --- a/crates/embeddings-api/src/lib.rs +++ b/crates/embeddings-api/src/lib.rs @@ -43,25 +43,23 @@ pub struct Embedding { pub text: String, } -pub async fn get_embeddings(input: &str) -> Result, Box> { +pub async fn get_embeddings( + input: &str, + api_end_point: &str, + model: &str, +) -> Result, Box> { let client = Client::new(); - let openai_endpoint = if let Ok(domain) = std::env::var("EMBEDDINGS_API_ENDPOINT") { - domain - } else { - "http://embeddings-api:80/embeddings".to_string() - }; - let text = String::from_utf8_lossy(input.as_bytes()).to_string(); let calling_json = EmbeddingRequest { input: text.clone(), - model: "text-embedding-ada-002".to_string(), + model: model.to_string(), user: None, }; //send request let response = client - .post(openai_endpoint) + .post(api_end_point) .json(&calling_json) .send() .await?; @@ -83,7 +81,13 @@ mod tests { #[tokio::test] async fn test_get_embeddings() { let input = "The food was delicious and the waiter...".to_string(); - let embeddings = get_embeddings(&input).await.unwrap(); + let embeddings = get_embeddings( + &input, + "http://embeddings-api:80/embeddings", + "text-embedding-ada-002", + ) + .await + .unwrap(); println!("{:#?}", embeddings); } diff --git a/crates/pipeline-job/src/main.rs b/crates/pipeline-job/src/main.rs index 187c79d73..3a7812c80 100644 --- a/crates/pipeline-job/src/main.rs +++ b/crates/pipeline-job/src/main.rs @@ -78,7 +78,12 @@ async fn main() -> Result<(), Box> { } if let Some(embedding) = unprocessed.get(0) { - let embeddings = embeddings_api::get_embeddings(&embedding.text).await; + let embeddings = embeddings_api::get_embeddings( + &embedding.text, + &embedding.base_url, + &embedding.model, + ) + .await; if let Ok(embeddings) = embeddings { let embedding_data = pgvector::Vector::from(embeddings); client diff --git a/crates/ui-pages/prompts/dataset_connection.rs b/crates/ui-pages/prompts/dataset_connection.rs deleted file mode 100644 index 64218c05d..000000000 --- a/crates/ui-pages/prompts/dataset_connection.rs +++ /dev/null @@ -1,29 +0,0 @@ -#![allow(non_snake_case)] -use daisy_rsx::*; -use db::DatasetConnection; -use dioxus::prelude::*; - -#[inline_props] -pub fn DatasetConnection(cx: Scope, connection: DatasetConnection) -> Element { - match connection { - DatasetConnection::All => cx.render(rsx!( - Label { - class: "mr-2", - label_role: LabelRole::Highlight, - "Use All the Teams Datasets" - } - )), - DatasetConnection::None => cx.render(rsx!( - Label { - class: "mr-2", - label_role: LabelRole::Highlight, - "Don't use any datasets" - } - )), - DatasetConnection::Selected => cx.render(rsx!(Label { - class: "mr-2", - label_role: LabelRole::Highlight, - "Use selected Datasets" - })), - } -} diff --git a/crates/ui-pages/prompts/form.rs b/crates/ui-pages/prompts/form.rs index 5968534c3..40ddb1f22 100644 --- a/crates/ui-pages/prompts/form.rs +++ b/crates/ui-pages/prompts/form.rs @@ -1,7 +1,6 @@ #![allow(non_snake_case)] -use super::dataset_connection_to_string; use daisy_rsx::{select::SelectOption, *}; -use db::{Dataset, DatasetConnection, Model, Visibility}; +use db::{Dataset, Model, Visibility}; use dioxus::prelude::*; #[inline_props] @@ -13,7 +12,6 @@ pub fn Form( system_prompt: String, datasets: Vec, selected_dataset_ids: Vec, - dataset_connection: DatasetConnection, models: Vec, model_id: i32, visibility: Visibility, @@ -23,7 +21,6 @@ pub fn Form( max_tokens: i32, trim_ratio: i32, temperature: f32, - top_p: f32, ) -> Element { cx.render(rsx!( form { @@ -116,64 +113,57 @@ pub fn Form( tab_name: "Datasets", div { class: "flex flex-col mt-3", - Select { - name: "dataset_connection", - label: "How shall we handle datasets with this prompt?", - help_text: "The prompt will be passed to the model", - value: "{dataset_connection_to_string(*dataset_connection)}", - required: true, - SelectOption { - value: "{dataset_connection_to_string(DatasetConnection::All)}", - selected_value: "{dataset_connection_to_string(*dataset_connection)}", - dataset_connection_to_string(DatasetConnection::All) - } - SelectOption { - value: "{dataset_connection_to_string(DatasetConnection::None)}", - selected_value: "{dataset_connection_to_string(*dataset_connection)}", - dataset_connection_to_string(DatasetConnection::None) - } - SelectOption { - value: "{dataset_connection_to_string(DatasetConnection::Selected)}", - selected_value: "{dataset_connection_to_string(*dataset_connection)}", - dataset_connection_to_string(DatasetConnection::Selected) - } + Alert { + class: "mb-4", + "Select which datasets you wish to attach to this prompt" } - - Select { - name: "datasets", - label: "Select datasets to connect to this prompt", - label_class: "mt-4", - help_text: "These datasets will only be used when the above is set to 'Use Selected Datasets'", - value: &name, - multiple: true, - datasets.iter().map(|dataset| { - if selected_dataset_ids.contains(&dataset.id) { - cx.render(rsx!( - option { - value: "{dataset.id}", - selected: true, - "{dataset.name}" - } - )) - } else { - cx.render(rsx!( - option { - value: "{dataset.id}", - "{dataset.name}" - } - )) + table { + class: "table table-sm", + thead { + tr { + th { + "Dataset" + } + th { + "Model" + } + th { + "Add?" + } } - }) - } - - Input { - input_type: InputType::Number, - name: "max_chunks", - label: "Maximum number of Chunks", - label_class: "mt-4", - help_text: "We don't add more chunks to the prompt than this.", - value: "{*max_chunks}", - required: true + } + tbody { + datasets.iter().map(|dataset| { + rsx!( + tr { + td { + "{dataset.name}" + } + td { + "{dataset.embeddings_model_name}" + } + td { + if selected_dataset_ids.contains(&dataset.id) { + cx.render(rsx!( + CheckBox { + checked: true, + name: "datasets", + value: "{dataset.id}" + } + )) + } else { + cx.render(rsx!( + CheckBox { + name: "datasets", + value: "{dataset.id}" + } + )) + } + } + } + ) + }) + } } } } @@ -249,12 +239,11 @@ pub fn Form( Input { input_type: InputType::Number, - step: "0.1", - name: "top_p", - label: "Alternative to Temperature", + name: "max_chunks", + label: "Maximum number of Chunks", label_class: "mt-4", - help_text: "Value between 0 and 2.", - value: "{*top_p}", + help_text: "We don't add more chunks to the prompt than this.", + value: "{*max_chunks}", required: true } } diff --git a/crates/ui-pages/prompts/index.rs b/crates/ui-pages/prompts/index.rs index 074c66196..7c2f18a10 100644 --- a/crates/ui-pages/prompts/index.rs +++ b/crates/ui-pages/prompts/index.rs @@ -3,7 +3,7 @@ use crate::app_layout::{Layout, SideBar}; use assets::files::*; use daisy_rsx::*; use db::authz::Rbac; -use db::{queries::prompts::Prompt, Dataset, DatasetConnection, Model, Visibility}; +use db::{queries::prompts::Prompt, Dataset, Model, Visibility}; use dioxus::prelude::*; #[inline_props] @@ -57,7 +57,6 @@ pub fn Page( class: "table table-sm", thead { th { "Name" } - th { "Dataset(s)" } th { "Visibility" } th { "Model" } th { "Updated" } @@ -74,11 +73,6 @@ pub fn Page( td { "{prompt.name}" } - td { - super::dataset_connection::DatasetConnection { - connection: prompt.dataset_connection - } - } td { super::visibility::VisLabel { visibility: prompt.visibility @@ -141,7 +135,6 @@ pub fn Page( system_prompt: prompt.system_prompt.clone().unwrap_or("".to_string()), datasets: datasets.clone(), selected_dataset_ids: split_datasets(&prompt.selected_datasets), - dataset_connection: prompt.dataset_connection, visibility: prompt.visibility, models: models.clone(), model_id: prompt.model_id, @@ -150,7 +143,6 @@ pub fn Page( max_tokens: prompt.max_tokens, trim_ratio: prompt.trim_ratio, temperature: prompt.temperature.unwrap_or(0.7), - top_p: prompt.top_p.unwrap_or(0.0), } )) }) @@ -164,7 +156,6 @@ pub fn Page( name: "".to_string(), system_prompt: "".to_string(), datasets: datasets.clone(), - dataset_connection: DatasetConnection::None, selected_dataset_ids: Default::default(), models: models.clone(), visibility: Visibility::Private, @@ -173,8 +164,7 @@ pub fn Page( max_chunks: 10, max_tokens: 1024, trim_ratio: 80, - temperature: 0.7, - top_p: 0.0, + temperature: 0.7 } } }) diff --git a/crates/ui-pages/prompts/mod.rs b/crates/ui-pages/prompts/mod.rs index e6b89d130..8f7ff655a 100644 --- a/crates/ui-pages/prompts/mod.rs +++ b/crates/ui-pages/prompts/mod.rs @@ -1,24 +1,6 @@ -pub mod dataset_connection; pub mod delete; pub mod form; pub mod index; pub mod visibility; -use db::DatasetConnection; pub use index::index; - -pub fn dataset_connection_to_string(connection: DatasetConnection) -> String { - match connection { - DatasetConnection::All => "All".to_string(), - DatasetConnection::Selected => "Selected".to_string(), - _ => "None".to_string(), - } -} - -pub fn string_to_dataset_connection(connection: &str) -> DatasetConnection { - match connection { - "All" => DatasetConnection::All, - "Selected" => DatasetConnection::Selected, - _ => DatasetConnection::None, - } -}