diff --git a/funasr/metrics/common.py b/funasr/metrics/common.py index 2443e0dc5..11bd2cdfb 100644 --- a/funasr/metrics/common.py +++ b/funasr/metrics/common.py @@ -9,8 +9,9 @@ import json import logging import sys - from itertools import groupby + +from rapidfuzz.distance import Levenshtein import numpy as np import six @@ -155,7 +156,6 @@ def calculate_cer_ctc(self, ys_hat, ys_pad): :return: average sentence-level CER score :rtype float """ - import editdistance cers, char_ref_lens = [], [] for i, y in enumerate(ys_hat): @@ -175,7 +175,7 @@ def calculate_cer_ctc(self, ys_hat, ys_pad): hyp_chars = "".join(seq_hat) ref_chars = "".join(seq_true) if len(ref_chars) > 0: - cers.append(editdistance.eval(hyp_chars, ref_chars)) + cers.append(Levenshtein.distance(hyp_chars, ref_chars)) char_ref_lens.append(len(ref_chars)) cer_ctc = float(sum(cers)) / sum(char_ref_lens) if cers else None @@ -214,16 +214,15 @@ def calculate_cer(self, seqs_hat, seqs_true): :return: average sentence-level CER score :rtype float """ - import editdistance char_eds, char_ref_lens = [], [] for i, seq_hat_text in enumerate(seqs_hat): seq_true_text = seqs_true[i] hyp_chars = seq_hat_text.replace(" ", "") ref_chars = seq_true_text.replace(" ", "") - char_eds.append(editdistance.eval(hyp_chars, ref_chars)) + char_eds.append(Levenshtein.distance(hyp_chars, ref_chars)) char_ref_lens.append(len(ref_chars)) - return float(sum(char_eds)) / sum(char_ref_lens) + return float(sum(char_eds)) / sum(char_ref_lens) if char_eds else None def calculate_wer(self, seqs_hat, seqs_true): """Calculate sentence-level WER score. @@ -233,13 +232,12 @@ def calculate_wer(self, seqs_hat, seqs_true): :return: average sentence-level WER score :rtype float """ - import editdistance word_eds, word_ref_lens = [], [] for i, seq_hat_text in enumerate(seqs_hat): seq_true_text = seqs_true[i] hyp_words = seq_hat_text.split() ref_words = seq_true_text.split() - word_eds.append(editdistance.eval(hyp_words, ref_words)) + word_eds.append(Levenshtein.distance(hyp_words, ref_words)) word_ref_lens.append(len(ref_words)) - return float(sum(word_eds)) / sum(word_ref_lens) + return float(sum(word_eds)) / sum(word_ref_lens) if word_eds else None diff --git a/setup.py b/setup.py index fdc693e1d..de2f5d4be 100644 --- a/setup.py +++ b/setup.py @@ -35,7 +35,7 @@ "jaconv", # Speaker & evaluation "umap_learn", - "editdistance>=0.5.2", + "rapidfuzz>=3.0.0", # Optional (training/enhancement) "torch_complex", "tensorboardX", @@ -44,7 +44,7 @@ ], # train: The modules invoked when training only. "train": [ - "editdistance", + "rapidfuzz>=3.0.0", ], # all: The modules should be optionally installled due to some reason. # Please consider moving them to "install" occasionally