|
40 | 40 | #include <cstdint> |
41 | 41 | #include <functional> |
42 | 42 | #include <set> |
| 43 | +#include <span> |
43 | 44 | #include <string> |
44 | 45 | #include <utility> |
45 | 46 | #include <vector> |
@@ -145,7 +146,7 @@ static const std::vector<std::string> labelsCent = { |
145 | 146 | "Cent bin 10"}; |
146 | 147 |
|
147 | 148 | // column labels |
148 | | -static const std::vector<std::string> labelsCutScore = {"score primary photons", "score background"}; |
| 149 | +static const std::vector<std::string> labelsCutScore = {"score background", "score primary photons"}; |
149 | 150 | } // namespace em_cuts_ml |
150 | 151 |
|
151 | 152 | } // namespace o2::analysis |
@@ -584,30 +585,22 @@ class V0PhotonCut : public TNamed |
584 | 585 | } |
585 | 586 | } |
586 | 587 | if (mApplyMlCuts) { |
587 | | - if (!mEmMlResponse) { |
| 588 | + if (mEmMlResponse == nullptr) { |
588 | 589 | LOG(error) << "EM ML Response is not initialized!"; |
589 | 590 | return false; |
590 | 591 | } |
591 | | - bool mIsSelectedMl = false; |
592 | | - std::vector<float> mOutputML; |
593 | | - V0PhotonCandidate v0photoncandidate(v0, pos, ele, mCentFT0A, mCentFT0C, mCentFT0M, mD_Bz); |
594 | | - std::vector<float> mlInputFeatures = mEmMlResponse->getInputFeatures(v0photoncandidate, pos, ele); |
| 592 | + mIsSelectedMl = false; |
| 593 | + mV0PhotonForMl.setPhoton(v0, pos, ele, mCent, mCentralityTypeMl); |
| 594 | + mMlInputFeatures = mEmMlResponse->getInputFeatures(mV0PhotonForMl, pos, ele); |
595 | 595 | if (mUse2DBinning) { |
596 | | - if (mCentralityTypeMl == "CentFT0C") { |
597 | | - mIsSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.getPt(), v0photoncandidate.getCentFT0C(), mOutputML); |
598 | | - } else if (mCentralityTypeMl == "CentFT0A") { |
599 | | - mIsSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.getPt(), v0photoncandidate.getCentFT0A(), mOutputML); |
600 | | - } else if (mCentralityTypeMl == "CentFT0M") { |
601 | | - mIsSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.getPt(), v0photoncandidate.getCentFT0M(), mOutputML); |
602 | | - } else { |
603 | | - LOG(fatal) << "Unsupported centTypePCMMl: " << mCentralityTypeMl << " , please choose from CentFT0C, CentFT0A, CentFT0M."; |
604 | | - } |
| 596 | + mIsSelectedMl = mEmMlResponse->isSelectedMl(mMlInputFeatures, mV0PhotonForMl.getPt(), mV0PhotonForMl.getCent(), mOutputML); |
605 | 597 | } else { |
606 | | - mIsSelectedMl = mEmMlResponse->isSelectedMl(mlInputFeatures, v0photoncandidate.getPt(), mOutputML); |
| 598 | + mIsSelectedMl = mEmMlResponse->isSelectedMl(mMlInputFeatures, mV0PhotonForMl.getPt(), mOutputML); |
607 | 599 | } |
608 | 600 | if (!mIsSelectedMl) { |
609 | 601 | return false; |
610 | 602 | } |
| 603 | + mMlBDTScores = std::span<float>(mOutputML.data(), mOutputML.size()); |
611 | 604 | } |
612 | 605 | if (doQA) { |
613 | 606 | fillAfterPhotonHistogram(v0, pos, ele, fRegistry); |
@@ -860,7 +853,7 @@ class V0PhotonCut : public TNamed |
860 | 853 |
|
861 | 854 | void initV0MlModels(o2::ccdb::CcdbApi& ccdbApi) |
862 | 855 | { |
863 | | - if (!mEmMlResponse) { |
| 856 | + if (mEmMlResponse == nullptr) { |
864 | 857 | mEmMlResponse = new o2::analysis::EmMlResponsePCM<float>(); |
865 | 858 | } |
866 | 859 | if (mUse2DBinning) { |
@@ -914,6 +907,11 @@ class V0PhotonCut : public TNamed |
914 | 907 | mEmMlResponse->init(); |
915 | 908 | } |
916 | 909 |
|
| 910 | + const std::span<float> getBDTValue() const |
| 911 | + { |
| 912 | + return mMlBDTScores; |
| 913 | + } |
| 914 | + |
917 | 915 | template <o2::soa::is_iterator TMCPhoton> |
918 | 916 | bool IsConversionPointInAcceptance(TMCPhoton const& mcphoton, float convRadius) const |
919 | 917 | { |
@@ -983,10 +981,10 @@ class V0PhotonCut : public TNamed |
983 | 981 | void SetLoadMlModelsFromCCDB(bool flag = true); |
984 | 982 | void SetNClassesMl(int nClasses); |
985 | 983 | void SetMlTimestampCCDB(int timestamp); |
986 | | - void SetCentrality(float centFT0A, float centFT0C, float centFT0M); |
| 984 | + void SetCentralityTypeMl(CentType centType); |
| 985 | + void SetCentrality(float cent); |
987 | 986 | void SetD_Bz(float d_bz); |
988 | 987 | void SetCcdbUrl(const std::string& url = "http://alice-ccdb.cern.ch"); |
989 | | - void SetCentralityTypeMl(const std::string& centType); |
990 | 988 | void SetCutDirMl(const std::vector<int>& cutDirMl); |
991 | 989 | void SetMlModelPathsCCDB(const std::vector<std::string>& modelPaths); |
992 | 990 | void SetMlOnnxFileNames(const std::vector<std::string>& onnxFileNamesVec); |
@@ -1026,22 +1024,25 @@ class V0PhotonCut : public TNamed |
1026 | 1024 | bool mLoadMlModelsFromCCDB{true}; |
1027 | 1025 | int mTimestampCCDB{-1}; |
1028 | 1026 | int mNClassesMl{static_cast<int>(o2::analysis::em_cuts_ml::NCutScores)}; |
1029 | | - float mCentFT0A{0.f}; |
1030 | | - float mCentFT0C{0.f}; |
1031 | | - float mCentFT0M{0.f}; |
| 1027 | + float mCent{0.f}; |
1032 | 1028 | float mD_Bz{0.f}; |
1033 | 1029 | std::string mCcdbUrl{"http://alice-ccdb.cern.ch"}; |
1034 | | - std::string mCentralityTypeMl{"CentFT0C"}; |
1035 | 1030 | std::vector<int> mCutDirMl{std::vector<int>{o2::analysis::em_cuts_ml::vecCutDir}}; |
1036 | 1031 | std::vector<std::string> mModelPathsCCDB{std::vector<std::string>{"path_ccdb/BDT_PCM/"}}; |
1037 | 1032 | std::vector<std::string> mOnnxFileNames{std::vector<std::string>{"ModelHandler_onnx_PCM.onnx"}}; |
1038 | 1033 | std::vector<std::string> mNamesInputFeatures{std::vector<std::string>{"feature1", "feature2"}}; |
1039 | 1034 | std::vector<std::string> mLabelsBinsMl{std::vector<std::string>{"bin 0", "bin 1"}}; |
1040 | | - std::vector<std::string> mLabelsCutScoresMl{std::vector<std::string>{"score primary photons", "score background"}}; |
| 1035 | + std::vector<std::string> mLabelsCutScoresMl{std::vector<std::string>{o2::analysis::em_cuts_ml::labelsCutScore}}; |
1041 | 1036 | std::vector<double> mBinsPtMl{std::vector<double>{o2::analysis::em_cuts_ml::vecBinsPt}}; |
1042 | 1037 | std::vector<double> mBinsCentMl{std::vector<double>{o2::analysis::em_cuts_ml::vecBinsCent}}; |
1043 | 1038 | std::vector<double> mCutsMlFlat{std::vector<double>{0.5}}; |
1044 | 1039 | o2::analysis::EmMlResponsePCM<float>* mEmMlResponse{nullptr}; |
| 1040 | + mutable bool mIsSelectedMl{false}; |
| 1041 | + mutable std::vector<float> mOutputML{}; |
| 1042 | + mutable std::vector<float> mMlInputFeatures{}; |
| 1043 | + mutable std::span<float> mMlBDTScores{}; |
| 1044 | + CentType mCentralityTypeMl{CentType::CentFT0C}; |
| 1045 | + mutable V0PhotonCandidate mV0PhotonForMl; |
1045 | 1046 |
|
1046 | 1047 | // pid cuts |
1047 | 1048 | float mMinTPCNsigmaEl{-5}, mMaxTPCNsigmaEl{+5}; |
|
0 commit comments