Whisper transcription on Emily

Copy Markdown View Source
Mix.install(
  [
    {:emily, "~> 0.3"},
    {:bumblebee, "~> 0.6"},
    {:tokenizers, "~> 0.5"},
    {:nx, "~> 0.10"},
    {:kino, "~> 0.14"},
    {:req, "~> 0.5"}
  ],
  config: [
    nx: [default_backend: Emily.Backend]
  ]
)

Overview

This notebook runs openai/whisper-tiny on Emily.Backend in three modes:

  1. A canned clip — a short public-domain WAV, transcribed the moment you evaluate the cell.
  2. A live recording — press record in the browser, click Transcribe, see the result.
  3. Your own input — swap the Req.get!/1 URL for another WAV, or substitute a File.read!/1 call.

All three feed the same Nx.Serving built below. Under the hood: Bumblebee's mel featurizer, Whisper's encoder self-attention, the cached encoder state, and the decoder with cross-attention — every op dispatches to MLX through Emily.Compiler.

The checkpoint is ~150 MB on first fetch. The numerical pin lives in test/emily/conformance/whisper_full_test.exs.

Loading the model

Whisper ships as four separate Bumblebee artefacts: the model (encoder + decoder weights), the featurizer (audio → mel), the tokenizer (token ids → text), and the generation config (decoder sampling strategy, special-token ids).

repo = {:hf, "openai/whisper-tiny"}

{:ok, whisper} = Bumblebee.load_model(repo)
{:ok, featurizer} = Bumblebee.load_featurizer(repo)
{:ok, tokenizer} = Bumblebee.load_tokenizer(repo)
{:ok, generation_config} = Bumblebee.load_generation_config(repo)

Building a serving

serving =
  Bumblebee.Audio.speech_to_text_whisper(whisper, featurizer, tokenizer, generation_config,
    defn_options: [compiler: Emily.Compiler]
  )

Bumblebee.Audio.speech_to_text_whisper/5 builds an Nx.Serving that composes the featurizer, the encoder forward pass, the autoregressive decoder loop (with a cached encoder state), and the tokenizer detokenize step. The compiler: Emily.Compiler attachment on :defn_options routes every compiled computation through Emily.

A canned transcription

An ~11 second JFK inauguration clip from the whisper.cpp samples. The WAV is 16 kHz 16-bit mono PCM — exactly what Whisper's featurizer expects — so the parser below just scans for the RIFF data chunk and reinterprets its bytes as signed 16-bit ints normalised into [-1, 1]. No ffmpeg, no resampling.

defmodule WAV do
  @moduledoc false

  @doc "Decode 16-bit PCM mono 16 kHz WAV to a 1-D f32 Nx tensor in [-1, 1]."
  def decode!(<<"RIFF", _size::little-32, "WAVE", rest::binary>>) do
    {%{bits: 16, channels: 1, rate: 16_000}, data} = scan_chunks(rest, nil, nil)
    data |> Nx.from_binary(:s16) |> Nx.divide(32_768.0) |> Nx.as_type(:f32)
  end

  defp scan_chunks(<<"fmt ", size::little-32, fmt::binary-size(size), rest::binary>>, _, data) do
    <<_fmt_tag::little-16, channels::little-16, rate::little-32, _byte_rate::little-32,
      _block_align::little-16, bits::little-16, _rest_fmt::binary>> = fmt

    scan_chunks(rest, %{bits: bits, channels: channels, rate: rate}, data)
  end

  defp scan_chunks(<<"data", size::little-32, data::binary-size(size), _::binary>>, fmt, _),
    do: {fmt, data}

  defp scan_chunks(
         <<_id::binary-4, size::little-32, _::binary-size(size), rest::binary>>,
         fmt,
         data
       ),
       do: scan_chunks(rest, fmt, data)
end

jfk_wav =
  Req.get!(
    "https://raw.githubusercontent.com/ggml-org/whisper.cpp/master/samples/jfk.wav"
  ).body

