Skip to content

Commit 8edc497

Browse files
committed
[RF] Use global workspace array in the generated code
This will make it easier to manage intermediate results.
1 parent 4ee5e77 commit 8edc497

6 files changed

Lines changed: 140 additions & 148 deletions

File tree

roofit/codegen/src/CodegenImpl.cxx

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,7 @@
6262

6363
#include <TInterpreter.h>
6464

65-
namespace RooFit {
66-
namespace Experimental {
65+
namespace RooFit::Experimental {
6766

6867
namespace {
6968

@@ -103,7 +102,7 @@ void rooHistTranslateImpl(RooAbsArg const &arg, CodegenContext &ctx, int intOrde
103102
}
104103

105104
std::string realSumPdfTranslateImpl(CodegenContext &ctx, RooAbsArg const &arg, RooArgList const &funcList,
106-
RooArgList const &coefList, bool normalize)
105+
RooArgList const &coefList, bool normalize, bool forceScopeIndependent)
107106
{
108107
bool noLastCoeff = funcList.size() != coefList.size();
109108

@@ -113,7 +112,12 @@ std::string realSumPdfTranslateImpl(CodegenContext &ctx, RooAbsArg const &arg, R
113112

114113
std::string sum = ctx.getTmpVarName();
115114
std::string coeffSum = ctx.getTmpVarName();
116-
ctx.addToCodeBody(&arg, "double " + sum + " = 0;\ndouble " + coeffSum + "= 0;\n");
115+
std::string code1 = "double " + sum + " = 0;\ndouble " + coeffSum + "= 0;\n";
116+
117+
if (forceScopeIndependent)
118+
ctx.addToCodeBody(code1, true);
119+
else
120+
ctx.addToCodeBody(&arg, code1);
117121

118122
std::string iterator = "i_" + ctx.getTmpVarName();
119123
std::string subscriptExpr = "[" + iterator + "]";
@@ -128,7 +132,10 @@ std::string realSumPdfTranslateImpl(CodegenContext &ctx, RooAbsArg const &arg, R
128132
} else if (normalize) {
129133
code += sum + " /= " + coeffSum + ";\n";
130134
}
131-
ctx.addToCodeBody(&arg, code);
135+
if (forceScopeIndependent)
136+
ctx.addToCodeBody(code, true);
137+
else
138+
ctx.addToCodeBody(&arg, code);
132139

133140
return sum;
134141
}
@@ -240,7 +247,7 @@ void codegenImpl(RooAbsArg &arg, CodegenContext &ctx)
240247

241248
void codegenImpl(RooAddPdf &arg, CodegenContext &ctx)
242249
{
243-
ctx.addResult(&arg, realSumPdfTranslateImpl(ctx, arg, arg.pdfList(), arg.coefList(), true));
250+
ctx.addResult(&arg, realSumPdfTranslateImpl(ctx, arg, arg.pdfList(), arg.coefList(), true, false));
244251
}
245252

246253
void codegenImpl(RooMultiVarGaussian &arg, CodegenContext &ctx)
@@ -261,30 +268,25 @@ void codegenImpl(RooMultiPdf &arg, CodegenContext &ctx)
261268
// indices MathFunc call becomes more efficient.
262269
if (numPdfs > 2) {
263270
ctx.addResult(&arg, ctx.buildCall(mathFunc("multipdf"), arg.indexCategory(), arg.getPdfList()));
271+
return;
272+
}
273+
// Ternary nested expression
274+
std::string indexExpr = ctx.getResult(arg.indexCategory());
264275

265-
std::cout << "MathFunc call used\n";
266-
267-
} else {
268-
269-
// Ternary nested expression
270-
std::string indexExpr = ctx.getResult(arg.indexCategory());
271-
272-
// int numPdfs = arg.getNumPdfs();
273-
std::string expr;
276+
// int numPdfs = arg.getNumPdfs();
277+
std::string expr;
274278

275-
for (int i = 0; i < numPdfs; ++i) {
276-
RooAbsPdf *pdf = arg.getPdf(i);
277-
std::string pdfExpr = ctx.getResult(*pdf);
279+
for (int i = 0; i < numPdfs; ++i) {
280+
RooAbsPdf *pdf = arg.getPdf(i);
281+
std::string pdfExpr = ctx.getResult(*pdf);
278282

279-
expr += "(" + indexExpr + " == " + std::to_string(i) + " ? (" + pdfExpr + ") : ";
280-
}
283+
expr += "(" + indexExpr + " == " + std::to_string(i) + " ? (" + pdfExpr + ") : ";
284+
}
281285

282-
expr += "0.0";
283-
expr += std::string(numPdfs, ')'); // Close all ternary operators
286+
expr += "0.0";
287+
expr += std::string(numPdfs, ')'); // Close all ternary operators
284288

285-
ctx.addResult(&arg, expr);
286-
std::cout << "Ternary expression call used \n";
287-
}
289+
ctx.addResult(&arg, expr);
288290
}
289291

290292
// RooCategory index added.
@@ -294,7 +296,7 @@ void codegenImpl(RooCategory &arg, CodegenContext &ctx)
294296
if (idx < 0) {
295297

296298
idx = 1;
297-
ctx.addVecObs(arg.GetName(), idx);
299+
ctx.addVecObs(arg.GetName(), idx, 1);
298300
}
299301

300302
std::string result = std::to_string(arg.getCurrentIndex());
@@ -305,6 +307,7 @@ void codegenImpl(RooAddition &arg, CodegenContext &ctx)
305307
{
306308
if (arg.list().empty()) {
307309
ctx.addResult(&arg, "0.0");
310+
return;
308311
}
309312
std::string result;
310313
if (arg.list().size() > 1)
@@ -469,7 +472,6 @@ void codegenImpl(RooFit::Detail::RooNLLVarNew &arg, CodegenContext &ctx)
469472

470473
std::string weightSumName = RooFit::Detail::makeValidVarName(arg.GetName()) + "WeightSum";
471474
std::string resName = RooFit::Detail::makeValidVarName(arg.GetName()) + "Result";
472-
ctx.addResult(&arg, resName);
473475
ctx.addToGlobalScope("double " + weightSumName + " = 0.0;\n");
474476
ctx.addToGlobalScope("double " + resName + " = 0.0;\n");
475477

@@ -496,6 +498,8 @@ void codegenImpl(RooFit::Detail::RooNLLVarNew &arg, CodegenContext &ctx)
496498
std::string expected = ctx.getResult(*arg.expectedEvents());
497499
ctx.addToCodeBody(resName + " += " + expected + " - " + weightSumName + " * std::log(" + expected + ");\n");
498500
}
501+
502+
ctx.addResult(&arg, resName);
499503
}
500504

501505
void codegenImpl(RooFit::Detail::RooNormalizedPdf &arg, CodegenContext &ctx)
@@ -609,17 +613,17 @@ void codegenImpl(RooRealIntegral &arg, CodegenContext &ctx)
609613
auto &intVar = static_cast<RooAbsRealLValue &>(*arg.numIntRealVars()[0]);
610614

611615
std::string obsName = ctx.getTmpVarName();
612-
std::string oldIntVarResult = ctx.getResult(intVar);
613-
ctx.addResult(&intVar, "obs[0]");
614616

617+
auto oldVecObsInfo = ctx._vecObsIndices[intVar.namePtr()];
618+
ctx.addVecObs(intVar.GetName(), 0, 1);
615619
std::string funcName = ctx.buildFunction(arg.integrand(), {});
620+
ctx._vecObsIndices[intVar.namePtr()] = oldVecObsInfo;
616621

617622
std::stringstream ss;
618623

619624
ss << "double " << obsName << "[1];\n";
620625

621626
std::string resName = RooFit::Detail::makeValidVarName(arg.GetName()) + "Result";
622-
ctx.addResult(&arg, resName);
623627
ctx.addToGlobalScope("double " + resName + " = 0.0;\n");
624628

625629
// TODO: once Clad has support for higher-order functions (follow also the
@@ -640,24 +644,21 @@ void codegenImpl(RooRealIntegral &arg, CodegenContext &ctx)
640644

641645
ctx.addToGlobalScope(ss.str());
642646

643-
ctx.addResult(&intVar, oldIntVarResult);
647+
ctx.addResult(&arg, resName);
644648
}
645649

