Skip to content

Commit 9a4aaaa

Browse files
authored
[PWGEM] Add and update ML-based photon cuts (#15076)
1 parent 6c8cf86 commit 9a4aaaa

File tree

12 files changed

+592
-112
lines changed

12 files changed

+592
-112
lines changed

PWGEM/PhotonMeson/Core/Pi0EtaToGammaGamma.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,6 @@ struct Pi0EtaToGammaGamma {
147147
o2::framework::Configurable<bool> cfg_load_ml_models_from_ccdb{"cfg_load_ml_models_from_ccdb", true, "flag to load ML models from CCDB"};
148148
o2::framework::Configurable<int> cfg_timestamp_ccdb{"cfg_timestamp_ccdb", -1, "timestamp for CCDB"};
149149
o2::framework::Configurable<int> cfg_nclasses_ml{"cfg_nclasses_ml", static_cast<int>(o2::analysis::em_cuts_ml::NCutScores), "number of classes for ML"};
150-
o2::framework::Configurable<std::string> cfg_cent_type_ml{"cfg_cent_type_ml", "CentFT0C", "centrality type for 2D ML application: CentFT0C, CentFT0M, or CentFT0A"};
151150
o2::framework::Configurable<std::vector<int>> cfg_cut_dir_ml{"cfg_cut_dir_ml", std::vector<int>{o2::analysis::em_cuts_ml::vecCutDir}, "cut direction for ML"};
152151
o2::framework::Configurable<std::vector<std::string>> cfg_input_feature_names{"cfg_input_feature_names", std::vector<std::string>{"feature1", "feature2"}, "input feature names for ML models"};
153152
o2::framework::Configurable<std::vector<std::string>> cfg_model_paths_ccdb{"cfg_model_paths_ccdb", std::vector<std::string>{"path_ccdb/BDT_PCM/"}, "CCDB paths for ML models"};
@@ -508,7 +507,9 @@ struct Pi0EtaToGammaGamma {
508507
fV0PhotonCut.SetNClassesMl(pcmcuts.cfg_nclasses_ml);
509508
fV0PhotonCut.SetMlTimestampCCDB(pcmcuts.cfg_timestamp_ccdb);
510509
fV0PhotonCut.SetCcdbUrl(ccdburl);
511-
fV0PhotonCut.SetCentralityTypeMl(pcmcuts.cfg_cent_type_ml);
510+
CentType mCentralityTypeMlEnum;
511+
mCentralityTypeMlEnum = static_cast<CentType>(cfgCentEstimator.value);
512+
fV0PhotonCut.SetCentralityTypeMl(mCentralityTypeMlEnum);
512513
fV0PhotonCut.SetCutDirMl(pcmcuts.cfg_cut_dir_ml);
513514
fV0PhotonCut.SetMlModelPathsCCDB(pcmcuts.cfg_model_paths_ccdb);
514515
fV0PhotonCut.SetMlOnnxFileNames(pcmcuts.cfg_onnx_file_names);
@@ -702,7 +703,6 @@ struct Pi0EtaToGammaGamma {
702703
{
703704
for (const auto& collision : collisions) {
704705
initCCDB(collision);
705-
fV0PhotonCut.SetCentrality(collision.centFT0A(), collision.centFT0C(), collision.centFT0M());
706706
int ndiphoton = 0;
707707
if ((pairtype == o2::aod::pwgem::photonmeson::photonpair::PairType::kPHOSPHOS || pairtype == o2::aod::pwgem::photonmeson::photonpair::PairType::kPCMPHOS) && !collision.alias_bit(triggerAliases::kTVXinPHOS)) {
708708
continue;
@@ -718,6 +718,7 @@ struct Pi0EtaToGammaGamma {
718718
}
719719

720720
const float centralities[3] = {collision.centFT0M(), collision.centFT0A(), collision.centFT0C()};
721+
fV0PhotonCut.SetCentrality(centralities[cfgCentEstimator]);
721722
if (centralities[cfgCentEstimator] < cfgCentMin || cfgCentMax < centralities[cfgCentEstimator]) {
722723
continue;
723724
}

PWGEM/PhotonMeson/Core/Pi0EtaToGammaGammaMC.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ struct Pi0EtaToGammaGammaMC {
136136
o2::framework::Configurable<bool> cfg_load_ml_models_from_ccdb{"cfg_load_ml_models_from_ccdb", true, "flag to load ML models from CCDB"};
137137
o2::framework::Configurable<int> cfg_timestamp_ccdb{"cfg_timestamp_ccdb", -1, "timestamp for CCDB"};
138138
o2::framework::Configurable<int> cfg_nclasses_ml{"cfg_nclasses_ml", static_cast<int>(o2::analysis::em_cuts_ml::NCutScores), "number of classes for ML"};
139-
o2::framework::Configurable<std::string> cfg_cent_type_ml{"cfg_cent_type_ml", "CentFT0C", "centrality type for 2D ML application: CentFT0C, CentFT0M, or CentFT0A"};
140139
o2::framework::Configurable<std::vector<int>> cfg_cut_dir_ml{"cfg_cut_dir_ml", std::vector<int>{o2::analysis::em_cuts_ml::vecCutDir}, "cut direction for ML"};
141140
o2::framework::Configurable<std::vector<std::string>> cfg_input_feature_names{"cfg_input_feature_names", std::vector<std::string>{"feature1", "feature2"}, "input feature names for ML models"};
142141
o2::framework::Configurable<std::vector<std::string>> cfg_model_paths_ccdb{"cfg_model_paths_ccdb", std::vector<std::string>{"path_ccdb/BDT_PCM/"}, "CCDB paths for ML models"};
@@ -348,7 +347,9 @@ struct Pi0EtaToGammaGammaMC {
348347
fV0PhotonCut.SetNClassesMl(pcmcuts.cfg_nclasses_ml);
349348
fV0PhotonCut.SetMlTimestampCCDB(pcmcuts.cfg_timestamp_ccdb);
350349
fV0PhotonCut.SetCcdbUrl(ccdburl);
351-
fV0PhotonCut.SetCentralityTypeMl(pcmcuts.cfg_cent_type_ml);
350+
CentType mCentralityTypeMlEnum;
351+
mCentralityTypeMlEnum = static_cast<CentType>(cfgCentEstimator.value);
352+
fV0PhotonCut.SetCentralityTypeMl(mCentralityTypeMlEnum);
352353
fV0PhotonCut.SetCutDirMl(pcmcuts.cfg_cut_dir_ml);
353354
fV0PhotonCut.SetMlModelPathsCCDB(pcmcuts.cfg_model_paths_ccdb);
354355
fV0PhotonCut.SetMlOnnxFileNames(pcmcuts.cfg_onnx_file_names);
@@ -560,7 +561,6 @@ struct Pi0EtaToGammaGammaMC {
560561
{
561562
for (auto& collision : collisions) {
562563
initCCDB(collision);
563-
fV0PhotonCut.SetCentrality(collision.centFT0A(), collision.centFT0C(), collision.centFT0M());
564564
if ((pairtype == o2::aod::pwgem::photonmeson::photonpair::PairType::kPHOSPHOS || pairtype == o2::aod::pwgem::photonmeson::photonpair::PairType::kPCMPHOS) && !collision.alias_bit(triggerAliases::kTVXinPHOS)) {
565565
continue;
566566
}
@@ -575,6 +575,7 @@ struct Pi0EtaToGammaGammaMC {
575575
}
576576

577577
const float centralities[3] = {collision.centFT0M(), collision.centFT0A(), collision.centFT0C()};
578+
fV0PhotonCut.SetCentrality(centralities[cfgCentEstimator]);
578579
if (centralities[cfgCentEstimator] < cfgCentMin || cfgCentMax < centralities[cfgCentEstimator]) {
579580
continue;
580581
}

PWGEM/PhotonMeson/Core/V0PhotonCandidate.h

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,19 @@
2323

2424
#include <KFParticle.h>
2525

26+
enum CentType : uint8_t {
27+
CentFT0M = 0,
28+
CentFT0A = 1,
29+
CentFT0C = 2
30+
};
31+
2632
struct V0PhotonCandidate {
2733

2834
public:
29-
// Constructor for photonconversionbuilder
30-
V0PhotonCandidate(const KFParticle& v0, const KFParticle& pos, const KFParticle& ele, const auto& collision, float cospa, float d_bz) : cospa(cospa)
35+
// Empty Constructor
36+
V0PhotonCandidate() = default;
37+
// Set method for photonconversionbuilder
38+
void setPhotonCandidate(const KFParticle& v0, const KFParticle& pos, const KFParticle& ele, const auto& collision, float cospa, float psipair, float phiv, CentType centType)
3139
{
3240
px = v0.GetPx();
3341
py = v0.GetPy();
@@ -57,18 +65,27 @@ struct V0PhotonCandidate {
5765

5866
alpha = v0_alpha(posPx, posPy, posPz, elePx, elePy, elePz);
5967
qt = v0_qt(posPx, posPy, posPz, elePx, elePy, elePz);
60-
int posSign = (pos.GetQ() > 0) - (pos.GetQ() < 0);
61-
int eleSign = (ele.GetQ() > 0) - (ele.GetQ() < 0);
62-
phiv = o2::aod::pwgem::dilepton::utils::pairutil::getPhivPair(posPx, posPy, posPz, elePx, elePy, elePz, posSign, eleSign, d_bz);
63-
psipair = o2::aod::pwgem::dilepton::utils::pairutil::getPsiPair(posPx, posPy, posPz, elePx, elePy, elePz);
64-
65-
centFT0M = collision.centFT0M();
66-
centFT0C = collision.centFT0C();
67-
centFT0A = collision.centFT0A();
68+
69+
this->cospa = cospa;
70+
this->psipair = psipair;
71+
this->phiv = phiv;
72+
this->centType = centType;
73+
74+
switch (centType) {
75+
case CentType::CentFT0A:
76+
cent = collision.centFT0A();
77+
break;
78+
case CentType::CentFT0C:
79+
cent = collision.centFT0C();
80+
break;
81+
case CentType::CentFT0M:
82+
cent = collision.centFT0M();
83+
break;
84+
}
6885
}
6986

70-
// Constructor for V0PhotonCut
71-
V0PhotonCandidate(const auto& v0, const auto& pos, const auto& ele, float centFT0A, float centFT0C, float centFT0M, float d_bz) : centFT0A(centFT0A), centFT0C(centFT0C), centFT0M(centFT0M)
87+
// Set-Method for V0PhotonCut
88+
void setPhoton(const auto& v0, const auto& pos, const auto& ele, float cent, CentType centType)
7289
{
7390
px = v0.px();
7491
py = v0.py();
@@ -93,9 +110,14 @@ struct V0PhotonCandidate {
93110
cospa = v0.cospa();
94111
alpha = v0.alpha();
95112
qt = v0.qtarm();
96-
97-
phiv = o2::aod::pwgem::dilepton::utils::pairutil::getPhivPair(posPx, posPy, posPz, elePx, elePy, elePz, pos.sign(), ele.sign(), d_bz);
98-
psipair = o2::aod::pwgem::dilepton::utils::pairutil::getPsiPair(posPx, posPy, posPz, elePx, elePy, elePz);
113+
psipair = 999.f; // default if V0PhotonPhiVPsi table is not included
114+
phiv = 999.f; // default if V0PhotonPhiVPsi table is not included
115+
if constexpr (requires { v0.psipair(); v0.phiv(); }) {
116+
psipair = v0.psipair();
117+
phiv = v0.phiv();
118+
}
119+
this->cent = cent;
120+
this->centType = centType;
99121
}
100122

101123
// Getter functions
@@ -119,10 +141,9 @@ struct V0PhotonCandidate {
119141
float getElePx() const { return elePx; }
120142
float getElePy() const { return elePy; }
121143
float getElePz() const { return elePz; }
122-
float getCentFT0M() const { return centFT0M; }
123-
float getCentFT0C() const { return centFT0C; }
124-
float getCentFT0A() const { return centFT0A; }
144+
float getCent() const { return cent; }
125145
float getPCA() const { return pca; }
146+
CentType getCentType() const { return centType; }
126147

127148
private:
128149
float px;
@@ -145,10 +166,9 @@ struct V0PhotonCandidate {
145166
float psipair;
146167
float cospa;
147168
float chi2ndf;
148-
float centFT0A;
149-
float centFT0C;
150-
float centFT0M;
169+
float cent;
151170
float pca;
171+
CentType centType;
152172
};
153173

154174
#endif // PWGEM_PHOTONMESON_CORE_V0PHOTONCANDIDATE_H_

PWGEM/PhotonMeson/Core/V0PhotonCut.cxx

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -306,20 +306,20 @@ void V0PhotonCut::SetCutsMl(const std::vector<double>& cuts)
306306
void V0PhotonCut::SetNClassesMl(int nClasses)
307307
{
308308
mNClassesMl = nClasses;
309+
mOutputML.reserve(mNClassesMl);
309310
LOG(info) << "V0 Photon Cut, set number of classes ML: " << mNClassesMl;
310311
}
311312

312313
void V0PhotonCut::SetNamesInputFeatures(const std::vector<std::string>& featureNames)
313314
{
314315
mNamesInputFeatures = featureNames;
316+
mMlInputFeatures.reserve(mNamesInputFeatures.size());
315317
LOG(info) << "V0 Photon Cut, set ML input feature names with size:" << mNamesInputFeatures.size();
316318
}
317319

318-
void V0PhotonCut::SetCentrality(float centFT0A, float centFT0C, float centFT0M)
320+
void V0PhotonCut::SetCentrality(float cent)
319321
{
320-
mCentFT0A = centFT0A;
321-
mCentFT0C = centFT0C;
322-
mCentFT0M = centFT0M;
322+
mCent = cent;
323323
}
324324
void V0PhotonCut::SetD_Bz(float d_bz)
325325
{
@@ -332,10 +332,10 @@ void V0PhotonCut::SetCutDirMl(const std::vector<int>& cutDirMl)
332332
LOG(info) << "V0 Photon Cut, set ML cut directions with size:" << mCutDirMl.size();
333333
}
334334

335-
void V0PhotonCut::SetCentralityTypeMl(const std::string& centType)
335+
void V0PhotonCut::SetCentralityTypeMl(CentType centType)
336336
{
337337
mCentralityTypeMl = centType;
338-
LOG(info) << "V0 Photon Cut, set centrality type ML: " << mCentralityTypeMl;
338+
LOG(info) << "V0 Photon Cut, set centrality type ML: " << mCentralityTypeMl << " (0: CentFT0M, 1: CentFT0A, 2: CentFT0C)";
339339
}
340340

341341
void V0PhotonCut::SetLabelsBinsMl(const std::vector<std::string>& labelsBins)

PWGEM/PhotonMeson/Core/V0PhotonCut.h

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
#include <cstdint>
4141
#include <functional>
4242
#include <set>
43+
#include <span>
4344
#include <string>
4445
#include <utility>
4546
#include <vector>
@@ -145,7 +146,7 @@ static const std::vector<std::string> labelsCent = {
145146
"Cent bin 10"};
146147

147148
// column labels
148-
static const std::vector<std::string> labelsCutScore = {"score primary photons", "score background"};
149+
static const std::vector<std::string> labelsCutScore = {"score background", "score primary photons"};
149150
} // namespace em_cuts_ml
150151

151152
} // namespace o2::analysis
@@ -584,30 +585,22 @@ class V0PhotonCut : public TNamed
584585
}
585586
}
586587
if (mApplyMlCuts) {
587-
if (!mEmMlResponse) {
588+
if (mEmMlResponse == nullptr) {
588589
LOG(error) << "EM ML Response is not initialized!";
589590
return false;
590591
}
591-
bool mIsSelectedMl = false;
592-
std::vector<float> mOutputML;
593-
V0PhotonCandidate v0photoncandidate(v0, pos, ele, mCentFT0A, mCentFT0C, mCentFT0M, mD_Bz);
594-
std::vector<float> mlInputFeatures = mEmMlResponse->getInputFeatures(v0photoncandidate, pos, ele);
592+
mIsSelectedMl = false;
593+
mV0PhotonForMl.setPhoton(v0, pos, ele, mCent, mCentralityTypeMl);
594+
mMlInputFeatures = mEmMlResponse->getInputFeatures(mV0PhotonForMl, pos, ele);
595595
if (mUse2DBinning) {
596-
if (mCentralityTypeMl == "CentFT0C") {
597-
mIsSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.getPt(), v0photoncandidate.getCentFT0C(), mOutputML);
598-
} else if (mCentralityTypeMl == "CentFT0A") {
599-
mIsSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.getPt(), v0photoncandidate.getCentFT0A(), mOutputML);
600-
} else if (mCentralityTypeMl == "CentFT0M") {
601-
mIsSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.getPt(), v0photoncandidate.getCentFT0M(), mOutputML);
602-
} else {
603-
LOG(fatal) << "Unsupported centTypePCMMl: " << mCentralityTypeMl << " , please choose from CentFT0C, CentFT0A, CentFT0M.";
604-
}
596+
mIsSelectedMl = mEmMlResponse->isSelectedMl(mMlInputFeatures, mV0PhotonForMl.getPt(), mV0PhotonForMl.getCent(), mOutputML);
605597
} else {
606-
mIsSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.getPt(), mOutputML);
598+
mIsSelectedMl = mEmMlResponse->isSelectedMl(mMlInputFeatures, mV0PhotonForMl.getPt(), mOutputML);
607599
}
608600
if (!mIsSelectedMl) {
609601
return false;
610602
}
603+
mMlBDTScores = std::span<float>(mOutputML.data(), mOutputML.size());
611604
}
612605
if (doQA) {
613606
fillAfterPhotonHistogram(v0, pos, ele, fRegistry);
@@ -860,7 +853,7 @@ class V0PhotonCut : public TNamed
860853

861854
void initV0MlModels(o2::ccdb::CcdbApi& ccdbApi)
862855
{
863-
if (!mEmMlResponse) {
856+
if (mEmMlResponse == nullptr) {
864857
mEmMlResponse = new o2::analysis::EmMlResponsePCM<float>();
865858
}
866859
if (mUse2DBinning) {
@@ -914,6 +907,11 @@ class V0PhotonCut : public TNamed
914907
mEmMlResponse->init();
915908
}
916909

910+
const std::span<float> getBDTValue() const
911+
{
912+
return mMlBDTScores;
913+
}
914+
917915
template <o2::soa::is_iterator TMCPhoton>
918916
bool IsConversionPointInAcceptance(TMCPhoton const& mcphoton, float convRadius) const
919917
{
@@ -983,10 +981,10 @@ class V0PhotonCut : public TNamed
983981
void SetLoadMlModelsFromCCDB(bool flag = true);
984982
void SetNClassesMl(int nClasses);
985983
void SetMlTimestampCCDB(int timestamp);
986-
void SetCentrality(float centFT0A, float centFT0C, float centFT0M);
984+
void SetCentralityTypeMl(CentType centType);
985+
void SetCentrality(float cent);
987986
void SetD_Bz(float d_bz);
988987
void SetCcdbUrl(const std::string& url = "http://alice-ccdb.cern.ch");
989-
void SetCentralityTypeMl(const std::string& centType);
990988
void SetCutDirMl(const std::vector<int>& cutDirMl);
991989
void SetMlModelPathsCCDB(const std::vector<std::string>& modelPaths);
992990
void SetMlOnnxFileNames(const std::vector<std::string>& onnxFileNamesVec);
@@ -1026,22 +1024,25 @@ class V0PhotonCut : public TNamed
10261024
bool mLoadMlModelsFromCCDB{true};
10271025
int mTimestampCCDB{-1};
10281026
int mNClassesMl{static_cast<int>(o2::analysis::em_cuts_ml::NCutScores)};
1029-
float mCentFT0A{0.f};
1030-
float mCentFT0C{0.f};
1031-
float mCentFT0M{0.f};
1027+
float mCent{0.f};
10321028
float mD_Bz{0.f};
10331029
std::string mCcdbUrl{"http://alice-ccdb.cern.ch"};
1034-
std::string mCentralityTypeMl{"CentFT0C"};
10351030
std::vector<int> mCutDirMl{std::vector<int>{o2::analysis::em_cuts_ml::vecCutDir}};
10361031
std::vector<std::string> mModelPathsCCDB{std::vector<std::string>{"path_ccdb/BDT_PCM/"}};
10371032
std::vector<std::string> mOnnxFileNames{std::vector<std::string>{"ModelHandler_onnx_PCM.onnx"}};
10381033
std::vector<std::string> mNamesInputFeatures{std::vector<std::string>{"feature1", "feature2"}};
10391034
std::vector<std::string> mLabelsBinsMl{std::vector<std::string>{"bin 0", "bin 1"}};
1040-
std::vector<std::string> mLabelsCutScoresMl{std::vector<std::string>{"score primary photons", "score background"}};
1035+
std::vector<std::string> mLabelsCutScoresMl{std::vector<std::string>{o2::analysis::em_cuts_ml::labelsCutScore}};
10411036
std::vector<double> mBinsPtMl{std::vector<double>{o2::analysis::em_cuts_ml::vecBinsPt}};
10421037
std::vector<double> mBinsCentMl{std::vector<double>{o2::analysis::em_cuts_ml::vecBinsCent}};
10431038
std::vector<double> mCutsMlFlat{std::vector<double>{0.5}};
10441039
o2::analysis::EmMlResponsePCM<float>* mEmMlResponse{nullptr};
1040+
mutable bool mIsSelectedMl{false};
1041+
mutable std::vector<float> mOutputML{};
1042+
mutable std::vector<float> mMlInputFeatures{};
1043+
mutable std::span<float> mMlBDTScores{};
1044+
CentType mCentralityTypeMl{CentType::CentFT0C};
1045+
mutable V0PhotonCandidate mV0PhotonForMl;
10451046

10461047
// pid cuts
10471048
float mMinTPCNsigmaEl{-5}, mMaxTPCNsigmaEl{+5};

PWGEM/PhotonMeson/DataModel/gammaTables.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -416,15 +416,16 @@ using EMPrimaryElectronsFromDalitz = EMPrimaryElectronsFromDalitz_001;
416416
// iterators
417417
using EMPrimaryElectronFromDalitz = EMPrimaryElectronsFromDalitz::iterator;
418418

419-
namespace v0photonsphiv
419+
namespace v0photonsphivpsi
420420
{
421421
DECLARE_SOA_INDEX_COLUMN(EMEvent, emevent); //!
422422
DECLARE_SOA_COLUMN(PhiV, phiv, float); //!
423-
} // namespace v0photonsphiv
424-
DECLARE_SOA_TABLE(V0PhotonsPhiV, "AOD", "V0PHOTONPHIV", //!
425-
o2::soa::Index<>, v0photonsphiv::PhiV);
423+
DECLARE_SOA_COLUMN(PsiPair, psipair, float);
424+
} // namespace v0photonsphivpsi
425+
DECLARE_SOA_TABLE(V0PhotonsPhiVPsi, "AOD", "V0PHOTONPHIVPSI", //!
426+
o2::soa::Index<>, v0photonsphivpsi::PhiV, v0photonsphivpsi::PsiPair);
426427
// iterators
427-
using V0PhotonsPhiV = V0PhotonsPhiV;
428+
using V0PhotonsPhiVPsi = V0PhotonsPhiVPsi;
428429

429430
namespace dalitzee
430431
{

0 commit comments

Comments
 (0)