Upload 9 files
Browse files- __init__.py +0 -0
- audio.py +119 -0
- feature_extractor.py +170 -0
- silero_vad.onnx +3 -0
- tokenizer.py +278 -0
- transcribe.py +1272 -0
- utils.py +157 -0
- vad.py +291 -0
- version.py +3 -0
__init__.py
ADDED
|
File without changes
|
audio.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""We use the PyAV library to decode the audio: https://github.com/PyAV-Org/PyAV
|
| 2 |
+
|
| 3 |
+
The advantage of PyAV is that it bundles the FFmpeg libraries so there is no additional
|
| 4 |
+
system dependencies. FFmpeg does not need to be installed on the system.
|
| 5 |
+
|
| 6 |
+
However, the API is quite low-level so we need to manipulate audio frames directly.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import gc
|
| 10 |
+
import io
|
| 11 |
+
import itertools
|
| 12 |
+
|
| 13 |
+
from typing import BinaryIO, Union
|
| 14 |
+
|
| 15 |
+
import av
|
| 16 |
+
import numpy as np
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def decode_audio(
|
| 20 |
+
input_file: Union[str, BinaryIO],
|
| 21 |
+
sampling_rate: int = 16000,
|
| 22 |
+
split_stereo: bool = False,
|
| 23 |
+
):
|
| 24 |
+
"""Decodes the audio.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
input_file: Path to the input file or a file-like object.
|
| 28 |
+
sampling_rate: Resample the audio to this sample rate.
|
| 29 |
+
split_stereo: Return separate left and right channels.
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
A float32 Numpy array.
|
| 33 |
+
|
| 34 |
+
If `split_stereo` is enabled, the function returns a 2-tuple with the
|
| 35 |
+
separated left and right channels.
|
| 36 |
+
"""
|
| 37 |
+
resampler = av.audio.resampler.AudioResampler(
|
| 38 |
+
format="s16",
|
| 39 |
+
layout="mono" if not split_stereo else "stereo",
|
| 40 |
+
rate=sampling_rate,
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
raw_buffer = io.BytesIO()
|
| 44 |
+
dtype = None
|
| 45 |
+
|
| 46 |
+
with av.open(input_file, mode="r", metadata_errors="ignore") as container:
|
| 47 |
+
frames = container.decode(audio=0)
|
| 48 |
+
frames = _ignore_invalid_frames(frames)
|
| 49 |
+
frames = _group_frames(frames, 500000)
|
| 50 |
+
frames = _resample_frames(frames, resampler)
|
| 51 |
+
|
| 52 |
+
for frame in frames:
|
| 53 |
+
array = frame.to_ndarray()
|
| 54 |
+
dtype = array.dtype
|
| 55 |
+
raw_buffer.write(array)
|
| 56 |
+
|
| 57 |
+
# It appears that some objects related to the resampler are not freed
|
| 58 |
+
# unless the garbage collector is manually run.
|
| 59 |
+
del resampler
|
| 60 |
+
gc.collect()
|
| 61 |
+
|
| 62 |
+
audio = np.frombuffer(raw_buffer.getbuffer(), dtype=dtype)
|
| 63 |
+
|
| 64 |
+
# Convert s16 back to f32.
|
| 65 |
+
audio = audio.astype(np.float32) / 32768.0
|
| 66 |
+
|
| 67 |
+
if split_stereo:
|
| 68 |
+
left_channel = audio[0::2]
|
| 69 |
+
right_channel = audio[1::2]
|
| 70 |
+
return left_channel, right_channel
|
| 71 |
+
|
| 72 |
+
return audio
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _ignore_invalid_frames(frames):
|
| 76 |
+
iterator = iter(frames)
|
| 77 |
+
|
| 78 |
+
while True:
|
| 79 |
+
try:
|
| 80 |
+
yield next(iterator)
|
| 81 |
+
except StopIteration:
|
| 82 |
+
break
|
| 83 |
+
except av.error.InvalidDataError:
|
| 84 |
+
continue
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _group_frames(frames, num_samples=None):
|
| 88 |
+
fifo = av.audio.fifo.AudioFifo()
|
| 89 |
+
|
| 90 |
+
for frame in frames:
|
| 91 |
+
frame.pts = None # Ignore timestamp check.
|
| 92 |
+
fifo.write(frame)
|
| 93 |
+
|
| 94 |
+
if num_samples is not None and fifo.samples >= num_samples:
|
| 95 |
+
yield fifo.read()
|
| 96 |
+
|
| 97 |
+
if fifo.samples > 0:
|
| 98 |
+
yield fifo.read()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _resample_frames(frames, resampler):
|
| 102 |
+
# Add None to flush the resampler.
|
| 103 |
+
for frame in itertools.chain(frames, [None]):
|
| 104 |
+
yield from resampler.resample(frame)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def pad_or_trim(array, length: int, *, axis: int = -1):
|
| 108 |
+
"""
|
| 109 |
+
Pad or trim the audio array to N_SAMPLES, as expected by the encoder.
|
| 110 |
+
"""
|
| 111 |
+
if array.shape[axis] > length:
|
| 112 |
+
array = array.take(indices=range(length), axis=axis)
|
| 113 |
+
|
| 114 |
+
if array.shape[axis] < length:
|
| 115 |
+
pad_widths = [(0, 0)] * array.ndim
|
| 116 |
+
pad_widths[axis] = (0, length - array.shape[axis])
|
| 117 |
+
array = np.pad(array, pad_widths)
|
| 118 |
+
|
| 119 |
+
return array
|
feature_extractor.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# Adapted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/whisper/feature_extraction_whisper.py # noqa: E501
|
| 5 |
+
class FeatureExtractor:
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
feature_size=80,
|
| 9 |
+
sampling_rate=16000,
|
| 10 |
+
hop_length=160,
|
| 11 |
+
chunk_length=30,
|
| 12 |
+
n_fft=400,
|
| 13 |
+
):
|
| 14 |
+
self.n_fft = n_fft
|
| 15 |
+
self.hop_length = hop_length
|
| 16 |
+
self.chunk_length = chunk_length
|
| 17 |
+
self.n_samples = chunk_length * sampling_rate
|
| 18 |
+
self.nb_max_frames = self.n_samples // hop_length
|
| 19 |
+
self.time_per_frame = hop_length / sampling_rate
|
| 20 |
+
self.sampling_rate = sampling_rate
|
| 21 |
+
self.mel_filters = self.get_mel_filters(
|
| 22 |
+
sampling_rate, n_fft, n_mels=feature_size
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
def get_mel_filters(self, sr, n_fft, n_mels=128, dtype=np.float32):
|
| 26 |
+
# Initialize the weights
|
| 27 |
+
n_mels = int(n_mels)
|
| 28 |
+
weights = np.zeros((n_mels, int(1 + n_fft // 2)), dtype=dtype)
|
| 29 |
+
|
| 30 |
+
# Center freqs of each FFT bin
|
| 31 |
+
fftfreqs = np.fft.rfftfreq(n=n_fft, d=1.0 / sr)
|
| 32 |
+
|
| 33 |
+
# 'Center freqs' of mel bands - uniformly spaced between limits
|
| 34 |
+
min_mel = 0.0
|
| 35 |
+
max_mel = 45.245640471924965
|
| 36 |
+
|
| 37 |
+
mels = np.linspace(min_mel, max_mel, n_mels + 2)
|
| 38 |
+
|
| 39 |
+
mels = np.asanyarray(mels)
|
| 40 |
+
|
| 41 |
+
# Fill in the linear scale
|
| 42 |
+
f_min = 0.0
|
| 43 |
+
f_sp = 200.0 / 3
|
| 44 |
+
freqs = f_min + f_sp * mels
|
| 45 |
+
|
| 46 |
+
# And now the nonlinear scale
|
| 47 |
+
min_log_hz = 1000.0 # beginning of log region (Hz)
|
| 48 |
+
min_log_mel = (min_log_hz - f_min) / f_sp # same (Mels)
|
| 49 |
+
logstep = np.log(6.4) / 27.0 # step size for log region
|
| 50 |
+
|
| 51 |
+
# If we have vector data, vectorize
|
| 52 |
+
log_t = mels >= min_log_mel
|
| 53 |
+
freqs[log_t] = min_log_hz * np.exp(logstep * (mels[log_t] - min_log_mel))
|
| 54 |
+
|
| 55 |
+
mel_f = freqs
|
| 56 |
+
|
| 57 |
+
fdiff = np.diff(mel_f)
|
| 58 |
+
ramps = np.subtract.outer(mel_f, fftfreqs)
|
| 59 |
+
|
| 60 |
+
for i in range(n_mels):
|
| 61 |
+
# lower and upper slopes for all bins
|
| 62 |
+
lower = -ramps[i] / fdiff[i]
|
| 63 |
+
upper = ramps[i + 2] / fdiff[i + 1]
|
| 64 |
+
|
| 65 |
+
# .. then intersect them with each other and zero
|
| 66 |
+
weights[i] = np.maximum(0, np.minimum(lower, upper))
|
| 67 |
+
|
| 68 |
+
# Slaney-style mel is scaled to be approx constant energy per channel
|
| 69 |
+
enorm = 2.0 / (mel_f[2 : n_mels + 2] - mel_f[:n_mels])
|
| 70 |
+
weights *= enorm[:, np.newaxis]
|
| 71 |
+
|
| 72 |
+
return weights
|
| 73 |
+
|
| 74 |
+
def fram_wave(self, waveform, center=True):
|
| 75 |
+
"""
|
| 76 |
+
Transform a raw waveform into a list of smaller waveforms.
|
| 77 |
+
The window length defines how much of the signal is
|
| 78 |
+
contain in each frame (smalle waveform), while the hope length defines the step
|
| 79 |
+
between the beginning of each new frame.
|
| 80 |
+
Centering is done by reflecting the waveform which is first centered around
|
| 81 |
+
`frame_idx * hop_length`.
|
| 82 |
+
"""
|
| 83 |
+
frames = []
|
| 84 |
+
for i in range(0, waveform.shape[0] + 1, self.hop_length):
|
| 85 |
+
half_window = (self.n_fft - 1) // 2 + 1
|
| 86 |
+
if center:
|
| 87 |
+
start = i - half_window if i > half_window else 0
|
| 88 |
+
end = (
|
| 89 |
+
i + half_window
|
| 90 |
+
if i < waveform.shape[0] - half_window
|
| 91 |
+
else waveform.shape[0]
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
frame = waveform[start:end]
|
| 95 |
+
|
| 96 |
+
if start == 0:
|
| 97 |
+
padd_width = (-i + half_window, 0)
|
| 98 |
+
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
|
| 99 |
+
|
| 100 |
+
elif end == waveform.shape[0]:
|
| 101 |
+
padd_width = (0, (i - waveform.shape[0] + half_window))
|
| 102 |
+
frame = np.pad(frame, pad_width=padd_width, mode="reflect")
|
| 103 |
+
|
| 104 |
+
else:
|
| 105 |
+
frame = waveform[i : i + self.n_fft]
|
| 106 |
+
frame_width = frame.shape[0]
|
| 107 |
+
if frame_width < waveform.shape[0]:
|
| 108 |
+
frame = np.lib.pad(
|
| 109 |
+
frame,
|
| 110 |
+
pad_width=(0, self.n_fft - frame_width),
|
| 111 |
+
mode="constant",
|
| 112 |
+
constant_values=0,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
frames.append(frame)
|
| 116 |
+
return np.stack(frames, 0)
|
| 117 |
+
|
| 118 |
+
def stft(self, frames, window):
|
| 119 |
+
"""
|
| 120 |
+
Calculates the complex Short-Time Fourier Transform (STFT) of the given framed signal.
|
| 121 |
+
Should give the same results as `torch.stft`.
|
| 122 |
+
"""
|
| 123 |
+
frame_size = frames.shape[1]
|
| 124 |
+
fft_size = self.n_fft
|
| 125 |
+
|
| 126 |
+
if fft_size is None:
|
| 127 |
+
fft_size = frame_size
|
| 128 |
+
|
| 129 |
+
if fft_size < frame_size:
|
| 130 |
+
raise ValueError("FFT size must greater or equal the frame size")
|
| 131 |
+
# number of FFT bins to store
|
| 132 |
+
num_fft_bins = (fft_size >> 1) + 1
|
| 133 |
+
|
| 134 |
+
data = np.empty((len(frames), num_fft_bins), dtype=np.complex64)
|
| 135 |
+
fft_signal = np.zeros(fft_size)
|
| 136 |
+
|
| 137 |
+
for f, frame in enumerate(frames):
|
| 138 |
+
if window is not None:
|
| 139 |
+
np.multiply(frame, window, out=fft_signal[:frame_size])
|
| 140 |
+
else:
|
| 141 |
+
fft_signal[:frame_size] = frame
|
| 142 |
+
data[f] = np.fft.fft(fft_signal, axis=0)[:num_fft_bins]
|
| 143 |
+
return data.T
|
| 144 |
+
|
| 145 |
+
def __call__(self, waveform, padding=True, chunk_length=None):
|
| 146 |
+
"""
|
| 147 |
+
Compute the log-Mel spectrogram of the provided audio, gives similar results
|
| 148 |
+
whisper's original torch implementation with 1e-5 tolerance.
|
| 149 |
+
"""
|
| 150 |
+
if chunk_length is not None:
|
| 151 |
+
self.n_samples = chunk_length * self.sampling_rate
|
| 152 |
+
self.nb_max_frames = self.n_samples // self.hop_length
|
| 153 |
+
|
| 154 |
+
if padding:
|
| 155 |
+
waveform = np.pad(waveform, [(0, self.n_samples)])
|
| 156 |
+
|
| 157 |
+
window = np.hanning(self.n_fft + 1)[:-1]
|
| 158 |
+
|
| 159 |
+
frames = self.fram_wave(waveform)
|
| 160 |
+
stft = self.stft(frames, window=window)
|
| 161 |
+
magnitudes = np.abs(stft[:, :-1]) ** 2
|
| 162 |
+
|
| 163 |
+
filters = self.mel_filters
|
| 164 |
+
mel_spec = filters @ magnitudes
|
| 165 |
+
|
| 166 |
+
log_spec = np.log10(np.clip(mel_spec, a_min=1e-10, a_max=None))
|
| 167 |
+
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
|
| 168 |
+
log_spec = (log_spec + 4.0) / 4.0
|
| 169 |
+
|
| 170 |
+
return log_spec
|
silero_vad.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:591f853590d11ddde2f2a54f9e7ccecb2533a8af7716330e8adfa6f3849787a9
|
| 3 |
+
size 1807524
|
tokenizer.py
ADDED
|
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import string
|
| 2 |
+
|
| 3 |
+
from functools import cached_property
|
| 4 |
+
from typing import List, Optional, Tuple
|
| 5 |
+
|
| 6 |
+
import tokenizers
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Tokenizer:
|
| 10 |
+
"""Simple wrapper around a tokenizers.Tokenizer."""
|
| 11 |
+
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
tokenizer: tokenizers.Tokenizer,
|
| 15 |
+
multilingual: bool,
|
| 16 |
+
task: Optional[str] = None,
|
| 17 |
+
language: Optional[str] = None,
|
| 18 |
+
):
|
| 19 |
+
self.tokenizer = tokenizer
|
| 20 |
+
|
| 21 |
+
if multilingual:
|
| 22 |
+
if task not in _TASKS:
|
| 23 |
+
raise ValueError(
|
| 24 |
+
"'%s' is not a valid task (accepted tasks: %s)"
|
| 25 |
+
% (task, ", ".join(_TASKS))
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
if language not in _LANGUAGE_CODES:
|
| 29 |
+
raise ValueError(
|
| 30 |
+
"'%s' is not a valid language code (accepted language codes: %s)"
|
| 31 |
+
% (language, ", ".join(_LANGUAGE_CODES))
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
self.task = self.tokenizer.token_to_id("<|%s|>" % task)
|
| 35 |
+
self.language = self.tokenizer.token_to_id("<|%s|>" % language)
|
| 36 |
+
self.language_code = language
|
| 37 |
+
else:
|
| 38 |
+
self.task = None
|
| 39 |
+
self.language = None
|
| 40 |
+
self.language_code = "en"
|
| 41 |
+
|
| 42 |
+
@cached_property
|
| 43 |
+
def transcribe(self) -> int:
|
| 44 |
+
return self.tokenizer.token_to_id("<|transcribe|>")
|
| 45 |
+
|
| 46 |
+
@cached_property
|
| 47 |
+
def translate(self) -> int:
|
| 48 |
+
return self.tokenizer.token_to_id("<|translate|>")
|
| 49 |
+
|
| 50 |
+
@cached_property
|
| 51 |
+
def sot(self) -> int:
|
| 52 |
+
return self.tokenizer.token_to_id("<|startoftranscript|>")
|
| 53 |
+
|
| 54 |
+
@cached_property
|
| 55 |
+
def sot_lm(self) -> int:
|
| 56 |
+
return self.tokenizer.token_to_id("<|startoflm|>")
|
| 57 |
+
|
| 58 |
+
@cached_property
|
| 59 |
+
def sot_prev(self) -> int:
|
| 60 |
+
return self.tokenizer.token_to_id("<|startofprev|>")
|
| 61 |
+
|
| 62 |
+
@cached_property
|
| 63 |
+
def eot(self) -> int:
|
| 64 |
+
return self.tokenizer.token_to_id("<|endoftext|>")
|
| 65 |
+
|
| 66 |
+
@cached_property
|
| 67 |
+
def no_timestamps(self) -> int:
|
| 68 |
+
return self.tokenizer.token_to_id("<|notimestamps|>")
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def timestamp_begin(self) -> int:
|
| 72 |
+
return self.no_timestamps + 1
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def sot_sequence(self) -> List[int]:
|
| 76 |
+
sequence = [self.sot]
|
| 77 |
+
|
| 78 |
+
if self.language is not None:
|
| 79 |
+
sequence.append(self.language)
|
| 80 |
+
|
| 81 |
+
if self.task is not None:
|
| 82 |
+
sequence.append(self.task)
|
| 83 |
+
|
| 84 |
+
return sequence
|
| 85 |
+
|
| 86 |
+
def encode(self, text: str) -> List[int]:
|
| 87 |
+
return self.tokenizer.encode(text, add_special_tokens=False).ids
|
| 88 |
+
|
| 89 |
+
def decode(self, tokens: List[int]) -> str:
|
| 90 |
+
text_tokens = [token for token in tokens if token < self.eot]
|
| 91 |
+
return self.tokenizer.decode(text_tokens)
|
| 92 |
+
|
| 93 |
+
def decode_with_timestamps(self, tokens: List[int]) -> str:
|
| 94 |
+
outputs = [[]]
|
| 95 |
+
|
| 96 |
+
for token in tokens:
|
| 97 |
+
if token >= self.timestamp_begin:
|
| 98 |
+
timestamp = f"<|{(token - self.timestamp_begin) * 0.02:.2f}|>"
|
| 99 |
+
outputs.append(timestamp)
|
| 100 |
+
outputs.append([])
|
| 101 |
+
else:
|
| 102 |
+
outputs[-1].append(token)
|
| 103 |
+
|
| 104 |
+
return "".join(
|
| 105 |
+
[s if isinstance(s, str) else self.tokenizer.decode(s) for s in outputs]
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
def split_to_word_tokens(
|
| 109 |
+
self, tokens: List[int]
|
| 110 |
+
) -> Tuple[List[str], List[List[int]]]:
|
| 111 |
+
if self.language_code in {"zh", "ja", "th", "lo", "my", "yue"}:
|
| 112 |
+
# These languages don't typically use spaces, so it is difficult to split words
|
| 113 |
+
# without morpheme analysis. Here, we instead split words at any
|
| 114 |
+
# position where the tokens are decoded as valid unicode points
|
| 115 |
+
return self.split_tokens_on_unicode(tokens)
|
| 116 |
+
|
| 117 |
+
return self.split_tokens_on_spaces(tokens)
|
| 118 |
+
|
| 119 |
+
def split_tokens_on_unicode(
|
| 120 |
+
self, tokens: List[int]
|
| 121 |
+
) -> Tuple[List[str], List[List[int]]]:
|
| 122 |
+
decoded_full = self.decode_with_timestamps(tokens)
|
| 123 |
+
replacement_char = "\ufffd"
|
| 124 |
+
|
| 125 |
+
words = []
|
| 126 |
+
word_tokens = []
|
| 127 |
+
current_tokens = []
|
| 128 |
+
unicode_offset = 0
|
| 129 |
+
|
| 130 |
+
for token in tokens:
|
| 131 |
+
current_tokens.append(token)
|
| 132 |
+
decoded = self.decode_with_timestamps(current_tokens)
|
| 133 |
+
|
| 134 |
+
try:
|
| 135 |
+
replacement_char_index = decoded.index(replacement_char)
|
| 136 |
+
replacement_char_index += unicode_offset
|
| 137 |
+
except ValueError:
|
| 138 |
+
replacement_char_index = None
|
| 139 |
+
|
| 140 |
+
if replacement_char_index is None or (
|
| 141 |
+
replacement_char_index < len(decoded_full)
|
| 142 |
+
and decoded_full[replacement_char_index] == replacement_char
|
| 143 |
+
):
|
| 144 |
+
words.append(decoded)
|
| 145 |
+
word_tokens.append(current_tokens)
|
| 146 |
+
current_tokens = []
|
| 147 |
+
unicode_offset += len(decoded)
|
| 148 |
+
|
| 149 |
+
return words, word_tokens
|
| 150 |
+
|
| 151 |
+
def split_tokens_on_spaces(
|
| 152 |
+
self, tokens: List[int]
|
| 153 |
+
) -> Tuple[List[str], List[List[int]]]:
|
| 154 |
+
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
|
| 155 |
+
words = []
|
| 156 |
+
word_tokens = []
|
| 157 |
+
|
| 158 |
+
for subword, subword_tokens in zip(subwords, subword_tokens_list):
|
| 159 |
+
special = subword_tokens[0] >= self.eot
|
| 160 |
+
with_space = subword.startswith(" ")
|
| 161 |
+
punctuation = subword.strip() in string.punctuation
|
| 162 |
+
if special or with_space or punctuation or len(words) == 0:
|
| 163 |
+
words.append(subword)
|
| 164 |
+
word_tokens.append(subword_tokens)
|
| 165 |
+
else:
|
| 166 |
+
words[-1] = words[-1] + subword
|
| 167 |
+
word_tokens[-1].extend(subword_tokens)
|
| 168 |
+
|
| 169 |
+
return words, word_tokens
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
_TASKS = (
|
| 173 |
+
"transcribe",
|
| 174 |
+
"translate",
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
_LANGUAGE_CODES = (
|
| 178 |
+
"af",
|
| 179 |
+
"am",
|
| 180 |
+
"ar",
|
| 181 |
+
"as",
|
| 182 |
+
"az",
|
| 183 |
+
"ba",
|
| 184 |
+
"be",
|
| 185 |
+
"bg",
|
| 186 |
+
"bn",
|
| 187 |
+
"bo",
|
| 188 |
+
"br",
|
| 189 |
+
"bs",
|
| 190 |
+
"ca",
|
| 191 |
+
"cs",
|
| 192 |
+
"cy",
|
| 193 |
+
"da",
|
| 194 |
+
"de",
|
| 195 |
+
"el",
|
| 196 |
+
"en",
|
| 197 |
+
"es",
|
| 198 |
+
"et",
|
| 199 |
+
"eu",
|
| 200 |
+
"fa",
|
| 201 |
+
"fi",
|
| 202 |
+
"fo",
|
| 203 |
+
"fr",
|
| 204 |
+
"gl",
|
| 205 |
+
"gu",
|
| 206 |
+
"ha",
|
| 207 |
+
"haw",
|
| 208 |
+
"he",
|
| 209 |
+
"hi",
|
| 210 |
+
"hr",
|
| 211 |
+
"ht",
|
| 212 |
+
"hu",
|
| 213 |
+
"hy",
|
| 214 |
+
"id",
|
| 215 |
+
"is",
|
| 216 |
+
"it",
|
| 217 |
+
"ja",
|
| 218 |
+
"jw",
|
| 219 |
+
"ka",
|
| 220 |
+
"kk",
|
| 221 |
+
"km",
|
| 222 |
+
"kn",
|
| 223 |
+
"ko",
|
| 224 |
+
"la",
|
| 225 |
+
"lb",
|
| 226 |
+
"ln",
|
| 227 |
+
"lo",
|
| 228 |
+
"lt",
|
| 229 |
+
"lv",
|
| 230 |
+
"mg",
|
| 231 |
+
"mi",
|
| 232 |
+
"mk",
|
| 233 |
+
"ml",
|
| 234 |
+
"mn",
|
| 235 |
+
"mr",
|
| 236 |
+
"ms",
|
| 237 |
+
"mt",
|
| 238 |
+
"my",
|
| 239 |
+
"ne",
|
| 240 |
+
"nl",
|
| 241 |
+
"nn",
|
| 242 |
+
"no",
|
| 243 |
+
"oc",
|
| 244 |
+
"pa",
|
| 245 |
+
"pl",
|
| 246 |
+
"ps",
|
| 247 |
+
"pt",
|
| 248 |
+
"ro",
|
| 249 |
+
"ru",
|
| 250 |
+
"sa",
|
| 251 |
+
"sd",
|
| 252 |
+
"si",
|
| 253 |
+
"sk",
|
| 254 |
+
"sl",
|
| 255 |
+
"sn",
|
| 256 |
+
"so",
|
| 257 |
+
"sq",
|
| 258 |
+
"sr",
|
| 259 |
+
"su",
|
| 260 |
+
"sv",
|
| 261 |
+
"sw",
|
| 262 |
+
"ta",
|
| 263 |
+
"te",
|
| 264 |
+
"tg",
|
| 265 |
+
"th",
|
| 266 |
+
"tk",
|
| 267 |
+
"tl",
|
| 268 |
+
"tr",
|
| 269 |
+
"tt",
|
| 270 |
+
"uk",
|
| 271 |
+
"ur",
|
| 272 |
+
"uz",
|
| 273 |
+
"vi",
|
| 274 |
+
"yi",
|
| 275 |
+
"yo",
|
| 276 |
+
"zh",
|
| 277 |
+
"yue",
|
| 278 |
+
)
|
transcribe.py
ADDED
|
@@ -0,0 +1,1272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import zlib
|
| 6 |
+
|
| 7 |
+
from inspect import signature
|
| 8 |
+
from typing import BinaryIO, Iterable, List, NamedTuple, Optional, Tuple, Union
|
| 9 |
+
|
| 10 |
+
import ctranslate2
|
| 11 |
+
import numpy as np
|
| 12 |
+
import tokenizers
|
| 13 |
+
|
| 14 |
+
from faster_whisper.audio import decode_audio, pad_or_trim
|
| 15 |
+
from faster_whisper.feature_extractor import FeatureExtractor
|
| 16 |
+
from faster_whisper.tokenizer import _LANGUAGE_CODES, Tokenizer
|
| 17 |
+
from faster_whisper.utils import download_model, format_timestamp, get_end, get_logger
|
| 18 |
+
from faster_whisper.vad import (
|
| 19 |
+
SpeechTimestampsMap,
|
| 20 |
+
VadOptions,
|
| 21 |
+
collect_chunks,
|
| 22 |
+
get_speech_timestamps,
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Word(NamedTuple):
|
| 27 |
+
start: float
|
| 28 |
+
end: float
|
| 29 |
+
word: str
|
| 30 |
+
probability: float
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class Segment(NamedTuple):
|
| 34 |
+
id: int
|
| 35 |
+
seek: int
|
| 36 |
+
start: float
|
| 37 |
+
end: float
|
| 38 |
+
text: str
|
| 39 |
+
tokens: List[int]
|
| 40 |
+
temperature: float
|
| 41 |
+
avg_logprob: float
|
| 42 |
+
compression_ratio: float
|
| 43 |
+
no_speech_prob: float
|
| 44 |
+
words: Optional[List[Word]]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class TranscriptionOptions(NamedTuple):
|
| 48 |
+
beam_size: int
|
| 49 |
+
best_of: int
|
| 50 |
+
patience: float
|
| 51 |
+
length_penalty: float
|
| 52 |
+
repetition_penalty: float
|
| 53 |
+
no_repeat_ngram_size: int
|
| 54 |
+
log_prob_threshold: Optional[float]
|
| 55 |
+
no_speech_threshold: Optional[float]
|
| 56 |
+
compression_ratio_threshold: Optional[float]
|
| 57 |
+
condition_on_previous_text: bool
|
| 58 |
+
prompt_reset_on_temperature: float
|
| 59 |
+
temperatures: List[float]
|
| 60 |
+
initial_prompt: Optional[Union[str, Iterable[int]]]
|
| 61 |
+
prefix: Optional[str]
|
| 62 |
+
suppress_blank: bool
|
| 63 |
+
suppress_tokens: Optional[List[int]]
|
| 64 |
+
without_timestamps: bool
|
| 65 |
+
max_initial_timestamp: float
|
| 66 |
+
word_timestamps: bool
|
| 67 |
+
prepend_punctuations: str
|
| 68 |
+
append_punctuations: str
|
| 69 |
+
max_new_tokens: Optional[int]
|
| 70 |
+
clip_timestamps: Union[str, List[float]]
|
| 71 |
+
hallucination_silence_threshold: Optional[float]
|
| 72 |
+
hotwords: Optional[str]
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
class TranscriptionInfo(NamedTuple):
|
| 76 |
+
language: str
|
| 77 |
+
language_probability: float
|
| 78 |
+
duration: float
|
| 79 |
+
duration_after_vad: float
|
| 80 |
+
all_language_probs: Optional[List[Tuple[str, float]]]
|
| 81 |
+
transcription_options: TranscriptionOptions
|
| 82 |
+
vad_options: VadOptions
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class WhisperModel:
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
model_size_or_path: str,
|
| 89 |
+
device: str = "auto",
|
| 90 |
+
device_index: Union[int, List[int]] = 0,
|
| 91 |
+
compute_type: str = "default",
|
| 92 |
+
cpu_threads: int = 0,
|
| 93 |
+
num_workers: int = 1,
|
| 94 |
+
download_root: Optional[str] = None,
|
| 95 |
+
local_files_only: bool = False,
|
| 96 |
+
files: dict = None,
|
| 97 |
+
**model_kwargs,
|
| 98 |
+
):
|
| 99 |
+
"""Initializes the Whisper model.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
model_size_or_path: Size of the model to use (tiny, tiny.en, base, base.en,
|
| 103 |
+
small, small.en, distil-small.en, medium, medium.en, distil-medium.en, large-v1,
|
| 104 |
+
large-v2, large-v3, large, distil-large-v2 or distil-large-v3), a path to a
|
| 105 |
+
converted model directory, or a CTranslate2-converted Whisper model ID from the HF Hub.
|
| 106 |
+
When a size or a model ID is configured, the converted model is downloaded
|
| 107 |
+
from the Hugging Face Hub.
|
| 108 |
+
device: Device to use for computation ("cpu", "cuda", "auto").
|
| 109 |
+
device_index: Device ID to use.
|
| 110 |
+
The model can also be loaded on multiple GPUs by passing a list of IDs
|
| 111 |
+
(e.g. [0, 1, 2, 3]). In that case, multiple transcriptions can run in parallel
|
| 112 |
+
when transcribe() is called from multiple Python threads (see also num_workers).
|
| 113 |
+
compute_type: Type to use for computation.
|
| 114 |
+
See https://opennmt.net/CTranslate2/quantization.html.
|
| 115 |
+
cpu_threads: Number of threads to use when running on CPU (4 by default).
|
| 116 |
+
A non zero value overrides the OMP_NUM_THREADS environment variable.
|
| 117 |
+
num_workers: When transcribe() is called from multiple Python threads,
|
| 118 |
+
having multiple workers enables true parallelism when running the model
|
| 119 |
+
(concurrent calls to self.model.generate() will run in parallel).
|
| 120 |
+
This can improve the global throughput at the cost of increased memory usage.
|
| 121 |
+
download_root: Directory where the models should be saved. If not set, the models
|
| 122 |
+
are saved in the standard Hugging Face cache directory.
|
| 123 |
+
local_files_only: If True, avoid downloading the file and return the path to the
|
| 124 |
+
local cached file if it exists.
|
| 125 |
+
files: Load model files from the memory. This argument is a dictionary mapping file names
|
| 126 |
+
to file contents as file-like or bytes objects. If this is set, model_path acts as an
|
| 127 |
+
identifier for this model.
|
| 128 |
+
"""
|
| 129 |
+
self.logger = get_logger()
|
| 130 |
+
|
| 131 |
+
tokenizer_bytes, preprocessor_bytes = None, None
|
| 132 |
+
if files:
|
| 133 |
+
model_path = model_size_or_path
|
| 134 |
+
tokenizer_bytes = files.pop("tokenizer.json", None)
|
| 135 |
+
preprocessor_bytes = files.pop("preprocessor_config.json", None)
|
| 136 |
+
elif os.path.isdir(model_size_or_path):
|
| 137 |
+
model_path = model_size_or_path
|
| 138 |
+
else:
|
| 139 |
+
model_path = download_model(
|
| 140 |
+
model_size_or_path,
|
| 141 |
+
local_files_only=local_files_only,
|
| 142 |
+
cache_dir=download_root,
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
self.model = ctranslate2.models.Whisper(
|
| 146 |
+
model_path,
|
| 147 |
+
device=device,
|
| 148 |
+
device_index=device_index,
|
| 149 |
+
compute_type=compute_type,
|
| 150 |
+
intra_threads=cpu_threads,
|
| 151 |
+
inter_threads=num_workers,
|
| 152 |
+
files=files,
|
| 153 |
+
**model_kwargs,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
tokenizer_file = os.path.join(model_path, "tokenizer.json")
|
| 157 |
+
if tokenizer_bytes:
|
| 158 |
+
self.hf_tokenizer = tokenizers.Tokenizer.from_buffer(tokenizer_bytes)
|
| 159 |
+
elif os.path.isfile(tokenizer_file):
|
| 160 |
+
self.hf_tokenizer = tokenizers.Tokenizer.from_file(tokenizer_file)
|
| 161 |
+
else:
|
| 162 |
+
self.hf_tokenizer = tokenizers.Tokenizer.from_pretrained(
|
| 163 |
+
"openai/whisper-tiny" + ("" if self.model.is_multilingual else ".en")
|
| 164 |
+
)
|
| 165 |
+
self.feat_kwargs = self._get_feature_kwargs(model_path, preprocessor_bytes)
|
| 166 |
+
self.feature_extractor = FeatureExtractor(**self.feat_kwargs)
|
| 167 |
+
self.num_samples_per_token = self.feature_extractor.hop_length * 2
|
| 168 |
+
self.frames_per_second = (
|
| 169 |
+
self.feature_extractor.sampling_rate // self.feature_extractor.hop_length
|
| 170 |
+
)
|
| 171 |
+
self.tokens_per_second = (
|
| 172 |
+
self.feature_extractor.sampling_rate // self.num_samples_per_token
|
| 173 |
+
)
|
| 174 |
+
self.input_stride = 2
|
| 175 |
+
self.time_precision = 0.02
|
| 176 |
+
self.max_length = 448
|
| 177 |
+
|
| 178 |
+
@property
|
| 179 |
+
def supported_languages(self) -> List[str]:
|
| 180 |
+
"""The languages supported by the model."""
|
| 181 |
+
return list(_LANGUAGE_CODES) if self.model.is_multilingual else ["en"]
|
| 182 |
+
|
| 183 |
+
def _get_feature_kwargs(self, model_path, preprocessor_bytes=None) -> dict:
|
| 184 |
+
config = {}
|
| 185 |
+
try:
|
| 186 |
+
config_path = os.path.join(model_path, "preprocessor_config.json")
|
| 187 |
+
if preprocessor_bytes:
|
| 188 |
+
config = json.loads(preprocessor_bytes)
|
| 189 |
+
elif os.path.isfile(config_path):
|
| 190 |
+
with open(config_path, "r", encoding="utf-8") as file:
|
| 191 |
+
config = json.load(file)
|
| 192 |
+
else:
|
| 193 |
+
return config
|
| 194 |
+
valid_keys = signature(FeatureExtractor.__init__).parameters.keys()
|
| 195 |
+
return {k: v for k, v in config.items() if k in valid_keys}
|
| 196 |
+
except json.JSONDecodeError as e:
|
| 197 |
+
self.logger.warning("Could not load preprocessor config: %s", e)
|
| 198 |
+
|
| 199 |
+
return config
|
| 200 |
+
|
| 201 |
+
def transcribe(
|
| 202 |
+
self,
|
| 203 |
+
audio: Union[str, BinaryIO, np.ndarray],
|
| 204 |
+
language: Optional[str] = None,
|
| 205 |
+
task: str = "transcribe",
|
| 206 |
+
beam_size: int = 5,
|
| 207 |
+
best_of: int = 5,
|
| 208 |
+
patience: float = 1,
|
| 209 |
+
length_penalty: float = 1,
|
| 210 |
+
repetition_penalty: float = 1,
|
| 211 |
+
no_repeat_ngram_size: int = 0,
|
| 212 |
+
temperature: Union[float, List[float], Tuple[float, ...]] = [
|
| 213 |
+
0.0,
|
| 214 |
+
0.2,
|
| 215 |
+
0.4,
|
| 216 |
+
0.6,
|
| 217 |
+
0.8,
|
| 218 |
+
1.0,
|
| 219 |
+
],
|
| 220 |
+
compression_ratio_threshold: Optional[float] = 2.4,
|
| 221 |
+
log_prob_threshold: Optional[float] = -1.0,
|
| 222 |
+
no_speech_threshold: Optional[float] = 0.6,
|
| 223 |
+
condition_on_previous_text: bool = True,
|
| 224 |
+
prompt_reset_on_temperature: float = 0.5,
|
| 225 |
+
initial_prompt: Optional[Union[str, Iterable[int]]] = None,
|
| 226 |
+
prefix: Optional[str] = None,
|
| 227 |
+
suppress_blank: bool = True,
|
| 228 |
+
suppress_tokens: Optional[List[int]] = [-1],
|
| 229 |
+
without_timestamps: bool = False,
|
| 230 |
+
max_initial_timestamp: float = 1.0,
|
| 231 |
+
word_timestamps: bool = False,
|
| 232 |
+
prepend_punctuations: str = "\"'“¿([{-",
|
| 233 |
+
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
| 234 |
+
vad_filter: bool = False,
|
| 235 |
+
vad_parameters: Optional[Union[dict, VadOptions]] = None,
|
| 236 |
+
max_new_tokens: Optional[int] = None,
|
| 237 |
+
chunk_length: Optional[int] = None,
|
| 238 |
+
clip_timestamps: Union[str, List[float]] = "0",
|
| 239 |
+
hallucination_silence_threshold: Optional[float] = None,
|
| 240 |
+
hotwords: Optional[str] = None,
|
| 241 |
+
language_detection_threshold: Optional[float] = None,
|
| 242 |
+
language_detection_segments: int = 1,
|
| 243 |
+
) -> Tuple[Iterable[Segment], TranscriptionInfo]:
|
| 244 |
+
"""Transcribes an input file.
|
| 245 |
+
|
| 246 |
+
Arguments:
|
| 247 |
+
audio: Path to the input file (or a file-like object), or the audio waveform.
|
| 248 |
+
language: The language spoken in the audio. It should be a language code such
|
| 249 |
+
as "en" or "fr". If not set, the language will be detected in the first 30 seconds
|
| 250 |
+
of audio.
|
| 251 |
+
task: Task to execute (transcribe or translate).
|
| 252 |
+
beam_size: Beam size to use for decoding.
|
| 253 |
+
best_of: Number of candidates when sampling with non-zero temperature.
|
| 254 |
+
patience: Beam search patience factor.
|
| 255 |
+
length_penalty: Exponential length penalty constant.
|
| 256 |
+
repetition_penalty: Penalty applied to the score of previously generated tokens
|
| 257 |
+
(set > 1 to penalize).
|
| 258 |
+
no_repeat_ngram_size: Prevent repetitions of ngrams with this size (set 0 to disable).
|
| 259 |
+
temperature: Temperature for sampling. It can be a tuple of temperatures,
|
| 260 |
+
which will be successively used upon failures according to either
|
| 261 |
+
`compression_ratio_threshold` or `log_prob_threshold`.
|
| 262 |
+
compression_ratio_threshold: If the gzip compression ratio is above this value,
|
| 263 |
+
treat as failed.
|
| 264 |
+
log_prob_threshold: If the average log probability over sampled tokens is
|
| 265 |
+
below this value, treat as failed.
|
| 266 |
+
no_speech_threshold: If the no_speech probability is higher than this value AND
|
| 267 |
+
the average log probability over sampled tokens is below `log_prob_threshold`,
|
| 268 |
+
consider the segment as silent.
|
| 269 |
+
condition_on_previous_text: If True, the previous output of the model is provided
|
| 270 |
+
as a prompt for the next window; disabling may make the text inconsistent across
|
| 271 |
+
windows, but the model becomes less prone to getting stuck in a failure loop,
|
| 272 |
+
such as repetition looping or timestamps going out of sync.
|
| 273 |
+
prompt_reset_on_temperature: Resets prompt if temperature is above this value.
|
| 274 |
+
Arg has effect only if condition_on_previous_text is True.
|
| 275 |
+
initial_prompt: Optional text string or iterable of token ids to provide as a
|
| 276 |
+
prompt for the first window.
|
| 277 |
+
prefix: Optional text to provide as a prefix for the first window.
|
| 278 |
+
suppress_blank: Suppress blank outputs at the beginning of the sampling.
|
| 279 |
+
suppress_tokens: List of token IDs to suppress. -1 will suppress a default set
|
| 280 |
+
of symbols as defined in the model config.json file.
|
| 281 |
+
without_timestamps: Only sample text tokens.
|
| 282 |
+
max_initial_timestamp: The initial timestamp cannot be later than this.
|
| 283 |
+
word_timestamps: Extract word-level timestamps using the cross-attention pattern
|
| 284 |
+
and dynamic time warping, and include the timestamps for each word in each segment.
|
| 285 |
+
prepend_punctuations: If word_timestamps is True, merge these punctuation symbols
|
| 286 |
+
with the next word
|
| 287 |
+
append_punctuations: If word_timestamps is True, merge these punctuation symbols
|
| 288 |
+
with the previous word
|
| 289 |
+
vad_filter: Enable the voice activity detection (VAD) to filter out parts of the audio
|
| 290 |
+
without speech. This step is using the Silero VAD model
|
| 291 |
+
https://github.com/snakers4/silero-vad.
|
| 292 |
+
vad_parameters: Dictionary of Silero VAD parameters or VadOptions class (see available
|
| 293 |
+
parameters and default values in the class `VadOptions`).
|
| 294 |
+
max_new_tokens: Maximum number of new tokens to generate per-chunk. If not set,
|
| 295 |
+
the maximum will be set by the default max_length.
|
| 296 |
+
chunk_length: The length of audio segments. If it is not None, it will overwrite the
|
| 297 |
+
default chunk_length of the FeatureExtractor.
|
| 298 |
+
clip_timestamps:
|
| 299 |
+
Comma-separated list start,end,start,end,... timestamps (in seconds) of clips to
|
| 300 |
+
process. The last end timestamp defaults to the end of the file.
|
| 301 |
+
vad_filter will be ignored if clip_timestamps is used.
|
| 302 |
+
hallucination_silence_threshold:
|
| 303 |
+
When word_timestamps is True, skip silent periods longer than this threshold
|
| 304 |
+
(in seconds) when a possible hallucination is detected
|
| 305 |
+
hotwords:
|
| 306 |
+
Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
|
| 307 |
+
language_detection_threshold: If the maximum probability of the language tokens is higher
|
| 308 |
+
than this value, the language is detected.
|
| 309 |
+
language_detection_segments: Number of segments to consider for the language detection.
|
| 310 |
+
Returns:
|
| 311 |
+
A tuple with:
|
| 312 |
+
|
| 313 |
+
- a generator over transcribed segments
|
| 314 |
+
- an instance of TranscriptionInfo
|
| 315 |
+
"""
|
| 316 |
+
sampling_rate = self.feature_extractor.sampling_rate
|
| 317 |
+
|
| 318 |
+
if not isinstance(audio, np.ndarray):
|
| 319 |
+
audio = decode_audio(audio, sampling_rate=sampling_rate)
|
| 320 |
+
|
| 321 |
+
duration = audio.shape[0] / sampling_rate
|
| 322 |
+
duration_after_vad = duration
|
| 323 |
+
|
| 324 |
+
self.logger.info(
|
| 325 |
+
"Processing audio with duration %s", format_timestamp(duration)
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
if vad_filter and clip_timestamps == "0":
|
| 329 |
+
if vad_parameters is None:
|
| 330 |
+
vad_parameters = VadOptions()
|
| 331 |
+
elif isinstance(vad_parameters, dict):
|
| 332 |
+
vad_parameters = VadOptions(**vad_parameters)
|
| 333 |
+
speech_chunks = get_speech_timestamps(audio, vad_parameters)
|
| 334 |
+
audio = collect_chunks(audio, speech_chunks)
|
| 335 |
+
duration_after_vad = audio.shape[0] / sampling_rate
|
| 336 |
+
|
| 337 |
+
self.logger.info(
|
| 338 |
+
"VAD filter removed %s of audio",
|
| 339 |
+
format_timestamp(duration - duration_after_vad),
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
if self.logger.isEnabledFor(logging.DEBUG):
|
| 343 |
+
self.logger.debug(
|
| 344 |
+
"VAD filter kept the following audio segments: %s",
|
| 345 |
+
", ".join(
|
| 346 |
+
"[%s -> %s]"
|
| 347 |
+
% (
|
| 348 |
+
format_timestamp(chunk["start"] / sampling_rate),
|
| 349 |
+
format_timestamp(chunk["end"] / sampling_rate),
|
| 350 |
+
)
|
| 351 |
+
for chunk in speech_chunks
|
| 352 |
+
),
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
else:
|
| 356 |
+
speech_chunks = None
|
| 357 |
+
|
| 358 |
+
features = self.feature_extractor(audio, chunk_length=chunk_length)
|
| 359 |
+
|
| 360 |
+
encoder_output = None
|
| 361 |
+
all_language_probs = None
|
| 362 |
+
|
| 363 |
+
if language is None:
|
| 364 |
+
if not self.model.is_multilingual:
|
| 365 |
+
language = "en"
|
| 366 |
+
language_probability = 1
|
| 367 |
+
else:
|
| 368 |
+
if (
|
| 369 |
+
language_detection_segments is None
|
| 370 |
+
or language_detection_segments < 1
|
| 371 |
+
):
|
| 372 |
+
language_detection_segments = 1
|
| 373 |
+
seek = 0
|
| 374 |
+
detected_language_info = {}
|
| 375 |
+
content_frames = (
|
| 376 |
+
features.shape[-1] - self.feature_extractor.nb_max_frames
|
| 377 |
+
)
|
| 378 |
+
while (
|
| 379 |
+
seek <= content_frames
|
| 380 |
+
and seek
|
| 381 |
+
< self.feature_extractor.nb_max_frames * language_detection_segments
|
| 382 |
+
):
|
| 383 |
+
segment = features[
|
| 384 |
+
:, seek : seek + self.feature_extractor.nb_max_frames
|
| 385 |
+
]
|
| 386 |
+
encoder_output = self.encode(segment)
|
| 387 |
+
# results is a list of tuple[str, float] with language names and
|
| 388 |
+
# probabilities.
|
| 389 |
+
results = self.model.detect_language(encoder_output)[0]
|
| 390 |
+
# Parse language names to strip out markers
|
| 391 |
+
all_language_probs = [
|
| 392 |
+
(token[2:-2], prob) for (token, prob) in results
|
| 393 |
+
]
|
| 394 |
+
# Get top language token and probability
|
| 395 |
+
language, language_probability = all_language_probs[0]
|
| 396 |
+
if (
|
| 397 |
+
language_detection_threshold is None
|
| 398 |
+
or language_probability > language_detection_threshold
|
| 399 |
+
):
|
| 400 |
+
break
|
| 401 |
+
detected_language_info.setdefault(language, []).append(
|
| 402 |
+
language_probability
|
| 403 |
+
)
|
| 404 |
+
seek += segment.shape[-1]
|
| 405 |
+
else:
|
| 406 |
+
# If no language detected for all segments, the majority vote of the highest
|
| 407 |
+
# projected languages for all segments is used to determine the language.
|
| 408 |
+
language = max(
|
| 409 |
+
detected_language_info,
|
| 410 |
+
key=lambda lang: len(detected_language_info[lang]),
|
| 411 |
+
)
|
| 412 |
+
language_probability = max(detected_language_info[language])
|
| 413 |
+
|
| 414 |
+
self.logger.info(
|
| 415 |
+
"Detected language '%s' with probability %.2f",
|
| 416 |
+
language,
|
| 417 |
+
language_probability,
|
| 418 |
+
)
|
| 419 |
+
else:
|
| 420 |
+
if not self.model.is_multilingual and language != "en":
|
| 421 |
+
self.logger.warning(
|
| 422 |
+
"The current model is English-only but the language parameter is set to '%s'; "
|
| 423 |
+
"using 'en' instead." % language
|
| 424 |
+
)
|
| 425 |
+
language = "en"
|
| 426 |
+
|
| 427 |
+
language_probability = 1
|
| 428 |
+
|
| 429 |
+
tokenizer = Tokenizer(
|
| 430 |
+
self.hf_tokenizer,
|
| 431 |
+
self.model.is_multilingual,
|
| 432 |
+
task=task,
|
| 433 |
+
language=language,
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
options = TranscriptionOptions(
|
| 437 |
+
beam_size=beam_size,
|
| 438 |
+
best_of=best_of,
|
| 439 |
+
patience=patience,
|
| 440 |
+
length_penalty=length_penalty,
|
| 441 |
+
repetition_penalty=repetition_penalty,
|
| 442 |
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
| 443 |
+
log_prob_threshold=log_prob_threshold,
|
| 444 |
+
no_speech_threshold=no_speech_threshold,
|
| 445 |
+
compression_ratio_threshold=compression_ratio_threshold,
|
| 446 |
+
condition_on_previous_text=condition_on_previous_text,
|
| 447 |
+
prompt_reset_on_temperature=prompt_reset_on_temperature,
|
| 448 |
+
temperatures=(
|
| 449 |
+
temperature if isinstance(temperature, (list, tuple)) else [temperature]
|
| 450 |
+
),
|
| 451 |
+
initial_prompt=initial_prompt,
|
| 452 |
+
prefix=prefix,
|
| 453 |
+
suppress_blank=suppress_blank,
|
| 454 |
+
suppress_tokens=get_suppressed_tokens(tokenizer, suppress_tokens),
|
| 455 |
+
without_timestamps=without_timestamps,
|
| 456 |
+
max_initial_timestamp=max_initial_timestamp,
|
| 457 |
+
word_timestamps=word_timestamps,
|
| 458 |
+
prepend_punctuations=prepend_punctuations,
|
| 459 |
+
append_punctuations=append_punctuations,
|
| 460 |
+
max_new_tokens=max_new_tokens,
|
| 461 |
+
clip_timestamps=clip_timestamps,
|
| 462 |
+
hallucination_silence_threshold=hallucination_silence_threshold,
|
| 463 |
+
hotwords=hotwords,
|
| 464 |
+
)
|
| 465 |
+
|
| 466 |
+
segments = self.generate_segments(features, tokenizer, options, encoder_output)
|
| 467 |
+
|
| 468 |
+
if speech_chunks:
|
| 469 |
+
segments = restore_speech_timestamps(segments, speech_chunks, sampling_rate)
|
| 470 |
+
|
| 471 |
+
info = TranscriptionInfo(
|
| 472 |
+
language=language,
|
| 473 |
+
language_probability=language_probability,
|
| 474 |
+
duration=duration,
|
| 475 |
+
duration_after_vad=duration_after_vad,
|
| 476 |
+
transcription_options=options,
|
| 477 |
+
vad_options=vad_parameters,
|
| 478 |
+
all_language_probs=all_language_probs,
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
return segments, info
|
| 482 |
+
|
| 483 |
+
def generate_segments(
|
| 484 |
+
self,
|
| 485 |
+
features: np.ndarray,
|
| 486 |
+
tokenizer: Tokenizer,
|
| 487 |
+
options: TranscriptionOptions,
|
| 488 |
+
encoder_output: Optional[ctranslate2.StorageView] = None,
|
| 489 |
+
) -> Iterable[Segment]:
|
| 490 |
+
content_frames = features.shape[-1] - self.feature_extractor.nb_max_frames
|
| 491 |
+
content_duration = float(content_frames * self.feature_extractor.time_per_frame)
|
| 492 |
+
|
| 493 |
+
if isinstance(options.clip_timestamps, str):
|
| 494 |
+
options = options._replace(
|
| 495 |
+
clip_timestamps=[
|
| 496 |
+
float(ts)
|
| 497 |
+
for ts in (
|
| 498 |
+
options.clip_timestamps.split(",")
|
| 499 |
+
if options.clip_timestamps
|
| 500 |
+
else []
|
| 501 |
+
)
|
| 502 |
+
]
|
| 503 |
+
)
|
| 504 |
+
seek_points: List[int] = [
|
| 505 |
+
round(ts * self.frames_per_second) for ts in options.clip_timestamps
|
| 506 |
+
]
|
| 507 |
+
if len(seek_points) == 0:
|
| 508 |
+
seek_points.append(0)
|
| 509 |
+
if len(seek_points) % 2 == 1:
|
| 510 |
+
seek_points.append(content_frames)
|
| 511 |
+
seek_clips: List[Tuple[int, int]] = list(
|
| 512 |
+
zip(seek_points[::2], seek_points[1::2])
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
|
| 516 |
+
|
| 517 |
+
idx = 0
|
| 518 |
+
clip_idx = 0
|
| 519 |
+
seek = seek_clips[clip_idx][0]
|
| 520 |
+
all_tokens = []
|
| 521 |
+
prompt_reset_since = 0
|
| 522 |
+
|
| 523 |
+
if options.initial_prompt is not None:
|
| 524 |
+
if isinstance(options.initial_prompt, str):
|
| 525 |
+
initial_prompt = " " + options.initial_prompt.strip()
|
| 526 |
+
initial_prompt_tokens = tokenizer.encode(initial_prompt)
|
| 527 |
+
all_tokens.extend(initial_prompt_tokens)
|
| 528 |
+
else:
|
| 529 |
+
all_tokens.extend(options.initial_prompt)
|
| 530 |
+
|
| 531 |
+
last_speech_timestamp = 0.0
|
| 532 |
+
# NOTE: This loop is obscurely flattened to make the diff readable.
|
| 533 |
+
# A later commit should turn this into a simpler nested loop.
|
| 534 |
+
# for seek_clip_start, seek_clip_end in seek_clips:
|
| 535 |
+
# while seek < seek_clip_end
|
| 536 |
+
while clip_idx < len(seek_clips):
|
| 537 |
+
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
|
| 538 |
+
if seek_clip_end > content_frames:
|
| 539 |
+
seek_clip_end = content_frames
|
| 540 |
+
if seek < seek_clip_start:
|
| 541 |
+
seek = seek_clip_start
|
| 542 |
+
if seek >= seek_clip_end:
|
| 543 |
+
clip_idx += 1
|
| 544 |
+
if clip_idx < len(seek_clips):
|
| 545 |
+
seek = seek_clips[clip_idx][0]
|
| 546 |
+
continue
|
| 547 |
+
time_offset = seek * self.feature_extractor.time_per_frame
|
| 548 |
+
window_end_time = float(
|
| 549 |
+
(seek + self.feature_extractor.nb_max_frames)
|
| 550 |
+
* self.feature_extractor.time_per_frame
|
| 551 |
+
)
|
| 552 |
+
segment_size = min(
|
| 553 |
+
self.feature_extractor.nb_max_frames,
|
| 554 |
+
content_frames - seek,
|
| 555 |
+
seek_clip_end - seek,
|
| 556 |
+
)
|
| 557 |
+
segment = features[:, seek : seek + segment_size]
|
| 558 |
+
segment_duration = segment_size * self.feature_extractor.time_per_frame
|
| 559 |
+
segment = pad_or_trim(segment, self.feature_extractor.nb_max_frames)
|
| 560 |
+
|
| 561 |
+
if self.logger.isEnabledFor(logging.DEBUG):
|
| 562 |
+
self.logger.debug(
|
| 563 |
+
"Processing segment at %s", format_timestamp(time_offset)
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
previous_tokens = all_tokens[prompt_reset_since:]
|
| 567 |
+
prompt = self.get_prompt(
|
| 568 |
+
tokenizer,
|
| 569 |
+
previous_tokens,
|
| 570 |
+
without_timestamps=options.without_timestamps,
|
| 571 |
+
prefix=options.prefix if seek == 0 else None,
|
| 572 |
+
hotwords=options.hotwords,
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
if seek > 0 or encoder_output is None:
|
| 576 |
+
encoder_output = self.encode(segment)
|
| 577 |
+
|
| 578 |
+
(
|
| 579 |
+
result,
|
| 580 |
+
avg_logprob,
|
| 581 |
+
temperature,
|
| 582 |
+
compression_ratio,
|
| 583 |
+
) = self.generate_with_fallback(encoder_output, prompt, tokenizer, options)
|
| 584 |
+
|
| 585 |
+
if options.no_speech_threshold is not None:
|
| 586 |
+
# no voice activity check
|
| 587 |
+
should_skip = result.no_speech_prob > options.no_speech_threshold
|
| 588 |
+
|
| 589 |
+
if (
|
| 590 |
+
options.log_prob_threshold is not None
|
| 591 |
+
and avg_logprob > options.log_prob_threshold
|
| 592 |
+
):
|
| 593 |
+
# don't skip if the logprob is high enough, despite the no_speech_prob
|
| 594 |
+
should_skip = False
|
| 595 |
+
|
| 596 |
+
if should_skip:
|
| 597 |
+
self.logger.debug(
|
| 598 |
+
"No speech threshold is met (%f > %f)",
|
| 599 |
+
result.no_speech_prob,
|
| 600 |
+
options.no_speech_threshold,
|
| 601 |
+
)
|
| 602 |
+
|
| 603 |
+
# fast-forward to the next segment boundary
|
| 604 |
+
seek += segment_size
|
| 605 |
+
continue
|
| 606 |
+
|
| 607 |
+
tokens = result.sequences_ids[0]
|
| 608 |
+
|
| 609 |
+
previous_seek = seek
|
| 610 |
+
current_segments = []
|
| 611 |
+
|
| 612 |
+
# anomalous words are very long/short/improbable
|
| 613 |
+
def word_anomaly_score(word: dict) -> float:
|
| 614 |
+
probability = word.get("probability", 0.0)
|
| 615 |
+
duration = word["end"] - word["start"]
|
| 616 |
+
score = 0.0
|
| 617 |
+
if probability < 0.15:
|
| 618 |
+
score += 1.0
|
| 619 |
+
if duration < 0.133:
|
| 620 |
+
score += (0.133 - duration) * 15
|
| 621 |
+
if duration > 2.0:
|
| 622 |
+
score += duration - 2.0
|
| 623 |
+
return score
|
| 624 |
+
|
| 625 |
+
def is_segment_anomaly(segment: Optional[dict]) -> bool:
|
| 626 |
+
if segment is None or not segment["words"]:
|
| 627 |
+
return False
|
| 628 |
+
words = [w for w in segment["words"] if w["word"] not in punctuation]
|
| 629 |
+
words = words[:8]
|
| 630 |
+
score = sum(word_anomaly_score(w) for w in words)
|
| 631 |
+
return score >= 3 or score + 0.01 >= len(words)
|
| 632 |
+
|
| 633 |
+
def next_words_segment(segments: List[dict]) -> Optional[dict]:
|
| 634 |
+
return next((s for s in segments if s["words"]), None)
|
| 635 |
+
|
| 636 |
+
single_timestamp_ending = (
|
| 637 |
+
len(tokens) >= 2
|
| 638 |
+
and tokens[-2] < tokenizer.timestamp_begin <= tokens[-1]
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
consecutive_timestamps = [
|
| 642 |
+
i
|
| 643 |
+
for i in range(len(tokens))
|
| 644 |
+
if i > 0
|
| 645 |
+
and tokens[i] >= tokenizer.timestamp_begin
|
| 646 |
+
and tokens[i - 1] >= tokenizer.timestamp_begin
|
| 647 |
+
]
|
| 648 |
+
|
| 649 |
+
if len(consecutive_timestamps) > 0:
|
| 650 |
+
slices = list(consecutive_timestamps)
|
| 651 |
+
if single_timestamp_ending:
|
| 652 |
+
slices.append(len(tokens))
|
| 653 |
+
|
| 654 |
+
last_slice = 0
|
| 655 |
+
for current_slice in slices:
|
| 656 |
+
sliced_tokens = tokens[last_slice:current_slice]
|
| 657 |
+
start_timestamp_position = (
|
| 658 |
+
sliced_tokens[0] - tokenizer.timestamp_begin
|
| 659 |
+
)
|
| 660 |
+
end_timestamp_position = (
|
| 661 |
+
sliced_tokens[-1] - tokenizer.timestamp_begin
|
| 662 |
+
)
|
| 663 |
+
start_time = (
|
| 664 |
+
time_offset + start_timestamp_position * self.time_precision
|
| 665 |
+
)
|
| 666 |
+
end_time = (
|
| 667 |
+
time_offset + end_timestamp_position * self.time_precision
|
| 668 |
+
)
|
| 669 |
+
|
| 670 |
+
current_segments.append(
|
| 671 |
+
dict(
|
| 672 |
+
seek=seek,
|
| 673 |
+
start=start_time,
|
| 674 |
+
end=end_time,
|
| 675 |
+
tokens=sliced_tokens,
|
| 676 |
+
)
|
| 677 |
+
)
|
| 678 |
+
last_slice = current_slice
|
| 679 |
+
|
| 680 |
+
if single_timestamp_ending:
|
| 681 |
+
# single timestamp at the end means no speech after the last timestamp.
|
| 682 |
+
seek += segment_size
|
| 683 |
+
else:
|
| 684 |
+
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
| 685 |
+
last_timestamp_position = (
|
| 686 |
+
tokens[last_slice - 1] - tokenizer.timestamp_begin
|
| 687 |
+
)
|
| 688 |
+
seek += last_timestamp_position * self.input_stride
|
| 689 |
+
|
| 690 |
+
else:
|
| 691 |
+
duration = segment_duration
|
| 692 |
+
timestamps = [
|
| 693 |
+
token for token in tokens if token >= tokenizer.timestamp_begin
|
| 694 |
+
]
|
| 695 |
+
if len(timestamps) > 0 and timestamps[-1] != tokenizer.timestamp_begin:
|
| 696 |
+
last_timestamp_position = timestamps[-1] - tokenizer.timestamp_begin
|
| 697 |
+
duration = last_timestamp_position * self.time_precision
|
| 698 |
+
|
| 699 |
+
current_segments.append(
|
| 700 |
+
dict(
|
| 701 |
+
seek=seek,
|
| 702 |
+
start=time_offset,
|
| 703 |
+
end=time_offset + duration,
|
| 704 |
+
tokens=tokens,
|
| 705 |
+
)
|
| 706 |
+
)
|
| 707 |
+
|
| 708 |
+
seek += segment_size
|
| 709 |
+
|
| 710 |
+
if options.word_timestamps:
|
| 711 |
+
self.add_word_timestamps(
|
| 712 |
+
current_segments,
|
| 713 |
+
tokenizer,
|
| 714 |
+
encoder_output,
|
| 715 |
+
segment_size,
|
| 716 |
+
options.prepend_punctuations,
|
| 717 |
+
options.append_punctuations,
|
| 718 |
+
last_speech_timestamp=last_speech_timestamp,
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
if not single_timestamp_ending:
|
| 722 |
+
last_word_end = get_end(current_segments)
|
| 723 |
+
if last_word_end is not None and last_word_end > time_offset:
|
| 724 |
+
seek = round(last_word_end * self.frames_per_second)
|
| 725 |
+
|
| 726 |
+
# skip silence before possible hallucinations
|
| 727 |
+
if options.hallucination_silence_threshold is not None:
|
| 728 |
+
threshold = options.hallucination_silence_threshold
|
| 729 |
+
|
| 730 |
+
# if first segment might be a hallucination, skip leading silence
|
| 731 |
+
first_segment = next_words_segment(current_segments)
|
| 732 |
+
if first_segment is not None and is_segment_anomaly(first_segment):
|
| 733 |
+
gap = first_segment["start"] - time_offset
|
| 734 |
+
if gap > threshold:
|
| 735 |
+
seek = previous_seek + round(gap * self.frames_per_second)
|
| 736 |
+
continue
|
| 737 |
+
|
| 738 |
+
# skip silence before any possible hallucination that is surrounded
|
| 739 |
+
# by silence or more hallucinations
|
| 740 |
+
hal_last_end = last_speech_timestamp
|
| 741 |
+
for si in range(len(current_segments)):
|
| 742 |
+
segment = current_segments[si]
|
| 743 |
+
if not segment["words"]:
|
| 744 |
+
continue
|
| 745 |
+
if is_segment_anomaly(segment):
|
| 746 |
+
next_segment = next_words_segment(
|
| 747 |
+
current_segments[si + 1 :]
|
| 748 |
+
)
|
| 749 |
+
if next_segment is not None:
|
| 750 |
+
hal_next_start = next_segment["words"][0]["start"]
|
| 751 |
+
else:
|
| 752 |
+
hal_next_start = time_offset + segment_duration
|
| 753 |
+
silence_before = (
|
| 754 |
+
segment["start"] - hal_last_end > threshold
|
| 755 |
+
or segment["start"] < threshold
|
| 756 |
+
or segment["start"] - time_offset < 2.0
|
| 757 |
+
)
|
| 758 |
+
silence_after = (
|
| 759 |
+
hal_next_start - segment["end"] > threshold
|
| 760 |
+
or is_segment_anomaly(next_segment)
|
| 761 |
+
or window_end_time - segment["end"] < 2.0
|
| 762 |
+
)
|
| 763 |
+
if silence_before and silence_after:
|
| 764 |
+
seek = round(
|
| 765 |
+
max(time_offset + 1, segment["start"])
|
| 766 |
+
* self.frames_per_second
|
| 767 |
+
)
|
| 768 |
+
if content_duration - segment["end"] < threshold:
|
| 769 |
+
seek = content_frames
|
| 770 |
+
current_segments[si:] = []
|
| 771 |
+
break
|
| 772 |
+
hal_last_end = segment["end"]
|
| 773 |
+
|
| 774 |
+
last_word_end = get_end(current_segments)
|
| 775 |
+
if last_word_end is not None:
|
| 776 |
+
last_speech_timestamp = last_word_end
|
| 777 |
+
|
| 778 |
+
for segment in current_segments:
|
| 779 |
+
tokens = segment["tokens"]
|
| 780 |
+
text = tokenizer.decode(tokens)
|
| 781 |
+
|
| 782 |
+
if segment["start"] == segment["end"] or not text.strip():
|
| 783 |
+
continue
|
| 784 |
+
|
| 785 |
+
all_tokens.extend(tokens)
|
| 786 |
+
idx += 1
|
| 787 |
+
|
| 788 |
+
yield Segment(
|
| 789 |
+
id=idx,
|
| 790 |
+
seek=seek,
|
| 791 |
+
start=segment["start"],
|
| 792 |
+
end=segment["end"],
|
| 793 |
+
text=text,
|
| 794 |
+
tokens=tokens,
|
| 795 |
+
temperature=temperature,
|
| 796 |
+
avg_logprob=avg_logprob,
|
| 797 |
+
compression_ratio=compression_ratio,
|
| 798 |
+
no_speech_prob=result.no_speech_prob,
|
| 799 |
+
words=(
|
| 800 |
+
[Word(**word) for word in segment["words"]]
|
| 801 |
+
if options.word_timestamps
|
| 802 |
+
else None
|
| 803 |
+
),
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
if (
|
| 807 |
+
not options.condition_on_previous_text
|
| 808 |
+
or temperature > options.prompt_reset_on_temperature
|
| 809 |
+
):
|
| 810 |
+
if options.condition_on_previous_text:
|
| 811 |
+
self.logger.debug(
|
| 812 |
+
"Reset prompt. prompt_reset_on_temperature threshold is met %f > %f",
|
| 813 |
+
temperature,
|
| 814 |
+
options.prompt_reset_on_temperature,
|
| 815 |
+
)
|
| 816 |
+
|
| 817 |
+
prompt_reset_since = len(all_tokens)
|
| 818 |
+
|
| 819 |
+
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
|
| 820 |
+
# When the model is running on multiple GPUs, the encoder output should be moved
|
| 821 |
+
# to the CPU since we don't know which GPU will handle the next job.
|
| 822 |
+
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
|
| 823 |
+
|
| 824 |
+
features = np.expand_dims(features, 0)
|
| 825 |
+
features = get_ctranslate2_storage(features)
|
| 826 |
+
|
| 827 |
+
return self.model.encode(features, to_cpu=to_cpu)
|
| 828 |
+
|
| 829 |
+
def generate_with_fallback(
|
| 830 |
+
self,
|
| 831 |
+
encoder_output: ctranslate2.StorageView,
|
| 832 |
+
prompt: List[int],
|
| 833 |
+
tokenizer: Tokenizer,
|
| 834 |
+
options: TranscriptionOptions,
|
| 835 |
+
) -> Tuple[ctranslate2.models.WhisperGenerationResult, float, float, float]:
|
| 836 |
+
decode_result = None
|
| 837 |
+
all_results = []
|
| 838 |
+
below_cr_threshold_results = []
|
| 839 |
+
|
| 840 |
+
max_initial_timestamp_index = int(
|
| 841 |
+
round(options.max_initial_timestamp / self.time_precision)
|
| 842 |
+
)
|
| 843 |
+
if options.max_new_tokens is not None:
|
| 844 |
+
max_length = len(prompt) + options.max_new_tokens
|
| 845 |
+
else:
|
| 846 |
+
max_length = self.max_length
|
| 847 |
+
|
| 848 |
+
if max_length > self.max_length:
|
| 849 |
+
raise ValueError(
|
| 850 |
+
f"The length of the prompt is {len(prompt)}, and the `max_new_tokens` "
|
| 851 |
+
f"{max_length - len(prompt)}. Thus, the combined length of the prompt "
|
| 852 |
+
f"and `max_new_tokens` is: {max_length}. This exceeds the "
|
| 853 |
+
f"`max_length` of the Whisper model: {self.max_length}. "
|
| 854 |
+
"You should either reduce the length of your prompt, or "
|
| 855 |
+
"reduce the value of `max_new_tokens`, "
|
| 856 |
+
f"so that their combined length is less that {self.max_length}."
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
for temperature in options.temperatures:
|
| 860 |
+
if temperature > 0:
|
| 861 |
+
kwargs = {
|
| 862 |
+
"beam_size": 1,
|
| 863 |
+
"num_hypotheses": options.best_of,
|
| 864 |
+
"sampling_topk": 0,
|
| 865 |
+
"sampling_temperature": temperature,
|
| 866 |
+
}
|
| 867 |
+
else:
|
| 868 |
+
kwargs = {
|
| 869 |
+
"beam_size": options.beam_size,
|
| 870 |
+
"patience": options.patience,
|
| 871 |
+
}
|
| 872 |
+
|
| 873 |
+
result = self.model.generate(
|
| 874 |
+
encoder_output,
|
| 875 |
+
[prompt],
|
| 876 |
+
length_penalty=options.length_penalty,
|
| 877 |
+
repetition_penalty=options.repetition_penalty,
|
| 878 |
+
no_repeat_ngram_size=options.no_repeat_ngram_size,
|
| 879 |
+
max_length=max_length,
|
| 880 |
+
return_scores=True,
|
| 881 |
+
return_no_speech_prob=True,
|
| 882 |
+
suppress_blank=options.suppress_blank,
|
| 883 |
+
suppress_tokens=options.suppress_tokens,
|
| 884 |
+
max_initial_timestamp_index=max_initial_timestamp_index,
|
| 885 |
+
**kwargs,
|
| 886 |
+
)[0]
|
| 887 |
+
|
| 888 |
+
tokens = result.sequences_ids[0]
|
| 889 |
+
|
| 890 |
+
# Recover the average log prob from the returned score.
|
| 891 |
+
seq_len = len(tokens)
|
| 892 |
+
cum_logprob = result.scores[0] * (seq_len**options.length_penalty)
|
| 893 |
+
avg_logprob = cum_logprob / (seq_len + 1)
|
| 894 |
+
|
| 895 |
+
text = tokenizer.decode(tokens).strip()
|
| 896 |
+
compression_ratio = get_compression_ratio(text)
|
| 897 |
+
|
| 898 |
+
decode_result = (
|
| 899 |
+
result,
|
| 900 |
+
avg_logprob,
|
| 901 |
+
temperature,
|
| 902 |
+
compression_ratio,
|
| 903 |
+
)
|
| 904 |
+
all_results.append(decode_result)
|
| 905 |
+
|
| 906 |
+
needs_fallback = False
|
| 907 |
+
|
| 908 |
+
if options.compression_ratio_threshold is not None:
|
| 909 |
+
if compression_ratio > options.compression_ratio_threshold:
|
| 910 |
+
needs_fallback = True # too repetitive
|
| 911 |
+
|
| 912 |
+
self.logger.debug(
|
| 913 |
+
"Compression ratio threshold is not met with temperature %.1f (%f > %f)",
|
| 914 |
+
temperature,
|
| 915 |
+
compression_ratio,
|
| 916 |
+
options.compression_ratio_threshold,
|
| 917 |
+
)
|
| 918 |
+
else:
|
| 919 |
+
below_cr_threshold_results.append(decode_result)
|
| 920 |
+
|
| 921 |
+
if (
|
| 922 |
+
options.log_prob_threshold is not None
|
| 923 |
+
and avg_logprob < options.log_prob_threshold
|
| 924 |
+
):
|
| 925 |
+
needs_fallback = True # average log probability is too low
|
| 926 |
+
|
| 927 |
+
self.logger.debug(
|
| 928 |
+
"Log probability threshold is not met with temperature %.1f (%f < %f)",
|
| 929 |
+
temperature,
|
| 930 |
+
avg_logprob,
|
| 931 |
+
options.log_prob_threshold,
|
| 932 |
+
)
|
| 933 |
+
|
| 934 |
+
if (
|
| 935 |
+
options.no_speech_threshold is not None
|
| 936 |
+
and result.no_speech_prob > options.no_speech_threshold
|
| 937 |
+
and options.log_prob_threshold is not None
|
| 938 |
+
and avg_logprob < options.log_prob_threshold
|
| 939 |
+
):
|
| 940 |
+
needs_fallback = False # silence
|
| 941 |
+
|
| 942 |
+
if not needs_fallback:
|
| 943 |
+
break
|
| 944 |
+
else:
|
| 945 |
+
# all failed, select the result with the highest average log probability
|
| 946 |
+
decode_result = max(
|
| 947 |
+
below_cr_threshold_results or all_results, key=lambda x: x[1]
|
| 948 |
+
)
|
| 949 |
+
# to pass final temperature for prompt_reset_on_temperature
|
| 950 |
+
decode_result = (
|
| 951 |
+
decode_result[0],
|
| 952 |
+
decode_result[1],
|
| 953 |
+
temperature,
|
| 954 |
+
decode_result[3],
|
| 955 |
+
)
|
| 956 |
+
|
| 957 |
+
return decode_result
|
| 958 |
+
|
| 959 |
+
def get_prompt(
|
| 960 |
+
self,
|
| 961 |
+
tokenizer: Tokenizer,
|
| 962 |
+
previous_tokens: List[int],
|
| 963 |
+
without_timestamps: bool = False,
|
| 964 |
+
prefix: Optional[str] = None,
|
| 965 |
+
hotwords: Optional[str] = None,
|
| 966 |
+
) -> List[int]:
|
| 967 |
+
prompt = []
|
| 968 |
+
|
| 969 |
+
if previous_tokens or (hotwords and not prefix):
|
| 970 |
+
prompt.append(tokenizer.sot_prev)
|
| 971 |
+
if hotwords and not prefix:
|
| 972 |
+
hotwords_tokens = tokenizer.encode(" " + hotwords.strip())
|
| 973 |
+
if len(hotwords_tokens) >= self.max_length // 2:
|
| 974 |
+
hotwords_tokens = hotwords_tokens[: self.max_length // 2 - 1]
|
| 975 |
+
prompt.extend(hotwords_tokens)
|
| 976 |
+
if previous_tokens:
|
| 977 |
+
prompt.extend(previous_tokens[-(self.max_length // 2 - 1) :])
|
| 978 |
+
|
| 979 |
+
prompt.extend(tokenizer.sot_sequence)
|
| 980 |
+
|
| 981 |
+
if without_timestamps:
|
| 982 |
+
prompt.append(tokenizer.no_timestamps)
|
| 983 |
+
|
| 984 |
+
if prefix:
|
| 985 |
+
prefix_tokens = tokenizer.encode(" " + prefix.strip())
|
| 986 |
+
if len(prefix_tokens) >= self.max_length // 2:
|
| 987 |
+
prefix_tokens = prefix_tokens[: self.max_length // 2 - 1]
|
| 988 |
+
if not without_timestamps:
|
| 989 |
+
prompt.append(tokenizer.timestamp_begin)
|
| 990 |
+
prompt.extend(prefix_tokens)
|
| 991 |
+
|
| 992 |
+
return prompt
|
| 993 |
+
|
| 994 |
+
def add_word_timestamps(
|
| 995 |
+
self,
|
| 996 |
+
segments: List[dict],
|
| 997 |
+
tokenizer: Tokenizer,
|
| 998 |
+
encoder_output: ctranslate2.StorageView,
|
| 999 |
+
num_frames: int,
|
| 1000 |
+
prepend_punctuations: str,
|
| 1001 |
+
append_punctuations: str,
|
| 1002 |
+
last_speech_timestamp: float,
|
| 1003 |
+
) -> None:
|
| 1004 |
+
if len(segments) == 0:
|
| 1005 |
+
return
|
| 1006 |
+
|
| 1007 |
+
text_tokens_per_segment = [
|
| 1008 |
+
[token for token in segment["tokens"] if token < tokenizer.eot]
|
| 1009 |
+
for segment in segments
|
| 1010 |
+
]
|
| 1011 |
+
|
| 1012 |
+
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
|
| 1013 |
+
alignment = self.find_alignment(
|
| 1014 |
+
tokenizer, text_tokens, encoder_output, num_frames
|
| 1015 |
+
)
|
| 1016 |
+
word_durations = np.array([word["end"] - word["start"] for word in alignment])
|
| 1017 |
+
word_durations = word_durations[word_durations.nonzero()]
|
| 1018 |
+
median_duration = np.median(word_durations) if len(word_durations) > 0 else 0.0
|
| 1019 |
+
median_duration = min(0.7, float(median_duration))
|
| 1020 |
+
max_duration = median_duration * 2
|
| 1021 |
+
|
| 1022 |
+
# hack: truncate long words at sentence boundaries.
|
| 1023 |
+
# a better segmentation algorithm based on VAD should be able to replace this.
|
| 1024 |
+
if len(word_durations) > 0:
|
| 1025 |
+
sentence_end_marks = ".。!!??"
|
| 1026 |
+
# ensure words at sentence boundaries
|
| 1027 |
+
# are not longer than twice the median word duration.
|
| 1028 |
+
for i in range(1, len(alignment)):
|
| 1029 |
+
if alignment[i]["end"] - alignment[i]["start"] > max_duration:
|
| 1030 |
+
if alignment[i]["word"] in sentence_end_marks:
|
| 1031 |
+
alignment[i]["end"] = alignment[i]["start"] + max_duration
|
| 1032 |
+
elif alignment[i - 1]["word"] in sentence_end_marks:
|
| 1033 |
+
alignment[i]["start"] = alignment[i]["end"] - max_duration
|
| 1034 |
+
|
| 1035 |
+
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
|
| 1036 |
+
|
| 1037 |
+
time_offset = (
|
| 1038 |
+
segments[0]["seek"]
|
| 1039 |
+
* self.feature_extractor.hop_length
|
| 1040 |
+
/ self.feature_extractor.sampling_rate
|
| 1041 |
+
)
|
| 1042 |
+
|
| 1043 |
+
word_index = 0
|
| 1044 |
+
|
| 1045 |
+
for segment, text_tokens in zip(segments, text_tokens_per_segment):
|
| 1046 |
+
saved_tokens = 0
|
| 1047 |
+
words = []
|
| 1048 |
+
|
| 1049 |
+
while word_index < len(alignment) and saved_tokens < len(text_tokens):
|
| 1050 |
+
timing = alignment[word_index]
|
| 1051 |
+
|
| 1052 |
+
if timing["word"]:
|
| 1053 |
+
words.append(
|
| 1054 |
+
dict(
|
| 1055 |
+
word=timing["word"],
|
| 1056 |
+
start=round(time_offset + timing["start"], 2),
|
| 1057 |
+
end=round(time_offset + timing["end"], 2),
|
| 1058 |
+
probability=timing["probability"],
|
| 1059 |
+
)
|
| 1060 |
+
)
|
| 1061 |
+
|
| 1062 |
+
saved_tokens += len(timing["tokens"])
|
| 1063 |
+
word_index += 1
|
| 1064 |
+
|
| 1065 |
+
# hack: truncate long words at segment boundaries.
|
| 1066 |
+
# a better segmentation algorithm based on VAD should be able to replace this.
|
| 1067 |
+
if len(words) > 0:
|
| 1068 |
+
# ensure the first and second word after a pause is not longer than
|
| 1069 |
+
# twice the median word duration.
|
| 1070 |
+
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
|
| 1071 |
+
words[0]["end"] - words[0]["start"] > max_duration
|
| 1072 |
+
or (
|
| 1073 |
+
len(words) > 1
|
| 1074 |
+
and words[1]["end"] - words[0]["start"] > max_duration * 2
|
| 1075 |
+
)
|
| 1076 |
+
):
|
| 1077 |
+
if (
|
| 1078 |
+
len(words) > 1
|
| 1079 |
+
and words[1]["end"] - words[1]["start"] > max_duration
|
| 1080 |
+
):
|
| 1081 |
+
boundary = max(
|
| 1082 |
+
words[1]["end"] / 2, words[1]["end"] - max_duration
|
| 1083 |
+
)
|
| 1084 |
+
words[0]["end"] = words[1]["start"] = boundary
|
| 1085 |
+
words[0]["start"] = max(0, words[0]["end"] - max_duration)
|
| 1086 |
+
|
| 1087 |
+
# prefer the segment-level start timestamp if the first word is too long.
|
| 1088 |
+
if (
|
| 1089 |
+
segment["start"] < words[0]["end"]
|
| 1090 |
+
and segment["start"] - 0.5 > words[0]["start"]
|
| 1091 |
+
):
|
| 1092 |
+
words[0]["start"] = max(
|
| 1093 |
+
0, min(words[0]["end"] - median_duration, segment["start"])
|
| 1094 |
+
)
|
| 1095 |
+
else:
|
| 1096 |
+
segment["start"] = words[0]["start"]
|
| 1097 |
+
|
| 1098 |
+
# prefer the segment-level end timestamp if the last word is too long.
|
| 1099 |
+
if (
|
| 1100 |
+
segment["end"] > words[-1]["start"]
|
| 1101 |
+
and segment["end"] + 0.5 < words[-1]["end"]
|
| 1102 |
+
):
|
| 1103 |
+
words[-1]["end"] = max(
|
| 1104 |
+
words[-1]["start"] + median_duration, segment["end"]
|
| 1105 |
+
)
|
| 1106 |
+
else:
|
| 1107 |
+
segment["end"] = words[-1]["end"]
|
| 1108 |
+
|
| 1109 |
+
last_speech_timestamp = segment["end"]
|
| 1110 |
+
|
| 1111 |
+
segment["words"] = words
|
| 1112 |
+
|
| 1113 |
+
def find_alignment(
|
| 1114 |
+
self,
|
| 1115 |
+
tokenizer: Tokenizer,
|
| 1116 |
+
text_tokens: List[int],
|
| 1117 |
+
encoder_output: ctranslate2.StorageView,
|
| 1118 |
+
num_frames: int,
|
| 1119 |
+
median_filter_width: int = 7,
|
| 1120 |
+
) -> List[dict]:
|
| 1121 |
+
if len(text_tokens) == 0:
|
| 1122 |
+
return []
|
| 1123 |
+
|
| 1124 |
+
result = self.model.align(
|
| 1125 |
+
encoder_output,
|
| 1126 |
+
tokenizer.sot_sequence,
|
| 1127 |
+
[text_tokens],
|
| 1128 |
+
num_frames,
|
| 1129 |
+
median_filter_width=median_filter_width,
|
| 1130 |
+
)[0]
|
| 1131 |
+
|
| 1132 |
+
text_token_probs = result.text_token_probs
|
| 1133 |
+
|
| 1134 |
+
alignments = result.alignments
|
| 1135 |
+
text_indices = np.array([pair[0] for pair in alignments])
|
| 1136 |
+
time_indices = np.array([pair[1] for pair in alignments])
|
| 1137 |
+
|
| 1138 |
+
words, word_tokens = tokenizer.split_to_word_tokens(
|
| 1139 |
+
text_tokens + [tokenizer.eot]
|
| 1140 |
+
)
|
| 1141 |
+
if len(word_tokens) <= 1:
|
| 1142 |
+
# return on eot only
|
| 1143 |
+
# >>> np.pad([], (1, 0))
|
| 1144 |
+
# array([0.])
|
| 1145 |
+
# This results in crashes when we lookup jump_times with float, like
|
| 1146 |
+
# IndexError: arrays used as indices must be of integer (or boolean) type
|
| 1147 |
+
return []
|
| 1148 |
+
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
|
| 1149 |
+
if len(word_boundaries) <= 1:
|
| 1150 |
+
return []
|
| 1151 |
+
|
| 1152 |
+
jumps = np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)
|
| 1153 |
+
jump_times = time_indices[jumps] / self.tokens_per_second
|
| 1154 |
+
start_times = jump_times[word_boundaries[:-1]]
|
| 1155 |
+
end_times = jump_times[word_boundaries[1:]]
|
| 1156 |
+
word_probabilities = [
|
| 1157 |
+
np.mean(text_token_probs[i:j])
|
| 1158 |
+
for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
|
| 1159 |
+
]
|
| 1160 |
+
|
| 1161 |
+
return [
|
| 1162 |
+
dict(
|
| 1163 |
+
word=word, tokens=tokens, start=start, end=end, probability=probability
|
| 1164 |
+
)
|
| 1165 |
+
for word, tokens, start, end, probability in zip(
|
| 1166 |
+
words, word_tokens, start_times, end_times, word_probabilities
|
| 1167 |
+
)
|
| 1168 |
+
]
|
| 1169 |
+
|
| 1170 |
+
|
| 1171 |
+
def restore_speech_timestamps(
|
| 1172 |
+
segments: Iterable[Segment],
|
| 1173 |
+
speech_chunks: List[dict],
|
| 1174 |
+
sampling_rate: int,
|
| 1175 |
+
) -> Iterable[Segment]:
|
| 1176 |
+
ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
|
| 1177 |
+
|
| 1178 |
+
for segment in segments:
|
| 1179 |
+
if segment.words:
|
| 1180 |
+
words = []
|
| 1181 |
+
for word in segment.words:
|
| 1182 |
+
# Ensure the word start and end times are resolved to the same chunk.
|
| 1183 |
+
middle = (word.start + word.end) / 2
|
| 1184 |
+
chunk_index = ts_map.get_chunk_index(middle)
|
| 1185 |
+
word = word._replace(
|
| 1186 |
+
start=ts_map.get_original_time(word.start, chunk_index),
|
| 1187 |
+
end=ts_map.get_original_time(word.end, chunk_index),
|
| 1188 |
+
)
|
| 1189 |
+
words.append(word)
|
| 1190 |
+
|
| 1191 |
+
segment = segment._replace(
|
| 1192 |
+
start=words[0].start,
|
| 1193 |
+
end=words[-1].end,
|
| 1194 |
+
words=words,
|
| 1195 |
+
)
|
| 1196 |
+
|
| 1197 |
+
else:
|
| 1198 |
+
segment = segment._replace(
|
| 1199 |
+
start=ts_map.get_original_time(segment.start),
|
| 1200 |
+
end=ts_map.get_original_time(segment.end),
|
| 1201 |
+
)
|
| 1202 |
+
|
| 1203 |
+
yield segment
|
| 1204 |
+
|
| 1205 |
+
|
| 1206 |
+
def get_ctranslate2_storage(segment: np.ndarray) -> ctranslate2.StorageView:
|
| 1207 |
+
segment = np.ascontiguousarray(segment)
|
| 1208 |
+
segment = ctranslate2.StorageView.from_array(segment)
|
| 1209 |
+
return segment
|
| 1210 |
+
|
| 1211 |
+
|
| 1212 |
+
def get_compression_ratio(text: str) -> float:
|
| 1213 |
+
text_bytes = text.encode("utf-8")
|
| 1214 |
+
return len(text_bytes) / len(zlib.compress(text_bytes))
|
| 1215 |
+
|
| 1216 |
+
|
| 1217 |
+
def get_suppressed_tokens(
|
| 1218 |
+
tokenizer: Tokenizer,
|
| 1219 |
+
suppress_tokens: Optional[List[int]],
|
| 1220 |
+
) -> Optional[List[int]]:
|
| 1221 |
+
if not suppress_tokens or -1 in suppress_tokens:
|
| 1222 |
+
return suppress_tokens
|
| 1223 |
+
|
| 1224 |
+
suppress_tokens = list(suppress_tokens)
|
| 1225 |
+
|
| 1226 |
+
# Ensure the following special tokens are suppressed when the user does
|
| 1227 |
+
# not use the default set (-1).
|
| 1228 |
+
suppress_tokens.extend(
|
| 1229 |
+
[
|
| 1230 |
+
tokenizer.transcribe,
|
| 1231 |
+
tokenizer.translate,
|
| 1232 |
+
tokenizer.sot,
|
| 1233 |
+
tokenizer.sot_prev,
|
| 1234 |
+
tokenizer.sot_lm,
|
| 1235 |
+
]
|
| 1236 |
+
)
|
| 1237 |
+
|
| 1238 |
+
return sorted(set(suppress_tokens))
|
| 1239 |
+
|
| 1240 |
+
|
| 1241 |
+
def merge_punctuations(alignment: List[dict], prepended: str, appended: str) -> None:
|
| 1242 |
+
# merge prepended punctuations
|
| 1243 |
+
i = len(alignment) - 2
|
| 1244 |
+
j = len(alignment) - 1
|
| 1245 |
+
while i >= 0:
|
| 1246 |
+
previous = alignment[i]
|
| 1247 |
+
following = alignment[j]
|
| 1248 |
+
if previous["word"].startswith(" ") and previous["word"].strip() in prepended:
|
| 1249 |
+
# prepend it to the following word
|
| 1250 |
+
following["word"] = previous["word"] + following["word"]
|
| 1251 |
+
following["tokens"] = previous["tokens"] + following["tokens"]
|
| 1252 |
+
previous["word"] = ""
|
| 1253 |
+
previous["tokens"] = []
|
| 1254 |
+
else:
|
| 1255 |
+
j = i
|
| 1256 |
+
i -= 1
|
| 1257 |
+
|
| 1258 |
+
# merge appended punctuations
|
| 1259 |
+
i = 0
|
| 1260 |
+
j = 1
|
| 1261 |
+
while j < len(alignment):
|
| 1262 |
+
previous = alignment[i]
|
| 1263 |
+
following = alignment[j]
|
| 1264 |
+
if not previous["word"].endswith(" ") and following["word"] in appended:
|
| 1265 |
+
# append it to the previous word
|
| 1266 |
+
previous["word"] = previous["word"] + following["word"]
|
| 1267 |
+
previous["tokens"] = previous["tokens"] + following["tokens"]
|
| 1268 |
+
following["word"] = ""
|
| 1269 |
+
following["tokens"] = []
|
| 1270 |
+
else:
|
| 1271 |
+
i = j
|
| 1272 |
+
j += 1
|
utils.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
from typing import List, Optional
|
| 6 |
+
|
| 7 |
+
import huggingface_hub
|
| 8 |
+
import requests
|
| 9 |
+
|
| 10 |
+
from tqdm.auto import tqdm
|
| 11 |
+
|
| 12 |
+
_MODELS = {
|
| 13 |
+
"tiny.en": "Systran/faster-whisper-tiny.en",
|
| 14 |
+
"tiny": "Systran/faster-whisper-tiny",
|
| 15 |
+
"base.en": "Systran/faster-whisper-base.en",
|
| 16 |
+
"base": "Systran/faster-whisper-base",
|
| 17 |
+
"small.en": "Systran/faster-whisper-small.en",
|
| 18 |
+
"small": "Systran/faster-whisper-small",
|
| 19 |
+
"medium.en": "Systran/faster-whisper-medium.en",
|
| 20 |
+
"medium": "Systran/faster-whisper-medium",
|
| 21 |
+
"large-v1": "Systran/faster-whisper-large-v1",
|
| 22 |
+
"large-v2": "Systran/faster-whisper-large-v2",
|
| 23 |
+
"large-v3": "Systran/faster-whisper-large-v3",
|
| 24 |
+
"large": "Systran/faster-whisper-large-v3",
|
| 25 |
+
"distil-large-v2": "Systran/faster-distil-whisper-large-v2",
|
| 26 |
+
"distil-medium.en": "Systran/faster-distil-whisper-medium.en",
|
| 27 |
+
"distil-small.en": "Systran/faster-distil-whisper-small.en",
|
| 28 |
+
"distil-large-v3": "Systran/faster-distil-whisper-large-v3",
|
| 29 |
+
}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def available_models() -> List[str]:
|
| 33 |
+
"""Returns the names of available models."""
|
| 34 |
+
return list(_MODELS.keys())
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def get_assets_path():
|
| 38 |
+
"""Returns the path to the assets directory."""
|
| 39 |
+
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_logger():
|
| 43 |
+
"""Returns the module logger."""
|
| 44 |
+
return logging.getLogger("faster_whisper")
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def download_model(
|
| 48 |
+
size_or_id: str,
|
| 49 |
+
output_dir: Optional[str] = None,
|
| 50 |
+
local_files_only: bool = False,
|
| 51 |
+
cache_dir: Optional[str] = None,
|
| 52 |
+
):
|
| 53 |
+
"""Downloads a CTranslate2 Whisper model from the Hugging Face Hub.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
size_or_id: Size of the model to download from https://huggingface.co/Systran
|
| 57 |
+
(tiny, tiny.en, base, base.en, small, small.en, distil-small.en, medium, medium.en,
|
| 58 |
+
distil-medium.en, large-v1, large-v2, large-v3, large, distil-large-v2,
|
| 59 |
+
distil-large-v3), or a CTranslate2-converted model ID from the Hugging Face Hub
|
| 60 |
+
(e.g. Systran/faster-whisper-large-v3).
|
| 61 |
+
output_dir: Directory where the model should be saved. If not set, the model is saved in
|
| 62 |
+
the cache directory.
|
| 63 |
+
local_files_only: If True, avoid downloading the file and return the path to the local
|
| 64 |
+
cached file if it exists.
|
| 65 |
+
cache_dir: Path to the folder where cached files are stored.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
The path to the downloaded model.
|
| 69 |
+
|
| 70 |
+
Raises:
|
| 71 |
+
ValueError: if the model size is invalid.
|
| 72 |
+
"""
|
| 73 |
+
if re.match(r".*/.*", size_or_id):
|
| 74 |
+
repo_id = size_or_id
|
| 75 |
+
else:
|
| 76 |
+
repo_id = _MODELS.get(size_or_id)
|
| 77 |
+
if repo_id is None:
|
| 78 |
+
raise ValueError(
|
| 79 |
+
"Invalid model size '%s', expected one of: %s"
|
| 80 |
+
% (size_or_id, ", ".join(_MODELS.keys()))
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
allow_patterns = [
|
| 84 |
+
"config.json",
|
| 85 |
+
"preprocessor_config.json",
|
| 86 |
+
"model.bin",
|
| 87 |
+
"tokenizer.json",
|
| 88 |
+
"vocabulary.*",
|
| 89 |
+
]
|
| 90 |
+
|
| 91 |
+
kwargs = {
|
| 92 |
+
"local_files_only": local_files_only,
|
| 93 |
+
"allow_patterns": allow_patterns,
|
| 94 |
+
"tqdm_class": disabled_tqdm,
|
| 95 |
+
}
|
| 96 |
+
|
| 97 |
+
if output_dir is not None:
|
| 98 |
+
kwargs["local_dir"] = output_dir
|
| 99 |
+
kwargs["local_dir_use_symlinks"] = False
|
| 100 |
+
|
| 101 |
+
if cache_dir is not None:
|
| 102 |
+
kwargs["cache_dir"] = cache_dir
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
return huggingface_hub.snapshot_download(repo_id, **kwargs)
|
| 106 |
+
except (
|
| 107 |
+
huggingface_hub.utils.HfHubHTTPError,
|
| 108 |
+
requests.exceptions.ConnectionError,
|
| 109 |
+
) as exception:
|
| 110 |
+
logger = get_logger()
|
| 111 |
+
logger.warning(
|
| 112 |
+
"An error occured while synchronizing the model %s from the Hugging Face Hub:\n%s",
|
| 113 |
+
repo_id,
|
| 114 |
+
exception,
|
| 115 |
+
)
|
| 116 |
+
logger.warning(
|
| 117 |
+
"Trying to load the model directly from the local cache, if it exists."
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
kwargs["local_files_only"] = True
|
| 121 |
+
return huggingface_hub.snapshot_download(repo_id, **kwargs)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def format_timestamp(
|
| 125 |
+
seconds: float,
|
| 126 |
+
always_include_hours: bool = False,
|
| 127 |
+
decimal_marker: str = ".",
|
| 128 |
+
) -> str:
|
| 129 |
+
assert seconds >= 0, "non-negative timestamp expected"
|
| 130 |
+
milliseconds = round(seconds * 1000.0)
|
| 131 |
+
|
| 132 |
+
hours = milliseconds // 3_600_000
|
| 133 |
+
milliseconds -= hours * 3_600_000
|
| 134 |
+
|
| 135 |
+
minutes = milliseconds // 60_000
|
| 136 |
+
milliseconds -= minutes * 60_000
|
| 137 |
+
|
| 138 |
+
seconds = milliseconds // 1_000
|
| 139 |
+
milliseconds -= seconds * 1_000
|
| 140 |
+
|
| 141 |
+
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
|
| 142 |
+
return (
|
| 143 |
+
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class disabled_tqdm(tqdm):
|
| 148 |
+
def __init__(self, *args, **kwargs):
|
| 149 |
+
kwargs["disable"] = True
|
| 150 |
+
super().__init__(*args, **kwargs)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def get_end(segments: List[dict]) -> Optional[float]:
|
| 154 |
+
return next(
|
| 155 |
+
(w["end"] for s in reversed(segments) for w in reversed(s["words"])),
|
| 156 |
+
segments[-1]["end"] if segments else None,
|
| 157 |
+
)
|
vad.py
ADDED
|
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import bisect
|
| 2 |
+
import functools
|
| 3 |
+
import os
|
| 4 |
+
import warnings
|
| 5 |
+
|
| 6 |
+
from typing import List, NamedTuple, Optional
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from faster_whisper.utils import get_assets_path
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# The code below is adapted from https://github.com/snakers4/silero-vad.
|
| 14 |
+
class VadOptions(NamedTuple):
|
| 15 |
+
"""VAD options.
|
| 16 |
+
|
| 17 |
+
Attributes:
|
| 18 |
+
threshold: Speech threshold. Silero VAD outputs speech probabilities for each audio chunk,
|
| 19 |
+
probabilities ABOVE this value are considered as SPEECH. It is better to tune this
|
| 20 |
+
parameter for each dataset separately, but "lazy" 0.5 is pretty good for most datasets.
|
| 21 |
+
min_speech_duration_ms: Final speech chunks shorter min_speech_duration_ms are thrown out.
|
| 22 |
+
max_speech_duration_s: Maximum duration of speech chunks in seconds. Chunks longer
|
| 23 |
+
than max_speech_duration_s will be split at the timestamp of the last silence that
|
| 24 |
+
lasts more than 100ms (if any), to prevent aggressive cutting. Otherwise, they will be
|
| 25 |
+
split aggressively just before max_speech_duration_s.
|
| 26 |
+
min_silence_duration_ms: In the end of each speech chunk wait for min_silence_duration_ms
|
| 27 |
+
before separating it
|
| 28 |
+
window_size_samples: Audio chunks of window_size_samples size are fed to the silero VAD model.
|
| 29 |
+
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
|
| 30 |
+
Values other than these may affect model performance!!
|
| 31 |
+
speech_pad_ms: Final speech chunks are padded by speech_pad_ms each side
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
threshold: float = 0.5
|
| 35 |
+
min_speech_duration_ms: int = 250
|
| 36 |
+
max_speech_duration_s: float = float("inf")
|
| 37 |
+
min_silence_duration_ms: int = 2000
|
| 38 |
+
window_size_samples: int = 1024
|
| 39 |
+
speech_pad_ms: int = 400
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_speech_timestamps(
|
| 43 |
+
audio: np.ndarray,
|
| 44 |
+
vad_options: Optional[VadOptions] = None,
|
| 45 |
+
**kwargs,
|
| 46 |
+
) -> List[dict]:
|
| 47 |
+
"""This method is used for splitting long audios into speech chunks using silero VAD.
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
audio: One dimensional float array.
|
| 51 |
+
vad_options: Options for VAD processing.
|
| 52 |
+
kwargs: VAD options passed as keyword arguments for backward compatibility.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
List of dicts containing begin and end samples of each speech chunk.
|
| 56 |
+
"""
|
| 57 |
+
if vad_options is None:
|
| 58 |
+
vad_options = VadOptions(**kwargs)
|
| 59 |
+
|
| 60 |
+
threshold = vad_options.threshold
|
| 61 |
+
min_speech_duration_ms = vad_options.min_speech_duration_ms
|
| 62 |
+
max_speech_duration_s = vad_options.max_speech_duration_s
|
| 63 |
+
min_silence_duration_ms = vad_options.min_silence_duration_ms
|
| 64 |
+
window_size_samples = vad_options.window_size_samples
|
| 65 |
+
speech_pad_ms = vad_options.speech_pad_ms
|
| 66 |
+
|
| 67 |
+
if window_size_samples not in [512, 1024, 1536]:
|
| 68 |
+
warnings.warn(
|
| 69 |
+
"Unusual window_size_samples! Supported window_size_samples:\n"
|
| 70 |
+
" - [512, 1024, 1536] for 16000 sampling_rate"
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
sampling_rate = 16000
|
| 74 |
+
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
| 75 |
+
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
| 76 |
+
max_speech_samples = (
|
| 77 |
+
sampling_rate * max_speech_duration_s
|
| 78 |
+
- window_size_samples
|
| 79 |
+
- 2 * speech_pad_samples
|
| 80 |
+
)
|
| 81 |
+
min_silence_samples = sampling_rate * min_silence_duration_ms / 1000
|
| 82 |
+
min_silence_samples_at_max_speech = sampling_rate * 98 / 1000
|
| 83 |
+
|
| 84 |
+
audio_length_samples = len(audio)
|
| 85 |
+
|
| 86 |
+
model = get_vad_model()
|
| 87 |
+
state = model.get_initial_state(batch_size=1)
|
| 88 |
+
|
| 89 |
+
speech_probs = []
|
| 90 |
+
for current_start_sample in range(0, audio_length_samples, window_size_samples):
|
| 91 |
+
chunk = audio[current_start_sample : current_start_sample + window_size_samples]
|
| 92 |
+
if len(chunk) < window_size_samples:
|
| 93 |
+
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
| 94 |
+
speech_prob, state = model(chunk, state, sampling_rate)
|
| 95 |
+
speech_probs.append(speech_prob)
|
| 96 |
+
|
| 97 |
+
triggered = False
|
| 98 |
+
speeches = []
|
| 99 |
+
current_speech = {}
|
| 100 |
+
neg_threshold = threshold - 0.15
|
| 101 |
+
|
| 102 |
+
# to save potential segment end (and tolerate some silence)
|
| 103 |
+
temp_end = 0
|
| 104 |
+
# to save potential segment limits in case of maximum segment size reached
|
| 105 |
+
prev_end = next_start = 0
|
| 106 |
+
|
| 107 |
+
for i, speech_prob in enumerate(speech_probs):
|
| 108 |
+
if (speech_prob >= threshold) and temp_end:
|
| 109 |
+
temp_end = 0
|
| 110 |
+
if next_start < prev_end:
|
| 111 |
+
next_start = window_size_samples * i
|
| 112 |
+
|
| 113 |
+
if (speech_prob >= threshold) and not triggered:
|
| 114 |
+
triggered = True
|
| 115 |
+
current_speech["start"] = window_size_samples * i
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
if (
|
| 119 |
+
triggered
|
| 120 |
+
and (window_size_samples * i) - current_speech["start"] > max_speech_samples
|
| 121 |
+
):
|
| 122 |
+
if prev_end:
|
| 123 |
+
current_speech["end"] = prev_end
|
| 124 |
+
speeches.append(current_speech)
|
| 125 |
+
current_speech = {}
|
| 126 |
+
# previously reached silence (< neg_thres) and is still not speech (< thres)
|
| 127 |
+
if next_start < prev_end:
|
| 128 |
+
triggered = False
|
| 129 |
+
else:
|
| 130 |
+
current_speech["start"] = next_start
|
| 131 |
+
prev_end = next_start = temp_end = 0
|
| 132 |
+
else:
|
| 133 |
+
current_speech["end"] = window_size_samples * i
|
| 134 |
+
speeches.append(current_speech)
|
| 135 |
+
current_speech = {}
|
| 136 |
+
prev_end = next_start = temp_end = 0
|
| 137 |
+
triggered = False
|
| 138 |
+
continue
|
| 139 |
+
|
| 140 |
+
if (speech_prob < neg_threshold) and triggered:
|
| 141 |
+
if not temp_end:
|
| 142 |
+
temp_end = window_size_samples * i
|
| 143 |
+
# condition to avoid cutting in very short silence
|
| 144 |
+
if (window_size_samples * i) - temp_end > min_silence_samples_at_max_speech:
|
| 145 |
+
prev_end = temp_end
|
| 146 |
+
if (window_size_samples * i) - temp_end < min_silence_samples:
|
| 147 |
+
continue
|
| 148 |
+
else:
|
| 149 |
+
current_speech["end"] = temp_end
|
| 150 |
+
if (
|
| 151 |
+
current_speech["end"] - current_speech["start"]
|
| 152 |
+
) > min_speech_samples:
|
| 153 |
+
speeches.append(current_speech)
|
| 154 |
+
current_speech = {}
|
| 155 |
+
prev_end = next_start = temp_end = 0
|
| 156 |
+
triggered = False
|
| 157 |
+
continue
|
| 158 |
+
|
| 159 |
+
if (
|
| 160 |
+
current_speech
|
| 161 |
+
and (audio_length_samples - current_speech["start"]) > min_speech_samples
|
| 162 |
+
):
|
| 163 |
+
current_speech["end"] = audio_length_samples
|
| 164 |
+
speeches.append(current_speech)
|
| 165 |
+
|
| 166 |
+
for i, speech in enumerate(speeches):
|
| 167 |
+
if i == 0:
|
| 168 |
+
speech["start"] = int(max(0, speech["start"] - speech_pad_samples))
|
| 169 |
+
if i != len(speeches) - 1:
|
| 170 |
+
silence_duration = speeches[i + 1]["start"] - speech["end"]
|
| 171 |
+
if silence_duration < 2 * speech_pad_samples:
|
| 172 |
+
speech["end"] += int(silence_duration // 2)
|
| 173 |
+
speeches[i + 1]["start"] = int(
|
| 174 |
+
max(0, speeches[i + 1]["start"] - silence_duration // 2)
|
| 175 |
+
)
|
| 176 |
+
else:
|
| 177 |
+
speech["end"] = int(
|
| 178 |
+
min(audio_length_samples, speech["end"] + speech_pad_samples)
|
| 179 |
+
)
|
| 180 |
+
speeches[i + 1]["start"] = int(
|
| 181 |
+
max(0, speeches[i + 1]["start"] - speech_pad_samples)
|
| 182 |
+
)
|
| 183 |
+
else:
|
| 184 |
+
speech["end"] = int(
|
| 185 |
+
min(audio_length_samples, speech["end"] + speech_pad_samples)
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
return speeches
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
|
| 192 |
+
"""Collects and concatenates audio chunks."""
|
| 193 |
+
if not chunks:
|
| 194 |
+
return np.array([], dtype=np.float32)
|
| 195 |
+
|
| 196 |
+
return np.concatenate([audio[chunk["start"] : chunk["end"]] for chunk in chunks])
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class SpeechTimestampsMap:
|
| 200 |
+
"""Helper class to restore original speech timestamps."""
|
| 201 |
+
|
| 202 |
+
def __init__(self, chunks: List[dict], sampling_rate: int, time_precision: int = 2):
|
| 203 |
+
self.sampling_rate = sampling_rate
|
| 204 |
+
self.time_precision = time_precision
|
| 205 |
+
self.chunk_end_sample = []
|
| 206 |
+
self.total_silence_before = []
|
| 207 |
+
|
| 208 |
+
previous_end = 0
|
| 209 |
+
silent_samples = 0
|
| 210 |
+
|
| 211 |
+
for chunk in chunks:
|
| 212 |
+
silent_samples += chunk["start"] - previous_end
|
| 213 |
+
previous_end = chunk["end"]
|
| 214 |
+
|
| 215 |
+
self.chunk_end_sample.append(chunk["end"] - silent_samples)
|
| 216 |
+
self.total_silence_before.append(silent_samples / sampling_rate)
|
| 217 |
+
|
| 218 |
+
def get_original_time(
|
| 219 |
+
self,
|
| 220 |
+
time: float,
|
| 221 |
+
chunk_index: Optional[int] = None,
|
| 222 |
+
) -> float:
|
| 223 |
+
if chunk_index is None:
|
| 224 |
+
chunk_index = self.get_chunk_index(time)
|
| 225 |
+
|
| 226 |
+
total_silence_before = self.total_silence_before[chunk_index]
|
| 227 |
+
return round(total_silence_before + time, self.time_precision)
|
| 228 |
+
|
| 229 |
+
def get_chunk_index(self, time: float) -> int:
|
| 230 |
+
sample = int(time * self.sampling_rate)
|
| 231 |
+
return min(
|
| 232 |
+
bisect.bisect(self.chunk_end_sample, sample),
|
| 233 |
+
len(self.chunk_end_sample) - 1,
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
@functools.lru_cache
|
| 238 |
+
def get_vad_model():
|
| 239 |
+
"""Returns the VAD model instance."""
|
| 240 |
+
path = os.path.join(get_assets_path(), "silero_vad.onnx")
|
| 241 |
+
return SileroVADModel(path)
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class SileroVADModel:
|
| 245 |
+
def __init__(self, path):
|
| 246 |
+
try:
|
| 247 |
+
import onnxruntime
|
| 248 |
+
except ImportError as e:
|
| 249 |
+
raise RuntimeError(
|
| 250 |
+
"Applying the VAD filter requires the onnxruntime package"
|
| 251 |
+
) from e
|
| 252 |
+
|
| 253 |
+
opts = onnxruntime.SessionOptions()
|
| 254 |
+
opts.inter_op_num_threads = 1
|
| 255 |
+
opts.intra_op_num_threads = 1
|
| 256 |
+
opts.log_severity_level = 4
|
| 257 |
+
|
| 258 |
+
self.session = onnxruntime.InferenceSession(
|
| 259 |
+
path,
|
| 260 |
+
providers=["CPUExecutionProvider"],
|
| 261 |
+
sess_options=opts,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
def get_initial_state(self, batch_size: int):
|
| 265 |
+
h = np.zeros((2, batch_size, 64), dtype=np.float32)
|
| 266 |
+
c = np.zeros((2, batch_size, 64), dtype=np.float32)
|
| 267 |
+
return h, c
|
| 268 |
+
|
| 269 |
+
def __call__(self, x, state, sr: int):
|
| 270 |
+
if len(x.shape) == 1:
|
| 271 |
+
x = np.expand_dims(x, 0)
|
| 272 |
+
if len(x.shape) > 2:
|
| 273 |
+
raise ValueError(
|
| 274 |
+
f"Too many dimensions for input audio chunk {len(x.shape)}"
|
| 275 |
+
)
|
| 276 |
+
if sr / x.shape[1] > 31.25:
|
| 277 |
+
raise ValueError("Input audio chunk is too short")
|
| 278 |
+
|
| 279 |
+
h, c = state
|
| 280 |
+
|
| 281 |
+
ort_inputs = {
|
| 282 |
+
"input": x,
|
| 283 |
+
"h": h,
|
| 284 |
+
"c": c,
|
| 285 |
+
"sr": np.array(sr, dtype="int64"),
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
out, h, c = self.session.run(None, ort_inputs)
|
| 289 |
+
state = (h, c)
|
| 290 |
+
|
| 291 |
+
return out, state
|
version.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Version information."""
|
| 2 |
+
|
| 3 |
+
__version__ = "1.0.2"
|