Wav2Vec2.0 in Unity Sentis: export to ONNX

The layers of Wav2Vec2.0.

Lately I have been digging around the use of Automatic Speech Recognition (ASR) systems, and their use in Unity to analyze speech without requiring a high-end graphic card. It drove me back to integrating Hugging Face model inside Unity, let’s have a look at the method used.

This is the first one of a series of two posts, and here we will focus on preparing the model for the outside world.

Finding an ASR model

First, we need to choose an ASR model. As of February 2025, Whisper is know to be an excellent generalist model. However, it is resources consuming, and is not really good at detecting isolated words. On the other hand, Wav2Vec2.0 is a simpler model, and it gives better results for word-level detection while using less resources.

I actually implemented both, yet for the sake of this post, we will have a look at Wav2Vec2.0. On of the reasons is that Whisper already has some nice integrations (whisper.unity, simpler export to ONNX) as it is more popular.

We will go with the specific checkpoint (trained model) Wav2Vec2 LJSpeech Gruut. It is a phoneme model, its out won’t be a word such as “hello”, but a string in IPA, so “hɛlˈoʊ”. It will be easier to check if a non-word is correct this way.

Usage in Python

Our Wav2Vec2Phoneme model is available on Hugging Face, which has most of its libraries in Python, so let’s start with this language.

We will define our first function to import the model checkpoint:

import pathlib
from itertools import groupby

from datasets import load_dataset
import librosa
import numpy as np
import onnx
import onnxruntime as ort
import torch
import transformers


def get_model():
    """
    Get the Wav2Vec2.0 model.

    The processor will be a Wav2Vec2Processor object.

    Source code: https://github.com/huggingface/transformers/tree/main/src/transformers/models/wav2vec2
    Looking at the code, the processor for audio may only be a feature extractor.

    :return tuple: Model and processor.
    """
    checkpoint = "bookbot/wav2vec2-ljspeech-gruut"

    model = transformers.AutoModelForCTC.from_pretrained(checkpoint)
    processor = transformers.AutoProcessor.from_pretrained(checkpoint)
    return model, processor

Now lets define some input audio data.

def get_audio_data(local_file=None):
    """Read an audio file as an audio array."""
    if local_file is None:
        # load dummy dataset and read soundfiles
        ds = load_dataset("patrickvonplaten/librispeech_asr_dummy", "clean", split="validation")
        audio_array = ds[0]["audio"]["array"]
    else:
        # or, read a single audio file
        audio_array, _ = librosa.load(local_file, sr=16000)
    return audio_array

Finally let’s run the model and get a prediction.


def get_logits(audio_array, processor, model, sampling_rate=16000):
    """
    Get the logits for the audio.

    For Unity: https://discussions.unity.com/t/how-to-use-logits-output-type/1540746

    :param audio_array: Input audio
    :param processor: Processor or feature extractor for the input.
    :param model: Actual model.
    :param sampling_rate: Samplig rate
    :return:
    """
    # On easy cases, inputs = (audio_array - audio_array.mean()) / sqrt(audio_array.var())
    inputs = processor(
        audio_array,
        return_tensors="pt",
        padding=True,
        sampling_rate=sampling_rate
    )

    with torch.no_grad():
        logits = model(inputs["input_values"]).logits
    return logits


def decode_phonemes(
    ids: torch.Tensor,
    processor: transformers.Wav2Vec2Processor,
    ignore_stress: bool = False
) -> str:
    """
    CTC-like decoding.
    First removes consecutive duplicates, then removes special tokens.
    """
    # removes consecutive duplicates
    ids = [id_ for id_, _ in groupby(ids)]

    special_token_ids = processor.tokenizer.all_special_ids + [
        processor.tokenizer.word_delimiter_token_id
    ]
    # converts id to token, skipping special tokens
    phonemes = [processor.decode(id_) for id_ in ids if id_ not in special_token_ids]

    # joins phonemes
    prediction = " ".join(phonemes)

    # whether to ignore IPA stress marks
    if ignore_stress:
        prediction = prediction.replace("ˈ", "").replace("ˌ", "")

    return prediction


def predicted_answer(logits, processor):
    """From a matrix of logits, get the best prediction match."""
    predicted_ids = torch.argmax(logits, dim=-1)
    my_prediction = decode_phonemes(predicted_ids[0], processor, ignore_stress=True)
    return my_prediction


def run_analysis():
    """Run the model."""
    model, processor = get_model()
    sampling_rate = processor.feature_extractor.sampling_rate
    audio_array = get_audio_data()
    logits = get_logits(
        audio_array,
        processor,
        model,
        sampling_rate
    )
    prediction = predicted_answer(logits, processor)
    return prediction

