-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgp_vi.py
More file actions
50 lines (39 loc) · 2.44 KB
/
gp_vi.py
File metadata and controls
50 lines (39 loc) · 2.44 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from typing import List
import jax
from upix.core import *
from upix.infer import VIDCC, Guide, PredicateSelector, MeanfieldNormalGuide, InferenceResult, LogWeightEstimateFromADVI, LogWeightEstimate
from gp import *
from successive_halving import *
class VIConfig(VIDCC):
def __init__(self, model: Model, *ignore, verbose=0, **config_kwargs) -> None:
super().__init__(model, *ignore, verbose=verbose, **config_kwargs)
self.successive_halving: SuccessiveHalving = self.config["successive_halving"]
def get_guide(self, slp: SLP) -> Guide:
selector = PredicateSelector(lambda addr: not addr.endswith("node_type"))
return MeanfieldNormalGuide(slp, selector, 0.1)
# return FullRankNormalGuide(slp, selector, 0.1)
def initialise_active_slps(self, active_slps: List[SLP], inactive_slps: List[SLP], rng_key: jax.Array):
super().initialise_active_slps(active_slps, inactive_slps, rng_key)
self.n_phases = self.successive_halving.calculate_num_phases(len(active_slps))
self.advi_n_iter = self.successive_halving.calculate_num_optimization_steps(len(active_slps))
# self.advi_n_iter = 1000
# self.n_phases = 1
tqdm.write(f"{len(active_slps)=} {self.n_phases=} {self.advi_n_iter=}")
def update_active_slps(self, active_slps: List[SLP], inactive_slps: List[SLP], inference_results: Dict[SLP, List[InferenceResult]], log_weight_estimates: Dict[SLP, List[LogWeightEstimate]], rng_key: PRNGKey):
inactive_slps.clear()
inactive_slps.extend(active_slps)
active_slps.clear()
if self.iteration_counter == self.n_phases:
return
slp_to_log_weight: List[Tuple[SLP, float]] = []
for slp in inactive_slps:
latest_estimate = log_weight_estimates[slp][-1]
assert isinstance(latest_estimate, LogWeightEstimateFromADVI)
slp_to_log_weight.append((slp, latest_estimate.get_estimate().item()))
selected_slps = self.successive_halving.select_active_slps(slp_to_log_weight)
for slp, log_weight in selected_slps:
tqdm.write(f"Keep {slp.formatted()} with {log_weight}")
active_slps.append(slp)
self.advi_n_iter = self.successive_halving.calculate_num_optimization_steps(len(active_slps))
tqdm.write(f"update active slps {len(active_slps)=} {self.advi_n_iter=}")
# TODO: estimate_path_log_prob as in SMC?