3232#include < algorithm>
3333#include < cassert>
3434#include < exception>
35+ #include < functional>
3536#include < map>
3637#include < optional>
3738#include < ranges>
@@ -45,48 +46,67 @@ Q_PLUGIN_METADATA(IID "studio.manivault.ScatterplotPlugin")
4546using namespace mv;
4647using namespace mv ::util;
4748
48- // returns a selection map between source and target
49- static std::optional<const mv::LinkedData*> getSelectionMapping (const mv::Dataset<Points>& source, const mv::Dataset<Points>& target) {
49+ // This only checks the immedeate parent and is deliberately not recursive
50+ // We might consider the latter in the future, but might need to cover edge cases
51+ static bool parentHasSameNumPoints (const mv::Dataset<DatasetImpl> data, const mv::Dataset<Points>& other) {
52+ if (data->isDerivedData ()) {
53+ const auto parent = data->getParent ();
54+ if (parent->getDataType () == PointType) {
55+ const auto parentPoints = mv::Dataset<Points>(parent);
56+ return parentPoints->getNumPoints () == other->getNumPoints ();
57+ }
58+ }
59+ return false ;
60+ }
61+
62+ using CheckFunc = std::function<bool (const mv::LinkedData& linkedData, const mv::Dataset<Points>& target)>;
63+
64+ static std::optional<const mv::LinkedData*> getSelectionMapping (const mv::Dataset<Points>& source, const mv::Dataset<Points>& target, CheckFunc checkMapping) {
5065 const std::vector<mv::LinkedData>& linkedDatas = source->getLinkedData ();
5166
52- // find linked data between source and target OR source and target's parent, if target is derived and they have the same number of points
53- const auto it = std::ranges::find_if (linkedDatas, [&target](const mv::LinkedData& obj) {
54-
55- // This only checks the immedeate parent and is deliberately not recursive
56- // We might consider the latter in the future, but would need to cover more edge cases
57- auto isParentOf = [&target](const mv::Dataset<Points>& linkedTarget) -> bool {
58- if (target->isDerivedData ()) {
59- const auto parent = target->getParent ();
60- if (parent->getDataType () == PointType) {
61- const auto parentPoints = mv::Dataset<Points>(parent);
62- return parentPoints->getNumPoints () == target->getNumPoints ();
63- }
64- }
65- return false ;
66- };
67+ if (linkedDatas.empty ())
68+ return std::nullopt ;
6769
68- return obj.getTargetDataset () == target || isParentOf (obj.getTargetDataset ());
70+ // find linked data between source and target OR source and target's parent, if target is derived and they have the same number of points
71+ const auto it = std::ranges::find_if (linkedDatas, [&target, &checkMapping](const mv::LinkedData& linkedData) -> bool {
72+ return checkMapping (linkedData, target);
6973 });
7074
7175 if (it != linkedDatas.end ()) {
7276 return &(*it); // return pointer to the found object
7377 }
7478
7579 return std::nullopt ; // nothing found
80+
7681}
7782
78- // returns whether there is a selection map from source to target that covers all elements in the target
79- static bool checkSelectionMapping (const mv::Dataset<Points>& source, const mv::Dataset<Points>& target) {
83+ static std::optional<const mv::LinkedData*> getSelectionMappingColorsToPositions (const mv::Dataset<Points>& colors, const mv::Dataset<Points>& positions) {
84+ auto testTargetAndParent = [](const mv::LinkedData& linkedData, const mv::Dataset<Points>& positions) -> bool {
85+ const Dataset<DatasetImpl> mapTargetData = linkedData.getTargetDataset ();
86+ return mapTargetData == positions || parentHasSameNumPoints (mapTargetData, positions);
87+ };
88+
89+ return getSelectionMapping (colors, positions, testTargetAndParent);
90+ }
8091
81- // Check if there is a mapping
82- const auto it = getSelectionMapping (source, target);
92+ static std::optional<const mv::LinkedData*> getSelectionMappingPositionsToColors (const mv::Dataset<Points>& positions, const mv::Dataset<Points>& colors) {
93+
94+ auto testTarget = [](const mv::LinkedData& linkedData, const mv::Dataset<Points>& colors) -> bool {
95+ return linkedData.getTargetDataset () == colors;
96+ };
8397
84- if (!it.has_value () || it.value () == nullptr )
85- return false ;
98+ auto mapping = getSelectionMapping (positions, colors, testTarget);
8699
100+ if (!mapping.has_value () && parentHasSameNumPoints (positions, colors)) {
101+ mapping = getSelectionMapping (positions->getParent <Points>(), colors, testTarget);
102+ }
103+
104+ return mapping;
105+ }
106+
107+ static bool checkSurjectiveMapping (const mv::LinkedData& linkedData, const std::uint32_t numPointsInTarget) {
87108 // Check if the mapping is surjective, i.e. hits all elements in the target
88- const std::map<std::uint32_t , std::vector<std::uint32_t >>& linkedMap = it.value ()->getMapping ().getMap ();
89- const std::uint32_t numPointsInTarget = target->getNumPoints ();
109+ const std::map<std::uint32_t , std::vector<std::uint32_t >>& linkedMap = linkedData.getMapping ().getMap ();
90110
91111 std::vector<bool > found (numPointsInTarget, false );
92112 std::uint32_t count = 0 ;
@@ -97,13 +117,34 @@ static bool checkSelectionMapping(const mv::Dataset<Points>& source, const mv::D
97117
98118 if (!found[val]) {
99119 found[val] = true ;
100- if (++count == numPointsInTarget)
120+ if (++count == numPointsInTarget)
101121 return true ;
102122 }
103123 }
104124 }
105125
106- return false ; // The previous loop would have returned early if the entire taget set was covered
126+ return false ; // The previous loop would have returned early if the entire taget set was covered
127+ }
128+
129+ // returns whether there is a selection map from source to target that covers all elements in the target
130+ static bool checkSelectionMapping (const mv::Dataset<Points>& colors, const mv::Dataset<Points>& positions) {
131+
132+ // Check if there is a mapping
133+ auto mapping = getSelectionMappingColorsToPositions (colors, positions);
134+ auto numTargetPoints = positions->getNumPoints ();
135+
136+ if (!mapping.has_value () || mapping.value () == nullptr ) {
137+
138+ mapping = getSelectionMappingPositionsToColors (positions, colors);
139+ numTargetPoints = colors->getNumPoints ();
140+
141+ if (!mapping.has_value () || mapping.value () == nullptr )
142+ return false ;
143+ }
144+
145+ const bool mappingCoversData = checkSurjectiveMapping (*(mapping.value ()), numTargetPoints);
146+
147+ return mappingCoversData;
107148}
108149
109150ScatterplotPlugin::ScatterplotPlugin (const PluginFactory* factory) :
@@ -240,7 +281,7 @@ ScatterplotPlugin::ScatterplotPlugin(const PluginFactory* factory) :
240281 // Accept for recoloring:
241282 // 1. data with the same number of points
242283 // 2. data which is derived from a parent that has the same number of points (e.g. for HSNE embeddings), where we can use global indices for mapping
243- // 3. data which has a fully-covering selection mapping, that we can use for setting colors
284+ // 3. data which has a fully-covering selection mapping to the position data (or it's parent) , that we can use for setting colors
244285
245286 // [1. Same number of points]
246287 const auto numPointsCandidate = candidateDataset->getNumPoints ();
@@ -253,6 +294,9 @@ ScatterplotPlugin::ScatterplotPlugin(const PluginFactory* factory) :
253294 /* then*/ _positionDataset->getSourceDataset <Points>()->getFullDataset <Points>()->getNumPoints () == numPointsCandidate :
254295 /* else*/ false ;
255296
297+ const auto & l1 = candidateDataset->getLinkedData ();
298+ const auto & l2 = _positionDataset->getLinkedData ();
299+
256300 // [3. Full selection mapping]
257301 const bool hasSelectionMapping = checkSelectionMapping (candidateDataset, _positionDataset);
258302
@@ -729,18 +773,17 @@ void ScatterplotPlugin::loadColors(const Dataset<Points>& pointsColor, const std
729773 if (!pointsColor.isValid ())
730774 return ;
731775
776+ const auto numColorPoints = pointsColor->getNumPoints ();
777+
732778 // Generate point scalars for color mapping
733779 std::vector<float > scalars;
734-
735780 pointsColor->extractDataForDimension (scalars, dimensionIndex);
736781
737- const auto numColorPoints = pointsColor->getNumPoints ();
738-
739- // If number of points do not match, prefer checking for derived data over selection mapping
782+ // If number of points do not match, use a mapping
783+ // prefer global IDs (for derived data) over selection mapping
740784 if (numColorPoints != _numPoints) {
741785
742786 try {
743-
744787 const bool hasSameNumPointsAsFull =
745788 /* if*/ _positionDataset->isDerivedData () ?
746789 /* then*/ _positionSourceDataset->getFullDataset <Points>()->getNumPoints () == numColorPoints :
@@ -758,24 +801,43 @@ void ScatterplotPlugin::loadColors(const Dataset<Points>& pointsColor, const std
758801
759802 std::swap (localScalars, scalars);
760803 }
761- else if ( // only get map if derived check failed
762- const auto selectionMapping = getSelectionMapping (pointsColor, _positionDataset);
804+ else if ( // mapping from color data set to position data set
805+ const auto selectionMapping = getSelectionMappingColorsToPositions (pointsColor, _positionDataset);
763806 /* check if valid */ selectionMapping.has_value () && selectionMapping.value () != nullptr
764807 )
765808 {
766- std::vector<float > localScalars (_numPoints, 0 );
809+ std::vector<float > mappedScalars (_numPoints, 0 );
767810
768811 // Map values like selection
769- const mv::SelectionMap::Map& linkedMap = selectionMapping.value ()->getMapping ().getMap ();
770- const std::uint32_t numPointsInTarget = _positionDataset->getNumPoints ();
812+ const mv::SelectionMap::Map& linkedMap = selectionMapping.value ()->getMapping ().getMap ();
771813
772- for (const auto & [fromID, vecOfIDs ] : linkedMap) {
773- for (std::uint32_t toID : vecOfIDs ) {
774- localScalars[toID ] = scalars[fromID ];
814+ for (const auto & [fromColorID, vecOfPositionIDs ] : linkedMap) {
815+ for (std::uint32_t toPositionID : vecOfPositionIDs ) {
816+ mappedScalars[toPositionID ] = scalars[fromColorID ];
775817 }
776818 }
777819
778- std::swap (localScalars, scalars);
820+ std::swap (mappedScalars, scalars);
821+ }
822+ else if ( // mapping from position data set to color data set
823+ const auto selectionMapping = getSelectionMappingPositionsToColors (_positionDataset, pointsColor);
824+ /* check if valid */ selectionMapping.has_value () && selectionMapping.value () != nullptr
825+ )
826+ {
827+ std::vector<float > mappedScalars (_numPoints, std::numeric_limits<float >::lowest ());
828+
829+ // Map values like selection (in reverse, use first value that occurs)
830+ const mv::SelectionMap::Map& linkedMap = selectionMapping.value ()->getMapping ().getMap ();
831+
832+ for (const auto & [fromPositionID, vecOfColorIDs] : linkedMap) {
833+ if (mappedScalars[fromPositionID] != std::numeric_limits<float >::lowest ())
834+ continue ;
835+ for (std::uint32_t toColorID : vecOfColorIDs) {
836+ mappedScalars[fromPositionID] = scalars[toColorID];
837+ }
838+ }
839+
840+ std::swap (mappedScalars, scalars);
779841 }
780842 else {
781843 qWarning (" Number of points used for coloring does not match number of points in data, aborting attempt to color plot" );
0 commit comments