Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for Piper TTS #1130

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions src/pipecat/services/piper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#
# Copyright (c) 2024–2025, Daily
#
# SPDX-License-Identifier: BSD 2-Clause License
#

from typing import AsyncGenerator

import aiohttp
from loguru import logger

from pipecat.frames.frames import (
ErrorFrame,
Frame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.ai_services import TTSService

# This assumes a running TTS service running: https://github.com/rhasspy/piper/blob/master/src/python_run/README_http.md


class PiperTTSService(TTSService):
"""Piper TTS service implementation.

Provides integration with Piper's TTS server.
"""

def __init__(
self,
*,
base_url: str,
aiohttp_session: aiohttp.ClientSession | None = None,
sample_rate: int = 24000,
**kwargs,
):
"""Initialize the PiperTTSService class instance.

Args:
base_url (str): Base URL of the Piper TTS server (should not end with a slash).
aiohttp_session (aiohttp.ClientSession, optional): Optional aiohttp session to use for requests. Defaults to None.
sample_rate (int, optional): Sample rate in Hz. Defaults to 24000.
**kwargs (dict): Additional keyword arguments.
"""
super().__init__(sample_rate=sample_rate, **kwargs)
if not aiohttp_session:
aiohttp_session = aiohttp.ClientSession()

if base_url.endswith("/"):
logger.warning("Base URL ends with a slash, this is not allowed.")
base_url = base_url[:-1]

self._settings = {"base_url": base_url}
self.set_voice("voice_id")
self._aiohttp_session = aiohttp_session

def can_generate_metrics(self) -> bool:
return True

async def run_tts(self, text: str) -> AsyncGenerator[Frame, None]:
logger.debug(f"Generating TTS: [{text}]")

url = self._settings["base_url"] + "/?text=" + text.replace(".", "").replace("*", "")

await self.start_ttfb_metrics()

async with self._aiohttp_session.get(url) as r:
if r.status != 200:
text = await r.text()
logger.error(f"{self} error getting audio (status: {r.status}, error: {text})")
yield ErrorFrame(f"Error getting audio (status: {r.status}, error: {text})")
return

await self.start_tts_usage_metrics(text)

yield TTSStartedFrame()

buffer = bytearray()
async for chunk in r.content.iter_chunked(1024):
if len(chunk) > 0:
await self.stop_ttfb_metrics()
# Append new chunk to the buffer.
buffer.extend(chunk)

# Check if buffer has enough data for processing.
while (
len(buffer) >= 48000
): # Assuming at least 0.5 seconds of audio data at 24000 Hz
# Process the buffer up to a safe size for resampling.
process_data = buffer[:48000]
# Remove processed data from buffer.
buffer = buffer[48000:]

frame = TTSAudioRawFrame(process_data, self._sample_rate, 1)
yield frame

# Process any remaining data in the buffer.
if len(buffer) > 0:
frame = TTSAudioRawFrame(buffer, self._sample_rate, 1)
yield frame

yield TTSStoppedFrame()
1 change: 1 addition & 0 deletions test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ pyaudio~=0.2.14
pydantic~=2.8.2
pyloudnorm~=0.1.1
pyht~=0.1.4
pytest-aiohttp==1.1.0
python-dotenv~=1.0.1
silero-vad~=5.1
soxr~=0.5.0
Expand Down
101 changes: 101 additions & 0 deletions tests/test_piper_tts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
"""Tests for PiperTTSService."""

import asyncio

import pytest
from aiohttp import web

from pipecat.frames.frames import (
ErrorFrame,
TTSAudioRawFrame,
TTSStartedFrame,
TTSStoppedFrame,
)
from pipecat.services.piper import PiperTTSService


@pytest.mark.asyncio
async def test_run_piper_tts_success(aiohttp_client):
"""Test successful TTS generation with chunked audio data.

Checks frames for TTSStartedFrame -> TTSAudioRawFrame -> TTSStoppedFrame.
"""

async def handler(request):
# The service expects a /?text= param
# Here we're just returning dummy chunked bytes to simulate an audio response
text_query = request.rel_url.query.get("text", "")
print(f"Mock server received text param: {text_query}")

# Prepare a StreamResponse with chunked data
resp = web.StreamResponse(
status=200,
reason="OK",
headers={"Content-Type": "audio/raw"},
)
await resp.prepare(request)

# Write out some chunked byte data
# In reality, you’d return WAV data or similar
data_chunk_1 = b"\x00\x01\x02\x03" * 12000 # 48000 bytes
data_chunk_2 = b"\x04\x05\x06\x07" * 6000 # another chunk
await resp.write(data_chunk_1)
await asyncio.sleep(0.01) # simulate async chunk delay
await resp.write(data_chunk_2)
await resp.write_eof()

return resp

# Create an aiohttp test server
app = web.Application()
app.router.add_get("/", handler)
client = await aiohttp_client(app)

# Remove trailing slash if present in the test URL
base_url = str(client.make_url("")).rstrip("/")

# Instantiate PiperTTSService with our mock server
tts_service = PiperTTSService(base_url=base_url)

# Collect frames from the generator
frames = []
async for frame in tts_service.run_tts("Hello world."):
frames.append(frame)

# Ensure we received frames in the expected order/types
assert len(frames) >= 3, "Expecting at least TTSStartedFrame, TTSAudioRawFrame, TTSStoppedFrame"
assert isinstance(frames[0], TTSStartedFrame), "First frame must be TTSStartedFrame"
assert isinstance(frames[-1], TTSStoppedFrame), "Last frame must be TTSStoppedFrame"

# Check we have at least one TTSAudioRawFrame
audio_frames = [f for f in frames if isinstance(f, TTSAudioRawFrame)]
assert len(audio_frames) > 0, "Should have received at least one TTSAudioRawFrame"
for a_frame in audio_frames:
assert a_frame.sample_rate == 24000, "Sample rate should match the default (24000)"


@pytest.mark.asyncio
async def test_run_piper_tts_error(aiohttp_client):
"""Test how the service handles a non-200 response from the server.

Expects an ErrorFrame to be returned.
"""

async def handler(_request):
# Return an error status for any request
return web.Response(status=404, text="Not found")

app = web.Application()
app.router.add_get("/", handler)
client = await aiohttp_client(app)
base_url = str(client.make_url("")).rstrip("/")

tts_service = PiperTTSService(base_url=base_url)

frames = []
async for frame in tts_service.run_tts("Error case."):
frames.append(frame)

assert len(frames) == 1, "Should only receive a single ErrorFrame"
assert isinstance(frames[0], ErrorFrame), "Must receive an ErrorFrame for 404"
assert "status: 404" in frames[0].error, "ErrorFrame should contain details about the 404"
Loading