Giới thiệu
Meta AI đã giới thiệu wav2vec2 XLS-R ("XLS-R") vào cuối năm 2021. XLS-R là mô hình máy học ("ML") để học cách biểu đạt giọng nói đa ngôn ngữ; và nó đã được đào tạo trên hơn 400.000 giờ âm thanh giọng nói có sẵn công khai trên 128 ngôn ngữ. Sau khi phát hành, mô hình này thể hiện một bước nhảy vọt so với mô hình đa ngôn ngữ XLSR-53 của Meta AI, mô hình này đã được đào tạo về khoảng 50.000 giờ âm thanh lời nói trên 53 ngôn ngữ.
Hướng dẫn này giải thích các bước để tinh chỉnh XLS-R để nhận dạng giọng nói tự động ("ASR") bằng Sổ tay Kaggle . Mô hình này sẽ được tinh chỉnh bằng tiếng Tây Ban Nha Chile, nhưng bạn có thể làm theo các bước chung để tinh chỉnh XLS-R trên các ngôn ngữ khác nhau mà bạn mong muốn.
Việc chạy suy luận trên mô hình tinh chỉnh sẽ được mô tả trong phần hướng dẫn đi kèm, khiến hướng dẫn này trở thành phần đầu tiên trong hai phần. Tôi quyết định tạo một hướng dẫn riêng dành riêng cho suy luận vì hướng dẫn tinh chỉnh này hơi dài.
Giả sử bạn đã có nền tảng ML hiện có và hiểu các khái niệm ASR cơ bản. Người mới bắt đầu có thể gặp khó khăn khi theo dõi/hiểu các bước xây dựng.
Một chút thông tin cơ bản về XLS-R
Mô hình wav2vec2 ban đầu được giới thiệu vào năm 2020 đã được đào tạo trước trên 960 giờ âm thanh giọng nói của tập dữ liệu Librispeech và ~53.200 giờ âm thanh giọng nói của tập dữ liệu LibriVox . Sau khi phát hành, có hai kích cỡ mô hình: mô hình BASE với 95 triệu tham số và mô hình LARGE với 317 triệu tham số.
Mặt khác, XLS-R đã được huấn luyện trước về âm thanh lời nói đa ngôn ngữ từ 5 bộ dữ liệu:
- VoxPopuli : Tổng cộng có ~372.000 giờ âm thanh bài phát biểu trên 23 ngôn ngữ phát biểu nghị viện Châu Âu từ quốc hội Châu Âu.
- Thư viện đa ngôn ngữ : Tổng cộng ~50.000 giờ âm thanh giọng nói trên tám ngôn ngữ Châu Âu, với phần lớn (~44.000 giờ) dữ liệu âm thanh bằng tiếng Anh.
- CommonVoice : Tổng cộng ~7.000 giờ âm thanh giọng nói trên 60 ngôn ngữ.
- VoxLingua107 : Tổng cộng ~6.600 giờ âm thanh giọng nói trên 107 ngôn ngữ dựa trên nội dung YouTube.
- BABEL : Tổng cộng ~1.100 giờ âm thanh giọng nói trên 17 ngôn ngữ Châu Phi và Châu Á dựa trên lời nói đàm thoại qua điện thoại.
Có 3 mẫu XLS-R: XLS-R (0,3B) với 300 triệu thông số, XLS-R (1B) với 1 tỷ thông số và XLS-R (2B) với 2 tỷ thông số. Hướng dẫn này sẽ sử dụng mẫu XLS-R (0,3B).
Tiếp cận
Có một số bài viết hay về cách tinh chỉnh các mô hình wav2vev2 , có lẽ mô hình này là một loại "tiêu chuẩn vàng". Tất nhiên, cách tiếp cận chung ở đây bắt chước những gì bạn sẽ tìm thấy trong các hướng dẫn khác. Bạn sẽ:
- Tải tập dữ liệu huấn luyện gồm dữ liệu âm thanh và bản ghi văn bản liên quan.
- Tạo từ vựng từ bản phiên âm văn bản trong tập dữ liệu.
- Khởi tạo bộ xử lý wav2vec2 sẽ trích xuất các tính năng từ dữ liệu đầu vào cũng như chuyển đổi bản phiên âm văn bản thành chuỗi nhãn.
- Tinh chỉnh wav2vec2 XLS-R trên dữ liệu đầu vào đã xử lý.
Tuy nhiên, có ba điểm khác biệt chính giữa hướng dẫn này và các hướng dẫn khác:
- Hướng dẫn không cung cấp nhiều cuộc thảo luận "nội tuyến" về các khái niệm ML và ASR có liên quan.
- Mặc dù mỗi phần phụ trên các ô sổ tay riêng lẻ sẽ bao gồm thông tin chi tiết về việc sử dụng/mục đích của ô cụ thể, nhưng giả định rằng bạn có nền tảng ML hiện có và bạn hiểu các khái niệm ASR cơ bản.
- Sổ tay Kaggle mà bạn sẽ xây dựng sẽ sắp xếp các phương thức tiện ích trong các ô cấp cao nhất.
- Trong khi nhiều sổ ghi chép tinh chỉnh có xu hướng có kiểu bố cục kiểu "dòng ý thức", tôi đã quyết định sắp xếp tất cả các phương pháp hữu ích lại với nhau. Nếu mới làm quen với wav2vec2, bạn có thể thấy cách tiếp cận này khó hiểu. Tuy nhiên, để nhắc lại, tôi cố gắng hết sức để giải thích rõ ràng mục đích của từng ô trong phần phụ dành riêng cho mỗi ô. Nếu bạn mới tìm hiểu về wav2vec2, bạn có thể được hưởng lợi từ việc xem nhanh bài viết HackerNoon wav2vec2 của tôi về Nhận dạng giọng nói tự động bằng tiếng Anh đơn giản .
- Hướng dẫn này chỉ mô tả các bước để tinh chỉnh.
- Như đã đề cập trong phần Giới thiệu , tôi đã chọn tạo một hướng dẫn đồng hành riêng về cách chạy suy luận trên mô hình XLS-R đã được tinh chỉnh mà bạn sẽ tạo. Điều này được thực hiện để tránh hướng dẫn này trở nên quá dài.
Điều kiện tiên quyết và trước khi bạn bắt đầu
Để hoàn thành hướng dẫn, bạn sẽ cần phải có:
- Một tài khoản Kaggle hiện có. Nếu chưa có tài khoản Kaggle, bạn cần tạo một tài khoản.
- Tài khoản Trọng số và Xu hướng hiện có ("WandB") . Nếu hiện tại bạn chưa có tài khoản Trọng lượng và Xu hướng, bạn cần tạo một tài khoản.
- Khóa API WandB. Nếu bạn không có khóa API WandB, hãy làm theo các bước tại đây .
- Kiến thức trung cấp về Python.
- Kiến thức trung cấp về làm việc với Kaggle Notebooks.
- Kiến thức trung cấp về các khái niệm ML.
- Kiến thức cơ bản về các khái niệm ASR.
Trước khi bắt đầu xây dựng sổ ghi chép, bạn có thể xem lại hai phần phụ ngay bên dưới. Họ mô tả:
- Tập dữ liệu huấn luyện.
- Số liệu Tỷ lệ Lỗi Từ ("WER") được sử dụng trong quá trình đào tạo.
Tập dữ liệu đào tạo
Như đã đề cập trong phần Giới thiệu , mẫu XLS-R sẽ được tinh chỉnh bằng tiếng Tây Ban Nha ở Chile. Tập dữ liệu cụ thể là Tập dữ liệu lời nói tiếng Tây Ban Nha ở Chile được phát triển bởi Guevara-Rukoz et al. Nó có sẵn để tải xuống trên OpenSLR . Bộ dữ liệu bao gồm hai bộ dữ liệu phụ: (1) 2.636 bản ghi âm của những người nói tiếng Chile là nam và (2) 1.738 bản ghi âm của những người nói tiếng Chile là nữ.
Mỗi tập dữ liệu con bao gồm một tệp chỉ mục line_index.tsv
. Mỗi dòng của mỗi tệp chỉ mục chứa một cặp tên tệp âm thanh và bản ghi âm của tệp được liên kết, ví dụ:
clm_08421_01719502739 Es un viaje de negocios solamente voy por una noche clm_02436_02011517900 Se usa para incitar a alguien a sacar el mayor provecho del dia presente
Tôi đã tải Bộ dữ liệu lời nói tiếng Tây Ban Nha tiếng Chile lên Kaggle để thuận tiện. Có một tập dữ liệu Kaggle cho các bản ghi âm của những người nói tiếng Chile là nam giới và một bộ dữ liệu Kaggle cho các bản ghi âm của những người nói tiếng Chile là nữ . Các bộ dữ liệu Kaggle này sẽ được thêm vào Sổ tay Kaggle mà bạn sẽ xây dựng theo các bước trong hướng dẫn này.
Tỷ lệ lỗi từ (WER)
WER là một số liệu có thể được sử dụng để đo lường hiệu suất của các mô hình nhận dạng giọng nói tự động. WER cung cấp một cơ chế để đo lường mức độ gần gũi của dự đoán văn bản với tham chiếu văn bản. WER thực hiện điều này bằng cách ghi lại 3 loại lỗi:
thay thế (
S
): Lỗi thay thế được ghi lại khi dự đoán chứa một từ khác với từ tương tự trong tham chiếu. Ví dụ: điều này xảy ra khi dự đoán viết sai chính tả một từ trong tài liệu tham khảo.xóa (
D
): Lỗi xóa được ghi lại khi dự đoán chứa một từ không có trong tham chiếu.phần chèn (
I
): Lỗi chèn được ghi lại khi dự đoán không chứa từ nào có trong tham chiếu.
Rõ ràng, WER hoạt động ở cấp độ từ. Công thức cho số liệu WER như sau:
WER = (S + D + I)/N where: S = number of substition errors D = number of deletion errors I = number of insertion errors N = number of words in the reference
Một ví dụ WER đơn giản bằng tiếng Tây Ban Nha như sau:
prediction: "Él está saliendo." reference: "Él está saltando."
Một bảng giúp hình dung các lỗi trong dự đoán:
CHỮ | TỪ 1 | TỪ 2 | TỪ 3 |
---|---|---|---|
sự dự đoán | Él | está | saliendo |
thẩm quyền giải quyết | Él | está | muối |
| Chính xác | Chính xác | thay thế |
Dự đoán có 1 lỗi thay thế, 0 lỗi xóa và 0 lỗi chèn. Vì vậy, WER cho ví dụ này là:
WER = 1 + 0 + 0 / 3 = 1/3 = 0.33
Rõ ràng là Tỷ lệ lỗi từ không nhất thiết cho chúng ta biết những lỗi cụ thể nào đang tồn tại. Trong ví dụ trên, WER xác định rằng WORD 3 có lỗi trong văn bản được dự đoán, nhưng nó không cho chúng ta biết rằng các ký tự i và e sai trong dự đoán. Các số liệu khác, chẳng hạn như Tỷ lệ lỗi ký tự ("CER"), có thể được sử dụng để phân tích lỗi chính xác hơn.
Xây dựng sổ tay tinh chỉnh
Bây giờ bạn đã sẵn sàng để bắt đầu xây dựng sổ ghi chép tinh chỉnh.
- Bước 1 và Bước 2 hướng dẫn bạn thiết lập môi trường Kaggle Notebook.
- Bước 3 hướng dẫn bạn cách xây dựng sổ ghi chép. Nó chứa 32 bước phụ đại diện cho 32 ô của sổ ghi chép tinh chỉnh.
- Bước 4 hướng dẫn bạn cách chạy sổ ghi chép, theo dõi quá trình đào tạo và lưu mô hình.
Bước 1 - Tìm nạp khóa API WandB của bạn
Sổ tay Kaggle của bạn phải được định cấu hình để gửi dữ liệu chạy đào tạo tới WandB bằng khóa API WandB của bạn. Để làm được điều đó, bạn cần sao chép nó.
- Đăng nhập vào WandB tại
www.wandb.com
. - Điều hướng đến
www.wandb.ai/authorize
. - Sao chép khóa API của bạn để sử dụng trong bước tiếp theo.
Bước 2 - Thiết lập môi trường Kaggle của bạn
Bước 2.1 - Tạo sổ tay Kaggle mới
- Đăng nhập vào Kaggle.
- Tạo Sổ tay Kaggle mới.
- Tất nhiên, tên sổ ghi chép có thể được thay đổi theo ý muốn. Hướng dẫn này sử dụng tên sổ ghi chép
xls-r-300m-chilean-spanish-asr
.
Bước 2.2 - Đặt khóa API WandB của bạn
Bí mật Kaggle sẽ được sử dụng để lưu trữ khóa API WandB của bạn một cách an toàn.
- Nhấp vào Tiện ích bổ sung trên menu chính của Kaggle Notebook.
- Chọn Bí mật từ menu bật lên.
- Nhập nhãn
WANDB_API_KEY
vào trường Nhãn và nhập khóa API WandB của bạn cho giá trị. - Đảm bảo rằng hộp kiểm Đã đính kèm ở bên trái của trường nhãn
WANDB_API_KEY
được chọn. - Nhấp vào Xong .
Bước 2.3 - Thêm bộ dữ liệu đào tạo
Tập dữ liệu giọng nói tiếng Tây Ban Nha ở Chile đã được tải lên Kaggle dưới dạng 2 tập dữ liệu riêng biệt:
Thêm cả hai bộ dữ liệu này vào Sổ tay Kaggle của bạn.
Bước 3 - Xây dựng sổ tay tinh chỉnh
32 bước phụ sau đây sẽ xây dựng từng ô trong số 32 ô của sổ ghi chép tinh chỉnh theo thứ tự.
Bước 3.1 - Ô 1: Cài đặt gói
Ô đầu tiên của sổ ghi chép tinh chỉnh sẽ cài đặt các phần phụ thuộc. Đặt ô đầu tiên thành:
### CELL 1: Install Packages ### !pip install --upgrade torchaudio !pip install jiwer
- Dòng đầu tiên nâng cấp gói
torchaudio
lên phiên bản mới nhất.torchaudio
sẽ được sử dụng để tải các tệp âm thanh và lấy mẫu lại dữ liệu âm thanh. - Dòng thứ hai cài đặt gói
jiwer
được yêu cầu để sử dụng phương thứcload_metric
của thư viện HuggingFaceDatasets
được sử dụng sau này.
Bước 3.2 - Ô 2: Nhập gói Python
Việc nhập ô thứ hai yêu cầu các gói Python. Đặt ô thứ hai thành:
### CELL 2: Import Python packages ### import wandb from kaggle_secrets import UserSecretsClient import math import re import numpy as np import pandas as pd import torch import torchaudio import json from typing import Any, Dict, List, Optional, Union from dataclasses import dataclass from datasets import Dataset, load_metric, load_dataset, Audio from transformers import Wav2Vec2CTCTokenizer from transformers import Wav2Vec2FeatureExtractor from transformers import Wav2Vec2Processor from transformers import Wav2Vec2ForCTC from transformers import TrainingArguments from transformers import Trainer
- Có lẽ bạn đã quen thuộc với hầu hết các gói này. Việc sử dụng chúng trong sổ ghi chép sẽ được giải thích khi các ô tiếp theo được xây dựng.
- Điều đáng nói là thư viện
transformers
HuggingFace và các lớpWav2Vec2*
liên quan cung cấp nền tảng cho chức năng được sử dụng để tinh chỉnh.
Bước 3.3 - Ô 3: Đang tải số liệu WER
Ô thứ ba nhập số liệu đánh giá HuggingFace WER. Đặt ô thứ ba thành:
### CELL 3: Load WER metric ### wer_metric = load_metric("wer")
- Như đã đề cập trước đó, WER sẽ được sử dụng để đo lường hiệu suất của mô hình trên dữ liệu đánh giá/nắm giữ.
Bước 3.4 - Ô 4: Đăng nhập vào WandB
Ô thứ tư lấy bí mật WANDB_API_KEY
của bạn đã được đặt ở Bước 2.2 . Đặt ô thứ tư thành:
### CELL 4: Login to WandB ### user_secrets = UserSecretsClient() wandb_api_key = user_secrets.get_secret("WANDB_API_KEY") wandb.login(key = wandb_api_key)
- Khóa API được sử dụng để định cấu hình Kaggle Notebook để dữ liệu chạy đào tạo được gửi đến WandB.
Bước 3.5 - Ô 5: Đặt hằng số
Ô thứ năm đặt các hằng số sẽ được sử dụng trong toàn bộ sổ ghi chép. Đặt ô thứ năm thành:
### CELL 5: Constants ### # Training data TRAINING_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-chile-male/" TRAINING_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-chile-female/" EXT = ".wav" NUM_LOAD_FROM_EACH_SET = 1600 # Vocabulary VOCAB_FILE_PATH = "/kaggle/working/" SPECIAL_CHARS = r"[\d\,\-\;\!\¡\?\¿\।\'\'\"\–\'\:\/\.\“\”\৷\…\‚\॥\\]" # Sampling rates ORIG_SAMPLING_RATE = 48000 TGT_SAMPLING_RATE = 16000 # Training/validation data split SPLIT_PCT = 0.10 # Model parameters MODEL = "facebook/wav2vec2-xls-r-300m" USE_SAFETENSORS = False # Training arguments OUTPUT_DIR_PATH = "/kaggle/working/xls-r-300m-chilean-spanish-asr" TRAIN_BATCH_SIZE = 18 EVAL_BATCH_SIZE = 10 TRAIN_EPOCHS = 30 SAVE_STEPS = 3200 EVAL_STEPS = 100 LOGGING_STEPS = 100 LEARNING_RATE = 1e-4 WARMUP_STEPS = 800
- Sổ ghi chép không hiển thị mọi hằng số có thể tưởng tượng được trong ô này. Một số giá trị có thể được biểu thị bằng hằng số đã được giữ nguyên dòng.
- Việc sử dụng nhiều hằng số ở trên là hiển nhiên. Đối với những trường hợp không, việc sử dụng chúng sẽ được giải thích trong các bước phụ sau.
Bước 3.6 - Ô 6: Các phương pháp tiện ích để đọc tệp chỉ mục, làm sạch văn bản và tạo từ vựng
Ô thứ sáu xác định các phương thức tiện ích để đọc các tệp chỉ mục tập dữ liệu (xem phần phụ Tập dữ liệu huấn luyện ở trên), cũng như để làm sạch văn bản phiên âm và tạo từ vựng. Đặt ô thứ sáu thành:
### CELL 6: Utility methods for reading index files, cleaning text, and creating vocabulary ### def read_index_file_data(path: str, filename: str): data = [] with open(path + filename, "r", encoding = "utf8") as f: lines = f.readlines() for line in lines: file_and_text = line.split("\t") data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")]) return data def truncate_training_dataset(dataset: list) -> list: if type(NUM_LOAD_FROM_EACH_SET) == str and "all" == NUM_LOAD_FROM_EACH_SET.lower(): return else: return dataset[:NUM_LOAD_FROM_EACH_SET] def clean_text(text: str) -> str: cleaned_text = re.sub(SPECIAL_CHARS, "", text) cleaned_text = cleaned_text.lower() return cleaned_text def create_vocab(data): vocab_list = [] for index in range(len(data)): text = data[index][1] words = text.split(" ") for word in words: chars = list(word) for char in chars: if char not in vocab_list: vocab_list.append(char) return vocab_list
Phương thức
read_index_file_data
đọc tệp chỉ mục tập dữ liệuline_index.tsv
và tạo danh sách các danh sách có tên tệp âm thanh và dữ liệu phiên âm, ví dụ:
[ ["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"] ... ]
- Phương thức
truncate_training_dataset
cắt bớt dữ liệu tệp chỉ mục danh sách bằng cách sử dụng hằngNUM_LOAD_FROM_EACH_SET
được đặt ở Bước 3.5 . Cụ thể, hằng sốNUM_LOAD_FROM_EACH_SET
được sử dụng để chỉ định số lượng mẫu âm thanh cần tải từ mỗi tập dữ liệu. Vì mục đích của hướng dẫn này, con số được đặt ở1600
, nghĩa là cuối cùng tổng cộng3200
mẫu âm thanh sẽ được tải. Để tải tất cả các mẫu, hãy đặtNUM_LOAD_FROM_EACH_SET
thành giá trị chuỗiall
. - Phương thức
clean_text
được sử dụng để loại bỏ từng phiên âm văn bản của các ký tự được chỉ định bởi biểu thức chính quy được gán choSPECIAL_CHARS
trong Bước 3.5 . Những ký tự này, bao gồm cả dấu câu, có thể bị loại bỏ vì chúng không cung cấp bất kỳ giá trị ngữ nghĩa nào khi đào tạo mô hình để tìm hiểu ánh xạ giữa các tính năng âm thanh và bản chép lời văn bản. - Phương thức
create_vocab
tạo từ vựng từ các bản chép lại văn bản rõ ràng. Đơn giản, nó trích xuất tất cả các ký tự duy nhất từ bộ phiên âm văn bản đã được làm sạch. Bạn sẽ thấy ví dụ về từ vựng được tạo ở Bước 3.14 .
Bước 3.7 - CELL 7: Các phương pháp tiện ích để tải và lấy mẫu lại dữ liệu âm thanh
Ô thứ bảy xác định các phương thức tiện ích sử dụng torchaudio
để tải và lấy mẫu lại dữ liệu âm thanh. Đặt ô thứ bảy thành:
### CELL 7: Utility methods for loading and resampling audio data ### def read_audio_data(file): speech_array, sampling_rate = torchaudio.load(file, normalize = True) return speech_array, sampling_rate def resample(waveform): transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE) waveform = transform(waveform) return waveform[0]
- Phương thức
read_audio_data
tải một tệp âm thanh được chỉ định và trả về một ma trận đa chiềutorch.Tensor
của dữ liệu âm thanh cùng với tốc độ lấy mẫu của âm thanh. Tất cả các tệp âm thanh trong dữ liệu huấn luyện đều có tốc độ lấy mẫu là48000
Hz. Tốc độ lấy mẫu "gốc" này được ghi lại bằng hằng sốORIG_SAMPLING_RATE
trong Bước 3.5 . - Phương pháp
resample
được sử dụng để giảm mẫu dữ liệu âm thanh từ tốc độ lấy mẫu từ48000
xuống16000
. wav2vec2 được huấn luyện trước trên âm thanh được lấy mẫu ở16000
Hz. Theo đó, bất kỳ âm thanh nào được sử dụng để tinh chỉnh đều phải có cùng tốc độ lấy mẫu. Trong trường hợp này, mẫu âm thanh phải được giảm tần số lấy mẫu từ48000
Hz xuống16000
Hz.16000
Hz được ghi lại bằng hằng sốTGT_SAMPLING_RATE
ở Bước 3.5 .
Bước 3.8 - Ô 8: Các phương pháp hữu ích để chuẩn bị dữ liệu cho đào tạo
Ô thứ tám xác định các phương thức tiện ích xử lý dữ liệu âm thanh và phiên âm. Đặt ô thứ tám thành:
### CELL 8: Utility methods to prepare input data for training ### def process_speech_audio(speech_array, sampling_rate): input_values = processor(speech_array, sampling_rate = sampling_rate).input_values return input_values[0] def process_target_text(target_text): with processor.as_target_processor(): encoding = processor(target_text).input_ids return encoding
- Phương thức
process_speech_audio
trả về giá trị đầu vào từ mẫu đào tạo được cung cấp. - Phương thức
process_target_text
mã hóa mỗi bản phiên âm văn bản dưới dạng danh sách các nhãn - tức là danh sách các chỉ mục đề cập đến các ký tự trong từ vựng. Bạn sẽ thấy mã hóa mẫu ở Bước 3.15 .
Bước 3.9 - CELL 9: Phương pháp tiện ích để tính tỷ lệ lỗi từ
Ô thứ chín là ô phương thức tiện ích cuối cùng và chứa phương pháp tính Tỷ lệ lỗi từ giữa bản phiên âm tham chiếu và bản phiên âm dự đoán. Đặt ô thứ chín thành:
### CELL 9: Utility method to calculate Word Error Rate def compute_wer(pred): pred_logits = pred.predictions pred_ids = np.argmax(pred_logits, axis = -1) pred.label_ids[pred.label_ids == -100] = processor.tokenizer.pad_token_id pred_str = processor.batch_decode(pred_ids) label_str = processor.batch_decode(pred.label_ids, group_tokens = False) wer = wer_metric.compute(predictions = pred_str, references = label_str) return {"wer": wer}
Bước 3.10 - Ô 10: Đọc dữ liệu huấn luyện
Ô thứ mười đọc các tệp chỉ mục dữ liệu huấn luyện cho bản ghi của người nói nam và bản ghi của người nói nữ bằng phương pháp read_index_file_data
được xác định trong Bước 3.6 . Đặt ô thứ mười thành:
### CELL 10: Read training data ### training_samples_male_cl = read_index_file_data(TRAINING_DATA_PATH_MALE, "line_index.tsv") training_samples_female_cl = read_index_file_data(TRAINING_DATA_PATH_FEMALE, "line_index.tsv")
- Như đã thấy, dữ liệu đào tạo được quản lý theo hai danh sách dành riêng cho giới tính tại thời điểm này. Dữ liệu sẽ được kết hợp ở Bước 3.12 sau khi cắt bớt.
Bước 3.11 - Ô 11: Cắt bớt dữ liệu huấn luyện
Ô thứ mười một cắt bớt danh sách dữ liệu huấn luyện bằng phương pháp truncate_training_dataset
được xác định trong Bước 3.6 . Đặt ô thứ mười một thành:
### CELL 11: Truncate training data ### training_samples_male_cl = truncate_training_dataset(training_samples_male_cl) training_samples_female_cl = truncate_training_dataset(training_samples_female_cl)
- Xin nhắc lại, hằng số
NUM_LOAD_FROM_EACH_SET
được đặt ở Bước 3.5 xác định số lượng mẫu cần giữ lại từ mỗi tập dữ liệu. Hằng số được đặt thành1600
trong hướng dẫn này cho tổng số3200
mẫu.
Bước 3.12 - Ô 12: Kết hợp dữ liệu mẫu huấn luyện
Ô thứ mười hai kết hợp các danh sách dữ liệu huấn luyện bị cắt bớt. Đặt ô thứ mười hai thành:
### CELL 12: Combine training samples data ### all_training_samples = training_samples_male_cl + training_samples_female_cl
Bước 3.13 - Ô 13: Kiểm tra phiên mã sạch
Ô thứ mười ba lặp lại từng mẫu dữ liệu huấn luyện và xóa văn bản phiên mã liên quan bằng phương pháp clean_text
được xác định trong Bước 3.6 . Đặt ô thứ mười ba thành:
for index in range(len(all_training_samples)): all_training_samples[index][1] = clean_text(all_training_samples[index][1])
Bước 3.14 - Ô 14: Tạo từ vựng
Ô thứ mười bốn tạo từ vựng bằng cách sử dụng các bản phiên âm đã được làm sạch từ bước trước và phương thức create_vocab
được xác định ở Bước 3.6 . Đặt ô thứ mười bốn thành:
### CELL 14: Create vocabulary ### vocab_list = create_vocab(all_training_samples) vocab_dict = {v: i for i, v in enumerate(vocab_list)}
Từ vựng được lưu trữ dưới dạng từ điển với các ký tự là khóa và chỉ mục từ vựng là giá trị.
Bạn có thể in
vocab_dict
để tạo ra kết quả sau:
{'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32}
Bước 3.15 - Ô 15: Thêm dấu phân cách từ vào từ vựng
Ô thứ mười lăm thêm ký tự phân cách từ |
đến từ vựng. Đặt ô thứ mười lăm thành:
### CELL 15: Add word delimiter to vocabulary ### vocab_dict["|"] = len(vocab_dict)
Ký tự phân cách từ được sử dụng khi mã hóa bản phiên âm văn bản dưới dạng danh sách nhãn. Cụ thể, nó được sử dụng để xác định phần cuối của một từ và nó được sử dụng khi khởi tạo lớp
Wav2Vec2CTCTokenizer
, như sẽ thấy trong Bước 3.17 .Ví dụ: danh sách sau đây mã hóa
no te entiendo nada
bằng cách sử dụng từ vựng ở Bước 3.14 :
# Encoded text [6, 14, 33, 9, 5, 33, 5, 6, 9, 3, 5, 6, 8, 14, 33, 6, 1, 8, 1] # Vocabulary {'l': 0, 'a': 1, 'v': 2, 'i': 3, 'g': 4, 'e': 5, 'n': 6, 'c': 7, 'd': 8, 't': 9, 'u': 10, 'r': 11, 'j': 12, 's': 13, 'o': 14, 'h': 15, 'm': 16, 'q': 17, 'b': 18, 'p': 19, 'y': 20, 'f': 21, 'z': 22, 'á': 23, 'ú': 24, 'í': 25, 'ó': 26, 'é': 27, 'ñ': 28, 'x': 29, 'k': 30, 'w': 31, 'ü': 32, '|': 33}
- Một câu hỏi có thể nảy sinh một cách tự nhiên là: "Tại sao cần xác định ký tự phân cách từ?" Ví dụ: phần cuối của các từ trong văn bản tiếng Anh và tiếng Tây Ban Nha được đánh dấu bằng khoảng trắng nên việc sử dụng ký tự khoảng trắng làm dấu phân cách từ là một vấn đề đơn giản. Hãy nhớ rằng tiếng Anh và tiếng Tây Ban Nha chỉ là hai ngôn ngữ trong số hàng nghìn ngôn ngữ; và không phải tất cả các ngôn ngữ viết đều sử dụng dấu cách để đánh dấu ranh giới từ.
Bước 3.16 - Ô 16: Xuất từ vựng
Ô thứ mười sáu chuyển từ vựng vào một tập tin. Đặt ô thứ mười sáu thành:
### CELL 16: Export vocabulary ### with open(VOCAB_FILE_PATH + "vocab.json", "w", encoding = "utf8") as vocab_file: json.dump(vocab_dict, vocab_file)
- Tệp từ vựng sẽ được sử dụng trong bước tiếp theo, Bước 3.17 , để khởi tạo lớp
Wav2Vec2CTCTokenizer
.
Bước 3.17 - CELL 17: Khởi tạo Tokenizer
Ô thứ mười bảy khởi tạo một phiên bản của Wav2Vec2CTCTokenizer
. Đặt ô thứ mười bảy thành:
### CELL 17: Initialize tokenizer ### tokenizer = Wav2Vec2CTCTokenizer( VOCAB_FILE_PATH + "vocab.json", unk_token = "[UNK]", pad_token = "[PAD]", word_delimiter_token = "|", replace_word_delimiter_char = " " )
Trình mã thông báo được sử dụng để mã hóa bản phiên âm văn bản và giải mã danh sách nhãn trở lại văn bản.
Lưu ý rằng
tokenizer
được khởi tạo với[UNK]
được gán chounk_token
và[PAD]
được gán chopad_token
, với mã trước đây được sử dụng để biểu thị các mã thông báo không xác định trong bản phiên âm văn bản và mã thông báo sau được sử dụng để đệm phiên âm khi tạo các lô phiên âm có độ dài khác nhau. Hai giá trị này sẽ được thêm vào từ vựng bằng mã thông báo.Việc khởi tạo mã thông báo trong bước này cũng sẽ thêm hai mã thông báo bổ sung vào từ vựng, đó là
<s>
và/</s>
, được sử dụng để phân định lần lượt phần đầu và phần cuối của câu.|
được gán choword_delimiter_token
một cách rõ ràng trong bước này để phản ánh rằng ký hiệu ống sẽ được sử dụng để phân định ranh giới cuối từ theo cách chúng ta thêm ký tự vào từ vựng trong Bước 3.15 .|
ký hiệu là giá trị mặc định choword_delimiter_token
. Vì vậy, nó không cần phải được thiết lập rõ ràng nhưng được thực hiện vì mục đích rõ ràng.Tương tự như với
word_delimiter_token
, một khoảng trắng được gán rõ ràng choreplace_word_delimiter_char
phản ánh rằng ký hiệu ống|
sẽ được sử dụng để thay thế các ký tự khoảng trống trong phiên âm văn bản. Khoảng trống là giá trị mặc định choreplace_word_delimiter_char
. Vì vậy, nó cũng không cần phải được thiết lập rõ ràng nhưng được thực hiện vì mục đích rõ ràng.Bạn có thể in toàn bộ từ vựng của tokenizer bằng cách gọi phương thức
get_vocab()
trêntokenizer
.
vocab = tokenizer.get_vocab() print(vocab) # Output: {'e': 0, 's': 1, 'u': 2, 'n': 3, 'v': 4, 'i': 5, 'a': 6, 'j': 7, 'd': 8, 'g': 9, 'o': 10, 'c': 11, 'l': 12, 'm': 13, 't': 14, 'y': 15, 'p': 16, 'r': 17, 'h': 18, 'ñ': 19, 'ó': 20, 'b': 21, 'q': 22, 'f': 23, 'ú': 24, 'z': 25, 'é': 26, 'í': 27, 'x': 28, 'á': 29, 'w': 30, 'k': 31, 'ü': 32, '|': 33, '<s>': 34, '</s>': 35, '[UNK]': 36, '[PAD]': 37}
Bước 3.18 - Ô 18: Khởi tạo Trình trích xuất tính năng
Ô thứ mười tám khởi tạo một phiên bản của Wav2Vec2FeatureExtractor
. Đặt ô thứ mười tám thành:
### CELL 18: Initialize feature extractor ### feature_extractor = Wav2Vec2FeatureExtractor( feature_size = 1, sampling_rate = 16000, padding_value = 0.0, do_normalize = True, return_attention_mask = True )
- Trình trích xuất tính năng được sử dụng để trích xuất các tính năng từ dữ liệu đầu vào, tất nhiên là dữ liệu âm thanh trong trường hợp sử dụng này. Bạn sẽ tải dữ liệu âm thanh cho từng mẫu dữ liệu huấn luyện ở Bước 3.20 .
- Các giá trị tham số được truyền tới trình khởi tạo
Wav2Vec2FeatureExtractor
đều là các giá trị mặc định, ngoại trừreturn_attention_mask
được mặc định làFalse
. Các giá trị mặc định được hiển thị/chuyển đi nhằm mục đích rõ ràng. - Tham số
feature_size
chỉ định kích thước kích thước của các tính năng đầu vào (tức là tính năng dữ liệu âm thanh). Giá trị mặc định của tham số này là1
. -
sampling_rate
cho trình trích xuất tính năng biết tốc độ lấy mẫu mà tại đó dữ liệu âm thanh sẽ được số hóa. Như đã thảo luận ở Bước 3.7 , wav2vec2 được huấn luyện trước trên âm thanh được lấy mẫu ở16000
Hz và do đó16000
là giá trị mặc định cho tham số này. - Tham số
padding_value
chỉ định giá trị được sử dụng khi đệm dữ liệu âm thanh, theo yêu cầu khi phân nhóm các mẫu âm thanh có độ dài khác nhau. Giá trị mặc định là0.0
. -
do_normalize
được sử dụng để chỉ định xem dữ liệu đầu vào có nên được chuyển đổi sang phân phối chuẩn chuẩn hay không. Giá trị mặc định làTrue
. Tài liệu lớpWav2Vec2FeatureExtractor
lưu ý rằng "[chuẩn hóa] có thể giúp cải thiện đáng kể hiệu suất cho một số kiểu máy." - Các tham số
return_attention_mask
chỉ định xem mặt nạ chú ý có được chuyển hay không. Giá trị được đặt thànhTrue
cho trường hợp sử dụng này.
Bước 3.19 - CELL 19: Khởi tạo bộ xử lý
Ô thứ mười chín khởi tạo một phiên bản của Wav2Vec2Processor
. Đặt ô thứ mười chín thành:
### CELL 19: Initialize processor ### processor = Wav2Vec2Processor(feature_extractor = feature_extractor, tokenizer = tokenizer)
Lớp
Wav2Vec2Processor
kết hợptokenizer
vàfeature_extractor
từ Bước 3.17 và Bước 3.18 tương ứng vào một bộ xử lý duy nhất.Lưu ý rằng cấu hình bộ xử lý có thể được lưu bằng cách gọi phương thức
save_pretrained
trên phiên bản lớpWav2Vec2Processor
.
processor.save_pretrained(OUTPUT_DIR_PATH)
Bước 3.20 - CELL 20: Đang tải dữ liệu âm thanh
Ô thứ 20 tải từng tệp âm thanh được chỉ định trong danh sách all_training_samples
. Đặt ô thứ hai mươi thành:
### CELL 20: Load audio data ### all_input_data = [] for index in range(len(all_training_samples)): speech_array, sampling_rate = read_audio_data(all_training_samples[index][0]) all_input_data.append({ "input_values": speech_array, "labels": all_training_samples[index][1] })
- Dữ liệu âm thanh được trả về dưới dạng
torch.Tensor
và được lưu trữ trongall_input_data
dưới dạng danh sách từ điển. Mỗi từ điển chứa dữ liệu âm thanh cho một mẫu cụ thể, cùng với bản phiên âm văn bản của âm thanh. - Lưu ý rằng phương thức
read_audio_data
cũng trả về tốc độ lấy mẫu của dữ liệu âm thanh. Vì chúng ta biết rằng tốc độ lấy mẫu là48000
Hz cho tất cả các tệp âm thanh trong trường hợp sử dụng này nên tốc độ lấy mẫu sẽ bị bỏ qua trong bước này.
Bước 3.21 - CELL 21: Chuyển đổi all_input_data
thành Pandas DataFrame
Ô thứ 21 chuyển đổi danh sách all_input_data
thành Pandas DataFrame để giúp thao tác dữ liệu dễ dàng hơn. Đặt ô thứ 21 thành:
### CELL 21: Convert audio training data list to Pandas DataFrame ### all_input_data_df = pd.DataFrame(data = all_input_data)
Bước 3.22 - CELL 22: Xử lý dữ liệu âm thanh và phiên âm văn bản
Ô thứ 22 sử dụng processor
được khởi tạo ở Bước 3.19 để trích xuất các tính năng từ từng mẫu dữ liệu âm thanh và mã hóa từng bản phiên âm văn bản dưới dạng danh sách các nhãn. Đặt ô thứ hai mươi hai thành:
### CELL 22: Process audio data and text transcriptions ### all_input_data_df["input_values"] = all_input_data_df["input_values"].apply(lambda x: process_speech_audio(resample(x), 16000)) all_input_data_df["labels"] = all_input_data_df["labels"].apply(lambda x: process_target_text(x))
Bước 3.23 - CELL 23: Tách dữ liệu đầu vào thành tập dữ liệu huấn luyện và xác thực
Ô thứ hai mươi ba chia DataFrame all_input_data_df
thành các tập dữ liệu huấn luyện và đánh giá (xác thực) bằng cách sử dụng hằng số SPLIT_PCT
từ Bước 3.5 . Đặt ô thứ hai mươi ba thành:
### CELL 23: Split input data into training and validation datasets ### split = math.floor((NUM_LOAD_FROM_EACH_SET * 2) * SPLIT_PCT) valid_data_df = all_input_data_df.iloc[-split:] train_data_df = all_input_data_df.iloc[:-split]
- Giá trị
SPLIT_PCT
là0.10
trong hướng dẫn này, nghĩa là 10% tất cả dữ liệu đầu vào sẽ được giữ lại để đánh giá và 90% dữ liệu sẽ được sử dụng để đào tạo/tinh chỉnh. - Vì có tổng cộng 3.200 mẫu huấn luyện nên 320 mẫu sẽ được sử dụng để đánh giá và 2.880 mẫu còn lại được sử dụng để tinh chỉnh mô hình.
Bước 3.24 - CELL 24: Chuyển đổi tập dữ liệu huấn luyện và xác thực thành đối tượng Dataset
Ô thứ 24 chuyển đổi DataFrames train_data_df
và valid_data_df
thành các đối tượng Dataset
. Đặt ô thứ 24 thành:
### CELL 24: Convert training and validation datasets to Dataset objects ### train_data = Dataset.from_pandas(train_data_df) valid_data = Dataset.from_pandas(valid_data_df)
Các đối tượng
Dataset
được sử dụng bởi các phiên bản của lớp HuggingFaceTrainer
, như bạn sẽ thấy trong Bước 3.30 .Các đối tượng này chứa siêu dữ liệu về tập dữ liệu cũng như chính tập dữ liệu đó.
Bạn có thể in
train_data
vàvalid_data
để xem siêu dữ liệu cho cả hai đối tượngDataset
.
print(train_data) print(valid_data) # Output: Dataset({ features: ['input_values', 'labels'], num_rows: 2880 }) Dataset({ features: ['input_values', 'labels'], num_rows: 320 })
Bước 3.25 - CELL 25: Khởi tạo mô hình tiền huấn luyện
Ô thứ 25 khởi tạo mô hình XLS-R (0,3) đã được huấn luyện trước. Đặt ô thứ 25 thành:
### CELL 25: Initialize pretrained model ### model = Wav2Vec2ForCTC.from_pretrained( MODEL, ctc_loss_reduction = "mean", pad_token_id = processor.tokenizer.pad_token_id, vocab_size = len(processor.tokenizer) )
- Phương thức
from_pretrained
được gọi trênWav2Vec2ForCTC
chỉ định rằng chúng ta muốn tải các trọng số đã được huấn luyện trước cho mô hình đã chỉ định. - Hằng số
MODEL
được chỉ định ở Bước 3.5 và được đặt thànhfacebook/wav2vec2-xls-r-300m
phản ánh mô hình XLS-R (0.3). - Tham số
ctc_loss_reduction
chỉ định loại giảm để áp dụng cho đầu ra của hàm mất mát Phân loại thời gian kết nối ("CTC"). Suy hao CTC được sử dụng để tính toán suy hao giữa đầu vào liên tục, trong trường hợp này là dữ liệu âm thanh và chuỗi đích, trong trường hợp này là bản chép lại văn bản. Bằng cách đặt giá trịmean
, tổn thất đầu ra của một lô đầu vào sẽ được chia cho độ dài mục tiêu. Sau đó, giá trị trung bình của lô sẽ được tính toán và mức giảm được áp dụng cho các giá trị tổn thất. -
pad_token_id
chỉ định mã thông báo được sử dụng để đệm khi tạo khối. Nó được đặt thành id[PAD]
được đặt khi khởi tạo mã thông báo ở Bước 3.17 . - Tham số
vocab_size
xác định kích thước từ vựng của mô hình. Đó là kích thước từ vựng sau khi khởi tạo mã thông báo ở Bước 3.17 và phản ánh số lượng nút lớp đầu ra của phần chuyển tiếp của mạng.
Bước 3.26 - CELL 26: Tính năng đóng băng Trọng lượng trích xuất
Ô thứ 26 đóng băng các trọng số đã được huấn luyện trước của bộ trích xuất đặc điểm. Đặt ô thứ hai mươi sáu thành:
### CELL 26: Freeze feature extractor ### model.freeze_feature_extractor()
Bước 3.27 - Ô 27: Thiết lập đối số huấn luyện
Ô thứ 27 khởi tạo các đối số huấn luyện sẽ được chuyển đến phiên bản Trainer
. Đặt ô thứ hai mươi bảy thành:
### CELL 27: Set training arguments ### training_args = TrainingArguments( output_dir = OUTPUT_DIR_PATH, save_safetensors = False, group_by_length = True, per_device_train_batch_size = TRAIN_BATCH_SIZE, per_device_eval_batch_size = EVAL_BATCH_SIZE, num_train_epochs = TRAIN_EPOCHS, gradient_checkpointing = True, evaluation_strategy = "steps", save_strategy = "steps", logging_strategy = "steps", eval_steps = EVAL_STEPS, save_steps = SAVE_STEPS, logging_steps = LOGGING_STEPS, learning_rate = LEARNING_RATE, warmup_steps = WARMUP_STEPS )
- Lớp
TrainingArguments
chấp nhận hơn 100 tham số . - Tham số
save_safetensors
khiFalse
chỉ định rằng mô hình đã tinh chỉnh sẽ được lưu vào tệppickle
thay vì sử dụng định dạngsafetensors
. - Tham số
group_by_length
khiTrue
cho biết rằng các mẫu có độ dài xấp xỉ nhau sẽ được nhóm lại với nhau. Điều này giảm thiểu phần đệm và cải thiện hiệu quả đào tạo. -
per_device_train_batch_size
đặt số lượng mẫu cho mỗi đợt đào tạo nhỏ. Tham số này được đặt thành18
thông qua hằng sốTRAIN_BATCH_SIZE
được chỉ định ở Bước 3.5 . Điều này ngụ ý 160 bước mỗi kỷ nguyên. -
per_device_eval_batch_size
đặt số lượng mẫu cho mỗi lô nhỏ đánh giá (tạm giữ). Tham số này được đặt thành10
thông qua hằng sốEVAL_BATCH_SIZE
được chỉ định ở Bước 3.5 . -
num_train_epochs
đặt số lượng kỷ nguyên đào tạo. Tham số này được đặt thành30
thông qua hằng sốTRAIN_EPOCHS
được chỉ định ở Bước 3.5 . Điều này ngụ ý tổng số 4.800 bước trong quá trình đào tạo. - Tham số
gradient_checkpointing
khiTrue
giúp tiết kiệm bộ nhớ bằng cách kiểm tra các phép tính gradient nhưng dẫn đến tốc độ lùi lại chậm hơn. - Tham số
evaluation_strategy
khi được đặt thànhsteps
có nghĩa là việc đánh giá sẽ được thực hiện và ghi lại trong quá trình đào tạo ở khoảng thời gian được chỉ định bởi tham sốeval_steps
. - Tham
logging_strategy
khi được đặt thànhsteps
có nghĩa là số liệu thống kê về lần chạy huấn luyện sẽ được ghi lại theo khoảng thời gian được chỉ định bởi thamlogging_steps
. - Tham số
save_strategy
khi được đặt thànhsteps
có nghĩa là điểm kiểm tra của mô hình đã tinh chỉnh sẽ được lưu trong khoảng thời gian được chỉ định bởi tham sốsave_steps
. -
eval_steps
đặt số bước giữa các lần đánh giá dữ liệu loại trừ. Tham số này được đặt thành100
thông qua hằng sốEVAL_STEPS
được gán ở Bước 3.5 . -
save_steps
đặt số bước sau đó điểm kiểm tra của mô hình đã tinh chỉnh sẽ được lưu. Tham số này được đặt thành3200
thông qua hằng sốSAVE_STEPS
được gán ở Bước 3.5 . -
logging_steps
đặt số bước giữa các nhật ký thống kê về lần chạy tập luyện. Tham số này được đặt thành100
thông qua hằng sốLOGGING_STEPS
được gán ở Bước 3.5 . - Tham số
learning_rate
đặt tốc độ học ban đầu. Tham số này được đặt thành1e-4
thông qua hằng sốLEARNING_RATE
được gán ở Bước 3.5 . - Tham số
warmup_steps
đặt số bước để tăng tốc độ học tập một cách tuyến tính từ 0 đến giá trị dolearning_rate
đặt. Tham số này được đặt thành800
thông qua hằng sốWARMUP_STEPS
được gán ở Bước 3.5 .
Bước 3.28 - Ô 28: Xác định logic của bộ đối chiếu dữ liệu
Ô thứ hai mươi tám xác định logic cho các chuỗi mục tiêu và đầu vào đệm động. Đặt ô thứ hai mươi tám thành:
### CELL 28: Define data collator logic ### @dataclass class DataCollatorCTCWithPadding: processor: Wav2Vec2Processor padding: Union[bool, str] = True max_length: Optional[int] = None max_length_labels: Optional[int] = None pad_to_multiple_of: Optional[int] = None pad_to_multiple_of_labels: Optional[int] = None def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]: input_features = [{"input_values": feature["input_values"]} for feature in features] label_features = [{"input_ids": feature["labels"]} for feature in features] batch = self.processor.pad( input_features, padding = self.padding, max_length = self.max_length, pad_to_multiple_of = self.pad_to_multiple_of, return_tensors = "pt", ) with self.processor.as_target_processor(): labels_batch = self.processor.pad( label_features, padding = self.padding, max_length = self.max_length_labels, pad_to_multiple_of = self.pad_to_multiple_of_labels, return_tensors = "pt", ) labels = labels_batch["input_ids"].masked_fill(labels_batch.attention_mask.ne(1), -100) batch["labels"] = labels return batch
- Các cặp nhãn đầu vào đào tạo và đánh giá được chuyển theo từng đợt nhỏ tới phiên bản
Trainer
. Các cặp này sẽ được khởi tạo trong giây lát ở Bước 3.30 . Do các chuỗi đầu vào và chuỗi nhãn có độ dài khác nhau trong mỗi lô nhỏ nên một số chuỗi phải được đệm để chúng có cùng độ dài. - Lớp
DataCollatorCTCWithPadding
tự động đệm dữ liệu theo lô nhỏ. Thông sốpadding
khi được đặt thànhTrue
chỉ định rằng chuỗi tính năng đầu vào âm thanh ngắn hơn và chuỗi nhãn phải có cùng độ dài với chuỗi dài nhất trong một lô nhỏ. - Các tính năng đầu vào âm thanh được đệm bằng giá trị
0.0
được đặt khi khởi chạy trình trích xuất tính năng ở Bước 3.18 . - Đầu vào nhãn trước tiên được đệm bằng giá trị đệm được đặt khi khởi tạo bộ mã thông báo ở Bước 3.17 . Các giá trị này được thay thế bằng
-100
để các nhãn này bị bỏ qua khi tính chỉ số WER.
Bước 3.29 - CELL 29: Khởi tạo Instance của Data Collator
Ô thứ 29 khởi tạo một phiên bản của bộ đối chiếu dữ liệu được xác định ở bước trước. Đặt ô thứ hai mươi chín thành:
### CELL 29: Initialize instance of data collator ### data_collator = DataCollatorCTCWithPadding(processor = processor, padding = True)
Bước 3.30 - Ô 30: Khởi tạo Trainer
Ô thứ ba mươi khởi tạo một thể hiện của lớp Trainer
. Đặt ô thứ ba mươi thành:
### CELL 30: Initialize trainer ### trainer = Trainer( model = model, data_collator = data_collator, args = training_args, compute_metrics = compute_wer, train_dataset = train_data, eval_dataset = valid_data, tokenizer = processor.feature_extractor )
- Như đã thấy, lớp
Trainer
được khởi tạo bằng:-
model
được huấn luyện trước được khởi tạo ở Bước 3.25 . - Trình đối chiếu dữ liệu được khởi tạo ở Bước 3.29 .
- Các đối số huấn luyện được khởi tạo ở Bước 3.27 .
- Phương pháp đánh giá WER được xác định ở Bước 3.9 .
- Đối tượng
Dataset
train_data
từ Bước 3.24 . - Đối tượng
Dataset
valid_data
từ Bước 3.24 .
-
- Tham số
tokenizer
được gán choprocessor.feature_extractor
và hoạt động vớidata_collator
để tự động đệm đầu vào vào đầu vào có độ dài tối đa của mỗi lô nhỏ.
Bước 3.31 - Ô 31: Tinh chỉnh mô hình
Ô thứ 31 gọi phương thức train
trên phiên bản lớp Trainer
để tinh chỉnh mô hình. Đặt ô thứ ba mươi mốt thành:
### CELL 31: Finetune the model ### trainer.train()
Bước 3.32 - Ô 32: Lưu mô hình đã tinh chỉnh
Ô thứ ba mươi hai là ô sổ tay cuối cùng. Nó lưu mô hình đã tinh chỉnh bằng cách gọi phương thức save_model
trên phiên bản Trainer
. Đặt ô thứ ba mươi giây thành:
### CELL 32: Save the finetuned model ### trainer.save_model(OUTPUT_DIR_PATH)
Bước 4 - Đào tạo và lưu mô hình
Bước 4.1 - Đào tạo mô hình
Bây giờ tất cả các ô của sổ ghi chép đã được tạo xong, đã đến lúc bắt đầu tinh chỉnh.
Đặt Máy tính xách tay Kaggle chạy với bộ tăng tốc NVIDIA GPU P100 .
Cam kết sổ ghi chép trên Kaggle.
Giám sát dữ liệu lần chạy tập luyện bằng cách đăng nhập vào tài khoản WandB của bạn và định vị lần chạy liên quan.
Quá trình đào tạo trên 30 kỷ nguyên sẽ mất khoảng 5 giờ bằng cách sử dụng bộ tăng tốc NVIDIA GPU P100. WER trên dữ liệu loại trừ sẽ giảm xuống ~ 0,15 khi kết thúc khóa đào tạo. Đây không hẳn là một kết quả hiện đại nhưng mô hình đã được tinh chỉnh vẫn đủ hữu ích cho nhiều ứng dụng.
Bước 4.2 - Lưu mô hình
Mô hình đã tinh chỉnh sẽ được xuất ra thư mục Kaggle được chỉ định bởi hằng số OUTPUT_DIR_PATH
được chỉ định trong Bước 3.5 . Đầu ra của mô hình phải bao gồm các tệp sau:
pytorch_model.bin config.json preprocessor_config.json vocab.json training_args.bin
Những tập tin này có thể được tải xuống cục bộ. Ngoài ra, bạn có thể tạo Mô hình Kaggle mới bằng cách sử dụng các tệp mô hình. Mô hình Kaggle sẽ được sử dụng cùng với hướng dẫn suy luận đồng hành để chạy suy luận trên mô hình đã được tinh chỉnh.
- Đăng nhập vào tài khoản Kaggle của bạn. Bấm vào Mô hình > Mô hình mới .
- Thêm tiêu đề cho mô hình đã tinh chỉnh của bạn trong trường Tiêu đề mô hình .
- Bấm vào Tạo mô hình .
- Bấm vào Đi tới trang chi tiết mô hình .
- Nhấp vào Thêm biến thể mới trong Biến thể mẫu .
- Chọn Transformers từ menu chọn Framework .
- Nhấp vào Thêm biến thể mới .
- Kéo và thả các tệp mô hình đã tinh chỉnh của bạn vào cửa sổ Tải lên dữ liệu . Ngoài ra, hãy nhấp vào nút Duyệt tệp để mở cửa sổ trình khám phá tệp và chọn các tệp mô hình đã được tinh chỉnh của bạn.
- Khi các file đã được tải lên Kaggle, hãy nhấp vào Create để tạo Kaggle Model .
Phần kết luận
Chúc mừng bạn đã hoàn thiện wav2vec2 XLS-R! Hãy nhớ rằng bạn có thể sử dụng các bước chung này để tinh chỉnh mô hình trên các ngôn ngữ khác mà bạn mong muốn. Việc chạy suy luận trên mô hình tinh chỉnh được tạo trong hướng dẫn này khá đơn giản. Các bước suy luận sẽ được trình bày trong hướng dẫn đồng hành riêng cho hướng dẫn này. Vui lòng tìm kiếm tên người dùng HackerNoon của tôi để tìm hướng dẫn đồng hành.