Introduction
Meta AI introduced wav2vec2 XLS-R ("XLS-R") at the end of 2021. XLS-R is a machine learning ("ML") model for cross-lingual speech representations learning; and it was trained on over 400,000 hours of publicly available speech audio across 128 languages. Upon its release, the model represented a leap over Meta AI's XLSR-53 cross-lingual model which was trained on approximately 50,000 hours of speech audio across 53 languages.
This guide explains the steps to finetune XLS-R for automatic speech recognition ("ASR") using a Kaggle Notebook. The model will be finetuned on Chilean Spanish, but the general steps can be followed to finetune XLS-R on different languages that you desire.
Running inference on the finetuned model will be described in a companion tutorial making this guide the first of two parts. I decided to create a separate inference-specific guide as this finetuning guide became a bit long.
It is assumed you have an existing ML background and that you understand basic ASR concepts. Beginners may have a difficult time following/understanding the build steps.
A Bit of Background on XLS-R
The original wav2vec2 model introduced in 2020 was pretrained on 960 hours of Librispeech dataset speech audio and ~53,200 hours of LibriVox dataset speech audio. Upon its release, two model sizes were available: the BASE model with 95 million parameters and the LARGE model with 317 million parameters.
XLS-R, on the other hand, was pretrained on multilingual speech audio from 5 datasets:
- VoxPopuli: A total of ~372,000 hours of speech audio across 23 European languages of parliamentary speech from the European parliament.
- Multilingual Librispeech: A total of ~50,000 hours of speech audio across eight European languages, with the majority (~44,000 hours) of audio data in English.
- CommonVoice: A total of ~7,000 hours of speech audio across 60 languages.
- VoxLingua107: A total of ~6,600 hours of speech audio across 107 languages based on YouTube content.
- BABEL: A total of ~1,100 hours of speech audio across 17 African and Asian languages based on conversational telephone speech.
There are 3 XLS-R models: XLS-R (0.3B) with 300 million parameters, XLS-R (1B) with 1 billion parameters, and XLS-R (2B) with 2 billion parameters. This guide will use the XLS-R (0.3B) model.
Approach
There are some great write-ups on how to finetune wav2vev2 models, with perhaps this one being a "gold standard" of sorts. Of course, the general approach here mimics what you will find in other guides. You will:
- Load a training dataset of audio data and associated text transcriptions.
- Create a vocabulary from the text transcriptions in the dataset.
- Initialize a wav2vec2 processor that will extract features from the input data, as well as convert text transcriptions to sequences of labels.
- Finetune wav2vec2 XLS-R on the processed input data.
However, there are three key differences between this guide and others:
- The guide does not provide as much "inline" discussion on relevant ML and ASR concepts.
- While each sub-section on individual notebook cells will include details on the use/purpose of the particular cell, it is assumed you have an existing ML background and that you understand basic ASR concepts.
- The Kaggle Notebook that you will build organizes utility methods in top-level cells.
- Whereas many finetuning notebooks tend to have a sort of "stream-of-consciousness"-type layout, I elected to organize all utility methods together. If you're new to wav2vec2, you may find this approach confusing. However, to reiterate, I do my best to be explicit when explaining the purpose of each cell in each cell's dedicated sub-section. If you are just learning about wav2vec2, you might benefit from taking a quick glance at my HackerNoon article wav2vec2 for Automatic Speech Recognition in Plain English.
- This guide describes the steps for finetuning only.
- As mentioned in the Introduction, I opted to create a separate companion guide on how to run inference on the finetuned XLS-R model that you will generate. This was done due to prevent this guide from becoming excessively long.
Prerequisites and Before You Get Started
To complete the guide, you will need to have:
- An existing Kaggle account. If you don't have an existing Kaggle account, you need to create one.
- An existing Weights and Biases account ("WandB"). If you don't have an existing Weights and Biases account, you need to create one.
- A WandB API key. If you don't have a WandB API key, follow the steps here.
- Intermediate knowledge of Python.
- Intermediate knowledge of working with Kaggle Notebooks.
- Intermediate knowledge of ML concepts.
- Basic knowledge of ASR concepts.
Before you get started with building the notebook, it may be helpful to review the two sub-sections directly below. They describe:
- The training dataset.
- The Word Error Rate ("WER") metric used during training.
Training Dataset
As mentioned in the Introduction, the XLS-R model will be finetuned on Chilean Spanish. The specific dataset is the Chilean Spanish Speech Data Set developed by Guevara-Rukoz et al. It is available for download on OpenSLR. The dataset consists of two sub-datasets: (1) 2,636 audio recordings of Chilean male speakers and (2) 1,738 audio recordings of Chilean female speakers.
Each sub-dataset includes a line_index.tsv
index file. Each line of each index file contains a pair of an audio filename and a transcription of the audio in the associated file, e.g.:
clm_08421_01719502739 Es un viaje de negocios solamente voy por una noche
clm_02436_02011517900 Se usa para incitar a alguien a sacar el mayor provecho del dia presente
I have uploaded the Chilean Spanish Speech Data Set to Kaggle for convenience. There is one Kaggle dataset for the recordings of Chilean male speakers and one Kaggle dataset for the recordings of Chilean female speakers. These Kaggle datasets will be added to the Kaggle Notebook that you will build following the steps in this guide.
Word Error Rate (WER)
WER is one metric that can be used to measure performance of automatic speech recognition models. WER provides a mechanism to measure how close a text prediction is to a text reference. WER accomplishes this by recording errors of 3 types:
-
substitutions (
S
): A substitution error is recorded when the prediction contains a word that is different from the analogous word in the reference. For example, this occurs when the prediction mis-spells a word in the reference. -
deletions (
D
): A deletion error is recorded when the prediction contains a word that is not present in the reference. -
insertions (
I
): An insertion error is recorded when the prediction does not contain a word that is present in the reference.
Obviously, WER works at the word-level. The formula for the WER metric is as follows:
WER = (S + D + I)/N
where:
S = number of substition errors
D = number of deletion errors
I = number of insertion errors
N = number of words in the reference
A simple WER example in Spanish is as follows:
prediction: "Él está saliendo."
reference: "Él está saltando."
A table helps to visualize the errors in the prediction:
TEXT |
WORD 1 |
WORD 2 |
WORD 3 |
---|---|---|---|
prediction |
Él |
está |
saliendo |
reference |
Él |
está |
saltando |
|
correct |
correct |
substitution |
The prediction contains 1 substitution error, 0 deletion errors, and 0 insertion errors. So, the WER for this example is:
WER = 1 + 0 + 0 / 3 = 1/3 = 0.33
It should be obvious that the Word Error Rate does not necessarily tell us what specific errors exist. In the example above, WER identifies that WORD 3 contains an error in the predicted text, but it doesn't tell us that the characters i and e are wrong in the prediction. Other metrics, such as the Character Error Rate ("CER"), can be used for more precise error analysis.
Building the Finetuning Notebook
You are now ready to start building the finetuning notebook.
- Step 1 and Step 2 guide you through setting up your Kaggle Notebook environment.
- Step 3 guides you through building the notebook itself. It contains 32 sub-steps representing the 32 cells of the finetuning notebook.
- Step 4 guides you through running the notebook, monitoring training, and saving the model.
Step 1 - Fetch Your WandB API Key
Your Kaggle Notebook must be configured to send training run data to WandB using your WandB API key. In order to do that, you need to copy it.
- Log in to WandB at
www.wandb.com
. - Navigate to
www.wandb.ai/authorize
. - Copy your API key for use in the next step.
Step 2 - Setting Up Your Kaggle Environment
Step 2.1 - Creating a New Kaggle Notebook
- Log in to Kaggle.
- Create a new Kaggle Notebook.
- Of course, the name of the notebook can be changed as desired. This guide uses the notebook name
xls-r-300m-chilean-spanish-asr
.
Step 2.2 - Setting Your WandB API Key
A Kaggle Secret will be used to securely store your WandB API key.
- Click Add-ons on the Kaggle Notebook main menu.
- Select Secret from the pop-up menu.
- Enter the label
WANDB_API_KEY
in the Label field and enter your WandB API key for the value. - Ensure that the Attached checkbox to the left of the
WANDB_API_KEY
label field is checked. - Click Done.
Step 2.3 - Adding the Training Datasets
The Chilean Spanish Speech Data Set has been uploaded to Kaggle as 2 distinct datasets:
Add both of these datasets to your Kaggle Notebook.
Step 3 - Building the Finetuning Notebook
The following 32 sub-steps build each of the finetuning notebook's 32 cells in order.
Step 3.1 - CELL 1: Installing Packages
The first cell of the finetuning notebook installs dependencies. Set the first cell to:
### CELL 1: Install Packages ###
!pip install --upgrade torchaudio
!pip install jiwer
- The first line upgrades the
torchaudio
package to the latest version.torchaudio
will be used to load audio files and resample audio data. - The second line installs the
jiwer
package which is required to use the HuggingFaceDatasets
libraryload_metric
method used later.
Step 3.2 - CELL 2: Importing Python Packages
The second cell imports required Python packages. Set the second cell to:
### CELL 2: Import Python packages ###
import wandb
from kaggle_secrets import UserSecretsClient
import math
import re
import numpy as np
import pandas as pd
import torch
import torchaudio
import json
from typing import Any, Dict, List, Optional, Union
from dataclasses import dataclass
from datasets import Dataset, load_metric, load_dataset, Audio
from transformers import Wav2Vec2CTCTokenizer
from transformers import Wav2Vec2FeatureExtractor
from transformers import Wav2Vec2Processor
from transformers import Wav2Vec2ForCTC
from transformers import TrainingArguments
from transformers import Trainer
- You are probably already familiar with most of these packages. Their use in the notebook will be explained as subsequent cells are built.
- It is worth mentioning that the HuggingFace
transformers
library and associatedWav2Vec2*
classes provide the backbone of the functionality used for finetuning.
Step 3.3 - CELL 3: Loading WER Metric
The third cell imports the HuggingFace WER evaluation metric. Set the third cell to:
### CELL 3: Load WER metric ###
wer_metric = load_metric("wer")
- As mentioned earlier, WER will be used to measure the performance of the model on evaluation/holdout data.
Step 3.4 - CELL 4: Logging into WandB
The fourth cell retrieves your WANDB_API_KEY
secret that was set in Step 2.2. Set the fourth cell to:
### CELL 4: Login to WandB ###
user_secrets = UserSecretsClient()
wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")
wandb.login(key = wandb_api_key)
- The API key is used to configure the Kaggle Notebook so that training run data is sent to WandB.
Step 3.5 - CELL 5: Setting Constants
The fifth cell sets constants that will be used throughout the notebook. Set the fifth cell to:
### CELL 5: Constants ###
# Training data
TRAINING_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-chile-male/"
TRAINING_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-chile-female/"
EXT = ".wav"
NUM_LOAD_FROM_EACH_SET = 1600
# Vocabulary
VOCAB_FILE_PATH = "/kaggle/working/"
SPECIAL_CHARS = r"[\d\,\-\;\!\¡\?\¿\।\‘\’\"\–\'\:\/\.\“\”\৷\…\‚\॥\\]"
# Sampling rates
ORIG_SAMPLING_RATE = 48000
TGT_SAMPLING_RATE = 16000
# Training/validation data split
SPLIT_PCT = 0.10
# Model parameters
MODEL = "facebook/wav2vec2-xls-r-300m"
USE_SAFETENSORS = False
# Training arguments
OUTPUT_DIR_PATH = "/kaggle/working/xls-r-300m-chilean-spanish-asr"
TRAIN_BATCH_SIZE = 18
EVAL_BATCH_SIZE = 10
TRAIN_EPOCHS = 30
SAVE_STEPS = 3200
EVAL_STEPS = 100
LOGGING_STEPS = 100
LEARNING_RATE = 1e-4
WARMUP_STEPS = 800
- The notebook doesn't surface every conceivable constant in this cell. Some values that could be represented by constants have been left inline.
- The use of many of the constants above should be self-evident. For those are not, their use will be explained in the following sub-steps.
Step 3.6 - CELL 6: Utility Methods for Reading Index Files, Cleaning Text, and Creating Vocabulary
The sixth cell defines utility methods for reading the dataset index files (see the Training Dataset sub-section above), as well as for cleaning transcription text and creating the vocabulary. Set the sixth cell to:
### CELL 6: Utility methods for reading index files, cleaning text, and creating vocabulary ###
def read_index_file_data(path: str, filename: str):
data = []
with open(path + filename, "r", encoding = "utf8") as f:
lines = f.readlines()
for line in lines:
file_and_text = line.split("\t")
data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")])
return data
def truncate_training_dataset(dataset: list) -> list:
if type(NUM_LOAD_FROM_EACH_SET) == str and "all" == NUM_LOAD_FROM_EACH_SET.lower():
return
else:
return dataset[:NUM_LOAD_FROM_EACH_SET]
def clean_text(text: str) -> str:
cleaned_text = re.sub(SPECIAL_CHARS, "", text)
cleaned_text = cleaned_text.lower()
return cleaned_text
def create_vocab(data):
vocab_list = []
for index in range(len(data)):
text = data[index][1]
words = text.split(" ")
for word in words:
chars = list(word)
for char in chars:
if char not in vocab_list:
vocab_list.append(char)
return vocab_list
-
The
read_index_file_data
method reads aline_index.tsv
dataset index file and produces a list of lists with audio filename and transcription data, e.g.:
[
["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"]
...
]
- The
truncate_training_dataset
method truncates a list index file data using theNUM_LOAD_FROM_EACH_SET
constant set in Step 3.5. Specifically, theNUM_LOAD_FROM_EACH_SET
constant is used to specify the number of audio samples that should be loaded from each dataset. For the purposes of this guide, the number is set at1600
which means a total of3200
audio samples will eventually be loaded. To load all samples, setNUM_LOAD_FROM_EACH_SET
to the string valueall
. - The
clean_text
method is used to strip each text transcription of the characters specified by the regular expression assigned toSPECIAL_CHARS
in Step 3.5. These characters, inclusive of punctuation, can be eliminated as they don't provide any semantic value when training the model to learn mappings between audio features and text transcriptions. - The
create_vocab
method creates a vocabulary from clean text transcriptions. Simply, it extracts all unique characters from the set of cleaned text transcriptions. You will see an example of the generated vocabulary in Step 3.14.
Step 3.7 - CELL 7: Utility Methods for Loading and Resampling Audio Data
The seventh cell defines utility methods using torchaudio
to load and resample audio data. Set the seventh cell to:
### CELL 7: Utility methods for loading and resampling audio data ###
def read_audio_data(file):
speech_array, sampling_rate = torchaudio.load(file, normalize = True)
return speech_array, sampling_rate
def resample(waveform):
transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE)
waveform = transform(waveform)
return waveform[0]
- The
read_audio_data
method loads a specified audio file and returns atorch.Tensor
multi-dimensional matrix of the audio data along with the sampling rate of the audio. All the audio files in the training data have a sampling rate of48000
Hz. This "original" sampling rate is captured by the constantORIG_SAMPLING_RATE
in Step 3.5. - The
resample
method is used to downsample audio data from a sampling rate of48000
to16000
. wav2vec2 is pretrained on audio sampled at16000
Hz. Accordingly, any audio used for finetuning must have the same sampling rate. In this case, the audio examples must be downsampled from48000
Hz to16000
Hz.16000
Hz is captured by the constantTGT_SAMPLING_RATE
in Step 3.5.
Step 3.8 - CELL 8: Utility Methods to Prepare Data for Training
The eighth cell defines utility methods that process the audio and transcription data. Set the eighth cell to:
### CELL 8: Utility methods to prepare input data for training ###
def process_speech_audio(speech_array, sampling_rate):
input_values = processor(speech_array, sampling_rate = sampling_rate).input_values
return input_values[0]
def process_target_text(target_text):
with processor.as_target_processor():
encoding = processor(target_text).input_ids
return encoding
- The
process_speech_audio
method returns the input values from a supplied training sample. - The
process_target_text
method encodes each text transcription as a list of labels - i.e. a list of indices referring to characters in the vocabulary. You will see a sample encoding in Step 3.15.
Step 3.9 - CELL 9: Utility Method to Calculate Word Error Rate
The ninth cell is the final utility method cell and contains the method to calculate the Word Error Rate between a reference transcription and a predicted transcription. Set the ninth cell to:
### CELL 9: Utility method to calculate Word Error Rate
def compute_wer(pred):
pred_logits = pred.predictions
pred_ids = np.argmax(pred_logits, axis = -1)
pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id
pred_str = processor.batch_decode(pred_ids)
label_str = processor.batch_decode(pred.label_ids, group_tokens = False)
wer = wer_metric.compute(predictions = pred_str, references = label_str)
return {"wer": wer}
Step 3.10 - CELL 10: Reading Training Data
The tenth cell reads the training data index files for the recordings of male speakers and the recordings of female speakers using the read_index_file_data
method defined in Step 3.6. Set the tenth cell to:
### CELL 10: Read training data ###
training_samples_male_cl = read_index_file_data(TRAINING_DATA_PATH_MALE, "line_index.tsv")
training_samples_female_cl = read_index_file_data(TRAINING_DATA_PATH_FEMALE, "line_index.tsv")
- As seen, the training data is managed in two gender-specific lists at this point. Data will be combined in Step 3.12 after truncation.
Step 3.11 - CELL 11: Truncating Training Data
The eleventh cell truncates the training data lists using the truncate_training_dataset
method defined in Step 3.6. Set the eleventh cell to:
### CELL 11: Truncate training data ###
training_samples_male_cl = truncate_training_dataset(training_samples_male_cl)
training_samples_female_cl = truncate_training_dataset(training_samples_female_cl)
- As a reminder, the
NUM_LOAD_FROM_EACH_SET
constant set in Step 3.5 defines the quantity of samples to keep from each dataset. The constant is set to1600
in this guide for a total of3200
samples.
Step 3.12 - CELL 12: Combining Training Samples Data
The twelfth cell combines the truncated training data lists. Set the twelfth cell to:
### CELL 12: Combine training samples data ###
all_training_samples = training_samples_male_cl + training_samples_female_cl
Step 3.13 - CELL 13: Cleaning Transcription Test
The thirteenth cell iterates over each training data sample and cleans the associated transcription text using the clean_text
method defined in Step 3.6. Set the thirteenth cell to:
for index in range(len(all_training_samples)):
all_training_samples[index][1] = clean_text(all_training_samples[index][1])
Step 3.14 - CELL 14: Creating the Vocabulary
The fourteenth cell creates a vocabulary using the cleaned transcriptions from the previous step and the create_vocab
method defined in Step 3.6. Set the fourteenth cell to:
### CELL 14: Create vocabulary ###
vocab_list = create_vocab(all_training_samples)
vocab_dict = {v: i for i, v in enumerate(vocab_list)}
-
The vocabulary is stored as a dictionary with characters as keys and vocabulary indices as values.
-
You can print
vocab_dict
which should produce the following output:
{'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32}
Step 3.15 - CELL 15: Adding Word Delimiter to the Vocabulary
The fifteenth cell adds the word delimiter character |
to the vocabulary. Set the fifteenth cell to:
### CELL 15: Add word delimiter to vocabulary ###
vocab_dict["|"] = len(vocab_dict)
-
The word delimiter character is used when tokenizing text transcriptions as a list of labels. Specifically, it is used to define the end of a word and it is used when initializing the
Wav2Vec2CTCTokenizer
class, as will be seen in Step 3.17. -
For example, the following list encodes
no te entiendo nada
using the vocabulary from Step 3.14:
# Encoded text
[6, 14, 33, 9, 5, 33, 5, 6, 9, 3, 5, 6, 8, 14, 33, 6, 1, 8, 1]
# Vocabulary
{'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32, '|': 33}
- A question that might naturally arise is: "Why is it necessary to define a word delimiter character?" For example, the end of words in written English and Spanish are marked by whitespace so it should be a simple matter to use the space character as a word delimiter. Remember that English and Spanish are just two languages among thousands; and not all written languages use a space to mark word boundaries.
Step 3.16 - CELL 16: Exporting Vocabulary
The sixteenth cell dumps the vocabulary to a file. Set the sixteenth cell to:
### CELL 16: Export vocabulary ###
with open(VOCAB_FILE_PATH + "vocab.json", "w", encoding = "utf8") as vocab_file:
json.dump(vocab_dict, vocab_file)
- The vocabulary file will be used in the next step, Step 3.17, to initialize the
Wav2Vec2CTCTokenizer
class.
Step 3.17 - CELL 17: Initialize the Tokenizer
The seventeenth cell initializes an instance of Wav2Vec2CTCTokenizer
. Set the seventeenth cell to:
### CELL 17: Initialize tokenizer ###
tokenizer = Wav2Vec2CTCTokenizer(
VOCAB_FILE_PATH + "vocab.json",
unk_token = "[UNK]",
pad_token = "[PAD]",
word_delimiter_token = "|",
replace_word_delimiter_char = " "
)
-
The tokenizer is used for encoding text transcriptions and decoding a list of labels back to text.
-
Note that the
tokenizer
is initialized with[UNK]
assigned tounk_token
and[PAD]
assigned topad_token
, with the former used to represent unknown tokens in text transcriptions and the latter used to pad transcriptions when creating batches of transcriptions with different lengths. These two values will be added to the vocabulary by the tokenizer. -
Initialization of the tokenizer in this step will also add two additional tokens to the vocabulary, namely
<s>
and/</s>
which are used to demarcate the beginning and end of sentences respectively. -
|
is assigned toword_delimiter_token
explicitly in this step to reflect that the pipe symbol will be used to demarcate the end of words in accordance with our addition of the character to the vocabulary in Step 3.15. The|
symbol is the default value forword_delimiter_token
. So, it did not need to be explicitly set but was done so for the sake of clarity. -
Similarly as with
word_delimiter_token
, a single space is explicitly assigned toreplace_word_delimiter_char
reflecting that the pipe symbol|
will be used to replace blank space characters in text transcriptions. Blank space is the default value forreplace_word_delimiter_char
. So, it also did not need to be explicitly set but was done so for the sake of clarity. -
You can print the full tokenizer vocabulary by calling the
get_vocab()
method ontokenizer
.
vocab = tokenizer.get_vocab()
print(vocab)
# Output:
{'e': 0, 's': 1, 'u': 2, 'n': 3, 'v': 4, 'i': 5, 'a': 6, 'j': 7, 'd': 8, 'g': 9, 'o': 10, 'c': 11, 'l': 12, 'm': 13, 't': 14, 'y': 15, 'p': 16, 'r': 17, 'h': 18, 'ñ': 19, 'ó': 20, 'b': 21, 'q': 22, 'f': 23, 'ú': 24, 'z': 25, 'é': 26, 'í': 27, 'x': 28, 'á': 29, 'w': 30, 'k': 31, 'ü': 32, '|': 33, '<s>': 34, '</s>': 35, '[UNK]': 36, '[PAD]': 37}
Step 3.18 - CELL 18: Initializing the Feature Extractor
The eighteenth cell initializes an instance of Wav2Vec2FeatureExtractor
. Set the eighteenth cell to:
### CELL 18: Initialize feature extractor ###
feature_extractor = Wav2Vec2FeatureExtractor(
feature_size = 1,
sampling_rate = 16000,
padding_value = 0.0,
do_normalize = True,
return_attention_mask = True
)
- The feature extractor is used to extract features from input data which is, of course, audio data in this use case. You will load the audio data for each training data sample in Step 3.20.
- The parameter values passed to the
Wav2Vec2FeatureExtractor
initializer are all default values, with the exception ofreturn_attention_mask
which defaults toFalse
. The default values are shown/passed for the sake of clarity. - The
feature_size
parameter specifies the dimension size of input features (i.e. audio data features). This default value of this parameter is1
. sampling_rate
tells the feature extractor the sampling rate at which the audio data should be digitalized. As discussed in Step 3.7, wav2vec2 is pretrained on audio sampled at16000
Hz and hence16000
is the default value for this parameter.- The
padding_value
parameter specifies the value that is used when padding audio data, as required when batching audio samples of different lengths. The default value is0.0
. do_normalize
is used to specify if input data should be transformed to a standard normal distribution. The default value isTrue
.Wav2Vec2FeatureExtractor
class documentation notes that "[normalizing] can help to significantly improve the performance for some models."- The
return_attention_mask
parameters specifies if the attention mask should be passed or not. The value is set toTrue
for this use case.
Step 3.19 - CELL 19: Initializing the Processor
The nineteenth cell initializes an instance of Wav2Vec2Processor
. Set the nineteenth cell to:
### CELL 19: Initialize processor ###
processor = Wav2Vec2Processor(feature_extractor = feature_extractor, tokenizer = tokenizer)
-
The
Wav2Vec2Processor
class combinestokenizer
andfeature_extractor
from Step 3.17 and Step 3.18 respectively into a single processor. -
Note that the processor configuration can be saved by calling the
save_pretrained
method on theWav2Vec2Processor
class instance.
processor.save_pretrained(OUTPUT_DIR_PATH)
Step 3.20 - CELL 20: Loading Audio Data
The twentieth cell loads each audio file specified in the all_training_samples
list. Set the twentieth cell to:
### CELL 20: Load audio data ###
all_input_data = []
for index in range(len(all_training_samples)):
speech_array, sampling_rate = read_audio_data(all_training_samples[index][0])
all_input_data.append({
"input_values": speech_array,
"labels": all_training_samples[index][1]
})
- Audio data is returned as a
torch.Tensor
and stored inall_input_data
as a list of dictionaries. Each dictionary contains the audio data for a particular sample, along with the text transcription of the audio. - Note that the
read_audio_data
method returns the sampling rate of the audio data as well. Since we know that the sampling rate is48000
Hz for all audio files in this use case, the sampling rate is ignored in this step.
Step 3.21 - CELL 21: Converting all_input_data
to a Pandas DataFrame
The twenty-first cell converts the all_input_data
list to a Pandas DataFrame to make it easier to manipulate the data. Set the twenty-first cell to:
### CELL 21: Convert audio training data list to Pandas DataFrame ###
all_input_data_df = pd.DataFrame(data = all_input_data)
Step 3.22 - CELL 22: Processing Audio Data and Text Transcriptions
The twenty-second cell uses the processor
initialized in Step 3.19 to extract features from each audio data sample and to encode each text transcription as a list of labels. Set the twenty-second cell to:
### CELL 22: Process audio data and text transcriptions ###
all_input_data_df["input_values"] = all_input_data_df["input_values"].apply(lambda x: process_speech_audio(resample(x), 16000))
all_input_data_df["labels"] = all_input_data_df["labels"].apply(lambda x: process_target_text(x))
Step 3.23 - CELL 23: Splitting Input Data into Training and Validation Datasets
The twenty-third cell splits the all_input_data_df
DataFrame into training and evaluation (validation) datasets using the SPLIT_PCT
constant from Step 3.5. Set the twenty-third cell to:
### CELL 23: Split input data into training and validation datasets ###
split = math.floor((NUM_LOAD_FROM_EACH_SET * 2) * SPLIT_PCT)
valid_data_df = all_input_data_df.iloc[-split:]
train_data_df = all_input_data_df.iloc[:-split]
- The
SPLIT_PCT
value is0.10
in this guide meaning 10% of all input data will be held out for evaluation and 90% of the data will be used for training/finetuning. - Since there are a total of 3,200 training samples, 320 samples will be used for evaluation with the remaining 2,880 samples used to finetune the model.
Step 3.24 - CELL 24: Converting Training and Validation Datasets to Dataset
Objects
The twenty-fourth cell converts the train_data_df
and valid_data_df
DataFrames to Dataset
objects. Set the twenty-fourth cell to:
### CELL 24: Convert training and validation datasets to Dataset objects ###
train_data = Dataset.from_pandas(train_data_df)
valid_data = Dataset.from_pandas(valid_data_df)
-
Dataset
objects are consumed by HuggingFaceTrainer
class instances, as you will see in Step 3.30. -
These objects contain metadata about the dataset as well as the dataset itself.
-
You can print
train_data
andvalid_data
to view the metadata for bothDataset
objects.
print(train_data)
print(valid_data)
# Output:
Dataset({
features: ['input_values', 'labels'],
num_rows: 2880
})
Dataset({
features: ['input_values', 'labels'],
num_rows: 320
})
Step 3.25 - CELL 25: Initializing the Pretrained Model
The twenty-fifth cell initializes the pretrained XLS-R (0.3) model. Set the twenty-fifth cell to:
### CELL 25: Initialize pretrained model ###
model = Wav2Vec2ForCTC.from_pretrained(
MODEL,
ctc_loss_reduction = "mean",
pad_token_id = processor.tokenizer.pad_token_id,
vocab_size = len(processor.tokenizer)
)
- The
from_pretrained
method called onWav2Vec2ForCTC
specifies that we want to load the pretrained weights for the specified model. - The
MODEL
constant was specified in Step 3.5 and was set tofacebook/wav2vec2-xls-r-300m
reflecting the XLS-R (0.3) model. - The
ctc_loss_reduction
parameter specifies the type of reduction to apply to the output of the Connectionist Temporal Classification ("CTC") loss function. CTC loss is used to calculate the loss between a continuous input, in this case audio data, and a target sequence, in this case text transcriptions. By setting the value tomean
, the output losses for a batch of inputs will be divided by the target lengths. The mean over the batch is then calculated and the reduction is applied to loss values. pad_token_id
specifies the token to be used for padding when batching. It is set to the[PAD]
id set when initializing the tokenizer in Step 3.17.- The
vocab_size
parameter defines the vocabulary size of the model. It is the vocabulary size after initialization of the tokenizer in Step 3.17 and reflects the number of output layer nodes of the forward portion of the network.
Step 3.26 - CELL 26: Freezing Feature Extractor Weights
The twenty-sixth cell freezes the pretrained weights of the feature extractor. Set the twenty-sixth cell to:
### CELL 26: Freeze feature extractor ###
model.freeze_feature_extractor()
Step 3.27 - CELL 27: Setting Training Arguments
The twenty-seventh cell initializes the training arguments that will be passed to a Trainer
instance. Set the twenty-seventh cell to:
### CELL 27: Set training arguments ###
training_args = TrainingArguments(
output_dir = OUTPUT_DIR_PATH,
save_safetensors = False,
group_by_length = True,
per_device_train_batch_size = TRAIN_BATCH_SIZE,
per_device_eval_batch_size = EVAL_BATCH_SIZE,
num_train_epochs = TRAIN_EPOCHS,
gradient_checkpointing = True,
evaluation_strategy = "steps",
save_strategy = "steps",
logging_strategy = "steps",
eval_steps = EVAL_STEPS,
save_steps = SAVE_STEPS,
logging_steps = LOGGING_STEPS,
learning_rate = LEARNING_RATE,
warmup_steps = WARMUP_STEPS
)
- The
TrainingArguments
class accepts more than 100 parameters. - The
save_safetensors
parameter whenFalse
specifies that the finetuned model should be saved to apickle
file instead of using thesafetensors
format. - The
group_by_length
parameter whenTrue
indicates that samples of approximately the same length should be grouped together. This minimizes padding and improves training efficiency. per_device_train_batch_size
sets the number of samples per training mini-batch. This parameter is set to18
via theTRAIN_BATCH_SIZE
constant assigned in Step 3.5. This implies 160 steps per epoch.per_device_eval_batch_size
sets the number of samples per evaluation (holdout) mini-batch. This parameter is set to10
via theEVAL_BATCH_SIZE
constant assigned in Step 3.5.num_train_epochs
sets the number of training epochs. This parameter is set to30
via theTRAIN_EPOCHS
constant assigned in Step 3.5. This implies 4,800 total steps during training.- The
gradient_checkpointing
parameter whenTrue
helps to save memory by checkpointing gradient calculations, but results in slower backward passes. - The
evaluation_strategy
parameter when set tosteps
means that evaluation will be performed and logged during training at an interval specified by the parametereval_steps
. - The
logging_strategy
parameter when set tosteps
means that training run statistics will be logged at an interval specified by the parameterlogging_steps
. - The
save_strategy
parameter when set tosteps
means that a checkpoint of the finetuned model will be saved at an interval specified by the parametersave_steps
. eval_steps
sets the number of steps between evaluations of holdout data. This parameter is set to100
via theEVAL_STEPS
constant assigned in Step 3.5.save_steps
sets the number of steps after which a checkpoint of the finetuned model is saved. This parameter is set to3200
via theSAVE_STEPS
constant assigned in Step 3.5.logging_steps
sets the number of steps between logs of training run statistics. This parameter is set to100
via theLOGGING_STEPS
constant assigned in Step 3.5.- The
learning_rate
parameter sets the initial learning rate. This parameter is set to1e-4
via theLEARNING_RATE
constant assigned in Step 3.5. - The
warmup_steps
parameter sets the number of steps to linearly warmup the learning rate from 0 to the value set bylearning_rate
. This parameter is set to800
via theWARMUP_STEPS
constant assigned in Step 3.5.
Step 3.28 - CELL 28: Defining Data Collator Logic
The twenty-eighth cell defines the logic for dynamically padding input and target sequences. Set the twenty-eighth cell to:
### CELL 28: Define data collator logic ###
@dataclass
class DataCollatorCTCWithPadding:
processor: Wav2Vec2Processor
padding: Union[bool, str] = True
max_length: Optional[int] = None
max_length_labels: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
pad_to_multiple_of_labels: Optional[int] = None
def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
input_features = [{"input_values": feature["input_values"]} for feature in features]
label_features = [{"input_ids": feature["labels"]} for feature in features]
batch = self.processor.pad(
input_features,
padding = self.padding,
max_length = self.max_length,
pad_to_multiple_of = self.pad_to_multiple_of,
return_tensors = "pt",
)
with self.processor.as_target_processor():
labels_batch = self.processor.pad(
label_features,
padding = self.padding,
max_length = self.max_length_labels,
pad_to_multiple_of = self.pad_to_multiple_of_labels,
return_tensors = "pt",
)
labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100)
batch["labels"] = labels
return batch
- Training and evaluation input-label pairs are passed in mini-batches to the
Trainer
instance that will be initialized momentarily in Step 3.30. Since the input sequences and label sequences vary in length in each mini-batch, some sequences must be padded so that they are all of the same length. - The
DataCollatorCTCWithPadding
class dynamically pads mini-batch data. Thepadding
paramenter when set toTrue
specifies that shorter audio input feature sequences and label sequences should have the same length as the longest sequence in a mini-batch. - Audio input features are padded with the value
0.0
set when initializing the feature extractor in Step 3.18. - Label inputs are first padded with the padding value set when initializing the tokenizer in Step 3.17. These values are replaced by
-100
so that these labels are ignored when calculating the WER metric.
Step 3.29 - CELL 29: Initializing Instance of Data Collator
The twenty-ninth cell initializes an instance of the data collator defined in the previous step. Set the twenty-ninth cell to:
### CELL 29: Initialize instance of data collator ###
data_collator = DataCollatorCTCWithPadding(processor = processor, padding = True)
Step 3.30 - CELL 30: Initializing the Trainer
The thirtieth cell initializes an instance of the Trainer
class. Set the thirtieth cell to:
### CELL 30: Initialize trainer ###
trainer = Trainer(
model = model,
data_collator = data_collator,
args = training_args,
compute_metrics = compute_wer,
train_dataset = train_data,
eval_dataset = valid_data,
tokenizer = processor.feature_extractor
)
- As seen, the
Trainer
class is initialized with:- The pretrained
model
initialized in Step 3.25. - The data collator initialized in Step 3.29.
- The training arguments initialized in Step 3.27.
- The WER evaluation method defined in Step 3.9.
- The
train_data
Dataset
object from Step 3.24. - The
valid_data
Dataset
object from Step 3.24.
- The pretrained
- The
tokenizer
parameter is assigned toprocessor.feature_extractor
and works withdata_collator
to automatically pad the inputs to the maximum-length input of each mini-batch.
Step 3.31 - CELL 31: Finetuning the Model
The thirty-first cell calls the train
method on the Trainer
class instance to finetune the model. Set the thirty-first cell to:
### CELL 31: Finetune the model ###
trainer.train()
Step 3.32 - CELL 32: Save the finetuned model
The thirty-second cell is the last notebook cell. It saves the finetuned model by calling the save_model
method on the Trainer
instance. Set the thirty-second cell to:
### CELL 32: Save the finetuned model ###
trainer.save_model(OUTPUT_DIR_PATH)
Step 4 - Training and Saving the Model
Step 4.1 - Training the Model
Now that all the cells of the notebook have been built, it’s time to start finetuning.
-
Set the Kaggle Notebook to run with the NVIDIA GPU P100 accelerator.
-
Commit the notebook on Kaggle.
-
Monitor training run data by logging into your WandB account and locating the associated run.
Training over 30 epochs should take ~5 hours using the NVIDIA GPU P100 accelerator. The WER on holdout data should drop to ~0.15 at the end of training. It’s not quite a state-of-the-art result, but the finetuned model is still sufficiently useful for many applications.
Step 4.2 - Saving the Model
The finetuned model will be output to the Kaggle directory specified by the constant OUTPUT_DIR_PATH
specified in Step 3.5. The model output should include the following files:
pytorch_model.bin
config.json
preprocessor_config.json
vocab.json
training_args.bin
These files can be downloaded locally. Additionally, you can create a new Kaggle Model using the model files. The Kaggle Model will be used with the companion inference guide to run inference on the finetuned model.
- Log in to your Kaggle account. Click on Models > New Model.
- Add a title for your finetuned model in the Model Title field.
- Click on Create Model.
- Click on Go to model detail page.
- Click on Add new variation under Model Variations.
- Select Transformers from the Framework select menu.
- Click on Add new variation.
- Drag and drop your finetuned model files into the Upload Data window. Alternatively, click on the Browse Files button to open a file explorer window and select your finetuned model files.
- Once the files have uploaded to Kaggle, click on Create to create the Kaggle Model.
Conclusion
Congratulations on finetuning wav2vec2 XLS-R! Remember that you can use these general steps to finetune the model on other languages that you desire. Running inference on the finetuned model generated in this guide is fairly straightforward. The inference steps will be outlined in a separate companion guide to this one. Please search on my HackerNoon username to find the companion guide.