Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/getting-started/01_simple_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
- Audio file contains the spoken text
"""

import os
from pathlib import Path

from fishaudio import FishAudio
from fishaudio.utils import save

Expand All @@ -43,7 +44,7 @@ def main():
save(audio, output_file)

print(f"✓ Audio saved to {output_file}")
print(f" File size: {os.path.getsize(output_file) / 1024:.2f} KB")
print(f" File size: {Path(output_file).stat().st_size / 1024:.2f} KB")


if __name__ == "__main__":
Expand Down
13 changes: 12 additions & 1 deletion examples/getting_started.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,18 @@
}
},
"outputs": [],
"source": "from dotenv import load_dotenv\nfrom fishaudio import FishAudio\nfrom fishaudio.utils import play\n# from fishaudio.utils import save # Uncomment if saving audio to file\n\nload_dotenv()\n\nclient = FishAudio()"
"source": [
"from dotenv import load_dotenv\n",
"\n",
"from fishaudio import FishAudio\n",
"from fishaudio.utils import play\n",
"\n",
"# from fishaudio.utils import save # Uncomment if saving audio to file\n",
"\n",
"load_dotenv()\n",
"\n",
"client = FishAudio()"
]
},
{
"cell_type": "markdown",
Expand Down
29 changes: 29 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,5 +83,34 @@ pages = [
{title = "Exceptions", name="fishaudio/exceptions", contents = ["fishaudio.exceptions.*"] },
]

[tool.ruff.lint]
extend-select = [
"F", # Pyflakes rules
"W", # PyCodeStyle warnings
"E", # PyCodeStyle errors
"I", # Sort imports properly
"UP", # Warn if certain things can changed due to newer Python versions
"C4", # Catch incorrect use of comprehensions, dict, list, etc
"FA", # Enforce from __future__ import annotations
"ISC", # Good use of string concatenation
"ICN", # Use common import conventions
"RET", # Good return practices
"SIM", # Common simplification rules
"TID", # Some good import practices
"TC", # Enforce importing certain types in a TYPE_CHECKING block
"PTH", # Use pathlib instead of os.path
"TD", # Be diligent with TODO comments
"NPY", # Some numpy-specific things
]
ignore = [
"E501", # Line too long (handled by ruff format)
]

[tool.ruff.lint.flake8-type-checking]
runtime-evaluated-base-classes = ["pydantic.BaseModel"]

[tool.ruff.lint.pyupgrade]
keep-runtime-typing = true

[tool.uv.sources]
fish-audio-sdk = { workspace = true }
2 changes: 2 additions & 0 deletions scripts/copy_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
python scripts/copy_docs.py sdk docs # In CI context
"""

from __future__ import annotations

import argparse
import shutil
from pathlib import Path
Expand Down
14 changes: 7 additions & 7 deletions src/fish_audio_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from .apis import Session
from .exceptions import HttpCodeErr, WebSocketErr
from .schemas import (
APICreditEntity,
ASRRequest,
TTSRequest,
ReferenceAudio,
Prosody,
PaginatedResponse,
CloseEvent,
ModelEntity,
APICreditEntity,
PaginatedResponse,
Prosody,
ReferenceAudio,
StartEvent,
TextEvent,
CloseEvent,
TTSRequest,
)
from .websocket import WebSocketSession, AsyncWebSocketSession
from .websocket import AsyncWebSocketSession, WebSocketSession

__all__ = [
"Session",
Expand Down
9 changes: 7 additions & 2 deletions src/fish_audio_sdk/apis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from typing import Generator, Literal
from __future__ import annotations

from typing import TYPE_CHECKING, Literal

import ormsgpack

Expand All @@ -7,13 +9,16 @@
APICreditEntity,
ASRRequest,
ASRResponse,
ModelEntity,
Backends,
ModelEntity,
PackageEntity,
PaginatedResponse,
TTSRequest,
)

if TYPE_CHECKING:
from collections.abc import Generator


class Session(RemoteCall):
@convert_stream
Expand Down
15 changes: 6 additions & 9 deletions src/fish_audio_sdk/io.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,20 @@
from __future__ import annotations

