導入
Meta AI は、2021 年末にwav2vec2 XLS-R (以下「XLS-R」)を導入しました。XLS-R は、言語間音声表現学習用の機械学習 (以下「ML」) モデルであり、128 言語にわたる 400,000 時間を超える公開音声オーディオでトレーニングされました。リリース時に、このモデルは、53 言語にわたる約 50,000 時間の音声オーディオでトレーニングされた Meta AI のXLSR-53言語間モデルを飛躍的に上回りました。
このガイドでは、 Kaggle Notebookを使用して、自動音声認識 (「ASR」) 用に XLS-R を微調整する手順について説明します。モデルはチリのスペイン語で微調整されますが、一般的な手順に従って、必要なさまざまな言語で XLS-R を微調整できます。
微調整されたモデルで推論を実行する方法については、付属のチュートリアルで説明されており、このガイドは 2 部構成の第 1 部となります。この微調整ガイドは少し長くなったため、推論に特化した別のガイドを作成することにしました。
すでに ML のバックグラウンドがあり、基本的な ASR の概念を理解していることが前提となります。初心者はビルド手順を理解しにくい場合があります。
XLS-R の背景について
2020 年に導入されたオリジナルの wav2vec2 モデルは、960 時間のLibrispeechデータセットの音声と約 53,200 時間のLibriVoxデータセットの音声で事前トレーニングされました。リリース時には、9,500 万のパラメータを持つBASEモデルと 3 億 1,700 万のパラメータを持つLARGEモデルの 2 つのモデル サイズが利用可能でした。
一方、XLS-R は、5 つのデータセットからの多言語音声オーディオで事前トレーニングされました。
- VoxPopuli : 欧州議会の議会演説を 23 のヨーロッパ言語で合計約 372,000 時間分収録した音声。
- 多言語 Librispeech : 8 つのヨーロッパ言語にわたる合計約 50,000 時間の音声オーディオ。オーディオ データの大部分 (約 44,000 時間) は英語です。
- CommonVoice : 60 の言語にわたる合計約 7,000 時間の音声オーディオ。
- VoxLingua107 : YouTube コンテンツに基づく、107 言語にわたる合計約 6,600 時間の音声オーディオ。
- BABEL : 電話での会話音声に基づいた、アフリカとアジアの 17 言語にわたる合計約 1,100 時間の音声オーディオ。
XLS-R モデルには、3 億個のパラメータを持つXLS-R (0.3B) 、10 億個のパラメータを持つXLS-R (1B) 、20 億個のパラメータを持つXLS-R (2B)の 3 つがあります。このガイドでは、XLS-R (0.3B) モデルを使用します。
アプローチ
wav2vev2モデルを微調整する方法について素晴らしい記事がいくつかあり、おそらくこれはある種の「ゴールド スタンダード」でしょう。もちろん、ここでの一般的なアプローチは、他のガイドで紹介されているものと似ています。次のようになります。
- オーディオ データと関連するテキスト転写のトレーニング データセットを読み込みます。
- データセット内のテキスト転写から語彙を作成します。
- 入力データから特徴を抽出し、テキスト転写をラベルのシーケンスに変換する wav2vec2 プロセッサを初期化します。
- 処理された入力データに対して wav2vec2 XLS-R を微調整します。
ただし、このガイドと他のガイドの間には 3 つの重要な違いがあります。
- このガイドでは、関連する ML および ASR の概念について、それほど多くの「インライン」の説明は提供されていません。
- 個々のノートブック セルの各サブセクションには、特定のセルの使用法や目的に関する詳細が含まれますが、読者には既に ML のバックグラウンドがあり、基本的な ASR の概念を理解していることが前提となります。
- 構築する Kaggle ノートブックでは、ユーティリティ メソッドが最上位のセル内に整理されます。
- 多くの微調整ノートブックは、一種の「意識の流れ」タイプのレイアウトになる傾向がありますが、私はすべてのユーティリティ メソッドをまとめて整理することにしました。wav2vec2 を初めて使用する場合は、このアプローチに戸惑うかもしれません。ただし、繰り返しになりますが、各セルの専用サブセクションで各セルの目的を説明するときは、できるだけ明確にするよう努めています。wav2vec2 について学び始めたばかりの場合は、私の HackerNoon の記事「wav2vec2 for Automatic Speech Recognition in Plain English」をざっと読むと役立つかもしれません。
- このガイドでは、微調整の手順のみを説明します。
- はじめにで述べたように、生成する微調整された XLS-R モデルで推論を実行する方法については、別のガイドを作成することにしました。これは、このガイドが長くなりすぎないようにするためです。
前提条件と始める前に
ガイドを完了するには、次のものが必要です。
- 既存のKaggle アカウント。既存の Kaggle アカウントがない場合は、作成する必要があります。
- 既存のWeights and Biases アカウント ("WandB") 。既存の Weights and Biases アカウントがない場合は、作成する必要があります。
- WandB API キー。WandB API キーをお持ちでない場合は、こちらの手順に従ってください。
- Python の中級レベルの知識。
- Kaggle Notebooks の操作に関する中級レベルの知識。
- ML 概念に関する中級レベルの知識。
- ASR 概念に関する基本的な知識。
ノートブックの構築を始める前に、すぐ下の 2 つのサブセクションを確認すると役立つ場合があります。これらのサブセクションの内容は次のとおりです。
- トレーニングデータセット。
- トレーニング中に使用される単語エラー率 (「WER」) メトリック。
トレーニングデータセット
はじめにで述べたように、XLS-Rモデルはチリのスペイン語で微調整されます。具体的なデータセットは、Guevara-Rukozらが開発したチリのスペイン語音声データセットです。これはOpenSLRからダウンロードできます。データセットは、(1)チリの男性話者の2,636の音声録音と(2)チリの女性話者の1,738の音声録音の2つのサブデータセットで構成されています。
各サブデータセットには、 line_index.tsv
インデックス ファイルが含まれています。各インデックス ファイルの各行には、オーディオ ファイル名と、関連ファイル内のオーディオの転写のペアが含まれています。例:
clm_08421_01719502739 Es un viaje de negocios solamente voy por una noche clm_02436_02011517900 Se usa para incitar a alguien a sacar el mayor provecho del dia presente
便宜上、チリのスペイン語音声データセットを Kaggle にアップロードしました。チリの男性話者の録音用に 1 つの Kaggle データセットがあり、チリの女性話者の録音用に 1 つの Kaggle データセットがあります。これらの Kaggle データセットは、このガイドの手順に従って構築する Kaggle ノートブックに追加されます。
単語誤り率 (WER)
WER は、自動音声認識モデルのパフォーマンスを測定するために使用できる 1 つの指標です。WER は、テキスト予測がテキスト参照にどれだけ近いかを測定するメカニズムを提供します。WER は、次の 3 種類のエラーを記録することでこれを実現します。
置換 (
S
): 予測に参照内の類似語とは異なる単語が含まれている場合に、置換エラーが記録されます。たとえば、予測で参照内の単語のスペルが間違っている場合に発生します。削除(
D
):予測に参照に存在しない単語が含まれている場合、削除エラーが記録されます。挿入(
I
):予測に参照に存在する単語が含まれていない場合、挿入エラーが記録されます。
明らかに、WER は単語レベルで機能します。WER メトリックの式は次のとおりです。
WER = (S + D + I)/N where: S = number of substition errors D = number of deletion errors I = number of insertion errors N = number of words in the reference
スペイン語での簡単な WER の例は次のとおりです。
prediction: "Él está saliendo." reference: "Él está saltando."
表は予測の誤差を視覚化するのに役立ちます。
文章 | 単語1 | 単語2 | 単語3 |
---|---|---|---|
予測 | エル | そうです | サリエンド |
参照 | エル | そうです | サルタンド |
| 正しい | 正しい | 代替 |
予測には、置換エラーが 1 つ、削除エラーが 0 つ、挿入エラーが 0 つ含まれています。したがって、この例の WER は次のようになります。
WER = 1 + 0 + 0 / 3 = 1/3 = 0.33
単語エラー率が必ずしも具体的なエラーが何であるかを教えてくれるとは限らないことは明らかです。上記の例では、WER は予測されたテキストのWORD 3にエラーが含まれていることを識別しますが、予測で文字iとeが間違っていることは教えてくれません。文字エラー率 (CER) などの他の指標は、より正確なエラー分析に使用できます。
ファインチューニングノートブックの構築
これで、微調整ノートブックの構築を開始する準備が整いました。
- ステップ 1とステップ 2では、Kaggle Notebook 環境の設定手順を説明します。
- ステップ 3 では、ノートブック自体の構築手順を説明します。このステップには、微調整ノートブックの 32 個のセルを表す 32 個のサブステップが含まれています。
- ステップ 4では、ノートブックの実行、トレーニングの監視、モデルの保存の手順を説明します。
ステップ1 - WandB APIキーを取得する
Kaggle Notebook は、WandB API キーを使用してトレーニング実行データを WandB に送信するように設定する必要があります。そのためには、キーをコピーする必要があります。
-
www.wandb.com
で WandB にログインします。 -
www.wandb.ai/authorize
に移動します。 - 次のステップで使用するために API キーをコピーします。
ステップ 2 - Kaggle 環境の設定
ステップ 2.1 - 新しい Kaggle ノートブックを作成する
- Kaggleにログインします。
- 新しい Kaggle ノートブックを作成します。
- もちろん、ノートブックの名前は必要に応じて変更できます。このガイドでは、ノートブック名
xls-r-300m-chilean-spanish-asr
を使用します。
ステップ 2.2 - WandB API キーの設定
WandB API キーを安全に保存するために、 Kaggle Secret が使用されます。
- Kaggle Notebook のメイン メニューで[アドオン] をクリックします。
- ポップアップメニューから「シークレット」を選択します。
- ラベルフィールドにラベル
WANDB_API_KEY
を入力し、値に WandB API キーを入力します。 -
WANDB_API_KEY
ラベル フィールドの左側にある[添付]チェックボックスがオンになっていることを確認します。 - 「完了」をクリックします。
ステップ 2.3 - トレーニング データセットの追加
チリのスペイン語音声データセットは、 2 つの異なるデータセットとして Kaggle にアップロードされました。
これら両方のデータセットを Kaggle ノートブックに追加します。
ステップ3 - 微調整ノートブックの構築
次の 32 のサブステップでは、微調整ノートブックの 32 個のセルそれぞれを順番に構築します。
ステップ 3.1 - セル 1: パッケージのインストール
微調整ノートブックの最初のセルは依存関係をインストールします。最初のセルを次のように設定します。
### CELL 1: Install Packages ### !pip install --upgrade torchaudio !pip install jiwer
- 最初の行は、
torchaudio
パッケージを最新バージョンにアップグレードします。torchaudiotorchaudio
、オーディオ ファイルを読み込み、オーディオ データを再サンプリングするために使用されます。 - 2 行目は、後で使用する HuggingFace
Datasets
ライブラリのload_metric
メソッドを使用するために必要なjiwer
パッケージをインストールします。
ステップ 3.2 - セル 2: Python パッケージのインポート
2 番目のセルは必要な Python パッケージをインポートします。2 番目のセルを次のように設定します。
### CELL 2: Import Python packages ### import wandb from kaggle_secrets import UserSecretsClient import math import re import numpy as np import pandas as pd import torch import torchaudio import json from typing import Any, Dict, List, Optional, Union from dataclasses import dataclass from datasets import Dataset, load_metric, load_dataset, Audio from transformers import Wav2Vec2CTCTokenizer from transformers import Wav2Vec2FeatureExtractor from transformers import Wav2Vec2Processor from transformers import Wav2Vec2ForCTC from transformers import TrainingArguments from transformers import Trainer
- これらのパッケージのほとんどはすでにご存知でしょう。ノートブックでのそれらの使用法については、後続のセルを作成するときに説明します。
- HuggingFace
transformers
ライブラリと関連するWav2Vec2*
クラスが、微調整に使用される機能のバックボーンを提供していることは言及する価値があります。
ステップ 3.3 - セル 3: WER メトリックの読み込み
3 番目のセルは、HuggingFace WER 評価メトリックをインポートします。3 番目のセルを次のように設定します。
### CELL 3: Load WER metric ### wer_metric = load_metric("wer")
- 前述したように、WER は評価/ホールドアウト データに対するモデルのパフォーマンスを測定するために使用されます。
ステップ 3.4 - セル 4: WandB へのログイン
4 番目のセルは、ステップ 2.2で設定したWANDB_API_KEY
シークレットを取得します。4 番目のセルを次のように設定します。
### CELL 4: Login to WandB ### user_secrets = UserSecretsClient() wandb_api_key = user_secrets.get_secret("WANDB_API_KEY") wandb.login(key = wandb_api_key)
- API キーは、トレーニング実行データが WandB に送信されるように Kaggle Notebook を構成するために使用されます。
ステップ 3.5 - セル 5: 定数の設定
5 番目のセルは、ノートブック全体で使用される定数を設定します。5 番目のセルを次のように設定します。
### CELL 5: Constants ### # Training data TRAINING_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-chile-male/" TRAINING_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-chile-female/" EXT = ".wav" NUM_LOAD_FROM_EACH_SET = 1600 # Vocabulary VOCAB_FILE_PATH = "/kaggle/working/" SPECIAL_CHARS = r"[\d\,\-\;\!\¡\?\¿\।\'\'\"\–\'\:\/\.\“\”\৷\…\‚\॥\\]" # Sampling rates ORIG_SAMPLING_RATE = 48000 TGT_SAMPLING_RATE = 16000 # Training/validation data split SPLIT_PCT = 0.10 # Model parameters MODEL = "facebook/wav2vec2-xls-r-300m" USE_SAFETENSORS = False # Training arguments OUTPUT_DIR_PATH = "/kaggle/working/xls-r-300m-chilean-spanish-asr" TRAIN_BATCH_SIZE = 18 EVAL_BATCH_SIZE = 10 TRAIN_EPOCHS = 30 SAVE_STEPS = 3200 EVAL_STEPS = 100 LOGGING_STEPS = 100 LEARNING_RATE = 1e-4 WARMUP_STEPS = 800
- ノートブックでは、このセルに考えられるすべての定数が表示されるわけではありません。定数で表すことができる値の一部はインラインのままになっています。
- 上記の定数の多くは、その使い方が自明です。そうでないものについては、次のサブステップでその使い方を説明します。
ステップ 3.6 - セル 6: インデックス ファイルの読み取り、テキストのクリーンアップ、語彙の作成のためのユーティリティ メソッド
6 番目のセルは、データセット インデックス ファイル (上記のトレーニング データセットのサブセクションを参照) の読み取り、および転写テキストのクリーニングと語彙の作成を行うユーティリティ メソッドを定義します。6 番目のセルを次のように設定します。
### CELL 6: Utility methods for reading index files, cleaning text, and creating vocabulary ### def read_index_file_data(path: str, filename: str): data = [] with open(path + filename, "r", encoding = "utf8") as f: lines = f.readlines() for line in lines: file_and_text = line.split("\t") data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")]) return data def truncate_training_dataset(dataset: list) -> list: if type(NUM_LOAD_FROM_EACH_SET) == str and "all" == NUM_LOAD_FROM_EACH_SET.lower(): return else: return dataset[:NUM_LOAD_FROM_EACH_SET] def clean_text(text: str) -> str: cleaned_text = re.sub(SPECIAL_CHARS, "", text) cleaned_text = cleaned_text.lower() return cleaned_text def create_vocab(data): vocab_list = [] for index in range(len(data)): text = data[index][1] words = text.split(" ") for word in words: chars = list(word) for char in chars: if char not in vocab_list: vocab_list.append(char) return vocab_list
read_index_file_data
メソッドは、line_index.tsv
データセット インデックス ファイルを読み取り、オーディオ ファイル名と転写データを含むリストのリストを生成します。例:
[ ["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"] ... ]
-
truncate_training_dataset
メソッドは、ステップ 3.5で設定されたNUM_LOAD_FROM_EACH_SET
定数を使用して、リスト インデックス ファイルのデータを切り捨てます。具体的には、NUM_LOAD_FROM_EACH_SET
定数は、各データセットからロードするオーディオ サンプルの数を指定するために使用されます。このガイドでは、この数は1600
に設定されており、最終的に合計3200
のオーディオ サンプルがロードされることを意味します。すべてのサンプルをロードするには、NUM_LOAD_FROM_EACH_SET
を文字列値all
に設定します。 -
clean_text
メソッドは、ステップ 3.5でSPECIAL_CHARS
に割り当てられた正規表現で指定された文字を各テキスト転写から削除するために使用されます。句読点を含むこれらの文字は、オーディオ機能とテキスト転写間のマッピングを学習するようにモデルをトレーニングするときに意味的な価値を提供しないため、削除できます。 -
create_vocab
メソッドは、クリーンなテキスト転写から語彙を作成します。簡単に言えば、クリーンなテキスト転写のセットからすべての一意の文字を抽出します。生成された語彙の例は、ステップ 3.14で確認できます。
ステップ 3.7 - セル 7: オーディオ データの読み込みと再サンプリングのためのユーティリティ メソッド
7 番目のセルは、 torchaudio
を使用してオーディオ データを読み込み、再サンプリングするユーティリティ メソッドを定義します。7 番目のセルを次のように設定します。
### CELL 7: Utility methods for loading and resampling audio data ### def read_audio_data(file): speech_array, sampling_rate = torchaudio.load(file, normalize = True) return speech_array, sampling_rate def resample(waveform): transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE) waveform = transform(waveform) return waveform[0]
-
read_audio_data
メソッドは、指定されたオーディオ ファイルを読み込み、オーディオ データのtorch.Tensor
多次元マトリックスとオーディオのサンプリング レートを返します。トレーニング データ内のすべてのオーディオ ファイルのサンプリング レートは48000
Hz です。この「元の」サンプリング レートは、ステップ 3.5の定数ORIG_SAMPLING_RATE
によって取得されます。 -
resample
メソッドは、オーディオ データをサンプリング レート48000
から16000
にダウンサンプリングするために使用されます。wav2vec2 は、16000
Hz でサンプリングされたオーディオで事前トレーニングされています。したがって、微調整に使用するオーディオはすべて同じサンプリング レートである必要があります。この場合、オーディオ サンプルは48000
Hz から16000
Hz にダウンサンプリングする必要があります。16000 Hz16000
、ステップ 3.5の定数TGT_SAMPLING_RATE
によって取得されます。
ステップ 3.8 - セル 8: トレーニング用データを準備するためのユーティリティ メソッド
8 番目のセルは、オーディオとトランスクリプション データを処理するユーティリティ メソッドを定義します。8 番目のセルを次のように設定します。
### CELL 8: Utility methods to prepare input data for training ### def process_speech_audio(speech_array, sampling_rate): input_values = processor(speech_array, sampling_rate = sampling_rate).input_values return input_values[0] def process_target_text(target_text): with processor.as_target_processor(): encoding = processor(target_text).input_ids return encoding
-
process_speech_audio
メソッドは、提供されたトレーニング サンプルから入力値を返します。 -
process_target_text
メソッドは、各テキスト転写をラベルのリスト、つまり語彙内の文字を参照するインデックスのリストとしてエンコードします。サンプルのエンコードは、ステップ 3.15で確認できます。
ステップ 3.9 - セル 9: 単語エラー率を計算するユーティリティ メソッド
9 番目のセルは、最後のユーティリティ メソッド セルであり、参照転写と予測転写の間の単語エラー率を計算するメソッドが含まれています。9 番目のセルを次のように設定します。
### CELL 9: Utility method to calculate Word Error Rate def compute_wer(pred): pred_logits = pred.predictions pred_ids = np.argmax(pred_logits, axis = -1) pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id pred_str = processor.batch_decode(pred_ids) label_str = processor.batch_decode(pred.label_ids, group_tokens = False) wer = wer_metric.compute(predictions = pred_str, references = label_str) return {"wer": wer}
ステップ 3.10 - セル 10: トレーニング データの読み取り
10 番目のセルは、ステップ 3.6で定義されたread_index_file_data
メソッドを使用して、男性話者の録音と女性話者の録音のトレーニング データ インデックス ファイルを読み取ります。10 番目のセルを次のように設定します。
### CELL 10: Read training data ### training_samples_male_cl = read_index_file_data(TRAINING_DATA_PATH_MALE, "line_index.tsv") training_samples_female_cl = read_index_file_data(TRAINING_DATA_PATH_FEMALE, "line_index.tsv")
- ご覧のとおり、この時点ではトレーニング データは 2 つの性別固有のリストで管理されています。データは切り捨てられた後、ステップ 3.12で結合されます。
ステップ 3.11 - セル 11: トレーニング データの切り捨て
11 番目のセルは、ステップ 3.6で定義されたtruncate_training_dataset
メソッドを使用してトレーニング データ リストを切り捨てます。11 番目のセルを次のように設定します。
### CELL 11: Truncate training data ### training_samples_male_cl = truncate_training_dataset(training_samples_male_cl) training_samples_female_cl = truncate_training_dataset(training_samples_female_cl)
- 念のため、ステップ 3.5で設定した
NUM_LOAD_FROM_EACH_SET
定数は、各データセットから保持するサンプルの数を定義します。このガイドでは、定数は1600
に設定されており、合計3200
サンプルになります。
ステップ 3.12 - セル 12: トレーニング サンプル データの結合
12 番目のセルは、切り捨てられたトレーニング データ リストを結合します。12 番目のセルを次のように設定します。
### CELL 12: Combine training samples data ### all_training_samples = training_samples_male_cl + training_samples_female_cl
ステップ 3.13 - セル 13: クリーニング転写テスト
13 番目のセルは、各トレーニング データ サンプルを反復処理し、ステップ 3.6で定義されたclean_text
メソッドを使用して、関連する転写テキストをクリーンアップします。13 番目のセルを次のように設定します。
for index in range(len(all_training_samples)): all_training_samples[index][1] = clean_text(all_training_samples[index][1])
ステップ 3.14 - セル 14: 語彙の作成
14 番目のセルは、前のステップでクリーンアップされた転写とステップ 3.6で定義されたcreate_vocab
メソッドを使用して語彙を作成します。14 番目のセルを次のように設定します。
### CELL 14: Create vocabulary ### vocab_list = create_vocab(all_training_samples) vocab_dict = {v: i for i, v in enumerate(vocab_list)}
語彙は、文字をキーとし、語彙インデックスを値とする辞書として保存されます。
vocab_dict
を印刷すると、次の出力が生成されます。
{'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32}
ステップ 3.15 - セル 15: 語彙に単語区切り文字を追加する
15 番目のセルは、語彙に単語区切り文字|
を追加します。15 番目のセルを次のように設定します。
### CELL 15: Add word delimiter to vocabulary ### vocab_dict["|"] = len(vocab_dict)
単語区切り文字は、テキスト転写をラベルのリストとしてトークン化するときに使用されます。具体的には、単語の終わりを定義するために使用され、ステップ 3.17で説明するように、
Wav2Vec2CTCTokenizer
クラスを初期化するときに使用されます。たとえば、次のリストは、ステップ3.14の語彙を使用して
no te entiendo nada
をエンコードします。
# Encoded text [6, 14, 33, 9, 5, 33, 5, 6, 9, 3, 5, 6, 8, 14, 33, 6, 1, 8, 1] # Vocabulary {'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32, '|': 33}
- 当然、「なぜ単語区切り文字を定義する必要があるのか」という疑問が湧いてくるかもしれません。たとえば、英語とスペイン語の書き言葉では単語の終わりは空白で示されるため、スペース文字を単語区切り文字として使用するのは簡単なはずです。英語とスペイン語は数千ある言語の中の 2 つの言語に過ぎず、すべての書き言葉が単語の境界を示すためにスペースを使用するわけではないことに注意してください。
ステップ 3.16 - セル 16: 語彙のエクスポート
16 番目のセルは語彙をファイルにダンプします。16 番目のセルを次のように設定します。
### CELL 16: Export vocabulary ### with open(VOCAB_FILE_PATH + "vocab.json", "w", encoding = "utf8") as vocab_file: json.dump(vocab_dict, vocab_file)
- 語彙ファイルは次のステップ 3.17で
Wav2Vec2CTCTokenizer
クラスを初期化するために使用されます。
ステップ 3.17 - セル 17: トークナイザーを初期化する
17 番目のセルは、 Wav2Vec2CTCTokenizer
のインスタンスを初期化します。17 番目のセルを次のように設定します。
### CELL 17: Initialize tokenizer ### tokenizer = Wav2Vec2CTCTokenizer( VOCAB_FILE_PATH + "vocab.json", unk_token = "[UNK]", pad_token = "[PAD]", word_delimiter_token = "|", replace_word_delimiter_char = " " )
トークナイザーは、テキストの転写をエンコードし、ラベルのリストをテキストにデコードするために使用されます。
tokenizer
、unk_token
に[UNK]
が割り当てられ、pad_token
に[PAD]
が割り当てられて初期化されることに注意してください。前者はテキスト転写内の不明なトークンを表すために使用され、後者は異なる長さの転写のバッチを作成するときに転写を埋め込むために使用されます。これらの 2 つの値は、トークナイザーによって語彙に追加されます。このステップでトークナイザーを初期化すると、語彙に 2 つの追加トークン、つまりそれぞれ文の始まりと終わりを区別するために使用される
<s>
と/</s>
も追加されます。このステップでは、ステップ 3.15で語彙に文字を追加したことに応じて、パイプ記号が単語の終わりを区切るために使用されることを反映して、
|
がword_delimiter_token
に明示的に割り当てられます。|
記号はword_delimiter_token
のデフォルト値です。したがって、明示的に設定する必要はありませんでしたが、わかりやすくするために設定しました。word_delimiter_token
と同様に、パイプ記号|
がテキスト転写内の空白文字の置き換えに使用されることを反映して、replace_word_delimiter_char
に 1 つのスペースが明示的に割り当てられます。空白はreplace_word_delimiter_char
のデフォルト値です。したがって、これも明示的に設定する必要はありませんでしたが、わかりやすくするために設定しました。tokenizer
のget_vocab()
メソッドを呼び出すと、完全なトークナイザー語彙を印刷できます。
vocab = tokenizer.get_vocab() print(vocab) # Output: {'e': 0, 's': 1, 'u': 2, 'n': 3, 'v': 4, 'i': 5, 'a': 6, 'j': 7, 'd': 8, 'g': 9, 'o': 10, 'c': 11, 'l': 12, 'm': 13, 't': 14, 'y': 15, 'p': 16, 'r': 17, 'h': 18, 'ñ': 19, 'ó': 20, 'b': 21, 'q': 22, 'f': 23, 'ú': 24, 'z': 25, 'é': 26, 'í': 27, 'x': 28, 'á': 29, 'w': 30, 'k': 31, 'ü': 32, '|': 33, '<s>': 34, '</s>': 35, '[UNK]': 36, '[PAD]': 37}
ステップ 3.18 - セル 18: 特徴抽出器の初期化
18 番目のセルは、 Wav2Vec2FeatureExtractor
のインスタンスを初期化します。18 番目のセルを次のように設定します。
### CELL 18: Initialize feature extractor ### feature_extractor = Wav2Vec2FeatureExtractor( feature_size = 1, sampling_rate = 16000, padding_value = 0.0, do_normalize = True, return_attention_mask = True )
- 特徴抽出器は、入力データ(このユースケースでは当然オーディオデータ)から特徴を抽出するために使用されます。ステップ 3.20で、各トレーニング データ サンプルのオーディオ データを読み込みます。
-
Wav2Vec2FeatureExtractor
初期化子に渡されるパラメータ値はすべてデフォルト値ですが、return_attention_mask
はデフォルトでFalse
に設定されます。デフォルト値はわかりやすくするために表示/渡されています。 -
feature_size
パラメータは、入力特徴(つまり、オーディオ データ特徴)の次元サイズを指定します。このパラメータのデフォルト値は1
です。 -
sampling_rate
、オーディオデータをデジタル化するサンプリングレートを特徴抽出器に伝えます。ステップ 3.7で説明したように、wav2vec2 は16000
Hz でサンプリングされたオーディオで事前トレーニングされているため、このパラメータのデフォルト値は16000
です。 -
padding_value
パラメータは、異なる長さのオーディオ サンプルをバッチ処理するときに必要な、オーディオ データのパディングに使用する値を指定します。デフォルト値は0.0
です。 -
do_normalize
、入力データを標準正規分布に変換するかどうかを指定するために使用されます。デフォルト値はTrue
です。Wav2Vec2FeatureExtractor クラスのドキュメントには、「[正規化] により、一部のモデルのパフォーマンスが大幅に向上する可能性があります」と記載されていますWav2Vec2FeatureExtractor
-
return_attention_mask
パラメータは、アテンション マスクを渡すかどうかを指定します。このユース ケースでは、値はTrue
に設定されています。
ステップ 3.19 - セル 19: プロセッサの初期化
19 番目のセルはWav2Vec2Processor
のインスタンスを初期化します。19 番目のセルを次のように設定します。
### CELL 19: Initialize processor ### processor = Wav2Vec2Processor(feature_extractor = feature_extractor, tokenizer = tokenizer)
Wav2Vec2Processor
クラスは、ステップ 3.17とステップ 3.18のtokenizer
とfeature_extractor
それぞれ 1 つのプロセッサに結合します。Wav2Vec2Processor
クラス インスタンスでsave_pretrained
メソッドを呼び出すことによって、プロセッサ構成を保存できることに注意してください。
processor.save_pretrained(OUTPUT_DIR_PATH)
ステップ 3.20 - セル 20: オーディオ データの読み込み
20 番目のセルは、 all_training_samples
リストで指定された各オーディオ ファイルを読み込みます。20 番目のセルを次のように設定します。
### CELL 20: Load audio data ### all_input_data = [] for index in range(len(all_training_samples)): speech_array, sampling_rate = read_audio_data(all_training_samples[index][0]) all_input_data.append({ "input_values": speech_array, "labels": all_training_samples[index][1] })
- オーディオ データは
torch.Tensor
として返され、辞書のリストとしてall_input_data
に保存されます。各辞書には、特定のサンプルのオーディオ データと、オーディオのテキスト転写が含まれています。 -
read_audio_data
メソッドはオーディオ データのサンプリング レートも返すことに注意してください。このユース ケースではすべてのオーディオ ファイルのサンプリング レートが48000
Hz であることがわかっているため、この手順ではサンプリング レートは無視されます。
ステップ 3.21 - セル 21: all_input_data
Pandas DataFrame に変換する
21 番目のセルは、 all_input_data
リストを Pandas DataFrame に変換して、データの操作を容易にします。21 番目のセルを次のように設定します。
### CELL 21: Convert audio training data list to Pandas DataFrame ### all_input_data_df = pd.DataFrame(data = all_input_data)
ステップ 3.22 - セル 22: オーディオ データとテキストの転写の処理
22 番目のセルは、ステップ 3.19で初期化されたprocessor
を使用して、各オーディオ データ サンプルから特徴を抽出し、各テキスト転写をラベルのリストとしてエンコードします。22 番目のセルを次のように設定します。
### CELL 22: Process audio data and text transcriptions ### all_input_data_df["input_values"] = all_input_data_df["input_values"].apply(lambda x: process_speech_audio(resample(x), 16000)) all_input_data_df["labels"] = all_input_data_df["labels"].apply(lambda x: process_target_text(x))
ステップ 3.23 - セル 23: 入力データをトレーニングデータセットと検証データセットに分割する
23 番目のセルは、ステップ 3.5のSPLIT_PCT
定数を使用して、 all_input_data_df
DataFrame をトレーニング データセットと評価 (検証) データセットに分割します。23 番目のセルを次のように設定します。
### CELL 23: Split input data into training and validation datasets ### split = math.floor((NUM_LOAD_FROM_EACH_SET * 2) * SPLIT_PCT) valid_data_df = all_input_data_df.iloc[-split:] train_data_df = all_input_data_df.iloc[:-split]
- このガイドでは
SPLIT_PCT
値は0.10
です。これは、すべての入力データの 10% が評価用に保持され、データの 90% がトレーニング/微調整に使用されることを意味します。 - トレーニング サンプルは合計 3,200 個あるため、320 個のサンプルが評価に使用され、残りの 2,880 個のサンプルはモデルの微調整に使用されます。
ステップ 3.24 - セル 24: トレーニングおよび検証データセットをDataset
オブジェクトに変換する
24 番目のセルは、 train_data_df
およびvalid_data_df
DataFrame をDataset
オブジェクトに変換します。24 番目のセルを次のように設定します。
### CELL 24: Convert training and validation datasets to Dataset objects ### train_data = Dataset.from_pandas(train_data_df) valid_data = Dataset.from_pandas(valid_data_df)
ステップ 3.30でわかるように、
Dataset
オブジェクトは HuggingFaceTrainer
クラス インスタンスによって使用されます。これらのオブジェクトには、データセット自体だけでなく、データセットに関するメタデータも含まれています。
train_data
とvalid_data
を印刷して、両方のDataset
オブジェクトのメタデータを表示できます。
print(train_data) print(valid_data) # Output: Dataset({ features: ['input_values', 'labels'], num_rows: 2880 }) Dataset({ features: ['input_values', 'labels'], num_rows: 320 })
ステップ 3.25 - セル 25: 事前トレーニング済みモデルの初期化
25 番目のセルは、事前トレーニング済みの XLS-R (0.3) モデルを初期化します。25 番目のセルを次のように設定します。
### CELL 25: Initialize pretrained model ### model = Wav2Vec2ForCTC.from_pretrained( MODEL, ctc_loss_reduction = "mean", pad_token_id = processor.tokenizer.pad_token_id, vocab_size = len(processor.tokenizer) )
-
Wav2Vec2ForCTC
で呼び出されるfrom_pretrained
メソッドは、指定されたモデルの事前トレーニング済みの重みを読み込むことを指定します。 -
MODEL
定数はステップ3.5で指定され、XLS-R(0.3)モデルを反映してfacebook/wav2vec2-xls-r-300m
に設定されました。 -
ctc_loss_reduction
パラメータは、コネクショニスト時間分類 ("CTC") 損失関数の出力に適用する削減のタイプを指定します。CTC 損失は、連続入力 (この場合はオーディオ データ) とターゲット シーケンス (この場合はテキスト転写) 間の損失を計算するために使用されます。値をmean
に設定すると、入力バッチの出力損失がターゲットの長さで除算されます。次に、バッチ全体の平均が計算され、削減が損失値に適用されます。 -
pad_token_id
、バッチ処理時にパディングに使用するトークンを指定します。これは、ステップ3.17でトークナイザーを初期化するときに設定された[PAD]
idに設定されます。 -
vocab_size
パラメータは、モデルの語彙サイズを定義します。これは、ステップ 3.17でトークナイザーを初期化した後の語彙サイズであり、ネットワークの前方部分の出力層ノードの数を反映します。
ステップ 3.26 - セル 26: 特徴抽出器の重みの固定
26 番目のセルは、特徴抽出器の事前トレーニング済みの重みを固定します。26 番目のセルを次のように設定します。
### CELL 26: Freeze feature extractor ### model.freeze_feature_extractor()
ステップ 3.27 - セル 27: トレーニング引数の設定
27 番目のセルは、 Trainer
インスタンスに渡されるトレーニング引数を初期化します。27 番目のセルを次のように設定します。
### CELL 27: Set training arguments ### training_args = TrainingArguments( output_dir = OUTPUT_DIR_PATH, save_safetensors = False, group_by_length = True, per_device_train_batch_size = TRAIN_BATCH_SIZE, per_device_eval_batch_size = EVAL_BATCH_SIZE, num_train_epochs = TRAIN_EPOCHS, gradient_checkpointing = True, evaluation_strategy = "steps", save_strategy = "steps", logging_strategy = "steps", eval_steps = EVAL_STEPS, save_steps = SAVE_STEPS, logging_steps = LOGGING_STEPS, learning_rate = LEARNING_RATE, warmup_steps = WARMUP_STEPS )
-
TrainingArguments
クラスは100 を超えるパラメータを受け入れます。 -
save_safetensors
パラメータがFalse
の場合、safetensors
形式を使用する代わりに、微調整されたモデルをpickle
ファイルに保存することを指定します。 -
group_by_length
パラメータがTrue
の場合、ほぼ同じ長さのサンプルをグループ化する必要があることを示します。これにより、パディングが最小限に抑えられ、トレーニングの効率が向上します。 -
per_device_train_batch_size
、トレーニング ミニバッチあたりのサンプル数を設定します。このパラメータは、ステップ 3.5で割り当てられたTRAIN_BATCH_SIZE
定数によって18
に設定されます。これは、エポックあたり 160 ステップを意味します。 -
per_device_eval_batch_size
、評価(ホールドアウト)ミニバッチあたりのサンプル数を設定します。このパラメータは、ステップ 3.5で割り当てられたEVAL_BATCH_SIZE
定数によって10
に設定されます。 -
num_train_epochs
はトレーニング エポックの数を設定します。このパラメータは、ステップ 3.5で割り当てられたTRAIN_EPOCHS
定数によって30
に設定されます。これは、トレーニング中に合計 4,800 ステップを意味します。 -
gradient_checkpointing
パラメータがTrue
の場合、勾配計算をチェックポイントすることでメモリを節約できますが、逆方向パスの速度は低下します。 -
evaluation_strategy
パラメータをsteps
に設定すると、トレーニング中にパラメータeval_steps
で指定された間隔で評価が実行され、ログに記録されます。 -
logging_strategy
パラメータをsteps
に設定すると、トレーニング実行の統計がパラメータlogging_steps
で指定された間隔で記録されることを意味します。 -
save_strategy
パラメータをsteps
に設定すると、微調整されたモデルのチェックポイントが、パラメータsave_steps
で指定された間隔で保存されることを意味します。 -
eval_steps
、ホールドアウト データの評価間のステップ数を設定します。このパラメータは、ステップ 3.5で割り当てられたEVAL_STEPS
定数によって100
に設定されます。 -
save_steps
、微調整されたモデルのチェックポイントが保存されるまでのステップ数を設定します。このパラメータは、ステップ 3.5で割り当てられたSAVE_STEPS
定数によって3200
に設定されます。 -
logging_steps
、トレーニング実行統計のログ間のステップ数を設定します。このパラメータは、ステップ 3.5で割り当てられたLOGGING_STEPS
定数によって100
に設定されます。 -
learning_rate
パラメータは初期学習率を設定します。このパラメータは、ステップ3.5で割り当てられたLEARNING_RATE
定数によって1e-4
に設定されます。 -
warmup_steps
パラメータは、学習率を 0 からlearning_rate
で設定された値まで線形にウォームアップするステップ数を設定します。このパラメータは、ステップ 3.5で割り当てられたWARMUP_STEPS
定数によって800
に設定されます。
ステップ 3.28 - セル 28: データ コレータ ロジックの定義
28 番目のセルは、入力シーケンスとターゲット シーケンスを動的にパディングするためのロジックを定義します。28 番目のセルを次のように設定します。
### CELL 28: Define data collator logic ### @dataclass class DataCollatorCTCWithPadding: processor: Wav2Vec2Processor padding: Union[bool, str] = True max_length: Optional[int] = None max_length_labels: Optional[int] = None pad_to_multiple_of: Optional[int] = None pad_to_multiple_of_labels: Optional[int] = None def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: input_features = [{"input_values": feature["input_values"]} for feature in features] label_features = [{"input_ids": feature["labels"]} for feature in features] batch = self.processor.pad( input_features, padding = self.padding, max_length = self.max_length, pad_to_multiple_of = self.pad_to_multiple_of, return_tensors = "pt", ) with self.processor.as_target_processor(): labels_batch = self.processor.pad( label_features, padding = self.padding, max_length = self.max_length_labels, pad_to_multiple_of = self.pad_to_multiple_of_labels, return_tensors = "pt", ) labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) batch["labels"] = labels return batch
- トレーニングと評価の入力ラベルのペアは、ステップ 3.30ですぐに初期化される
Trainer
インスタンスにミニバッチで渡されます。入力シーケンスとラベル シーケンスの長さは各ミニバッチで異なるため、一部のシーケンスはすべて同じ長さになるようにパディングする必要があります。 -
DataCollatorCTCWithPadding
クラスは、ミニバッチ データを動的にパディングします。padding
パラメータをTrue
に設定すると、短いオーディオ入力機能シーケンスとラベル シーケンスの長さがミニバッチ内の最長シーケンスと同じになるように指定されます。 - オーディオ入力特徴は、ステップ 3.18で特徴抽出器を初期化するときに設定された値
0.0
でパディングされます。 - ラベル入力は、まずステップ 3.17でトークナイザーを初期化するときに設定されたパディング値でパディングされます。これらの値は
-100
に置き換えられ、WER メトリックを計算するときにこれらのラベルは無視されます。
ステップ 3.29 - セル 29: データ コレータのインスタンスの初期化
29 番目のセルは、前の手順で定義したデータ コレータのインスタンスを初期化します。29 番目のセルを次のように設定します。
### CELL 29: Initialize instance of data collator ### data_collator = DataCollatorCTCWithPadding(processor = processor, padding = True)
ステップ 3.30 - セル 30: トレーナーの初期化
30 番目のセルは、 Trainer
クラスのインスタンスを初期化します。30 番目のセルを次のように設定します。
### CELL 30: Initialize trainer ### trainer = Trainer( model = model, data_collator = data_collator, args = training_args, compute_metrics = compute_wer, train_dataset = train_data, eval_dataset = valid_data, tokenizer = processor.feature_extractor )
- ご覧のとおり、
Trainer
クラスは次のように初期化されます。- ステップ 3.25で初期化された事前トレーニング済み
model
。 - ステップ 3.29で初期化されたデータ コレータ。
- ステップ3.27で初期化されたトレーニング引数。
- ステップ3.9で定義されたWER評価方法。
- ステップ 3.24の
train_data
Dataset
オブジェクト。 - ステップ 3.24の
valid_data
Dataset
オブジェクト。
- ステップ 3.25で初期化された事前トレーニング済み
-
tokenizer
パラメータは、processor.feature_extractor
に割り当てられ、data_collator
と連携して、各ミニバッチの最大長の入力に自動的に入力を埋め込みます。
ステップ 3.31 - セル 31: モデルの微調整
31 番目のセルは、 Trainer
クラス インスタンスのtrain
メソッドを呼び出して、モデルを微調整します。31 番目のセルを次のように設定します。
### CELL 31: Finetune the model ### trainer.train()
ステップ3.32 - セル32: 微調整したモデルを保存する
32 番目のセルは最後のノートブック セルです。Trainer Trainer
のsave_model
メソッドを呼び出して、微調整されたモデルを保存します。32 番目のセルを次のように設定します。
### CELL 32: Save the finetuned model ### trainer.save_model(OUTPUT_DIR_PATH)
ステップ4 - モデルのトレーニングと保存
ステップ 4.1 - モデルのトレーニング
ノートブックのすべてのセルが構築されたので、微調整を開始します。
Kaggle Notebook をNVIDIA GPU P100アクセラレータで実行するように設定します。
ノートブックを Kaggle にコミットします。
WandB アカウントにログインし、関連する実行を見つけて、トレーニング実行データを監視します。
NVIDIA GPU P100 アクセラレータを使用すると、30 エポックを超えるトレーニングには約 5 時間かかります。ホールドアウト データの WER は、トレーニング終了時に約 0.15 に低下します。最先端の結果ではありませんが、微調整されたモデルは、多くのアプリケーションで十分に役立ちます。
ステップ4.2 - モデルの保存
微調整されたモデルは、ステップ 3.5で指定された定数OUTPUT_DIR_PATH
で指定された Kaggle ディレクトリに出力されます。モデル出力には次のファイルが含まれる必要があります。
pytorch_model.bin config.json preprocessor_config.json vocab.json training_args.bin
これらのファイルはローカルにダウンロードできます。さらに、モデル ファイルを使用して新しいKaggle モデルを作成することもできます。Kaggleモデルは、付属の推論ガイドとともに使用され、微調整されたモデルで推論を実行します。
- Kaggle アカウントにログインします。 「モデル」 > 「新しいモデル」をクリックします。
- 「モデルタイトル」フィールドに、微調整したモデルのタイトルを追加します。
- 「モデルの作成」をクリックします。
- 「モデルの詳細ページに移動」をクリックします。
- 「モデルバリエーション」の下にある「新しいバリエーションの追加」をクリックします。
- フレームワーク選択メニューからトランスフォーマーを選択します。
- 「新しいバリエーションを追加」をクリックします。
- 微調整したモデル ファイルを[データのアップロード]ウィンドウにドラッグ アンド ドロップします。または、 [ファイルの参照]ボタンをクリックしてファイル エクスプローラー ウィンドウを開き、微調整したモデル ファイルを選択します。
- ファイルが Kaggle にアップロードされたら、 「作成」をクリックしてKaggle モデルを作成します。
結論
wav2vec2 XLS-R の微調整おめでとうございます。これらの一般的な手順を使用して、必要な他の言語でモデルを微調整できることを覚えておいてください。このガイドで生成された微調整済みモデルで推論を実行するのは非常に簡単です。推論手順については、このガイドとは別のコンパニオン ガイドで概説します。コンパニオン ガイドを見つけるには、私の HackerNoon ユーザー名を検索してください。