"""Thin FastAPI server inside the F5-TTS container. Loads model + vocoder ONCE at startup (heavy: ~5s, ~5GB VRAM). POST /synthesize runs inference and writes the WAV to a shared volume; the response is JSON with the output path and metadata — not the WAV bytes, since chapter-length renders are 20-30MB and both skald and the f5 container share /audio anyway. Why not Gradio's API: Gradio's /gradio_api/call/* shape is event- stream + polling; this is a single POST + immediate response. Right for skald's "render one chapter, then move on" loop. """ import logging import time import uuid from pathlib import Path import soundfile as sf import torch from fastapi import FastAPI, HTTPException from omegaconf import OmegaConf from pydantic import BaseModel, Field from cached_path import cached_path from importlib.resources import files from hydra.utils import get_class from f5_tts.infer.utils_infer import ( infer_process, load_model, load_vocoder, preprocess_ref_audio_text, ) log = logging.getLogger("f5-server") logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s %(message)s") # ─── model state ───────────────────────────────────────────────── DEVICE = "cuda" if torch.cuda.is_available() else "cpu" MODEL_NAME = "F5TTS_v1_Base" VOCODER_NAME = "vocos" AUDIO_ROOT = Path("/audio") VOICES_ROOT = Path("/voices") _model = None _vocoder = None def _load_models() -> None: """One-time model + vocoder load. ~5-8s wall-clock on first call.""" global _model, _vocoder if _model is not None: return log.info("loading vocoder=%s device=%s", VOCODER_NAME, DEVICE) _vocoder = load_vocoder(vocoder_name=VOCODER_NAME, device=DEVICE) cfg_path = str(files("f5_tts").joinpath(f"configs/{MODEL_NAME}.yaml")) model_cfg = OmegaConf.load(cfg_path) model_cls = get_class(f"f5_tts.model.{model_cfg.model.backbone}") model_arc = model_cfg.model.arch # F5TTS_v1_Base ships as a HuggingFace artifact; cached_path # handles the resolution + downloads to HF_HOME. ckpt_file = str( cached_path("hf://SWivid/F5-TTS/F5TTS_v1_Base/model_1250000.safetensors") ) vocab_file = str(files("f5_tts").joinpath("infer/examples/vocab.txt")) log.info("loading model=%s ckpt=%s", MODEL_NAME, ckpt_file) _model = load_model( model_cls, model_arc, ckpt_file, mel_spec_type=VOCODER_NAME, vocab_file=vocab_file, device=DEVICE, ) log.info("model + vocoder loaded; ready") # ─── FastAPI app ───────────────────────────────────────────────── class SynthesizeRequest(BaseModel): # The text we want to synthesize. Long-form OK — F5-TTS chunks # internally via infer_process. gen_text: str = Field(min_length=1) # Reference audio path (inside the f5-tts container). Defaults # to the staged lj_speech clip. ref_audio_path: str = "/voices/lj_speech.wav" # Reference transcript. Defaults to the bundled lj_speech.txt. ref_text: str | None = None # Output filename, relative to /audio (the shared output dir). # If omitted, a UUID-based name is assigned. output_filename: str | None = None # Speech speed (0.5-2.0). Default 1.0 = natural pace. speed: float = Field(default=1.0, ge=0.3, le=2.0) # Cross-fade between chunks; F5 default is 0.15s. Bigger smooths # chunk boundaries on long-form prose at the cost of pacing. cross_fade_duration: float = Field(default=0.15, ge=0.0, le=1.0) class SynthesizeResponse(BaseModel): ok: bool output_path: str sample_rate_hz: int duration_seconds: float elapsed_ms: int chars_in: int app = FastAPI(title="f5-tts-server", version="0.1.0") @app.on_event("startup") def _startup() -> None: _load_models() @app.get("/healthz") def healthz() -> dict: return { "ok": True, "device": DEVICE, "model": MODEL_NAME, "vocoder": VOCODER_NAME, "loaded": _model is not None, } @app.post("/synthesize", response_model=SynthesizeResponse) def synthesize(req: SynthesizeRequest) -> SynthesizeResponse: if _model is None: raise HTTPException(503, "model not loaded yet — retry shortly") ref_audio_path = Path(req.ref_audio_path) if not ref_audio_path.is_file(): raise HTTPException(400, f"ref_audio_path not found: {ref_audio_path}") # If no explicit ref_text, try sidecar .txt then fall back to "" # (which triggers F5's auto-ASR). ref_text = req.ref_text if ref_text is None: sidecar = ref_audio_path.with_suffix(".txt") if sidecar.is_file(): ref_text = sidecar.read_text().strip() else: ref_text = "" output_filename = req.output_filename or f"{uuid.uuid4().hex}.wav" if "/" in output_filename or ".." in output_filename: raise HTTPException(400, "output_filename must be a bare name, no path parts") output_path = AUDIO_ROOT / output_filename output_path.parent.mkdir(parents=True, exist_ok=True) started = time.monotonic() ref_audio_processed, ref_text_processed = preprocess_ref_audio_text( str(ref_audio_path), ref_text ) audio_segment, final_sample_rate, _ = infer_process( ref_audio_processed, ref_text_processed, req.gen_text, _model, _vocoder, mel_spec_type=VOCODER_NAME, speed=req.speed, cross_fade_duration=req.cross_fade_duration, device=DEVICE, ) elapsed_ms = int((time.monotonic() - started) * 1000) sf.write(str(output_path), audio_segment, final_sample_rate, subtype="PCM_16") duration_s = float(len(audio_segment)) / float(final_sample_rate) log.info( "synthesized chars=%d -> %s (sr=%d, dur=%.2fs, elapsed=%dms)", len(req.gen_text), output_path, final_sample_rate, duration_s, elapsed_ms, ) return SynthesizeResponse( ok=True, output_path=str(output_path), sample_rate_hz=final_sample_rate, duration_seconds=duration_s, elapsed_ms=elapsed_ms, chars_in=len(req.gen_text), )