|
| 1 | +import unittest |
| 2 | +from unittest import mock |
| 3 | + |
| 4 | +from datasets import Dataset |
| 5 | +from mocks.mock_dump_db import MockDumpDB |
| 6 | + |
| 7 | +from bsmetadata.preprocessing_utils import WebsiteDescPreprocessor |
| 8 | + |
| 9 | + |
| 10 | +def mock_sent_tokenize(text): |
| 11 | + return [text] |
| 12 | + |
| 13 | + |
| 14 | +class WebsiteDescPreprocessorTester(unittest.TestCase): |
| 15 | + @mock.patch("bsmetadata.preprocessing_tools.website_desc_utils.DumpDB") |
| 16 | + def setUp(self, mock_db) -> None: |
| 17 | + mock_db.return_value = MockDumpDB("some/path") |
| 18 | + self.website_processor = WebsiteDescPreprocessor() |
| 19 | + self.example_ids = [0, 1, 2] |
| 20 | + self.example_text = ["test text 1", "test text 2", "test text 3"] |
| 21 | + self.example_metadata = [ |
| 22 | + [{"key": "url", "type": "global", "value": "https://www.xyz.com"}], |
| 23 | + [ |
| 24 | + {"key": "url", "type": "global", "value": "http://sometitle.com"}, |
| 25 | + {"key": "url", "type": "global", "value": "http://notfound.com"}, |
| 26 | + ], |
| 27 | + [{"key": "url", "type": "global", "value": "https://www.test.com"}], |
| 28 | + ] |
| 29 | + |
| 30 | + self.example_dict = {"id": self.example_ids, "metadata": self.example_metadata, "text": self.example_text} |
| 31 | + |
| 32 | + @mock.patch("bsmetadata.preprocessing_tools.website_desc_utils.nltk.sent_tokenize", new=mock_sent_tokenize) |
| 33 | + def test_website_metadata_processor(self): |
| 34 | + ds = Dataset.from_dict(self.example_dict) |
| 35 | + ds = ds.map(lambda ex: self.website_processor.preprocess(ex), batched=True) |
| 36 | + target_metadata = [ |
| 37 | + [ |
| 38 | + {"key": "url", "type": "global", "value": "https://www.xyz.com"}, |
| 39 | + {"key": "website_description", "type": "global", "value": "XYZ is a U.S. based company."}, |
| 40 | + ], |
| 41 | + [ |
| 42 | + {"key": "url", "type": "global", "value": "http://sometitle.com"}, |
| 43 | + {"key": "url", "type": "global", "value": "http://notfound.com"}, |
| 44 | + {"key": "website_description", "type": "global", "value": "SomeTitle is a U.S. based company."}, |
| 45 | + ], |
| 46 | + [ |
| 47 | + {"key": "url", "type": "global", "value": "https://www.test.com"}, |
| 48 | + {"key": "website_description", "type": "global", "value": "Test is a U.S. based company."}, |
| 49 | + ], |
| 50 | + ] |
| 51 | + self.assertEqual(ds[:]["metadata"], target_metadata) |
| 52 | + |
| 53 | + |
| 54 | +if __name__ == "__main__": |
| 55 | + unittest.main() |
0 commit comments