646650
void codegenImpl(RooRealSumFunc &arg, CodegenContext &ctx)
647651
{
648-
ctx.addResult(&arg, realSumPdfTranslateImpl(ctx, arg, arg.funcList(), arg.coefList(), false));
652+
ctx.addResult(&arg, realSumPdfTranslateImpl(ctx, arg, arg.funcList(), arg.coefList(), false, false));
649653
}
650654

651655
void codegenImpl(RooRealSumPdf &arg, CodegenContext &ctx)
652656
{
653-
ctx.addResult(&arg, realSumPdfTranslateImpl(ctx, arg, arg.funcList(), arg.coefList(), false));
657+
ctx.addResult(&arg, realSumPdfTranslateImpl(ctx, arg, arg.funcList(), arg.coefList(), false, false));
654658
}
655659

656660
void codegenImpl(RooRealVar &arg, CodegenContext &ctx)
657661
{
658-
if (!arg.isConstant()) {
659-
ctx.addResult(&arg, arg.GetName());
660-
}
661662
ctx.addResult(&arg, doubleToString(arg.getVal()));
662663
}
663664

@@ -898,7 +899,7 @@ std::string codegenIntegralImpl(RooPolynomial &arg, int, const char *rangeName,
898899
std::string codegenIntegralImpl(RooRealSumPdf &arg, int code, const char *rangeName, CodegenContext &ctx)
899900
{
900901
// Re-use translate, since integration is linear.
901-
return realSumPdfTranslateImpl(ctx, arg, arg.funcIntListFromCache(code, rangeName), arg.coefList(), false);
902+
return realSumPdfTranslateImpl(ctx, arg, arg.funcIntListFromCache(code, rangeName), arg.coefList(), false, true);
902903
}
903904

904905
std::string codegenIntegralImpl(RooUniform &arg, int code, const char *rangeName, CodegenContext &)
@@ -908,5 +909,4 @@ std::string codegenIntegralImpl(RooUniform &arg, int code, const char *rangeName
908909
return doubleToString(arg.analyticalIntegral(code, rangeName));
909910
}
910911

911-
} // namespace Experimental
912-
} // namespace RooFit
912+
} // namespace RooFit::Experimental

