diff --git a/roofit/roofitcore/inc/LinkDef.h b/roofit/roofitcore/inc/LinkDef.h index f805db5d3af8c..ad1c59a92909a 100644 --- a/roofit/roofitcore/inc/LinkDef.h +++ b/roofit/roofitcore/inc/LinkDef.h @@ -249,6 +249,7 @@ #pragma link C++ class RooWorkspace- ; #pragma link C++ class RooWorkspace::CodeRepo- ; #pragma link C++ class RooWorkspace::WSDir+ ; +#pragma link C++ class RooWorkspace::EmbeddedHisto+ ; #pragma link C++ class RooWorkspaceHandle+; #pragma link C++ class std::list+ ; #pragma link C++ class std::list+ ; diff --git a/roofit/roofitcore/inc/RooDataHist.h b/roofit/roofitcore/inc/RooDataHist.h index 1daea0809fdcb..2c9523e3afb57 100644 --- a/roofit/roofitcore/inc/RooDataHist.h +++ b/roofit/roofitcore/inc/RooDataHist.h @@ -265,6 +265,8 @@ class RooDataHist : public RooAbsData, public RooDirItem { mutable double _cache_sum{0.}; /// xVals, bool correctForBinSize, bool cdfBoundaries); void interpolateLinear(double* output, std::span xVals, bool correctForBinSize, bool cdfBoundaries); double weightInterpolated(const RooArgSet& bin, int intOrder, bool correctForBinSize, bool cdfBoundaries); diff --git a/roofit/roofitcore/inc/RooHistFunc.h b/roofit/roofitcore/inc/RooHistFunc.h index 8819001aaac78..bca1650eda05d 100644 --- a/roofit/roofitcore/inc/RooHistFunc.h +++ b/roofit/roofitcore/inc/RooHistFunc.h @@ -109,6 +109,7 @@ class RooHistFunc : public RooAbsReal { double evaluate() const override; void doEval(RooFit::EvalContext &) const override; friend class RooAbsCachedReal ; + friend class RooWorkspace ; void ioStreamerPass2() override ; diff --git a/roofit/roofitcore/inc/RooHistPdf.h b/roofit/roofitcore/inc/RooHistPdf.h index bd50afddf3d58..42cffc2bf951e 100644 --- a/roofit/roofitcore/inc/RooHistPdf.h +++ b/roofit/roofitcore/inc/RooHistPdf.h @@ -104,9 +104,11 @@ class RooHistPdf : public RooAbsPdf { double evaluate() const override; double totalVolume() const ; - friend class RooAbsCachedPdf ; double totVolume() const ; + friend class RooAbsCachedPdf ; + friend class RooWorkspace ; + RooArgSet _histObsList; ///< List of observables defining dimensions of histogram RooSetProxy _pdfObsList; ///< List of observables mapped onto histogram observables RooDataHist* _dataHist = nullptr; ///< Unowned pointer to underlying histogram diff --git a/roofit/roofitcore/inc/RooWorkspace.h b/roofit/roofitcore/inc/RooWorkspace.h index dc7f044737191..7a64010246169 100644 --- a/roofit/roofitcore/inc/RooWorkspace.h +++ b/roofit/roofitcore/inc/RooWorkspace.h @@ -36,6 +36,8 @@ class RooAbsReal ; class RooAbsCategory ; class RooFactoryWSTool ; class RooAbsStudy ; +class RooDataHist ; +class RooAbsBinning ; #include "TNamed.h" #include "TDirectoryFile.h" @@ -165,6 +167,24 @@ class RooWorkspace : public TNamed { RooExpensiveObjectCache& expensiveObjectCache() { return _eocache ; } + /// Internal class that can pack all the information in an embedded + /// RooDataHist for smaller workspace sizes on disk. + struct EmbeddedHisto : public TObject { + std::string name; + std::string title; + std::string argName; + + int arraySize; + std::vector weightArray; + std::vector wgtErrLoArray; + std::vector wgtErrHiArray; + std::vector sumW2Array; + + std::vector> binnings; + + ClassDefOverride(EmbeddedHisto, 1); + }; + class CodeRepo : public TObject { public: CodeRepo(RooWorkspace* wspace=nullptr) : _wspace(wspace), _compiledOK(true) {} ; @@ -245,8 +265,13 @@ class RooWorkspace : public TNamed { friend class RooAbsArg; friend class RooAbsPdf; friend class RooConstraintSum; + bool defineSetInternal(const char *name, const RooArgSet &aset); + void packEmbeddedHisto(RooAbsArg const &arg, RooDataHist const &dataHist, EmbeddedHisto &packed); + + RooDataHist *unpackEmbeddedHisto(RooArgSet const &vars, EmbeddedHisto const &packed); + friend class CodeRepo; static std::list _classDeclDirList; static std::list _classImplDirList; diff --git a/roofit/roofitcore/src/RooWorkspace.cxx b/roofit/roofitcore/src/RooWorkspace.cxx index 866e982891b25..70c69ac615442 100644 --- a/roofit/roofitcore/src/RooWorkspace.cxx +++ b/roofit/roofitcore/src/RooWorkspace.cxx @@ -54,6 +54,8 @@ and try reading again. #include #include #include +#include +#include #include #include #include @@ -2432,6 +2434,69 @@ void RooWorkspace::CodeRepo::Streamer(TBuffer &R__b) } } +void RooWorkspace::packEmbeddedHisto(RooAbsArg const &arg, RooDataHist const &dataHist, + RooWorkspace::EmbeddedHisto &packed) +{ + // We error out if the dataHist is not as we expect it to be for a template. + // Which is in case of: + // * asymmetric uncertainties + // * global observables + if (dataHist.wgtErrLoArray() || dataHist.wgtErrHiArray() || dataHist.getGlobalObservables()) { + throw std::runtime_error("Template histogram for \"" + std::string{arg.GetName()} + "\" can't be serialized!"); + } + + packed.argName = arg.GetName(); + packed.name = dataHist.GetName(); + packed.title = dataHist.GetTitle(); + + packed.arraySize = dataHist.arraySize(); + + auto fillVector = [&](double const *arr, std::vector &vec) { + if (arr) { + vec.resize(packed.arraySize); + std::copy(arr, arr + vec.size(), vec.begin()); + } + }; + + fillVector(dataHist.weightArray(), packed.weightArray); + fillVector(dataHist.sumW2Array(), packed.sumW2Array); + fillVector(dataHist.wgtErrLoArray(), packed.wgtErrLoArray); + fillVector(dataHist.wgtErrHiArray(), packed.wgtErrHiArray); + + for (auto const &binning : dataHist.getBinnings()) { + packed.binnings.emplace_back(binning ? binning->clone() : nullptr); + } +} + +RooDataHist *RooWorkspace::unpackEmbeddedHisto(RooArgSet const &vars, RooWorkspace::EmbeddedHisto const &packed) +{ + RooArgSet varsCopy; + vars.snapshot(varsCopy); + + for (std::size_t i = 0; i < vars.size(); ++i) { + if (packed.binnings[i]) { + static_cast(varsCopy[i])->setBinning(*packed.binnings[i]); + } + } + + auto *dataHist = new RooDataHist{packed.name, packed.title, varsCopy}; + + int n = packed.arraySize; + + for (int i = 0; i < n; ++i) { + dataHist->set(i, packed.weightArray[i], -1); + } + + // Set the sum of weights squared + if (!packed.sumW2Array.empty()) { + dataHist->_sumw2 = new double[n]; + std::copy(packed.sumW2Array.begin(), packed.sumW2Array.end(), dataHist->_sumw2); + } + + dataHist->registerWeightArraysToDataStore(); + + return dataHist; +} //////////////////////////////////////////////////////////////////////////////// /// Stream an object of class RooWorkspace. This is a standard ROOT streamer for the @@ -2475,6 +2540,29 @@ void RooWorkspace::Streamer(TBuffer &R__b) } } + for (auto *packed : dynamic_range_cast(_embeddedDataList)) { + // If the type was not RooWorkspace::EmbeddedHisto, it was an old workspace + // where the RooDataHists were stored directly. So nothing to do. + if (packed) { + RooAbsArg *arg = &_allOwnedNodes[packed->argName.c_str()]; + + auto *histPdf = dynamic_cast(arg); + auto *histFunc = dynamic_cast(arg); + + RooArgSet const &vars = histPdf ? histPdf->variables() : histFunc->variables(); + + RooDataHist *dataHist = unpackEmbeddedHisto(vars, *packed); + + if (histPdf) + histPdf->_dataHist = dataHist; + if (histFunc) + histFunc->_dataHist = dataHist; + + _embeddedDataList.Replace(packed, dataHist); + delete packed; + } + } + } else { // Make lists of external clients of WS objects, and remove those links temporarily @@ -2531,8 +2619,52 @@ void RooWorkspace::Streamer(TBuffer &R__b) } } + // Temporary container to hold converted embedded datasets during serialization + RooLinkedList embeddedDataList; + + // Pack the embedded RooDataHists + std::vector> associated; + for (RooAbsArg *arg : _allOwnedNodes) { + auto *histPdf = dynamic_cast(arg); + auto *histFunc = dynamic_cast(arg); + if (!histPdf && !histFunc) continue; + RooDataHist &dataHist = histPdf ? histPdf->dataHist() : histFunc->dataHist(); + + // We have to temporarily nullify the dataHists that the + // HistFuncs/Pdfs point to, so they don't get serialized. + // Remember the association and reset later. + if(histPdf) histPdf->_dataHist = nullptr; + if(histFunc) histFunc->_dataHist = nullptr; + associated.emplace_back(arg, &dataHist); + + auto *dataHistPacked = new RooWorkspace::EmbeddedHisto; + packEmbeddedHisto(*arg, dataHist, *dataHistPacked); + embeddedDataList.Add(dataHistPacked); + } + + // Validate that we have now packed all the embedded RooDataHists + if (embeddedDataList.size() != _embeddedDataList.size()) { + throw std::runtime_error("There were unexpected embedded datasets!"); + } + + // Temporarily replace _embeddedDataList with the serialized version for writing + std::swap(embeddedDataList, _embeddedDataList); + R__b.WriteClassBuffer(RooWorkspace::Class(), this); + // Restore original _embeddedDataList after serialization is complete + std::swap(embeddedDataList, _embeddedDataList); + + embeddedDataList.Delete(); + + // Reset the histFuncs to the RooDataHists/Pdfs + for(auto const &item : associated) { + auto *histPdf = dynamic_cast(item.first); + auto *histFunc = dynamic_cast(item.first); + if(histPdf) histPdf->_dataHist = item.second; + if(histFunc) histFunc->_dataHist = item.second; + } + // Reinstate clients here for (auto &iterx : extClients) {