zhu-han commited on
Commit
9e4e0d2
·
1 Parent(s): 32ffa33

update to 0.1.4 version

Browse files
omnivoice/cli/demo.py CHANGED
@@ -213,8 +213,7 @@ def build_demo(
213
  except Exception as e:
214
  return None, f"Error: {type(e).__name__}: {e}"
215
 
216
- waveform = audio[0].squeeze(0).numpy() # (T,)
217
- waveform = (waveform * 32767).astype(np.int16)
218
  return (sampling_rate, waveform), "Done."
219
 
220
  # Allow external wrappers (e.g. spaces.GPU for ZeroGPU Spaces)
 
213
  except Exception as e:
214
  return None, f"Error: {type(e).__name__}: {e}"
215
 
216
+ waveform = (audio[0] * 32767).astype(np.int16)
 
217
  return (sampling_rate, waveform), "Done."
218
 
219
  # Allow external wrappers (e.g. spaces.GPU for ZeroGPU Spaces)
omnivoice/cli/infer.py CHANGED
@@ -23,7 +23,8 @@ import argparse
23
  import logging
24
 
25
  import torch
26
- import torchaudio
 
27
 
28
  from omnivoice.models.omnivoice import OmniVoice
29
  from omnivoice.utils.common import str2bool
@@ -149,7 +150,7 @@ def main():
149
  class_temperature=args.class_temperature,
150
  )
151
 
152
- torchaudio.save(args.output, audios[0], model.sampling_rate)
153
  logging.info(f"Saved to {args.output}")
154
 
155
 
 
23
  import logging
24
 
25
  import torch
26
+
27
+ import soundfile as sf
28
 
29
  from omnivoice.models.omnivoice import OmniVoice
30
  from omnivoice.utils.common import str2bool
 
150
  class_temperature=args.class_temperature,
151
  )
152
 
153
+ sf.write(args.output, audios[0], model.sampling_rate)
154
  logging.info(f"Saved to {args.output}")
155
 
156
 
omnivoice/cli/infer_batch.py CHANGED
@@ -42,10 +42,11 @@ from concurrent.futures import ProcessPoolExecutor, as_completed
42
  from typing import List, Optional, Tuple
43
 
44
  import torch
45
- import torchaudio
46
  from tqdm import tqdm
47
 
48
  from omnivoice.models.omnivoice import OmniVoice
 
 
49
  from omnivoice.utils.audio import load_audio
50
  from omnivoice.utils.common import str2bool
51
  from omnivoice.utils.data_utils import read_test_list
@@ -79,11 +80,17 @@ def get_parser():
79
  type=str,
80
  required=True,
81
  help="Path to the JSONL file containing test samples. "
82
- 'Each line is a JSON object: {"id": "name", "text": "...", '
83
- '"ref_audio": "/path.wav", "ref_text": "...", '
84
- '"language_id": "en", "language_name": "English", '
85
- '"duration": 10.0, "speed": 1.2}. '
86
- "language_id, language_name, duration, and speed are optional.",
 
 
 
 
 
 
87
  )
88
  parser.add_argument(
89
  "--res_dir",
@@ -135,8 +142,7 @@ def get_parser():
135
  "--batch_duration",
136
  type=float,
137
  default=1000.0,
138
- help="Maximum total duration (reference + generated) per batch (seconds). "
139
- "Only effective for parallel_chunk / no chunk mode.",
140
  )
