diff --git a/PWGHF/TableProducer/treeCreatorXicToXiPiPi.cxx b/PWGHF/TableProducer/treeCreatorXicToXiPiPi.cxx index 47748b4c30d..a21218a5eed 100644 --- a/PWGHF/TableProducer/treeCreatorXicToXiPiPi.cxx +++ b/PWGHF/TableProducer/treeCreatorXicToXiPiPi.cxx @@ -78,6 +78,9 @@ DECLARE_SOA_COLUMN(PtPi1, ptPi1, float); DECLARE_SOA_COLUMN(ImpactParameterPi1, impactParameterPi1, float); //! Normalised impact parameter of Pi1 (prong2) DECLARE_SOA_COLUMN(ImpactParameterNormalisedPi1, impactParameterNormalisedPi1, float); //! Normalised impact parameter of Pi1 (prong2) DECLARE_SOA_COLUMN(MaxNormalisedDeltaIP, maxNormalisedDeltaIP, float); //! Maximum normalized difference between measured and expected impact parameter of candidate prongs +DECLARE_SOA_COLUMN(MlScoreBkg, mlScoreBkg, float); //! ML score for background class +DECLARE_SOA_COLUMN(MlScorePrompt, mlScorePrompt, float); //! ML score for prompt signal class +DECLARE_SOA_COLUMN(MlScoreNonPrompt, mlScoreNonPrompt, float); //! ML score for non-prompt signal class (3-class model only, -1 otherwise) } // namespace full DECLARE_SOA_TABLE(HfCandXicToXiPiPiLites, "AOD", "HFXICXI2PILITE", @@ -186,6 +189,37 @@ DECLARE_SOA_TABLE(HfCandXicToXiPiPiLiteKfs, "AOD", "HFXICXI2PILITKF", hf_cand_xic_to_xi_pi_pi::DcaXYPi0Xi, hf_cand_xic_to_xi_pi_pi::DcaXYPi1Xi); +DECLARE_SOA_TABLE(HfCandXicToXiPiPiLiteMLs, "AOD", "HFXICXI2PIMLITE", + full::ParticleFlag, + hf_cand_mc_flag::OriginMcRec, + full::CandidateSelFlag, + full::Y, + full::Eta, + full::Phi, + full::P, + full::Pt, + full::M, + hf_cand_xic_to_xi_pi_pi::InvMassXi, + hf_cand_xic_to_xi_pi_pi::InvMassLambda, + full::DecayLength, + full::DecayLengthXY, + full::Cpa, + full::CpaXY, + hf_cand_xic_to_xi_pi_pi::CpaXi, + hf_cand_xic_to_xi_pi_pi::CpaXYXi, + hf_cand_xic_to_xi_pi_pi::CpaLambda, + hf_cand_xic_to_xi_pi_pi::CpaXYLambda, + full::ImpactParameterXi, + full::ImpactParameterNormalisedXi, + full::ImpactParameterPi0, + full::ImpactParameterNormalisedPi0, + full::ImpactParameterPi1, + full::ImpactParameterNormalisedPi1, + full::MaxNormalisedDeltaIP, + full::MlScoreBkg, + full::MlScorePrompt, + full::MlScoreNonPrompt); + DECLARE_SOA_TABLE(HfCandXicToXiPiPiFulls, "AOD", "HFXICXI2PIFULL", full::ParticleFlag, hf_cand_mc_flag::OriginMcRec, @@ -343,6 +377,7 @@ DECLARE_SOA_TABLE(HfCandXicToXiPiPiFullPs, "AOD", "HFXICXI2PIFULLP", struct HfTreeCreatorXicToXiPiPi { Produces rowCandidateLite; Produces rowCandidateLiteKf; + Produces rowCandidateLiteMl; Produces rowCandidateFull; Produces rowCandidateFullKf; Produces rowCandidateFullParticles; @@ -356,10 +391,14 @@ struct HfTreeCreatorXicToXiPiPi { Configurable downSampleBkgFactor{"downSampleBkgFactor", 1., "Fraction of background candidates to keep for ML trainings"}; Configurable ptMaxForDownSample{"ptMaxForDownSample", 10., "Maximum pt for the application of the downsampling factor"}; + static constexpr int kNumBinaryClasses = 2; + using SelectedCandidates = soa::Filtered>; using SelectedCandidatesKf = soa::Filtered>; + using SelectedCandidatesML = soa::Filtered>; using SelectedCandidatesMc = soa::Filtered>; using SelectedCandidatesKfMc = soa::Filtered>; + using SelectedCandidatesMcML = soa::Filtered>; using MatchedGenXicToXiPiPi = soa::Filtered>; Filter filterSelectCandidates = aod::hf_sel_candidate_xic::isSelXicToXiPiPi >= selectionFlagXic; @@ -372,9 +411,13 @@ struct HfTreeCreatorXicToXiPiPi { void init(InitContext const&) { + std::array doprocess{doprocessData, doprocessDataKf, doprocessDataWithML, doprocessMc, doprocessMcKf, doprocessMcWithML}; + if (std::accumulate(doprocess.begin(), doprocess.end(), 0) != 1) { + LOGP(fatal, "Only one process function can be enabled at a time."); + } } - template + template void fillCandidateTable(const T& candidate) { int8_t particleFlag = candidate.sign(); @@ -383,7 +426,7 @@ struct HfTreeCreatorXicToXiPiPi { particleFlag = candidate.flagMcMatchRec(); originMc = candidate.originMcRec(); } - if constexpr (!DoKf) { + if constexpr (!DoKf && !DoMl) { if (fillCandidateLiteTable) { rowCandidateLite( particleFlag, @@ -484,7 +527,7 @@ struct HfTreeCreatorXicToXiPiPi { candidate.nSigTofPiFromLambda(), candidate.nSigTofPrFromLambda()); } - } else { + } else if constexpr (DoKf) { if (fillCandidateLiteTable) { rowCandidateLiteKf( particleFlag, @@ -636,6 +679,47 @@ struct HfTreeCreatorXicToXiPiPi { candidate.dcaXYPi1Xi()); } } + if constexpr (DoMl) { + float mlScoreBkg = -1.f, mlScorePrompt = -1.f, mlScoreNonPrompt = -1.f; + const int scoreSize = static_cast(candidate.mlProbXicToXiPiPi().size()); + if (scoreSize > 0) { + mlScoreBkg = candidate.mlProbXicToXiPiPi()[0]; + mlScorePrompt = candidate.mlProbXicToXiPiPi()[1]; + if (scoreSize > kNumBinaryClasses) { + mlScoreNonPrompt = candidate.mlProbXicToXiPiPi()[2]; + } + } + rowCandidateLiteMl( + particleFlag, + originMc, + candidate.isSelXicToXiPiPi(), + candidate.y(o2::constants::physics::MassXiCPlus), + candidate.eta(), + candidate.phi(), + candidate.p(), + candidate.pt(), + candidate.invMassXicPlus(), + candidate.invMassXi(), + candidate.invMassLambda(), + candidate.decayLength(), + candidate.decayLengthXY(), + candidate.cpa(), + candidate.cpaXY(), + candidate.cpaXi(), + candidate.cpaXYXi(), + candidate.cpaLambda(), + candidate.cpaXYLambda(), + candidate.impactParameter0(), + candidate.impactParameterNormalised0(), + candidate.impactParameter1(), + candidate.impactParameterNormalised1(), + candidate.impactParameter2(), + candidate.impactParameterNormalised2(), + candidate.maxNormalisedDeltaIP(), + mlScoreBkg, + mlScorePrompt, + mlScoreNonPrompt); + } } void processData(SelectedCandidates const& candidates) @@ -653,10 +737,10 @@ struct HfTreeCreatorXicToXiPiPi { continue; } } - fillCandidateTable(candidate); + fillCandidateTable(candidate); } } - PROCESS_SWITCH(HfTreeCreatorXicToXiPiPi, processData, "Process data with DCAFitter reconstruction", true); + PROCESS_SWITCH(HfTreeCreatorXicToXiPiPi, processData, "Process data with DCAFitter reconstruction", false); void processDataKf(SelectedCandidatesKf const& candidates) { @@ -673,11 +757,22 @@ struct HfTreeCreatorXicToXiPiPi { continue; } } - fillCandidateTable(candidate); + fillCandidateTable(candidate); } } PROCESS_SWITCH(HfTreeCreatorXicToXiPiPi, processDataKf, "Process data with KFParticle reconstruction", false); + void processDataWithML(SelectedCandidatesML const& candidates) + { + // Filling candidate properties + rowCandidateLiteMl.reserve(candidates.size()); + + for (const auto& candidate : candidates) { + fillCandidateTable(candidate); + } + } + PROCESS_SWITCH(HfTreeCreatorXicToXiPiPi, processDataWithML, "Process data with DCAFitter reconstruction and ML", false); + void processMc(SelectedCandidatesMc const& candidates, MatchedGenXicToXiPiPi const& particles) { @@ -689,7 +784,7 @@ struct HfTreeCreatorXicToXiPiPi { rowCandidateFull.reserve(recSig.size()); } for (const auto& candidate : recSig) { - fillCandidateTable(candidate); + fillCandidateTable(candidate); } } else if (fillOnlyBackground) { if (fillCandidateLiteTable) { @@ -702,7 +797,7 @@ struct HfTreeCreatorXicToXiPiPi { if (candidate.pt() < ptMaxForDownSample && pseudoRndm >= downSampleBkgFactor) { continue; } - fillCandidateTable(candidate); + fillCandidateTable(candidate); } } else { if (fillCandidateLiteTable) { @@ -711,7 +806,7 @@ struct HfTreeCreatorXicToXiPiPi { rowCandidateFull.reserve(candidates.size()); } for (const auto& candidate : candidates) { - fillCandidateTable(candidate); + fillCandidateTable(candidate); } } @@ -743,7 +838,7 @@ struct HfTreeCreatorXicToXiPiPi { rowCandidateFull.reserve(recSigKf.size()); } for (const auto& candidate : recSigKf) { - fillCandidateTable(candidate); + fillCandidateTable(candidate); } } else if (fillOnlyBackground) { if (fillCandidateLiteTable) { @@ -756,7 +851,7 @@ struct HfTreeCreatorXicToXiPiPi { if (candidate.pt() < ptMaxForDownSample && pseudoRndm >= downSampleBkgFactor) { continue; } - fillCandidateTable(candidate); + fillCandidateTable(candidate); } } else { if (fillCandidateLiteTable) { @@ -765,7 +860,7 @@ struct HfTreeCreatorXicToXiPiPi { rowCandidateFull.reserve(candidates.size()); } for (const auto& candidate : candidates) { - fillCandidateTable(candidate); + fillCandidateTable(candidate); } } @@ -785,6 +880,52 @@ struct HfTreeCreatorXicToXiPiPi { } } PROCESS_SWITCH(HfTreeCreatorXicToXiPiPi, processMcKf, "Process MC with KF Particle reconstruction", false); + + void processMcWithML(SelectedCandidatesMcML const& candidates, + MatchedGenXicToXiPiPi const& particles) + { + // Filling candidate properties + rowCandidateLiteMl.reserve(candidates.size()); + if (fillOnlySignal) { + for (const auto& candidate : candidates) { + if (candidate.flagMcMatchRec() == int8_t(0)) { + continue; + } + fillCandidateTable(candidate); + } + } else if (fillOnlyBackground) { + for (const auto& candidate : candidates) { + if (candidate.flagMcMatchRec() != int8_t(0)) { + continue; + } + float const pseudoRndm = candidate.ptProng1() * 1000. - static_cast(candidate.ptProng1() * 1000); + if (candidate.pt() < ptMaxForDownSample && pseudoRndm >= downSampleBkgFactor) { + continue; + } + fillCandidateTable(candidate); + } + } else { + for (const auto& candidate : candidates) { + fillCandidateTable(candidate); + } + } + + if (fillGenParticleTable) { + rowCandidateFullParticles.reserve(particles.size()); + for (const auto& particle : particles) { + rowCandidateFullParticles( + particle.flagMcMatchGen(), + particle.originMcGen(), + particle.pdgBhadMotherPart(), + particle.pt(), + particle.eta(), + particle.phi(), + RecoDecay::y(particle.pVector(), o2::constants::physics::MassXiCPlus), + particle.decayLengthMcGen()); + } + } + } + PROCESS_SWITCH(HfTreeCreatorXicToXiPiPi, processMcWithML, "Process MC with DCAFitter reconstruction and ML", false); }; WorkflowSpec defineDataProcessing(ConfigContext const& cfgc)