The model provider actually wrote a simple Python use case, and is pretty straightforward to use. The idea is to get an audio array from an audio file at 16 kHz, pass it to a processor, and then give it to the model. Let’s try to understand it.

Input

First, we pass the audio in a simple Wav2Vec2Processor. From what I tried, it only normalizes the input data, following a Z-score normalization. It may convert the audio array to 16 kHz. So from our initial audio array, we get an array of the same size, but with a better data range.

Removing the processor does not seem to change the model outputs, so I guess it ships with an internal Z-score normalization.

New, let’s pass the processed audio array as the model input.

Output

The model output is in the form of logits which is simply a matrix of probabilities for each timestamp, for each phoneme. In C#, this is a matrix of size float[n_timestamps][n_possible_phonemes]. This output matrix will be useful later, but for now let’s get the first predicted answer.

Decoding the output

The model output is a matrix of probabilities, lets get phonemes from there.

We apply an Arg Max function, which gives use the index of the token with the hightest prediction ratio for each timestamp. We now have a list int[n_timestamps] (remember Arg Max selects the index).

As the matrix was a prediction score for each phoneme, getting the index of the most probable phoneme means we get the id of the phoneme itself. To access the list of IDs, we can have a look at vocab.json. We can build a simple map int → string from it, and voilà, our output is a string of phonemes.

Export to ONNX

Now that we understood our model, let’s use it in Unity. The official, and best way, to integrate AI to Unity is through Unity Sentis. As it only accepts ONNX file format we will go with this open-source format. We will only export the model part (not the processor).

Luckily enough, someone already did it, the only change is to set opset_version to 15 to be compatible with Sentis.

 
def convert_to_onnx(onnx_model_name="wav2vec2.onnx"):
    """
    Convert a Wav2Vec2 model using a low-level method.

    https://github.com/ccoreilly/wav2vec2-service/blob/master/convert_torch_to_onnx.py

    :param string onnx_model_name: Name of the output file.
    """
    model, _ = get_model()
    audio_len = 250000

    model_input = torch.randn(1, audio_len, requires_grad=True)

    torch.onnx.export(
        model,                          # model being run
        model_input,                    # model input (or a tuple for multiple inputs)
        onnx_model_name,                # where to save the model (can be a file or file-like object)
        export_params=True,             # store the trained parameter weights inside the model file
        opset_version=15,               # the ONNX version to export the model to, 15 is recommended version for Unity
        do_constant_folding=True,       # whether to execute constant folding for optimization
        input_names=['input'],          # the model's input names
        output_names=['output'],        # the model's output names
        dynamic_axes={
            'input': {1: 'audio_len'},    # variable length axes
            'output': {1: 'audio_len'}
        }
    )
    print(f"Model saved to {onnx_model_name} successfully!")

For some models, Hugging Face provides Optimum to do the conversion from safetensors to ONNX with a CLI. It was not available for this checkpoint of Wav2Vec2, so crafted some code.

If you are lucky enough, running the code will yield a nice .onnx file.

Checking if everything works: import the ONNX file in Python

It is a good idea to see if, at this point, our model is still functional. We will import the ONNX file in PyTorch, run some demo data and try to get the phonemes.


def run_from_onnx(audio_array, model_file="wav2vec2.onnx", check_model=False):
    """
    Run the Wav2Vec2.0 model from an ONNX file.
     
    :param list[float] audio_array: Input data, at 16 kHz.
    :param string model_file: Path to the ONNX file.  
    :param bool check_model: To apply supplementary check to the model file. 
    :return string: Predicted phonemes. 
    """
    audio_array = np.array(audio_array, dtype=np.float32).reshape((1, -1))
    if check_model:
        onnx_model = onnx.load(model_file)
        onnx.checker.check_model(onnx_model)

    ort_sess = ort.InferenceSession(model_file)
    logits = ort_sess.run(None, {'input': audio_array})[0]
    
    # The following is to pass from the logits to the prediction, not the hardest part
    _, processor = get_model()
    prediction = predicted_answer(torch.as_tensor(logits), processor)
    return prediction

We can see our answer is the same. Another cool thing with the ONNX is it is easy to visualize, we can use neutron.app.

Conclusion

To recap, we exported a Python model that takes an audio as an input, and return it phonemes representation as the output. We used the ONNX format as an interchange file format. As our final goal is to integrate is within Unity, we will cover that in another post!

Comments

Leave a Reply

Your email address will not be published. Required fields are marked *