141
  parser.add_argument(
142
  "--batch_size",
@@ -239,7 +245,7 @@ def process_init(rank_queue, model_checkpoint, warmup=0):
239
  dummy_ref_audio = (
240
  torch.randn(1, SAMPLING_RATE),
241
  SAMPLING_RATE,
242
- ) # 1s silence
243
  for i in range(warmup):
244
  worker_model.generate(
245
  text=["hello"],
@@ -255,40 +261,58 @@ def process_init(rank_queue, model_checkpoint, warmup=0):
255
  def estimate_sample_total_duration(
256
  duration_estimator: RuleDurationEstimator,
257
  text: str,
258
- ref_text: str,
259
- ref_audio_path: str,
260
  gen_duration: Optional[float] = None,
261
  ) -> float:
262
- ref_wav = load_audio(ref_audio_path, SAMPLING_RATE)
263
- ref_duration = ref_wav.shape[-1] / SAMPLING_RATE
 
 
 
 
 
 
 
 
 
264
 
265
  if gen_duration is None:
266
- gen_duration = duration_estimator.estimate_duration(
267
- text, ref_text, ref_duration, low_threshold=2.0
268
- )
 
 
 
 
 
269
 
270
  total_duration = ref_duration + gen_duration
271
  return total_duration
272
 
273
 
274
- def cluster_samples_by_duration(
275
  samples: List[Tuple],
276
  duration_estimator: RuleDurationEstimator,
277
- batch_duration: float,
278
- ) -> List[List[Tuple]]:
279
  sample_with_duration = []
280
  for sample in samples:
281
- save_name, ref_text, ref_audio_path, text, lang_id, lang_name, dur, spd = sample
282
  total_duration = estimate_sample_total_duration(
283
- duration_estimator,
284
- text,
285
- ref_text,
286
- ref_audio_path,
287
- gen_duration=dur,
288
  )
289
  sample_with_duration.append((sample, total_duration))
290
-
291
  sample_with_duration.sort(key=lambda x: x[1], reverse=True)
 
 
 
 
 
 
 
 
 
292
  batches = []
293
  current_batch = []
294
  current_total_duration = 0.0
@@ -319,19 +343,7 @@ def cluster_samples_by_batch_size(
319
  batch_size: int,
320
  ) -> List[List[Tuple]]:
321
  """Split samples into fixed-size batches, sorted by duration to minimize padding."""
322
- sample_with_duration = []
323
- for sample in samples:
324
- save_name, ref_text, ref_audio_path, text, lang_id, lang_name, dur, spd = sample
325
- total_duration = estimate_sample_total_duration(
326
- duration_estimator,
327
- text,
328
- ref_text,
329
- ref_audio_path,
330
- gen_duration=dur,
331
- )
332
- sample_with_duration.append((sample, total_duration))
333
-
334
- sample_with_duration.sort(key=lambda x: x[1], reverse=True)
335
  sorted_samples = [s for s, _ in sample_with_duration]
336
 
337
  batches = [
@@ -359,9 +371,10 @@ def run_inference_batch(
359
  langs = []
360
  durations = []
361
  speeds = []
 
362
 
363
  for sample in batch_samples:
364
- save_name, ref_text, ref_audio_path, text, lang_id, lang_name, dur, spd = sample
365
  save_names.append(save_name)
366
  ref_texts.append(ref_text)
367
  ref_audio_paths.append(ref_audio_path)
@@ -369,15 +382,17 @@ def run_inference_batch(
369
  langs.append(lang_id)
370
  durations.append(dur)
371
  speeds.append(spd)
 
372
 
373
  start_time = time.time()
374
  audios = worker_model.generate(
375
  text=texts,
376
  language=langs,
377
- ref_audio=ref_audio_paths,
378
- ref_text=ref_texts,
379
  duration=durations if any(d is not None for d in durations) else None,
380
  speed=speeds if any(s is not None for s in speeds) else None,
 
381
  **gen_kwargs,
382
  )
383
  batch_synth_time = time.time() - start_time
@@ -385,7 +400,7 @@ def run_inference_batch(
385
  results = []
386
  for save_name, audio in zip(save_names, audios):
387
  save_path = os.path.join(res_dir, save_name + ".wav")
388
- torchaudio.save(save_path, audio, worker_model.sampling_rate)
389
  audio_duration = audio.shape[-1] / worker_model.sampling_rate
390
  results.append(
391
  (
@@ -436,13 +451,14 @@ def main():
436
  samples.append(
437
  (
438
  s["id"],
439
- s["ref_text"],
440
- s["ref_audio"],
441
  s["text"],
442
  lang_id,
443
  lang_name,
444
  s.get("duration"),
445
  s.get("speed"),
 
446
  )
447
  )
448
 
@@ -457,18 +473,32 @@ def main():
457
  ) as executor:
458
  futures = []
459
 
460
- # parallel_chunk / no chunk
461
  logging.info("Running batch inference")
462
 
 
 
 
 
 
 
 
463
  duration_estimator = RuleDurationEstimator()
464
- if args.batch_size > 0:
465
- batches = cluster_samples_by_batch_size(
466
- samples, duration_estimator, args.batch_size
467
- )
468
- else:
469
- batches = cluster_samples_by_duration(
470
- samples, duration_estimator, args.batch_duration
471
- )
 
 
 
 
 
 
 
 
472
 
473
  args_dict = vars(args)
474
 
 
42
  from typing import List, Optional, Tuple
43
 
44
  import torch
 
45
  from tqdm import tqdm
46
 
47
  from omnivoice.models.omnivoice import OmniVoice
48
+ import soundfile as sf
49
+
50
  from omnivoice.utils.audio import load_audio
51
  from omnivoice.utils.common import str2bool
52
  from omnivoice.utils.data_utils import read_test_list
 
80
  type=str,
81
  required=True,
82
  help="Path to the JSONL file containing test samples. "
83
+ "Each line is a JSON object with the following fields: "
84
+ '"id" (str, required): unique name for the output file; '
85
+ '"text" (str, required): text to synthesize; '
86
+ '"ref_audio" (str): path to reference audio for voice cloning; '
87
+ '"ref_text" (str): transcript of the reference audio; '
88
+ '"instruct" (str): instruction for voice design (used when ref_audio is absent); '
89
+ '"language_id" (str): language code, e.g. "en"; '
90
+ '"language_name" (str): language name, e.g. "English"; '
91
+ '"duration" (float): target duration in seconds; '
92
+ '"speed" (float): speaking speed multiplier. '
93
+ "Only id and text are required; all other fields are optional.",
94
  )
95
  parser.add_argument(
96
  "--res_dir",
 
142
  "--batch_duration",
143
  type=float,
144
  default=1000.0,
145
+ help="Maximum total duration (reference + generated) per batch (seconds).",
 
146
  )
147
  parser.add_argument(
148
  "--batch_size",
 
245
  dummy_ref_audio = (
246
  torch.randn(1, SAMPLING_RATE),
247
  SAMPLING_RATE,
248
+ ) # 1s dummy audio
249
  for i in range(warmup):
250
  worker_model.generate(
251
  text=["hello"],
 
261
  def estimate_sample_total_duration(
262
  duration_estimator: RuleDurationEstimator,
263
  text: str,
264
+ ref_text: Optional[str],
265
+ ref_audio_path: Optional[str],
266
  gen_duration: Optional[float] = None,
267
  ) -> float:
268
+ """Estimate total duration (ref + generated) for a single sample.
269
+
270
+ When ``ref_audio_path`` is ``None`` (instruct / voice-design mode),
271
+ the reference duration is treated as 0 and only the estimated generated
272
+ duration contributes to the total.
273
+ """
274
+ if ref_audio_path is not None:
275
+ ref_wav = load_audio(ref_audio_path, SAMPLING_RATE)
276
+ ref_duration = ref_wav.shape[-1] / SAMPLING_RATE
277
+ else:
278
+ ref_duration = 0
279
 
280
  if gen_duration is None:
281
+ if ref_audio_path is not None:
282
+ gen_duration = duration_estimator.estimate_duration(
283
+ text, ref_text or "", ref_duration, low_threshold=2.0
284
+ )
285
+ else:
286
+ gen_duration = duration_estimator.estimate_duration(
287
+ text, "Nice to meet you.", 0.5, low_threshold=2.0
288
+ )
289
 
290
  total_duration = ref_duration + gen_duration
291
  return total_duration
292
 
293
 
294
+ def _sort_samples_by_duration(
295
  samples: List[Tuple],
296
  duration_estimator: RuleDurationEstimator,
297
+ ) -> List[Tuple[Tuple, float]]:
298
+ """Return (sample, total_duration) pairs sorted by duration descending."""
299
  sample_with_duration = []
300
  for sample in samples:
301
+ _, ref_text, ref_audio_path, text, _, _, dur, _, _ = sample
302
  total_duration = estimate_sample_total_duration(
303
+ duration_estimator, text, ref_text, ref_audio_path, gen_duration=dur
 
 
 
 
304
  )
305
  sample_with_duration.append((sample, total_duration))
 
306
  sample_with_duration.sort(key=lambda x: x[1], reverse=True)
307
+ return sample_with_duration
308
+
309
+
310
+ def cluster_samples_by_duration(
311
+ samples: List[Tuple],
312
+ duration_estimator: RuleDurationEstimator,
313
+ batch_duration: float,
314
+ ) -> List[List[Tuple]]:
315
+ sample_with_duration = _sort_samples_by_duration(samples, duration_estimator)
316
  batches = []
317
  current_batch = []
318
  current_total_duration = 0.0
 
343
  batch_size: int,
344
  ) -> List[List[Tuple]]:
345
  """Split samples into fixed-size batches, sorted by duration to minimize padding."""
346
+ sample_with_duration = _sort_samples_by_duration(samples, duration_estimator)
 
 
 
 
 
 
 
 
 
 
 
 
347
  sorted_samples = [s for s, _ in sample_with_duration]
348
 
349
  batches = [
 
371
  langs = []
372
  durations = []
373
  speeds = []
374
+ instructs = []
375
 
376
  for sample in batch_samples:
377
+ save_name, ref_text, ref_audio_path, text, lang_id, lang_name, dur, spd, instruct = sample
378
  save_names.append(save_name)
379
  ref_texts.append(ref_text)
380
  ref_audio_paths.append(ref_audio_path)
 
382
  langs.append(lang_id)
383
  durations.append(dur)
384
  speeds.append(spd)
385
+ instructs.append(instruct)
386
 
387
  start_time = time.time()
388
  audios = worker_model.generate(
389
  text=texts,
390
  language=langs,
391
+ ref_audio=ref_audio_paths if any(p is not None for p in ref_audio_paths) else None,
392
+ ref_text=ref_texts if any(t is not None for t in ref_texts) else None,
393
  duration=durations if any(d is not None for d in durations) else None,
394
  speed=speeds if any(s is not None for s in speeds) else None,
395
+ instruct=instructs if any(i is not None for i in instructs) else None,
396
  **gen_kwargs,
397
  )
398
  batch_synth_time = time.time() - start_time
 
400
  results = []
401
  for save_name, audio in zip(save_names, audios):
402
  save_path = os.path.join(res_dir, save_name + ".wav")
403
+ sf.write(save_path, audio, worker_model.sampling_rate)
404
  audio_duration = audio.shape[-1] / worker_model.sampling_rate
405
  results.append(
406
  (
 
451
  samples.append(
452
  (
453
  s["id"],
454
+ s.get("ref_text"),
455
+ s.get("ref_audio"),
456
  s["text"],
457
  lang_id,
458
  lang_name,
459
  s.get("duration"),
460
  s.get("speed"),
461
+ s.get("instruct"),
462
  )
463
  )
464
 
 
473
  ) as executor:
474
  futures = []
475
 
 
476
  logging.info("Running batch inference")
477
 
478
+ # Split samples by mode (voice-clone vs non-voice-clone) before
479
+ # clustering so that each batch is homogeneous. Mixing ref_audio
480
+ # and non-ref_audio samples in the same batch would crash in
481
+ # generate() → create_voice_clone_prompt().
482
+ clone_samples = [s for s in samples if s[2] is not None]
483
+ other_samples = [s for s in samples if s[2] is None]
484
+
485
  duration_estimator = RuleDurationEstimator()
486
+ batches = []
487
+ for subset in (clone_samples, other_samples):
488
+ if not subset:
489
+ continue
490
+ if args.batch_size > 0:
491
+ batches.extend(
492
+ cluster_samples_by_batch_size(
493
+ subset, duration_estimator, args.batch_size
494
+ )
495
+ )
496
+ else:
497
+ batches.extend(
498
+ cluster_samples_by_duration(
499
+ subset, duration_estimator, args.batch_duration
500
+ )
501
+ )
502
 
503
  args_dict = vars(args)
504
 
omnivoice/data/dataset.py CHANGED
@@ -44,8 +44,9 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple
44
 
45
  import torch
46
  import torch.distributed as dist
47
- import torchaudio
48
  import webdataset as wds
 
 
49
  from torch.utils.data import IterableDataset
50
 
51
 
@@ -54,12 +55,8 @@ def load_audio_webdataset(data, sample_rate: int = 24000, device="cpu"):
54
  Load audio from bytes data and resample to the target sample rate if needed.
55
  Return a tensor of shape (1, num_samples)
56
  """
57
- audio, sr = torchaudio.load(io.BytesIO(data))
58
  audio = audio.to(device)
59
- if audio.size(dim=0) > 1:
60
- audio = torch.mean(audio, dim=0)
61
- if sr != sample_rate:
62
- audio = torchaudio.functional.resample(audio, sr, sample_rate)
63
  return audio
64
 
65
 
@@ -433,13 +430,9 @@ class JsonlDatasetReader(IterableDataReader):
433
  )
434
  continue
435
  try:
436
- waveform, sr = torchaudio.load(audio_path)
437
- if waveform.shape[0] > 1:
438
- waveform = waveform.mean(dim=0, keepdim=True)
439
- if sr != self.sample_rate:
440
- waveform = torchaudio.functional.resample(
441
- waveform, sr, self.sample_rate
442
- )
443
  if self.normalize_audio:
444
  waveform = (waveform / (waveform.abs().max() + 1e-7)) * 0.9
445
  meta["audio_duration"] = waveform.shape[1] / self.sample_rate
 
44
 
45
  import torch
46
  import torch.distributed as dist
 
47
  import webdataset as wds
48
+
49
+ from omnivoice.utils.audio import load_audio, load_audio_bytes
50
  from torch.utils.data import IterableDataset
51
 
52
 
 
55
  Load audio from bytes data and resample to the target sample rate if needed.
56
  Return a tensor of shape (1, num_samples)
57
  """
58
+ audio = torch.from_numpy(load_audio_bytes(data, sample_rate))
59
  audio = audio.to(device)
 
 
 
 
60
  return audio
61
 
62
 
 
430
  )
431
  continue
432
  try:
433
+ waveform = torch.from_numpy(
434
+ load_audio(audio_path, self.sample_rate)
435
+ )
 
 
 
 
436
  if self.normalize_audio:
437
  waveform = (waveform / (waveform.abs().max() + 1e-7)) * 0.9
438
  meta["audio_duration"] = waveform.shape[1] / self.sample_rate
omnivoice/eval/mos/utmos.py CHANGED
@@ -32,7 +32,7 @@ import torch
32
  from tqdm import tqdm
33
 
34
  from omnivoice.eval.models.utmos import UTMOS22Strong
35
- from omnivoice.eval.utils import load_waveform
36
  from omnivoice.utils.data_utils import read_test_list
37
 
38
  warnings.filterwarnings("ignore")
@@ -140,7 +140,7 @@ def run_utmos_worker(file_idx, wav_path, language_name):
140
  return file_idx, wav_path, language_name, f"File not found: {wav_path}", "error"
141
 
142
  # Load and preprocess waveform
143
- speech = load_waveform(wav_path, worker_sr, device=worker_device)
144
 
145
  # Compute score
146
  # UTMOS expects input shape (Batch, Time)
 
32
  from tqdm import tqdm
33
 
34
  from omnivoice.eval.models.utmos import UTMOS22Strong
35
+ from omnivoice.eval.utils import load_eval_waveform
36
  from omnivoice.utils.data_utils import read_test_list
37
 
38
  warnings.filterwarnings("ignore")
 
140
  return file_idx, wav_path, language_name, f"File not found: {wav_path}", "error"
141
 
142
  # Load and preprocess waveform
143
+ speech = load_eval_waveform(wav_path, worker_sr, device=worker_device)
144
 
145
  # Compute score
146
  # UTMOS expects input shape (Batch, Time)
omnivoice/eval/speaker_similarity/sim.py CHANGED
@@ -33,7 +33,7 @@ import torch
33
  from tqdm import tqdm
34
 
35
  from omnivoice.eval.models.ecapa_tdnn_wavlm import ECAPA_TDNN_WAVLM
36
- from omnivoice.eval.utils import load_waveform
37
  from omnivoice.utils.data_utils import read_test_list
38
 
39
  warnings.filterwarnings("ignore")
@@ -144,7 +144,7 @@ def worker_init(
144
  @torch.no_grad()
145
  def get_embedding(wav_path: str) -> torch.Tensor:
146
  """Extract embedding for a single file."""
147
- speech = load_waveform(wav_path, worker_sr, device=worker_device, max_seconds=120)
148
  return worker_model([speech])
149
 
150
 
 
33
  from tqdm import tqdm
34
 
35
  from omnivoice.eval.models.ecapa_tdnn_wavlm import ECAPA_TDNN_WAVLM
36
+ from omnivoice.eval.utils import load_eval_waveform
37
  from omnivoice.utils.data_utils import read_test_list
38
 
39
  warnings.filterwarnings("ignore")
 
144
  @torch.no_grad()
145
  def get_embedding(wav_path: str) -> torch.Tensor:
146
  """Extract embedding for a single file."""
147
+ speech = load_eval_waveform(wav_path, worker_sr, device=worker_device, max_seconds=120)
148
  return worker_model([speech])
149
 
150
 
omnivoice/eval/utils.py CHANGED
@@ -23,7 +23,7 @@ import soundfile as sf
23
  import torch
24
 
25
 
26
- def load_waveform(
27
  fname: str,
28
  sample_rate: int,
29
  dtype: str = "float32",
 
23
  import torch
24
 
25
 
26
+ def load_eval_waveform(
27
  fname: str,
28
  sample_rate: int,
29
  dtype: str = "float32",
omnivoice/eval/wer/hubert.py CHANGED
@@ -31,7 +31,7 @@ import numpy as np
31
  import torch
32
  from tqdm import tqdm
33
 
34
- from omnivoice.eval.utils import load_waveform
35
  from omnivoice.eval.wer.common import process_one
36
  from omnivoice.utils.data_utils import read_test_list
37
 
@@ -166,7 +166,7 @@ def run_eval_worker(data_chunk, batch_size):
166
  try:
167
  dataset = [
168
  {
169
- "array": load_waveform(
170
  item["wav_path"], sample_rate=16000, return_numpy=True
171
  ),
172
  "sampling_rate": 16000,
 
31
  import torch
32
  from tqdm import tqdm
33
 
34
+ from omnivoice.eval.utils import load_eval_waveform
35
  from omnivoice.eval.wer.common import process_one
36
  from omnivoice.utils.data_utils import read_test_list
37
 
 
166
  try:
167
  dataset = [
168
  {
169
+ "array": load_eval_waveform(
170
  item["wav_path"], sample_rate=16000, return_numpy=True
171
  ),
172
  "sampling_rate": 16000,
omnivoice/eval/wer/minimax.py CHANGED
@@ -34,7 +34,7 @@ import torch
34
  import zhconv
35
  from tqdm import tqdm
36
 
37
- from omnivoice.eval.utils import load_waveform
38
  from omnivoice.eval.wer.common import log_metrics, process_one
39
  from omnivoice.eval.wer.text_norm_omni import text_normalize
40
  from omnivoice.utils.data_utils import read_test_list
@@ -275,7 +275,7 @@ class SpeechEvalDataset(torch.utils.data.Dataset):
275
 
276
  def __getitem__(self, index):
277
  item = self.data_list[index]
278
- waveform = load_waveform(item["wav_path"], sample_rate=16000, return_numpy=True)
279
  return {
280
  "array": waveform,
281
  "sampling_rate": 16000,
 
34
  import zhconv
35
  from tqdm import tqdm
36
 
37
+ from omnivoice.eval.utils import load_eval_waveform
38
  from omnivoice.eval.wer.common import log_metrics, process_one
39
  from omnivoice.eval.wer.text_norm_omni import text_normalize
40
  from omnivoice.utils.data_utils import read_test_list
 
275
 
276
  def __getitem__(self, index):
277
  item = self.data_list[index]
278
+ waveform = load_eval_waveform(item["wav_path"], sample_rate=16000, return_numpy=True)
279
  return {
280
  "array": waveform,
281
  "sampling_rate": 16000,
omnivoice/eval/wer/seedtts.py CHANGED
@@ -34,7 +34,7 @@ import zhconv
34
  from tqdm import tqdm
35
  from zhon.hanzi import punctuation
36
 
37
- from omnivoice.eval.utils import load_waveform
38
  from omnivoice.eval.wer.common import process_one
39
  from omnivoice.utils.data_utils import read_test_list
40
 
@@ -228,7 +228,7 @@ def run_eval_worker(data_chunk, lang, batch_size):
228
  # Load waveforms as arrays, truncating to 30s
229
  dataset = [
230
  {
231
- "array": load_waveform(
232
  item["wav_path"], sample_rate=16000, return_numpy=True
233
  )[: 16000 * 30],
234
  "sampling_rate": 16000,
 
34
  from tqdm import tqdm
35
  from zhon.hanzi import punctuation
36
 
37
+ from omnivoice.eval.utils import load_eval_waveform
38
  from omnivoice.eval.wer.common import process_one
39
  from omnivoice.utils.data_utils import read_test_list
40
 
 
228
  # Load waveforms as arrays, truncating to 30s
229
  dataset = [
230
  {
231
+ "array": load_eval_waveform(
232
  item["wav_path"], sample_rate=16000, return_numpy=True
233
  )[: 16000 * 30],
234
  "sampling_rate": 16000,
omnivoice/models/omnivoice.py CHANGED
@@ -36,10 +36,11 @@ from dataclasses import dataclass, fields
36
  from functools import partial
37
  from typing import Any, List, Optional, Union
38
 
 
 
39
  import torch
40
  import torch.nn as nn
41
  import torch.nn.functional as F
42
- import torchaudio
43
  from torch.nn.attention.flex_attention import create_block_mask
44
  from transformers import (
45
  AutoFeatureExtractor,
@@ -310,12 +311,14 @@ class OmniVoice(PreTrainedModel):
310
  @torch.inference_mode()
311
  def transcribe(
312
  self,
313
- audio: Union[str, tuple[torch.Tensor, int]],
314
  ) -> str:
315
  """Transcribe audio using the loaded Whisper ASR model.
316
 
317
  Args:
318
- audio: File path or (waveform, sample_rate) tuple.
 
 
319
 
320
  Returns:
321
  Transcribed text.
@@ -329,12 +332,11 @@ class OmniVoice(PreTrainedModel):
329
  return self._asr_pipe(audio)["text"].strip()
330
  else:
331
  waveform, sr = audio
332
- if waveform.dim() == 1:
333
- waveform = waveform.unsqueeze(0)
334
- if waveform.size(0) > 1:
335
- waveform = torch.mean(waveform, dim=0, keepdim=True)
336
  audio_input = {
337
- "array": waveform.squeeze(0).cpu().numpy(),
338
  "sampling_rate": sr,
339
  }
340
  return self._asr_pipe(audio_input)["text"].strip()
@@ -475,7 +477,7 @@ class OmniVoice(PreTrainedModel):
475
  speed: Union[float, list[Optional[float]], None] = None,
476
  generation_config: Optional[OmniVoiceGenerationConfig] = None,
477
  **kwargs,
478
- ) -> list[torch.Tensor]:
479
  """Generate speech audio given text in various modes.
480
 
481
  Supports three modes:
@@ -522,8 +524,10 @@ class OmniVoice(PreTrainedModel):
522
  audio_chunk_threshold: Only apply chunking if estimated audio
523
  duration exceeds this threshold (seconds).
524
  Returns:
525
- ``audios`` a list of 2-D ``torch.Tensor``, with the shape (1, T) and sampling rate
526
- consistent with the model's audio tokenizer (usually 24000 Hz).
 
 
527
  """
528
 
529
  if self.audio_tokenizer is None or self.text_tokenizer is None:
@@ -611,17 +615,19 @@ class OmniVoice(PreTrainedModel):
611
  ref_wav = load_audio(ref_audio, self.sampling_rate)
612
  else:
613
  waveform, sr = ref_audio
614
- if waveform.dim() == 1:
615
- waveform = waveform.unsqueeze(0)
616
- if waveform.size(0) > 1:
617
- waveform = torch.mean(waveform, dim=0, keepdim=True)
 
 
618
  if sr != self.sampling_rate:
619
- waveform = torchaudio.functional.resample(
620
- waveform, sr, self.sampling_rate
621
  )
622
  ref_wav = waveform
623
 
624
- ref_rms = torch.sqrt(torch.mean(torch.square(ref_wav))).item()
625
  if 0 < ref_rms < 0.1:
626
  ref_wav = ref_wav * 0.1 / ref_rms
627
 
@@ -640,13 +646,13 @@ class OmniVoice(PreTrainedModel):
640
  lead_sil=100,
641
  trail_sil=200,
642
  )
643
- if ref_wav.size(-1) == 0:
644
  raise ValueError(
645
  "Reference audio is empty after silence removal. "
646
  "Try setting preprocess_prompt=False."
647
  )
648
 
649
- ref_duration = ref_wav.size(-1) / self.sampling_rate
650
  if ref_duration > 20.0:
651
  logger.warning(
652
  "Reference audio is %.1fs long (>20s). This may cause slower "
@@ -664,10 +670,14 @@ class OmniVoice(PreTrainedModel):
664
  logger.debug("Auto-transcribed ref_text: %s", ref_text)
665
 
666
  chunk_size = self.audio_tokenizer.config.hop_length
667
- clip_size = int(ref_wav.size(-1) % chunk_size)
668
  ref_wav = ref_wav[:, :-clip_size] if clip_size > 0 else ref_wav
 
 
 
 
669
  ref_audio_tokens = self.audio_tokenizer.encode(
670
- ref_wav.unsqueeze(0).to(self.audio_tokenizer.device),
671
  ).audio_codes.squeeze(
672
  0
673
  ) # (C, T)
@@ -686,7 +696,7 @@ class OmniVoice(PreTrainedModel):
686
  tokens: Union[torch.Tensor, List[torch.Tensor]],
687
  rms: Union[float, None],
688
  gen_config: OmniVoiceGenerationConfig,
689
- ) -> torch.Tensor:
690
  """
691
  Args:
692
  tokens: Audio tokens — either a single tensor of shape
@@ -694,7 +704,7 @@ class OmniVoice(PreTrainedModel):
694
  rms: RMS of the reference audio for volume adjustment.
695
  gen_config: Generation config for post-processing options.
696
  Returns:
697
- Decoded and post-processed audio tensor of shape (1, T).
698
  """
699
  tokenizer_device = self.audio_tokenizer.device
700
  if isinstance(tokens, list):
@@ -702,6 +712,7 @@ class OmniVoice(PreTrainedModel):
702
  self.audio_tokenizer.decode(t.to(tokenizer_device).unsqueeze(0))
703
  .audio_values[0]
704
  .cpu()
 
705
  for t in tokens
706
  ]
707
  audio_waveform = cross_fade_chunks(chunk_audios, self.sampling_rate)
@@ -710,28 +721,30 @@ class OmniVoice(PreTrainedModel):
710
  self.audio_tokenizer.decode(tokens.to(tokenizer_device).unsqueeze(0))
711
  .audio_values[0]
712
  .cpu()
 
713
  )
714
 
715
- return self._post_process_audio(
716
  audio_waveform,
717
  postprocess_output=gen_config.postprocess_output,
718
  ref_rms=rms,
719
  )
 
720
 
721
  def _post_process_audio(
722
  self,
723
- generated_audio: torch.Tensor,
724
  postprocess_output: bool,
725
  ref_rms: Union[float, None],
726
- ) -> torch.Tensor:
727
  """Optionally remove long silences, adjust volume, and add edge padding.
728
 
729
  Args:
730
- generated_audio: Audio tensor of shape (1, T).
731
  postprocess_output: If True, remove long silences and apply fade/pad.
732
  ref_rms: RMS of the reference audio for volume normalisation.
733
  Returns:
734
- Processed audio tensor of shape (1, T).
735
  """
736
  if postprocess_output:
737
  generated_audio = remove_silence(
@@ -745,9 +758,7 @@ class OmniVoice(PreTrainedModel):
745
  if ref_rms is not None and ref_rms < 0.1:
746
  generated_audio = generated_audio * ref_rms / 0.1
747
  elif ref_rms is None:
748
- # No reference audio (voice design): peak-normalize to 0.5
749
- # to avoid clipping while keeping a comfortable volume level.
750
- peak = generated_audio.abs().max()
751
  if peak > 1e-6:
752
  generated_audio = generated_audio / peak * 0.5
753
 
@@ -1549,6 +1560,9 @@ def _combine_text(text, ref_text: Optional[str] = None) -> str:
1549
  # filter out newline / carriage-return characters
1550
  full_text = re.sub(r"[\r\n]+", "", full_text)
1551
 
 
 
 
1552
  # collapse consecutive spaces / tabs into a single space
1553
  full_text = re.sub(r"[ \t]+", " ", full_text)
1554
 
 
36
  from functools import partial
37
  from typing import Any, List, Optional, Union
38
 
39
+ import librosa
40
+ import numpy as np
41
  import torch
42
  import torch.nn as nn
43
  import torch.nn.functional as F
 
44
  from torch.nn.attention.flex_attention import create_block_mask
45
  from transformers import (
46
  AutoFeatureExtractor,
 
311
  @torch.inference_mode()
312
  def transcribe(
313
  self,
314
+ audio: Union[str, tuple],
315
  ) -> str:
316
  """Transcribe audio using the loaded Whisper ASR model.
317
 
318
  Args:
319
+ audio: File path or ``(waveform, sample_rate)`` tuple.
320
+ Waveform can be a numpy array or torch.Tensor of shape
321
+ ``(1, T)`` or ``(T,)``.
322
 
323
  Returns:
324
  Transcribed text.
 
332
  return self._asr_pipe(audio)["text"].strip()
333
  else:
334
  waveform, sr = audio
335
+ if isinstance(waveform, torch.Tensor):
336
+ waveform = waveform.cpu().numpy()
337
+ waveform = np.squeeze(waveform) # (1, T) or (T,) → (T,)
 
338
  audio_input = {
339
+ "array": waveform,
340
  "sampling_rate": sr,
341
  }
342
  return self._asr_pipe(audio_input)["text"].strip()
 
477
  speed: Union[float, list[Optional[float]], None] = None,
478
  generation_config: Optional[OmniVoiceGenerationConfig] = None,
479
  **kwargs,
480
+ ) -> list[np.ndarray]:
481
  """Generate speech audio given text in various modes.
482
 
483
  Supports three modes:
 
524
  audio_chunk_threshold: Only apply chunking if estimated audio
525
  duration exceeds this threshold (seconds).
526
  Returns:
527
+ ``audios`` a list of 1-D ``np.ndarray`` with shape ``(T,)`` and
528
+ sampling rate consistent with the model's audio tokenizer
529
+ (usually 24 000 Hz). Can be saved directly with
530
+ ``soundfile.write("out.wav", audios[0], model.sampling_rate)``.
531
  """
532
 
533
  if self.audio_tokenizer is None or self.text_tokenizer is None:
 
615
  ref_wav = load_audio(ref_audio, self.sampling_rate)
616
  else:
617
  waveform, sr = ref_audio
618
+ if isinstance(waveform, torch.Tensor):
619
+ waveform = waveform.cpu().numpy()
620
+ if waveform.ndim == 1:
621
+ waveform = waveform[np.newaxis, :]
622
+ if waveform.shape[0] > 1:
623
+ waveform = np.mean(waveform, axis=0, keepdims=True)
624
  if sr != self.sampling_rate:
625
+ waveform = librosa.resample(
626
+ waveform, orig_sr=sr, target_sr=self.sampling_rate,
627
  )
628
  ref_wav = waveform
629
 
630
+ ref_rms = float(np.sqrt(np.mean(ref_wav ** 2)))
631
  if 0 < ref_rms < 0.1:
632
  ref_wav = ref_wav * 0.1 / ref_rms
633
 
 
646
  lead_sil=100,
647
  trail_sil=200,
648
  )
649
+ if ref_wav.shape[-1] == 0:
650
  raise ValueError(
651
  "Reference audio is empty after silence removal. "
652
  "Try setting preprocess_prompt=False."
653
  )
654
 
655
+ ref_duration = ref_wav.shape[-1] / self.sampling_rate
656
  if ref_duration > 20.0:
657
  logger.warning(
658
  "Reference audio is %.1fs long (>20s). This may cause slower "
 
670
  logger.debug("Auto-transcribed ref_text: %s", ref_text)
671
 
672
  chunk_size = self.audio_tokenizer.config.hop_length
673
+ clip_size = int(ref_wav.shape[-1] % chunk_size)
674
  ref_wav = ref_wav[:, :-clip_size] if clip_size > 0 else ref_wav
675
+ # numpy → torch at tokenizer boundary
676
+ ref_wav_tensor = torch.from_numpy(ref_wav).to(
677
+ self.audio_tokenizer.device
678
+ )
679
  ref_audio_tokens = self.audio_tokenizer.encode(
680
+ ref_wav_tensor.unsqueeze(0),
681
  ).audio_codes.squeeze(
682
  0
683
  ) # (C, T)
 
696
  tokens: Union[torch.Tensor, List[torch.Tensor]],
697
  rms: Union[float, None],
698
  gen_config: OmniVoiceGenerationConfig,
699
+ ) -> np.ndarray:
700
  """
701
  Args:
702
  tokens: Audio tokens — either a single tensor of shape
 
704
  rms: RMS of the reference audio for volume adjustment.
705
  gen_config: Generation config for post-processing options.
706
  Returns:
707
+ Decoded and post-processed audio array of shape (T,).
708
  """
709
  tokenizer_device = self.audio_tokenizer.device
710
  if isinstance(tokens, list):
 
712
  self.audio_tokenizer.decode(t.to(tokenizer_device).unsqueeze(0))
713
  .audio_values[0]
714
  .cpu()
715
+ .numpy()
716
  for t in tokens
717
  ]
718
  audio_waveform = cross_fade_chunks(chunk_audios, self.sampling_rate)
 
721
  self.audio_tokenizer.decode(tokens.to(tokenizer_device).unsqueeze(0))
722
  .audio_values[0]
723
  .cpu()
724
+ .numpy()
725
  )
726
 
727
+ audio_waveform = self._post_process_audio(
728
  audio_waveform,
729
  postprocess_output=gen_config.postprocess_output,
730
  ref_rms=rms,
731
  )
732
+ return audio_waveform.squeeze(0)
733
 
734
  def _post_process_audio(
735
  self,
736
+ generated_audio: np.ndarray,
737
  postprocess_output: bool,
738
  ref_rms: Union[float, None],
739
+ ) -> np.ndarray:
740
  """Optionally remove long silences, adjust volume, and add edge padding.
741
 
742
  Args:
743
+ generated_audio: Numpy array of shape (1, T).
744
  postprocess_output: If True, remove long silences and apply fade/pad.
745
  ref_rms: RMS of the reference audio for volume normalisation.
746
  Returns:
747
+ Processed numpy array of shape (1, T).
748
  """
749
  if postprocess_output:
750
  generated_audio = remove_silence(
 
758
  if ref_rms is not None and ref_rms < 0.1:
759
  generated_audio = generated_audio * ref_rms / 0.1
760
  elif ref_rms is None:
761
+ peak = np.abs(generated_audio).max()
 
 
762
  if peak > 1e-6:
763
  generated_audio = generated_audio / peak * 0.5
764
 
 
1560
  # filter out newline / carriage-return characters
1561
  full_text = re.sub(r"[\r\n]+", "", full_text)
1562
 
1563
+ # replace Chinese parentheses with English ones
1564
+ full_text = full_text.replace("\uff08", "(").replace("\uff09", ")")
1565
+
1566
  # collapse consecutive spaces / tabs into a single space
1567
  full_text = re.sub(r"[ \t]+", " ", full_text)
1568
 
omnivoice/scripts/denoise_audio.py CHANGED
@@ -73,6 +73,7 @@ from tqdm.auto import tqdm
73
 
74
  from omnivoice.data.batching import StreamLengthGroupDataset
75
  from omnivoice.data.dataset import JsonlDatasetReader, WebDatasetReader
 
76
  from omnivoice.utils.common import str2bool
77
 
78
  SIDON_INPUT_SAMPLE_RATE = 16_000
@@ -367,10 +368,10 @@ def extract_seamless_m4t_features(
367
 
368
  def serialise_flac(key: str, waveform: torch.Tensor, sample_rate: int) -> dict:
369
  buffer = io.BytesIO()
370
- audio = waveform.to(dtype=torch.float32).cpu()
371
- if audio.ndim == 1:
372
- audio = audio.unsqueeze(0)
373
- torchaudio.save(buffer, audio, sample_rate, format="flac", bits_per_sample=16)
374
  return {"__key__": key, "flac": buffer.getvalue()}
375
 
376
 
 
73
 
74
  from omnivoice.data.batching import StreamLengthGroupDataset
75
  from omnivoice.data.dataset import JsonlDatasetReader, WebDatasetReader
76
+ import soundfile as sf
77
  from omnivoice.utils.common import str2bool
78
 
79
  SIDON_INPUT_SAMPLE_RATE = 16_000
 
368
 
369
  def serialise_flac(key: str, waveform: torch.Tensor, sample_rate: int) -> dict:
370
  buffer = io.BytesIO()
371
+ audio = waveform.to(dtype=torch.float32).cpu().numpy()
372
+ if audio.ndim == 2:
373
+ audio = audio.T # (C, T) → (T, C) for soundfile
374
+ sf.write(buffer, audio, sample_rate, format="FLAC")
375
  return {"__key__": key, "flac": buffer.getvalue()}
376
 
377
 
omnivoice/scripts/extract_audio_tokens_add_noise.py CHANGED
@@ -66,13 +66,13 @@ from typing import Any
66
  import numpy as np
67
  import torch
68
  import torch.nn.functional as F
69
- import torchaudio
70
  import webdataset as wds
71
  from torch.utils.data import DataLoader, IterableDataset
72
  from tqdm.auto import tqdm
73
  from transformers import AutoFeatureExtractor, HiggsAudioV2TokenizerModel
74
 
75
  from omnivoice.data.dataset import JsonlDatasetReader, WebDatasetReader
 
76
  from omnivoice.utils.common import str2bool
77
 
78
  warnings.filterwarnings(
@@ -207,13 +207,7 @@ def serialise_numpy(key: str, tokens: np.ndarray) -> dict:
207
 
208
  def _load_aug_audio(data, sample_rate=24000):
209
  """Simple audio loader for augmentation files."""
210
- with io.BytesIO(data) as b:
211
- wav, sr = torchaudio.load(b)
212
- if wav.shape[0] > 1:
213
- wav = wav.mean(dim=0, keepdim=True)
214
- if sr != sample_rate:
215
- wav = torchaudio.functional.resample(wav, sr, sample_rate)
216
- return wav
217
 
218
 
219
  class SimpleWorkerSampler:
 
66
  import numpy as np
67
  import torch
68
  import torch.nn.functional as F
 
69
  import webdataset as wds
70
  from torch.utils.data import DataLoader, IterableDataset
71
  from tqdm.auto import tqdm
72
  from transformers import AutoFeatureExtractor, HiggsAudioV2TokenizerModel
73
 
74
  from omnivoice.data.dataset import JsonlDatasetReader, WebDatasetReader
75
+ from omnivoice.utils.audio import load_audio_bytes
76
  from omnivoice.utils.common import str2bool
77
 
78
  warnings.filterwarnings(
 
207
 
208
  def _load_aug_audio(data, sample_rate=24000):
209
  """Simple audio loader for augmentation files."""
210
+ return torch.from_numpy(load_audio_bytes(data, sample_rate))
 
 
 
 
 
 
211
 
212
 
213
  class SimpleWorkerSampler:
omnivoice/scripts/jsonl_to_webdataset.py CHANGED
@@ -65,10 +65,13 @@ from concurrent.futures import (
65
  from itertools import islice
66
  from pathlib import Path
67
 
68
- import torchaudio
69
  import webdataset as wds
70
  from tqdm import tqdm
71
 
 
 
 
72
  from omnivoice.utils.common import str2bool
73
 
74
 
@@ -164,16 +167,16 @@ def process_audio_item(meta, target_sr):
164
  if not os.path.exists(audio_path):
165
  raise FileNotFoundError(f"{audio_path} not found")
166
 
167
- waveform, sr = torchaudio.load(audio_path)
168
  audio_duration = waveform.shape[1] / sr
169
  meta["audio_duration"] = audio_duration
170
 
171
  if target_sr and sr != target_sr:
172
- waveform = torchaudio.functional.resample(waveform, sr, target_sr)
173
  sr = target_sr
174
 
175
  audio_buffer = io.BytesIO()
176
- torchaudio.save(audio_buffer, waveform, sr, format="flac", bits_per_sample=16)
177
  audio_bytes = audio_buffer.getvalue()
178
 
179
  sample = {
 
65
  from itertools import islice
66
  from pathlib import Path
67
 
68
+ import librosa
69
  import webdataset as wds
70
  from tqdm import tqdm
71
 
72
+ import soundfile as sf
73
+
74
+ from omnivoice.utils.audio import load_waveform
75
  from omnivoice.utils.common import str2bool
76
 
77
 
 
167
  if not os.path.exists(audio_path):
168
  raise FileNotFoundError(f"{audio_path} not found")
169
 
170
+ waveform, sr = load_waveform(audio_path)
171
  audio_duration = waveform.shape[1] / sr
172
  meta["audio_duration"] = audio_duration
173
 
174
  if target_sr and sr != target_sr:
175
+ waveform = librosa.resample(waveform, orig_sr=sr, target_sr=target_sr)
176
  sr = target_sr
177
 
178
  audio_buffer = io.BytesIO()
179
+ sf.write(audio_buffer, waveform.T, sr, format="FLAC")
180
  audio_bytes = audio_buffer.getvalue()
181
 
182
  sample = {
omnivoice/utils/audio.py CHANGED
@@ -17,83 +17,157 @@
17
 
18
  """Audio I/O and processing utilities.
19
 
20
- Provides functions for loading, resampling, silence removal, chunking,
21
- cross-fading, and format conversion. Used by ``OmniVoice.generate()`` during
22
- inference post-processing.
 
 
23
  """
24
 
 
 
 
 
25
  import numpy as np
26
- import torch
27
- import torchaudio
28
  from pydub import AudioSegment
29
  from pydub.silence import detect_leading_silence, detect_nonsilent, split_on_silence
30
 
 
 
 
 
 
 
31
 
32
- def load_audio(audio_path: str, sampling_rate: int):
 
 
 
 
 
 
 
 
 
 
33
  """
34
- Load the waveform with torchaudio and resampling if needed.
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  Parameters:
37
  audio_path: path of the audio.
38
  sampling_rate: target sampling rate.
39
 
40
  Returns:
41
- Loaded prompt waveform with target sampling rate,
42
- PyTorch tensor of shape (1, T)
43
  """
44
- try:
45
- waveform, prompt_sampling_rate = torchaudio.load(
46
- audio_path, backend="soundfile"
 
 
 
 
 
 
47
  )
48
- except (RuntimeError, OSError):
49
- # Fallback via pydub+ffmpeg for formats torchaudio can't handle
50
- aseg = AudioSegment.from_file(audio_path)
51
- audio_data = np.array(aseg.get_array_of_samples()).astype(np.float32) / 32768.0
52
- if aseg.channels == 1:
53
- waveform = torch.from_numpy(audio_data).unsqueeze(0)
54
- else:
55
- waveform = torch.from_numpy(audio_data.reshape(-1, aseg.channels).T)
56
- prompt_sampling_rate = aseg.frame_rate
57
-
58
- if prompt_sampling_rate != sampling_rate:
59
- waveform = torchaudio.functional.resample(
60
- waveform,
61
- orig_freq=prompt_sampling_rate,
62
- new_freq=sampling_rate,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  )
64
- if waveform.shape[0] > 1:
65
- waveform = torch.mean(waveform, dim=0, keepdim=True)
66
 
67
- return waveform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
 
70
  def remove_silence(
71
- audio: torch.Tensor,
72
  sampling_rate: int,
73
  mid_sil: int = 300,
74
  lead_sil: int = 100,
75
  trail_sil: int = 300,
76
- ):
77
- """
78
- Remove middle silences longer than mid_sil ms, and edge silences longer than edge_sil ms
79
 
80
  Parameters:
81
- audio: PyTorch tensor with shape (C, T).
82
  sampling_rate: sampling rate of the audio.
83
- mid_sil: the duration of silences in the middle of audio to be removed in ms.
84
- if mid_sil <= 0, no middle silence will be removed.
85
- edge_sil: the duration of silences in the edge of audio to be removed in ms.
86
- trail_sil: the duration of added trailing silence in ms.
87
 
88
  Returns:
89
- PyTorch tensor with shape (C, T), where C is number of channels
90
- and T is number of audio samples
91
  """
92
- # Load audio file
93
- wave = tensor_to_audiosegment(audio, sampling_rate)
94
 
95
  if mid_sil > 0:
96
- # Split audio using silences longer than mid_sil
97
  non_silent_segs = split_on_silence(
98
  wave,
99
  min_silence_len=mid_sil,
@@ -101,17 +175,13 @@ def remove_silence(
101
  keep_silence=mid_sil,
102
  seek_step=10,
103
  )
104
-
105
- # Concatenate all non-silent segments
106
  wave = AudioSegment.silent(duration=0)
107
  for seg in non_silent_segs:
108
  wave += seg
109
 
110
- # Remove silence longer than 0.1 seconds in the begining and ending of wave
111
  wave = remove_silence_edges(wave, lead_sil, trail_sil, -50)
112
 
113
- # Convert to PyTorch tensor
114
- return audiosegment_to_tensor(wave)
115
 
116
 
117
  def remove_silence_edges(
@@ -119,25 +189,12 @@ def remove_silence_edges(
119
  lead_sil: int = 100,
120
  trail_sil: int = 300,
121
  silence_threshold: float = -50,
122
- ):
123
- """
124
- Remove edge silences longer than `keep_silence` ms.
125
-
126
- Parameters:
127
- audio: an AudioSegment object.
128
- keep_silence: kept silence in the edge.
129
- only_edge: If true, only remove edge silences.
130
- silence_threshold: the threshold of silence.
131
-
132
- Returns:
133
- An AudioSegment object
134
- """
135
- # Remove heading silence
136
  start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold)
137
  start_idx = max(0, start_idx - lead_sil)
138
  audio = audio[start_idx:]
139
 
140
- # Remove trailing silence
141
  audio = audio.reverse()
142
  start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold)
143
  start_idx = max(0, start_idx - trail_sil)
@@ -147,80 +204,22 @@ def remove_silence_edges(
147
  return audio
148
 
149
 
150
- def audiosegment_to_tensor(aseg):
151
- """
152
- Convert a pydub.AudioSegment to PyTorch audio tensor
153
- """
154
- audio_data = np.array(aseg.get_array_of_samples())
155
-
156
- # Convert to float32 and normalize to [-1, 1] range
157
- audio_data = audio_data.astype(np.float32) / 32768.0
158
-
159
- # Handle channels
160
- if aseg.channels == 1:
161
- # Mono channel: add channel dimension (T) -> (1, T)
162
- tensor_data = torch.from_numpy(audio_data).unsqueeze(0)
163
- else:
164
- # Multi-channel: reshape to (C, T)
165
- tensor_data = torch.from_numpy(audio_data.reshape(-1, aseg.channels).T)
166
-
167
- return tensor_data
168
-
169
-
170
- def tensor_to_audiosegment(tensor, sample_rate):
171
- """
172
- Convert a PyTorch audio tensor to pydub.AudioSegment
173
-
174
- Parameters:
175
- tensor: Tensor with shape (C, T), where C is the number of channels
176
- and T is the time steps
177
- sample_rate: Audio sample rate
178
- """
179
- # Convert tensor to numpy array
180
- assert isinstance(tensor, torch.Tensor)
181
- audio_np = tensor.cpu().numpy()
182
-
183
- # Convert to int16 type (common format for pydub)
184
- # Assumes tensor values are in [-1, 1] range as floating point
185
- audio_np = (audio_np * 32768.0).clip(-32768, 32767).astype(np.int16)
186
-
187
- # Convert to byte stream
188
- # For multi-channel audio, pydub requires interleaved format
189
- # (e.g., left-right-left-right)
190
- if audio_np.shape[0] > 1:
191
- # Convert to interleaved format
192
- audio_np = audio_np.transpose(1, 0).flatten()
193
- audio_bytes = audio_np.tobytes()
194
-
195
- # Create AudioSegment
196
- audio_segment = AudioSegment(
197
- data=audio_bytes,
198
- sample_width=2,
199
- frame_rate=sample_rate,
200
- channels=tensor.shape[0],
201
- )
202
-
203
- return audio_segment
204
-
205
-
206
  def fade_and_pad_audio(
207
- audio: torch.Tensor,
208
  pad_duration: float = 0.1,
209
  fade_duration: float = 0.1,
210
  sample_rate: int = 24000,
211
- ) -> torch.Tensor:
212
- """
213
- Applies a smooth fade-in and fade-out to the audio, and then pads both sides
214
- with pure silence to prevent abrupt starts and ends (clicks/pops).
215
 
216
  Args:
217
- audio: PyTorch tensor of shape (C, T) containing audio data.
218
- pad_duration: Duration of pure silence to add to each end (in seconds).
219
- fade_duration: Duration of the fade-in/out curve (in seconds).
220
- sample_rate: Audio sampling rate.
221
 
222
  Returns:
223
- Processed sequence tensor with shape (C, T_new)
224
  """
225
  if audio.shape[-1] == 0:
226
  return audio
@@ -228,59 +227,53 @@ def fade_and_pad_audio(
228
  fade_samples = int(fade_duration * sample_rate)
229
  pad_samples = int(pad_duration * sample_rate)
230
 
231
- processed = audio.clone()
232
 
233
  if fade_samples > 0:
234
  k = min(fade_samples, processed.shape[-1] // 2)
235
-
236
  if k > 0:
237
- fade_in = torch.linspace(
238
- 0, 1, k, device=processed.device, dtype=processed.dtype
239
- )[None, :]
240
- processed[..., :k] = processed[..., :k] * fade_in
241
 
242
- fade_out = torch.linspace(
243
- 1, 0, k, device=processed.device, dtype=processed.dtype
244
- )[None, :]
245
- processed[..., -k:] = processed[..., -k:] * fade_out
246
 
247
  if pad_samples > 0:
248
- silence = torch.zeros(
249
  (processed.shape[0], pad_samples),
250
  dtype=processed.dtype,
251
- device=processed.device,
252
  )
253
- processed = torch.cat([silence, processed, silence], dim=-1)
254
 
255
  return processed
256
 
257
 
258
  def trim_long_audio(
259
- audio: torch.Tensor,
260
  sampling_rate: int,
261
  max_duration: float = 15.0,
262
  min_duration: float = 3.0,
263
  trim_threshold: float = 20.0,
264
- ) -> torch.Tensor:
265
- """Trim audio to <= max_duration by splitting at the largest silence gap.
266
 
267
  Only trims when the audio exceeds *trim_threshold* seconds.
268
 
269
  Args:
270
- audio: Audio tensor of shape (C, T).
271
- sampling_rate: Audio sampling rate.
272
- max_duration: Maximum duration in seconds.
273
- min_duration: Minimum duration in seconds.
274
- trim_threshold: Only trim if audio is longer than this (seconds).
275
 
276
  Returns:
277
- Trimmed audio tensor.
278
  """
279
- duration = audio.size(-1) / sampling_rate
280
  if duration <= trim_threshold:
281
  return audio
282
 
283
- seg = tensor_to_audiosegment(audio, sampling_rate)
284
  nonsilent = detect_nonsilent(
285
  seg, min_silence_len=100, silence_thresh=-40, seek_step=10
286
  )
@@ -290,7 +283,6 @@ def trim_long_audio(
290
  max_ms = int(max_duration * 1000)
291
  min_ms = int(min_duration * 1000)
292
 
293
- # Walk through speech regions; at each gap pick the latest split <= max_duration
294
  best_split = 0
295
  for start, end in nonsilent:
296
  if start > best_split and start <= max_ms:
@@ -302,56 +294,49 @@ def trim_long_audio(
302
  best_split = min(max_ms, len(seg))
303
 
304
  trimmed = seg[:best_split]
305
- return audiosegment_to_tensor(trimmed)
306
 
307
 
308
  def cross_fade_chunks(
309
- chunks: list[torch.Tensor],
310
  sample_rate: int,
311
  silence_duration: float = 0.3,
312
- ) -> torch.Tensor:
313
- """Concatenate audio chunks with a short silence gap and fade at boundaries.
314
-
315
- Each boundary is structured as: fade-out tail → silence buffer → fade-in head.
316
- This avoids click artifacts from direct concatenation or overlapping mismatch.
317
 
318
  Args:
319
- chunks: List of audio tensors, each (C, T).
320
- sample_rate: Audio sample rate.
321
- silence_duration: Total silence gap duration in seconds.
322
 
323
  Returns:
324
- Merged audio tensor (C, T_total).
325
  """
326
  if len(chunks) == 1:
327
  return chunks[0]
328
 
329
  total_n = int(silence_duration * sample_rate)
330
  fade_n = total_n // 3
331
- silence_n = fade_n # middle silent gap
332
- merged = chunks[0].clone()
333
 
334
  for chunk in chunks[1:]:
335
- dev, dt = merged.device, merged.dtype
336
  parts = [merged]
337
 
338
- # Fade out tail of current merged audio
339
- fout_n = min(fade_n, merged.size(-1))
340
  if fout_n > 0:
341
- w_out = torch.linspace(1, 0, fout_n, device=dev, dtype=dt)[None, :]
342
- parts[-1][..., -fout_n:] = parts[-1][..., -fout_n:] * w_out
343
 
344
- # Silent buffer between chunks
345
- parts.append(torch.zeros(chunks[0].shape[0], silence_n, device=dev, dtype=dt))
346
 
347
- # Fade in head of next chunk
348
- fade_in = chunk.clone()
349
- fin_n = min(fade_n, fade_in.size(-1))
350
  if fin_n > 0:
351
- w_in = torch.linspace(0, 1, fin_n, device=dev, dtype=dt)[None, :]
352
- fade_in[..., :fin_n] = fade_in[..., :fin_n] * w_in
353
 
354
  parts.append(fade_in)
355
- merged = torch.cat(parts, dim=-1)
356
 
357
  return merged
 
17
 
18
  """Audio I/O and processing utilities.
19
 
20
+ Provides functions for loading, resampling, silence removal,
21
+ chunking, cross-fading, and format conversion.
22
+
23
+ All public functions in this module operate on **numpy float32 arrays**
24
+ with shape ``(C, T)`` (channels-first).
25
  """
26
 
27
+ import io
28
+ import logging
29
+
30
+ import librosa
31
  import numpy as np
32
+ import soundfile as sf
 
33
  from pydub import AudioSegment
34
  from pydub.silence import detect_leading_silence, detect_nonsilent, split_on_silence
35
 
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ # ---------------------------------------------------------------------------
40
+ # Loading
41
+ # ---------------------------------------------------------------------------
42
 
43
+
44
+ def load_waveform(audio_path: str):
45
+ """Load audio from a file path, returning (data, sample_rate).
46
+
47
+ Tries two backends in order:
48
+ 1. soundfile — covers WAV/FLAC/OGG etc., no ffmpeg needed.
49
+ 2. librosa — covers MP3/M4A etc. via audioread + ffmpeg.
50
+
51
+ Returns:
52
+ (data, sample_rate) where data is a numpy float32 array of
53
+ shape (C, T).
54
  """
55
+ try:
56
+ data, sr = sf.read(audio_path, dtype="float32", always_2d=True)
57
+ return data.T, sr # (T, C) → (C, T)
58
+ except Exception:
59
+ # soundfile cannot handle MP3/M4A etc., fall back to librosa.
60
+ data, sr = librosa.load(audio_path, sr=None, mono=False)
61
+ if data.ndim == 1:
62
+ data = data[np.newaxis, :]
63
+ return data, sr
64
+
65
+
66
+ def load_audio(audio_path: str, sampling_rate: int) -> np.ndarray:
67
+ """Load a waveform from file and resample to the target rate.
68
 
69
  Parameters:
70
  audio_path: path of the audio.
71
  sampling_rate: target sampling rate.
72
 
73
  Returns:
74
+ Numpy float32 array of shape (1, T).
 
75
  """
76
+ data, sr = load_waveform(audio_path)
77
+
78
+ if data.shape[0] > 1:
79
+ data = np.mean(data, axis=0, keepdims=True)
80
+ if sr != sampling_rate:
81
+ data = librosa.resample(
82
+ data,
83
+ orig_sr=sr,
84
+ target_sr=sampling_rate,
85
  )
86
+
87
+ return data
88
+
89
+
90
+ def load_audio_bytes(raw: bytes, sampling_rate: int) -> np.ndarray:
91
+ """Load audio from in-memory bytes and resample.
92
+
93
+ Parameters:
94
+ raw: raw audio file bytes (e.g. from WebDataset).
95
+ sampling_rate: target sampling rate.
96
+
97
+ Returns:
98
+ Numpy float32 array of shape (1, T).
99
+ """
100
+ buf = io.BytesIO(raw)
101
+
102
+ try:
103
+ data, sr = sf.read(buf, dtype="float32", always_2d=True)
104
+ data = data.T # (T, C) → (C, T)
105
+ except Exception:
106
+ buf.seek(0)
107
+ data, sr = librosa.load(buf, sr=None, mono=False)
108
+ if data.ndim == 1:
109
+ data = data[np.newaxis, :]
110
+
111
+ if data.shape[0] > 1:
112
+ data = np.mean(data, axis=0, keepdims=True)
113
+ if sr != sampling_rate:
114
+ data = librosa.resample(
115
+ data,
116
+ orig_sr=sr,
117
+ target_sr=sampling_rate,
118
  )
 
 
119
 
120
+ return data
121
+
122
+
123
+ # ---------------------------------------------------------------------------
124
+ # Audio processing (all numpy in / numpy out)
125
+ # ---------------------------------------------------------------------------
126
+
127
+
128
+ def numpy_to_audiosegment(audio: np.ndarray, sample_rate: int) -> AudioSegment:
129
+ """Convert a numpy float32 array of shape (C, T) to a pydub AudioSegment."""
130
+ audio_int = (audio * 32768.0).clip(-32768, 32767).astype(np.int16)
131
+ if audio_int.shape[0] > 1:
132
+ audio_int = audio_int.T.flatten() # interleave channels
133
+ return AudioSegment(
134
+ data=audio_int.tobytes(),
135
+ sample_width=2,
136
+ frame_rate=sample_rate,
137
+ channels=audio.shape[0],
138
+ )
139
+
140
+
141
+ def audiosegment_to_numpy(aseg: AudioSegment) -> np.ndarray:
142
+ """Convert a pydub AudioSegment to a numpy float32 array of shape (C, T)."""
143
+ data = np.array(aseg.get_array_of_samples()).astype(np.float32) / 32768.0
144
+ if aseg.channels == 1:
145
+ return data[np.newaxis, :]
146
+ return data.reshape(-1, aseg.channels).T
147
 
148
 
149
  def remove_silence(
150
+ audio: np.ndarray,
151
  sampling_rate: int,
152
  mid_sil: int = 300,
153
  lead_sil: int = 100,
154
  trail_sil: int = 300,
155
+ ) -> np.ndarray:
156
+ """Remove middle silences longer than *mid_sil* ms and trim edge silences.
 
157
 
158
  Parameters:
159
+ audio: numpy array with shape (C, T).
160
  sampling_rate: sampling rate of the audio.
161
+ mid_sil: middle-silence threshold in ms (0 to skip).
162
+ lead_sil: kept leading silence in ms.
163
+ trail_sil: kept trailing silence in ms.
 
164
 
165
  Returns:
166
+ Numpy array with shape (C, T').
 
167
  """
168
+ wave = numpy_to_audiosegment(audio, sampling_rate)
 
169
 
170
  if mid_sil > 0:
 
171
  non_silent_segs = split_on_silence(
172
  wave,
173
  min_silence_len=mid_sil,
 
175
  keep_silence=mid_sil,
176
  seek_step=10,
177
  )
 
 
178
  wave = AudioSegment.silent(duration=0)
179
  for seg in non_silent_segs:
180
  wave += seg
181
 
 
182
  wave = remove_silence_edges(wave, lead_sil, trail_sil, -50)
183
 
184
+ return audiosegment_to_numpy(wave)
 
185
 
186
 
187
  def remove_silence_edges(
 
189
  lead_sil: int = 100,
190
  trail_sil: int = 300,
191
  silence_threshold: float = -50,
192
+ ) -> AudioSegment:
193
+ """Remove edge silences, keeping *lead_sil* / *trail_sil* ms."""
 
 
 
 
 
 
 
 
 
 
 
 
194
  start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold)
195
  start_idx = max(0, start_idx - lead_sil)
196
  audio = audio[start_idx:]
197
 
 
198
  audio = audio.reverse()
199
  start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold)
200
  start_idx = max(0, start_idx - trail_sil)
 
204
  return audio
205
 
206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  def fade_and_pad_audio(
208
+ audio: np.ndarray,
209
  pad_duration: float = 0.1,
210
  fade_duration: float = 0.1,
211
  sample_rate: int = 24000,
212
+ ) -> np.ndarray:
213
+ """Apply fade-in/out and pad with silence to prevent clicks.
 
 
214
 
215
  Args:
216
+ audio: numpy array of shape (C, T).
217
+ pad_duration: silence padding duration per side (seconds).
218
+ fade_duration: fade curve duration (seconds).
219
+ sample_rate: audio sampling rate.
220
 
221
  Returns:
222
+ Processed numpy array of shape (C, T_new).
223
  """
224
  if audio.shape[-1] == 0:
225
  return audio
 
227
  fade_samples = int(fade_duration * sample_rate)
228
  pad_samples = int(pad_duration * sample_rate)
229
 
230
+ processed = audio.copy()
231
 
232
  if fade_samples > 0:
233
  k = min(fade_samples, processed.shape[-1] // 2)
 
234
  if k > 0:
235
+ fade_in = np.linspace(0, 1, k, dtype=np.float32)[np.newaxis, :]
236
+ processed[..., :k] *= fade_in
 
 
237
 
238
+ fade_out = np.linspace(1, 0, k, dtype=np.float32)[np.newaxis, :]
239
+ processed[..., -k:] *= fade_out
 
 
240
 
241
  if pad_samples > 0:
242
+ silence = np.zeros(
243
  (processed.shape[0], pad_samples),
244
  dtype=processed.dtype,
 
245
  )
246
+ processed = np.concatenate([silence, processed, silence], axis=-1)
247
 
248
  return processed
249
 
250
 
251
  def trim_long_audio(
252
+ audio: np.ndarray,
253
  sampling_rate: int,
254
  max_duration: float = 15.0,
255
  min_duration: float = 3.0,
256
  trim_threshold: float = 20.0,
257
+ ) -> np.ndarray:
258
+ """Trim audio to <= *max_duration* by splitting at the largest silence gap.
259
 
260
  Only trims when the audio exceeds *trim_threshold* seconds.
261
 
262
  Args:
263
+ audio: numpy array of shape (C, T).
264
+ sampling_rate: audio sampling rate.
265
+ max_duration: maximum duration in seconds.
266
+ min_duration: minimum duration in seconds.
267
+ trim_threshold: only trim if audio is longer than this (seconds).
268
 
269
  Returns:
270
+ Trimmed numpy array.
271
  """
272
+ duration = audio.shape[-1] / sampling_rate
273
  if duration <= trim_threshold:
274
  return audio
275
 
276
+ seg = numpy_to_audiosegment(audio, sampling_rate)
277
  nonsilent = detect_nonsilent(
278
  seg, min_silence_len=100, silence_thresh=-40, seek_step=10
279
  )
 
283
  max_ms = int(max_duration * 1000)
284
  min_ms = int(min_duration * 1000)
285
 
 
286
  best_split = 0
287
  for start, end in nonsilent:
288
  if start > best_split and start <= max_ms:
 
294
  best_split = min(max_ms, len(seg))
295
 
296
  trimmed = seg[:best_split]
297
+ return audiosegment_to_numpy(trimmed)
298
 
299
 
300
  def cross_fade_chunks(
301
+ chunks: list[np.ndarray],
302
  sample_rate: int,
303
  silence_duration: float = 0.3,
304
+ ) -> np.ndarray:
305
+ """Concatenate audio chunks with silence gaps and cross-fade at boundaries.
 
 
 
306
 
307
  Args:
308
+ chunks: list of numpy arrays, each (C, T).
309
+ sample_rate: audio sample rate.
310
+ silence_duration: total silence gap duration in seconds.
311
 
312
  Returns:
313
+ Merged numpy array (C, T_total).
314
  """
315
  if len(chunks) == 1:
316
  return chunks[0]
317
 
318
  total_n = int(silence_duration * sample_rate)
319
  fade_n = total_n // 3
320
+ silence_n = fade_n
321
+ merged = chunks[0].copy()
322
 
323
  for chunk in chunks[1:]:
 
324
  parts = [merged]
325
 
326
+ fout_n = min(fade_n, merged.shape[-1])
 
327
  if fout_n > 0:
328
+ w_out = np.linspace(1, 0, fout_n, dtype=np.float32)[np.newaxis, :]
329
+ parts[-1][..., -fout_n:] *= w_out
330
 
331
+ parts.append(np.zeros((chunks[0].shape[0], silence_n), dtype=np.float32))
 
332
 
333
+ fade_in = chunk.copy()
334
+ fin_n = min(fade_n, fade_in.shape[-1])
 
335
  if fin_n > 0:
336
+ w_in = np.linspace(0, 1, fin_n, dtype=np.float32)[np.newaxis, :]
337
+ fade_in[..., :fin_n] *= w_in
338
 
339
  parts.append(fade_in)
340
+ merged = np.concatenate(parts, axis=-1)
341
 
342
  return merged
omnivoice/utils/data_utils.py CHANGED
@@ -29,10 +29,10 @@ from pathlib import Path
29
  def read_test_list(path):
30
  """Read a JSONL test list file.
31
 
32
- Each line should be a JSON object with fields:
33
- id, text, ref_audio, ref_text, language_id, language_name, duration, speed
34
-
35
- language_id, language_name, duration, and speed are optional (default to None).
36
 
37
  Returns a list of dicts.
38
  """
@@ -58,6 +58,7 @@ def read_test_list(path):
58
  "language_name": obj.get("language_name"),
59
  "duration": obj.get("duration"),
60
  "speed": obj.get("speed"),
 
61
  }
62
  samples.append(sample)
63
  return samples
 
29
  def read_test_list(path):
30
  """Read a JSONL test list file.
31
 
32
+ Each line should be a JSON object. Only ``id`` and ``text`` are required;
33
+ all other fields are optional (default to ``None``):
34
+ id, text, ref_audio, ref_text, instruct,
35
+ language_id, language_name, duration, speed
36
 
37
  Returns a list of dicts.
38
  """
 
58
  "language_name": obj.get("language_name"),
59
  "duration": obj.get("duration"),
60
  "speed": obj.get("speed"),
61
+ "instruct": obj.get("instruct"),
62
  }
63
  samples.append(sample)
64
  return samples
requirements.txt CHANGED
@@ -5,6 +5,7 @@ transformers==5.3
5
  accelerate
6
  pydub
7
  soundfile
 
8
  numpy
9
  gradio
10
  hf_transfer
 
5
  accelerate
6
  pydub
7
  soundfile
8
+ librosa
9
  numpy
10
  gradio
11
  hf_transfer