[go: nahoru, domu]

Skip to content

Commit

Permalink
Implement the Starcoder 2 model architecture (#522)
Browse files Browse the repository at this point in the history
* Initial implementation

* Initial implementation

* Fix

* Fix

* Update aici

* Clippy

* Implement isq, anymoe

* Implement device mapping

* Cast for f32

* Add support for xlora and lora
  • Loading branch information
EricLBuehler committed Jul 2, 2024
1 parent bb31e12 commit 52c3edf
Show file tree
Hide file tree
Showing 20 changed files with 2,246 additions and 183 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ mistralrs-vision = { version = "0.1.13", path = "../mistralrs-vision" }
csv = "1.3.0"
reqwest.workspace = true
base64.workspace = true
bytemuck_derive = "1.7.0"

[features]
pyo3_macros = ["pyo3"]
Expand Down
21 changes: 13 additions & 8 deletions mistralrs-core/src/aici/bintokens.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
// Originally from https://github.com/microsoft/aici/blob/64f0b551dee49e320e9b3b92289f3d6f2e888276/aicirt/src/bintokens.rs
// Licensed under the MIT license

use crate::aici::{bytes::TokRxInfo, toktree::TokTrie};
use crate::aici::bytes::TokRxInfo;
use anyhow::{anyhow, bail, Result};
use serde::{Deserialize, Serialize};
use std::collections::{BTreeMap, HashMap};
use tokenizers::{normalizers::Sequence, NormalizerWrapper, Tokenizer};
use tracing::{error, warn};
use tracing::warn;

use super::toktree::TokTrie;

#[derive(Serialize, Deserialize)]
pub struct ByteTokenizer {
Expand All @@ -17,9 +16,13 @@ pub struct ByteTokenizer {
token_bytes: Vec<Vec<u8>>,
pub special: BTreeMap<String, u32>,
}
fn is_self_mapped(c: char) -> bool {

// useful when debugging this: https://www.cogsci.ed.ac.uk/~richard/utf-8.cgi

const fn is_self_mapped(c: char) -> bool {
matches!(c, '!'..='~' | '\u{00A1}'..='\u{00AC}' | '\u{00AE}'..='\u{00FF}')
}

fn build_char_map() -> HashMap<char, u8> {
let mut res = HashMap::default();
let mut k = 0x100u32;
Expand Down Expand Up @@ -88,7 +91,6 @@ impl ByteTokenizer {
bail!("can't determine decoder type: {:?}", hft.get_decoder());
}

#[allow(clippy::cast_possible_truncation)]
let vocab_size = hft.get_vocab_size(true) as u32;
let added = hft.get_added_tokens_decoder();

Expand Down Expand Up @@ -116,6 +118,9 @@ impl ByteTokenizer {
let char_map = build_char_map();

for tok_id in 0..vocab_size {
if added.contains_key(&tok_id) {
continue;
}
if let Some(tok_name) = res.hf_tokenizer.id_to_token(tok_id) {
if is_byte_fallback {
if tok_name.len() == 6 && tok_name.starts_with("<0x") && tok_name.ends_with('>')
Expand All @@ -142,7 +147,7 @@ impl ByteTokenizer {
let bytes = match bytes {
Ok(b) => b,
Err(e) => {
error!("error: {} for {:?}", e, tok_name);
println!("error: {} for {:?}", e, tok_name);
continue;
}
};
Expand Down
21 changes: 20 additions & 1 deletion mistralrs-core/src/aici/bytes.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,31 @@
use std::mem::size_of;

use bytemuck_derive::{Pod, Zeroable};

pub(crate) type TokenId = u32;

#[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroable, Pod)]
#[repr(C)]
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct TokRxInfo {
pub vocab_size: u32,
pub tok_eos: TokenId,
}

#[derive(Clone, Copy, PartialEq, Eq, Debug, Zeroable, Pod)]
#[repr(C)]
pub struct U32Pair(pub u32, pub u32);

pub fn vec_from_bytes<T: bytemuck::Pod>(bytes: &[u8]) -> Vec<T> {
if bytes.len() % size_of::<T>() != 0 {
panic!(
"vecT: got {} bytes, needed multiple of {}",
bytes.len(),
size_of::<T>()
);
}
bytemuck::cast_slice(bytes).to_vec()
}

pub fn to_hex_string(bytes: &[u8]) -> String {
bytes
.iter()
Expand Down
48 changes: 24 additions & 24 deletions mistralrs-core/src/aici/cfg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ use cfgrammar::{
};
use lrtable::{from_yacc, Action, Minimiser, StIdx, StateTable};
use rustc_hash::FxHashMap;
use std::sync::{Arc, RwLock};
use std::sync::RwLock;
use std::vec;
use tracing::debug;
use vob::{vob, Vob};

type StorageT = u32;
Expand All @@ -24,15 +23,14 @@ enum ParseResult {
Continue,
}

#[derive(Clone)]
struct CfgStats {
yacc_actions: usize,
states_pushed: usize,
}

pub struct CfgParser {
grm: Arc<YaccGrammar<StorageT>>,
stable: Arc<StateTable<StorageT>>,
grm: YaccGrammar<StorageT>,
stable: StateTable<StorageT>,
lexer: Lexer,
byte_states: Vec<ByteState>,
pat_idx_to_tidx: Vec<TIdx<u32>>,
Expand Down Expand Up @@ -131,10 +129,10 @@ impl CfgParser {
};

if false {
debug!("core\n{}\n\n", sgraph.pp(&grm, true));
println!("core\n{}\n\n", sgraph.pp(&grm, true));
for pidx in grm.iter_pidxs() {
let prod = grm.prod(pidx);
debug!("{:?} -> {}", prod, prod.len());
println!("{:?} -> {}", prod, prod.len());
}
}

Expand Down Expand Up @@ -170,29 +168,29 @@ impl CfgParser {
.collect::<Vec<_>>();

for ridx in grm.iter_rules() {
let rname = grm.rule_name_str(ridx);
if rname.to_uppercase() != rname {
let rule_name = grm.rule_name_str(ridx);
if rule_name.to_uppercase() != rule_name {
continue;
}
for pidx in grm.rule_to_prods(ridx) {
let toks = grm.prod(*pidx);
if let [Symbol::Token(tidx)] = toks {
let idx = *tidx_to_pat_idx.get(tidx).unwrap();
// this doesn't seem very useful
// friendly_pattern_names[idx] = rname.to_string();
if rname == "SKIP" {
// friendly_pattern_names[idx] = rule_name.to_string();
if rule_name == "SKIP" {
skip_patterns.set(idx, true);
}
}
}
}

debug!("patterns: {:?}", friendly_pattern_names);
println!("patterns: {:?}", friendly_pattern_names);

let mut vobset = VobSet::new();
// all-zero has to be inserted first
let _all0 = vobset.get(&vob![false; patterns.len()]);
let all1 = vobset.get(&vob![true; patterns.len()]);
let _all0 = vobset.insert_or_get(&vob![false; patterns.len()]);
let all1 = vobset.insert_or_get(&vob![true; patterns.len()]);

// TIME: 27ms
let dfa = Lexer::from(patterns, &mut vobset);
Expand Down Expand Up @@ -225,13 +223,13 @@ impl CfgParser {
}
}

vobset.get(&r)
vobset.insert_or_get(&r)
})
.collect::<Vec<_>>();

let mut cfg = CfgParser {
grm: grm.into(),
stable: stable.into(),
grm,
stable,
lexer: dfa,
byte_states: vec![byte_state],
pat_idx_to_tidx,
Expand All @@ -252,7 +250,7 @@ impl CfgParser {
// compute viable set of initial tokens
cfg.byte_states[0].viable = cfg.viable_vobidx(cfg_start);
if LOG_PARSER {
debug!(
println!(
"initial viable: {:?}",
cfg.vobset.resolve(cfg.byte_states[0].viable)
);
Expand Down Expand Up @@ -283,7 +281,7 @@ impl CfgParser {
let act = self.stable.action(stidx, lexeme);

if LOG_PARSER {
debug!(
println!(
"parse: {:?} {:?} -> {:?}",
pstack,
self.friendly_token_name(lexeme),
Expand Down Expand Up @@ -316,10 +314,10 @@ impl CfgParser {

#[allow(dead_code)]
fn print_viable(&self, lbl: &str, vob: &Vob) {
debug!("viable tokens {}:", lbl);
println!("viable tokens {}:", lbl);
for (idx, b) in vob.iter().enumerate() {
if b {
debug!(" {}: {}", idx, self.friendly_pattern_names[idx]);
println!(" {}: {}", idx, self.friendly_pattern_names[idx]);
}
}
}
Expand Down Expand Up @@ -348,7 +346,7 @@ impl CfgParser {
Some((ls, Some(pat_idx))) => ("parse", self.run_parser(pat_idx, &top, ls)),
};
if LOG_PARSER {
debug!(
println!(
" -> {} {}",
info,
if res.is_none() { "error" } else { "ok" }
Expand All @@ -375,14 +373,16 @@ impl CfgParser {
let mut s = self.stats.write().unwrap();
s.yacc_actions += 1;
}

if LOG_PARSER {
println!();
}
let pstack = self.pstack_for(top);
if self.skip_patterns[pat_idx] {
let stidx = *pstack.last().unwrap();
let viable = self.viable_vobidx(stidx);
//self.print_viable("reset", &viable);
if LOG_PARSER {
debug!("parse: {:?} skip", pstack);
println!("parse: {:?} skip", pstack);
}
// reset viable states - they have been narrowed down to SKIP
self.mk_byte_state(ls, top.parse_stack_idx, viable)
Expand Down
Loading

0 comments on commit 52c3edf

Please sign in to comment.