[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to train from memory #544

Merged
merged 8 commits into from
Nov 28, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Improve docs and fix tests around training
  • Loading branch information
n1t0 committed Nov 28, 2020
commit a20fbf74b91570bcd6815446ab084de33b73e4fa
1 change: 1 addition & 0 deletions bindings/python/examples/train_with_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@ def batch_iterator():
for i in range(0, len(dataset["train"]), batch_length):
yield dataset["train"][i : i + batch_length]["text"]


# And finally train
bpe_tokenizer.train_from_iterator(batch_iterator(), length=len(dataset["train"]))
39 changes: 39 additions & 0 deletions bindings/python/py_src/tokenizers/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1022,6 +1022,45 @@ class Tokenizer:
:obj:`Optional[int]`: An optional id, :obj:`None` if out of vocabulary
"""
pass
def train(self, files, trainer=None):
"""
Train the Tokenizer using the given files.

Reads the files line by line, while keeping all the whitespace, even new lines.
If you want to train from data store in-memory, you can check
:meth:`~tokenizers.Tokenizer.train_from_iterator`

Args:
files (:obj:`List[str]`):
A list of path to the files that we should use for training

trainer (:obj:`~tokenizers.trainers.Trainer`, `optional`):
An optional trainer that should be used to train our Model
"""
pass
def train_from_iterator(self, iterator, trainer=None, length=None):
"""
Train the Tokenizer using the provided iterator.

You can provide anything that is a Python Iterator

* A list of sequences :obj:`List[str]`
* A generator that yields :obj:`str` or :obj:`List[str]`
* A Numpy array of strings
* ...

Args:
iterator (:obj:`Iterator`):
Any iterator over strings or list of strings

trainer (:obj:`~tokenizers.trainers.Trainer`, `optional`):
An optional trainer that should be used to train our Model

length (:obj:`int`, `optional`):
The total number of sequences in the iterator. This is used to
provide meaningful progress tracking
"""
pass
@property
def truncation(self):
"""
Expand Down
33 changes: 33 additions & 0 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1068,7 +1068,20 @@ impl PyTokenizer {
Ok(self.tokenizer.add_special_tokens(&tokens))
}

/// Train the Tokenizer using the given files.
///
/// Reads the files line by line, while keeping all the whitespace, even new lines.
/// If you want to train from data store in-memory, you can check
/// :meth:`~tokenizers.Tokenizer.train_from_iterator`
///
/// Args:
/// files (:obj:`List[str]`):
/// A list of path to the files that we should use for training
///
/// trainer (:obj:`~tokenizers.trainers.Trainer`, `optional`):
/// An optional trainer that should be used to train our Model
#[args(trainer = "None")]
#[text_signature = "(self, files, trainer = None)"]
fn train(&mut self, files: Vec<String>, trainer: Option<&mut PyTrainer>) -> PyResult<()> {
let mut trainer =
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
Expand All @@ -1084,7 +1097,27 @@ impl PyTokenizer {
})
}

