Spaces:
Running on Zero
Running on Zero
update to 0.1.4 version
Browse files- omnivoice/cli/demo.py +1 -2
- omnivoice/cli/infer.py +3 -2
- omnivoice/cli/infer_batch.py +84 -54
- omnivoice/data/dataset.py +6 -13
- omnivoice/eval/mos/utmos.py +2 -2
- omnivoice/eval/speaker_similarity/sim.py +2 -2
- omnivoice/eval/utils.py +1 -1
- omnivoice/eval/wer/hubert.py +2 -2
- omnivoice/eval/wer/minimax.py +2 -2
- omnivoice/eval/wer/seedtts.py +2 -2
- omnivoice/models/omnivoice.py +46 -32
- omnivoice/scripts/denoise_audio.py +5 -4
- omnivoice/scripts/extract_audio_tokens_add_noise.py +2 -8
- omnivoice/scripts/jsonl_to_webdataset.py +7 -4
- omnivoice/utils/audio.py +166 -181
- omnivoice/utils/data_utils.py +5 -4
- requirements.txt +1 -0
omnivoice/cli/demo.py
CHANGED
|
@@ -213,8 +213,7 @@ def build_demo(
|
|
| 213 |
except Exception as e:
|
| 214 |
return None, f"Error: {type(e).__name__}: {e}"
|
| 215 |
|
| 216 |
-
waveform = audio[0]
|
| 217 |
-
waveform = (waveform * 32767).astype(np.int16)
|
| 218 |
return (sampling_rate, waveform), "Done."
|
| 219 |
|
| 220 |
# Allow external wrappers (e.g. spaces.GPU for ZeroGPU Spaces)
|
|
|
|
| 213 |
except Exception as e:
|
| 214 |
return None, f"Error: {type(e).__name__}: {e}"
|
| 215 |
|
| 216 |
+
waveform = (audio[0] * 32767).astype(np.int16)
|
|
|
|
| 217 |
return (sampling_rate, waveform), "Done."
|
| 218 |
|
| 219 |
# Allow external wrappers (e.g. spaces.GPU for ZeroGPU Spaces)
|
omnivoice/cli/infer.py
CHANGED
|
@@ -23,7 +23,8 @@ import argparse
|
|
| 23 |
import logging
|
| 24 |
|
| 25 |
import torch
|
| 26 |
-
|
|
|
|
| 27 |
|
| 28 |
from omnivoice.models.omnivoice import OmniVoice
|
| 29 |
from omnivoice.utils.common import str2bool
|
|
@@ -149,7 +150,7 @@ def main():
|
|
| 149 |
class_temperature=args.class_temperature,
|
| 150 |
)
|
| 151 |
|
| 152 |
-
|
| 153 |
logging.info(f"Saved to {args.output}")
|
| 154 |
|
| 155 |
|
|
|
|
| 23 |
import logging
|
| 24 |
|
| 25 |
import torch
|
| 26 |
+
|
| 27 |
+
import soundfile as sf
|
| 28 |
|
| 29 |
from omnivoice.models.omnivoice import OmniVoice
|
| 30 |
from omnivoice.utils.common import str2bool
|
|
|
|
| 150 |
class_temperature=args.class_temperature,
|
| 151 |
)
|
| 152 |
|
| 153 |
+
sf.write(args.output, audios[0], model.sampling_rate)
|
| 154 |
logging.info(f"Saved to {args.output}")
|
| 155 |
|
| 156 |
|
omnivoice/cli/infer_batch.py
CHANGED
|
@@ -42,10 +42,11 @@ from concurrent.futures import ProcessPoolExecutor, as_completed
|
|
| 42 |
from typing import List, Optional, Tuple
|
| 43 |
|
| 44 |
import torch
|
| 45 |
-
import torchaudio
|
| 46 |
from tqdm import tqdm
|
| 47 |
|
| 48 |
from omnivoice.models.omnivoice import OmniVoice
|
|
|
|
|
|
|
| 49 |
from omnivoice.utils.audio import load_audio
|
| 50 |
from omnivoice.utils.common import str2bool
|
| 51 |
from omnivoice.utils.data_utils import read_test_list
|
|
@@ -79,11 +80,17 @@ def get_parser():
|
|
| 79 |
type=str,
|
| 80 |
required=True,
|
| 81 |
help="Path to the JSONL file containing test samples. "
|
| 82 |
-
|
| 83 |
-
'"
|
| 84 |
-
'"
|
| 85 |
-
'"
|
| 86 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
)
|
| 88 |
parser.add_argument(
|
| 89 |
"--res_dir",
|
|
@@ -135,8 +142,7 @@ def get_parser():
|
|
| 135 |
"--batch_duration",
|
| 136 |
type=float,
|
| 137 |
default=1000.0,
|
| 138 |
-
help="Maximum total duration (reference + generated) per batch (seconds).
|
| 139 |
-
"Only effective for parallel_chunk / no chunk mode.",
|
| 140 |
)
|
| 141 |
parser.add_argument(
|
| 142 |
"--batch_size",
|
|
@@ -239,7 +245,7 @@ def process_init(rank_queue, model_checkpoint, warmup=0):
|
|
| 239 |
dummy_ref_audio = (
|
| 240 |
torch.randn(1, SAMPLING_RATE),
|
| 241 |
SAMPLING_RATE,
|
| 242 |
-
) # 1s
|
| 243 |
for i in range(warmup):
|
| 244 |
worker_model.generate(
|
| 245 |
text=["hello"],
|
|
@@ -255,40 +261,58 @@ def process_init(rank_queue, model_checkpoint, warmup=0):
|
|
| 255 |
def estimate_sample_total_duration(
|
| 256 |
duration_estimator: RuleDurationEstimator,
|
| 257 |
text: str,
|
| 258 |
-
ref_text: str,
|
| 259 |
-
ref_audio_path: str,
|
| 260 |
gen_duration: Optional[float] = None,
|
| 261 |
) -> float:
|
| 262 |
-
|
| 263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
if gen_duration is None:
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
|
| 270 |
total_duration = ref_duration + gen_duration
|
| 271 |
return total_duration
|
| 272 |
|
| 273 |
|
| 274 |
-
def
|
| 275 |
samples: List[Tuple],
|
| 276 |
duration_estimator: RuleDurationEstimator,
|
| 277 |
-
|
| 278 |
-
)
|
| 279 |
sample_with_duration = []
|
| 280 |
for sample in samples:
|
| 281 |
-
|
| 282 |
total_duration = estimate_sample_total_duration(
|
| 283 |
-
duration_estimator,
|
| 284 |
-
text,
|
| 285 |
-
ref_text,
|
| 286 |
-
ref_audio_path,
|
| 287 |
-
gen_duration=dur,
|
| 288 |
)
|
| 289 |
sample_with_duration.append((sample, total_duration))
|
| 290 |
-
|
| 291 |
sample_with_duration.sort(key=lambda x: x[1], reverse=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
batches = []
|
| 293 |
current_batch = []
|
| 294 |
current_total_duration = 0.0
|
|
@@ -319,19 +343,7 @@ def cluster_samples_by_batch_size(
|
|
| 319 |
batch_size: int,
|
| 320 |
) -> List[List[Tuple]]:
|
| 321 |
"""Split samples into fixed-size batches, sorted by duration to minimize padding."""
|
| 322 |
-
sample_with_duration =
|
| 323 |
-
for sample in samples:
|
| 324 |
-
save_name, ref_text, ref_audio_path, text, lang_id, lang_name, dur, spd = sample
|
| 325 |
-
total_duration = estimate_sample_total_duration(
|
| 326 |
-
duration_estimator,
|
| 327 |
-
text,
|
| 328 |
-
ref_text,
|
| 329 |
-
ref_audio_path,
|
| 330 |
-
gen_duration=dur,
|
| 331 |
-
)
|
| 332 |
-
sample_with_duration.append((sample, total_duration))
|
| 333 |
-
|
| 334 |
-
sample_with_duration.sort(key=lambda x: x[1], reverse=True)
|
| 335 |
sorted_samples = [s for s, _ in sample_with_duration]
|
| 336 |
|
| 337 |
batches = [
|
|
@@ -359,9 +371,10 @@ def run_inference_batch(
|
|
| 359 |
langs = []
|
| 360 |
durations = []
|
| 361 |
speeds = []
|
|
|
|
| 362 |
|
| 363 |
for sample in batch_samples:
|
| 364 |
-
save_name, ref_text, ref_audio_path, text, lang_id, lang_name, dur, spd = sample
|
| 365 |
save_names.append(save_name)
|
| 366 |
ref_texts.append(ref_text)
|
| 367 |
ref_audio_paths.append(ref_audio_path)
|
|
@@ -369,15 +382,17 @@ def run_inference_batch(
|
|
| 369 |
langs.append(lang_id)
|
| 370 |
durations.append(dur)
|
| 371 |
speeds.append(spd)
|
|
|
|
| 372 |
|
| 373 |
start_time = time.time()
|
| 374 |
audios = worker_model.generate(
|
| 375 |
text=texts,
|
| 376 |
language=langs,
|
| 377 |
-
ref_audio=ref_audio_paths,
|
| 378 |
-
ref_text=ref_texts,
|
| 379 |
duration=durations if any(d is not None for d in durations) else None,
|
| 380 |
speed=speeds if any(s is not None for s in speeds) else None,
|
|
|
|
| 381 |
**gen_kwargs,
|
| 382 |
)
|
| 383 |
batch_synth_time = time.time() - start_time
|
|
@@ -385,7 +400,7 @@ def run_inference_batch(
|
|
| 385 |
results = []
|
| 386 |
for save_name, audio in zip(save_names, audios):
|
| 387 |
save_path = os.path.join(res_dir, save_name + ".wav")
|
| 388 |
-
|
| 389 |
audio_duration = audio.shape[-1] / worker_model.sampling_rate
|
| 390 |
results.append(
|
| 391 |
(
|
|
@@ -436,13 +451,14 @@ def main():
|
|
| 436 |
samples.append(
|
| 437 |
(
|
| 438 |
s["id"],
|
| 439 |
-
s
|
| 440 |
-
s
|
| 441 |
s["text"],
|
| 442 |
lang_id,
|
| 443 |
lang_name,
|
| 444 |
s.get("duration"),
|
| 445 |
s.get("speed"),
|
|
|
|
| 446 |
)
|
| 447 |
)
|
| 448 |
|
|
@@ -457,18 +473,32 @@ def main():
|
|
| 457 |
) as executor:
|
| 458 |
futures = []
|
| 459 |
|
| 460 |
-
# parallel_chunk / no chunk
|
| 461 |
logging.info("Running batch inference")
|
| 462 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 463 |
duration_estimator = RuleDurationEstimator()
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
|
| 473 |
args_dict = vars(args)
|
| 474 |
|
|
|
|
| 42 |
from typing import List, Optional, Tuple
|
| 43 |
|
| 44 |
import torch
|
|
|
|
| 45 |
from tqdm import tqdm
|
| 46 |
|
| 47 |
from omnivoice.models.omnivoice import OmniVoice
|
| 48 |
+
import soundfile as sf
|
| 49 |
+
|
| 50 |
from omnivoice.utils.audio import load_audio
|
| 51 |
from omnivoice.utils.common import str2bool
|
| 52 |
from omnivoice.utils.data_utils import read_test_list
|
|
|
|
| 80 |
type=str,
|
| 81 |
required=True,
|
| 82 |
help="Path to the JSONL file containing test samples. "
|
| 83 |
+
"Each line is a JSON object with the following fields: "
|
| 84 |
+
'"id" (str, required): unique name for the output file; '
|
| 85 |
+
'"text" (str, required): text to synthesize; '
|
| 86 |
+
'"ref_audio" (str): path to reference audio for voice cloning; '
|
| 87 |
+
'"ref_text" (str): transcript of the reference audio; '
|
| 88 |
+
'"instruct" (str): instruction for voice design (used when ref_audio is absent); '
|
| 89 |
+
'"language_id" (str): language code, e.g. "en"; '
|
| 90 |
+
'"language_name" (str): language name, e.g. "English"; '
|
| 91 |
+
'"duration" (float): target duration in seconds; '
|
| 92 |
+
'"speed" (float): speaking speed multiplier. '
|
| 93 |
+
"Only id and text are required; all other fields are optional.",
|
| 94 |
)
|
| 95 |
parser.add_argument(
|
| 96 |
"--res_dir",
|
|
|
|
| 142 |
"--batch_duration",
|
| 143 |
type=float,
|
| 144 |
default=1000.0,
|
| 145 |
+
help="Maximum total duration (reference + generated) per batch (seconds).",
|
|
|
|
| 146 |
)
|
| 147 |
parser.add_argument(
|
| 148 |
"--batch_size",
|
|
|
|
| 245 |
dummy_ref_audio = (
|
| 246 |
torch.randn(1, SAMPLING_RATE),
|
| 247 |
SAMPLING_RATE,
|
| 248 |
+
) # 1s dummy audio
|
| 249 |
for i in range(warmup):
|
| 250 |
worker_model.generate(
|
| 251 |
text=["hello"],
|
|
|
|
| 261 |
def estimate_sample_total_duration(
|
| 262 |
duration_estimator: RuleDurationEstimator,
|
| 263 |
text: str,
|
| 264 |
+
ref_text: Optional[str],
|
| 265 |
+
ref_audio_path: Optional[str],
|
| 266 |
gen_duration: Optional[float] = None,
|
| 267 |
) -> float:
|
| 268 |
+
"""Estimate total duration (ref + generated) for a single sample.
|
| 269 |
+
|
| 270 |
+
When ``ref_audio_path`` is ``None`` (instruct / voice-design mode),
|
| 271 |
+
the reference duration is treated as 0 and only the estimated generated
|
| 272 |
+
duration contributes to the total.
|
| 273 |
+
"""
|
| 274 |
+
if ref_audio_path is not None:
|
| 275 |
+
ref_wav = load_audio(ref_audio_path, SAMPLING_RATE)
|
| 276 |
+
ref_duration = ref_wav.shape[-1] / SAMPLING_RATE
|
| 277 |
+
else:
|
| 278 |
+
ref_duration = 0
|
| 279 |
|
| 280 |
if gen_duration is None:
|
| 281 |
+
if ref_audio_path is not None:
|
| 282 |
+
gen_duration = duration_estimator.estimate_duration(
|
| 283 |
+
text, ref_text or "", ref_duration, low_threshold=2.0
|
| 284 |
+
)
|
| 285 |
+
else:
|
| 286 |
+
gen_duration = duration_estimator.estimate_duration(
|
| 287 |
+
text, "Nice to meet you.", 0.5, low_threshold=2.0
|
| 288 |
+
)
|
| 289 |
|
| 290 |
total_duration = ref_duration + gen_duration
|
| 291 |
return total_duration
|
| 292 |
|
| 293 |
|
| 294 |
+
def _sort_samples_by_duration(
|
| 295 |
samples: List[Tuple],
|
| 296 |
duration_estimator: RuleDurationEstimator,
|
| 297 |
+
) -> List[Tuple[Tuple, float]]:
|
| 298 |
+
"""Return (sample, total_duration) pairs sorted by duration descending."""
|
| 299 |
sample_with_duration = []
|
| 300 |
for sample in samples:
|
| 301 |
+
_, ref_text, ref_audio_path, text, _, _, dur, _, _ = sample
|
| 302 |
total_duration = estimate_sample_total_duration(
|
| 303 |
+
duration_estimator, text, ref_text, ref_audio_path, gen_duration=dur
|
|
|
|
|
|
|
|
|
|
|
|
|
| 304 |
)
|
| 305 |
sample_with_duration.append((sample, total_duration))
|
|
|
|
| 306 |
sample_with_duration.sort(key=lambda x: x[1], reverse=True)
|
| 307 |
+
return sample_with_duration
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def cluster_samples_by_duration(
|
| 311 |
+
samples: List[Tuple],
|
| 312 |
+
duration_estimator: RuleDurationEstimator,
|
| 313 |
+
batch_duration: float,
|
| 314 |
+
) -> List[List[Tuple]]:
|
| 315 |
+
sample_with_duration = _sort_samples_by_duration(samples, duration_estimator)
|
| 316 |
batches = []
|
| 317 |
current_batch = []
|
| 318 |
current_total_duration = 0.0
|
|
|
|
| 343 |
batch_size: int,
|
| 344 |
) -> List[List[Tuple]]:
|
| 345 |
"""Split samples into fixed-size batches, sorted by duration to minimize padding."""
|
| 346 |
+
sample_with_duration = _sort_samples_by_duration(samples, duration_estimator)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
sorted_samples = [s for s, _ in sample_with_duration]
|
| 348 |
|
| 349 |
batches = [
|
|
|
|
| 371 |
langs = []
|
| 372 |
durations = []
|
| 373 |
speeds = []
|
| 374 |
+
instructs = []
|
| 375 |
|
| 376 |
for sample in batch_samples:
|
| 377 |
+
save_name, ref_text, ref_audio_path, text, lang_id, lang_name, dur, spd, instruct = sample
|
| 378 |
save_names.append(save_name)
|
| 379 |
ref_texts.append(ref_text)
|
| 380 |
ref_audio_paths.append(ref_audio_path)
|
|
|
|
| 382 |
langs.append(lang_id)
|
| 383 |
durations.append(dur)
|
| 384 |
speeds.append(spd)
|
| 385 |
+
instructs.append(instruct)
|
| 386 |
|
| 387 |
start_time = time.time()
|
| 388 |
audios = worker_model.generate(
|
| 389 |
text=texts,
|
| 390 |
language=langs,
|
| 391 |
+
ref_audio=ref_audio_paths if any(p is not None for p in ref_audio_paths) else None,
|
| 392 |
+
ref_text=ref_texts if any(t is not None for t in ref_texts) else None,
|
| 393 |
duration=durations if any(d is not None for d in durations) else None,
|
| 394 |
speed=speeds if any(s is not None for s in speeds) else None,
|
| 395 |
+
instruct=instructs if any(i is not None for i in instructs) else None,
|
| 396 |
**gen_kwargs,
|
| 397 |
)
|
| 398 |
batch_synth_time = time.time() - start_time
|
|
|
|
| 400 |
results = []
|
| 401 |
for save_name, audio in zip(save_names, audios):
|
| 402 |
save_path = os.path.join(res_dir, save_name + ".wav")
|
| 403 |
+
sf.write(save_path, audio, worker_model.sampling_rate)
|
| 404 |
audio_duration = audio.shape[-1] / worker_model.sampling_rate
|
| 405 |
results.append(
|
| 406 |
(
|
|
|
|
| 451 |
samples.append(
|
| 452 |
(
|
| 453 |
s["id"],
|
| 454 |
+
s.get("ref_text"),
|
| 455 |
+
s.get("ref_audio"),
|
| 456 |
s["text"],
|
| 457 |
lang_id,
|
| 458 |
lang_name,
|
| 459 |
s.get("duration"),
|
| 460 |
s.get("speed"),
|
| 461 |
+
s.get("instruct"),
|
| 462 |
)
|
| 463 |
)
|
| 464 |
|
|
|
|
| 473 |
) as executor:
|
| 474 |
futures = []
|
| 475 |
|
|
|
|
| 476 |
logging.info("Running batch inference")
|
| 477 |
|
| 478 |
+
# Split samples by mode (voice-clone vs non-voice-clone) before
|
| 479 |
+
# clustering so that each batch is homogeneous. Mixing ref_audio
|
| 480 |
+
# and non-ref_audio samples in the same batch would crash in
|
| 481 |
+
# generate() → create_voice_clone_prompt().
|
| 482 |
+
clone_samples = [s for s in samples if s[2] is not None]
|
| 483 |
+
other_samples = [s for s in samples if s[2] is None]
|
| 484 |
+
|
| 485 |
duration_estimator = RuleDurationEstimator()
|
| 486 |
+
batches = []
|
| 487 |
+
for subset in (clone_samples, other_samples):
|
| 488 |
+
if not subset:
|
| 489 |
+
continue
|
| 490 |
+
if args.batch_size > 0:
|
| 491 |
+
batches.extend(
|
| 492 |
+
cluster_samples_by_batch_size(
|
| 493 |
+
subset, duration_estimator, args.batch_size
|
| 494 |
+
)
|
| 495 |
+
)
|
| 496 |
+
else:
|
| 497 |
+
batches.extend(
|
| 498 |
+
cluster_samples_by_duration(
|
| 499 |
+
subset, duration_estimator, args.batch_duration
|
| 500 |
+
)
|
| 501 |
+
)
|
| 502 |
|
| 503 |
args_dict = vars(args)
|
| 504 |
|
omnivoice/data/dataset.py
CHANGED
|
@@ -44,8 +44,9 @@ from typing import Any, Dict, Iterator, List, Optional, Tuple
|
|
| 44 |
|
| 45 |
import torch
|
| 46 |
import torch.distributed as dist
|
| 47 |
-
import torchaudio
|
| 48 |
import webdataset as wds
|
|
|
|
|
|
|
| 49 |
from torch.utils.data import IterableDataset
|
| 50 |
|
| 51 |
|
|
@@ -54,12 +55,8 @@ def load_audio_webdataset(data, sample_rate: int = 24000, device="cpu"):
|
|
| 54 |
Load audio from bytes data and resample to the target sample rate if needed.
|
| 55 |
Return a tensor of shape (1, num_samples)
|
| 56 |
"""
|
| 57 |
-
audio
|
| 58 |
audio = audio.to(device)
|
| 59 |
-
if audio.size(dim=0) > 1:
|
| 60 |
-
audio = torch.mean(audio, dim=0)
|
| 61 |
-
if sr != sample_rate:
|
| 62 |
-
audio = torchaudio.functional.resample(audio, sr, sample_rate)
|
| 63 |
return audio
|
| 64 |
|
| 65 |
|
|
@@ -433,13 +430,9 @@ class JsonlDatasetReader(IterableDataReader):
|
|
| 433 |
)
|
| 434 |
continue
|
| 435 |
try:
|
| 436 |
-
waveform
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
if sr != self.sample_rate:
|
| 440 |
-
waveform = torchaudio.functional.resample(
|
| 441 |
-
waveform, sr, self.sample_rate
|
| 442 |
-
)
|
| 443 |
if self.normalize_audio:
|
| 444 |
waveform = (waveform / (waveform.abs().max() + 1e-7)) * 0.9
|
| 445 |
meta["audio_duration"] = waveform.shape[1] / self.sample_rate
|
|
|
|
| 44 |
|
| 45 |
import torch
|
| 46 |
import torch.distributed as dist
|
|
|
|
| 47 |
import webdataset as wds
|
| 48 |
+
|
| 49 |
+
from omnivoice.utils.audio import load_audio, load_audio_bytes
|
| 50 |
from torch.utils.data import IterableDataset
|
| 51 |
|
| 52 |
|
|
|
|
| 55 |
Load audio from bytes data and resample to the target sample rate if needed.
|
| 56 |
Return a tensor of shape (1, num_samples)
|
| 57 |
"""
|
| 58 |
+
audio = torch.from_numpy(load_audio_bytes(data, sample_rate))
|
| 59 |
audio = audio.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
return audio
|
| 61 |
|
| 62 |
|
|
|
|
| 430 |
)
|
| 431 |
continue
|
| 432 |
try:
|
| 433 |
+
waveform = torch.from_numpy(
|
| 434 |
+
load_audio(audio_path, self.sample_rate)
|
| 435 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 436 |
if self.normalize_audio:
|
| 437 |
waveform = (waveform / (waveform.abs().max() + 1e-7)) * 0.9
|
| 438 |
meta["audio_duration"] = waveform.shape[1] / self.sample_rate
|
omnivoice/eval/mos/utmos.py
CHANGED
|
@@ -32,7 +32,7 @@ import torch
|
|
| 32 |
from tqdm import tqdm
|
| 33 |
|
| 34 |
from omnivoice.eval.models.utmos import UTMOS22Strong
|
| 35 |
-
from omnivoice.eval.utils import
|
| 36 |
from omnivoice.utils.data_utils import read_test_list
|
| 37 |
|
| 38 |
warnings.filterwarnings("ignore")
|
|
@@ -140,7 +140,7 @@ def run_utmos_worker(file_idx, wav_path, language_name):
|
|
| 140 |
return file_idx, wav_path, language_name, f"File not found: {wav_path}", "error"
|
| 141 |
|
| 142 |
# Load and preprocess waveform
|
| 143 |
-
speech =
|
| 144 |
|
| 145 |
# Compute score
|
| 146 |
# UTMOS expects input shape (Batch, Time)
|
|
|
|
| 32 |
from tqdm import tqdm
|
| 33 |
|
| 34 |
from omnivoice.eval.models.utmos import UTMOS22Strong
|
| 35 |
+
from omnivoice.eval.utils import load_eval_waveform
|
| 36 |
from omnivoice.utils.data_utils import read_test_list
|
| 37 |
|
| 38 |
warnings.filterwarnings("ignore")
|
|
|
|
| 140 |
return file_idx, wav_path, language_name, f"File not found: {wav_path}", "error"
|
| 141 |
|
| 142 |
# Load and preprocess waveform
|
| 143 |
+
speech = load_eval_waveform(wav_path, worker_sr, device=worker_device)
|
| 144 |
|
| 145 |
# Compute score
|
| 146 |
# UTMOS expects input shape (Batch, Time)
|
omnivoice/eval/speaker_similarity/sim.py
CHANGED
|
@@ -33,7 +33,7 @@ import torch
|
|
| 33 |
from tqdm import tqdm
|
| 34 |
|
| 35 |
from omnivoice.eval.models.ecapa_tdnn_wavlm import ECAPA_TDNN_WAVLM
|
| 36 |
-
from omnivoice.eval.utils import
|
| 37 |
from omnivoice.utils.data_utils import read_test_list
|
| 38 |
|
| 39 |
warnings.filterwarnings("ignore")
|
|
@@ -144,7 +144,7 @@ def worker_init(
|
|
| 144 |
@torch.no_grad()
|
| 145 |
def get_embedding(wav_path: str) -> torch.Tensor:
|
| 146 |
"""Extract embedding for a single file."""
|
| 147 |
-
speech =
|
| 148 |
return worker_model([speech])
|
| 149 |
|
| 150 |
|
|
|
|
| 33 |
from tqdm import tqdm
|
| 34 |
|
| 35 |
from omnivoice.eval.models.ecapa_tdnn_wavlm import ECAPA_TDNN_WAVLM
|
| 36 |
+
from omnivoice.eval.utils import load_eval_waveform
|
| 37 |
from omnivoice.utils.data_utils import read_test_list
|
| 38 |
|
| 39 |
warnings.filterwarnings("ignore")
|
|
|
|
| 144 |
@torch.no_grad()
|
| 145 |
def get_embedding(wav_path: str) -> torch.Tensor:
|
| 146 |
"""Extract embedding for a single file."""
|
| 147 |
+
speech = load_eval_waveform(wav_path, worker_sr, device=worker_device, max_seconds=120)
|
| 148 |
return worker_model([speech])
|
| 149 |
|
| 150 |
|
omnivoice/eval/utils.py
CHANGED
|
@@ -23,7 +23,7 @@ import soundfile as sf
|
|
| 23 |
import torch
|
| 24 |
|
| 25 |
|
| 26 |
-
def
|
| 27 |
fname: str,
|
| 28 |
sample_rate: int,
|
| 29 |
dtype: str = "float32",
|
|
|
|
| 23 |
import torch
|
| 24 |
|
| 25 |
|
| 26 |
+
def load_eval_waveform(
|
| 27 |
fname: str,
|
| 28 |
sample_rate: int,
|
| 29 |
dtype: str = "float32",
|
omnivoice/eval/wer/hubert.py
CHANGED
|
@@ -31,7 +31,7 @@ import numpy as np
|
|
| 31 |
import torch
|
| 32 |
from tqdm import tqdm
|
| 33 |
|
| 34 |
-
from omnivoice.eval.utils import
|
| 35 |
from omnivoice.eval.wer.common import process_one
|
| 36 |
from omnivoice.utils.data_utils import read_test_list
|
| 37 |
|
|
@@ -166,7 +166,7 @@ def run_eval_worker(data_chunk, batch_size):
|
|
| 166 |
try:
|
| 167 |
dataset = [
|
| 168 |
{
|
| 169 |
-
"array":
|
| 170 |
item["wav_path"], sample_rate=16000, return_numpy=True
|
| 171 |
),
|
| 172 |
"sampling_rate": 16000,
|
|
|
|
| 31 |
import torch
|
| 32 |
from tqdm import tqdm
|
| 33 |
|
| 34 |
+
from omnivoice.eval.utils import load_eval_waveform
|
| 35 |
from omnivoice.eval.wer.common import process_one
|
| 36 |
from omnivoice.utils.data_utils import read_test_list
|
| 37 |
|
|
|
|
| 166 |
try:
|
| 167 |
dataset = [
|
| 168 |
{
|
| 169 |
+
"array": load_eval_waveform(
|
| 170 |
item["wav_path"], sample_rate=16000, return_numpy=True
|
| 171 |
),
|
| 172 |
"sampling_rate": 16000,
|
omnivoice/eval/wer/minimax.py
CHANGED
|
@@ -34,7 +34,7 @@ import torch
|
|
| 34 |
import zhconv
|
| 35 |
from tqdm import tqdm
|
| 36 |
|
| 37 |
-
from omnivoice.eval.utils import
|
| 38 |
from omnivoice.eval.wer.common import log_metrics, process_one
|
| 39 |
from omnivoice.eval.wer.text_norm_omni import text_normalize
|
| 40 |
from omnivoice.utils.data_utils import read_test_list
|
|
@@ -275,7 +275,7 @@ class SpeechEvalDataset(torch.utils.data.Dataset):
|
|
| 275 |
|
| 276 |
def __getitem__(self, index):
|
| 277 |
item = self.data_list[index]
|
| 278 |
-
waveform =
|
| 279 |
return {
|
| 280 |
"array": waveform,
|
| 281 |
"sampling_rate": 16000,
|
|
|
|
| 34 |
import zhconv
|
| 35 |
from tqdm import tqdm
|
| 36 |
|
| 37 |
+
from omnivoice.eval.utils import load_eval_waveform
|
| 38 |
from omnivoice.eval.wer.common import log_metrics, process_one
|
| 39 |
from omnivoice.eval.wer.text_norm_omni import text_normalize
|
| 40 |
from omnivoice.utils.data_utils import read_test_list
|
|
|
|
| 275 |
|
| 276 |
def __getitem__(self, index):
|
| 277 |
item = self.data_list[index]
|
| 278 |
+
waveform = load_eval_waveform(item["wav_path"], sample_rate=16000, return_numpy=True)
|
| 279 |
return {
|
| 280 |
"array": waveform,
|
| 281 |
"sampling_rate": 16000,
|
omnivoice/eval/wer/seedtts.py
CHANGED
|
@@ -34,7 +34,7 @@ import zhconv
|
|
| 34 |
from tqdm import tqdm
|
| 35 |
from zhon.hanzi import punctuation
|
| 36 |
|
| 37 |
-
from omnivoice.eval.utils import
|
| 38 |
from omnivoice.eval.wer.common import process_one
|
| 39 |
from omnivoice.utils.data_utils import read_test_list
|
| 40 |
|
|
@@ -228,7 +228,7 @@ def run_eval_worker(data_chunk, lang, batch_size):
|
|
| 228 |
# Load waveforms as arrays, truncating to 30s
|
| 229 |
dataset = [
|
| 230 |
{
|
| 231 |
-
"array":
|
| 232 |
item["wav_path"], sample_rate=16000, return_numpy=True
|
| 233 |
)[: 16000 * 30],
|
| 234 |
"sampling_rate": 16000,
|
|
|
|
| 34 |
from tqdm import tqdm
|
| 35 |
from zhon.hanzi import punctuation
|
| 36 |
|
| 37 |
+
from omnivoice.eval.utils import load_eval_waveform
|
| 38 |
from omnivoice.eval.wer.common import process_one
|
| 39 |
from omnivoice.utils.data_utils import read_test_list
|
| 40 |
|
|
|
|
| 228 |
# Load waveforms as arrays, truncating to 30s
|
| 229 |
dataset = [
|
| 230 |
{
|
| 231 |
+
"array": load_eval_waveform(
|
| 232 |
item["wav_path"], sample_rate=16000, return_numpy=True
|
| 233 |
)[: 16000 * 30],
|
| 234 |
"sampling_rate": 16000,
|
omnivoice/models/omnivoice.py
CHANGED
|
@@ -36,10 +36,11 @@ from dataclasses import dataclass, fields
|
|
| 36 |
from functools import partial
|
| 37 |
from typing import Any, List, Optional, Union
|
| 38 |
|
|
|
|
|
|
|
| 39 |
import torch
|
| 40 |
import torch.nn as nn
|
| 41 |
import torch.nn.functional as F
|
| 42 |
-
import torchaudio
|
| 43 |
from torch.nn.attention.flex_attention import create_block_mask
|
| 44 |
from transformers import (
|
| 45 |
AutoFeatureExtractor,
|
|
@@ -310,12 +311,14 @@ class OmniVoice(PreTrainedModel):
|
|
| 310 |
@torch.inference_mode()
|
| 311 |
def transcribe(
|
| 312 |
self,
|
| 313 |
-
audio: Union[str, tuple
|
| 314 |
) -> str:
|
| 315 |
"""Transcribe audio using the loaded Whisper ASR model.
|
| 316 |
|
| 317 |
Args:
|
| 318 |
-
audio: File path or (waveform, sample_rate) tuple.
|
|
|
|
|
|
|
| 319 |
|
| 320 |
Returns:
|
| 321 |
Transcribed text.
|
|
@@ -329,12 +332,11 @@ class OmniVoice(PreTrainedModel):
|
|
| 329 |
return self._asr_pipe(audio)["text"].strip()
|
| 330 |
else:
|
| 331 |
waveform, sr = audio
|
| 332 |
-
if waveform.
|
| 333 |
-
waveform = waveform.
|
| 334 |
-
|
| 335 |
-
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
| 336 |
audio_input = {
|
| 337 |
-
"array": waveform
|
| 338 |
"sampling_rate": sr,
|
| 339 |
}
|
| 340 |
return self._asr_pipe(audio_input)["text"].strip()
|
|
@@ -475,7 +477,7 @@ class OmniVoice(PreTrainedModel):
|
|
| 475 |
speed: Union[float, list[Optional[float]], None] = None,
|
| 476 |
generation_config: Optional[OmniVoiceGenerationConfig] = None,
|
| 477 |
**kwargs,
|
| 478 |
-
) -> list[
|
| 479 |
"""Generate speech audio given text in various modes.
|
| 480 |
|
| 481 |
Supports three modes:
|
|
@@ -522,8 +524,10 @@ class OmniVoice(PreTrainedModel):
|
|
| 522 |
audio_chunk_threshold: Only apply chunking if estimated audio
|
| 523 |
duration exceeds this threshold (seconds).
|
| 524 |
Returns:
|
| 525 |
-
``audios`` a list of
|
| 526 |
-
consistent with the model's audio tokenizer
|
|
|
|
|
|
|
| 527 |
"""
|
| 528 |
|
| 529 |
if self.audio_tokenizer is None or self.text_tokenizer is None:
|
|
@@ -611,17 +615,19 @@ class OmniVoice(PreTrainedModel):
|
|
| 611 |
ref_wav = load_audio(ref_audio, self.sampling_rate)
|
| 612 |
else:
|
| 613 |
waveform, sr = ref_audio
|
| 614 |
-
if waveform.
|
| 615 |
-
waveform = waveform.
|
| 616 |
-
if waveform.
|
| 617 |
-
waveform =
|
|
|
|
|
|
|
| 618 |
if sr != self.sampling_rate:
|
| 619 |
-
waveform =
|
| 620 |
-
waveform, sr, self.sampling_rate
|
| 621 |
)
|
| 622 |
ref_wav = waveform
|
| 623 |
|
| 624 |
-
ref_rms =
|
| 625 |
if 0 < ref_rms < 0.1:
|
| 626 |
ref_wav = ref_wav * 0.1 / ref_rms
|
| 627 |
|
|
@@ -640,13 +646,13 @@ class OmniVoice(PreTrainedModel):
|
|
| 640 |
lead_sil=100,
|
| 641 |
trail_sil=200,
|
| 642 |
)
|
| 643 |
-
if ref_wav.
|
| 644 |
raise ValueError(
|
| 645 |
"Reference audio is empty after silence removal. "
|
| 646 |
"Try setting preprocess_prompt=False."
|
| 647 |
)
|
| 648 |
|
| 649 |
-
ref_duration = ref_wav.
|
| 650 |
if ref_duration > 20.0:
|
| 651 |
logger.warning(
|
| 652 |
"Reference audio is %.1fs long (>20s). This may cause slower "
|
|
@@ -664,10 +670,14 @@ class OmniVoice(PreTrainedModel):
|
|
| 664 |
logger.debug("Auto-transcribed ref_text: %s", ref_text)
|
| 665 |
|
| 666 |
chunk_size = self.audio_tokenizer.config.hop_length
|
| 667 |
-
clip_size = int(ref_wav.
|
| 668 |
ref_wav = ref_wav[:, :-clip_size] if clip_size > 0 else ref_wav
|
|
|
|
|
|
|
|
|
|
|
|
|
| 669 |
ref_audio_tokens = self.audio_tokenizer.encode(
|
| 670 |
-
|
| 671 |
).audio_codes.squeeze(
|
| 672 |
0
|
| 673 |
) # (C, T)
|
|
@@ -686,7 +696,7 @@ class OmniVoice(PreTrainedModel):
|
|
| 686 |
tokens: Union[torch.Tensor, List[torch.Tensor]],
|
| 687 |
rms: Union[float, None],
|
| 688 |
gen_config: OmniVoiceGenerationConfig,
|
| 689 |
-
) ->
|
| 690 |
"""
|
| 691 |
Args:
|
| 692 |
tokens: Audio tokens — either a single tensor of shape
|
|
@@ -694,7 +704,7 @@ class OmniVoice(PreTrainedModel):
|
|
| 694 |
rms: RMS of the reference audio for volume adjustment.
|
| 695 |
gen_config: Generation config for post-processing options.
|
| 696 |
Returns:
|
| 697 |
-
Decoded and post-processed audio
|
| 698 |
"""
|
| 699 |
tokenizer_device = self.audio_tokenizer.device
|
| 700 |
if isinstance(tokens, list):
|
|
@@ -702,6 +712,7 @@ class OmniVoice(PreTrainedModel):
|
|
| 702 |
self.audio_tokenizer.decode(t.to(tokenizer_device).unsqueeze(0))
|
| 703 |
.audio_values[0]
|
| 704 |
.cpu()
|
|
|
|
| 705 |
for t in tokens
|
| 706 |
]
|
| 707 |
audio_waveform = cross_fade_chunks(chunk_audios, self.sampling_rate)
|
|
@@ -710,28 +721,30 @@ class OmniVoice(PreTrainedModel):
|
|
| 710 |
self.audio_tokenizer.decode(tokens.to(tokenizer_device).unsqueeze(0))
|
| 711 |
.audio_values[0]
|
| 712 |
.cpu()
|
|
|
|
| 713 |
)
|
| 714 |
|
| 715 |
-
|
| 716 |
audio_waveform,
|
| 717 |
postprocess_output=gen_config.postprocess_output,
|
| 718 |
ref_rms=rms,
|
| 719 |
)
|
|
|
|
| 720 |
|
| 721 |
def _post_process_audio(
|
| 722 |
self,
|
| 723 |
-
generated_audio:
|
| 724 |
postprocess_output: bool,
|
| 725 |
ref_rms: Union[float, None],
|
| 726 |
-
) ->
|
| 727 |
"""Optionally remove long silences, adjust volume, and add edge padding.
|
| 728 |
|
| 729 |
Args:
|
| 730 |
-
generated_audio:
|
| 731 |
postprocess_output: If True, remove long silences and apply fade/pad.
|
| 732 |
ref_rms: RMS of the reference audio for volume normalisation.
|
| 733 |
Returns:
|
| 734 |
-
Processed
|
| 735 |
"""
|
| 736 |
if postprocess_output:
|
| 737 |
generated_audio = remove_silence(
|
|
@@ -745,9 +758,7 @@ class OmniVoice(PreTrainedModel):
|
|
| 745 |
if ref_rms is not None and ref_rms < 0.1:
|
| 746 |
generated_audio = generated_audio * ref_rms / 0.1
|
| 747 |
elif ref_rms is None:
|
| 748 |
-
|
| 749 |
-
# to avoid clipping while keeping a comfortable volume level.
|
| 750 |
-
peak = generated_audio.abs().max()
|
| 751 |
if peak > 1e-6:
|
| 752 |
generated_audio = generated_audio / peak * 0.5
|
| 753 |
|
|
@@ -1549,6 +1560,9 @@ def _combine_text(text, ref_text: Optional[str] = None) -> str:
|
|
| 1549 |
# filter out newline / carriage-return characters
|
| 1550 |
full_text = re.sub(r"[\r\n]+", "", full_text)
|
| 1551 |
|
|
|
|
|
|
|
|
|
|
| 1552 |
# collapse consecutive spaces / tabs into a single space
|
| 1553 |
full_text = re.sub(r"[ \t]+", " ", full_text)
|
| 1554 |
|
|
|
|
| 36 |
from functools import partial
|
| 37 |
from typing import Any, List, Optional, Union
|
| 38 |
|
| 39 |
+
import librosa
|
| 40 |
+
import numpy as np
|
| 41 |
import torch
|
| 42 |
import torch.nn as nn
|
| 43 |
import torch.nn.functional as F
|
|
|
|
| 44 |
from torch.nn.attention.flex_attention import create_block_mask
|
| 45 |
from transformers import (
|
| 46 |
AutoFeatureExtractor,
|
|
|
|
| 311 |
@torch.inference_mode()
|
| 312 |
def transcribe(
|
| 313 |
self,
|
| 314 |
+
audio: Union[str, tuple],
|
| 315 |
) -> str:
|
| 316 |
"""Transcribe audio using the loaded Whisper ASR model.
|
| 317 |
|
| 318 |
Args:
|
| 319 |
+
audio: File path or ``(waveform, sample_rate)`` tuple.
|
| 320 |
+
Waveform can be a numpy array or torch.Tensor of shape
|
| 321 |
+
``(1, T)`` or ``(T,)``.
|
| 322 |
|
| 323 |
Returns:
|
| 324 |
Transcribed text.
|
|
|
|
| 332 |
return self._asr_pipe(audio)["text"].strip()
|
| 333 |
else:
|
| 334 |
waveform, sr = audio
|
| 335 |
+
if isinstance(waveform, torch.Tensor):
|
| 336 |
+
waveform = waveform.cpu().numpy()
|
| 337 |
+
waveform = np.squeeze(waveform) # (1, T) or (T,) → (T,)
|
|
|
|
| 338 |
audio_input = {
|
| 339 |
+
"array": waveform,
|
| 340 |
"sampling_rate": sr,
|
| 341 |
}
|
| 342 |
return self._asr_pipe(audio_input)["text"].strip()
|
|
|
|
| 477 |
speed: Union[float, list[Optional[float]], None] = None,
|
| 478 |
generation_config: Optional[OmniVoiceGenerationConfig] = None,
|
| 479 |
**kwargs,
|
| 480 |
+
) -> list[np.ndarray]:
|
| 481 |
"""Generate speech audio given text in various modes.
|
| 482 |
|
| 483 |
Supports three modes:
|
|
|
|
| 524 |
audio_chunk_threshold: Only apply chunking if estimated audio
|
| 525 |
duration exceeds this threshold (seconds).
|
| 526 |
Returns:
|
| 527 |
+
``audios`` a list of 1-D ``np.ndarray`` with shape ``(T,)`` and
|
| 528 |
+
sampling rate consistent with the model's audio tokenizer
|
| 529 |
+
(usually 24 000 Hz). Can be saved directly with
|
| 530 |
+
``soundfile.write("out.wav", audios[0], model.sampling_rate)``.
|
| 531 |
"""
|
| 532 |
|
| 533 |
if self.audio_tokenizer is None or self.text_tokenizer is None:
|
|
|
|
| 615 |
ref_wav = load_audio(ref_audio, self.sampling_rate)
|
| 616 |
else:
|
| 617 |
waveform, sr = ref_audio
|
| 618 |
+
if isinstance(waveform, torch.Tensor):
|
| 619 |
+
waveform = waveform.cpu().numpy()
|
| 620 |
+
if waveform.ndim == 1:
|
| 621 |
+
waveform = waveform[np.newaxis, :]
|
| 622 |
+
if waveform.shape[0] > 1:
|
| 623 |
+
waveform = np.mean(waveform, axis=0, keepdims=True)
|
| 624 |
if sr != self.sampling_rate:
|
| 625 |
+
waveform = librosa.resample(
|
| 626 |
+
waveform, orig_sr=sr, target_sr=self.sampling_rate,
|
| 627 |
)
|
| 628 |
ref_wav = waveform
|
| 629 |
|
| 630 |
+
ref_rms = float(np.sqrt(np.mean(ref_wav ** 2)))
|
| 631 |
if 0 < ref_rms < 0.1:
|
| 632 |
ref_wav = ref_wav * 0.1 / ref_rms
|
| 633 |
|
|
|
|
| 646 |
lead_sil=100,
|
| 647 |
trail_sil=200,
|
| 648 |
)
|
| 649 |
+
if ref_wav.shape[-1] == 0:
|
| 650 |
raise ValueError(
|
| 651 |
"Reference audio is empty after silence removal. "
|
| 652 |
"Try setting preprocess_prompt=False."
|
| 653 |
)
|
| 654 |
|
| 655 |
+
ref_duration = ref_wav.shape[-1] / self.sampling_rate
|
| 656 |
if ref_duration > 20.0:
|
| 657 |
logger.warning(
|
| 658 |
"Reference audio is %.1fs long (>20s). This may cause slower "
|
|
|
|
| 670 |
logger.debug("Auto-transcribed ref_text: %s", ref_text)
|
| 671 |
|
| 672 |
chunk_size = self.audio_tokenizer.config.hop_length
|
| 673 |
+
clip_size = int(ref_wav.shape[-1] % chunk_size)
|
| 674 |
ref_wav = ref_wav[:, :-clip_size] if clip_size > 0 else ref_wav
|
| 675 |
+
# numpy → torch at tokenizer boundary
|
| 676 |
+
ref_wav_tensor = torch.from_numpy(ref_wav).to(
|
| 677 |
+
self.audio_tokenizer.device
|
| 678 |
+
)
|
| 679 |
ref_audio_tokens = self.audio_tokenizer.encode(
|
| 680 |
+
ref_wav_tensor.unsqueeze(0),
|
| 681 |
).audio_codes.squeeze(
|
| 682 |
0
|
| 683 |
) # (C, T)
|
|
|
|
| 696 |
tokens: Union[torch.Tensor, List[torch.Tensor]],
|
| 697 |
rms: Union[float, None],
|
| 698 |
gen_config: OmniVoiceGenerationConfig,
|
| 699 |
+
) -> np.ndarray:
|
| 700 |
"""
|
| 701 |
Args:
|
| 702 |
tokens: Audio tokens — either a single tensor of shape
|
|
|
|
| 704 |
rms: RMS of the reference audio for volume adjustment.
|
| 705 |
gen_config: Generation config for post-processing options.
|
| 706 |
Returns:
|
| 707 |
+
Decoded and post-processed audio array of shape (T,).
|
| 708 |
"""
|
| 709 |
tokenizer_device = self.audio_tokenizer.device
|
| 710 |
if isinstance(tokens, list):
|
|
|
|
| 712 |
self.audio_tokenizer.decode(t.to(tokenizer_device).unsqueeze(0))
|
| 713 |
.audio_values[0]
|
| 714 |
.cpu()
|
| 715 |
+
.numpy()
|
| 716 |
for t in tokens
|
| 717 |
]
|
| 718 |
audio_waveform = cross_fade_chunks(chunk_audios, self.sampling_rate)
|
|
|
|
| 721 |
self.audio_tokenizer.decode(tokens.to(tokenizer_device).unsqueeze(0))
|
| 722 |
.audio_values[0]
|
| 723 |
.cpu()
|
| 724 |
+
.numpy()
|
| 725 |
)
|
| 726 |
|
| 727 |
+
audio_waveform = self._post_process_audio(
|
| 728 |
audio_waveform,
|
| 729 |
postprocess_output=gen_config.postprocess_output,
|
| 730 |
ref_rms=rms,
|
| 731 |
)
|
| 732 |
+
return audio_waveform.squeeze(0)
|
| 733 |
|
| 734 |
def _post_process_audio(
|
| 735 |
self,
|
| 736 |
+
generated_audio: np.ndarray,
|
| 737 |
postprocess_output: bool,
|
| 738 |
ref_rms: Union[float, None],
|
| 739 |
+
) -> np.ndarray:
|
| 740 |
"""Optionally remove long silences, adjust volume, and add edge padding.
|
| 741 |
|
| 742 |
Args:
|
| 743 |
+
generated_audio: Numpy array of shape (1, T).
|
| 744 |
postprocess_output: If True, remove long silences and apply fade/pad.
|
| 745 |
ref_rms: RMS of the reference audio for volume normalisation.
|
| 746 |
Returns:
|
| 747 |
+
Processed numpy array of shape (1, T).
|
| 748 |
"""
|
| 749 |
if postprocess_output:
|
| 750 |
generated_audio = remove_silence(
|
|
|
|
| 758 |
if ref_rms is not None and ref_rms < 0.1:
|
| 759 |
generated_audio = generated_audio * ref_rms / 0.1
|
| 760 |
elif ref_rms is None:
|
| 761 |
+
peak = np.abs(generated_audio).max()
|
|
|
|
|
|
|
| 762 |
if peak > 1e-6:
|
| 763 |
generated_audio = generated_audio / peak * 0.5
|
| 764 |
|
|
|
|
| 1560 |
# filter out newline / carriage-return characters
|
| 1561 |
full_text = re.sub(r"[\r\n]+", "", full_text)
|
| 1562 |
|
| 1563 |
+
# replace Chinese parentheses with English ones
|
| 1564 |
+
full_text = full_text.replace("\uff08", "(").replace("\uff09", ")")
|
| 1565 |
+
|
| 1566 |
# collapse consecutive spaces / tabs into a single space
|
| 1567 |
full_text = re.sub(r"[ \t]+", " ", full_text)
|
| 1568 |
|
omnivoice/scripts/denoise_audio.py
CHANGED
|
@@ -73,6 +73,7 @@ from tqdm.auto import tqdm
|
|
| 73 |
|
| 74 |
from omnivoice.data.batching import StreamLengthGroupDataset
|
| 75 |
from omnivoice.data.dataset import JsonlDatasetReader, WebDatasetReader
|
|
|
|
| 76 |
from omnivoice.utils.common import str2bool
|
| 77 |
|
| 78 |
SIDON_INPUT_SAMPLE_RATE = 16_000
|
|
@@ -367,10 +368,10 @@ def extract_seamless_m4t_features(
|
|
| 367 |
|
| 368 |
def serialise_flac(key: str, waveform: torch.Tensor, sample_rate: int) -> dict:
|
| 369 |
buffer = io.BytesIO()
|
| 370 |
-
audio = waveform.to(dtype=torch.float32).cpu()
|
| 371 |
-
if audio.ndim ==
|
| 372 |
-
audio = audio.
|
| 373 |
-
|
| 374 |
return {"__key__": key, "flac": buffer.getvalue()}
|
| 375 |
|
| 376 |
|
|
|
|
| 73 |
|
| 74 |
from omnivoice.data.batching import StreamLengthGroupDataset
|
| 75 |
from omnivoice.data.dataset import JsonlDatasetReader, WebDatasetReader
|
| 76 |
+
import soundfile as sf
|
| 77 |
from omnivoice.utils.common import str2bool
|
| 78 |
|
| 79 |
SIDON_INPUT_SAMPLE_RATE = 16_000
|
|
|
|
| 368 |
|
| 369 |
def serialise_flac(key: str, waveform: torch.Tensor, sample_rate: int) -> dict:
|
| 370 |
buffer = io.BytesIO()
|
| 371 |
+
audio = waveform.to(dtype=torch.float32).cpu().numpy()
|
| 372 |
+
if audio.ndim == 2:
|
| 373 |
+
audio = audio.T # (C, T) → (T, C) for soundfile
|
| 374 |
+
sf.write(buffer, audio, sample_rate, format="FLAC")
|
| 375 |
return {"__key__": key, "flac": buffer.getvalue()}
|
| 376 |
|
| 377 |
|
omnivoice/scripts/extract_audio_tokens_add_noise.py
CHANGED
|
@@ -66,13 +66,13 @@ from typing import Any
|
|
| 66 |
import numpy as np
|
| 67 |
import torch
|
| 68 |
import torch.nn.functional as F
|
| 69 |
-
import torchaudio
|
| 70 |
import webdataset as wds
|
| 71 |
from torch.utils.data import DataLoader, IterableDataset
|
| 72 |
from tqdm.auto import tqdm
|
| 73 |
from transformers import AutoFeatureExtractor, HiggsAudioV2TokenizerModel
|
| 74 |
|
| 75 |
from omnivoice.data.dataset import JsonlDatasetReader, WebDatasetReader
|
|
|
|
| 76 |
from omnivoice.utils.common import str2bool
|
| 77 |
|
| 78 |
warnings.filterwarnings(
|
|
@@ -207,13 +207,7 @@ def serialise_numpy(key: str, tokens: np.ndarray) -> dict:
|
|
| 207 |
|
| 208 |
def _load_aug_audio(data, sample_rate=24000):
|
| 209 |
"""Simple audio loader for augmentation files."""
|
| 210 |
-
|
| 211 |
-
wav, sr = torchaudio.load(b)
|
| 212 |
-
if wav.shape[0] > 1:
|
| 213 |
-
wav = wav.mean(dim=0, keepdim=True)
|
| 214 |
-
if sr != sample_rate:
|
| 215 |
-
wav = torchaudio.functional.resample(wav, sr, sample_rate)
|
| 216 |
-
return wav
|
| 217 |
|
| 218 |
|
| 219 |
class SimpleWorkerSampler:
|
|
|
|
| 66 |
import numpy as np
|
| 67 |
import torch
|
| 68 |
import torch.nn.functional as F
|
|
|
|
| 69 |
import webdataset as wds
|
| 70 |
from torch.utils.data import DataLoader, IterableDataset
|
| 71 |
from tqdm.auto import tqdm
|
| 72 |
from transformers import AutoFeatureExtractor, HiggsAudioV2TokenizerModel
|
| 73 |
|
| 74 |
from omnivoice.data.dataset import JsonlDatasetReader, WebDatasetReader
|
| 75 |
+
from omnivoice.utils.audio import load_audio_bytes
|
| 76 |
from omnivoice.utils.common import str2bool
|
| 77 |
|
| 78 |
warnings.filterwarnings(
|
|
|
|
| 207 |
|
| 208 |
def _load_aug_audio(data, sample_rate=24000):
|
| 209 |
"""Simple audio loader for augmentation files."""
|
| 210 |
+
return torch.from_numpy(load_audio_bytes(data, sample_rate))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 211 |
|
| 212 |
|
| 213 |
class SimpleWorkerSampler:
|
omnivoice/scripts/jsonl_to_webdataset.py
CHANGED
|
@@ -65,10 +65,13 @@ from concurrent.futures import (
|
|
| 65 |
from itertools import islice
|
| 66 |
from pathlib import Path
|
| 67 |
|
| 68 |
-
import
|
| 69 |
import webdataset as wds
|
| 70 |
from tqdm import tqdm
|
| 71 |
|
|
|
|
|
|
|
|
|
|
| 72 |
from omnivoice.utils.common import str2bool
|
| 73 |
|
| 74 |
|
|
@@ -164,16 +167,16 @@ def process_audio_item(meta, target_sr):
|
|
| 164 |
if not os.path.exists(audio_path):
|
| 165 |
raise FileNotFoundError(f"{audio_path} not found")
|
| 166 |
|
| 167 |
-
waveform, sr =
|
| 168 |
audio_duration = waveform.shape[1] / sr
|
| 169 |
meta["audio_duration"] = audio_duration
|
| 170 |
|
| 171 |
if target_sr and sr != target_sr:
|
| 172 |
-
waveform =
|
| 173 |
sr = target_sr
|
| 174 |
|
| 175 |
audio_buffer = io.BytesIO()
|
| 176 |
-
|
| 177 |
audio_bytes = audio_buffer.getvalue()
|
| 178 |
|
| 179 |
sample = {
|
|
|
|
| 65 |
from itertools import islice
|
| 66 |
from pathlib import Path
|
| 67 |
|
| 68 |
+
import librosa
|
| 69 |
import webdataset as wds
|
| 70 |
from tqdm import tqdm
|
| 71 |
|
| 72 |
+
import soundfile as sf
|
| 73 |
+
|
| 74 |
+
from omnivoice.utils.audio import load_waveform
|
| 75 |
from omnivoice.utils.common import str2bool
|
| 76 |
|
| 77 |
|
|
|
|
| 167 |
if not os.path.exists(audio_path):
|
| 168 |
raise FileNotFoundError(f"{audio_path} not found")
|
| 169 |
|
| 170 |
+
waveform, sr = load_waveform(audio_path)
|
| 171 |
audio_duration = waveform.shape[1] / sr
|
| 172 |
meta["audio_duration"] = audio_duration
|
| 173 |
|
| 174 |
if target_sr and sr != target_sr:
|
| 175 |
+
waveform = librosa.resample(waveform, orig_sr=sr, target_sr=target_sr)
|
| 176 |
sr = target_sr
|
| 177 |
|
| 178 |
audio_buffer = io.BytesIO()
|
| 179 |
+
sf.write(audio_buffer, waveform.T, sr, format="FLAC")
|
| 180 |
audio_bytes = audio_buffer.getvalue()
|
| 181 |
|
| 182 |
sample = {
|
omnivoice/utils/audio.py
CHANGED
|
@@ -17,83 +17,157 @@
|
|
| 17 |
|
| 18 |
"""Audio I/O and processing utilities.
|
| 19 |
|
| 20 |
-
Provides functions for loading, resampling, silence removal,
|
| 21 |
-
cross-fading, and format conversion.
|
| 22 |
-
|
|
|
|
|
|
|
| 23 |
"""
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
import numpy as np
|
| 26 |
-
import
|
| 27 |
-
import torchaudio
|
| 28 |
from pydub import AudioSegment
|
| 29 |
from pydub.silence import detect_leading_silence, detect_nonsilent, split_on_silence
|
| 30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
"""
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
Parameters:
|
| 37 |
audio_path: path of the audio.
|
| 38 |
sampling_rate: target sampling rate.
|
| 39 |
|
| 40 |
Returns:
|
| 41 |
-
|
| 42 |
-
PyTorch tensor of shape (1, T)
|
| 43 |
"""
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
)
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
)
|
| 64 |
-
if waveform.shape[0] > 1:
|
| 65 |
-
waveform = torch.mean(waveform, dim=0, keepdim=True)
|
| 66 |
|
| 67 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
|
| 69 |
|
| 70 |
def remove_silence(
|
| 71 |
-
audio:
|
| 72 |
sampling_rate: int,
|
| 73 |
mid_sil: int = 300,
|
| 74 |
lead_sil: int = 100,
|
| 75 |
trail_sil: int = 300,
|
| 76 |
-
):
|
| 77 |
-
"""
|
| 78 |
-
Remove middle silences longer than mid_sil ms, and edge silences longer than edge_sil ms
|
| 79 |
|
| 80 |
Parameters:
|
| 81 |
-
audio:
|
| 82 |
sampling_rate: sampling rate of the audio.
|
| 83 |
-
mid_sil:
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
trail_sil: the duration of added trailing silence in ms.
|
| 87 |
|
| 88 |
Returns:
|
| 89 |
-
|
| 90 |
-
and T is number of audio samples
|
| 91 |
"""
|
| 92 |
-
|
| 93 |
-
wave = tensor_to_audiosegment(audio, sampling_rate)
|
| 94 |
|
| 95 |
if mid_sil > 0:
|
| 96 |
-
# Split audio using silences longer than mid_sil
|
| 97 |
non_silent_segs = split_on_silence(
|
| 98 |
wave,
|
| 99 |
min_silence_len=mid_sil,
|
|
@@ -101,17 +175,13 @@ def remove_silence(
|
|
| 101 |
keep_silence=mid_sil,
|
| 102 |
seek_step=10,
|
| 103 |
)
|
| 104 |
-
|
| 105 |
-
# Concatenate all non-silent segments
|
| 106 |
wave = AudioSegment.silent(duration=0)
|
| 107 |
for seg in non_silent_segs:
|
| 108 |
wave += seg
|
| 109 |
|
| 110 |
-
# Remove silence longer than 0.1 seconds in the begining and ending of wave
|
| 111 |
wave = remove_silence_edges(wave, lead_sil, trail_sil, -50)
|
| 112 |
|
| 113 |
-
|
| 114 |
-
return audiosegment_to_tensor(wave)
|
| 115 |
|
| 116 |
|
| 117 |
def remove_silence_edges(
|
|
@@ -119,25 +189,12 @@ def remove_silence_edges(
|
|
| 119 |
lead_sil: int = 100,
|
| 120 |
trail_sil: int = 300,
|
| 121 |
silence_threshold: float = -50,
|
| 122 |
-
):
|
| 123 |
-
"""
|
| 124 |
-
Remove edge silences longer than `keep_silence` ms.
|
| 125 |
-
|
| 126 |
-
Parameters:
|
| 127 |
-
audio: an AudioSegment object.
|
| 128 |
-
keep_silence: kept silence in the edge.
|
| 129 |
-
only_edge: If true, only remove edge silences.
|
| 130 |
-
silence_threshold: the threshold of silence.
|
| 131 |
-
|
| 132 |
-
Returns:
|
| 133 |
-
An AudioSegment object
|
| 134 |
-
"""
|
| 135 |
-
# Remove heading silence
|
| 136 |
start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold)
|
| 137 |
start_idx = max(0, start_idx - lead_sil)
|
| 138 |
audio = audio[start_idx:]
|
| 139 |
|
| 140 |
-
# Remove trailing silence
|
| 141 |
audio = audio.reverse()
|
| 142 |
start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold)
|
| 143 |
start_idx = max(0, start_idx - trail_sil)
|
|
@@ -147,80 +204,22 @@ def remove_silence_edges(
|
|
| 147 |
return audio
|
| 148 |
|
| 149 |
|
| 150 |
-
def audiosegment_to_tensor(aseg):
|
| 151 |
-
"""
|
| 152 |
-
Convert a pydub.AudioSegment to PyTorch audio tensor
|
| 153 |
-
"""
|
| 154 |
-
audio_data = np.array(aseg.get_array_of_samples())
|
| 155 |
-
|
| 156 |
-
# Convert to float32 and normalize to [-1, 1] range
|
| 157 |
-
audio_data = audio_data.astype(np.float32) / 32768.0
|
| 158 |
-
|
| 159 |
-
# Handle channels
|
| 160 |
-
if aseg.channels == 1:
|
| 161 |
-
# Mono channel: add channel dimension (T) -> (1, T)
|
| 162 |
-
tensor_data = torch.from_numpy(audio_data).unsqueeze(0)
|
| 163 |
-
else:
|
| 164 |
-
# Multi-channel: reshape to (C, T)
|
| 165 |
-
tensor_data = torch.from_numpy(audio_data.reshape(-1, aseg.channels).T)
|
| 166 |
-
|
| 167 |
-
return tensor_data
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
def tensor_to_audiosegment(tensor, sample_rate):
|
| 171 |
-
"""
|
| 172 |
-
Convert a PyTorch audio tensor to pydub.AudioSegment
|
| 173 |
-
|
| 174 |
-
Parameters:
|
| 175 |
-
tensor: Tensor with shape (C, T), where C is the number of channels
|
| 176 |
-
and T is the time steps
|
| 177 |
-
sample_rate: Audio sample rate
|
| 178 |
-
"""
|
| 179 |
-
# Convert tensor to numpy array
|
| 180 |
-
assert isinstance(tensor, torch.Tensor)
|
| 181 |
-
audio_np = tensor.cpu().numpy()
|
| 182 |
-
|
| 183 |
-
# Convert to int16 type (common format for pydub)
|
| 184 |
-
# Assumes tensor values are in [-1, 1] range as floating point
|
| 185 |
-
audio_np = (audio_np * 32768.0).clip(-32768, 32767).astype(np.int16)
|
| 186 |
-
|
| 187 |
-
# Convert to byte stream
|
| 188 |
-
# For multi-channel audio, pydub requires interleaved format
|
| 189 |
-
# (e.g., left-right-left-right)
|
| 190 |
-
if audio_np.shape[0] > 1:
|
| 191 |
-
# Convert to interleaved format
|
| 192 |
-
audio_np = audio_np.transpose(1, 0).flatten()
|
| 193 |
-
audio_bytes = audio_np.tobytes()
|
| 194 |
-
|
| 195 |
-
# Create AudioSegment
|
| 196 |
-
audio_segment = AudioSegment(
|
| 197 |
-
data=audio_bytes,
|
| 198 |
-
sample_width=2,
|
| 199 |
-
frame_rate=sample_rate,
|
| 200 |
-
channels=tensor.shape[0],
|
| 201 |
-
)
|
| 202 |
-
|
| 203 |
-
return audio_segment
|
| 204 |
-
|
| 205 |
-
|
| 206 |
def fade_and_pad_audio(
|
| 207 |
-
audio:
|
| 208 |
pad_duration: float = 0.1,
|
| 209 |
fade_duration: float = 0.1,
|
| 210 |
sample_rate: int = 24000,
|
| 211 |
-
) ->
|
| 212 |
-
"""
|
| 213 |
-
Applies a smooth fade-in and fade-out to the audio, and then pads both sides
|
| 214 |
-
with pure silence to prevent abrupt starts and ends (clicks/pops).
|
| 215 |
|
| 216 |
Args:
|
| 217 |
-
audio:
|
| 218 |
-
pad_duration:
|
| 219 |
-
fade_duration:
|
| 220 |
-
sample_rate:
|
| 221 |
|
| 222 |
Returns:
|
| 223 |
-
Processed
|
| 224 |
"""
|
| 225 |
if audio.shape[-1] == 0:
|
| 226 |
return audio
|
|
@@ -228,59 +227,53 @@ def fade_and_pad_audio(
|
|
| 228 |
fade_samples = int(fade_duration * sample_rate)
|
| 229 |
pad_samples = int(pad_duration * sample_rate)
|
| 230 |
|
| 231 |
-
processed = audio.
|
| 232 |
|
| 233 |
if fade_samples > 0:
|
| 234 |
k = min(fade_samples, processed.shape[-1] // 2)
|
| 235 |
-
|
| 236 |
if k > 0:
|
| 237 |
-
fade_in =
|
| 238 |
-
|
| 239 |
-
)[None, :]
|
| 240 |
-
processed[..., :k] = processed[..., :k] * fade_in
|
| 241 |
|
| 242 |
-
fade_out =
|
| 243 |
-
|
| 244 |
-
)[None, :]
|
| 245 |
-
processed[..., -k:] = processed[..., -k:] * fade_out
|
| 246 |
|
| 247 |
if pad_samples > 0:
|
| 248 |
-
silence =
|
| 249 |
(processed.shape[0], pad_samples),
|
| 250 |
dtype=processed.dtype,
|
| 251 |
-
device=processed.device,
|
| 252 |
)
|
| 253 |
-
processed =
|
| 254 |
|
| 255 |
return processed
|
| 256 |
|
| 257 |
|
| 258 |
def trim_long_audio(
|
| 259 |
-
audio:
|
| 260 |
sampling_rate: int,
|
| 261 |
max_duration: float = 15.0,
|
| 262 |
min_duration: float = 3.0,
|
| 263 |
trim_threshold: float = 20.0,
|
| 264 |
-
) ->
|
| 265 |
-
"""Trim audio to <= max_duration by splitting at the largest silence gap.
|
| 266 |
|
| 267 |
Only trims when the audio exceeds *trim_threshold* seconds.
|
| 268 |
|
| 269 |
Args:
|
| 270 |
-
audio:
|
| 271 |
-
sampling_rate:
|
| 272 |
-
max_duration:
|
| 273 |
-
min_duration:
|
| 274 |
-
trim_threshold:
|
| 275 |
|
| 276 |
Returns:
|
| 277 |
-
Trimmed
|
| 278 |
"""
|
| 279 |
-
duration = audio.
|
| 280 |
if duration <= trim_threshold:
|
| 281 |
return audio
|
| 282 |
|
| 283 |
-
seg =
|
| 284 |
nonsilent = detect_nonsilent(
|
| 285 |
seg, min_silence_len=100, silence_thresh=-40, seek_step=10
|
| 286 |
)
|
|
@@ -290,7 +283,6 @@ def trim_long_audio(
|
|
| 290 |
max_ms = int(max_duration * 1000)
|
| 291 |
min_ms = int(min_duration * 1000)
|
| 292 |
|
| 293 |
-
# Walk through speech regions; at each gap pick the latest split <= max_duration
|
| 294 |
best_split = 0
|
| 295 |
for start, end in nonsilent:
|
| 296 |
if start > best_split and start <= max_ms:
|
|
@@ -302,56 +294,49 @@ def trim_long_audio(
|
|
| 302 |
best_split = min(max_ms, len(seg))
|
| 303 |
|
| 304 |
trimmed = seg[:best_split]
|
| 305 |
-
return
|
| 306 |
|
| 307 |
|
| 308 |
def cross_fade_chunks(
|
| 309 |
-
chunks: list[
|
| 310 |
sample_rate: int,
|
| 311 |
silence_duration: float = 0.3,
|
| 312 |
-
) ->
|
| 313 |
-
"""Concatenate audio chunks with
|
| 314 |
-
|
| 315 |
-
Each boundary is structured as: fade-out tail → silence buffer → fade-in head.
|
| 316 |
-
This avoids click artifacts from direct concatenation or overlapping mismatch.
|
| 317 |
|
| 318 |
Args:
|
| 319 |
-
chunks:
|
| 320 |
-
sample_rate:
|
| 321 |
-
silence_duration:
|
| 322 |
|
| 323 |
Returns:
|
| 324 |
-
Merged
|
| 325 |
"""
|
| 326 |
if len(chunks) == 1:
|
| 327 |
return chunks[0]
|
| 328 |
|
| 329 |
total_n = int(silence_duration * sample_rate)
|
| 330 |
fade_n = total_n // 3
|
| 331 |
-
silence_n = fade_n
|
| 332 |
-
merged = chunks[0].
|
| 333 |
|
| 334 |
for chunk in chunks[1:]:
|
| 335 |
-
dev, dt = merged.device, merged.dtype
|
| 336 |
parts = [merged]
|
| 337 |
|
| 338 |
-
|
| 339 |
-
fout_n = min(fade_n, merged.size(-1))
|
| 340 |
if fout_n > 0:
|
| 341 |
-
w_out =
|
| 342 |
-
parts[-1][..., -fout_n:] =
|
| 343 |
|
| 344 |
-
|
| 345 |
-
parts.append(torch.zeros(chunks[0].shape[0], silence_n, device=dev, dtype=dt))
|
| 346 |
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
fin_n = min(fade_n, fade_in.size(-1))
|
| 350 |
if fin_n > 0:
|
| 351 |
-
w_in =
|
| 352 |
-
fade_in[..., :fin_n] =
|
| 353 |
|
| 354 |
parts.append(fade_in)
|
| 355 |
-
merged =
|
| 356 |
|
| 357 |
return merged
|
|
|
|
| 17 |
|
| 18 |
"""Audio I/O and processing utilities.
|
| 19 |
|
| 20 |
+
Provides functions for loading, resampling, silence removal,
|
| 21 |
+
chunking, cross-fading, and format conversion.
|
| 22 |
+
|
| 23 |
+
All public functions in this module operate on **numpy float32 arrays**
|
| 24 |
+
with shape ``(C, T)`` (channels-first).
|
| 25 |
"""
|
| 26 |
|
| 27 |
+
import io
|
| 28 |
+
import logging
|
| 29 |
+
|
| 30 |
+
import librosa
|
| 31 |
import numpy as np
|
| 32 |
+
import soundfile as sf
|
|
|
|
| 33 |
from pydub import AudioSegment
|
| 34 |
from pydub.silence import detect_leading_silence, detect_nonsilent, split_on_silence
|
| 35 |
|
| 36 |
+
logger = logging.getLogger(__name__)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ---------------------------------------------------------------------------
|
| 40 |
+
# Loading
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
|
| 43 |
+
|
| 44 |
+
def load_waveform(audio_path: str):
|
| 45 |
+
"""Load audio from a file path, returning (data, sample_rate).
|
| 46 |
+
|
| 47 |
+
Tries two backends in order:
|
| 48 |
+
1. soundfile — covers WAV/FLAC/OGG etc., no ffmpeg needed.
|
| 49 |
+
2. librosa — covers MP3/M4A etc. via audioread + ffmpeg.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
(data, sample_rate) where data is a numpy float32 array of
|
| 53 |
+
shape (C, T).
|
| 54 |
"""
|
| 55 |
+
try:
|
| 56 |
+
data, sr = sf.read(audio_path, dtype="float32", always_2d=True)
|
| 57 |
+
return data.T, sr # (T, C) → (C, T)
|
| 58 |
+
except Exception:
|
| 59 |
+
# soundfile cannot handle MP3/M4A etc., fall back to librosa.
|
| 60 |
+
data, sr = librosa.load(audio_path, sr=None, mono=False)
|
| 61 |
+
if data.ndim == 1:
|
| 62 |
+
data = data[np.newaxis, :]
|
| 63 |
+
return data, sr
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def load_audio(audio_path: str, sampling_rate: int) -> np.ndarray:
|
| 67 |
+
"""Load a waveform from file and resample to the target rate.
|
| 68 |
|
| 69 |
Parameters:
|
| 70 |
audio_path: path of the audio.
|
| 71 |
sampling_rate: target sampling rate.
|
| 72 |
|
| 73 |
Returns:
|
| 74 |
+
Numpy float32 array of shape (1, T).
|
|
|
|
| 75 |
"""
|
| 76 |
+
data, sr = load_waveform(audio_path)
|
| 77 |
+
|
| 78 |
+
if data.shape[0] > 1:
|
| 79 |
+
data = np.mean(data, axis=0, keepdims=True)
|
| 80 |
+
if sr != sampling_rate:
|
| 81 |
+
data = librosa.resample(
|
| 82 |
+
data,
|
| 83 |
+
orig_sr=sr,
|
| 84 |
+
target_sr=sampling_rate,
|
| 85 |
)
|
| 86 |
+
|
| 87 |
+
return data
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def load_audio_bytes(raw: bytes, sampling_rate: int) -> np.ndarray:
|
| 91 |
+
"""Load audio from in-memory bytes and resample.
|
| 92 |
+
|
| 93 |
+
Parameters:
|
| 94 |
+
raw: raw audio file bytes (e.g. from WebDataset).
|
| 95 |
+
sampling_rate: target sampling rate.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
Numpy float32 array of shape (1, T).
|
| 99 |
+
"""
|
| 100 |
+
buf = io.BytesIO(raw)
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
data, sr = sf.read(buf, dtype="float32", always_2d=True)
|
| 104 |
+
data = data.T # (T, C) → (C, T)
|
| 105 |
+
except Exception:
|
| 106 |
+
buf.seek(0)
|
| 107 |
+
data, sr = librosa.load(buf, sr=None, mono=False)
|
| 108 |
+
if data.ndim == 1:
|
| 109 |
+
data = data[np.newaxis, :]
|
| 110 |
+
|
| 111 |
+
if data.shape[0] > 1:
|
| 112 |
+
data = np.mean(data, axis=0, keepdims=True)
|
| 113 |
+
if sr != sampling_rate:
|
| 114 |
+
data = librosa.resample(
|
| 115 |
+
data,
|
| 116 |
+
orig_sr=sr,
|
| 117 |
+
target_sr=sampling_rate,
|
| 118 |
)
|
|
|
|
|
|
|
| 119 |
|
| 120 |
+
return data
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
# ---------------------------------------------------------------------------
|
| 124 |
+
# Audio processing (all numpy in / numpy out)
|
| 125 |
+
# ---------------------------------------------------------------------------
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def numpy_to_audiosegment(audio: np.ndarray, sample_rate: int) -> AudioSegment:
|
| 129 |
+
"""Convert a numpy float32 array of shape (C, T) to a pydub AudioSegment."""
|
| 130 |
+
audio_int = (audio * 32768.0).clip(-32768, 32767).astype(np.int16)
|
| 131 |
+
if audio_int.shape[0] > 1:
|
| 132 |
+
audio_int = audio_int.T.flatten() # interleave channels
|
| 133 |
+
return AudioSegment(
|
| 134 |
+
data=audio_int.tobytes(),
|
| 135 |
+
sample_width=2,
|
| 136 |
+
frame_rate=sample_rate,
|
| 137 |
+
channels=audio.shape[0],
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def audiosegment_to_numpy(aseg: AudioSegment) -> np.ndarray:
|
| 142 |
+
"""Convert a pydub AudioSegment to a numpy float32 array of shape (C, T)."""
|
| 143 |
+
data = np.array(aseg.get_array_of_samples()).astype(np.float32) / 32768.0
|
| 144 |
+
if aseg.channels == 1:
|
| 145 |
+
return data[np.newaxis, :]
|
| 146 |
+
return data.reshape(-1, aseg.channels).T
|
| 147 |
|
| 148 |
|
| 149 |
def remove_silence(
|
| 150 |
+
audio: np.ndarray,
|
| 151 |
sampling_rate: int,
|
| 152 |
mid_sil: int = 300,
|
| 153 |
lead_sil: int = 100,
|
| 154 |
trail_sil: int = 300,
|
| 155 |
+
) -> np.ndarray:
|
| 156 |
+
"""Remove middle silences longer than *mid_sil* ms and trim edge silences.
|
|
|
|
| 157 |
|
| 158 |
Parameters:
|
| 159 |
+
audio: numpy array with shape (C, T).
|
| 160 |
sampling_rate: sampling rate of the audio.
|
| 161 |
+
mid_sil: middle-silence threshold in ms (0 to skip).
|
| 162 |
+
lead_sil: kept leading silence in ms.
|
| 163 |
+
trail_sil: kept trailing silence in ms.
|
|
|
|
| 164 |
|
| 165 |
Returns:
|
| 166 |
+
Numpy array with shape (C, T').
|
|
|
|
| 167 |
"""
|
| 168 |
+
wave = numpy_to_audiosegment(audio, sampling_rate)
|
|
|
|
| 169 |
|
| 170 |
if mid_sil > 0:
|
|
|
|
| 171 |
non_silent_segs = split_on_silence(
|
| 172 |
wave,
|
| 173 |
min_silence_len=mid_sil,
|
|
|
|
| 175 |
keep_silence=mid_sil,
|
| 176 |
seek_step=10,
|
| 177 |
)
|
|
|
|
|
|
|
| 178 |
wave = AudioSegment.silent(duration=0)
|
| 179 |
for seg in non_silent_segs:
|
| 180 |
wave += seg
|
| 181 |
|
|
|
|
| 182 |
wave = remove_silence_edges(wave, lead_sil, trail_sil, -50)
|
| 183 |
|
| 184 |
+
return audiosegment_to_numpy(wave)
|
|
|
|
| 185 |
|
| 186 |
|
| 187 |
def remove_silence_edges(
|
|
|
|
| 189 |
lead_sil: int = 100,
|
| 190 |
trail_sil: int = 300,
|
| 191 |
silence_threshold: float = -50,
|
| 192 |
+
) -> AudioSegment:
|
| 193 |
+
"""Remove edge silences, keeping *lead_sil* / *trail_sil* ms."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold)
|
| 195 |
start_idx = max(0, start_idx - lead_sil)
|
| 196 |
audio = audio[start_idx:]
|
| 197 |
|
|
|
|
| 198 |
audio = audio.reverse()
|
| 199 |
start_idx = detect_leading_silence(audio, silence_threshold=silence_threshold)
|
| 200 |
start_idx = max(0, start_idx - trail_sil)
|
|
|
|
| 204 |
return audio
|
| 205 |
|
| 206 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 207 |
def fade_and_pad_audio(
|
| 208 |
+
audio: np.ndarray,
|
| 209 |
pad_duration: float = 0.1,
|
| 210 |
fade_duration: float = 0.1,
|
| 211 |
sample_rate: int = 24000,
|
| 212 |
+
) -> np.ndarray:
|
| 213 |
+
"""Apply fade-in/out and pad with silence to prevent clicks.
|
|
|
|
|
|
|
| 214 |
|
| 215 |
Args:
|
| 216 |
+
audio: numpy array of shape (C, T).
|
| 217 |
+
pad_duration: silence padding duration per side (seconds).
|
| 218 |
+
fade_duration: fade curve duration (seconds).
|
| 219 |
+
sample_rate: audio sampling rate.
|
| 220 |
|
| 221 |
Returns:
|
| 222 |
+
Processed numpy array of shape (C, T_new).
|
| 223 |
"""
|
| 224 |
if audio.shape[-1] == 0:
|
| 225 |
return audio
|
|
|
|
| 227 |
fade_samples = int(fade_duration * sample_rate)
|
| 228 |
pad_samples = int(pad_duration * sample_rate)
|
| 229 |
|
| 230 |
+
processed = audio.copy()
|
| 231 |
|
| 232 |
if fade_samples > 0:
|
| 233 |
k = min(fade_samples, processed.shape[-1] // 2)
|
|
|
|
| 234 |
if k > 0:
|
| 235 |
+
fade_in = np.linspace(0, 1, k, dtype=np.float32)[np.newaxis, :]
|
| 236 |
+
processed[..., :k] *= fade_in
|
|
|
|
|
|
|
| 237 |
|
| 238 |
+
fade_out = np.linspace(1, 0, k, dtype=np.float32)[np.newaxis, :]
|
| 239 |
+
processed[..., -k:] *= fade_out
|
|
|
|
|
|
|
| 240 |
|
| 241 |
if pad_samples > 0:
|
| 242 |
+
silence = np.zeros(
|
| 243 |
(processed.shape[0], pad_samples),
|
| 244 |
dtype=processed.dtype,
|
|
|
|
| 245 |
)
|
| 246 |
+
processed = np.concatenate([silence, processed, silence], axis=-1)
|
| 247 |
|
| 248 |
return processed
|
| 249 |
|
| 250 |
|
| 251 |
def trim_long_audio(
|
| 252 |
+
audio: np.ndarray,
|
| 253 |
sampling_rate: int,
|
| 254 |
max_duration: float = 15.0,
|
| 255 |
min_duration: float = 3.0,
|
| 256 |
trim_threshold: float = 20.0,
|
| 257 |
+
) -> np.ndarray:
|
| 258 |
+
"""Trim audio to <= *max_duration* by splitting at the largest silence gap.
|
| 259 |
|
| 260 |
Only trims when the audio exceeds *trim_threshold* seconds.
|
| 261 |
|
| 262 |
Args:
|
| 263 |
+
audio: numpy array of shape (C, T).
|
| 264 |
+
sampling_rate: audio sampling rate.
|
| 265 |
+
max_duration: maximum duration in seconds.
|
| 266 |
+
min_duration: minimum duration in seconds.
|
| 267 |
+
trim_threshold: only trim if audio is longer than this (seconds).
|
| 268 |
|
| 269 |
Returns:
|
| 270 |
+
Trimmed numpy array.
|
| 271 |
"""
|
| 272 |
+
duration = audio.shape[-1] / sampling_rate
|
| 273 |
if duration <= trim_threshold:
|
| 274 |
return audio
|
| 275 |
|
| 276 |
+
seg = numpy_to_audiosegment(audio, sampling_rate)
|
| 277 |
nonsilent = detect_nonsilent(
|
| 278 |
seg, min_silence_len=100, silence_thresh=-40, seek_step=10
|
| 279 |
)
|
|
|
|
| 283 |
max_ms = int(max_duration * 1000)
|
| 284 |
min_ms = int(min_duration * 1000)
|
| 285 |
|
|
|
|
| 286 |
best_split = 0
|
| 287 |
for start, end in nonsilent:
|
| 288 |
if start > best_split and start <= max_ms:
|
|
|
|
| 294 |
best_split = min(max_ms, len(seg))
|
| 295 |
|
| 296 |
trimmed = seg[:best_split]
|
| 297 |
+
return audiosegment_to_numpy(trimmed)
|
| 298 |
|
| 299 |
|
| 300 |
def cross_fade_chunks(
|
| 301 |
+
chunks: list[np.ndarray],
|
| 302 |
sample_rate: int,
|
| 303 |
silence_duration: float = 0.3,
|
| 304 |
+
) -> np.ndarray:
|
| 305 |
+
"""Concatenate audio chunks with silence gaps and cross-fade at boundaries.
|
|
|
|
|
|
|
|
|
|
| 306 |
|
| 307 |
Args:
|
| 308 |
+
chunks: list of numpy arrays, each (C, T).
|
| 309 |
+
sample_rate: audio sample rate.
|
| 310 |
+
silence_duration: total silence gap duration in seconds.
|
| 311 |
|
| 312 |
Returns:
|
| 313 |
+
Merged numpy array (C, T_total).
|
| 314 |
"""
|
| 315 |
if len(chunks) == 1:
|
| 316 |
return chunks[0]
|
| 317 |
|
| 318 |
total_n = int(silence_duration * sample_rate)
|
| 319 |
fade_n = total_n // 3
|
| 320 |
+
silence_n = fade_n
|
| 321 |
+
merged = chunks[0].copy()
|
| 322 |
|
| 323 |
for chunk in chunks[1:]:
|
|
|
|
| 324 |
parts = [merged]
|
| 325 |
|
| 326 |
+
fout_n = min(fade_n, merged.shape[-1])
|
|
|
|
| 327 |
if fout_n > 0:
|
| 328 |
+
w_out = np.linspace(1, 0, fout_n, dtype=np.float32)[np.newaxis, :]
|
| 329 |
+
parts[-1][..., -fout_n:] *= w_out
|
| 330 |
|
| 331 |
+
parts.append(np.zeros((chunks[0].shape[0], silence_n), dtype=np.float32))
|
|
|
|
| 332 |
|
| 333 |
+
fade_in = chunk.copy()
|
| 334 |
+
fin_n = min(fade_n, fade_in.shape[-1])
|
|
|
|
| 335 |
if fin_n > 0:
|
| 336 |
+
w_in = np.linspace(0, 1, fin_n, dtype=np.float32)[np.newaxis, :]
|
| 337 |
+
fade_in[..., :fin_n] *= w_in
|
| 338 |
|
| 339 |
parts.append(fade_in)
|
| 340 |
+
merged = np.concatenate(parts, axis=-1)
|
| 341 |
|
| 342 |
return merged
|
omnivoice/utils/data_utils.py
CHANGED
|
@@ -29,10 +29,10 @@ from pathlib import Path
|
|
| 29 |
def read_test_list(path):
|
| 30 |
"""Read a JSONL test list file.
|
| 31 |
|
| 32 |
-
Each line should be a JSON object
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
|
| 37 |
Returns a list of dicts.
|
| 38 |
"""
|
|
@@ -58,6 +58,7 @@ def read_test_list(path):
|
|
| 58 |
"language_name": obj.get("language_name"),
|
| 59 |
"duration": obj.get("duration"),
|
| 60 |
"speed": obj.get("speed"),
|
|
|
|
| 61 |
}
|
| 62 |
samples.append(sample)
|
| 63 |
return samples
|
|
|
|
| 29 |
def read_test_list(path):
|
| 30 |
"""Read a JSONL test list file.
|
| 31 |
|
| 32 |
+
Each line should be a JSON object. Only ``id`` and ``text`` are required;
|
| 33 |
+
all other fields are optional (default to ``None``):
|
| 34 |
+
id, text, ref_audio, ref_text, instruct,
|
| 35 |
+
language_id, language_name, duration, speed
|
| 36 |
|
| 37 |
Returns a list of dicts.
|
| 38 |
"""
|
|
|
|
| 58 |
"language_name": obj.get("language_name"),
|
| 59 |
"duration": obj.get("duration"),
|
| 60 |
"speed": obj.get("speed"),
|
| 61 |
+
"instruct": obj.get("instruct"),
|
| 62 |
}
|
| 63 |
samples.append(sample)
|
| 64 |
return samples
|
requirements.txt
CHANGED
|
@@ -5,6 +5,7 @@ transformers==5.3
|
|
| 5 |
accelerate
|
| 6 |
pydub
|
| 7 |
soundfile
|
|
|
|
| 8 |
numpy
|
| 9 |
gradio
|
| 10 |
hf_transfer
|
|
|
|
| 5 |
accelerate
|
| 6 |
pydub
|
| 7 |
soundfile
|
| 8 |
+
librosa
|
| 9 |
numpy
|
| 10 |
gradio
|
| 11 |
hf_transfer
|