Source code for cbp.graph.graph_model

from cbp.utils import (compare_marginals, diff_max_marginals,
                       engine_loop)
from cbp.configs.base_config import baseconfig

from .base_graph import BaseGraph
from .coef_policy import bp_policy
from .graph_utils import itsbp_inner_loop, find_link


[docs]class GraphModel(BaseGraph): def __init__(self, silent=True, epsilon=1, coef_policy=bp_policy, config=baseconfig): super().__init__(config=config, silent=silent, epsilon=epsilon, coef_policy=coef_policy) self.itsbp_outer_cnt = 0
[docs] def bake(self): super().bake() for node in self.nodes: node.cal_cnp_coef()
[docs] def run_cnp(self): self.bake() return self.norm_product_bp()
[docs] def run_bp(self): if self.coef_policy != bp_policy: # pylint: disable=comparison-with-callable self.coef_policy = bp_policy self.bake() return self.itsbp()
[docs] def norm_product_bp(self, max_iter=5000000, tolerance=1e-5, error_fun=None): if error_fun is None: error_fun = diff_max_marginals self.first_belief_propagation() return self.engine_loop( max_iter=max_iter, engine_fun=self.parallel_message, tolerance=tolerance, error_fun=error_fun, isoutput=False)
[docs] def engine_loop( # pylint: disable= too-many-arguments self, engine_fun, max_iter=5000000, tolerance=1e-2, error_fun=None, isoutput=False): if error_fun is None: error_fun = compare_marginals epsilons, step, _ = engine_loop( engine_fun=engine_fun, max_iter=max_iter, tolerance=tolerance, error_fun=error_fun, meassure_fun=self.export_convergence_marginals, isoutput=isoutput, silent=self.silent ) return epsilons, step
[docs] def itsbp(self): """run sinkhorn or iterative scaling inference :return: [description] :rtype: [type] """ self.first_belief_propagation() return self.engine_loop(self.itsbp_outer_loop, tolerance=1e-4, error_fun=diff_max_marginals, isoutput=False)
[docs] def itsbp_outer_loop(self): for _ in range(len(self.leaf_nodes)): _, loop_link = self.its_next_looplink() itsbp_inner_loop(loop_link, self.silent)
[docs] def parallel_message(self, run_constrained=True): for target_var in self.varnode_recorder.values(): # sendind in messages from factors target_var.sendin_message(self.silent) if run_constrained or (not target_var.isconstrained): target_var.sendout_message(self.silent)