[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to train from memory #544

Merged
merged 8 commits into from
Nov 28, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Improve progress tracking while training
  • Loading branch information
n1t0 committed Nov 28, 2020
commit b08bdb930759b18db5c333b71be43470bc953b9f
61 changes: 25 additions & 36 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use tk::tokenizer::{
Model, PaddingDirection, PaddingParams, PaddingStrategy, PostProcessor, TokenizerImpl,
TruncationParams, TruncationStrategy,
};
use tk::utils::iter::ResultShunt;
use tokenizers as tk;

use super::decoders::PyDecoder;
Expand Down Expand Up @@ -1085,51 +1084,41 @@ impl PyTokenizer {
})
}

#[args(trainer = "None")]
#[args(trainer = "None", length = "None")]
fn train_from_iterator(
&mut self,
iterator: &PyAny,
trainer: Option<&mut PyTrainer>,
length: Option<usize>,
) -> PyResult<()> {
use crate::utils::PySendIterator;

let mut trainer =
trainer.map_or_else(|| self.tokenizer.get_model().get_trainer(), |t| t.clone());
let (send, recv) = std::sync::mpsc::sync_channel(256);
let mut sender = Some(send);
let iterator: PyIterator = iterator.iter()?;

crossbeam::thread::scope(|s| {
let _train_handle = s.spawn(|_| {
self.tokenizer
.train(&mut trainer, recv.into_iter())
.map(|_| {})
});

ResultShunt::process(
// Each element of the iterator can either be:
// - An iterator, to allow batching
// - A string
iterator.flat_map(|seq| match seq {
Ok(s) => {
if let Ok(iter) = s.iter() {
itertools::Either::Left(iter.map(|i| i?.extract::<&str>()))
} else {
itertools::Either::Right(std::iter::once(s.extract::<&str>()))
}
let py_send = PySendIterator::new(
// Each element of the iterator can either be:
// - An iterator, to allow batching
// - A string
iterator.iter()?.flat_map(|seq| match seq {
Ok(s) => {
if let Ok(iter) = s.iter() {
itertools::Either::Left(iter.map(|i| i?.extract::<&str>()))
} else {
itertools::Either::Right(std::iter::once(s.extract::<&str>()))
}
Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
}),
|iter| {
if let Some(send) = sender.take() {
for seq in iter {
send.send(seq)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))?;
}
}
Ok(())
},
)?
}
Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
}),
length,
);

py_send.execute(|iter| {
self.tokenizer
.train(&mut trainer, iter)
.map(|_| {})
.map_err(|e| exceptions::PyException::new_err(e.to_string()))
})
.unwrap()
}

/// Apply all the post-processing steps to the given encodings.
Expand Down
69 changes: 69 additions & 0 deletions bindings/python/src/utils/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,75 @@ pub use normalization::*;
pub use pretokenization::*;
pub use regex::*;

// PySendIterator

use std::sync::mpsc::{sync_channel, IntoIter};
use tk::utils::iter::ResultShunt;

pub struct MaybeSizedIterator<I> {
length: Option<usize>,
iter: I,
}

impl<I> Iterator for MaybeSizedIterator<I>
where
I: Iterator,
{
type Item = I::Item;

fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}

fn size_hint(&self) -> (usize, Option<usize>) {
(self.length.unwrap_or(0), None)
}
}

pub struct PySendIterator<I: Iterator> {
iter: I,
length: Option<usize>,
}

impl<I, T> PySendIterator<I>
where
I: Iterator<Item = PyResult<T>>,
T: Send,
{
pub fn new(iter: I, length: Option<usize>) -> Self {
PySendIterator { iter, length }
}

pub fn execute<F>(self, mut scope: F) -> PyResult<()>
where
F: FnMut(MaybeSizedIterator<IntoIter<T>>) -> PyResult<()> + Send + Sync,
{
let (send, recv) = sync_channel(256);
let mut sender = Some(send);

crossbeam::thread::scope(|s| {
let length = self.length;
s.spawn(move |_| {
scope(MaybeSizedIterator {
length,
iter: recv.into_iter(),
})
});

ResultShunt::process(self.iter, |iter| {
if let Some(send) = sender.take() {
for i in iter {
send.send(i)
.map_err(|e| exceptions::PyException::new_err(e.to_string()))?;
}
}
Ok(())
})?
})
.unwrap()
}
}

