Skip to content

Commit 74f47b1

Browse files
committed
Add data type property to rows
1 parent 7d44497 commit 74f47b1

27 files changed

Lines changed: 217 additions & 176 deletions

machine/corpora/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from .alignment_row import AlignmentRow
55
from .corpora_utils import batch
66
from .corpus import Corpus
7+
from .data_type import DataType
78
from .dbl_bundle_text_corpus import DblBundleTextCorpus
89
from .dictionary_alignment_corpus import DictionaryAlignmentCorpus
910
from .dictionary_text_corpus import DictionaryTextCorpus
@@ -102,6 +103,7 @@
102103
"batch",
103104
"Corpus",
104105
"create_versification_ref_corpus",
106+
"DataType",
105107
"DblBundleTextCorpus",
106108
"DictionaryAlignmentCorpus",
107109
"DictionaryTextCorpus",

machine/corpora/corpora_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,14 @@ def get_split_indices(
4949
return set(rand.sample(range(corpus_size), min(split_size, corpus_size)))
5050

5151

52-
def get_files(file_patterns: Iterable[str]) -> Iterable[Tuple[str, str]]:
52+
def get_files(file_patterns: Iterable[str]) -> Iterable[Tuple[str, str, int]]:
5353
file_patterns = list(file_patterns)
5454
if len(file_patterns) == 1 and os.path.isfile(file_patterns[0]):
55-
yield ("*all*", file_patterns[0])
55+
yield ("*all*", file_patterns[0], 0)
5656
else:
5757
for i, file_pattern in enumerate(file_patterns):
5858
if os.path.isfile(file_pattern):
59-
yield (str(i), file_pattern)
59+
yield (str(i), file_pattern, i)
6060
continue
6161

6262
if "*" not in file_pattern and "?" not in file_pattern and not os.path.exists(file_pattern):
@@ -89,7 +89,7 @@ def get_files(file_patterns: Iterable[str]) -> Iterable[Tuple[str, str]]:
8989
updated_id += group
9090
if len(updated_id) > 0:
9191
id = updated_id
92-
yield (id, filename)
92+
yield (id, filename, i)
9393

9494

9595
def gen(iterable: Iterable[T] = []) -> Generator[T, None, None]:

machine/corpora/data_type.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from enum import Enum, auto
2+
3+
4+
class DataType(Enum): # TODO what options to include? Does a verse=SENTENCE for our purposes?
5+
GLOSS = auto()
6+
PHRASE = auto()
7+
SENTENCE = auto()
8+
PASSAGE = auto()
9+
DOCUMENT = auto()

machine/corpora/memory_text.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
from typing import Generator, Iterable
22

33
from .corpora_utils import gen
4+
from .data_type import DataType
45
from .text import Text
56
from .text_row import TextRow
67

78

89
class MemoryText(Text):
9-
def __init__(self, id: str, rows: Iterable[TextRow] = []) -> None:
10+
def __init__(self, id: str, rows: Iterable[TextRow] = [], data_type: DataType = DataType.SENTENCE) -> None:
1011
self._id = id
1112
self._rows = list(rows)
13+
if any([r.data_type != data_type for r in self._rows]):
14+
raise ValueError(f"{type(data_type)} of rows must match text {type(data_type)} {data_type}")
15+
self._data_type = data_type
1216

1317
@property
1418
def id(self) -> str:
@@ -18,5 +22,9 @@ def id(self) -> str:
1822
def sort_key(self) -> str:
1923
return self._id
2024

25+
@property
26+
def data_type(self) -> DataType:
27+
return self._data_type
28+
2129
def _get_rows(self) -> Generator[TextRow, None, None]:
2230
return gen(self._rows)

machine/corpora/n_parallel_text_corpus.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any, Callable, Iterable, List, Optional, Sequence, Set, cast
33

44
from ..scripture.verse_ref import Versification
5+
from .data_type import DataType
56
from .n_parallel_text_corpus_base import NParallelTextCorpusBase
67
from .n_parallel_text_row import NParallelTextRow
78
from .scripture_ref import ScriptureRef
@@ -14,6 +15,7 @@ class _RangeRow:
1415
refs: List[Any]
1516
segment: List[str]
1617
is_sentence_start: bool = False
18+
data_type: DataType = DataType.SENTENCE
1719

1820
@property
1921
def is_in_range(self):
@@ -36,6 +38,7 @@ def __init__(self, n: int):
3638
self.text_id = ""
3739
self.versifications: Optional[List[Versification]] = None
3840
self.row_ref_comparer = None
41+
self.data_type = DataType.SENTENCE
3942

4043
@property
4144
def is_in_range(self) -> bool:
@@ -44,6 +47,7 @@ def is_in_range(self) -> bool:
4447
def add_text_row(self, row: TextRow, index: int):
4548
self.text_id = row.text_id
4649
self.rows[index].refs.append(row.ref)
50+
self.rows[index].data_type = row.data_type
4751
if self.rows[index].is_empty:
4852
self.rows[index].is_sentence_start = row.is_sentence_start
4953
self.rows[index].segment.extend(row.segment)
@@ -53,6 +57,7 @@ def create_row(self) -> NParallelTextRow:
5357
reference_refs: List[Any] = [r.refs[0] if len(r.refs) > 0 else None for r in self.rows if len(r.refs) > 0]
5458
for i in range(len(self.rows)):
5559
row = self.rows[i]
60+
self.data_type = row.data_type
5661

5762
if (
5863
self.versifications is not None
@@ -62,7 +67,7 @@ def create_row(self) -> NParallelTextRow:
6267
refs[i] = [cast(ScriptureRef, r).change_versification(self.versifications[i]) for r in reference_refs]
6368
else:
6469
refs[i] = row.refs.copy()
65-
n_parallel_text_row = NParallelTextRow(self.text_id, refs)
70+
n_parallel_text_row = NParallelTextRow(self.text_id, refs, self.data_type)
6671
n_parallel_text_row.n_segments = [r.segment.copy() for r in self.rows]
6772
n_parallel_text_row.n_flags = [
6873
TextRowFlags.SENTENCE_START if r.is_sentence_start else TextRowFlags.NONE for r in self.rows
@@ -288,6 +293,7 @@ def _create_rows(
288293
yield range_info.create_row()
289294

290295
default_refs = [[r.ref for r in rows if r is not None][0]]
296+
data_type = DataType.SENTENCE
291297

292298
text_id: Optional[str] = None
293299
refs: List[List[Any]] = []
@@ -298,6 +304,7 @@ def _create_rows(
298304
for i in range(len(rows)):
299305
row = rows[i]
300306
if row is not None:
307+
data_type = row.data_type
301308
text_id = text_id or row.text_id
302309
if self.corpora[i].is_scripture:
303310
refs[i] = self._correct_versification([row.ref] if row.ref is None else default_refs, i)
@@ -314,7 +321,7 @@ def _create_rows(
314321
)
315322
refs = [r or default_refs for r in refs]
316323

317-
new_row = NParallelTextRow(cast(str, text_id), refs)
324+
new_row = NParallelTextRow(cast(str, text_id), refs, data_type)
318325
new_row.n_segments = [r.segment if r is not None else [] for r in rows]
319326
new_row.n_flags = flags
320327
yield new_row

machine/corpora/n_parallel_text_row.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
from typing import Any, Sequence
22

3+
from .data_type import DataType
34
from .text_row import TextRowFlags
45

56

67
class NParallelTextRow:
7-
def __init__(self, text_id: str, n_refs: Sequence[Sequence[Any]]):
8+
def __init__(self, text_id: str, n_refs: Sequence[Sequence[Any]], data_type: DataType = DataType.SENTENCE):
89
if len([n_ref for n_ref in n_refs if n_ref is not None and len(n_ref) > 0]) == 0:
910
raise ValueError(f"Refs must be provided but n_refs={n_refs}")
1011
self._text_id = text_id
1112
self._n_refs = n_refs
1213
self._n = len(n_refs)
1314
self.n_segments: Sequence[Sequence[str]] = [[] for _ in range(0, self._n)]
1415
self.n_flags: Sequence[TextRowFlags] = [TextRowFlags.SENTENCE_START for _ in range(0, self._n)]
16+
self._data_type = data_type
1517

1618
@property
1719
def text_id(self) -> str:
@@ -21,6 +23,10 @@ def text_id(self) -> str:
2123
def ref(self) -> Any:
2224
return self._n_refs[0][0]
2325

26+
@property
27+
def data_type(self) -> DataType:
28+
return self._data_type
29+
2430
@property
2531
def n_refs(self) -> Sequence[Sequence[Any]]:
2632
return self._n_refs
@@ -42,6 +48,6 @@ def text(self, i: int) -> str:
4248
return " ".join(self.n_segments[i])
4349

4450
def invert(self) -> "NParallelTextRow":
45-
inverted_row = NParallelTextRow(self._text_id, list(reversed(self._n_refs)))
51+
inverted_row = NParallelTextRow(self._text_id, list(reversed(self._n_refs)), data_type=self.data_type)
4652
inverted_row.n_flags = list(reversed(self.n_flags))
4753
return inverted_row

machine/corpora/parallel_text_corpus.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from .aligned_word_pair import AlignedWordPair
2727
from .corpora_utils import get_split_indices
2828
from .corpus import Corpus
29+
from .data_type import DataType
2930
from .parallel_text_row import ParallelTextRow
3031
from .token_processors import escape_spaces, lowercase, normalize, unescape_spaces
3132

@@ -401,10 +402,11 @@ def to_hf_dataset(
401402
ref_column: Optional[str] = "ref",
402403
translation_column: str = "translation",
403404
alignment_column: Optional[str] = "alignment",
405+
data_type_column: Optional[str] = "data_type",
404406
) -> Dataset:
405407
try:
406408
from datasets.arrow_dataset import Dataset
407-
from datasets.features.features import Features, FeatureType, Sequence, Value
409+
from datasets.features.features import ClassLabel, Features, FeatureType, Sequence, Value
408410
from datasets.features.translation import Translation
409411
except ImportError:
410412
raise RuntimeError("datasets is not installed.")
@@ -416,6 +418,8 @@ def to_hf_dataset(
416418
features_dict[ref_column] = Sequence(Value("string"))
417419
if alignment_column is not None:
418420
features_dict[alignment_column] = Sequence({source_lang: Value("int32"), target_lang: Value("int32")})
421+
if data_type_column is not None:
422+
features_dict[data_type_column] = ClassLabel(names=[e.name for e in DataType])
419423
features = Features(features_dict)
420424

421425
def iterable() -> Iterable[dict]:
@@ -426,6 +430,8 @@ def iterable() -> Iterable[dict]:
426430
example[text_id_column] = row.text_id
427431
if ref_column is not None:
428432
example[ref_column] = row.refs
433+
if data_type_column is not None:
434+
example[data_type_column] = row.data_type.name
429435
example[translation_column] = {source_lang: row.source_text, target_lang: row.target_text}
430436
if alignment_column is not None:
431437
src_indices: List[int] = []

machine/corpora/parallel_text_row.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Collection, Optional, Sequence
44

55
from .aligned_word_pair import AlignedWordPair
6+
from .data_type import DataType
67
from .text_row import TextRowFlags
78

89

@@ -17,6 +18,7 @@ def __init__(
1718
aligned_word_pairs: Optional[Collection[AlignedWordPair]] = None,
1819
source_flags: TextRowFlags = TextRowFlags.SENTENCE_START,
1920
target_flags: TextRowFlags = TextRowFlags.SENTENCE_START,
21+
data_type: DataType = DataType.SENTENCE,
2022
) -> None:
2123
if not text_id:
2224
raise ValueError("A text_id must be set.")
@@ -25,6 +27,7 @@ def __init__(
2527
self._text_id = text_id
2628
self._source_refs = source_refs
2729
self._target_refs = target_refs
30+
self._data_type = data_type
2831
self.source_segment = source_segment
2932
self.target_segment = target_segment
3033
self.aligned_word_pairs = aligned_word_pairs
@@ -51,6 +54,10 @@ def ref(self) -> Any:
5154
def refs(self) -> Sequence[Any]:
5255
return self.target_refs if len(self.source_refs) == 0 else self.source_refs
5356

57+
@property
58+
def data_type(self) -> DataType:
59+
return self._data_type
60+
5461
@property
5562
def is_source_sentence_start(self) -> bool:
5663
return TextRowFlags.SENTENCE_START in self.source_flags
@@ -107,4 +114,5 @@ def invert(self) -> ParallelTextRow:
107114
None if self.aligned_word_pairs is None else [wp.invert() for wp in self.aligned_word_pairs],
108115
self.target_flags,
109116
self.source_flags,
117+
self.data_type,
110118
)

machine/corpora/paratext_backup_terms_corpus.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from zipfile import ZipFile
33

44
from ..utils.typeshed import StrPath
5+
from .data_type import DataType
56
from .dictionary_text_corpus import DictionaryTextCorpus
67
from .key_term import KeyTerm
78
from .memory_text import MemoryText
@@ -25,5 +26,12 @@ def __init__(self, filename: StrPath, term_categories: Sequence[str], use_term_g
2526
f"{settings.biblical_terms_file_name}"
2627
)
2728

28-
text = MemoryText(text_id, [TextRow(text_id, key_term.id, key_term.renderings) for key_term in key_terms])
29+
text = MemoryText(
30+
text_id,
31+
[
32+
TextRow(text_id, key_term.id, key_term.renderings, data_type=DataType.GLOSS)
33+
for key_term in key_terms
34+
],
35+
data_type=DataType.GLOSS,
36+
)
2937
self._add_text(text)

machine/corpora/scripture_text.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@
44
from ..scripture.verse_ref import VerseRef, Versification
55
from ..utils.context_managed_generator import ContextManagedGenerator
66
from .corpora_utils import gen, get_scripture_text_sort_key
7+
from .data_type import DataType
78
from .scripture_ref import ScriptureElement, ScriptureRef
89
from .text_base import TextBase
910
from .text_row import TextRow, TextRowFlags
1011

1112

1213
class ScriptureText(TextBase):
1314
def __init__(self, id: str, versification: Optional[Versification] = None) -> None:
14-
super().__init__(id, get_scripture_text_sort_key(id))
15+
super().__init__(id, get_scripture_text_sort_key(id), data_type=DataType.SENTENCE)
1516
self._versification = ENGLISH_VERSIFICATION if versification is None else versification
1617

1718
@property

0 commit comments

Comments
 (0)