Skip to content

Commit 7606daa

Browse files
committed
maybe this works for staterror
1 parent ca9849d commit 7606daa

File tree

3 files changed

+27
-6
lines changed

3 files changed

+27
-6
lines changed

src/pyhf/modifiers/staterror.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,25 @@
66
from pyhf.exceptions import InvalidModifier
77
from pyhf.parameters import ParamViewer
88
from pyhf.tensor.manager import get_backend
9+
from typing import Optional
910

1011
log = logging.getLogger(__name__)
1112

1213

13-
def required_parset(sigmas, fixed: List[bool]):
14+
def required_parset(sigmas, fixed: List[bool], constraint: Optional[str] = "gaussian"):
1415
n_parameters = len(sigmas)
1516
return {
16-
'paramset_type': 'constrained_by_normal',
17+
'paramset_type': 'constrained_by_normal'
18+
if constraint == "gaussian"
19+
else 'constrained_by_poisson',
1720
'n_parameters': n_parameters,
1821
'is_shared': True,
1922
'is_scalar': False,
2023
'inits': (1.0,) * n_parameters,
2124
'bounds': ((1e-10, 10.0),) * n_parameters,
2225
'fixed': tuple(fixed),
2326
'auxdata': (1.0,) * n_parameters,
24-
'sigmas': tuple(sigmas),
27+
'sigmas' if constraint == "gaussian" else 'factors': tuple(sigmas),
2528
}
2629

2730

@@ -36,11 +39,12 @@ def __init__(self, config):
3639
def collect(self, thismod, nom):
3740
uncrt = thismod['data'] if thismod else [0.0] * len(nom)
3841
mask = [True if thismod else False] * len(nom)
39-
return {'mask': mask, 'nom_data': nom, 'uncrt': uncrt}
42+
constraint = thismod.get('constraint', 'gaussian') if thismod else 'gaussian'
43+
return {'mask': mask, 'nom_data': nom, 'uncrt': uncrt, 'constraint': constraint}
4044

4145
def append(self, key, channel, sample, thismod, defined_samp):
4246
self.builder_data.setdefault(key, {}).setdefault(sample, {}).setdefault(
43-
'data', {'uncrt': [], 'nom_data': [], 'mask': []}
47+
'data', {'uncrt': [], 'nom_data': [], 'mask': [], 'constraint': []}
4448
)
4549
nom = (
4650
defined_samp['data']
@@ -51,6 +55,9 @@ def append(self, key, channel, sample, thismod, defined_samp):
5155
self.builder_data[key][sample]['data']['mask'].append(moddata['mask'])
5256
self.builder_data[key][sample]['data']['uncrt'].append(moddata['uncrt'])
5357
self.builder_data[key][sample]['data']['nom_data'].append(moddata['nom_data'])
58+
self.builder_data[key][sample]['data']['constraint'].append(
59+
moddata['constraint']
60+
)
5461

5562
def finalize(self):
5663
default_backend = pyhf.default_backend
@@ -115,12 +122,22 @@ def finalize(self):
115122

116123
for modifier_data in self.builder_data[modname].values():
117124
modifier_data['data']['mask'] = masks[modname]
125+
118126
sigmas = relerrs[masks[modname]]
119127
# list of bools, consistent with other modifiers (no numpy.bool_)
120128
fixed = default_backend.tolist(sigmas == 0)
121129
# ensures non-Nan constraint term, but in a future PR we need to remove constraints for these
122130
sigmas[fixed] = 1.0
123-
self.required_parsets.setdefault(parname, [required_parset(sigmas, fixed)])
131+
132+
constraint = [
133+
i
134+
for i, v in zip(modifier_data['data']['constraint'], masks[modname])
135+
if v
136+
]
137+
assert all(constraint[0] == element for element in constraint)
138+
self.required_parsets.setdefault(
139+
parname, [required_parset(sigmas, fixed, constraint)]
140+
)
124141
return self.builder_data
125142

126143

src/pyhf/pdf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def _finalize_parameters_specs(user_parameters, _paramsets_requirements):
3535
f"Multiple parameter configurations for {parameter['name']} were found."
3636
)
3737
_paramsets_user_configs[parameter.get('name')] = parameter
38+
3839
_reqs = reduce_paramsets_requirements(
3940
_paramsets_requirements, _paramsets_user_configs
4041
)

src/pyhf/schema/loader.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,6 @@ def load_schema(schema_id: str):
3939
schema = json.load(json_schema)
4040
variables.SCHEMA_CACHE[schema['$id']] = schema
4141
return variables.SCHEMA_CACHE[schema['$id']]
42+
43+
44+
load_schema(f'{variables.SCHEMA_VERSION}/defs.json')

0 commit comments

Comments
 (0)