skald/engines/f5-tts/server.py
Kayos d1631ddffe engines: import f5-tts + kokoro + tortoise sidecars into the tree
The python FastAPI sidecars have lived ad-hoc at /mnt/cache/appdata/
<engine>/build/ on Lucy without version control. Bringing them into
the skald repo so the engine code travels with the cross-engine
routing it depends on.

This commit lands the VANILLA version of each engine on main:

  engines/f5-tts/    SWivid F5-TTS (CC-BY-NC weights flagged)
  engines/kokoro/    hexgrad Kokoro-82M (Apache 2.0 top to bottom)
  engines/tortoise/  neonbjb Tortoise-TTS (Apache 2.0 top to bottom)

Engine-specific kludges (question doubling, GPU coordination,
pause-duration tuning) get layered on engine/* branches per the
README. Main stays the safe-to-read baseline.
2026-05-14 09:40:01 -07:00

184 lines
6.2 KiB
Python

"""Thin FastAPI server inside the F5-TTS container.
Loads model + vocoder ONCE at startup (heavy: ~5s, ~5GB VRAM).
POST /synthesize runs inference and writes the WAV to a shared
volume; the response is JSON with the output path and metadata —
not the WAV bytes, since chapter-length renders are 20-30MB and
both skald and the f5 container share /audio anyway.
Why not Gradio's API: Gradio's /gradio_api/call/* shape is event-
stream + polling; this is a single POST + immediate response.
Right for skald's "render one chapter, then move on" loop.
"""
import logging
import time
import uuid
from pathlib import Path
import soundfile as sf
import torch
from fastapi import FastAPI, HTTPException
from omegaconf import OmegaConf
from pydantic import BaseModel, Field
from cached_path import cached_path
from importlib.resources import files
from hydra.utils import get_class
from f5_tts.infer.utils_infer import (
infer_process,
load_model,
load_vocoder,
preprocess_ref_audio_text,
)
log = logging.getLogger("f5-server")
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s")
# ─── model state ─────────────────────────────────────────────────
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_NAME = "F5TTS_v1_Base"
VOCODER_NAME = "vocos"
AUDIO_ROOT = Path("/audio")
VOICES_ROOT = Path("/voices")
_model = None
_vocoder = None
def _load_models() -> None:
"""One-time model + vocoder load. ~5-8s wall-clock on first call."""
global _model, _vocoder
if _model is not None:
return
log.info("loading vocoder=%s device=%s", VOCODER_NAME, DEVICE)
_vocoder = load_vocoder(vocoder_name=VOCODER_NAME, device=DEVICE)
cfg_path = str(files("f5_tts").joinpath(f"configs/{MODEL_NAME}.yaml"))
model_cfg = OmegaConf.load(cfg_path)
model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}")
model_arc = model_cfg.model.arch
# F5TTS_v1_Base ships as a HuggingFace artifact; cached_path
# handles the resolution + downloads to HF_HOME.
ckpt_file = str(
cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors")
)
vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt"))
log.info("loading model=%s ckpt=%s", MODEL_NAME, ckpt_file)
_model = load_model(
model_cls, model_arc, ckpt_file,
mel_spec_type=VOCODER_NAME, vocab_file=vocab_file, device=DEVICE,
)
log.info("model + vocoder loaded; ready")
# ─── FastAPI app ─────────────────────────────────────────────────
class SynthesizeRequest(BaseModel):
# The text we want to synthesize. Long-form OK — F5-TTS chunks
# internally via infer_process.
gen_text: str = Field(min_length=1)
# Reference audio path (inside the f5-tts container). Defaults
# to the staged lj_speech clip.
ref_audio_path: str = "/voices/lj_speech.wav"
# Reference transcript. Defaults to the bundled lj_speech.txt.
ref_text: str | None = None
# Output filename, relative to /audio (the shared output dir).
# If omitted, a UUID-based name is assigned.
output_filename: str | None = None
# Speech speed (0.5-2.0). Default 1.0 = natural pace.
speed: float = Field(default=1.0, ge=0.3, le=2.0)
# Cross-fade between chunks; F5 default is 0.15s. Bigger smooths
# chunk boundaries on long-form prose at the cost of pacing.
cross_fade_duration: float = Field(default=0.15, ge=0.0, le=1.0)
class SynthesizeResponse(BaseModel):
ok: bool
output_path: str
sample_rate_hz: int
duration_seconds: float
elapsed_ms: int
chars_in: int
app = FastAPI(title="f5-tts-server", version="0.1.0")
@app.on_event("startup")
def _startup() -> None:
_load_models()
@app.get("/healthz")
def healthz() -> dict:
return {
"ok": True,
"device": DEVICE,
"model": MODEL_NAME,
"vocoder": VOCODER_NAME,
"loaded": _model is not None,
}
@app.post("/synthesize", response_model=SynthesizeResponse)
def synthesize(req: SynthesizeRequest) -> SynthesizeResponse:
if _model is None:
raise HTTPException(503, "model not loaded yet — retry shortly")
ref_audio_path = Path(req.ref_audio_path)
if not ref_audio_path.is_file():
raise HTTPException(400, f"ref_audio_path not found: {ref_audio_path}")
# If no explicit ref_text, try sidecar .txt then fall back to ""
# (which triggers F5's auto-ASR).
ref_text = req.ref_text
if ref_text is None:
sidecar = ref_audio_path.with_suffix(".txt")
if sidecar.is_file():
ref_text = sidecar.read_text().strip()
else:
ref_text = ""
output_filename = req.output_filename or f"{uuid.uuid4().hex}.wav"
if "/" in output_filename or ".." in output_filename:
raise HTTPException(400, "output_filename must be a bare name, no path parts")
output_path = AUDIO_ROOT / output_filename
output_path.parent.mkdir(parents=True, exist_ok=True)
started = time.monotonic()
ref_audio_processed, ref_text_processed = preprocess_ref_audio_text(
str(ref_audio_path), ref_text
)
audio_segment, final_sample_rate, _ = infer_process(
ref_audio_processed,
ref_text_processed,
req.gen_text,
_model,
_vocoder,
mel_spec_type=VOCODER_NAME,
speed=req.speed,
cross_fade_duration=req.cross_fade_duration,
device=DEVICE,
)
elapsed_ms = int((time.monotonic() - started) * 1000)
sf.write(str(output_path), audio_segment, final_sample_rate, subtype="PCM_16")
duration_s = float(len(audio_segment)) / float(final_sample_rate)
log.info(
"synthesized chars=%d -> %s (sr=%d, dur=%.2fs, elapsed=%dms)",
len(req.gen_text), output_path, final_sample_rate, duration_s, elapsed_ms,
)
return SynthesizeResponse(
ok=True,
output_path=str(output_path),
sample_rate_hz=final_sample_rate,
duration_seconds=duration_s,
elapsed_ms=elapsed_ms,
chars_in=len(req.gen_text),
)