[go: nahoru, domu]

Skip to content

Commit

Permalink
fix: Use embeddings model attached to dataset. (bionic-gpt#378)
Browse files Browse the repository at this point in the history
  • Loading branch information
9876691 committed Feb 21, 2024
1 parent b3000c2 commit b36edc1
Show file tree
Hide file tree
Showing 20 changed files with 283 additions and 268 deletions.
2 changes: 1 addition & 1 deletion .devcontainer/.bash_aliases
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ alias watch-app='mold -run cargo watch --workdir /workspace/ -w crates/daisy-rsx
alias wa=watch-app
alias watch-pipeline='npm install --prefix /workspace/crates/asset-pipeline && npm run start --prefix /workspace/crates/asset-pipeline'
alias wp=watch-pipeline
alias watch-embeddings='mold -run cargo watch --workdir /workspace/ -w crates/open-api -w crates/pipeline-job --no-gitignore -x "run --bin pipeline-job"'
alias watch-embeddings='mold -run cargo watch --workdir /workspace/ -w crates/embeddings-api -w crates/pipeline-job --no-gitignore -x "run --bin pipeline-job"'
alias we=watch-embeddings
alias watch-tailwind='cd /workspace/crates/asset-pipeline && npx tailwindcss -i ./input.css -o ./dist/output.css --watch'
alias wt=watch-tailwind
Expand Down
2 changes: 1 addition & 1 deletion crates/asset-pipeline/input.css
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ a {
}

/** In the chat the first p has margin, we need to remove it **/
response-formatter > p:first-child, streaming-chat > p:first-child {
.response-formatter > p:first-child, streaming-chat > p:first-child {
margin-top: 0;
}

Expand Down
37 changes: 21 additions & 16 deletions crates/axum-server/src/prompt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,27 @@ pub async fn execute_prompt(
};

// Turn the users message into something the vector database can use
let embeddings = embeddings_api::get_embeddings(&question)
.await
.map_err(|e| CustomError::ExternalApi(e.to_string()))?;

tracing::info!(prompt.name);
// Get related context
let related_context = db::get_related_context(
transaction,
prompt.dataset_connection,
prompt_id,
team_id,
prompt.max_chunks,
embeddings,
)
.await?;
tracing::info!("Retrieved {} chunks", related_context.len());
let mut related_context = Default::default();
if let (Some(embeddings_base_url), Some(embeddings_model)) =
(prompt.embeddings_base_url, prompt.embeddings_model)
{
let embeddings =
embeddings_api::get_embeddings(&question, &embeddings_base_url, &embeddings_model)
.await
.map_err(|e| CustomError::ExternalApi(e.to_string()))?;

tracing::info!(prompt.name);
// Get related context
related_context = db::get_related_context(
transaction,
prompt_id,
team_id,
prompt.max_chunks,
embeddings,
)
.await?;
tracing::info!("Retrieved {} chunks", related_context.len());
}

// Get the maximum required amount of chat history
let chat_history = if let Some(conversation_id) = conversation_id {
Expand Down
31 changes: 8 additions & 23 deletions crates/axum-server/src/prompts/form.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use db::authz;
use db::Pool;
use db::{queries, Transaction};
use serde::Deserialize;
use ui_pages::{prompts::string_to_dataset_connection, string_to_visibility};
use ui_pages::string_to_visibility;
use validator::Validate;

#[derive(Deserialize, Validate, Default, Debug)]
Expand All @@ -16,15 +16,14 @@ pub struct NewPromptTemplate {
#[validate(length(min = 1, message = "The name is mandatory"))]
pub name: String,
pub system_prompt: String,
pub dataset_connection: String,
pub model_id: i32,
pub datasets: Option<String>,
#[serde(default)]
pub datasets: Vec<i32>,
pub max_history_items: i32,
pub max_chunks: i32,
pub max_tokens: i32,
pub trim_ratio: i32,
pub temperature: f32,
pub top_p: f32,
pub visibility: String,
}

Expand Down Expand Up @@ -57,14 +56,12 @@ pub async fn upsert(
&new_prompt_template.model_id,
&new_prompt_template.name,
&visibility,
&string_to_dataset_connection(&new_prompt_template.dataset_connection),
&system_prompt,
&new_prompt_template.max_history_items,
&new_prompt_template.max_chunks,
&new_prompt_template.max_tokens,
&new_prompt_template.trim_ratio,
&new_prompt_template.temperature,
&new_prompt_template.top_p,
&id,
)
.await?;
Expand Down Expand Up @@ -92,14 +89,12 @@ pub async fn upsert(
&new_prompt_template.model_id,
&new_prompt_template.name,
&visibility,
&string_to_dataset_connection(&new_prompt_template.dataset_connection),
&system_prompt,
&new_prompt_template.max_history_items,
&new_prompt_template.max_chunks,
&new_prompt_template.max_tokens,
&new_prompt_template.trim_ratio,
&new_prompt_template.temperature,
&new_prompt_template.top_p,
)
.one()
.await?;
Expand All @@ -126,23 +121,13 @@ pub async fn upsert(
async fn insert_datasets(
transaction: &Transaction<'_>,
prompt_id: i32,
datasets: Option<String>,
datasets: Vec<i32>,
) -> Result<(), CustomError> {
// Create the connections to any datasets
if let Some(datasets) = datasets {
// The environments we have selected for the ser come in as a comma
// separated list of ids.
let datasets: Vec<i32> = datasets
.split(',')
.map(|e| e.parse::<i32>().unwrap_or(-1))
.filter(|e| *e != -1)
.collect();

for dataset in datasets {
queries::prompts::insert_prompt_dataset()
.bind(transaction, &prompt_id, &dataset)
.await?;
}
for dataset in datasets {
queries::prompts::insert_prompt_dataset()
.bind(transaction, &prompt_id, &dataset)
.await?;
}

Ok(())
Expand Down
4 changes: 1 addition & 3 deletions crates/axum-server/src/team/new_team.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use axum::{
use db::authz;
use db::types;
use db::Pool;
use db::{queries, DatasetConnection, Visibility};
use db::{queries, Visibility};
use serde::Deserialize;
use validator::Validate;

Expand Down Expand Up @@ -59,14 +59,12 @@ pub async fn new_team(
&model.id,
&"Default (Exclude All Datasets)",
&Visibility::Private,
&DatasetConnection::None,
&system_prompt,
&3,
&10,
&1024,
&80,
&0.7,
&0.1,
)
.one()
.await?;
Expand Down
105 changes: 105 additions & 0 deletions crates/daisy-rsx/check_box.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
#![allow(non_snake_case)]
use dioxus::prelude::*;

#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)]
pub enum CheckBoxScheme {
#[default]
Default,
Primary,
Outline,
Danger,
}

impl CheckBoxScheme {
pub fn to_string(&self) -> &'static str {
match self {
CheckBoxScheme::Default => "checkbox-default",
CheckBoxScheme::Primary => "checkbox-primary",
CheckBoxScheme::Outline => "checkbox-outline",
CheckBoxScheme::Danger => "checkbox-warning",
}
}
}

#[derive(Default, Copy, Clone, Debug, PartialEq, Eq)]
pub enum CheckBoxSize {
#[default]
Default,
Small,
ExtraSmall,
Large,
Medium,
}

impl CheckBoxSize {
pub fn to_string(&self) -> &'static str {
match self {
CheckBoxSize::Default => "checkbox-sm",
CheckBoxSize::ExtraSmall => "checkbox-xs",
CheckBoxSize::Small => "checkbox-sm",
CheckBoxSize::Medium => "checkbox-md",
CheckBoxSize::Large => "checkbox-lg",
}
}
}

#[derive(Props)]
pub struct CheckBoxProps<'a> {
children: Element<'a>,
id: Option<&'a str>,
checked: Option<bool>,
class: Option<&'a str>,
name: &'a str,
value: &'a str,
checkbox_size: Option<CheckBoxSize>,
checkbox_scheme: Option<CheckBoxScheme>,
}

pub fn CheckBox<'a>(cx: Scope<'a, CheckBoxProps<'a>>) -> Element {
let checkbox_scheme = if cx.props.checkbox_scheme.is_some() {
cx.props.checkbox_scheme.unwrap()
} else {
Default::default()
};

let checkbox_size = if cx.props.checkbox_size.is_some() {
cx.props.checkbox_size.unwrap()
} else {
Default::default()
};

let class = if let Some(class) = cx.props.class {
class
} else {
""
};

let checked = if let Some(checked) = cx.props.checked {
if checked {
Some("checked")
} else {
None
}
} else {
None
};

let class = format!(
"checkbox {} {} {}",
class,
checkbox_scheme.to_string(),
checkbox_size.to_string()
);

cx.render(rsx!(
input {
"type": "checkbox",
class: "{class}",
id: cx.props.id,
name: cx.props.name,
value: cx.props.value,
checked: checked,
&cx.props.children,
}
))
}
2 changes: 2 additions & 0 deletions crates/daisy-rsx/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub mod avatar;
pub mod blank_slate;
pub mod button;
pub mod card;
pub mod check_box;
pub mod drawer;
pub mod drop_down;
pub mod input;
Expand All @@ -24,6 +25,7 @@ pub use avatar::{Avatar, AvatarSize, AvatarType};
pub use blank_slate::BlankSlate;
pub use button::{Button, ButtonScheme, ButtonSize, ButtonType};
pub use card::{Box, BoxBody, BoxHeader};
pub use check_box::{CheckBox, CheckBoxScheme, CheckBoxSize};
pub use drawer::{Drawer, DrawerBody, DrawerFooter};
pub use drop_down::{Direction, DropDown, DropDownLink};
pub use input::{Input, InputSize, InputType};
Expand Down
4 changes: 1 addition & 3 deletions crates/db/authz.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize};

use crate::queries;
use crate::{types, DatasetConnection, Permission, Transaction, Visibility};
use crate::{types, Permission, Transaction, Visibility};

#[derive(Serialize, Deserialize, Debug)]
pub struct Authentication {
Expand Down Expand Up @@ -114,14 +114,12 @@ pub async fn setup_user_if_not_already_registered(
&model.id,
&"Default (Exclude All Datasets)",
&Visibility::Private,
&DatasetConnection::None,
&system_prompt,
&3,
&10,
&1024,
&100,
&0.7,
&0.1,
)
.one()
.await?;
Expand Down
3 changes: 1 addition & 2 deletions crates/db/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ pub use queries::teams::GetUsers as Member;
pub use queries::teams::{Team, TeamOwner};
pub use queries::users::User;
pub use types::public::{
AuditAccessType, AuditAction, ChatStatus, DatasetConnection, ModelType, Permission, Role,
Visibility,
AuditAccessType, AuditAction, ChatStatus, ModelType, Permission, Role, Visibility,
};
pub use vector_search::{get_related_context, RelatedContext};

Expand Down
14 changes: 14 additions & 0 deletions crates/db/migrations/20240221131717_refactor_prompts.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
-- migrate:up
ALTER TABLE prompts DROP COLUMN dataset_connection;
ALTER TABLE prompts DROP COLUMN top_p;
DROP TYPE dataset_connection;

-- migrate:down
CREATE TYPE dataset_connection AS ENUM (
'All',
'None',
'Selected'
);
COMMENT ON TYPE dataset_connection IS 'A prompt can use all datasets, no datasets or selected datasets.';
ALTER TABLE prompts ADD COLUMN dataset_connection dataset_connection NOT NULL DEFAULT 'None';
ALTER TABLE prompts ADD COLUMN top_p REAL CHECK (top_p >= 0 AND top_p <= 1);
18 changes: 17 additions & 1 deletion crates/db/queries/chunks.sql
Original file line number Diff line number Diff line change
@@ -1,7 +1,23 @@
--! unprocessed_chunks : Chunk()
SELECT
id,
text
text,
(SELECT
base_url
FROM
models
WHERE
id IN (SELECT embeddings_model_id FROM datasets ds WHERE ds.id IN
(SELECT dataset_id FROM documents d WHERE d.id = document_id))
) as base_url,
(SELECT
name
FROM
models
WHERE
id IN (SELECT embeddings_model_id FROM datasets ds WHERE ds.id IN
(SELECT dataset_id FROM documents d WHERE d.id = document_id))
) as model
FROM
chunks
WHERE
Expand Down
Loading

0 comments on commit b36edc1

Please sign in to comment.