Skip to content

Commit cbd883e

Browse files
authored
feat(medcat-trainer): improved / fixed demo screen (#293)
* feat(medcat-trainer): improved demo screen, model pack selection, new clinical text component, improved demo filtering concept picker * fix(medcat-trainer): CU-8699ryke4: test type fixes * fix(medcat-trainer): pr review changes * fix(medcat-trainer): address auto feedback --------- Co-authored-by: Tom Searle <tom@cogstack.org>
1 parent 046754b commit cbd883e

16 files changed

Lines changed: 929 additions & 121 deletions

medcat-trainer/webapp/api/api/model_cache.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from medcat.vocab import Vocab
1212
from medcat.utils.legacy.convert_cdb import get_cdb_from_old
1313

14-
from api.models import ConceptDB
14+
from api.models import ConceptDB, ModelPack
1515

1616
"""
1717
Module level caches for CDBs, Vocabs and CAT instances.
@@ -163,6 +163,22 @@ def get_medcat_from_model_pack(project, cat_map: Dict[str, CAT]=CAT_MAP) -> CAT:
163163
return cat
164164

165165

166+
def get_medcat_from_model_pack_id(modelpack_id: int, cat_map: Dict[str, CAT]=CAT_MAP) -> CAT:
167+
"""
168+
Load (and cache) a MedCAT model pack directly from a ModelPack id.
169+
"""
170+
cat_id = f'mp{modelpack_id}'
171+
if cat_id in cat_map:
172+
return cat_map[cat_id]
173+
174+
model_pack_obj = ModelPack.objects.get(id=modelpack_id)
175+
logger.info('Loading model pack from:%s', model_pack_obj.model_pack.path)
176+
cat = CAT.load_model_pack(model_pack_obj.model_pack.path)
177+
cat_map[cat_id] = cat
178+
_clear_models(cat_map=cat_map)
179+
return cat
180+
181+
166182
def get_medcat(project,
167183
cdb_map: Dict[str, CDB]=CDB_MAP,
168184
vocab_map: Dict[str, Vocab]=VOCAB_MAP,
@@ -204,6 +220,16 @@ def clear_cached_medcat(project, cat_map: Dict[str, CAT]=CAT_MAP):
204220
del cat_map[cat_id]
205221

206222

223+
def is_model_pack_loaded(modelpack_id: int, cat_map: Dict[str, CAT]=CAT_MAP) -> bool:
224+
return f'mp{modelpack_id}' in cat_map
225+
226+
227+
def clear_cached_medcat_by_model_pack_id(modelpack_id: int, cat_map: Dict[str, CAT]=CAT_MAP) -> None:
228+
cat_id = f'mp{modelpack_id}'
229+
if cat_id in cat_map:
230+
del cat_map[cat_id]
231+
232+
207233
def get_cached_cdb(cdb_id: str, cdb_map: Dict[str, CDB]=CDB_MAP) -> CDB:
208234
from api.utils import clear_cdb_cnf_addons
209235
if cdb_id not in cdb_map:

medcat-trainer/webapp/api/api/views.py

Lines changed: 84 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import_concepts_from_cdb
2525
from .data_utils import upload_projects_export
2626
from .metrics import calculate_metrics
27-
from .model_cache import get_medcat, get_cached_cdb, VOCAB_MAP, clear_cached_medcat, CAT_MAP, CDB_MAP, is_model_loaded
27+
from .model_cache import get_medcat, get_medcat_from_model_pack_id, get_cached_cdb, VOCAB_MAP, clear_cached_medcat, clear_cached_medcat_by_model_pack_id, is_model_pack_loaded, CAT_MAP, CDB_MAP, is_model_loaded
2828
from .permissions import *
2929
from .serializers import *
3030
from .solr_utils import collections_available, search_collection, ensure_concept_searchable
@@ -637,32 +637,87 @@ def update_meta_annotation(request):
637637

638638
@api_view(http_method_names=['POST'])
639639
def annotate_text(request):
640-
p_id = request.data['project_id']
641-
message = request.data['message']
642-
cuis = request.data['cuis']
643-
if message is None or p_id is None:
644-
return HttpResponseBadRequest('No message to annotate')
640+
message = request.data.get('message')
641+
cuis = request.data.get('cuis', [])
642+
p_id = request.data.get('project_id')
643+
modelpack_id = request.data.get('modelpack_id')
644+
include_sub_concepts = request.data.get('include_sub_concepts', False)
645645

646-
project = ProjectAnnotateEntities.objects.get(id=p_id)
646+
if message is None or (p_id is None and modelpack_id is None):
647+
return HttpResponseBadRequest('No message to annotate')
647648

648-
cat = get_medcat(project=project)
649-
cat.config.components.linking.filters.cuis = set(cuis)
649+
if modelpack_id is not None:
650+
try:
651+
cat = get_medcat_from_model_pack_id(int(modelpack_id))
652+
except (ValueError, TypeError):
653+
logger.warning(f'Invalid modelpack_id received for project:{p_id}')
654+
return HttpResponseBadRequest('Invalid modelpack_id for project')
655+
except ModelPack.DoesNotExist:
656+
logger.warning(f'ModelPack does not exist received for project:{p_id}')
657+
return HttpResponseBadRequest('ModelPack does not exist for project')
658+
else:
659+
project = ProjectAnnotateEntities.objects.get(id=p_id)
660+
cat = get_medcat(project=project)
661+
662+
# Normalise cuis to a set[str]
663+
if isinstance(cuis, str):
664+
cuis_set = {c.strip() for c in cuis.split(',') if c.strip()}
665+
elif isinstance(cuis, (list, tuple, set)):
666+
cuis_set = {str(c).strip() for c in cuis if str(c).strip()}
667+
else:
668+
cuis_set = set()
669+
670+
# Expand CUIs to include sub-concepts if requested
671+
if include_sub_concepts and cuis_set and cat.cdb:
672+
expanded_cuis = set(cuis_set)
673+
for parent_cui in cuis_set:
674+
try:
675+
child_cuis = get_all_ch(parent_cui, cat.cdb)
676+
expanded_cuis.update(child_cuis)
677+
except Exception as e:
678+
logger.warning(f'Failed to get children for CUI {parent_cui}: {e}')
679+
cuis_set = expanded_cuis
680+
681+
curr_cuis = cat.config.components.linking.filters
682+
cat.config.components.linking.filters.cuis = cuis_set
650683
spacy_doc = cat(message)
684+
cat.config.components.linking.filters = curr_cuis
651685

652686
ents = []
653687
anno_tkns = []
654688
for ent in spacy_doc.linked_ents:
655689
cnt = Entity.objects.filter(label=ent.cui).count()
656690
inc_ent = all(tkn not in anno_tkns for tkn in ent)
657691
if inc_ent and cnt != 0:
692+
meta_annotations = []
693+
if 'meta_cat_meta_anns' in ent.get_available_addon_paths():
694+
meta_anns = ent.get_addon_data('meta_cat_meta_anns')
695+
for meta_ann_task, pred in meta_anns.items():
696+
# Extract value and confidence from pred
697+
# pred can be a dict, object, or string
698+
if isinstance(pred, dict):
699+
pred_value = pred.get('value', str(pred))
700+
pred_confidence = pred.get('confidence', None)
701+
elif hasattr(pred, 'value'):
702+
pred_value = pred.value
703+
pred_confidence = getattr(pred, 'confidence', None)
704+
else:
705+
pred_value = str(pred)
706+
pred_confidence = None
707+
meta_annotations.append({
708+
'task': meta_ann_task,
709+
'value': pred_value,
710+
'confidence': pred_confidence
711+
})
658712
anno_tkns.extend([tkn for tkn in ent])
659713
entity = Entity.objects.get(label=ent.cui)
660714
ents.append({
661715
'entity': entity.id,
662716
'value': ent.base.text,
663717
'start_ind': ent.base.start_char_index,
664718
'end_ind': ent.base.end_char_index,
665-
'acc': ent.context_similarity
719+
'acc': ent.context_similarity,
720+
'meta_annotations': meta_annotations
666721
})
667722

668723
ents.sort(key=lambda e: e['start_ind'])
@@ -752,7 +807,7 @@ def upload_deployment(request):
752807

753808

754809
@api_view(http_method_names=['GET', 'DELETE'])
755-
def cache_model(request, project_id):
810+
def cache_project_model(request, project_id):
756811
try:
757812
project = ProjectAnnotateEntities.objects.get(id=project_id)
758813
is_loaded = is_model_loaded(project)
@@ -772,6 +827,24 @@ def cache_model(request, project_id):
772827
return Response({'message': f'{str(e)}'}, 500)
773828

774829

830+
@api_view(http_method_names=['GET', 'DELETE'])
831+
def cache_modelpack(request, modelpack_id: int):
832+
try:
833+
if request.method == 'GET':
834+
if not is_model_pack_loaded(modelpack_id):
835+
get_medcat_from_model_pack_id(modelpack_id)
836+
return Response('success', 200)
837+
elif request.method == 'DELETE':
838+
clear_cached_medcat_by_model_pack_id(modelpack_id)
839+
return Response('success', 200)
840+
else:
841+
return Response(f'Invalid method', 404)
842+
except ModelPack.DoesNotExist:
843+
return Response(f'ModelPack with id:{modelpack_id} does not exist', 404)
844+
except Exception as e:
845+
return Response({'message': f'{str(e)}'}, 500)
846+
847+
775848

776849
@api_view(http_method_names=['GET'])
777850
def model_loaded(_):

medcat-trainer/webapp/api/core/urls.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@
5050
path('api/project-progress/', api.views.project_progress),
5151
path('api/concept-db-search-index-created/', api.views.concept_search_index_available),
5252
path('api/model-loaded/', api.views.model_loaded),
53-
path('api/cache-model/<int:project_id>/', api.views.cache_model),
53+
path('api/cache-project-model/<int:project_id>/', api.views.cache_project_model),
54+
path('api/cache-modelpack/<int:modelpack_id>/', api.views.cache_modelpack),
5455
path('api/upload-deployment/', api.views.upload_deployment),
5556
path('api/model-concept-children/<int:cdb_id>/', api.views.cdb_cui_children),
5657
path('api/metrics/<int:report_id>/', api.views.view_metrics),
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,7 @@
11
/// <reference types="vite/client" />
2+
3+
declare module '*.vue' {
4+
import type { DefineComponent } from 'vue'
5+
const component: DefineComponent<object, object, any>
6+
export default component
7+
}

medcat-trainer/webapp/frontend/src/components/anns/AddAnnotation.vue

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ export default {
133133
this.loading = false
134134
this.errorMessage = err.response.data.message || 'Error loading model.'
135135
})
136+
136137
},
137138
cancel () {
138139
this.$emit('request:addAnnotationComplete')

0 commit comments

Comments
 (0)