Source code for cbp.graph.base_graph

import warnings
from functools import partial

import numpy as np
from cbp.node import VarNode
from cbp.utils import (Message, diff_max_marginals,
                       engine_loop)
from cbp.utils.np_utils import (nd_expand, nd_multiexpand,
                                reduction_ndarray)
from cbp.configs.base_config import baseconfig
from .coef_policy import bp_policy
from .graph_utils import cal_marginal_from_tensor
try:
    import pygraphviz  # noqa
except BaseException:
    pygraphviz = None


[docs]class BaseGraph(): # pylint: disable=too-many-instance-attributes def __init__(self, silent=True, epsilon=1, coef_policy=bp_policy, config=baseconfig): self.varnode_recorder = {} self.constrained_names = [] self.leaf_nodes = [] self.factornode_recorder = {} self.node_recorder = {} self.epsilon = epsilon self.coef_policy = coef_policy self.cnt_varnode = 0 self.cnt_factornode = 0 self.cfg = config # debug utils self.silent = silent
[docs] def add_varnode(self, node): """add one `~cbp.node.VarNode` to this graph, idx follow the increasing order :param node: one VarNode :type node: VarNode :return: name of varnode :rtype: str """ assert isinstance(node, VarNode) varnode_name = f"VarNode_{self.cnt_varnode:03d}" node.format_name(varnode_name) self.varnode_recorder[varnode_name] = node self.node_recorder[varnode_name] = node if node.isconstrained: self.constrained_names.append(varnode_name) self.cnt_varnode += 1 return varnode_name
[docs] def add_factornode(self, factornode): """add one factor node to the graph Do the following tasks * add node to the recorders * set connections * set parent relation :param factornode: one factor node :type factornode: FactorNode :return: name of factor node :rtype: str """ factornode.check_potential(self.varnode_recorder) factornode_name = f"FactorNode_{self.cnt_factornode:03d}" factornode.format_name(factornode_name) self.factornode_recorder[factornode_name] = factornode self.node_recorder[factornode_name] = factornode self.__register_connection(factornode) self.__set_parent(factornode) self.cnt_factornode += 1 return factornode_name
def __register_connection(self, factornode): for varnode_name in factornode.get_connections(): varnode = self.varnode_recorder[varnode_name] varnode.register_connection(factornode.name) def __set_parent(self, factornode): connections = factornode.get_connections() factornode.parent = self.varnode_recorder[connections[0]] for varnode_name in connections[1:]: varnode = self.varnode_recorder[varnode_name] varnode.parent = factornode
[docs] def pmf(self): """output the probability mass matrix through brutal-force methods :return: joint probability mass matrix :rtype: ndarray """ varnode_names = list(self.varnode_recorder.keys()) varnodes = list(self.varnode_recorder.values()) var_dim = [variable.rv_dim for variable in varnodes] assert len(var_dim) < 32, "max number of vars for brute_force is 32 \ (numpy matrix dim limit)" joint_acc = np.ones(var_dim) for factor in self.factornode_recorder.values(): which_dims = [varnode_names.index(v) for v in factor.get_connections()] factor_acc = np.ones(var_dim) factor_acc = nd_multiexpand(factor.potential, var_dim, which_dims) joint_acc *= factor_acc joint_prob = joint_acc / np.sum(joint_acc) return joint_prob
[docs] def exact_marginal(self): varnodes = list(self.varnode_recorder.values()) prob_tensor = self.pmf() marginal_list = cal_marginal_from_tensor(prob_tensor, varnodes) for node, marginal in zip(varnodes, marginal_list): node.bfmarginal = marginal
[docs] def first_belief_propagation(self): for node in self.nodes: for recipient_name in node.connections: recipient = self.node_recorder[recipient_name] if recipient.name not in node.message_inbox: val = node.make_init_message(recipient_name) message = Message(node, val) self.node_recorder[recipient_name].store_message(message)
[docs] def copy_bp_initialization(self, another_graph): """copy message setup from the another graph has same structure :param another_graph: another graph which close to the optimal point :type another_graph: BaseGraph """ # TODO: copy safety!!! for node in self.nodes: if node.name in another_graph.node_recorder: another_node = another_graph.node_recorder[node.name] node.message_inbox = another_node.message_inbox node.latest_message = another_node.latest_message else: raise f"{node.name} not in this graph"
def __init_sinkhorn_node(self): varnode_names = list(self.varnode_recorder.keys()) self.sinkhorn_node_coef = {} # pylint: disable=attribute-defined-outside-init for node_name in self.constrained_names: node_instance = self.varnode_recorder[node_name] self.sinkhorn_node_coef[node_name] = { 'index': varnode_names.index(node_name), 'mu': node_instance.constrained_marginal, 'u': np.ones(node_instance.rv_dim) } for node in self.varnode_recorder.values(): node.sinkhorn = np.ones(node.rv_dim) / node.rv_dim def __build_big_u(self): varnodes = list(self.varnode_recorder.values()) var_dim = [variable.rv_dim for variable in varnodes] joint_acc = np.ones(var_dim) for _, recoder in self.sinkhorn_node_coef.items(): constrained_acc = nd_expand( recoder['u'], tuple(var_dim), recoder['index']) joint_acc *= constrained_acc # log_joint_acc -= np.max(log_joint_acc) return joint_acc / np.sum(joint_acc) # TODO this is a bug!!!!
[docs] def sinkhorn_update(self, tilde_c): for _, recorder in self.sinkhorn_node_coef.items(): big_u = self.__build_big_u() normalized_denominator = (big_u * tilde_c) / \ np.sum(big_u * tilde_c) copy_denominator = reduction_ndarray( normalized_denominator, recorder['index']) copy_denominator = np.clip(copy_denominator, 1e-12, None) recorder['u'] = recorder['u'] * recorder['mu'] / copy_denominator varnodes = list(self.varnode_recorder.values()) marginal_list = cal_marginal_from_tensor( normalized_denominator, varnodes) for node, marginal in zip(varnodes, marginal_list): node.sinkhorn = marginal
def __check_sinkhorn(self): if len(self.constrained_names) == 0: raise RuntimeError( "There is no constrained nodes, use brutal force")
[docs] def sinkhorn(self, max_iter=5000000, tolerance=1e-5): self.__check_sinkhorn() tilde_c = self.pmf() self.__init_sinkhorn_node() sinkhorn_func = partial(self.sinkhorn_update, tilde_c) return engine_loop(engine_fun=sinkhorn_func, max_iter=max_iter, tolerance=tolerance, error_fun=diff_max_marginals, meassure_fun=self.export_sinkhorn, isoutput=False, silent=self.silent)
[docs] def bake(self): self.init_node_recorder() for node in self.nodes: if len(node.connections) == 1: root = node self.__cal_node_coef(root)
def __cal_node_coef(self, node): node.is_traversed = True for item in node.connections: if not self.node_recorder[item].is_traversed: self.__cal_node_coef(self.node_recorder[item]) node.auto_coef(self.node_recorder, self.coef_policy) node.is_traversed = False
[docs] def init_node_recorder(self): factors = list(self.factornode_recorder.values()) variables = list(self.varnode_recorder.values()) # in Norm-Product, run factor message first self.nodes = factors + variables # pylint: disable=attribute-defined-outside-init self.leaf_nodes = [ node for node in self.nodes if len(node.get_connections()) == 1]
[docs] def get_node(self, name_str): if name_str not in self.node_recorder: raise RuntimeError(f"{name_str} is illegal, not in this graph") return self.node_recorder[name_str]
[docs] def cal_bethe(self, margin): """calculate bethe energy :param margin: node_name : margin :type margin: dict :return: KL divergence between expoert joint dist and p_graph :rtype: float """ sum_item = [] for node in self.nodes: sum_item.append(node.cal_bethe(margin[node.name])) return np.sum(sum_item)
[docs] def delete_node(self, name_str): """delete node from graph, needs to check following * judge type, Factor or Variable * delete from the various recorders * clear connections :param name_str: [description] :type name_str: [type] :raises RuntimeError: [description] """ if name_str not in self.node_recorder: raise RuntimeError(f"{name_str} is illegal, not in this graph") target_node = self.node_recorder[name_str] if isinstance(target_node, VarNode): warnings.warn(f"Delete {name_str}, may have a suspend factor node") for connected_name in target_node.get_connections(): connected_node = self.node_recorder[connected_name] if connected_node.parent is target_node: connected_node.parent = None connected_node.get_connections().remove(name_str) # clear map if len(connected_node.get_connections()) == 0: self.__delete_node_recorder(connected_node) self.__delete_node_recorder(target_node)
def __delete_node_recorder(self, node): target_map = self.varnode_recorder if isinstance( node, VarNode) else self.factornode_recorder del target_map[node.name] del self.node_recorder[node.name] if node.name in self.constrained_names: self.constrained_names.remove(node.name)
[docs] def set_node(self, node_name, potential=None, isconstrained=None): """change node property 1. check whether or not in recorder 2. change potential easily[this may a duplicate function] 3. change isconstrained if possible, delete from recorder :param node_name: [description] :type node_name: [type] :param potential: [description], defaults to None :type potential: [type], optional :param isconstrained: [description], defaults to None :type isconstrained: [type], optional :raises RuntimeError: [description] """ if node_name not in self.node_recorder: raise RuntimeError node = self.node_recorder[node_name] if potential is not None: # TODO: make potential property check, when do set node.potential = potential if isconstrained is not None: if node.isconstrained != isconstrained: node.isconstrained = isconstrained if isconstrained: self.constrained_names.append(node_name) else: self.constrained_names.remove(node_name)
[docs] def export_marginals(self): """export the marginal for variable nodes :return: {node.key:node.marginal} :rtype: dict """ return { n.name: n.marginal() for n in self.varnode_recorder.values() }
[docs] def export_convergence_marginals(self): """export the marginal for variable nodes and factor nodes :return: {node.key:node.marginal} :rtype: dict """ return {n.name: n.marginal() for n in self.nodes}
[docs] def export_sinkhorn(self): return {node_name: node.sinkhorn for node_name, node in self.varnode_recorder.items()}
[docs] def plot(self, png_name='file.png'): """plot the graph through graphviz * red: constrained variable node * blue: free variable node * green: factor :param png_name: name of figure, defaults to 'file.png' :type png_name: str, optional :raises ValueError: [description] """ if pygraphviz is not None: graph = pygraphviz.AGraph(directed=False) for varnode_name in self.varnode_recorder: if varnode_name in self.constrained_names: graph.add_node(varnode_name, color='red', style='filled') else: graph.add_node(varnode_name, color='blue', style='bold') for name, factornode in self.factornode_recorder.items(): graph.add_node(name, color='green') for varnode_name in factornode.get_connections(): graph.add_edge(name, varnode_name) graph.layout(prog='neato') graph.draw(png_name) else: raise ValueError("must have pygraphviz installed for visualization")
[docs] def tree_bp(self): """run classical belief propagation on a tree graph, only need forward and backward * add attr: is_send_forward: begin send forward false, after forward before backward true, after backward false :raises RuntimeError: Only works for the tree graph, loopy graph does not work, root node not decided """ assert len(self.constrained_names) == 0 self.bake() self.first_belief_propagation() for node in self.nodes: node.is_send_forward = False tree_root = None for node in self.nodes: if len(node.connections) == 1: tree_root = node if tree_root is None: raise RuntimeError("graph contains circle") self._send_forward(tree_root) self._send_backward(tree_root)
def _send_forward(self, node): node.is_send_forward = True for cur_node in node.connected_nodes.values(): if not cur_node.is_send_forward: self._send_forward(cur_node) cur_node.send_message(node) def _send_backward(self, node): node.is_send_forward = False for cur_node in node.connected_nodes.values(): if cur_node.is_send_forward: node.send_message(cur_node) self._send_backward(cur_node)