// PyChar
// This type is a temporary hack to accept `char` as argument
// To be removed once https://github.com/PyO3/pyo3/pull/1282 has been released
Expand Down
73 changes: 64 additions & 9 deletions tokenizers/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ use serde::de::DeserializeOwned;
use serde::export::Formatter;
use serde::{Deserialize, Serialize};

use crate::utils::iter::ResultShunt;
use crate::utils::parallelism::*;
use crate::utils::progress::{ProgressBar, ProgressStyle};

Expand Down Expand Up @@ -124,15 +125,16 @@ pub trait Decoder {
}

/// A `Trainer` has the responsibility to train a model. We feed it with lines/sentences
/// and it returns a `Model` when done.
/// and then it can train the given `Model`.
pub trait Trainer {
type Model: Model + Sized;
/// Whether we should show progress during the training.
fn should_show_progress(&self) -> bool;
/// The actual training method. This will return a new trained Model as well as a list
/// of `special_tokens` to be added directly to the tokenizer along with the model.
fn train(&self, model: &mut Self::Model) -> Result<Vec<AddedToken>>;
/// Process an iterator of sequences already pre-processed by the Tokenizer
/// Process an iterator of sequences, calling `process` for each of them in order to
/// pre-process the said sequence as relevant.
fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
where
I: Iterator<Item = S> + Send,
Expand Down Expand Up @@ -962,12 +964,20 @@ where
.collect()
}

/// Train our Model from files
pub fn train_from_files<T>(&mut self, trainer: &mut T, files: Vec<String>) -> Result<&mut Self>
where
T: Trainer<Model = M> + Sync,
{
let mut len = 0;
for file in files.iter() {
len += File::open(file)
.and_then(|f| f.metadata())
.map(|m| m.len())?;
}

let max_read = 1_000_000;
use crate::utils::iter::ResultShunt;

ResultShunt::process(
files.into_iter().flat_map(|filename| {
match File::open(filename) {
Expand All @@ -981,12 +991,52 @@ where
Err(e) => itertools::Either::Right(std::iter::once(Err(e))),
}
}),
|iter| self.train(trainer, iter).map(|_| {}),
|sequences| -> Result<()> {
let progress = if trainer.should_show_progress() {
let progress = ProgressBar::new(len);
progress.set_style(
ProgressStyle::default_bar()
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {percent:>18!}%"),
);
progress
.set_message(&format!("Pre-processing files ({:.2} Mo)", len / 1_000_000));
progress.set_draw_delta(len / 100); // Redraw only every 2%
Some(progress)
} else {
None
};

trainer.feed(
sequences.map(|s| {
if let Some(progress) = &progress {
progress.inc(s.len() as u64)
}
s
}),
|seq| {
let normalized = self.do_normalize(seq.as_ref())?;
let pre_tokenized = self.do_pre_tokenize(normalized)?;
Ok(pre_tokenized
.get_splits(OffsetReferential::Original, OffsetType::Byte)
.into_iter()
.map(|(s, _, _)| s.to_owned())
.collect())
},
)?;

if let Some(pbar) = progress {
pbar.finish();
}
let special_tokens = trainer.train(&mut self.model)?;
self.add_special_tokens(&special_tokens);

Ok(())
},
)??;
Ok(self)
}

/// Train a model and replace our current Model, using the given Trainer
/// Train our Model, using the given Trainer and iterator
pub fn train<T, I, S>(&mut self, trainer: &mut T, sequences: I) -> Result<&mut Self>
where
T: Trainer<Model = M> + Sync,
Expand All @@ -1002,17 +1052,22 @@ where
.template("[{elapsed_precise}] {msg:<40!} {wide_bar} {pos:<9!}/{len:>9!}"),
);
progress.set_message("Pre-processing sequences");
progress.set_draw_delta(len / 100); // Redraw only every 2%
if len > 0 {
progress.set_draw_delta(len / 100); // Redraw only every 2%
} else {
// Trying to have a good default to avoid progress tracking being the bottleneck
progress.set_draw_delta(1000);
}
Some(progress)
} else {
None
};

trainer.feed(
sequences.map(|s| {
// if let Some(progress) = &progress {
// progress.inc(1)
// }
if let Some(progress) = &progress {
progress.inc(1)
}
s
}),
|seq| {
Expand Down