1import pymc3 as pm
2from .. import types
3
4Operator = pm.operators.Operator
5Inference = pm.Inference
6MeanField = pm.MeanField
7
8
9class KLThermal(Operator):
10    """Kullback-Leibler divergence operator with finite temperature."""
11    def __init__(self,
12                 approx: pm.approximations.Approximation,
13                 temperature: types.TensorSharedVariable):
14        """Initializer.
15
16        Args:
17            approx: an instance of PyMC3 approximation
18            temperature: a scalar shared theano tensor variable
19        """
20        super().__init__(approx)
21        assert temperature is not None
22        self.temperature = temperature
23
24    def apply(self, f):
25        z = self.input
26        return self.temperature * self.logq_norm(z) - self.logp_norm(z)
27
28
29class ADVIDeterministicAnnealing(Inference):
30    """ADVI with deterministic annealing functionality.
31
32    Note:
33        Temperature is not updated automatically by this class. This task is delegated to the ADVI step
34        function. This can be done by including a temperature update in `more_updates`; refer to
35        `pymc3.opvi.ObjectiveFunction.step_function` for more information.
36
37    """
38    def __init__(self,
39                 local_rv=None,
40                 model=None,
41                 cost_part_grad_scale=1,
42                 scale_cost_to_minibatch=False,
43                 random_seed=None, start=None,
44                 temperature=None):
45
46        assert temperature is not None, "Temperature (a scalar theano shared tensor) is not provided"
47        super().__init__(
48            KLThermal, MeanField, None,
49            local_rv=local_rv,
50            model=model,
51            cost_part_grad_scale=cost_part_grad_scale,
52            scale_cost_to_minibatch=scale_cost_to_minibatch,
53            random_seed=random_seed,
54            start=start,
55            op_kwargs={'temperature': temperature})
56