import numpy as np
from cbp.utils.np_utils import nd_expand
from .base_node import BaseNode
[docs]class FactorNode(BaseNode):
"""Factor Node in factor graph
Add new attr:
* ``isconstrained`` Fixed marginal or not
* ``hat_c_ialpha`` See Norm-Product paper
* ``last_innerparenthese_msg`` See Norm-Product paper
"""
def __init__(self, connections, potential, coef=1):
super().__init__(coef, potential)
self.connections = connections
self.last_innerparenthese_msg = {}
self.hat_c_ialpha = {}
self.i_alpha = {}
num_connectednode = []
for item in self.connections:
self.i_alpha[item] = None
num_connectednode.append(int(item[8:]))
if any(i > j for i, j in zip(num_connectednode, num_connectednode[1:])):
raise RuntimeError('Set the connection of factor in order')
[docs] def check_before_run(self, node_map):
super().check_before_run(node_map)
self.check_potential(node_map)
[docs] def check_potential(self, node_map):
for i, varnode_name in enumerate(self.connections):
varnode = node_map[varnode_name]
assert self.potential.shape[i] == varnode.rv_dim, \
f"Dimention mismatch! At {i:02d} axis in Factor:{self.name} \
rv_dim:{varnode.rv_dim:02d}, \
potential: {self.potential.shape[i]}"
self.last_innerparenthese_msg[varnode_name] = np.ones(
self.potential.shape)
def _check_potential(self, potential):
return potential / np.sum(potential)
[docs] def auto_coef(self, node_map, assign_policy=None):
super().auto_coef(node_map, assign_policy)
sum_i_alpha = 0
unset_edge = None
for connected_var in self.connections:
i_alpha = self.get_i_alpha(connected_var)
if i_alpha is not None:
sum_i_alpha += i_alpha
else:
unset_edge = connected_var
if unset_edge:
new_i_alpha = 1 - self.node_coef - sum_i_alpha
self.set_i_alpha(unset_edge, new_i_alpha)
[docs] def get_i_alpha(self, connection_name):
return self.i_alpha[connection_name]
[docs] def set_i_alpha(self, connection_name, value):
self.i_alpha[connection_name] = value
[docs] def cal_cnp_coef(self):
self.coef_ready = True
for item in self.connections:
hat_c_ialpha = self.node_coef + self.i_alpha[item]
assert hat_c_ialpha != 0
self.hat_c_ialpha[item] = hat_c_ialpha
[docs] def get_hat_c_ialpha(self, node_name):
if self.coef_ready:
return self.hat_c_ialpha[node_name]
return None
[docs] def make_message(self, recipient_node):
assert recipient_node.name in self.connections
if len(self.connections) == 1:
self.last_innerparenthese_msg[recipient_node.name] = self.potential
return self.summation(self.potential, recipient_node)
product_out = self.cal_inner_parentheses(recipient_node)
with np.errstate(divide='raise'):
hat_c_ialpha = self.hat_c_ialpha[recipient_node.name]
log_media = 1.0 / hat_c_ialpha * \
np.log(np.clip(product_out, 1e-12, None))
product_out_power = np.exp(log_media)
return np.power(
self.summation(
product_out_power,
recipient_node),
hat_c_ialpha)
[docs] def cal_bethe(self, margin):
clip_potential = np.clip(self.potential, 1e-12, None)
return np.sum(margin * np.log(margin / clip_potential))
[docs] def marginal(self):
message_val = np.array([message.val for message in self.latest_message])
prod_messages = np.prod(message_val, axis=0)
product_out = np.multiply(self.potential, prod_messages)
unormalized = np.power(product_out, 1.0 / self.node_coef)
return unormalized / np.sum(unormalized)
[docs] def cal_inner_parentheses(self, recipient_node):
latest_message = self.latest_message
filtered_message = [message for message in latest_message
if not message.sender.name == recipient_node.name]
message_val = np.array([message.val for message in filtered_message])
prod_messages = np.prod(message_val, axis=0)
product_out = np.multiply(self.potential, prod_messages)
self.last_innerparenthese_msg[recipient_node.name] = product_out
return product_out
[docs] def store_message(self, message):
assert message.val.shape == self.potential.shape, \
f"From {message.sender.name} to {self.name} shape mismatch, \
expected {self.potential.shape}, received {message.val.shape}"
super().store_message(message)
[docs] def summation(self, potential, node):
potential_dim = potential.shape
node_index = self.connections.index(node.name)
assert potential_dim[node_index] == node.rv_dim
return potential.sum(
tuple(j for j in range(potential.ndim) if j != node_index))