from cbp.utils import (compare_marginals, diff_max_marginals,
engine_loop)
from cbp.configs.base_config import baseconfig
from .base_graph import BaseGraph
from .coef_policy import bp_policy
from .graph_utils import itsbp_inner_loop, find_link
[docs]class GraphModel(BaseGraph):
def __init__(self, silent=True, epsilon=1,
coef_policy=bp_policy, config=baseconfig):
super().__init__(config=config, silent=silent,
epsilon=epsilon, coef_policy=coef_policy)
self.itsbp_outer_cnt = 0
[docs] def bake(self):
super().bake()
for node in self.nodes:
node.cal_cnp_coef()
[docs] def run_cnp(self):
self.bake()
return self.norm_product_bp()
[docs] def run_bp(self):
if self.coef_policy != bp_policy: # pylint: disable=comparison-with-callable
self.coef_policy = bp_policy
self.bake()
return self.itsbp()
[docs] def norm_product_bp(self, max_iter=5000000, tolerance=1e-5, error_fun=None):
if error_fun is None:
error_fun = diff_max_marginals
self.first_belief_propagation()
return self.engine_loop(
max_iter=max_iter,
engine_fun=self.parallel_message,
tolerance=tolerance,
error_fun=error_fun,
isoutput=False)
[docs] def engine_loop( # pylint: disable= too-many-arguments
self,
engine_fun,
max_iter=5000000,
tolerance=1e-2,
error_fun=None,
isoutput=False):
if error_fun is None:
error_fun = compare_marginals
epsilons, step, _ = engine_loop(
engine_fun=engine_fun,
max_iter=max_iter,
tolerance=tolerance,
error_fun=error_fun,
meassure_fun=self.export_convergence_marginals,
isoutput=isoutput,
silent=self.silent
)
return epsilons, step
[docs] def itsbp(self):
"""run sinkhorn or iterative scaling inference
:return: [description]
:rtype: [type]
"""
self.first_belief_propagation()
return self.engine_loop(self.itsbp_outer_loop,
tolerance=1e-4,
error_fun=diff_max_marginals,
isoutput=False)
[docs] def its_next_looplink(self):
target_node = self.leaf_nodes[self.itsbp_outer_cnt]
next_node = self.leaf_nodes[(
self.itsbp_outer_cnt + 1) % len(self.leaf_nodes)]
self.itsbp_outer_cnt = self.cfg.itsbp_schedule(
self.itsbp_outer_cnt, self.leaf_nodes)
return target_node, find_link(target_node, next_node)
[docs] def itsbp_outer_loop(self):
for _ in range(len(self.leaf_nodes)):
_, loop_link = self.its_next_looplink()
itsbp_inner_loop(loop_link, self.silent)
[docs] def parallel_message(self, run_constrained=True):
for target_var in self.varnode_recorder.values():
# sendind in messages from factors
target_var.sendin_message(self.silent)
if run_constrained or (not target_var.isconstrained):
target_var.sendout_message(self.silent)