1+ from typing import Sequence
2+
3+ from pymc import STEP_METHODS
4+ from pytensor .tensor .random .type import RandomGeneratorType
5+
6+ from pytensor .compile .builders import OpFromGraph
7+
8+ from pymc_experimental .sampling .mcmc import posterior_optimization_db
9+ from pymc_experimental .sampling .optimizations .conjugate_sampler import ConjugateRV , ConjugateRVSampler
10+
11+ STEP_METHODS .append (ConjugateRVSampler )
12+
13+ from pytensor .graph .fg import Output
14+ from pytensor .tensor .elemwise import DimShuffle
15+ from pymc .model .fgraph import model_free_rv , ModelValuedVar
16+
17+
18+ from pytensor .graph .basic import Variable
19+ from pytensor .graph .fg import FunctionGraph
20+ from pytensor .graph .rewriting .basic import node_rewriter
21+ from pymc .model .fgraph import ModelFreeRV
22+ from pymc .distributions import Beta , Binomial
23+ from pymc .pytensorf import collect_default_updates
24+
25+
26+ def get_model_var_of_rv (fgraph : FunctionGraph , rv : Variable ) -> Variable :
27+ """Return the Model dummy var that wraps the RV"""
28+ for client , _ in fgraph .clients [rv ]:
29+ if isinstance (client .op , ModelValuedVar ):
30+ return client .outputs [0 ]
31+
32+
33+ def get_dist_params (rv : Variable ) -> tuple [Variable ]:
34+ return rv .owner .op .dist_params (rv .owner )
35+
36+
37+ def rv_used_by (fgraph : FunctionGraph , rv : Variable , used_by_type : type , used_as_arg_idx : int | Sequence [int ], strict : bool = True ) -> list [Variable ]:
38+ """Return the RVs that use `rv` as an argument in an operation of type `used_by_type`.
39+
40+ RV may be used directly or broadcasted before being used.
41+
42+ Parameters
43+ ----------
44+ fgraph : FunctionGraph
45+ The function graph containing the RVs
46+ rv : Variable
47+ The RV to check for uses.
48+ used_by_type : type
49+ The type of operation that may use the RV.
50+ used_as_arg_idx : int | Sequence[int]
51+ The index of the RV in the operation's inputs.
52+ strict : bool, default=True
53+ If True, return no results when the RV is used in an unrecognized way.
54+
55+ """
56+ if isinstance (used_as_arg_idx , int ):
57+ used_as_arg_idx = (used_as_arg_idx ,)
58+
59+ clients = fgraph .clients
60+ used_by : list [Variable ] = []
61+ for client , inp_idx in clients [rv ]:
62+ if isinstance (client .op , Output ):
63+ continue
64+
65+ if isinstance (client .op , used_by_type ) and inp_idx in used_as_arg_idx :
66+ # RV is directly used by the RV type
67+ used_by .append (client .default_output ())
68+
69+ elif isinstance (client .op , DimShuffle ) and client .op .is_left_expand_dims :
70+ for sub_client , sub_inp_idx in clients [client .outputs [0 ]]:
71+ if isinstance (sub_client .op , used_by_type ) and sub_inp_idx in used_as_arg_idx :
72+ # RV is broadcasted and then used by the RV type
73+ used_by .append (sub_client .default_output ())
74+ elif strict :
75+ # Some other unrecognized use, bail out
76+ return []
77+ elif strict :
78+ # Some other unrecognized use, bail out
79+ return []
80+
81+ return used_by
82+
83+
84+ def wrap_rv_and_conjugate_rv (fgraph : FunctionGraph , rv : Variable , conjugate_rv : Variable , inputs : Sequence [Variable ]) -> Variable :
85+ """Wrap the RV and its conjugate posterior RV in a ConjugateRV node.
86+
87+ Also takes care of handling the random number generators used in the conjugate posterior.
88+ """
89+ rngs , next_rngs = zip (* collect_default_updates (conjugate_rv , inputs = [rv , * inputs ]).items ())
90+ for rng in rngs :
91+ if rng not in fgraph .inputs :
92+ fgraph .add_input (rng )
93+ conjugate_op = ConjugateRV (inputs = [rv , * inputs , * rngs ], outputs = [rv , conjugate_rv , * next_rngs ])
94+ return conjugate_op (rv , * inputs , * rngs )[0 ]
95+
96+
97+ def create_untransformed_free_rv (fgraph : FunctionGraph , rv : Variable , name : str , dims : Sequence [str | Variable ]) -> Variable :
98+ """Create a model FreeRV without transform."""
99+ transform = None
100+ value = rv .type (name = name )
101+ fgraph .add_input (value )
102+ free_rv = model_free_rv (rv , value , transform , * dims )
103+ free_rv .name = name
104+ return free_rv
105+
106+
107+ @node_rewriter (tracks = [ModelFreeRV ])
108+ def beta_binomial_conjugacy (fgraph : FunctionGraph , node ):
109+ """This applies the equivalence (up to a normalizing constant) described in:
110+
111+ https://mc-stan.org/docs/stan-users-guide/efficiency-tuning.html#exploiting-sufficient-statistics
112+ """
113+ [beta_free_rv ] = node .outputs
114+ beta_rv , beta_value , * beta_dims = node .inputs
115+
116+ if not isinstance (beta_rv .owner .op , Beta ):
117+ return None
118+
119+ p_arg_idx = 3 # inputs to Binomial are (rng, size, n, p)
120+ binomial_rvs = rv_used_by (fgraph , beta_free_rv , Binomial , p_arg_idx )
121+
122+ if len (binomial_rvs ) != 1 :
123+ # Question: Can we apply conjugacy when RV is used by more than one binomial?
124+ return None
125+
126+ [binomial_rv ] = binomial_rvs
127+
128+ binomial_model_var = get_model_var_of_rv (fgraph , binomial_rv )
129+ if binomial_model_var is None :
130+ return None
131+
132+ # We want to replace free_rv by ConjugateRV()->(free_rv, conjugate_posterior_rv)
133+ a , b = get_dist_params (beta_rv )
134+ n , _ = get_dist_params (binomial_rv )
135+
136+ # Use value of y in new graph to avoid circularity
137+ y = binomial_model_var .owner .inputs [1 ]
138+
139+ conjugate_a = a + y
140+ conjugate_b = b + (n - y )
141+ extra_dims = range (binomial_rv .type .ndim - beta_rv .type .ndim )
142+ if extra_dims :
143+ conjugate_a = conjugate_a .sum (extra_dims )
144+ conjugate_b = conjugate_b .sum (extra_dims )
145+ conjugate_beta_rv = Beta .dist (conjugate_a , conjugate_b )
146+
147+ new_beta_rv = wrap_rv_and_conjugate_rv (fgraph , beta_rv , conjugate_beta_rv , [a , b , n , y ])
148+ new_beta_free_rv = create_untransformed_free_rv (fgraph , new_beta_rv , beta_free_rv .name , beta_dims )
149+ return [new_beta_free_rv ]
150+
151+
152+ posterior_optimization_db .register (
153+ beta_binomial_conjugacy .__name__ ,
154+ beta_binomial_conjugacy ,
155+ "conjugacy"
156+ )
0 commit comments