Skip to content

[RF] Compress the embedded RooDataHists in workspaces when serializing #19459

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions roofit/roofitcore/inc/LinkDef.h
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@
#pragma link C++ class RooWorkspace- ;
#pragma link C++ class RooWorkspace::CodeRepo- ;
#pragma link C++ class RooWorkspace::WSDir+ ;
#pragma link C++ class RooWorkspace::EmbeddedHisto+ ;
#pragma link C++ class RooWorkspaceHandle+;
#pragma link C++ class std::list<TObject*>+ ;
#pragma link C++ class std::list<RooAbsData*>+ ;
Expand Down
2 changes: 2 additions & 0 deletions roofit/roofitcore/inc/RooDataHist.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@ class RooDataHist : public RooAbsData, public RooDirItem {
mutable double _cache_sum{0.}; ///<! Cache for sum of entries ;

private:
friend class RooWorkspace ;

void interpolateQuadratic(double* output, std::span<const double> xVals, bool correctForBinSize, bool cdfBoundaries);
void interpolateLinear(double* output, std::span<const double> xVals, bool correctForBinSize, bool cdfBoundaries);
double weightInterpolated(const RooArgSet& bin, int intOrder, bool correctForBinSize, bool cdfBoundaries);
Expand Down
1 change: 1 addition & 0 deletions roofit/roofitcore/inc/RooHistFunc.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class RooHistFunc : public RooAbsReal {
double evaluate() const override;
void doEval(RooFit::EvalContext &) const override;
friend class RooAbsCachedReal ;
friend class RooWorkspace ;

void ioStreamerPass2() override ;

Expand Down
4 changes: 3 additions & 1 deletion roofit/roofitcore/inc/RooHistPdf.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,11 @@ class RooHistPdf : public RooAbsPdf {

double evaluate() const override;
double totalVolume() const ;
friend class RooAbsCachedPdf ;
double totVolume() const ;

friend class RooAbsCachedPdf ;
friend class RooWorkspace ;

RooArgSet _histObsList; ///< List of observables defining dimensions of histogram
RooSetProxy _pdfObsList; ///< List of observables mapped onto histogram observables
RooDataHist* _dataHist = nullptr; ///< Unowned pointer to underlying histogram
Expand Down
25 changes: 25 additions & 0 deletions roofit/roofitcore/inc/RooWorkspace.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class RooAbsReal ;
class RooAbsCategory ;
class RooFactoryWSTool ;
class RooAbsStudy ;
class RooDataHist ;
class RooAbsBinning ;

#include "TNamed.h"
#include "TDirectoryFile.h"
Expand Down Expand Up @@ -165,6 +167,24 @@ class RooWorkspace : public TNamed {

RooExpensiveObjectCache& expensiveObjectCache() { return _eocache ; }

/// Internal class that can pack all the information in an embedded
/// RooDataHist for smaller workspace sizes on disk.
struct EmbeddedHisto : public TObject {
std::string name;
std::string title;
std::string argName;

int arraySize;
std::vector<double> weightArray;
std::vector<double> wgtErrLoArray;
std::vector<double> wgtErrHiArray;
std::vector<double> sumW2Array;

std::vector<std::unique_ptr<RooAbsBinning>> binnings;

ClassDefOverride(EmbeddedHisto, 1);
};

class CodeRepo : public TObject {
public:
CodeRepo(RooWorkspace* wspace=nullptr) : _wspace(wspace), _compiledOK(true) {} ;
Expand Down Expand Up @@ -245,8 +265,13 @@ class RooWorkspace : public TNamed {
friend class RooAbsArg;
friend class RooAbsPdf;
friend class RooConstraintSum;

bool defineSetInternal(const char *name, const RooArgSet &aset);

void packEmbeddedHisto(RooAbsArg const &arg, RooDataHist const &dataHist, EmbeddedHisto &packed);

RooDataHist *unpackEmbeddedHisto(RooArgSet const &vars, EmbeddedHisto const &packed);

friend class CodeRepo;
static std::list<std::string> _classDeclDirList;
static std::list<std::string> _classImplDirList;
Expand Down
132 changes: 132 additions & 0 deletions roofit/roofitcore/src/RooWorkspace.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ and try reading again.
#include <RooCmdConfig.h>
#include <RooConstVar.h>
#include <RooFactoryWSTool.h>
#include <RooHistFunc.h>
#include <RooHistPdf.h>
#include <RooLinkedListIter.h>
#include <RooMsgService.h>
#include <RooPlot.h>
Expand Down Expand Up @@ -2432,6 +2434,69 @@ void RooWorkspace::CodeRepo::Streamer(TBuffer &R__b)
}
}

void RooWorkspace::packEmbeddedHisto(RooAbsArg const &arg, RooDataHist const &dataHist,
RooWorkspace::EmbeddedHisto &packed)
{
// We error out if the dataHist is not as we expect it to be for a template.
// Which is in case of:
// * asymmetric uncertainties
// * global observables
if (dataHist.wgtErrLoArray() || dataHist.wgtErrHiArray() || dataHist.getGlobalObservables()) {
throw std::runtime_error("Template histogram for \"" + std::string{arg.GetName()} + "\" can't be serialized!");
}

packed.argName = arg.GetName();
packed.name = dataHist.GetName();
packed.title = dataHist.GetTitle();

packed.arraySize = dataHist.arraySize();

auto fillVector = [&](double const *arr, std::vector<double> &vec) {
if (arr) {
vec.resize(packed.arraySize);
std::copy(arr, arr + vec.size(), vec.begin());
}
};

fillVector(dataHist.weightArray(), packed.weightArray);
fillVector(dataHist.sumW2Array(), packed.sumW2Array);
fillVector(dataHist.wgtErrLoArray(), packed.wgtErrLoArray);
fillVector(dataHist.wgtErrHiArray(), packed.wgtErrHiArray);

for (auto const &binning : dataHist.getBinnings()) {
packed.binnings.emplace_back(binning ? binning->clone() : nullptr);
}
}

RooDataHist *RooWorkspace::unpackEmbeddedHisto(RooArgSet const &vars, RooWorkspace::EmbeddedHisto const &packed)
{
RooArgSet varsCopy;
vars.snapshot(varsCopy);

for (std::size_t i = 0; i < vars.size(); ++i) {
if (packed.binnings[i]) {
static_cast<RooRealVar *>(varsCopy[i])->setBinning(*packed.binnings[i]);
}
}

auto *dataHist = new RooDataHist{packed.name, packed.title, varsCopy};

int n = packed.arraySize;

for (int i = 0; i < n; ++i) {
dataHist->set(i, packed.weightArray[i], -1);
}

// Set the sum of weights squared
if (!packed.sumW2Array.empty()) {
dataHist->_sumw2 = new double[n];
std::copy(packed.sumW2Array.begin(), packed.sumW2Array.end(), dataHist->_sumw2);
}

dataHist->registerWeightArraysToDataStore();

return dataHist;
}

////////////////////////////////////////////////////////////////////////////////
/// Stream an object of class RooWorkspace. This is a standard ROOT streamer for the
Expand Down Expand Up @@ -2475,6 +2540,29 @@ void RooWorkspace::Streamer(TBuffer &R__b)
}
}

for (auto *packed : dynamic_range_cast<RooWorkspace::EmbeddedHisto const*>(_embeddedDataList)) {
// If the type was not RooWorkspace::EmbeddedHisto, it was an old workspace
// where the RooDataHists were stored directly. So nothing to do.
if (packed) {
RooAbsArg *arg = &_allOwnedNodes[packed->argName.c_str()];

auto *histPdf = dynamic_cast<RooHistPdf *>(arg);
auto *histFunc = dynamic_cast<RooHistFunc *>(arg);

RooArgSet const &vars = histPdf ? histPdf->variables() : histFunc->variables();

RooDataHist *dataHist = unpackEmbeddedHisto(vars, *packed);

if (histPdf)
histPdf->_dataHist = dataHist;
if (histFunc)
histFunc->_dataHist = dataHist;

_embeddedDataList.Replace(packed, dataHist);
delete packed;
}
}

} else {

// Make lists of external clients of WS objects, and remove those links temporarily
Expand Down Expand Up @@ -2531,8 +2619,52 @@ void RooWorkspace::Streamer(TBuffer &R__b)
}
}

