Skip to content

Commit 5fc04ca

Browse files
committed
Add Reverse mapping as well
1 parent ecd5b2e commit 5fc04ca

1 file changed

Lines changed: 105 additions & 43 deletions

File tree

src/ScatterplotPlugin.cpp

Lines changed: 105 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include <algorithm>
3333
#include <cassert>
3434
#include <exception>
35+
#include <functional>
3536
#include <map>
3637
#include <optional>
3738
#include <ranges>
@@ -45,48 +46,67 @@ Q_PLUGIN_METADATA(IID "studio.manivault.ScatterplotPlugin")
4546
using namespace mv;
4647
using namespace mv::util;
4748

48-
// returns a selection map between source and target
49-
static std::optional<const mv::LinkedData*> getSelectionMapping(const mv::Dataset<Points>& source, const mv::Dataset<Points>& target) {
49+
// This only checks the immedeate parent and is deliberately not recursive
50+
// We might consider the latter in the future, but might need to cover edge cases
51+
static bool parentHasSameNumPoints(const mv::Dataset<DatasetImpl> data, const mv::Dataset<Points>& other) {
52+
if (data->isDerivedData()) {
53+
const auto parent = data->getParent();
54+
if (parent->getDataType() == PointType) {
55+
const auto parentPoints = mv::Dataset<Points>(parent);
56+
return parentPoints->getNumPoints() == other->getNumPoints();
57+
}
58+
}
59+
return false;
60+
}
61+
62+
using CheckFunc = std::function<bool(const mv::LinkedData& linkedData, const mv::Dataset<Points>& target)>;
63+
64+
static std::optional<const mv::LinkedData*> getSelectionMapping(const mv::Dataset<Points>& source, const mv::Dataset<Points>& target, CheckFunc checkMapping) {
5065
const std::vector<mv::LinkedData>& linkedDatas = source->getLinkedData();
5166

52-
// find linked data between source and target OR source and target's parent, if target is derived and they have the same number of points
53-
const auto it = std::ranges::find_if(linkedDatas, [&target](const mv::LinkedData& obj) {
54-
55-
// This only checks the immedeate parent and is deliberately not recursive
56-
// We might consider the latter in the future, but would need to cover more edge cases
57-
auto isParentOf = [&target](const mv::Dataset<Points>& linkedTarget) -> bool {
58-
if (target->isDerivedData()) {
59-
const auto parent = target->getParent();
60-
if (parent->getDataType() == PointType) {
61-
const auto parentPoints = mv::Dataset<Points>(parent);
62-
return parentPoints->getNumPoints() == target->getNumPoints();
63-
}
64-
}
65-
return false;
66-
};
67+
if (linkedDatas.empty())
68+
return std::nullopt;
6769

68-
return obj.getTargetDataset() == target || isParentOf(obj.getTargetDataset());
70+
// find linked data between source and target OR source and target's parent, if target is derived and they have the same number of points
71+
const auto it = std::ranges::find_if(linkedDatas, [&target, &checkMapping](const mv::LinkedData& linkedData) -> bool {
72+
return checkMapping(linkedData, target);
6973
});
7074

7175
if (it != linkedDatas.end()) {
7276
return &(*it); // return pointer to the found object
7377
}
7478

7579
return std::nullopt; // nothing found
80+
7681
}
7782

78-
// returns whether there is a selection map from source to target that covers all elements in the target
79-
static bool checkSelectionMapping(const mv::Dataset<Points>& source, const mv::Dataset<Points>& target) {
83+
static std::optional<const mv::LinkedData*> getSelectionMappingColorsToPositions(const mv::Dataset<Points>& colors, const mv::Dataset<Points>& positions) {
84+
auto testTargetAndParent = [](const mv::LinkedData& linkedData, const mv::Dataset<Points>& positions) -> bool {
85+
const Dataset<DatasetImpl> mapTargetData = linkedData.getTargetDataset();
86+
return mapTargetData == positions || parentHasSameNumPoints(mapTargetData, positions);
87+
};
88+
89+
return getSelectionMapping(colors, positions, testTargetAndParent);
90+
}
8091

81-
// Check if there is a mapping
82-
const auto it = getSelectionMapping(source, target);
92+
static std::optional<const mv::LinkedData*> getSelectionMappingPositionsToColors(const mv::Dataset<Points>& positions, const mv::Dataset<Points>& colors) {
93+
94+
auto testTarget = [](const mv::LinkedData& linkedData, const mv::Dataset<Points>& colors) -> bool {
95+
return linkedData.getTargetDataset() == colors;
96+
};
8397

84-
if (!it.has_value() || it.value() == nullptr)
85-
return false;
98+
auto mapping = getSelectionMapping(positions, colors, testTarget);
8699

100+
if (!mapping.has_value() && parentHasSameNumPoints(positions, colors)) {
101+
mapping = getSelectionMapping(positions->getParent<Points>(), colors, testTarget);
102+
}
103+
104+
return mapping;
105+
}
106+
107+
static bool checkSurjectiveMapping(const mv::LinkedData& linkedData, const std::uint32_t numPointsInTarget) {
87108
// Check if the mapping is surjective, i.e. hits all elements in the target
88-
const std::map<std::uint32_t, std::vector<std::uint32_t>>& linkedMap = it.value()->getMapping().getMap();
89-
const std::uint32_t numPointsInTarget = target->getNumPoints();
109+
const std::map<std::uint32_t, std::vector<std::uint32_t>>& linkedMap = linkedData.getMapping().getMap();
90110

91111
std::vector<bool> found(numPointsInTarget, false);
92112
std::uint32_t count = 0;
@@ -97,13 +117,34 @@ static bool checkSelectionMapping(const mv::Dataset<Points>& source, const mv::D
97117

98118
if (!found[val]) {
99119
found[val] = true;
100-
if (++count == numPointsInTarget)
120+
if (++count == numPointsInTarget)
101121
return true;
102122
}
103123
}
104124
}
105125

106-
return false; // The previous loop would have returned early if the entire taget set was covered
126+
return false; // The previous loop would have returned early if the entire taget set was covered
127+
}
128+
129+
// returns whether there is a selection map from source to target that covers all elements in the target
130+
static bool checkSelectionMapping(const mv::Dataset<Points>& colors, const mv::Dataset<Points>& positions) {
131+
132+
// Check if there is a mapping
133+
auto mapping = getSelectionMappingColorsToPositions(colors, positions);
134+
auto numTargetPoints = positions->getNumPoints();
135+
136+
if (!mapping.has_value() || mapping.value() == nullptr) {
137+
138+
mapping = getSelectionMappingPositionsToColors(positions, colors);
139+
numTargetPoints = colors->getNumPoints();
140+
141+
if (!mapping.has_value() || mapping.value() == nullptr)
142+
return false;
143+
}
144+
145+
const bool mappingCoversData = checkSurjectiveMapping(*(mapping.value()), numTargetPoints);
146+
147+
return mappingCoversData;
107148
}
108149

109150
ScatterplotPlugin::ScatterplotPlugin(const PluginFactory* factory) :
@@ -240,7 +281,7 @@ ScatterplotPlugin::ScatterplotPlugin(const PluginFactory* factory) :
240281
// Accept for recoloring:
241282
// 1. data with the same number of points
242283
// 2. data which is derived from a parent that has the same number of points (e.g. for HSNE embeddings), where we can use global indices for mapping
243-
// 3. data which has a fully-covering selection mapping, that we can use for setting colors
284+
// 3. data which has a fully-covering selection mapping to the position data (or it's parent), that we can use for setting colors
244285

245286
// [1. Same number of points]
246287
const auto numPointsCandidate = candidateDataset->getNumPoints();
@@ -253,6 +294,9 @@ ScatterplotPlugin::ScatterplotPlugin(const PluginFactory* factory) :
253294
/*then*/ _positionDataset->getSourceDataset<Points>()->getFullDataset<Points>()->getNumPoints() == numPointsCandidate :
254295
/*else*/ false;
255296

297+
const auto& l1 = candidateDataset->getLinkedData();
298+
const auto& l2 = _positionDataset->getLinkedData();
299+
256300
// [3. Full selection mapping]
257301
const bool hasSelectionMapping = checkSelectionMapping(candidateDataset, _positionDataset);
258302

@@ -729,18 +773,17 @@ void ScatterplotPlugin::loadColors(const Dataset<Points>& pointsColor, const std
729773
if (!pointsColor.isValid())
730774
return;
731775

776+
const auto numColorPoints = pointsColor->getNumPoints();
777+
732778
// Generate point scalars for color mapping
733779
std::vector<float> scalars;
734-
735780
pointsColor->extractDataForDimension(scalars, dimensionIndex);
736781

737-
const auto numColorPoints = pointsColor->getNumPoints();
738-
739-
// If number of points do not match, prefer checking for derived data over selection mapping
782+
// If number of points do not match, use a mapping
783+
// prefer global IDs (for derived data) over selection mapping
740784
if (numColorPoints != _numPoints) {
741785

742786
try {
743-
744787
const bool hasSameNumPointsAsFull =
745788
/*if*/ _positionDataset->isDerivedData() ?
746789
/*then*/ _positionSourceDataset->getFullDataset<Points>()->getNumPoints() == numColorPoints :
@@ -758,24 +801,43 @@ void ScatterplotPlugin::loadColors(const Dataset<Points>& pointsColor, const std
758801

759802
std::swap(localScalars, scalars);
760803
}
761-
else if ( // only get map if derived check failed
762-
const auto selectionMapping = getSelectionMapping(pointsColor, _positionDataset);
804+
else if ( // mapping from color data set to position data set
805+
const auto selectionMapping = getSelectionMappingColorsToPositions(pointsColor, _positionDataset);
763806
/* check if valid */ selectionMapping.has_value() && selectionMapping.value() != nullptr
764807
)
765808
{
766-
std::vector<float> localScalars(_numPoints, 0);
809+
std::vector<float> mappedScalars(_numPoints, 0);
767810

768811
// Map values like selection
769-
const mv::SelectionMap::Map& linkedMap = selectionMapping.value()->getMapping().getMap();
770-
const std::uint32_t numPointsInTarget = _positionDataset->getNumPoints();
812+
const mv::SelectionMap::Map& linkedMap = selectionMapping.value()->getMapping().getMap();
771813

772-
for (const auto& [fromID, vecOfIDs] : linkedMap) {
773-
for (std::uint32_t toID : vecOfIDs) {
774-
localScalars[toID] = scalars[fromID];
814+
for (const auto& [fromColorID, vecOfPositionIDs] : linkedMap) {
815+
for (std::uint32_t toPositionID : vecOfPositionIDs) {
816+
mappedScalars[toPositionID] = scalars[fromColorID];
775817
}
776818
}
777819

778-
std::swap(localScalars, scalars);
820+
std::swap(mappedScalars, scalars);
821+
}
822+
else if ( // mapping from position data set to color data set
823+
const auto selectionMapping = getSelectionMappingPositionsToColors(_positionDataset, pointsColor);
824+
/* check if valid */ selectionMapping.has_value() && selectionMapping.value() != nullptr
825+
)
826+
{
827+
std::vector<float> mappedScalars(_numPoints, std::numeric_limits<float>::lowest());
828+
829+
// Map values like selection (in reverse, use first value that occurs)
830+
const mv::SelectionMap::Map& linkedMap = selectionMapping.value()->getMapping().getMap();
831+
832+
for (const auto& [fromPositionID, vecOfColorIDs] : linkedMap) {
833+
if (mappedScalars[fromPositionID] != std::numeric_limits<float>::lowest())
834+
continue;
835+
for (std::uint32_t toColorID : vecOfColorIDs) {
836+
mappedScalars[fromPositionID] = scalars[toColorID];
837+
}
838+
}
839+
840+
std::swap(mappedScalars, scalars);
779841
}
780842
else {
781843
qWarning("Number of points used for coloring does not match number of points in data, aborting attempt to color plot");

0 commit comments

Comments
 (0)