Skip to content

Commit a806334

Browse files
authored
issue(medcat): CU-869c43hgw Show requirement installation message upon scripts install (#332)
* CU-869c43hgw: Fix a logged message format * CU-869c43hgw: Add instructions for requirements install * CU-869c43hgw: Implement requirements file version fix functionality * CU-869c43hgw: Add simple tests for script downloading * CU-869c43hgw: Add a requirements change test * CU-869c43hgw: Fix syntax error typo * CU-869c43hgw: Fix import and typing issues * CU-869c43hgw: Fix typo in tests * CU-869c43hgw: Fix test time versioning * CU-869c43hgw: Fix small issue in tests * CU-869c43hgw: Fix minor issues in tests * CU-869c43hgw: Fix tests issue
1 parent 66219d8 commit a806334

2 files changed

Lines changed: 59 additions & 1 deletion

File tree

medcat-v2/medcat/utils/download_scripts.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
import importlib.metadata
99
import tempfile
1010
import zipfile
11+
import sys
1112
from pathlib import Path
1213
import requests
1314
import logging
1415
import argparse
16+
import re
1517

1618

1719
logger = logging.getLogger(__name__)
@@ -67,7 +69,7 @@ def _determine_url(overwrite_url: str | None,
6769
else:
6870
tag = _find_latest_scripts_tag(version)
6971

70-
logger.info("Fetching scripts for MedCAT %s → tag %s}",
72+
logger.info("Fetching scripts for MedCAT %s → tag %s",
7173
version, tag)
7274

7375
# Download the GitHub auto-generated zipball
@@ -110,6 +112,23 @@ def _extract_zip(dest: Path, zip_path: Path):
110112
logger.info("Scripts extracted to: %s", dest)
111113

112114

115+
def _fix_requirements(dest: Path, current_version: str):
116+
requirements_file = dest / "requirements.txt"
117+
original = requirements_file.read_text(encoding="utf-8")
118+
119+
updated, count = re.subn(
120+
pattern=r"(medcat\[.*?\])[><=!~]+[\d.]+",
121+
repl=rf"\1~={current_version}",
122+
string=original,
123+
)
124+
125+
if count == 0:
126+
return
127+
128+
requirements_file.write_text(updated, encoding="utf-8")
129+
130+
131+
113132
def fetch_scripts(destination: str | Path = ".",
114133
overwrite_url: str | None = None,
115134
overwrite_tag: str | None = None) -> Path:
@@ -130,6 +149,11 @@ def fetch_scripts(destination: str | Path = ".",
130149
with tempfile.NamedTemporaryFile() as tmp:
131150
_download_zip(zip_url, tmp)
132151
_extract_zip(dest, Path(tmp.name))
152+
_fix_requirements(dest, _get_medcat_version())
153+
logger.info(
154+
"You also need to install the requiements by doing:\n"
155+
"%s -m pip install -r %s/requirements.txt",
156+
sys.executable, str(destination))
133157
return dest
134158

135159

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from medcat.utils import download_scripts
2+
3+
import os
4+
import unittest
5+
import unittest.mock
6+
import tempfile
7+
8+
9+
class ScriptsDownloadTest(unittest.TestCase):
10+
use_version = "2.5"
11+
12+
@classmethod
13+
def setUpClass(cls):
14+
cls._temp_dir = tempfile.TemporaryDirectory()
15+
with unittest.mock.patch(
16+
"medcat.utils.download_scripts._get_medcat_version"
17+
) as mock_get_version:
18+
mock_get_version.return_value = cls.use_version
19+
cls.scripts_path = download_scripts.fetch_scripts(cls._temp_dir.name)
20+
21+
def test_can_download(self):
22+
self.assertTrue(os.path.exists(self.scripts_path))
23+
self.assertTrue(os.path.isdir(self.scripts_path))
24+
self.assertTrue(os.listdir(self.scripts_path))
25+
26+
def test_has_requirements(self):
27+
self.assertIn('requirements.txt', os.listdir(self.scripts_path))
28+
29+
def test_requirements_define_correct_version(self):
30+
req_path = os.path.join(self.scripts_path, 'requirements.txt')
31+
with open(req_path) as f:
32+
medcat_line = [line.strip() for line in f if "medcat" in line][0]
33+
self.assertIn(self.use_version, medcat_line)
34+
self.assertTrue(medcat_line.endswith(self.use_version))

0 commit comments

Comments
 (0)