[go: nahoru, domu]

Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Train custom data #2309

Closed
wants to merge 3 commits into from
Closed
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Gemaakt met Colaboratory
  • Loading branch information
FenNlay committed Nov 10, 2022
commit d2351bb372d3c37bdd07ef554e6ab77337f71429
358 changes: 358 additions & 0 deletions Fabs_version_train_your_first_coqui_STT_model.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,358 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/github/FenNlay/STT/blob/train_custom_data/Fabs_version_train_your_first_coqui_STT_model.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"id": "f79d99ef",
"metadata": {
"id": "f79d99ef"
},
"source": [
"# Train your first 🐸 STT model 💫\n",
"\n",
"👋 Hello and welcome to Coqui (🐸) STT \n",
"\n",
"The goal of this notebook is to show you a **typical workflow** for **training** and **testing** an STT model with 🐸.\n",
"\n",
"Let's train a very small model on a very small amount of data so we can iterate quickly.\n",
"\n",
"In this notebook, we will:\n",
"\n",
"1. Download data and format it for 🐸 STT.\n",
"2. Configure the training and testing runs.\n",
"3. Train a new model.\n",
"4. Test the model and display its performance.\n",
"\n",
"So, let's jump right in!\n",
"\n",
"*PS - If you just want a working, off-the-shelf model, check out the [🐸 Model Zoo](https://www.coqui.ai/models)*"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fa2aec78",
"metadata": {
"id": "fa2aec78"
},
"outputs": [],
"source": [
"## Install Coqui STT\n",
"! pip install -U pip\n",
"! pip install coqui_stt_training"
]
},
{
"cell_type": "markdown",
"id": "be5fe49c",
"metadata": {
"id": "be5fe49c"
},
"source": [
"## ✅ Download & format sample data for English\n",
"\n",
"**First things first**: we need some data.\n",
"\n",
"We're training a Speech-to-Text model, so we need some _speech_ and we need some _text_. Specificially, we want _transcribed speech_. Let's download an English audio file and its transcript and then format them for 🐸 STT. \n",
"\n",
"🐸 STT expects to find information about your data in a CSV file, where each line contains:\n",
"\n",
"1. the **path** to an audio file\n",
"2. the **size** of that audio file\n",
"3. the **transcript** of that audio file.\n",
"\n",
"Formatting the audio and transcript isn't too difficult in this case. We define `download_sample_data()` which does all the work. If you have a custom dataset, you will want to write a custom data importer.\n",
"\n",
"**Second things second**: we want an alphabet. The output layer of a typical* 🐸 STT model represents letters in the alphabet. Let's download an English alphabet from Coqui and use that.\n",
"\n",
"*_If you are working with languages with large character sets (e.g. Chinese), you can set `bytes_output_mode=True` instead of supplying an `alphabet.txt` file. In this case, the output layer of the STT model will correspond to individual UTF-8 bytes instead of individual characters._"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53945462",
"metadata": {
"scrolled": true,
"id": "53945462"
},
"outputs": [],
"source": [
"### Download sample data\n",
"import os\n",
"import pandas\n",
"import urllib.request\n",
"from coqui_stt_training.util.downloader import maybe_download\n",
"\n",
"### Import data\n",
"!curl -o cv_en.txt \"https://datasets-server.huggingface.co/first-rows?dataset=common_voice&config=en&split=validated\"\n",
"\n",
"def download_sample_data():\n",
" data_dir=\"english/\"\n",
"\n",
" # Download data + alphabet\n",
" fhandle = open(\"cv_en.txt\")\n",
" for line in fhandle:\n",
" x = line.split(\"],\")\n",
" y = line.split(\",{\")\n",
" \n",
" sentences = []\n",
" for i in x:\n",
" if i.startswith(\"\\\"sentence\"):\n",
" i = i.replace(\"\\\"sentence\\\":\\\"\", \"\")\n",
" i = i.replace(\"\\\\\\\"\", \"\")\n",
" stp = \"\\\"\"\n",
" i = i.split(stp, 1)[0]\n",
" sentences.append(i)\n",
" \n",
" audios = []\n",
" count = 0\n",
" for j in y:\n",
" if j.startswith(\"\\\"src\"):\n",
" j = j.replace(\"\\\"src\\\":\\\"\", \"\")\n",
" stp = \"\\\"\"\n",
" j = j.split(stp, 1)[0]\n",
" audios.append(j)\n",
" \n",
" a = 0\n",
" audio_file = []\n",
" for l in audios:\n",
" name = f\"audio{a}.wav\"\n",
" file = maybe_download(f\"{name}\", data_dir, f\"{l}\")\n",
" audio_file.append(file)\n",
" a += 1\n",
"\n",
" b = 0\n",
" transcript_file = []\n",
" for k in sentences:\n",
" name = f\"audio{b}.txt\"\n",
" with open(f\"/content/{data_dir}{name}\", \"w\") as wf:\n",
" wf.write(f\"{k}\")\n",
" transcript_file.append(f\"{data_dir}{name}\")\n",
" b += 1\n",
"\n",
"\n",
" alphabet = maybe_download(\"alphabet.txt\", data_dir, \"https://raw.githubusercontent.com/coqui-ai/STT/main/data/alphabet.txt\")\n",
" \n",
" # Format data\n",
" df = pandas.DataFrame(data=[(1, 1)],\n",
" columns=[\"wav_filename\", \"wav_filesize\"])\n",
" for m in audio_file:\n",
" new_row = {\"wav_filename\": os.path.abspath(m), \"wav_filesize\": os.path.getsize(m)}\n",
" df = df.append(new_row, ignore_index=True)\n",
" trans = [1]\n",
" for n in transcript_file:\n",
" with open(n, \"r\") as fin:\n",
" transcript = \" \".join(fin.read().strip().lower().split(\" \")).replace(\".\", \"\").replace(\",\", \"\").replace(\"’\", \"\").replace(\"?\", \"\").replace(\"!\", \"\").replace(\"-\", \"\")\n",
" trans.append(transcript)\n",
" df['transcript'] = trans\n",
" df = df.iloc[1: , :]\n",
" \n",
" # Save formatted CSV \n",
" df.to_csv(os.path.join(data_dir, \"data_cv.csv\"), index=False)\n",
"\n",
"# Download and format data\n",
"download_sample_data()"
]
},
{
"cell_type": "markdown",
"id": "96e8b708",
"metadata": {
"id": "96e8b708"
},
"source": [
"### 👀 Take a look at the data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fa2aec77",
"metadata": {
"id": "fa2aec77"
},
"outputs": [],
"source": [
"csv_file = open(\"english/data_cv.csv\", \"r\")\n",
"print(csv_file.read())"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6c046277",
"metadata": {
"id": "6c046277"
},
"outputs": [],
"source": [
"alphabet_file = open(\"english/alphabet.txt\", \"r\")\n",
"print(alphabet_file.read())"
]
},
{
"cell_type": "markdown",
"id": "d9dfac21",
"metadata": {
"id": "d9dfac21"
},
"source": [
"## ✅ Configure & set hyperparameters\n",
"\n",
"Coqui STT comes with a long list of hyperparameters you can tweak. We've set default values, but you will often want to set your own. You can use `initialize_globals_from_args()` to do this. \n",
"\n",
"You must **always** configure the paths to your data, and you must **always** configure your alphabet. Additionally, here we show how you can specify the size of hidden layers (`n_hidden`), the number of epochs to train for (`epochs`), and to initialize a new model from scratch (`load_train=\"init\"`)."
]
},
{
"cell_type": "code",
"execution_count": 90,
"id": "d264fdec",
"metadata": {
"id": "d264fdec"
},
"outputs": [],
"source": [
"from coqui_stt_training.util.config import initialize_globals_from_args\n",
"\n",
"initialize_globals_from_args(\n",
" alphabet_config_path=\"english/alphabet.txt\",\n",
" checkpoint_dir=\"ckpt_dir\",\n",
" train_files=[\"english/data_cv.csv\"],\n",
" dev_files=[\"english/data_cv.csv\"],\n",
" test_files=[\"english/data_cv.csv\"],\n",
" load_train=\"init\",\n",
" n_hidden=200,\n",
" epochs=100,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "799c1425",
"metadata": {
"id": "799c1425"
},
"source": [
"### 👀 View all Config settings"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "03b33d2b",
"metadata": {
"id": "03b33d2b"
},
"outputs": [],
"source": [
"from coqui_stt_training.util.config import Config\n",
"\n",
"# Take a peek at the entire Config\n",
"print(Config.to_json())"
]
},
{
"cell_type": "markdown",
"id": "ae82fd75",
"metadata": {
"id": "ae82fd75"
},
"source": [
"## ✅ Train a new model\n",
"\n",
"Let's kick off a training run 🚀🚀🚀 (using the configure you set above).\n",
"\n",
"This notebook should work on either a GPU or a CPU. However, in case you're running this on _multiple_ GPUs we want to only use one, because the sample dataset (one audio file) is too small to split across multiple GPUs."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "550a504e",
"metadata": {
"scrolled": true,
"id": "550a504e"
},
"outputs": [],
"source": [
"from coqui_stt_training.train import train\n",
"\n",
"# use maximum one GPU\n",
"os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"0\"\n",
"\n",
"train()"
]
},
{
"cell_type": "markdown",
"id": "9f6dc959",
"metadata": {
"id": "9f6dc959"
},
"source": [
"## ✅ Test the model\n",
"\n",
"We made it! 🙌\n",
"\n",
"Let's kick off the testing run, which displays performance metrics.\n",
"\n",
"We're committing the cardinal sin of ML 😈 (aka - testing on our training data) so you don't want to deploy this model into production. In this notebook we're focusing on the workflow itself, so it's forgivable 😇\n",
"\n",
"You can see from the test output that our tiny model has overfit to the data, and basically memorized this one sentence.\n",
"\n",
"When you start training your own models, make sure your testing data doesn't include your training data 😅"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd42bc7a",
"metadata": {
"id": "dd42bc7a"
},
"outputs": [],
"source": [
"from coqui_stt_training.evaluate import test\n",
"\n",
"test()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
},
"colab": {
"provenance": [],
"collapsed_sections": [],
"include_colab_link": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}