In [1]:
!pip install --upgrade datasets[audio] transformers accelerate evaluate jiwer tensorboard gradio

[0m

In [2]:
# import huggingface_hub

# huggingface_hub.login()

## Imports

In [3]:
from datasets import load_dataset, DatasetDict
from transformers import (
    WhisperTokenizer, 
    WhisperProcessor, 
    WhisperFeatureExtractor, 
    WhisperForConditionalGeneration, 
    Seq2SeqTrainingArguments, 
    Seq2SeqTrainer
) 
from datasets import Audio
from dataclasses import dataclass
from typing import Any, Dict, List, Union

import torch
import evaluate

In [4]:
model_id = 'openai/whisper-tiny'
out_dir = 'whisper_tiny_atco2_v2'
epochs = 10
batch_size = 32

## Load Dataset

In [5]:
atc_dataset_train = load_dataset('jlvdoorn/atco2-asr-atcosim', split='train')
atc_dataset_valid = load_dataset('jlvdoorn/atco2-asr-atcosim', split='validation')

In [6]:
print(atc_dataset_train)
print(atc_dataset_valid)

Dataset({
    features: ['audio', 'text', 'info'],
    num_rows: 8092
})
Dataset({
    features: ['audio', 'text', 'info'],
    num_rows: 2026
})


In [7]:
print(atc_dataset_train[0])

{'audio': {'path': 'LKPR_RUZYNE_Radar_120_520MHz_20201025_091112.wav', 'array': array([ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
       -6.10351562e-05, -6.10351562e-05, -6.10351562e-05]), 'sampling_rate': 16000}, 'text': 'Oscar Kilo Papa Mike Bravo descend flight level one hundred level one hundred Oscar Kilo Papa Mike Bravo ', 'info': 'LKPR\nPraha Ruzyne\nRadar\nAKEVA ARVEG BAGRU BAROX BAVIN BEKVI ELMEK ELPON ERASU EVEMI KENOK KUVIX LETNA RATEV RISUK SOMIS SULOV TIPRU UTORO\nBLA131 BLA1XQ BTI7PY CTN480 DLH3NL DLH9TP ETD72E EWG6HP FIN1DH IRA711 KLM44K MLD863 MLD864 OKHBT OKLLZ OKMHZ OKPHM OKWUS17 OKYAI14 RYR1JU RYR4945 SXS7D THY32B THY6577 TIE790J UAE73  \nAll Charter Air Baltic Croatia Lufthansa Etihad Eurowings Finn Iranair Klm Moldova Oklahoma Okapi Alfa Ryan Sunexpress Turkish Time Emirates'}


## Feature Extractor and Tokenizer 

In [8]:
feature_extractor = WhisperFeatureExtractor.from_pretrained(model_id)

In [9]:
tokenizer = WhisperTokenizer.from_pretrained(model_id, language='English', task='transcribe')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [10]:
processor = WhisperProcessor.from_pretrained(model_id, language='English', task='transcribe')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


### Prepare Data

In [11]:
atc_dataset_train = atc_dataset_train.cast_column('audio', Audio(sampling_rate=16000))
atc_dataset_valid = atc_dataset_valid.cast_column('audio', Audio(sampling_rate=16000))

In [12]:
print(atc_dataset_train[0])

{'audio': {'path': 'LKPR_RUZYNE_Radar_120_520MHz_20201025_091112.wav', 'array': array([ 0.00000000e+00,  0.00000000e+00,  0.00000000e+00, ...,
       -6.10351562e-05, -6.10351562e-05, -6.10351562e-05]), 'sampling_rate': 16000}, 'text': 'Oscar Kilo Papa Mike Bravo descend flight level one hundred level one hundred Oscar Kilo Papa Mike Bravo ', 'info': 'LKPR\nPraha Ruzyne\nRadar\nAKEVA ARVEG BAGRU BAROX BAVIN BEKVI ELMEK ELPON ERASU EVEMI KENOK KUVIX LETNA RATEV RISUK SOMIS SULOV TIPRU UTORO\nBLA131 BLA1XQ BTI7PY CTN480 DLH3NL DLH9TP ETD72E EWG6HP FIN1DH IRA711 KLM44K MLD863 MLD864 OKHBT OKLLZ OKMHZ OKPHM OKWUS17 OKYAI14 RYR1JU RYR4945 SXS7D THY32B THY6577 TIE790J UAE73  \nAll Charter Air Baltic Croatia Lufthansa Etihad Eurowings Finn Iranair Klm Moldova Oklahoma Okapi Alfa Ryan Sunexpress Turkish Time Emirates'}


In [13]:
def prepare_dataset(batch):
    audio = batch['audio']

    batch['input_features'] = feature_extractor(audio['array'], sampling_rate=audio['sampling_rate']).input_features[0]

    batch['labels'] = tokenizer(batch['text']).input_ids

    return batch

In [14]:
atc_dataset_train = atc_dataset_train.map(
    prepare_dataset, 
    num_proc=4
)

atc_dataset_valid = atc_dataset_valid.map(
    prepare_dataset, 
    num_proc=4
)

### Data Collator

In [15]:
@dataclass
class DataCollatorSpeechSeq2SeqWithPadding:
    processor: Any
    decoder_start_token_id: int

    def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
        input_features = [{'input_features': feature['input_features']} for feature in features]
        batch = self.processor.feature_extractor.pad(input_features, return_tensors='pt')

        label_features = [{'input_ids': feature['labels']} for feature in features]
        labels_batch = self.processor.tokenizer.pad(label_features, return_tensors='pt')

        labels = labels_batch['input_ids'].masked_fill(labels_batch.attention_mask.ne(1), -100)

        if (labels[:, 0] == self.decoder_start_token_id).all().cpu().item():
            labels = labels[:, 1:]

        batch['labels'] = labels

        return batch

## Whisper Model

In [16]:
model = WhisperForConditionalGeneration.from_pretrained(model_id)

In [17]:
model.generation_config.task = 'transcribe'

model.generation_config.forced_decoder_ids = None

In [18]:
data_collator = DataCollatorSpeechSeq2SeqWithPadding(
    processor=processor,
    decoder_start_token_id=model.config.decoder_start_token_id,
)

### Evaluation Metrics

In [19]:
metric = evaluate.load('wer')

In [20]:
def compute_metrics(pred):
    pred_ids = pred.predictions
    label_ids = pred.label_ids

    # replace -100 with the pad_token_id
    label_ids[label_ids == -100] = tokenizer.pad_token_id

    # we do not want to group tokens when computing the metrics
    pred_str = tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = tokenizer.batch_decode(label_ids, skip_special_tokens=True)

    wer = 100 * metric.compute(predictions=pred_str, references=label_str)

    return {'wer': wer}

### Define the Training Configuration

In [21]:
training_args = Seq2SeqTrainingArguments(
    output_dir=out_dir, 
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    gradient_accumulation_steps=1, 
    learning_rate=0.00001,
    warmup_steps=1000,
    bf16=True,
    fp16=False,
    num_train_epochs=epochs,
    evaluation_strategy='epoch',
    logging_strategy='epoch',
    save_strategy='epoch',
    predict_with_generate=True,
    generation_max_length=225,
    report_to=['tensorboard'],
    load_best_model_at_end=True,
    metric_for_best_model='wer',
    greater_is_better=False,
    dataloader_num_workers=8,
    save_total_limit=2,
    lr_scheduler_type='constant',
    seed=42,
    data_seed=42
)



In [22]:
trainer = Seq2SeqTrainer(
    args=training_args,
    model=model,
    train_dataset=atc_dataset_train,
    eval_dataset=atc_dataset_valid,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    tokenizer=processor.feature_extractor,
)

### Training

In [23]:
trainer.train()

Epoch,Training Loss,Validation Loss,Wer
1,0.4413,0.165822,11.994632
2,0.1235,0.125956,11.684281
3,0.0699,0.114142,8.333333
4,0.0405,0.110083,6.915786
5,0.024,0.113671,6.420903
6,0.0146,0.111585,5.976346
7,0.0086,0.112418,7.867807
8,0.0054,0.111893,6.601241
9,0.0032,0.113409,6.160879
10,0.0027,0.1125,5.930213


You have passed task=transcribe, but also have set `forced_decoder_ids` to [[1, 50259], [2, 50359], [3, 50363]] which creates a conflict. `forced_decoder_ids` will be ignored in favor of task=transcribe.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token.As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50358, 50359, 50360, 503

TrainOutput(global_step=2530, training_loss=0.07335695365200871, metrics={'train_runtime': 2287.9828, 'train_samples_per_second': 35.367, 'train_steps_per_second': 1.106, 'total_flos': 1.9921601839104e+18, 'train_loss': 0.07335695365200871, 'epoch': 10.0})

In [24]:
model.save_pretrained(f"{out_dir}/best_model")
tokenizer.save_pretrained(f"{out_dir}/best_model")
processor.save_pretrained(f"{out_dir}/best_model")

Non-default generation parameters: {'max_length': 448, 'suppress_tokens': [1, 2, 7, 8, 9, 10, 14, 25, 26, 27, 28, 29, 31, 58, 59, 60, 61, 62, 63, 90, 91, 92, 93, 359, 503, 522, 542, 873, 893, 902, 918, 922, 931, 1350, 1853, 1982, 2460, 2627, 3246, 3253, 3268, 3536, 3846, 3961, 4183, 4667, 6585, 6647, 7273, 9061, 9383, 10428, 10929, 11938, 12033, 12331, 12562, 13793, 14157, 14635, 15265, 15618, 16553, 16604, 18362, 18956, 20075, 21675, 22520, 26130, 26161, 26435, 28279, 29464, 31650, 32302, 32470, 36865, 42863, 47425, 49870, 50254, 50258, 50358, 50359, 50360, 50361, 50362], 'begin_suppress_tokens': [220, 50257]}


[]

In [27]:
!zip -r whisper_tiny_atco2_v2 whisper_tiny_atco2_v2

  adding: whisper_tiny_atco2_v2/ (stored 0%)
  adding: whisper_tiny_atco2_v2/runs/ (stored 0%)
  adding: whisper_tiny_atco2_v2/runs/Jul19_05-51-05_e7459306a578/ (stored 0%)
  adding: whisper_tiny_atco2_v2/runs/Jul19_05-51-05_e7459306a578/events.out.tfevents.1721368268.e7459306a578.2499.0 (deflated 64%)
  adding: whisper_tiny_atco2_v2/checkpoint-1518/ (stored 0%)
  adding: whisper_tiny_atco2_v2/checkpoint-1518/config.json (deflated 63%)
  adding: whisper_tiny_atco2_v2/checkpoint-1518/generation_config.json (deflated 71%)
  adding: whisper_tiny_atco2_v2/checkpoint-1518/model.safetensors (deflated 8%)
  adding: whisper_tiny_atco2_v2/checkpoint-1518/preprocessor_config.json (deflated 42%)
  adding: whisper_tiny_atco2_v2/checkpoint-1518/training_args.bin (deflated 52%)
  adding: whisper_tiny_atco2_v2/checkpoint-1518/optimizer.pt (deflated 7%)
  adding: whisper_tiny_atco2_v2/checkpoint-1518/scheduler.pt (deflated 57%)
  adding: whisper_tiny_atco2_v2/checkpoint-1518/rng_state.pth (deflated 25