/// Train the Tokenizer using the provided iterator.
///
/// You can provide anything that is a Python Iterator
///
/// * A list of sequences :obj:`List[str]`
/// * A generator that yields :obj:`str` or :obj:`List[str]`
/// * A Numpy array of strings
/// * ...
///
/// Args:
/// iterator (:obj:`Iterator`):
/// Any iterator over strings or list of strings
///
/// trainer (:obj:`~tokenizers.trainers.Trainer`, `optional`):
/// An optional trainer that should be used to train our Model
///
/// length (:obj:`int`, `optional`):
/// The total number of sequences in the iterator. This is used to
/// provide meaningful progress tracking
#[args(trainer = "None", length = "None")]
#[text_signature = "(self, iterator, trainer=None, length=None)"]
fn train_from_iterator(
&mut self,
iterator: &PyAny,
Expand Down
6 changes: 3 additions & 3 deletions tokenizers/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ use std::path::Path;
fn main() -> Result<()> {
let vocab_size: usize = 100;

let trainer = BpeTrainerBuilder::new()
let mut trainer = BpeTrainerBuilder::new()
.show_progress(true)
.vocab_size(vocab_size)
.min_frequency(0)
Expand All @@ -97,8 +97,8 @@ fn main() -> Result<()> {

let pretty = false;
tokenizer
.train(
&trainer,
.train_from_files(
&mut trainer,
vec!["path/to/vocab.txt".to_string()],
)?
.save("tokenizer.json", pretty)?;
Expand Down
6 changes: 3 additions & 3 deletions tokenizers/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
//! fn main() -> Result<()> {
//! let vocab_size: usize = 100;
//!
//! let trainer = BpeTrainerBuilder::new()
//! let mut trainer = BpeTrainerBuilder::new()
//! .show_progress(true)
//! .vocab_size(vocab_size)
//! .min_frequency(0)
Expand All @@ -84,8 +84,8 @@
//!
//! let pretty = false;
//! tokenizer
//! .train(
//! &trainer,
//! .train_from_files(
//! &mut trainer,
//! vec!["path/to/vocab.txt".to_string()],
//! )?
//! .save("tokenizer.json", pretty)?;
Expand Down
15 changes: 7 additions & 8 deletions tokenizers/src/models/bpe/trainer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,22 +138,21 @@ impl BpeTrainerBuilder {
}
}

/// In charge of training a `BPE` model from a mapping of words to word counts.
/// In charge of training a `BPE` model
///
/// # Examples
///
/// ```
/// use std::collections::HashMap;
/// use tokenizers::tokenizer::Trainer;
/// use tokenizers::models::bpe::{BPE, BpeTrainer};
///
/// let word_counts: HashMap<String, u32> = [
/// (String::from("Hello"), 1),
/// (String::from("World"), 1),
/// ].iter().cloned().collect();
/// let trainer = BpeTrainer::default();
/// let sequences = vec![ "Hello", "World" ];
///
/// let mut trainer = BpeTrainer::default();
/// trainer.feed(sequences.iter(), |s| Ok(vec![s.to_owned()]));
///
/// let mut model = BPE::default();
/// let special_tokens = trainer.train(word_counts, &mut model).unwrap();
/// let special_tokens = trainer.train(&mut model).unwrap();
/// ```
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
Expand Down
12 changes: 6 additions & 6 deletions tokenizers/tests/documentation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ fn train_tokenizer() {
.build()
.unwrap();

let trainer = BpeTrainerBuilder::new()
let mut trainer = BpeTrainerBuilder::new()
.show_progress(false)
.vocab_size(vocab_size)
.min_frequency(0)
Expand All @@ -35,7 +35,7 @@ fn train_tokenizer() {

let pretty = true;
tokenizer
.train(&trainer, vec!["data/small.txt".to_string()])
.train_from_files(&mut trainer, vec!["data/small.txt".to_string()])
.unwrap()
.save("data/tokenizer.json", pretty)
.unwrap();
Expand Down Expand Up @@ -80,7 +80,7 @@ fn quicktour_slow_train() -> tokenizers::Result<()> {
// START quicktour_init_trainer
use tokenizers::models::bpe::BpeTrainer;

let trainer = BpeTrainer::builder()
let mut trainer = BpeTrainer::builder()
.special_tokens(vec![
AddedToken::from("[UNK]", true),
AddedToken::from("[CLS]", true),
Expand All @@ -102,7 +102,7 @@ fn quicktour_slow_train() -> tokenizers::Result<()> {
"data/wikitext-103-raw/wiki.test.raw".into(),
"data/wikitext-103-raw/wiki.valid.raw".into(),
];
tokenizer.train(&trainer, files)?;
tokenizer.train_from_files(&mut trainer, files)?;
// END quicktour_train
// START quicktour_save
tokenizer.save("data/tokenizer-wiki.json", false)?;
Expand Down Expand Up @@ -403,7 +403,7 @@ fn train_pipeline_bert() -> tokenizers::Result<()> {
// START bert_train_tokenizer
use tokenizers::models::{wordpiece::WordPieceTrainer, TrainerWrapper};

let trainer: TrainerWrapper = WordPieceTrainer::builder()
let mut trainer: TrainerWrapper = WordPieceTrainer::builder()
.vocab_size(30_522)
.special_tokens(vec![
AddedToken::from("[UNK]", true),
Expand All @@ -419,7 +419,7 @@ fn train_pipeline_bert() -> tokenizers::Result<()> {
"data/wikitext-103-raw/wiki.test.raw".into(),
"data/wikitext-103-raw/wiki.valid.raw".into(),
];
bert_tokenizer.train(&trainer, files)?;
bert_tokenizer.train_from_files(&mut trainer, files)?;

bert_tokenizer.save("data/bert-wiki.json", false)?;
// END bert_train_tokenizer
Expand Down
4 changes: 2 additions & 2 deletions tokenizers/tests/training.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ fn bpe_values_after_training() {
)
.build()
.unwrap();
let trainer = tokenizer.get_model().get_trainer();
let mut trainer = tokenizer.get_model().get_trainer();
tokenizer
.train(&trainer, vec!["./data/small.txt".to_string()])
.train_from_files(&mut trainer, vec!["./data/small.txt".to_string()])
.unwrap();
assert_eq!(tokenizer.get_model().dropout, Some(0.1));
assert_eq!(tokenizer.get_model().unk_token, Some("[UNK]".to_string()));
Expand Down
9 changes: 7 additions & 2 deletions tokenizers/tests/unigram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use std::path::Path;
use tokenizers::models::unigram::Lattice;
use tokenizers::models::unigram::Unigram;
use tokenizers::models::unigram::UnigramTrainer;
use tokenizers::tokenizer::{Model, Trainer};
use tokenizers::tokenizer::Model;

#[test]
fn test_unigram_from_file() {
Expand Down Expand Up @@ -56,7 +56,12 @@ fn test_train_unigram_from_file() {
.build()
.unwrap();
let mut model = Unigram::default();
trainer.train(word_counts, &mut model).unwrap();

let sentences: Vec<_> = word_counts
.iter()
.map(|(s, i)| (s.to_owned(), *i))
.collect();
trainer.do_train(sentences, &mut model).unwrap();
assert_eq!(model.get_vocab_size(), 719);
}

Expand Down