Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions sdk/batch/speechmatics/batch/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,9 @@ class TranscriptionConfig:
enable_partials: Enable partial transcript results.
max_delay: Maximum delay for transcript delivery.
max_delay_mode: Mode for handling max delay.
transcript_filtering_config: Configuration for filtering transcription.
defaults to None.
"""

language: str = "en"
operating_point: OperatingPoint = OperatingPoint.ENHANCED
output_locale: Optional[str] = None
Expand All @@ -112,11 +113,13 @@ class TranscriptionConfig:
enable_partials: Optional[bool] = None
max_delay: Optional[float] = None
max_delay_mode: Optional[str] = None
transcript_filtering_config: Optional[TranscriptFilteringConfig] = None

def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary, excluding None values."""
return {k: v for k, v in asdict(self).items() if v is not None}

result: dict[str, Any] = {k: v for k, v in asdict(self).items() if v is not None}
if self.transcript_filtering_config is not None:
result["transcript_filtering_config"] = self.transcript_filtering_config.to_dict()
return result

@dataclass
class OutputConfig:
Expand All @@ -129,7 +132,6 @@ def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary, excluding None values."""
return {k: v for k, v in asdict(self).items() if v is not None}


@dataclass
class AlignmentConfig:
"""Configuration for alignment jobs."""
Expand Down Expand Up @@ -266,6 +268,16 @@ def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary, excluding None values."""
return {k: v for k, v in asdict(self).items() if v is not None}

@dataclass
class TranscriptFilteringConfig:
"""Configuration for transcript filtering."""

remove_disfluencies: bool = False
replacements: Optional[list[dict[str, str]]] = None

def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary, excluding None values."""
return {k: v for k, v in asdict(self).items() if v is not None}

@dataclass
class JobConfig:
Expand Down Expand Up @@ -337,7 +349,6 @@ def to_dict(self) -> dict[str, Any]:
config["audio_events_config"] = self.audio_events_config.to_dict()
if self.output_config:
config["output_config"] = self.output_config.to_dict()

return config

@classmethod
Expand All @@ -347,7 +358,9 @@ def from_dict(cls, data: dict[str, Any]) -> JobConfig:

transcription_config = None
if "transcription_config" in data:
tc_data = data["transcription_config"]
tc_data = data["transcription_config"].copy()
if "transcript_filtering_config" in tc_data and isinstance(tc_data["transcript_filtering_config"], dict):
tc_data["transcript_filtering_config"] = TranscriptFilteringConfig(**tc_data["transcript_filtering_config"])
transcription_config = TranscriptionConfig(**tc_data)

alignment_config = None
Expand Down Expand Up @@ -405,6 +418,11 @@ def from_dict(cls, data: dict[str, Any]) -> JobConfig:
fd_data = data["fetch_data"]
fetch_data = FetchData(**fd_data)

output_config = None
if "output_config" in data:
oc_data = data["output_config"]
output_config = OutputConfig(**oc_data)

return cls(
type=job_type,
fetch_data=fetch_data,
Expand All @@ -419,6 +437,7 @@ def from_dict(cls, data: dict[str, Any]) -> JobConfig:
topic_detection_config=topic_detection_config,
auto_chapters_config=auto_chapters_config,
audio_events_config=audio_events_config,
output_config=output_config,
)


Expand Down
129 changes: 129 additions & 0 deletions tests/batch/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from speechmatics.batch._models import JobConfig, TranscriptFilteringConfig, TranscriptionConfig


class TestTranscriptFilteringConfigToDict:
def test_remove_disfluencies_true_serializes_correctly(self):
config = TranscriptionConfig(
transcript_filtering_config=TranscriptFilteringConfig(remove_disfluencies=True)
)
result = config.to_dict()
assert result["transcript_filtering_config"] == {"remove_disfluencies": True}

def test_remove_disfluencies_false_included_in_output(self):
config = TranscriptionConfig(
transcript_filtering_config=TranscriptFilteringConfig(remove_disfluencies=False)
)
result = config.to_dict()
assert result["transcript_filtering_config"] == {"remove_disfluencies": False}

def test_none_excluded_from_output(self):
config = TranscriptionConfig()
result = config.to_dict()
assert "transcript_filtering_config" not in result

def test_replacements_serialized(self):
replacements = [{"from": "um", "to": ""}, {"from": "uh", "to": ""}]
config = TranscriptionConfig(
transcript_filtering_config=TranscriptFilteringConfig(replacements=replacements)
)
result = config.to_dict()
assert result["transcript_filtering_config"] == {
"remove_disfluencies": False,
"replacements": replacements,
}

def test_replacements_absent_when_none(self):
config = TranscriptionConfig(
transcript_filtering_config=TranscriptFilteringConfig(remove_disfluencies=True)
)
result = config.to_dict()
assert "replacements" not in result["transcript_filtering_config"]

def test_replacements_and_remove_disfluencies_together(self):
replacements = [{"from": "gonna", "to": "going to"}]
config = TranscriptionConfig(
transcript_filtering_config=TranscriptFilteringConfig(
remove_disfluencies=True, replacements=replacements
)
)
result = config.to_dict()
assert result["transcript_filtering_config"] == {
"remove_disfluencies": True,
"replacements": replacements,
}


class TestTranscriptFilteringConfigFromDict:
def test_dict_form_deserializes_to_config_object(self):
data = {
"type": "transcription",
"transcription_config": {
"language": "en",
"transcript_filtering_config": {"remove_disfluencies": True},
},
}
job_config = JobConfig.from_dict(data)
assert job_config.transcription_config is not None
tfc = job_config.transcription_config.transcript_filtering_config
assert isinstance(tfc, TranscriptFilteringConfig)
assert tfc.remove_disfluencies is True

def test_absent_field_is_none(self):
data = {
"type": "transcription",
"transcription_config": {"language": "en"},
}
job_config = JobConfig.from_dict(data)
assert job_config.transcription_config is not None
assert job_config.transcription_config.transcript_filtering_config is None

def test_dict_with_replacements_deserializes(self):
replacements = [{"from": "um", "to": ""}, {"from": "uh", "to": ""}]
data = {
"type": "transcription",
"transcription_config": {
"language": "en",
"transcript_filtering_config": {"replacements": replacements},
},
}
job_config = JobConfig.from_dict(data)
assert job_config.transcription_config is not None
tfc = job_config.transcription_config.transcript_filtering_config
assert isinstance(tfc, TranscriptFilteringConfig)
assert tfc.replacements == replacements
assert tfc.remove_disfluencies is False

def test_dict_with_replacements_and_remove_disfluencies_deserializes(self):
replacements = [{"from": "gonna", "to": "going to"}]
data = {
"type": "transcription",
"transcription_config": {
"language": "en",
"transcript_filtering_config": {
"remove_disfluencies": True,
"replacements": replacements,
},
},
}
job_config = JobConfig.from_dict(data)
assert job_config.transcription_config is not None
tfc = job_config.transcription_config.transcript_filtering_config
assert isinstance(tfc, TranscriptFilteringConfig)
assert tfc.remove_disfluencies is True
assert tfc.replacements == replacements


class TestOutputConfigFromDict:
def test_output_config_deserialized(self):
data = {
"type": "transcription",
"output_config": {"generate_lattice": True},
}
job_config = JobConfig.from_dict(data)
assert job_config.output_config is not None
assert job_config.output_config.generate_lattice is True

def test_absent_output_config_is_none(self):
data = {"type": "transcription"}
job_config = JobConfig.from_dict(data)
assert job_config.output_config is None