jfk_audio = WAV.decode!(jfk_wav)

%{chunks: chunks} = Nx.Serving.run(serving, jfk_audio)
chunks |> Enum.map_join(& &1.text) |> String.trim()

You should see something close to "And so, my fellow Americans, ask not what your country can do for you — ask what you can do for your country." The exact punctuation and capitalisation will wobble from run to run because the decoder samples greedily from a distribution.

Transcribing a live recording

Kino.Control.form/2 wraps a Kino.Input.audio/2 widget with a submit button. Clicking Transcribe reads the recorded PCM f32 buffer, downmixes to mono if the browser captured stereo, and pushes the tensor through the same serving. Kino.Frame.render/2 swaps the output in place each time.

audio_input = Kino.Input.audio("Audio", format: :pcm_f32, sampling_rate: 16_000)

form =
  Kino.Control.form([audio: audio_input], submit: "Transcribe", reset_on_submit: [:audio])

frame = Kino.Frame.new()

Kino.render(form)
Kino.render(frame)

Kino.listen(form, fn
  %{type: :submit, data: %{audio: nil}} ->
    Kino.Frame.render(
      frame,
      Kino.Markdown.new("_No audio recorded — record a clip above and click Transcribe._")
    )

  %{type: :submit, data: %{audio: %{file_ref: ref, num_channels: channels}}} ->
    samples =
      ref
      |> Kino.Input.file_path()
      |> File.read!()
      |> Nx.from_binary(:f32)

    audio =
      if channels > 1 do
        samples |> Nx.reshape({:auto, channels}) |> Nx.mean(axes: [1])
      else
        samples
      end

    %{chunks: chunks} = Nx.Serving.run(serving, audio)
    text = chunks |> Enum.map_join(& &1.text) |> String.trim()
    Kino.Frame.render(frame, Kino.Markdown.new("**Transcript:** #{text}"))
end)

Whisper's context window is 30 seconds. For longer recordings, pass chunk_num_seconds: 30 when building the serving — Bumblebee will split the audio, transcribe each chunk with overlap, and stitch the outputs back together. See the Bumblebee.Audio moduledoc for the long-form example.

Telemetry

Whisper's autoregressive decoder reads tensors back host-side constantly — every sampled token, every early-stopping check, every caching step. Each readback fires [:emily, :to_binary, :stop]. Printing every span would drown the notebook, so aggregate: count the readbacks, sum their durations, and print a single summary.

{:ok, stats} = Agent.start_link(fn -> %{count: 0, total_ns: 0, bytes: 0} end)

:telemetry.attach(
  "whisper-to-binary",
  [:emily, :to_binary, :stop],
  fn _event, %{duration: duration}, %{byte_size: bytes}, _config ->
    Agent.update(stats, fn s ->
      %{count: s.count + 1, total_ns: s.total_ns + duration, bytes: s.bytes + bytes}
    end)
  end,
  nil
)

Nx.Serving.run(serving, jfk_audio)

summary = Agent.get(stats, & &1)
total_ms = System.convert_time_unit(summary.total_ns, :native, :millisecond)

IO.puts(
  "to_binary: #{summary.count} spans, " <>
    "#{Float.round(summary.bytes / 1_048_576, 2)} MiB total, " <>
    "#{total_ms} ms cumulative"
)

The count is high — Whisper's decoder touches many small tensors (per-step cache slices, EOS probes, timestamp checks), so span counts in the tens of thousands for ten seconds of audio are normal. What matters is the cumulative readback time vs. end-to-end latency: if they're close, the decoder is readback-bound and a batched or async-aware driver would help.

See Emily.Telemetry for the full event catalogue, including [:emily, :fallback, *] for spotting any op that routes through Nx.BinaryBackend, [:emily, :eval, *] for explicit lazy-graph flushes via Emily.eval/1, and [:emily, :memory, :stats] for MLX allocator polling.