diff --git a/sdk/batch/speechmatics/batch/_models.py b/sdk/batch/speechmatics/batch/_models.py index 84970b11..8b235969 100644 --- a/sdk/batch/speechmatics/batch/_models.py +++ b/sdk/batch/speechmatics/batch/_models.py @@ -97,8 +97,9 @@ class TranscriptionConfig: enable_partials: Enable partial transcript results. max_delay: Maximum delay for transcript delivery. max_delay_mode: Mode for handling max delay. + transcript_filtering_config: Configuration for filtering transcription. + defaults to None. """ - language: str = "en" operating_point: OperatingPoint = OperatingPoint.ENHANCED output_locale: Optional[str] = None @@ -112,11 +113,13 @@ class TranscriptionConfig: enable_partials: Optional[bool] = None max_delay: Optional[float] = None max_delay_mode: Optional[str] = None + transcript_filtering_config: Optional[TranscriptFilteringConfig] = None def to_dict(self) -> dict[str, Any]: - """Convert to dictionary, excluding None values.""" - return {k: v for k, v in asdict(self).items() if v is not None} - + result: dict[str, Any] = {k: v for k, v in asdict(self).items() if v is not None} + if self.transcript_filtering_config is not None: + result["transcript_filtering_config"] = self.transcript_filtering_config.to_dict() + return result @dataclass class OutputConfig: @@ -129,7 +132,6 @@ def to_dict(self) -> dict[str, Any]: """Convert to dictionary, excluding None values.""" return {k: v for k, v in asdict(self).items() if v is not None} - @dataclass class AlignmentConfig: """Configuration for alignment jobs.""" @@ -266,6 +268,16 @@ def to_dict(self) -> dict[str, Any]: """Convert to dictionary, excluding None values.""" return {k: v for k, v in asdict(self).items() if v is not None} +@dataclass +class TranscriptFilteringConfig: + """Configuration for transcript filtering.""" + + remove_disfluencies: bool = False + replacements: Optional[list[dict[str, str]]] = None + + def to_dict(self) -> dict[str, Any]: + """Convert to dictionary, excluding None values.""" + return {k: v for k, v in asdict(self).items() if v is not None} @dataclass class JobConfig: @@ -337,7 +349,6 @@ def to_dict(self) -> dict[str, Any]: config["audio_events_config"] = self.audio_events_config.to_dict() if self.output_config: config["output_config"] = self.output_config.to_dict() - return config @classmethod @@ -347,7 +358,9 @@ def from_dict(cls, data: dict[str, Any]) -> JobConfig: transcription_config = None if "transcription_config" in data: - tc_data = data["transcription_config"] + tc_data = data["transcription_config"].copy() + if "transcript_filtering_config" in tc_data and isinstance(tc_data["transcript_filtering_config"], dict): + tc_data["transcript_filtering_config"] = TranscriptFilteringConfig(**tc_data["transcript_filtering_config"]) transcription_config = TranscriptionConfig(**tc_data) alignment_config = None @@ -405,6 +418,11 @@ def from_dict(cls, data: dict[str, Any]) -> JobConfig: fd_data = data["fetch_data"] fetch_data = FetchData(**fd_data) + output_config = None + if "output_config" in data: + oc_data = data["output_config"] + output_config = OutputConfig(**oc_data) + return cls( type=job_type, fetch_data=fetch_data, @@ -419,6 +437,7 @@ def from_dict(cls, data: dict[str, Any]) -> JobConfig: topic_detection_config=topic_detection_config, auto_chapters_config=auto_chapters_config, audio_events_config=audio_events_config, + output_config=output_config, ) diff --git a/tests/batch/test_models.py b/tests/batch/test_models.py new file mode 100644 index 00000000..d262685e --- /dev/null +++ b/tests/batch/test_models.py @@ -0,0 +1,129 @@ +from speechmatics.batch._models import JobConfig, TranscriptFilteringConfig, TranscriptionConfig + + +class TestTranscriptFilteringConfigToDict: + def test_remove_disfluencies_true_serializes_correctly(self): + config = TranscriptionConfig( + transcript_filtering_config=TranscriptFilteringConfig(remove_disfluencies=True) + ) + result = config.to_dict() + assert result["transcript_filtering_config"] == {"remove_disfluencies": True} + + def test_remove_disfluencies_false_included_in_output(self): + config = TranscriptionConfig( + transcript_filtering_config=TranscriptFilteringConfig(remove_disfluencies=False) + ) + result = config.to_dict() + assert result["transcript_filtering_config"] == {"remove_disfluencies": False} + + def test_none_excluded_from_output(self): + config = TranscriptionConfig() + result = config.to_dict() + assert "transcript_filtering_config" not in result + + def test_replacements_serialized(self): + replacements = [{"from": "um", "to": ""}, {"from": "uh", "to": ""}] + config = TranscriptionConfig( + transcript_filtering_config=TranscriptFilteringConfig(replacements=replacements) + ) + result = config.to_dict() + assert result["transcript_filtering_config"] == { + "remove_disfluencies": False, + "replacements": replacements, + } + + def test_replacements_absent_when_none(self): + config = TranscriptionConfig( + transcript_filtering_config=TranscriptFilteringConfig(remove_disfluencies=True) + ) + result = config.to_dict() + assert "replacements" not in result["transcript_filtering_config"] + + def test_replacements_and_remove_disfluencies_together(self): + replacements = [{"from": "gonna", "to": "going to"}] + config = TranscriptionConfig( + transcript_filtering_config=TranscriptFilteringConfig( + remove_disfluencies=True, replacements=replacements + ) + ) + result = config.to_dict() + assert result["transcript_filtering_config"] == { + "remove_disfluencies": True, + "replacements": replacements, + } + + +class TestTranscriptFilteringConfigFromDict: + def test_dict_form_deserializes_to_config_object(self): + data = { + "type": "transcription", + "transcription_config": { + "language": "en", + "transcript_filtering_config": {"remove_disfluencies": True}, + }, + } + job_config = JobConfig.from_dict(data) + assert job_config.transcription_config is not None + tfc = job_config.transcription_config.transcript_filtering_config + assert isinstance(tfc, TranscriptFilteringConfig) + assert tfc.remove_disfluencies is True + + def test_absent_field_is_none(self): + data = { + "type": "transcription", + "transcription_config": {"language": "en"}, + } + job_config = JobConfig.from_dict(data) + assert job_config.transcription_config is not None + assert job_config.transcription_config.transcript_filtering_config is None + + def test_dict_with_replacements_deserializes(self): + replacements = [{"from": "um", "to": ""}, {"from": "uh", "to": ""}] + data = { + "type": "transcription", + "transcription_config": { + "language": "en", + "transcript_filtering_config": {"replacements": replacements}, + }, + } + job_config = JobConfig.from_dict(data) + assert job_config.transcription_config is not None + tfc = job_config.transcription_config.transcript_filtering_config + assert isinstance(tfc, TranscriptFilteringConfig) + assert tfc.replacements == replacements + assert tfc.remove_disfluencies is False + + def test_dict_with_replacements_and_remove_disfluencies_deserializes(self): + replacements = [{"from": "gonna", "to": "going to"}] + data = { + "type": "transcription", + "transcription_config": { + "language": "en", + "transcript_filtering_config": { + "remove_disfluencies": True, + "replacements": replacements, + }, + }, + } + job_config = JobConfig.from_dict(data) + assert job_config.transcription_config is not None + tfc = job_config.transcription_config.transcript_filtering_config + assert isinstance(tfc, TranscriptFilteringConfig) + assert tfc.remove_disfluencies is True + assert tfc.replacements == replacements + + +class TestOutputConfigFromDict: + def test_output_config_deserialized(self): + data = { + "type": "transcription", + "output_config": {"generate_lattice": True}, + } + job_config = JobConfig.from_dict(data) + assert job_config.output_config is not None + assert job_config.output_config.generate_lattice is True + + def test_absent_output_config_is_none(self): + data = {"type": "transcription"} + job_config = JobConfig.from_dict(data) + assert job_config.output_config is None