Skip to content

Commit 6dc6981

Browse files
committed
Update ML-based photon cuts to be able to get ML score in tasks. Update initialisation and filling of V0PhotonCandidate.
1 parent c1d708a commit 6dc6981

File tree

3 files changed

+72
-51
lines changed

3 files changed

+72
-51
lines changed

PWGEM/PhotonMeson/Core/V0PhotonCandidate.h

Lines changed: 41 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,19 @@
2323

2424
#include <KFParticle.h>
2525

26+
enum CentType : uint8_t {
27+
CentFT0M = 0,
28+
CentFT0A = 1,
29+
CentFT0C = 2
30+
};
31+
2632
struct V0PhotonCandidate {
2733

2834
public:
29-
// Constructor for photonconversionbuilder
30-
V0PhotonCandidate(const KFParticle& v0, const KFParticle& pos, const KFParticle& ele, const auto& collision, float cospa, float d_bz) : cospa(cospa)
35+
// Empty Constructor
36+
V0PhotonCandidate() = default;
37+
// Set method for photonconversionbuilder
38+
void setPhotonCandidate(const KFParticle& v0, const KFParticle& pos, const KFParticle& ele, const auto& collision, float cospa, float psipair, float phiv, CentType centType)
3139
{
3240
px = v0.GetPx();
3341
py = v0.GetPy();
@@ -57,18 +65,27 @@ struct V0PhotonCandidate {
5765

5866
alpha = v0_alpha(posPx, posPy, posPz, elePx, elePy, elePz);
5967
qt = v0_qt(posPx, posPy, posPz, elePx, elePy, elePz);
60-
int posSign = (pos.GetQ() > 0) - (pos.GetQ() < 0);
61-
int eleSign = (ele.GetQ() > 0) - (ele.GetQ() < 0);
62-
phiv = o2::aod::pwgem::dilepton::utils::pairutil::getPhivPair(posPx, posPy, posPz, elePx, elePy, elePz, posSign, eleSign, d_bz);
63-
psipair = o2::aod::pwgem::dilepton::utils::pairutil::getPsiPair(posPx, posPy, posPz, elePx, elePy, elePz);
64-
65-
centFT0M = collision.centFT0M();
66-
centFT0C = collision.centFT0C();
67-
centFT0A = collision.centFT0A();
68+
69+
this->cospa = cospa;
70+
this->psipair = psipair;
71+
this->phiv = phiv;
72+
this->centType = centType;
73+
74+
switch (centType) {
75+
case CentType::CentFT0A:
76+
cent = collision.centFT0A();
77+
break;
78+
case CentType::CentFT0C:
79+
cent = collision.centFT0C();
80+
break;
81+
case CentType::CentFT0M:
82+
cent = collision.centFT0M();
83+
break;
84+
}
6885
}
6986

70-
// Constructor for V0PhotonCut
71-
V0PhotonCandidate(const auto& v0, const auto& pos, const auto& ele, float centFT0A, float centFT0C, float centFT0M, float d_bz) : centFT0A(centFT0A), centFT0C(centFT0C), centFT0M(centFT0M)
87+
// Set-Method for V0PhotonCut
88+
void setPhoton(const auto& v0, const auto& pos, const auto& ele, float cent, CentType centType)
7289
{
7390
px = v0.px();
7491
py = v0.py();
@@ -93,9 +110,14 @@ struct V0PhotonCandidate {
93110
cospa = v0.cospa();
94111
alpha = v0.alpha();
95112
qt = v0.qtarm();
96-
97-
phiv = o2::aod::pwgem::dilepton::utils::pairutil::getPhivPair(posPx, posPy, posPz, elePx, elePy, elePz, pos.sign(), ele.sign(), d_bz);
98-
psipair = o2::aod::pwgem::dilepton::utils::pairutil::getPsiPair(posPx, posPy, posPz, elePx, elePy, elePz);
113+
psipair = 999.f; // default if V0PhotonPhiVPsi table is not included
114+
phiv = 999.f; // default if V0PhotonPhiVPsi table is not included
115+
if constexpr( requires{ v0.psipair(); v0.phiv(); } ) {
116+
psipair = v0.psipair();
117+
phiv = v0.phiv();
118+
}
119+
this->cent = cent;
120+
this->centType = centType;
99121
}
100122

101123
// Getter functions
@@ -119,10 +141,9 @@ struct V0PhotonCandidate {
119141
float getElePx() const { return elePx; }
120142
float getElePy() const { return elePy; }
121143
float getElePz() const { return elePz; }
122-
float getCentFT0M() const { return centFT0M; }
123-
float getCentFT0C() const { return centFT0C; }
124-
float getCentFT0A() const { return centFT0A; }
144+
float getCent() const { return cent; }
125145
float getPCA() const { return pca; }
146+
CentType getCentType() const { return centType; }
126147

127148
private:
128149
float px;
@@ -145,10 +166,9 @@ struct V0PhotonCandidate {
145166
float psipair;
146167
float cospa;
147168
float chi2ndf;
148-
float centFT0A;
149-
float centFT0C;
150-
float centFT0M;
169+
float cent;
151170
float pca;
171+
CentType centType;
152172
};
153173

154174
#endif // PWGEM_PHOTONMESON_CORE_V0PHOTONCANDIDATE_H_

PWGEM/PhotonMeson/Core/V0PhotonCut.cxx

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -306,20 +306,20 @@ void V0PhotonCut::SetCutsMl(const std::vector<double>& cuts)
306306
void V0PhotonCut::SetNClassesMl(int nClasses)
307307
{
308308
mNClassesMl = nClasses;
309+
mOutputML.reserve(mNClassesMl);
309310
LOG(info) << "V0 Photon Cut, set number of classes ML: " << mNClassesMl;
310311
}
311312

312313
void V0PhotonCut::SetNamesInputFeatures(const std::vector<std::string>& featureNames)
313314
{
314315
mNamesInputFeatures = featureNames;
316+
mMlInputFeatures.reserve(mNamesInputFeatures.size());
315317
LOG(info) << "V0 Photon Cut, set ML input feature names with size:" << mNamesInputFeatures.size();
316318
}
317319

318-
void V0PhotonCut::SetCentrality(float centFT0A, float centFT0C, float centFT0M)
320+
void V0PhotonCut::SetCentrality(float cent)
319321
{
320-
mCentFT0A = centFT0A;
321-
mCentFT0C = centFT0C;
322-
mCentFT0M = centFT0M;
322+
mCent = cent;
323323
}
324324
void V0PhotonCut::SetD_Bz(float d_bz)
325325
{
@@ -332,10 +332,10 @@ void V0PhotonCut::SetCutDirMl(const std::vector<int>& cutDirMl)
332332
LOG(info) << "V0 Photon Cut, set ML cut directions with size:" << mCutDirMl.size();
333333
}
334334

335-
void V0PhotonCut::SetCentralityTypeMl(const std::string& centType)
335+
void V0PhotonCut::SetCentralityTypeMl(CentType centType)
336336
{
337337
mCentralityTypeMl = centType;
338-
LOG(info) << "V0 Photon Cut, set centrality type ML: " << mCentralityTypeMl;
338+
LOG(info) << "V0 Photon Cut, set centrality type ML: " << mCentralityTypeMl << " (0: CentFT0M, 1: CentFT0A, 2: CentFT0C)";
339339
}
340340

341341
void V0PhotonCut::SetLabelsBinsMl(const std::vector<std::string>& labelsBins)

PWGEM/PhotonMeson/Core/V0PhotonCut.h

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
#include <string>
4343
#include <utility>
4444
#include <vector>
45+
#include <span>
4546

4647
namespace o2::analysis
4748
{
@@ -144,7 +145,7 @@ static const std::vector<std::string> labelsCent = {
144145
"Cent bin 10"};
145146

146147
// column labels
147-
static const std::vector<std::string> labelsCutScore = {"score primary photons", "score background"};
148+
static const std::vector<std::string> labelsCutScore = {"score background", "score primary photons"};
148149
} // namespace em_cuts_ml
149150

150151
} // namespace o2::analysis
@@ -573,30 +574,22 @@ class V0PhotonCut : public TNamed
573574
}
574575
}
575576
if (mApplyMlCuts) {
576-
if (!mEmMlResponse) {
577+
if (mEmMlResponse == nullptr) {
577578
LOG(error) << "EM ML Response is not initialized!";
578579
return false;
579580
}
580-
bool mIsSelectedMl = false;
581-
std::vector<float> mOutputML;
582-
V0PhotonCandidate v0photoncandidate(v0, pos, ele, mCentFT0A, mCentFT0C, mCentFT0M, mD_Bz);
583-
std::vector<float> mlInputFeatures = mEmMlResponse->getInputFeatures(v0photoncandidate, pos, ele);
581+
mIsSelectedMl = false;
582+
mV0PhotonForMl.setPhoton(v0, pos, ele, mCent, mCentralityTypeMl);
583+
mMlInputFeatures = mEmMlResponse->getInputFeatures(mV0PhotonForMl, pos, ele);
584584
if (mUse2DBinning) {
585-
if (mCentralityTypeMl == "CentFT0C") {
586-
mIsSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.getPt(), v0photoncandidate.getCentFT0C(), mOutputML);
587-
} else if (mCentralityTypeMl == "CentFT0A") {
588-
mIsSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.getPt(), v0photoncandidate.getCentFT0A(), mOutputML);
589-
} else if (mCentralityTypeMl == "CentFT0M") {
590-
mIsSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.getPt(), v0photoncandidate.getCentFT0M(), mOutputML);
591-
} else {
592-
LOG(fatal) << "Unsupported centTypePCMMl: " << mCentralityTypeMl << " , please choose from CentFT0C, CentFT0A, CentFT0M.";
593-
}
585+
mIsSelectedMl = mEmMlResponse->isSelectedMl(mMlInputFeatures, mV0PhotonForMl.getPt(), mV0PhotonForMl.getCent(), mOutputML);
594586
} else {
595-
mIsSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.getPt(), mOutputML);
587+
mIsSelectedMl = mEmMlResponse->isSelectedMl(mMlInputFeatures, mV0PhotonForMl.getPt(), mOutputML);
596588
}
597589
if (!mIsSelectedMl) {
598590
return false;
599591
}
592+
mMlBDTScores = std::span<float>(mOutputML.data(), mOutputML.size());
600593
}
601594
if (doQA) {
602595
fillAfterPhotonHistogram(v0, pos, ele, fRegistry);
@@ -845,7 +838,7 @@ class V0PhotonCut : public TNamed
845838

846839
void initV0MlModels(o2::ccdb::CcdbApi& ccdbApi)
847840
{
848-
if (!mEmMlResponse) {
841+
if (mEmMlResponse == nullptr) {
849842
mEmMlResponse = new o2::analysis::EmMlResponsePCM<float>();
850843
}
851844
if (mUse2DBinning) {
@@ -899,6 +892,11 @@ class V0PhotonCut : public TNamed
899892
mEmMlResponse->init();
900893
}
901894

895+
const std::span<float> getBDTValue() const
896+
{
897+
return mMlBDTScores;
898+
}
899+
902900
template <o2::soa::is_iterator TMCPhoton>
903901
bool IsConversionPointInAcceptance(TMCPhoton const& mcphoton, float convRadius) const
904902
{
@@ -968,10 +966,10 @@ class V0PhotonCut : public TNamed
968966
void SetLoadMlModelsFromCCDB(bool flag = true);
969967
void SetNClassesMl(int nClasses);
970968
void SetMlTimestampCCDB(int timestamp);
971-
void SetCentrality(float centFT0A, float centFT0C, float centFT0M);
969+
void SetCentralityTypeMl(CentType centType);
970+
void SetCentrality(float cent);
972971
void SetD_Bz(float d_bz);
973972
void SetCcdbUrl(const std::string& url = "http://alice-ccdb.cern.ch");
974-
void SetCentralityTypeMl(const std::string& centType);
975973
void SetCutDirMl(const std::vector<int>& cutDirMl);
976974
void SetMlModelPathsCCDB(const std::vector<std::string>& modelPaths);
977975
void SetMlOnnxFileNames(const std::vector<std::string>& onnxFileNamesVec);
@@ -1011,22 +1009,25 @@ class V0PhotonCut : public TNamed
10111009
bool mLoadMlModelsFromCCDB{true};
10121010
int mTimestampCCDB{-1};
10131011
int mNClassesMl{static_cast<int>(o2::analysis::em_cuts_ml::NCutScores)};
1014-
float mCentFT0A{0.f};
1015-
float mCentFT0C{0.f};
1016-
float mCentFT0M{0.f};
1012+
float mCent{0.f};
10171013
float mD_Bz{0.f};
10181014
std::string mCcdbUrl{"http://alice-ccdb.cern.ch"};
1019-
std::string mCentralityTypeMl{"CentFT0C"};
10201015
std::vector<int> mCutDirMl{std::vector<int>{o2::analysis::em_cuts_ml::vecCutDir}};
10211016
std::vector<std::string> mModelPathsCCDB{std::vector<std::string>{"path_ccdb/BDT_PCM/"}};
10221017
std::vector<std::string> mOnnxFileNames{std::vector<std::string>{"ModelHandler_onnx_PCM.onnx"}};
10231018
std::vector<std::string> mNamesInputFeatures{std::vector<std::string>{"feature1", "feature2"}};
10241019
std::vector<std::string> mLabelsBinsMl{std::vector<std::string>{"bin 0", "bin 1"}};
1025-
std::vector<std::string> mLabelsCutScoresMl{std::vector<std::string>{"score primary photons", "score background"}};
1020+
std::vector<std::string> mLabelsCutScoresMl{std::vector<std::string>{o2::analysis::em_cuts_ml::labelsCutScore}};
10261021
std::vector<double> mBinsPtMl{std::vector<double>{o2::analysis::em_cuts_ml::vecBinsPt}};
10271022
std::vector<double> mBinsCentMl{std::vector<double>{o2::analysis::em_cuts_ml::vecBinsCent}};
10281023
std::vector<double> mCutsMlFlat{std::vector<double>{0.5}};
10291024
o2::analysis::EmMlResponsePCM<float>* mEmMlResponse{nullptr};
1025+
mutable bool mIsSelectedMl{false};
1026+
mutable std::vector<float> mOutputML{};
1027+
mutable std::vector<float> mMlInputFeatures{};
1028+
mutable std::span<float> mMlBDTScores{};
1029+
CentType mCentralityTypeMl{CentType::CentFT0C};
1030+
mutable V0PhotonCandidate mV0PhotonForMl;
10301031

10311032
// pid cuts
10321033
float mMinTPCNsigmaEl{-5}, mMaxTPCNsigmaEl{+5};

0 commit comments

Comments
 (0)