Skip to content

Commit cf548f3

Browse files
committed
Add optional ML-based photon cuts to QC task
1 parent 6dc6981 commit cf548f3

File tree

2 files changed

+341
-27
lines changed

2 files changed

+341
-27
lines changed

PWGEM/PhotonMeson/Tasks/pcmQC.cxx

Lines changed: 170 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
#include "Framework/AnalysisTask.h"
2626
#include "Framework/runDataProcessing.h"
2727
#include <CommonConstants/MathConstants.h>
28+
#include <DataFormatsParameters/GRPMagField.h>
29+
#include <DataFormatsParameters/GRPObject.h>
2830
#include <Framework/Configurable.h>
2931
#include <Framework/HistogramRegistry.h>
3032
#include <Framework/HistogramSpec.h>
@@ -52,10 +54,18 @@ using MyCollision = MyCollisions::iterator;
5254
using MyV0Photons = soa::Join<aod::V0PhotonsKF, aod::V0KFEMEventIds>;
5355
using MyV0Photon = MyV0Photons::iterator;
5456

57+
using MyV0PhotonsML = soa::Join<aod::V0PhotonsKF, aod::V0PhotonsPhiVPsi, aod::V0KFEMEventIds>;
58+
using MyV0PhotonML = MyV0PhotonsML::iterator;
59+
5560
struct PCMQC {
5661
Configurable<int> cfgCentEstimator{"cfgCentEstimator", 2, "FT0M:0, FT0A:1, FT0C:2"};
5762
Configurable<float> cfgCentMin{"cfgCentMin", 0, "min. centrality"};
5863
Configurable<float> cfgCentMax{"cfgCentMax", 999.f, "max. centrality"};
64+
Configurable<std::string> ccdburl{"ccdb-url", "http://alice-ccdb.cern.ch", "url of the ccdb repository"};
65+
Configurable<std::string> grpPath{"grpPath", "GLO/GRP/GRP", "Path of the grp file"};
66+
Configurable<std::string> grpmagPath{"grpmagPath", "GLO/Config/GRPMagField", "CCDB path of the GRPMagField object"};
67+
Configurable<bool> skipGRPOquery{"skipGRPOquery", true, "skip grpo query"};
68+
Configurable<float> d_bz_input{"d_bz_input", -999, "bz field in kG, -999 is automatic"};
5969

6070
EMPhotonEventCut fEMEventCut;
6171
struct : ConfigurableGroup {
@@ -106,8 +116,28 @@ struct PCMQC {
106116
Configurable<float> cfg_max_TPCNsigmaEl{"cfg_max_TPCNsigmaEl", +3.0, "max. TPC n sigma for electron"};
107117
Configurable<bool> cfg_disable_itsonly_track{"cfg_disable_itsonly_track", false, "flag to disable ITSonly tracks"};
108118
Configurable<bool> cfg_disable_tpconly_track{"cfg_disable_tpconly_track", false, "flag to disable TPConly tracks"};
119+
Configurable<bool> cfg_dEdx_postcalibration{"cfg_dEdx_postcalibration", false, "flag to enable dEdx post calibration"};
120+
// for ML cuts
121+
Configurable<bool> cfg_apply_ml_cuts{"cfg_apply_ml", false, "flag to apply ML cut"};
122+
Configurable<bool> cfg_use_2d_binning{"cfg_use_2d_binning", false, "flag to use 2D binning (pT, cent)"};
123+
Configurable<bool> cfg_load_ml_models_from_ccdb{"cfg_load_ml_models_from_ccdb", true, "flag to load ML models from CCDB"};
124+
Configurable<int> cfg_timestamp_ccdb{"cfg_timestamp_ccdb", -1, "timestamp for CCDB"};
125+
Configurable<int> cfg_nclasses_ml{"cfg_nclasses_ml", static_cast<int>(o2::analysis::em_cuts_ml::NCutScores), "number of classes for ML"};
126+
Configurable<std::vector<int>> cfg_cut_dir_ml{"cfg_cut_dir_ml", std::vector<int>{o2::analysis::em_cuts_ml::vecCutDir}, "cut direction for ML"};
127+
Configurable<std::vector<std::string>> cfg_input_feature_names{"cfg_input_feature_names", std::vector<std::string>{"feature1", "feature2"}, "input feature names for ML models"};
128+
Configurable<std::vector<std::string>> cfg_model_paths_ccdb{"cfg_model_paths_ccdb", std::vector<std::string>{"path_ccdb/BDT_PCM/"}, "CCDB paths for ML models"};
129+
Configurable<std::vector<std::string>> cfg_onnx_file_names{"cfg_onnx_file_names", std::vector<std::string>{"ModelHandler_onnx_PCM.onnx"}, "ONNX file names for ML models"};
130+
Configurable<std::vector<std::string>> cfg_labels_bins_ml{"cfg_labels_bins_ml", std::vector<std::string>{"bin 0", "bin 1"}, "Labels for bins"};
131+
Configurable<std::vector<std::string>> cfg_labels_cut_scores_ml{"cfg_labels_cut_scores_ml", std::vector<std::string>{o2::analysis::em_cuts_ml::labelsCutScore}, "Labels for cut scores"};
132+
Configurable<std::vector<double>> cfg_bins_pt_ml{"cfg_bins_pt_ml", std::vector<double>{0.0, +1e+10}, "pT bin limits for ML application"};
133+
Configurable<std::vector<double>> cfg_bins_cent_ml{"cfg_bins_cent_ml", std::vector<double>{o2::analysis::em_cuts_ml::vecBinsCent}, "centrality bins for ML"};
134+
Configurable<std::vector<double>> cfg_cuts_ml_flat{"cfg_cuts_ml_flat", {0.5}, "Flattened ML cuts: [bin0_score0, bin0_score1, ..., binN_scoreM]"};
109135
} pcmcuts;
110136

137+
o2::ccdb::CcdbApi ccdbApi;
138+
o2::framework::Service<o2::ccdb::BasicCCDBManager> ccdb;
139+
int mRunNumber;
140+
float d_bz;
111141
static constexpr std::string_view event_types[2] = {"before/", "after/"};
112142
HistogramRegistry fRegistry{"output", {}, OutputObjHandlingPolicy::AnalysisObject, false, false};
113143

@@ -116,6 +146,54 @@ struct PCMQC {
116146
addhistograms();
117147
DefineEMEventCut();
118148
DefinePCMCut();
149+
150+
mRunNumber = 0;
151+
d_bz = 0;
152+
153+
ccdb->setURL(ccdburl);
154+
ccdb->setCaching(true);
155+
ccdb->setLocalObjectValidityChecking();
156+
ccdb->setFatalWhenNull(false);
157+
}
158+
159+
template <typename TCollision>
160+
void initCCDB(TCollision const& collision)
161+
{
162+
if (mRunNumber == collision.runNumber()) {
163+
return;
164+
}
165+
166+
// In case override, don't proceed, please - no CCDB access required
167+
if (d_bz_input > -990) {
168+
d_bz = d_bz_input;
169+
o2::parameters::GRPMagField grpmag;
170+
if (std::fabs(d_bz) > 1e-5) {
171+
grpmag.setL3Current(30000.f / (d_bz / 5.0f));
172+
}
173+
mRunNumber = collision.runNumber();
174+
return;
175+
}
176+
177+
auto run3grp_timestamp = collision.timestamp();
178+
o2::parameters::GRPObject* grpo = 0x0;
179+
o2::parameters::GRPMagField* grpmag = 0x0;
180+
if (!skipGRPOquery)
181+
grpo = ccdb->getForTimeStamp<o2::parameters::GRPObject>(grpPath, run3grp_timestamp);
182+
if (grpo) {
183+
// Fetch magnetic field from ccdb for current collision
184+
d_bz = grpo->getNominalL3Field();
185+
LOG(info) << "Retrieved GRP for timestamp " << run3grp_timestamp << " with magnetic field of " << d_bz << " kZG";
186+
} else {
187+
grpmag = ccdb->getForTimeStamp<o2::parameters::GRPMagField>(grpmagPath, run3grp_timestamp);
188+
if (!grpmag) {
189+
LOG(fatal) << "Got nullptr from CCDB for path " << grpmagPath << " of object GRPMagField and " << grpPath << " of object GRPObject for timestamp " << run3grp_timestamp;
190+
}
191+
// Fetch magnetic field from ccdb for current collision
192+
d_bz = std::lround(5.f * grpmag->getL3Current() / 30000.f);
193+
LOG(info) << "Retrieved GRP for timestamp " << run3grp_timestamp << " with magnetic field of " << d_bz << " kZG";
194+
}
195+
fV0PhotonCut.SetD_Bz(d_bz);
196+
mRunNumber = collision.runNumber();
119197
}
120198

121199
void addhistograms()
@@ -164,6 +242,22 @@ struct PCMQC {
164242
fRegistry.add("V0/hKFChi2vsZ", "KF chi2 vs. conversion point in Z;Z (cm);KF chi2/NDF", kTH2F, {{200, -100.0f, 100.0f}, {100, 0.f, 100.0f}}, false);
165243
fRegistry.add("V0/hsConvPoint", "photon conversion point;r_{xy} (cm);#varphi (rad.);#eta;", kTHnSparseF, {{100, 0.0f, 100}, {90, 0, o2::constants::math::TwoPI}, {80, -2, +2}}, false);
166244
fRegistry.add("V0/hNgamma", "Number of #gamma candidates per collision", kTH1F, {{101, -0.5f, 100.5f}});
245+
246+
if (pcmcuts.cfg_apply_ml_cuts) {
247+
if (pcmcuts.cfg_nclasses_ml == 2) {
248+
fRegistry.add("V0/hBDTBackgroundScoreVsPt", "BDT background score vs pT; pT (GeV/c); BDT background score", {HistType::kTH2F, {{1000, 0.0f, 20.0f}, {1000, 0.0f, 1.0f}}});
249+
fRegistry.add("V0/hBDTSignalScoreVsPt", "BDT signal score vs pT; pT (GeV/c); BDT signal score", {HistType::kTH2F, {{1000, 0.0f, 20.0f}, {1000, 0.0f, 1.0f}}});
250+
fRegistry.add("V0/hPhiVPsi", "#varphi vs. #psi angle;#psi (rad.); #varphi (rad.)", kTH2F, {{200, -o2::constants::math::PI, o2::constants::math::PI}, {200, 0, o2::constants::math::TwoPI}}, false);
251+
} else if (pcmcuts.cfg_nclasses_ml == 3) {
252+
fRegistry.add("V0/hBDTBackgroundScoreVsPt", "BDT background score vs pT; pT (GeV/c); BDT background score", {HistType::kTH2F, {{1000, 0.0f, 20.0f}, {1000, 0.0f, 1.0f}}});
253+
fRegistry.add("V0/hBDTPrimaryPhotonScoreVsPt", "BDT primary photon score vs pT; pT (GeV/c); BDT primary photon score", {HistType::kTH2F, {{1000, 0.0f, 20.0f}, {1000, 0.0f, 1.0f}}});
254+
fRegistry.add("V0/hBDTSecondaryPhotonScoreVsPt", "BDT secondary photon score vs pT; pT (GeV/c); BDT secondary photon score", {HistType::kTH2F, {{1000, 0.0f, 20.0f}, {1000, 0.0f, 1.0f}}});
255+
fRegistry.add("V0/hPhiVPsi", "#varphi vs. #psi angle;#psi (rad.); #varphi (rad.)", kTH2F, {{200, -o2::constants::math::PI, o2::constants::math::PI}, {200, 0, o2::constants::math::TwoPI}}, false);
256+
} else {
257+
fRegistry.add("V0/hBDTScoreVsPt", "BDT score vs pT; pT (GeV/c); BDT score", {HistType::kTH2F, {{1000, 0.0f, 20.0f}, {1000, 0.0f, 1.0f}}});
258+
fRegistry.add("V0/hPhiVPsi", "#varphi vs. #psi angle;#psi (rad.); #varphi (rad.)", kTH2F, {{200, -o2::constants::math::PI, o2::constants::math::PI}, {200, 0, o2::constants::math::TwoPI}}, false);
259+
}
260+
}
167261

168262
// v0leg info
169263
fRegistry.add("V0Leg/hPt", "pT;p_{T,e} (GeV/c)", kTH1F, {{1000, 0.0f, 10}}, false);
@@ -183,7 +277,11 @@ struct PCMQC {
183277
fRegistry.add("V0Leg/hChi2ITS", "chi2/number of ITS clusters", kTH1F, {{100, 0, 10}}, false);
184278
fRegistry.add("V0Leg/hITSClusterMap", "ITS cluster map", kTH1F, {{128, -0.5, 127.5}}, false);
185279
fRegistry.add("V0Leg/hMeanClusterSizeITS", "mean cluster size ITS;<cluster size> on ITS #times cos(#lambda)", kTH2F, {{1000, 0, 10}, {160, 0, 16}}, false);
186-
// fRegistry.add("V0Leg/hXY", "X vs. Y;X (cm);Y (cm)", kTH2F, {{100, 0, 100}, {80, -20, 20}}, false);
280+
if (pcmcuts.cfg_dEdx_postcalibration) {
281+
fRegistry.add("V0Leg/hPvsConvPointvsTPCNsigmaElvsEta_Pos", "momentum of pos leg vs. conversion point of V0 vs. TPC n sigma pos vs. eta of pos leg; p (GeV/c); r_{xy} (cm); n #sigma_{e}^{TPC}; #eta", kTHnSparseF, {{200, 0, 20}, {100, 0, 100}, {500, -5, 5}, {200, -1, +1}}, false);
282+
fRegistry.add("V0Leg/hPvsConvPointvsTPCNsigmaElvsEta_Ele", "momentum of neg leg vs. conversion point of V0 vs. TPC n sigma el vs. eta of neg leg; p (GeV/c); r_{xy} (cm); n #sigma_{e}^{TPC}; #eta", kTHnSparseF, {{200, 0, 20}, {100, 0, 100}, {500, -5, 5}, {200, -1, +1}}, false);
283+
}
284+
// fRegistry.add("V0Leg/hXY", "X vs. Y;X (cm);Y (cm)", kTH2F, {{100, 0, 100}, {80, -20, 20}}, false);
187285
// fRegistry.add("V0Leg/hZX", "Z vs. X;Z (cm);X (cm)", kTH2F, {{200, -100, 100}, {100, 0, 100}}, false);
188286
// fRegistry.add("V0Leg/hZY", "Z vs. Y;Z (cm);Y (cm)", kTH2F, {{200, -100, 100}, {80, -20, 20}}, false);
189287
}
@@ -235,6 +333,31 @@ struct PCMQC {
235333
fV0PhotonCut.SetRequireITSTPC(pcmcuts.cfg_require_v0_with_itstpc);
236334
fV0PhotonCut.SetRequireITSonly(pcmcuts.cfg_require_v0_with_itsonly);
237335
fV0PhotonCut.SetRequireTPConly(pcmcuts.cfg_require_v0_with_tpconly);
336+
337+
// for ML
338+
fV0PhotonCut.SetApplyMlCuts(pcmcuts.cfg_apply_ml_cuts);
339+
fV0PhotonCut.SetUse2DBinning(pcmcuts.cfg_use_2d_binning);
340+
fV0PhotonCut.SetLoadMlModelsFromCCDB(pcmcuts.cfg_load_ml_models_from_ccdb);
341+
fV0PhotonCut.SetNClassesMl(pcmcuts.cfg_nclasses_ml);
342+
fV0PhotonCut.SetMlTimestampCCDB(pcmcuts.cfg_timestamp_ccdb);
343+
fV0PhotonCut.SetCcdbUrl(ccdburl);
344+
CentType mCentralityTypeMlEnum;
345+
mCentralityTypeMlEnum = static_cast<CentType>(cfgCentEstimator.value);
346+
fV0PhotonCut.SetCentralityTypeMl(mCentralityTypeMlEnum);
347+
fV0PhotonCut.SetCutDirMl(pcmcuts.cfg_cut_dir_ml);
348+
fV0PhotonCut.SetMlModelPathsCCDB(pcmcuts.cfg_model_paths_ccdb);
349+
fV0PhotonCut.SetMlOnnxFileNames(pcmcuts.cfg_onnx_file_names);
350+
fV0PhotonCut.SetBinsPtMl(pcmcuts.cfg_bins_pt_ml);
351+
fV0PhotonCut.SetBinsCentMl(pcmcuts.cfg_bins_cent_ml);
352+
fV0PhotonCut.SetCutsMl(pcmcuts.cfg_cuts_ml_flat);
353+
fV0PhotonCut.SetNamesInputFeatures(pcmcuts.cfg_input_feature_names);
354+
fV0PhotonCut.SetLabelsBinsMl(pcmcuts.cfg_labels_bins_ml);
355+
fV0PhotonCut.SetLabelsCutScoresMl(pcmcuts.cfg_labels_cut_scores_ml);
356+
fV0PhotonCut.SetD_Bz(0.0f); // dummy value -> only for psi_pair calculation
357+
358+
if (pcmcuts.cfg_apply_ml_cuts) {
359+
fV0PhotonCut.initV0MlModels(ccdbApi);
360+
}
238361
}
239362

240363
template <const int ev_id, typename TCollision>
@@ -302,6 +425,28 @@ struct PCMQC {
302425
o2::math_utils::bringTo02Pi(phi_cp);
303426
float eta_cp = std::atanh(v0.vz() / std::sqrt(std::pow(v0.vx(), 2) + std::pow(v0.vy(), 2) + std::pow(v0.vz(), 2)));
304427
fRegistry.fill(HIST("V0/hsConvPoint"), v0.v0radius(), phi_cp, eta_cp);
428+
429+
// BDT response histogram can be filled here when apply BDT is true
430+
if (pcmcuts.cfg_apply_ml_cuts) {
431+
const std::span<const float>& bdtValue = fV0PhotonCut.getBDTValue();
432+
float psipair = 999.f;
433+
float phiv = 999.f;
434+
if constexpr( requires{ v0.psipair(); v0.phiv(); } ) {
435+
psipair = v0.psipair();
436+
phiv = v0.phiv();
437+
}
438+
fRegistry.fill(HIST("V0/hPhiVPsi"), psipair, phiv);
439+
if (pcmcuts.cfg_nclasses_ml == 2 && bdtValue.size() == 2) {
440+
fRegistry.fill(HIST("V0/hBDTBackgroundScoreVsPt"), v0.pt(), bdtValue[0]);
441+
fRegistry.fill(HIST("V0/hBDTSignalScoreVsPt"), v0.pt(), bdtValue[1]);
442+
} else if (pcmcuts.cfg_nclasses_ml == 3 && bdtValue.size() == 3) {
443+
fRegistry.fill(HIST("V0/hBDTBackgroundScoreVsPt"), v0.pt(), bdtValue[0]);
444+
fRegistry.fill(HIST("V0/hBDTPrimaryPhotonScoreVsPt"), v0.pt(), bdtValue[1]);
445+
fRegistry.fill(HIST("V0/hBDTSecondaryPhotonScoreVsPt"), v0.pt(), bdtValue[2]);
446+
} else if (bdtValue.size() == 1) {
447+
fRegistry.fill(HIST("V0/hBDTCutVsPt"), v0.pt(), bdtValue[0]);
448+
}
449+
}
305450
}
306451

307452
template <typename TLeg>
@@ -331,15 +476,17 @@ struct PCMQC {
331476
// fRegistry.fill(HIST("V0Leg/hZY"), leg.z(), leg.y());
332477
}
333478

334-
Preslice<MyV0Photons> perCollision = aod::v0photonkf::emeventId;
479+
o2::framework::SliceCache v0cache;
335480
Filter collisionFilter_centrality = (cfgCentMin < o2::aod::cent::centFT0M && o2::aod::cent::centFT0M < cfgCentMax) || (cfgCentMin < o2::aod::cent::centFT0A && o2::aod::cent::centFT0A < cfgCentMax) || (cfgCentMin < o2::aod::cent::centFT0C && o2::aod::cent::centFT0C < cfgCentMax);
336481
Filter collisionFilter_occupancy_track = eventcuts.cfgTrackOccupancyMin <= o2::aod::evsel::trackOccupancyInTimeRange && o2::aod::evsel::trackOccupancyInTimeRange < eventcuts.cfgTrackOccupancyMax;
337482
Filter collisionFilter_occupancy_ft0c = eventcuts.cfgFT0COccupancyMin <= o2::aod::evsel::ft0cOccupancyInTimeRange && o2::aod::evsel::ft0cOccupancyInTimeRange < eventcuts.cfgFT0COccupancyMax;
338483
using FilteredMyCollisions = soa::Filtered<MyCollisions>;
339484

340-
void processQC(FilteredMyCollisions const& collisions, MyV0Photons const& v0photons, aod::V0Legs const&)
485+
template <typename TV0Photon>
486+
void process(FilteredMyCollisions const& collisions, TV0Photon const& v0photons, aod::V0Legs const& v0legs)
341487
{
342-
for (auto& collision : collisions) {
488+
for (const auto& collision : collisions) {
489+
initCCDB(collision);
343490
const float centralities[3] = {collision.centFT0M(), collision.centFT0A(), collision.centFT0C()};
344491
if (centralities[cfgCentEstimator] < cfgCentMin || cfgCentMax < centralities[cfgCentEstimator]) {
345492
continue;
@@ -353,11 +500,12 @@ struct PCMQC {
353500
fRegistry.fill(HIST("Event/before/hCollisionCounter"), 10.0); // accepted
354501
fRegistry.fill(HIST("Event/after/hCollisionCounter"), 10.0); // accepted
355502

503+
fV0PhotonCut.SetCentrality(centralities[cfgCentEstimator]);
356504
int nv0 = 0;
357-
auto v0photons_coll = v0photons.sliceBy(perCollision, collision.globalIndex());
358-
for (auto& v0 : v0photons_coll) {
359-
auto pos = v0.posTrack_as<aod::V0Legs>();
360-
auto ele = v0.negTrack_as<aod::V0Legs>();
505+
auto v0photons_coll = v0photons.sliceByCached(aod::v0photonkf::emeventId, collision.globalIndex(), v0cache);
506+
for (const auto& v0 : v0photons_coll) {
507+
auto pos = v0.template posTrack_as<aod::V0Legs>();
508+
auto ele = v0.template negTrack_as<aod::V0Legs>();
361509

362510
if (!fV0PhotonCut.IsSelected<decltype(v0), aod::V0Legs>(v0)) {
363511
continue;
@@ -366,15 +514,29 @@ struct PCMQC {
366514
for (auto& leg : {pos, ele}) {
367515
fillV0LegInfo(leg);
368516
}
517+
if (pcmcuts.cfg_dEdx_postcalibration) {
518+
fRegistry.fill(HIST("V0Leg/hPvsConvPointvsTPCNsigmaElvsEta_Pos"), pos.p(), v0.v0radius(), pos.tpcNSigmaEl(), pos.eta());
519+
fRegistry.fill(HIST("V0Leg/hPvsConvPointvsTPCNsigmaElvsEta_Ele"), ele.p(), v0.v0radius(), ele.tpcNSigmaEl(), ele.eta());
520+
}
369521
nv0++;
370522
} // end of v0 loop
371523
fRegistry.fill(HIST("V0/hNgamma"), nv0);
372524
} // end of collision loop
525+
}
526+
void processQC(FilteredMyCollisions const& collisions, MyV0Photons const& v0photons, aod::V0Legs const& v0legs)
527+
{
528+
process(collisions, v0photons, v0legs);
373529
} // end of process
374530

531+
void processQCML(FilteredMyCollisions const& collisions, MyV0PhotonsML const& v0photonsML, aod::V0Legs const& v0legs)
532+
{
533+
process(collisions, v0photonsML, v0legs);
534+
} // end of ML process
535+
375536
void processDummy(MyCollisions const&) {}
376537

377538
PROCESS_SWITCH(PCMQC, processQC, "run PCM QC", true);
539+
PROCESS_SWITCH(PCMQC, processQCML, "run PCM QC with ML", false);
378540
PROCESS_SWITCH(PCMQC, processDummy, "Dummy function", false);
379541
};
380542

0 commit comments

Comments
 (0)