Source code for cbp.node.var_node

import numpy as np
from cbp.utils.np_utils import nd_expand

from .base_node import BaseNode


[docs]class VarNode(BaseNode): """Variable Node in Factor graph Add new attr: * ``isconstrained`` Fixed marginal or not * ``hat_c_i`` See Norm-Product paper """ def __init__(self, rv_dim, potential=None, constrained_marginal=None, node_coef=1): self.rv_dim = rv_dim self.hat_c_i = None super().__init__(node_coef, potential) if constrained_marginal is None: self.isconstrained = False else: assert constrained_marginal.shape[0] == rv_dim assert abs(np.sum(constrained_marginal) - 1) < 1e-6 self.isconstrained = True constrained_marginal = np.clip(constrained_marginal, 1e-12, None) self.constrained_marginal = constrained_marginal def _check_potential(self, potential): if potential is None: return np.ones([self.rv_dim]) assert potential.shape[0] == self.rv_dim final_potential = np.clip(potential, 1e-12, None) return final_potential / np.sum(final_potential)
[docs] def auto_coef(self, node_map, assign_policy=None): super().auto_coef(node_map, assign_policy) sum_i_alpha = 0 unset_edge = None for item in self.connected_nodes.values(): i_alpha = item.get_i_alpha(self.name) if i_alpha is not None: sum_i_alpha += i_alpha else: unset_edge = item.name if unset_edge: new_i_alpha = self.node_coef - \ (1 - len(self.connections)) - sum_i_alpha self.connected_nodes[unset_edge].set_i_alpha(self.name, new_i_alpha)
[docs] def cal_cnp_coef(self): self.coef_ready = True self.hat_c_i = self.node_coef for item in self.connections: self.hat_c_i += self.connected_nodes[item].node_coef
def _make_message_first_term(self, recipient_node): recipient_index_in_var = self.search_msg_index(self.latest_message, recipient_node.name) hat_c_ialpha = recipient_node.get_hat_c_ialpha(self.name) c_alpha = recipient_node.node_coef vals = [message.val for message in self.latest_message] with np.errstate(divide='raise'): if self.isconstrained: log_numerator = self.epsilon * np.log(self.constrained_marginal) else: potential_part = 1.0 / self.hat_c_i * np.log(self.potential) message_part = 1.0 / self.hat_c_i * \ np.log(np.clip(np.prod(vals, axis=0), 1e-12, None)) log_numerator = potential_part + message_part clip_base = np.clip(vals[recipient_index_in_var], 1e-12, None) log_denominator = 1.0 / hat_c_ialpha * np.log(clip_base) log_base = c_alpha * (log_numerator - log_denominator) return np.exp(log_base)
[docs] def make_message_bp(self, recipient_node): assert self.coef_ready,\ f"{self.name} need to cal_cnp_coef by graph firstly" # first_term.shape equals (self.rv_dim,) first_term = self._make_message_first_term(recipient_node) assert first_term.shape[0] == self.rv_dim # second_term shape equals shape of recipient_node second_term = recipient_node.get_varnode_extra_term(self.name) assert second_term.shape == self.connected_nodes[recipient_node.name].potential.shape var_index_in_recipient = recipient_node.search_node_index(self.name) expanded_first_term = nd_expand( first_term, second_term.shape, var_index_in_recipient) return np.multiply(expanded_first_term, second_term)
[docs] def make_message(self, recipient_node): return self.make_message_bp(recipient_node)
[docs] def cal_bethe(self, margin): clip_margin = np.clip(margin, 1e-12, None) log_margin = np.log(clip_margin) entropy_term = -(self.node_degree - 1) * np.sum(margin * log_margin) clip_potential = np.clip(self.potential, 1e-12, None) potential_term = -np.sum(margin * np.log(clip_potential)) return potential_term + entropy_term
[docs] def marginal(self): if self.isconstrained: return self.constrained_marginal if self.message_inbox: vals = [message.val for message in self.latest_message] vals_prod = np.prod(vals, axis=0) prod = self.potential * vals_prod belief = np.power(prod, 1.0 / self.hat_c_i) return belief / np.sum(belief) return np.ones(self.rv_dim) / self.rv_dim