import dataclasses
import typing
from collections.abc import AsyncGenerator, Awaitable, Generator
from http.client import responses as http_responses
from typing import (
Any,
AsyncGenerator,
Awaitable,
Callable,
Generator,
Generic,
TypeVar,
)

from typing_extensions import Concatenate, ParamSpec

import httpx
import httpx._client
import httpx._types
from typing_extensions import Concatenate, ParamSpec

from .exceptions import HttpCodeErr

Expand Down Expand Up @@ -194,8 +193,7 @@ def sync_wrapper(self: RemoteCall, *args: P.args, **kwargs: P.kwargs) -> R:
return exc.value
raise RuntimeError("Generator did not stop")

call = IOCallDescriptor(async_wrapper, sync_wrapper)
return call
return IOCallDescriptor(async_wrapper, sync_wrapper)


GStream = G[Generator[bytes, bytes, None]]
Expand Down Expand Up @@ -257,5 +255,4 @@ def sync_wrapper(

raise RuntimeError("Generator did not stop")

call = StreamIOCallDescriptor(async_wrapper, sync_wrapper)
return call
return StreamIOCallDescriptor(async_wrapper, sync_wrapper)
3 changes: 2 additions & 1 deletion src/fish_audio_sdk/schemas.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from __future__ import annotations

import datetime
import decimal
from typing import Annotated, Generic, Literal, TypeVar

from pydantic import BaseModel, Field


Backends = Literal["speech-1.5", "speech-1.6", "agent-x0", "s1", "s1-mini"]

Item = TypeVar("Item")
Expand Down
7 changes: 3 additions & 4 deletions src/fish_audio_sdk/websocket.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import asyncio
from collections.abc import AsyncGenerator, AsyncIterable, Generator, Iterable
from concurrent.futures import ThreadPoolExecutor
from typing import AsyncGenerator, AsyncIterable, Generator, Iterable

import httpx
import ormsgpack
from httpx_ws import WebSocketDisconnect, connect_ws, aconnect_ws
from httpx_ws import WebSocketDisconnect, aconnect_ws, connect_ws

from .exceptions import WebSocketErr

from .schemas import Backends, CloseEvent, StartEvent, TTSRequest, TextEvent
from .schemas import Backends, CloseEvent, StartEvent, TextEvent, TTSRequest


class WebSocketSession:
Expand Down
2 changes: 1 addition & 1 deletion src/fishaudio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

from .core import AsyncClientWrapper, ClientWrapper
from .resources import (
ASRClient,
AccountClient,
ASRClient,
AsyncAccountClient,
AsyncASRClient,
AsyncTTSClient,
Expand Down
24 changes: 12 additions & 12 deletions src/fishaudio/core/client_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,20 @@

import os
from json import JSONDecodeError
from typing import Any, Dict, Optional
from typing import Any, Optional

import httpx

from .._version import __version__
from ..exceptions import (
from fishaudio._version import __version__
from fishaudio.exceptions import (
APIError,
AuthenticationError,
NotFoundError,
PermissionError,
RateLimitError,
ServerError,
)

from .request_options import RequestOptions


Expand All @@ -32,16 +33,15 @@ def _raise_for_status(response: httpx.Response) -> None:
# Raise specific exception based on status code
if status == 401:
raise AuthenticationError(status, message, response.text)
elif status == 403:
if status == 403:
raise PermissionError(status, message, response.text)
elif status == 404:
if status == 404:
raise NotFoundError(status, message, response.text)
elif status == 429:
if status == 429:
raise RateLimitError(status, message, response.text)
elif status >= 500:
if status >= 500:
raise ServerError(status, message, response.text)
else:
raise APIError(status, message, response.text)
raise APIError(status, message, response.text)


class BaseClientWrapper:
Expand All @@ -61,8 +61,8 @@ def __init__(
self.base_url = base_url

def get_headers(
self, additional_headers: Optional[Dict[str, str]] = None
) -> Dict[str, str]:
self, additional_headers: Optional[dict[str, str]] = None
) -> dict[str, str]:
"""Build headers including authentication and user agent."""
headers = {
"Authorization": f"Bearer {self.api_key}",
Expand All @@ -73,7 +73,7 @@ def get_headers(
return headers

def _prepare_request_kwargs(
self, request_options: Optional[RequestOptions], kwargs: Dict[str, Any]
self, request_options: Optional[RequestOptions], kwargs: dict[str, Any]
) -> None:
"""Prepare request kwargs by merging headers, timeout, and query params."""
# Merge headers
Expand Down
2 changes: 1 addition & 1 deletion src/fishaudio/core/iterators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Audio stream wrappers with collection utilities."""

from typing import AsyncIterator, Iterator
from collections.abc import AsyncIterator, Iterator


class AudioStream:
Expand Down
6 changes: 3 additions & 3 deletions src/fishaudio/core/request_options.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Request-level options for API calls."""

from typing import Dict, Optional
from typing import Optional

import httpx

Expand All @@ -21,8 +21,8 @@ def __init__(
*,
timeout: Optional[float] = None,
max_retries: Optional[int] = None,
additional_headers: Optional[Dict[str, str]] = None,
additional_query_params: Optional[Dict[str, str]] = None,
additional_headers: Optional[dict[str, str]] = None,
additional_query_params: Optional[dict[str, str]] = None,
):
self.timeout = timeout
self.max_retries = max_retries
Expand Down
4 changes: 2 additions & 2 deletions src/fishaudio/core/websocket_options.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""WebSocket-level options for WebSocket connections."""

from typing import Any, Dict, Optional
from typing import Any, Optional


class WebSocketOptions:
Expand Down Expand Up @@ -40,7 +40,7 @@ def __init__(
self.max_message_size_bytes = max_message_size_bytes
self.queue_size = queue_size

def to_httpx_ws_kwargs(self) -> Dict[str, Any]:
def to_httpx_ws_kwargs(self) -> dict[str, Any]:
"""Convert to kwargs dict for httpx_ws aconnect_ws/connect_ws."""
kwargs = {}
if self.keepalive_ping_timeout_seconds is not None:
Expand Down
4 changes: 2 additions & 2 deletions src/fishaudio/resources/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

from typing import Optional

from ..core import OMIT, AsyncClientWrapper, ClientWrapper, RequestOptions
from ..types import Credits, Package
from fishaudio.core import OMIT, AsyncClientWrapper, ClientWrapper, RequestOptions
from fishaudio.types import Credits, Package


class AccountClient:
Expand Down
4 changes: 2 additions & 2 deletions src/fishaudio/resources/asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import ormsgpack

from ..core import OMIT, AsyncClientWrapper, ClientWrapper, RequestOptions
from ..types import ASRResponse
from fishaudio.core import OMIT, AsyncClientWrapper, ClientWrapper, RequestOptions
from fishaudio.types import ASRResponse


class ASRClient:
Expand Down
11 changes: 6 additions & 5 deletions src/fishaudio/resources/realtime.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
"""Real-time WebSocket streaming helpers."""

from typing import Any, AsyncIterator, Dict, Iterator, Optional
from collections.abc import AsyncIterator, Iterator
from typing import Any, Optional

import ormsgpack
from httpx_ws import WebSocketDisconnect

from ..exceptions import WebSocketError
from fishaudio.exceptions import WebSocketError


def _should_stop(data: Dict[str, Any]) -> bool:
def _should_stop(data: dict[str, Any]) -> bool:
"""
Check if WebSocket event signals stream should stop.

Expand All @@ -21,7 +22,7 @@ def _should_stop(data: Dict[str, Any]) -> bool:
return data.get("event") == "finish" and data.get("reason") == "stop"


def _process_audio_event(data: Dict[str, Any]) -> Optional[bytes]:
def _process_audio_event(data: dict[str, Any]) -> Optional[bytes]:
"""
Process a WebSocket audio event.

Expand All @@ -36,7 +37,7 @@ def _process_audio_event(data: Dict[str, Any]) -> Optional[bytes]:
"""
if data.get("event") == "audio":
return data.get("audio")
elif data.get("event") == "finish" and data.get("reason") == "error":
if data.get("event") == "finish" and data.get("reason") == "error":
raise WebSocketError("WebSocket stream ended with error")
return None # Ignore unknown events

Expand Down
Loading