[go: nahoru, domu]

Skip to content

Commit

Permalink
fix: API to create prompts based on dataset. (bionic-gpt#40)
Browse files Browse the repository at this point in the history
  • Loading branch information
9876691 committed Sep 29, 2023
1 parent 30010b6 commit e0a47a2
Show file tree
Hide file tree
Showing 72 changed files with 400 additions and 4,407 deletions.
1 change: 1 addition & 0 deletions .devcontainer/.bash_aliases
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ alias gcb='git checkout -b'
alias gcr='f() { git checkout -b $1 origin/$1; }; f'
alias gitsetup='git config --global user.name \$NAME && git config --global user.email \$EMAIL'
alias gsu='git submodule update --recursive --remote'
alias gdb='git branch | grep -v "main" | xargs git branch -D'

# Database
alias dbmate='dbmate --no-dump-schema --migrations-dir /workspace/crates/db/migrations'
Expand Down
33 changes: 15 additions & 18 deletions .devcontainer/envoy.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,20 @@ static_resources:
- match:
prefix: "/v1"
route:
cluster: llm-api
cluster: app
# Disable timeout for SSE
# https://medium.com/@kaitmore/server-sent-events-http-2-and-envoy-6927c70368bb
timeout: 0s
typed_per_filter_config:
envoy.filters.http.ext_authz:
"@type": type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute
disabled: true

# These are requests coming from the front end typescript
- match:
prefix: "/completions"
route:
cluster: app
# Disable timeout for SSE
# https://medium.com/@kaitmore/server-sent-events-http-2-and-envoy-6927c70368bb
timeout: 0s
Expand Down Expand Up @@ -128,20 +141,4 @@ static_resources:
address:
socket_address:
address: development
port_value: 7703

# The LLM API
- name: llm-api
connect_timeout: 10s
type: strict_dns
lb_policy: round_robin
dns_lookup_family: V4_ONLY
load_assignment:
cluster_name: llm-api
endpoints:
- lb_endpoints:
- endpoint:
address:
socket_address:
address: llm-api
port_value: 8080
port_value: 7703
2 changes: 2 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ Create all the database tables with

The website uses a zola theme. This will need to be loaded with

`submodule init`

`gsu`

## Starting the services
Expand Down
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion crates/asset-pipeline/web-components/streaming-chat.ts
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ export class StreamingChat extends HTMLElement {
const signal = this.controller.signal;

try {
const response = await fetch('/v1/completions', {
const response = await fetch(`/completions/${chatId}`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Expand Down
1 change: 1 addition & 0 deletions crates/axum-server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ tokio = { version = "1", default-features = false, features = ["macros", "rt-mul
tokio-util = "0"
validator = { version = "0", features = ["derive"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
tracing = "0"
tracing-subscriber = { version="0", features = ["env-filter"] }
tower-http = { version = "0", default-features = false, features = ["fs", "trace"] }
Expand Down
1 change: 1 addition & 0 deletions crates/axum-server/src/api_keys/new.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pub async fn new_api_key(
.bind(
&transaction,
&new_api_key.prompt_id,
&current_user.user_id,
&new_api_key.name,
&api_key,
)
Expand Down
132 changes: 101 additions & 31 deletions crates/axum-server/src/api_reverse_proxy.rs
Original file line number Diff line number Diff line change
@@ -1,47 +1,117 @@
//! Run with
//!
//! ```not_rust
//! $ cargo run -p example-http-proxy
//! ```
//!
//! In another terminal:
//!
//! ```not_rust
//! $ curl -v -x "127.0.0.1:3000" https://tokio.rs
//! ```
//!
//! Example is based on <https://github.com/hyperium/hyper/blob/master/examples/http_proxy.rs>

use axum::{
body::Body,
extract::State,
http::Request,
http::StatusCode,
response::{IntoResponse, Response},
Extension, RequestExt,
};
use http::Uri;
use hyper::client::HttpConnector;

use db::{queries, Pool};
use serde::{Deserialize, Serialize};

type Client = hyper::client::Client<HttpConnector, Body>;

#[derive(Serialize, Deserialize, Debug)]
struct Message {
role: String,
content: String,
}

#[derive(Serialize, Deserialize, Debug)]
struct Completion {
model: String,
streaming: Option<bool>,
messages: Vec<Message>,
temperature: Option<f32>,
}

pub async fn handler(
Extension(pool): Extension<Pool>,
State(client): State<Client>,
mut req: Request<hyper::Body>,
mut req: Request<Body>,
) -> Result<Response, StatusCode> {
let path = req.uri().path();
let path_query = req
.uri()
.path_and_query()
.map(|v| v.as_str())
.unwrap_or(path);

let uri = format!("http://llm-api:8080{path_query}");

*req.uri_mut() = Uri::try_from(uri).unwrap();

Ok(client
.request(req)
.await
.map_err(|_| StatusCode::BAD_REQUEST)?
.into_response())
if let Some(api_key) = req.headers().get("Authorization") {
let mut db_client = pool.get().await.unwrap();
let transaction = db_client.transaction().await.unwrap();

let prompt = queries::prompts::prompt_by_api_key()
.bind(&transaction, &api_key.to_str().unwrap())
.one()
.await
.map_err(|_| StatusCode::BAD_REQUEST)?;

let api_key = queries::api_keys::find_api_key()
.bind(&transaction, &api_key.to_str().unwrap())
.one()
.await
.map_err(|_| StatusCode::UNAUTHORIZED)?;

let path = req.uri().path();
let path_query = req
.uri()
.path_and_query()
.map(|v| v.as_str())
.unwrap_or(path);

let base_url = prompt.base_url;
let uri = format!("{base_url}{path_query}");

// If we are completions we need to add the prompt to the request
if path_query.ends_with("/completions") {
super::rls::set_row_level_security_user_id(&transaction, api_key.user_id)
.await
.map_err(|_| StatusCode::BAD_REQUEST)?;

let body: String = req.extract().await.map_err(|_| StatusCode::BAD_REQUEST)?;
let completion: Completion =
serde_json::from_str(&body).map_err(|_| StatusCode::BAD_REQUEST)?;

let generated_prompt = crate::prompt::execute_prompt(
&transaction,
prompt.id,
prompt.organisation_id,
"message.message",
)
.await
.map_err(|_| StatusCode::BAD_REQUEST)?;

let completion = Completion {
messages: vec![Message {
role: "user".to_string(),
content: generated_prompt,
}],
..completion
};

let completion_json =
serde_json::to_string(&completion).map_err(|_| StatusCode::BAD_REQUEST)?;

// Create a new request
let req = Request::post(uri)
.header("content-type", "application/json")
.body(Body::from(completion_json))
.map_err(|_| StatusCode::BAD_REQUEST)?;

Ok(client
.request(req)
.await
.map_err(|_| StatusCode::BAD_REQUEST)?
.into_response())
} else {
// Anything that is not completions gets passed direct to the LLM API

*req.uri_mut() = Uri::try_from(uri).unwrap();

Ok(client
.request(req)
.await
.map_err(|_| StatusCode::BAD_REQUEST)?
.into_response())
}
} else {
Err(StatusCode::UNAUTHORIZED)
}
}
Loading

0 comments on commit e0a47a2

Please sign in to comment.