|
42 | 42 | #include <string> |
43 | 43 | #include <utility> |
44 | 44 | #include <vector> |
| 45 | +#include <span> |
45 | 46 |
|
46 | 47 | namespace o2::analysis |
47 | 48 | { |
@@ -144,7 +145,7 @@ static const std::vector<std::string> labelsCent = { |
144 | 145 | "Cent bin 10"}; |
145 | 146 |
|
146 | 147 | // column labels |
147 | | -static const std::vector<std::string> labelsCutScore = {"score primary photons", "score background"}; |
| 148 | +static const std::vector<std::string> labelsCutScore = {"score background", "score primary photons"}; |
148 | 149 | } // namespace em_cuts_ml |
149 | 150 |
|
150 | 151 | } // namespace o2::analysis |
@@ -573,30 +574,22 @@ class V0PhotonCut : public TNamed |
573 | 574 | } |
574 | 575 | } |
575 | 576 | if (mApplyMlCuts) { |
576 | | - if (!mEmMlResponse) { |
| 577 | + if (mEmMlResponse == nullptr) { |
577 | 578 | LOG(error) << "EM ML Response is not initialized!"; |
578 | 579 | return false; |
579 | 580 | } |
580 | | - bool mIsSelectedMl = false; |
581 | | - std::vector<float> mOutputML; |
582 | | - V0PhotonCandidate v0photoncandidate(v0, pos, ele, mCentFT0A, mCentFT0C, mCentFT0M, mD_Bz); |
583 | | - std::vector<float> mlInputFeatures = mEmMlResponse->getInputFeatures(v0photoncandidate, pos, ele); |
| 581 | + mIsSelectedMl = false; |
| 582 | + mV0PhotonForMl.setPhoton(v0, pos, ele, mCent, mCentralityTypeMl); |
| 583 | + mMlInputFeatures = mEmMlResponse->getInputFeatures(mV0PhotonForMl, pos, ele); |
584 | 584 | if (mUse2DBinning) { |
585 | | - if (mCentralityTypeMl == "CentFT0C") { |
586 | | - mIsSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.getPt(), v0photoncandidate.getCentFT0C(), mOutputML); |
587 | | - } else if (mCentralityTypeMl == "CentFT0A") { |
588 | | - mIsSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.getPt(), v0photoncandidate.getCentFT0A(), mOutputML); |
589 | | - } else if (mCentralityTypeMl == "CentFT0M") { |
590 | | - mIsSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.getPt(), v0photoncandidate.getCentFT0M(), mOutputML); |
591 | | - } else { |
592 | | - LOG(fatal) << "Unsupported centTypePCMMl: " << mCentralityTypeMl << " , please choose from CentFT0C, CentFT0A, CentFT0M."; |
593 | | - } |
| 585 | + mIsSelectedMl = mEmMlResponse->isSelectedMl(mMlInputFeatures, mV0PhotonForMl.getPt(), mV0PhotonForMl.getCent(), mOutputML); |
594 | 586 | } else { |
595 | | - mIsSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.getPt(), mOutputML); |
| 587 | + mIsSelectedMl = mEmMlResponse->isSelectedMl(mMlInputFeatures, mV0PhotonForMl.getPt(), mOutputML); |
596 | 588 | } |
597 | 589 | if (!mIsSelectedMl) { |
598 | 590 | return false; |
599 | 591 | } |
| 592 | + mMlBDTScores = std::span<float>(mOutputML.data(), mOutputML.size()); |
600 | 593 | } |
601 | 594 | if (doQA) { |
602 | 595 | fillAfterPhotonHistogram(v0, pos, ele, fRegistry); |
@@ -845,7 +838,7 @@ class V0PhotonCut : public TNamed |
845 | 838 |
|
846 | 839 | void initV0MlModels(o2::ccdb::CcdbApi& ccdbApi) |
847 | 840 | { |
848 | | - if (!mEmMlResponse) { |
| 841 | + if (mEmMlResponse == nullptr) { |
849 | 842 | mEmMlResponse = new o2::analysis::EmMlResponsePCM<float>(); |
850 | 843 | } |
851 | 844 | if (mUse2DBinning) { |
@@ -899,6 +892,11 @@ class V0PhotonCut : public TNamed |
899 | 892 | mEmMlResponse->init(); |
900 | 893 | } |
901 | 894 |
|
| 895 | + const std::span<float> getBDTValue() const |
| 896 | + { |
| 897 | + return mMlBDTScores; |
| 898 | + } |
| 899 | + |
902 | 900 | template <o2::soa::is_iterator TMCPhoton> |
903 | 901 | bool IsConversionPointInAcceptance(TMCPhoton const& mcphoton, float convRadius) const |
904 | 902 | { |
@@ -968,10 +966,10 @@ class V0PhotonCut : public TNamed |
968 | 966 | void SetLoadMlModelsFromCCDB(bool flag = true); |
969 | 967 | void SetNClassesMl(int nClasses); |
970 | 968 | void SetMlTimestampCCDB(int timestamp); |
971 | | - void SetCentrality(float centFT0A, float centFT0C, float centFT0M); |
| 969 | + void SetCentralityTypeMl(CentType centType); |
| 970 | + void SetCentrality(float cent); |
972 | 971 | void SetD_Bz(float d_bz); |
973 | 972 | void SetCcdbUrl(const std::string& url = "http://alice-ccdb.cern.ch"); |
974 | | - void SetCentralityTypeMl(const std::string& centType); |
975 | 973 | void SetCutDirMl(const std::vector<int>& cutDirMl); |
976 | 974 | void SetMlModelPathsCCDB(const std::vector<std::string>& modelPaths); |
977 | 975 | void SetMlOnnxFileNames(const std::vector<std::string>& onnxFileNamesVec); |
@@ -1011,22 +1009,25 @@ class V0PhotonCut : public TNamed |
1011 | 1009 | bool mLoadMlModelsFromCCDB{true}; |
1012 | 1010 | int mTimestampCCDB{-1}; |
1013 | 1011 | int mNClassesMl{static_cast<int>(o2::analysis::em_cuts_ml::NCutScores)}; |
1014 | | - float mCentFT0A{0.f}; |
1015 | | - float mCentFT0C{0.f}; |
1016 | | - float mCentFT0M{0.f}; |
| 1012 | + float mCent{0.f}; |
1017 | 1013 | float mD_Bz{0.f}; |
1018 | 1014 | std::string mCcdbUrl{"http://alice-ccdb.cern.ch"}; |
1019 | | - std::string mCentralityTypeMl{"CentFT0C"}; |
1020 | 1015 | std::vector<int> mCutDirMl{std::vector<int>{o2::analysis::em_cuts_ml::vecCutDir}}; |
1021 | 1016 | std::vector<std::string> mModelPathsCCDB{std::vector<std::string>{"path_ccdb/BDT_PCM/"}}; |
1022 | 1017 | std::vector<std::string> mOnnxFileNames{std::vector<std::string>{"ModelHandler_onnx_PCM.onnx"}}; |
1023 | 1018 | std::vector<std::string> mNamesInputFeatures{std::vector<std::string>{"feature1", "feature2"}}; |
1024 | 1019 | std::vector<std::string> mLabelsBinsMl{std::vector<std::string>{"bin 0", "bin 1"}}; |
1025 | | - std::vector<std::string> mLabelsCutScoresMl{std::vector<std::string>{"score primary photons", "score background"}}; |
| 1020 | + std::vector<std::string> mLabelsCutScoresMl{std::vector<std::string>{o2::analysis::em_cuts_ml::labelsCutScore}}; |
1026 | 1021 | std::vector<double> mBinsPtMl{std::vector<double>{o2::analysis::em_cuts_ml::vecBinsPt}}; |
1027 | 1022 | std::vector<double> mBinsCentMl{std::vector<double>{o2::analysis::em_cuts_ml::vecBinsCent}}; |
1028 | 1023 | std::vector<double> mCutsMlFlat{std::vector<double>{0.5}}; |
1029 | 1024 | o2::analysis::EmMlResponsePCM<float>* mEmMlResponse{nullptr}; |
| 1025 | + mutable bool mIsSelectedMl{false}; |
| 1026 | + mutable std::vector<float> mOutputML{}; |
| 1027 | + mutable std::vector<float> mMlInputFeatures{}; |
| 1028 | + mutable std::span<float> mMlBDTScores{}; |
| 1029 | + CentType mCentralityTypeMl{CentType::CentFT0C}; |
| 1030 | + mutable V0PhotonCandidate mV0PhotonForMl; |
1030 | 1031 |
|
1031 | 1032 | // pid cuts |
1032 | 1033 | float mMinTPCNsigmaEl{-5}, mMaxTPCNsigmaEl{+5}; |
|
0 commit comments