Este é um guia complementar para Trabalhar com wav2vec2 Parte 1 - Ajuste fino do XLS-R para reconhecimento automático de fala (o "guia da Parte 1"). Escrevi o guia da Parte 1 sobre como ajustar o modelo wav2vec2 XLS-R ("XLS-R") do Meta AI no espanhol chileno. Supõe-se que você tenha concluído esse guia e gerado seu próprio modelo XLS-R ajustado. Este guia explicará as etapas para executar inferência em seu modelo XLS-R ajustado por meio de um Kaggle Notebook .
Para completar o guia, você precisará ter:
spanish-asr-inference
.Este guia usa o conjunto de dados de fala do espanhol peruano como fonte de dados de teste. Assim como o conjunto de dados da fala do espanhol chileno , o conjunto de dados de falantes peruanos também consiste em dois subconjuntos de dados: 2.918 gravações de falantes peruanos do sexo masculino e 2.529 gravações de falantes peruanos do sexo feminino.
Este conjunto de dados foi carregado no Kaggle como 2 conjuntos de dados distintos:
Adicione esses dois conjuntos de dados ao seu Kaggle Notebook clicando em Add Input .
Você deveria ter salvo seu modelo ajustado na Etapa 4 do guia Trabalhando com wav2vec2 Parte 1 - Ajuste fino do XLS-R para reconhecimento automático de fala como um modelo Kaggle .
Adicione seu modelo ajustado ao seu Kaggle Notebook clicando em Add Input .
As 16 subetapas a seguir constroem cada uma das 16 células do caderno de inferência em ordem. Você notará que muitos dos mesmos métodos utilitários do guia da Parte 1 são usados aqui.
A primeira célula do bloco de notas de inferência instala dependências. Defina a primeira célula como:
### CELL 1: Install Packages ### !pip install --upgrade torchaudio !pip install jiwer
As importações da segunda célula exigiam pacotes Python. Defina a segunda célula como:
### CELL 2: Import Python packages ### import re import math import random import pandas as pd import torchaudio from datasets import load_metric from transformers import pipeline
A terceira célula importa a métrica de avaliação HuggingFace WER. Defina a terceira célula como:
### CELL 3: Load WER metric ### wer_metric = load_metric("wer")
A quarta célula define constantes que serão usadas em todo o notebook. Defina a quarta célula como:
### CELL 4: Constants ### # Testing data TEST_DATA_PATH_MALE = "/kaggle/input/google-spanish-speakers-peru-male/" TEST_DATA_PATH_FEMALE = "/kaggle/input/google-spanish-speakers-peru-female/" EXT = ".wav" NUM_LOAD_FROM_EACH_SET = 3 # Special characters SPECIAL_CHARS = r"[\d\,\-\;\!\¡\?\¿\।\'\'\"\–\'\:\/\.\“\”\৷\…\‚\॥\\]" # Sampling rates ORIG_SAMPLING_RATE = 48000 TGT_SAMPLING_RATE = 16000
A quinta célula define métodos utilitários para ler os arquivos de índice do conjunto de dados, bem como para limpar o texto da transcrição e gerar um conjunto aleatório de amostras a partir dos dados de teste. Defina a quinta célula como:
### CELL 5: Utility methods for reading index files, cleaning text, random indices generator ### def read_index_file_data(path: str, filename: str): data = [] with open(path + filename, "r", encoding = "utf8") as f: lines = f.readlines() for line in lines: file_and_text = line.split("\t") data.append([path + file_and_text[0] + EXT, file_and_text[1].replace("\n", "")]) return data def clean_text(text: str) -> str: cleaned_text = re.sub(SPECIAL_CHARS, "", text) cleaned_text = cleaned_text.lower() return cleaned_text def get_random_samples(dataset: list, num: int) -> list: used = [] samples = [] for i in range(num): a = -1 while a == -1 or a in used: a = math.floor(len(dataset) * random.random()) samples.append(dataset[a]) used.append(a) return samples
O método read_index_file_data
lê um arquivo de índice do conjunto de dados line_index.tsv
e produz uma lista de listas com nomes de arquivos de áudio e dados de transcrição, por exemplo:
[ ["/kaggle/input/google-spanish-speakers-chile-male/clm_08421_01719502739", "Es un viaje de negocios solamente voy por una noche"] ... ]
clean_text
é usado para retirar cada transcrição de texto dos caracteres especificados pela expressão regular atribuída a SPECIAL_CHARS
na Etapa 2.4 . Esses caracteres, inclusive a pontuação, podem ser eliminados, pois não fornecem nenhum valor semântico ao treinar o modelo para aprender mapeamentos entre recursos de áudio e transcrições de texto.get_random_samples
retorna um conjunto de amostras de teste aleatórias com a quantidade definida pela constante NUM_LOAD_FROM_EACH_SET
na Etapa 2.4 . A sexta célula define métodos utilitários usando torchaudio
para carregar e reamostrar dados de áudio. Defina a sexta célula como:
### CELL 7: Utility methods for loading and resampling audio data ### def read_audio_data(file): speech_array, sampling_rate = torchaudio.load(file, normalize = True) return speech_array, sampling_rate def resample(waveform): transform = torchaudio.transforms.Resample(ORIG_SAMPLING_RATE, TGT_SAMPLING_RATE) waveform = transform(waveform) return waveform[0]
read_audio_data
carrega um arquivo de áudio especificado e retorna uma matriz multidimensional torch.Tensor
dos dados de áudio junto com a taxa de amostragem do áudio. Todos os arquivos de áudio nos dados de treinamento têm uma taxa de amostragem de 48000
Hz. Esta taxa de amostragem "original" é capturada pela constante ORIG_SAMPLING_RATE
na Etapa 2.4 .resample
é usado para reduzir a resolução de dados de áudio de uma taxa de amostragem de 48000
para a taxa de amostragem alvo de 16000
. A sétima célula lê os arquivos de índice de dados de teste para as gravações de falantes do sexo masculino e as gravações de falantes do sexo feminino usando o método read_index_file_data
definido na Etapa 2.5 . Defina a sétima célula como:
### CELL 7: Read test data ### test_data_male = read_index_file_data(TEST_DATA_PATH_MALE, "line_index.tsv") test_data_female = read_index_file_data(TEST_DATA_PATH_FEMALE, "line_index.tsv")
A oitava célula gera conjuntos de amostras de teste aleatórias usando o método get_random_samples
definido na Etapa 2.5 . Defina a oitava célula como:
### CELL 8: Generate lists of random test samples ### random_test_samples_male = get_random_samples(test_data_male, NUM_LOAD_FROM_EACH_SET) random_test_samples_female = get_random_samples(test_data_female, NUM_LOAD_FROM_EACH_SET)
A nona célula combina as amostras de teste masculinas e as amostras de teste femininas em uma única lista. Defina a nona célula como:
### CELL 9: Combine test data ### all_test_samples = random_test_samples_male + random_test_samples_female
A décima célula itera sobre cada amostra de dados de teste e limpa o texto de transcrição associado usando o método clean_text
definido na Etapa 2.5 . Defina a décima célula como:
### CELL 10: Clean text transcriptions ### for index in range(len(all_test_samples)): all_test_samples[index][1] = clean_text(all_test_samples[index][1])
A décima primeira célula carrega cada arquivo de áudio especificado na lista all_test_samples
. Defina a décima primeira célula como:
### CELL 11: Load audio data ### all_test_data = [] for index in range(len(all_test_samples)): speech_array, sampling_rate = read_audio_data(all_test_samples[index][0]) all_test_data.append({ "raw": speech_array, "sampling_rate": sampling_rate, "target_text": all_test_samples[index][1] })
torch.Tensor
e armazenados em all_test_data
como uma lista de dicionários. Cada dicionário contém os dados de áudio de uma amostra específica, a taxa de amostragem e a transcrição do texto do áudio. A décima segunda célula reamostra os dados de áudio para a taxa de amostragem alvo de 16000
. Defina a décima segunda célula como:
### CELL 12: Resample audio data and cast to NumPy arrays ### all_test_data = [{"raw": resample(sample["raw"]).numpy(), "sampling_rate": TGT_SAMPLING_RATE, "target_text": sample["target_text"]} for sample in all_test_data]
A décima terceira célula inicializa uma instância da classe pipeline
da biblioteca transformer
HuggingFace. Defina a décima terceira célula como:
### CELL 13: Initialize instance of Automatic Speech Recognition Pipeline ### transcriber = pipeline("automatic-speech-recognition", model = "YOUR_FINETUNED_MODEL_PATH")
O parâmetro model
deve ser definido como o caminho para o seu modelo ajustado adicionado ao Kaggle Notebook na Etapa 1.3 , por exemplo:
transcriber = pipeline("automatic-speech-recognition", model = "/kaggle/input/xls-r-300m-chilean-spanish/transformers/hardy-pine/1")
A décima quarta célula chama o transcriber
inicializado na etapa anterior nos dados de teste para gerar previsões de texto. Defina a décima quarta célula como:
### CELL 14: Generate transcriptions ### transcriptions = transcriber(all_test_data)
A décima quinta célula calcula as pontuações WER para cada previsão, bem como uma pontuação WER geral para todas as previsões. Defina a décima quinta célula como:
### CELL 15: Calculate WER metrics ### predictions = [transcription["text"] for transcription in transcriptions] references = [transcription["target_text"][0] for transcription in transcriptions] wers = [] for p in range(len(predictions)): wer = wer_metric.compute(predictions = [predictions[p]], references = [references[p]]) wers.append(wer) zipped = list(zip(predictions, references, wers)) df = pd.DataFrame(zipped, columns=["Prediction", "Reference", "WER"]) wer = wer_metric.compute(predictions = predictions, references = references)
A décima sexta e última célula simplesmente imprime os cálculos do WER da etapa anterior. Defina a décima sexta célula como:
### CELL 16: Output WER metrics ### pd.set_option("display.max_colwidth", None) print(f"Overall WER: {wer}") print(df)
Como o notebook gera previsões em amostras aleatórias de dados de teste, a saída variará cada vez que o notebook for executado. A seguinte saída foi gerada em uma execução do notebook com NUM_LOAD_FROM_EACH_SET
definido como 3
para um total de 6 amostras de teste:
Overall WER: 0.013888888888888888 Prediction \ 0 quiero que me reserves el mejor asiento del teatro 1 el llano en llamas es un clásico de juan rulfo 2 el cuadro de los alcatraces es una de las pinturas más famosas de diego rivera 3 hay tres cafés que están abiertos hasta las once de la noche 4 quiero que me recomiendes una dieta pero donde uno pueda comer algo no puras verduras 5 cuántos albergues se abrieron después del terremoto del diecinueve de setiembre Reference \ 0 quiero que me reserves el mejor asiento del teatro 1 el llano en llamas es un clásico de juan rulfo 2 el cuadro de los alcatraces es una de las pinturas más famosas de diego rivera 3 hay tres cafés que están abiertos hasta las once de la noche 4 quiero que me recomiendes una dieta pero donde uno pueda comer algo no puras verduras 5 cuántos albergues se abrieron después del terremoto del diecinueve de septiembre WER 0 0.000000 1 0.000000 2 0.000000 3 0.000000 4 0.000000 5 0.090909
Como pode ser visto, o modelo fez um excelente trabalho! Cometeu apenas um erro com a sexta amostra (índice 5
), escrevendo incorretamente a palavra septiembre
como setiembre
. É claro que executar o notebook novamente com amostras de teste diferentes e, mais importante, com um número maior de amostras de teste, produzirá resultados diferentes e mais informativos. No entanto, estes dados limitados sugerem que o modelo pode ter um bom desempenho em diferentes dialetos do espanhol – ou seja, foi treinado em espanhol chileno, mas parece ter um bom desempenho em espanhol peruano.
Se você está apenas aprendendo como trabalhar com modelos wav2vec2, espero que o guia Trabalhando com wav2vec2 Parte 1 - Ajuste fino do XLS-R para reconhecimento automático de fala e este guia tenham sido úteis para você. Conforme mencionado, o modelo ajustado gerado pelo guia da Parte 1 não é exatamente o que há de mais moderno, mas ainda assim deve ser útil para muitas aplicações. Feliz edifício!