// Temporary container to hold converted embedded datasets during serialization
RooLinkedList embeddedDataList;

// Pack the embedded RooDataHists
std::vector<std::pair<RooAbsArg*, RooDataHist*>> associated;
for (RooAbsArg *arg : _allOwnedNodes) {
auto *histPdf = dynamic_cast<RooHistPdf *>(arg);
auto *histFunc = dynamic_cast<RooHistFunc *>(arg);
if (!histPdf && !histFunc) continue;
RooDataHist &dataHist = histPdf ? histPdf->dataHist() : histFunc->dataHist();

// We have to temporarily nullify the dataHists that the
// HistFuncs/Pdfs point to, so they don't get serialized.
// Remember the association and reset later.
if(histPdf) histPdf->_dataHist = nullptr;
if(histFunc) histFunc->_dataHist = nullptr;
associated.emplace_back(arg, &dataHist);

auto *dataHistPacked = new RooWorkspace::EmbeddedHisto;
packEmbeddedHisto(*arg, dataHist, *dataHistPacked);
embeddedDataList.Add(dataHistPacked);
}

// Validate that we have now packed all the embedded RooDataHists
if (embeddedDataList.size() != _embeddedDataList.size()) {
throw std::runtime_error("There were unexpected embedded datasets!");
}

// Temporarily replace _embeddedDataList with the serialized version for writing
std::swap(embeddedDataList, _embeddedDataList);

R__b.WriteClassBuffer(RooWorkspace::Class(), this);

// Restore original _embeddedDataList after serialization is complete
std::swap(embeddedDataList, _embeddedDataList);

embeddedDataList.Delete();

// Reset the histFuncs to the RooDataHists/Pdfs
for(auto const &item : associated) {
auto *histPdf = dynamic_cast<RooHistPdf *>(item.first);
auto *histFunc = dynamic_cast<RooHistFunc *>(item.first);
if(histPdf) histPdf->_dataHist = item.second;
if(histFunc) histFunc->_dataHist = item.second;
}

// Reinstate clients here

for (auto &iterx : extClients) {
Expand Down
Loading