forked from AliceO2Group/O2Physics
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsimpleApplyPidOnnxInterface.cxx
More file actions
114 lines (97 loc) · 5.67 KB
/
simpleApplyPidOnnxInterface.cxx
File metadata and controls
114 lines (97 loc) · 5.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
// Copyright 2019-2020 CERN and copyright holders of ALICE O2.
// See https://alice-o2.web.cern.ch/copyright for details of the copyright holders.
// All rights not expressly granted are reserved.
//
// This software is distributed under the terms of the GNU General Public
// License v3 (GPL Version 3), copied verbatim in the file "COPYING".
//
// In applying this license CERN does not waive the privileges and immunities
// granted to it by virtue of its status as an Intergovernmental Organization
// or submit itself to any jurisdiction.
/// \file simpleApplyPidInterface
/// \brief A simple example for using PID obtained from the PID ML ONNX Interface. See README.md for more detailed instructions.
///
/// \author Maja Kabus <mkabus@cern.ch>
#include <string>
#include "Framework/runDataProcessing.h"
#include "Framework/AnalysisTask.h"
#include "CCDB/CcdbApi.h"
#include "Common/DataModel/TrackSelectionTables.h"
#include "Common/DataModel/PIDResponse.h"
#include "Tools/PIDML/pidOnnxInterface.h"
using namespace o2;
using namespace o2::framework;
using namespace o2::framework::expressions;
namespace o2::aod
{
namespace mlpidresult
{
DECLARE_SOA_INDEX_COLUMN(Track, track); //! Track index
DECLARE_SOA_COLUMN(Pid, pid, int); //! Pid to be tested by the model
DECLARE_SOA_COLUMN(Accepted, accepted, bool); //! Whether the model accepted particle to be of given kind
} // namespace mlpidresult
DECLARE_SOA_TABLE(MlPidResults, "AOD", "MLPIDRESULTS", o2::soa::Index<>, mlpidresult::TrackId, mlpidresult::Pid, mlpidresult::Accepted);
} // namespace o2::aod
struct SimpleApplyOnnxInterface {
Configurable<LabeledArray<double>> cfgPTCuts{"pT_cuts", {pidml_pt_cuts::cuts[0], pidml_pt_cuts::nPids, pidml_pt_cuts::nCutVars, pidml_pt_cuts::pidLabels, pidml_pt_cuts::cutVarLabels}, "pT cuts for each output pid and each detector configuration"};
Configurable<std::vector<int>> cfgPids{"pids", std::vector<int>{pidml_pt_cuts::pids_v}, "PIDs to predict"};
Configurable<std::vector<double>> cfgCertainties{"certainties", std::vector<double>{pidml_pt_cuts::certainties_v}, "Min certainties of the models to accept given particle to be of given kind"};
Configurable<bool> cfgAutoMode{"autoMode", true, "Use automatic model matching: default pT cuts and min certainties"};
Configurable<std::string> cfgPathCCDB{"ccdb-path", "Users/m/mkabus/PIDML", "base path to the CCDB directory with ONNX models"};
Configurable<std::string> cfgCCDBURL{"ccdb-url", "http://alice-ccdb.cern.ch", "URL of the CCDB repository"};
Configurable<bool> cfgUseCCDB{"useCCDB", true, "Whether to autofetch ML model from CCDB. If false, local file will be used."};
Configurable<std::string> cfgPathLocal{"local-path", "/home/mkabus/PIDML/", "base path to the local directory with ONNX models"};
Configurable<bool> cfgUseFixedTimestamp{"use-fixed-timestamp", false, "Whether to use fixed timestamp from configurable instead of timestamp calculated from the data"};
Configurable<uint64_t> cfgTimestamp{"timestamp", 1524176895000, "Hardcoded timestamp for tests"};
o2::ccdb::CcdbApi ccdbApi;
int currentRunNumber = -1;
Produces<o2::aod::MlPidResults> pidMLResults;
Filter trackFilter = requireGlobalTrackInFilter();
// Minimum table requirements for sample model:
// TPC signal (FullTracks), TOF signal (TOFSignal), TOF beta (pidTOFbeta), dcaXY and dcaZ (TracksDCA)
// Filter on isGlobalTrack (TracksSelection)
using BigTracks = soa::Filtered<soa::Join<aod::FullTracks, aod::TracksDCA, aod::pidTOFbeta, aod::TrackSelection, aod::TOFSignal>>;
PidONNXInterface<BigTracks> pidInterface; // One instance to manage all needed ONNX models
void init(InitContext const&)
{
if (cfgUseCCDB) {
ccdbApi.init(cfgCCDBURL);
} else {
pidInterface = PidONNXInterface<BigTracks>(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, -1, cfgPids.value, cfgPTCuts.value, cfgCertainties.value, cfgAutoMode.value);
}
}
void processCollisions(aod::Collisions const& collisions, BigTracks const& tracks, aod::BCsWithTimestamps const&)
{
auto bc = collisions.iteratorAt(0).bc_as<aod::BCsWithTimestamps>();
if (cfgUseCCDB && bc.runNumber() != currentRunNumber) {
uint64_t timestamp = cfgUseFixedTimestamp ? cfgTimestamp.value : bc.timestamp();
pidInterface = PidONNXInterface<BigTracks>(cfgPathLocal.value, cfgPathCCDB.value, cfgUseCCDB.value, ccdbApi, timestamp, cfgPids.value, cfgPTCuts.value, cfgCertainties.value, cfgAutoMode.value);
}
for (auto& track : tracks) {
for (int pid : cfgPids.value) {
bool accepted = pidInterface.applyModelBoolean(track, pid);
LOGF(info, "collision id: %d track id: %d pid: %d accepted: %d p: %.3f; x: %.3f, y: %.3f, z: %.3f",
track.collisionId(), track.index(), pid, accepted, track.p(), track.x(), track.y(), track.z());
pidMLResults(track.index(), pid, accepted);
}
}
}
PROCESS_SWITCH(SimpleApplyOnnxInterface, processCollisions, "Process with collisions and bcs for CCDB", true);
void processTracksOnly(BigTracks const& tracks)
{
for (auto& track : tracks) {
for (int pid : cfgPids.value) {
bool accepted = pidInterface.applyModelBoolean(track, pid);
LOGF(info, "collision id: %d track id: %d pid: %d accepted: %d p: %.3f; x: %.3f, y: %.3f, z: %.3f",
track.collisionId(), track.index(), pid, accepted, track.p(), track.x(), track.y(), track.z());
pidMLResults(track.index(), pid, accepted);
}
}
}
PROCESS_SWITCH(SimpleApplyOnnxInterface, processTracksOnly, "Process with tracks only -- faster but no CCDB", false);
};
WorkflowSpec defineDataProcessing(ConfigContext const& cfgc)
{
return WorkflowSpec{
adaptAnalysisTask<SimpleApplyOnnxInterface>(cfgc)};
}