roofit/roofitcore/inc/RooFit/CodegenContext.h

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
template <class T>
3131
class RooTemplateProxy;
3232

33-
namespace RooFit {
34-
namespace Experimental {
33+
namespace RooFit::Experimental {
3534

3635
template <int P>
3736
struct Prio {
@@ -46,12 +45,11 @@ using PrioLowest = Prio<10>;
4645
class CodegenContext {
4746
public:
4847
void addResult(RooAbsArg const *key, std::string const &value);
49-
void addResult(const char *key, std::string const &value);
5048

51-
std::string const &getResult(RooAbsArg const &arg);
49+
std::string getResult(RooAbsArg const &arg);
5250

5351
template <class T>
54-
std::string const &getResult(RooTemplateProxy<T> const &key)
52+
std::string getResult(RooTemplateProxy<T> const &key)
5553
{
5654
return getResult(key.arg());
5755
}
@@ -69,7 +67,8 @@ class CodegenContext {
6967
}
7068

7169
void addToGlobalScope(std::string const &str);
72-
void addVecObs(const char *key, int idx);
70+
void addVecObs(const char *key, int idx, std::size_t size);
71+
void addParam(const RooAbsArg *key, int idx);
7372
int observableIndexOf(const RooAbsArg &arg) const;
7473

7574
void addToCodeBody(RooAbsArg const *klass, std::string const &in);
@@ -135,6 +134,13 @@ class CodegenContext {
135134
};
136135
ScopeRAII OutputScopeRangeComment(RooAbsArg const *arg) { return {arg, *this}; }
137136

137+
/// @brief Map of node names to their result strings.
138+
std::unordered_map<const TNamed *, std::size_t> _nodeNames;
139+
std::size_t _nWksp = 0;
140+
std::unordered_map<const RooAbsArg *, int> _paramIndices;
141+
/// @brief A map to keep track of the observable indices if they are non scalar.
142+
std::unordered_map<const TNamed *, std::pair<int, std::size_t>> _vecObsIndices;
143+
138144
private:
139145
void pushScope();
140146
void popScope();
@@ -145,8 +151,6 @@ class CodegenContext {
145151

146152
void endLoop(LoopScope const &scope);
147153

148-
void addResult(TNamed const *key, std::string const &value);
149-
150154
template <class T, typename std::enable_if<std::is_floating_point<T>{}, bool>::type = true>
151155
std::string buildArg(T x)
152156
{
@@ -191,10 +195,6 @@ class CodegenContext {
191195
template <class T>
192196
std::string typeName() const;
193197

194-
/// @brief Map of node names to their result strings.
195-
std::unordered_map<const TNamed *, std::string> _nodeNames;
196-
/// @brief A map to keep track of the observable indices if they are non scalar.
197-
std::unordered_map<const TNamed *, int> _vecObsIndices;
198198
/// @brief Map of node output sizes.
199199
std::map<RooFit::Detail::DataKey, std::size_t> _nodeOutputSizes;
200200
/// @brief The code layered by lexical scopes used as a stack.
@@ -203,8 +203,6 @@ class CodegenContext {
203203
unsigned _indent = 0;
204204
/// @brief Index to get unique names for temporary variables.
205205
mutable int _tmpVarIdx = 0;
206-
/// @brief A map to keep track of list names as assigned by addResult.
207-
std::unordered_map<RooFit::UniqueId<RooAbsCollection>::Value_t, std::string> _listNames;
208206
std::vector<double> _xlArr;
209207
std::vector<std::string> _collectedFunctions;
210208
};
@@ -242,7 +240,6 @@ void declareDispatcherCode(std::string const &funcName);
242240

243241
void codegen(RooAbsArg &arg, CodegenContext &ctx);
244242

245-
} // namespace Experimental
246-
} // namespace RooFit
243+
} // namespace RooFit::Experimental
247244

248245
#endif

roofit/roofitcore/inc/RooFit/Detail/MathFuncs.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,23 @@ double stepFunctionIntegral(double xmin, double xmax, std::size_t nBins, DoubleA
812812

813813
} // namespace RooFit::Detail::MathFuncs
814814

815+
inline void fillFromWorkspace(double *out, std::size_t n, double const *wksp, double const *idx)
816+
{
817+
for (std::size_t i = 0; i < n; ++i) {
818+
out[i] += wksp[static_cast<int>(idx[i])];
819+
}
820+
}
821+
815822
namespace clad::custom_derivatives {
823+
824+
inline void fillFromWorkspace_pullback(double *, std::size_t n, double const *, double const *idx, double *d_out,
825+
std::size_t *, double *d_wksp, double *)
826+
{
827+
for (std::size_t i = 0; i < n; ++i) {
828+
d_wksp[static_cast<int>(idx[i])] += d_out[i];
829+
}
830+
}
831+
816832
namespace RooFit::Detail::MathFuncs {
817833

818834
// Clad can't generate the pullback for binNumber because of the
@@ -826,6 +842,7 @@ void binNumber_pullback(Types...)
826842
}
827843

828844
} // namespace RooFit::Detail::MathFuncs
845+
829846
} // namespace clad::custom_derivatives
830847

831848
#endif

roofit/roofitcore/src/RooEvaluatorWrapper.cxx

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -239,21 +239,13 @@ RooFuncWrapper::RooFuncWrapper(RooAbsReal &obj, const RooAbsData *data, RooSimul
239239
// First update the result variable of params in the compute graph to in[<position>].
240240
int idx = 0;
241241
for (RooAbsArg *param : _params) {
242-
ctx.addResult(param, "params[" + std::to_string(idx) + "]");
242+
ctx.addParam(param, idx);
243243
idx++;
244244
}
245245

246246
for (auto const &item : _obsInfos) {
247247
const char *obsName = item.first->GetName();
248-
// If the observable is scalar, set name to the start idx. else, store
249-
// the start idx and later set the the name to obs[start_idx + curr_idx],
250-
// here curr_idx is defined by a loop producing parent node.
251-
if (item.second.size == 1) {
252-
ctx.addResult(obsName, "obs[" + std::to_string(item.second.idx) + "]");
253-
} else {
254-
ctx.addResult(obsName, "obs");
255-
ctx.addVecObs(obsName, item.second.idx);
256-
}
248+
ctx.addVecObs(obsName, item.second.idx, item.second.size);
257249
}
258250

259251
gInterpreter->Declare("#pragma cling optimize(2)");

0 commit comments

Comments
 (0)