Source code for cbp.node.base_node

import uuid
from abc import ABC, abstractmethod
import numpy as np
from cbp.utils.message import Message


[docs]class BaseNode(ABC): """All kinds node must inherit :class `~cbp.node.BaseNode` """ # pylint: disable=too-many-instance-attributes def __init__(self, node_coef, potential) -> None: """Initialize default attr Every node need to have the following attr: * ``name`` str. id for the node :param node_coef: works for the norm-product algorithm :type node_coef: float :param potential: potential :type potential: ndarray or list """ self.name = str(uuid.uuid4()) self.node_coef = node_coef self._potential = None self.potential = potential self.epsilon = 1 self.coef_ready = False self.is_traversed = False self.parent = None self.node_degree = 0 self.connections = [] self.message_inbox = {} self.latest_message = [] self.connected_nodes = {} def __str__(self): return f"{self.name}" def __repr__(self): return self.__str__() @property def potential(self): return self._potential @potential.setter def potential(self, potential): self._potential = self._check_potential(potential) @abstractmethod def _check_potential(self, potential) -> np.ndarray: """check potential before set node potential :param potential: input potential :type potential: np.ndarray :return: [description] :rtype: np.ndarray """
[docs] def format_name(self, name): self.name = name
[docs] def reset_node_coef(self, coef): self.node_coef = coef
[docs] def auto_coef(self, node_map, assign_policy=None): if assign_policy is None: self.node_coef = 1.0 / len(node_map) else: self.node_coef = assign_policy(self, node_map) self.register_nodes(node_map)
# TODO:, should remove node_map parameter
[docs] def cal_cnp_coef(self): raise NotImplementedError( f"{self.__class__.__name__} is an abstract class")
[docs] def check_before_run(self, node_map): for item in self.connections: assert item in node_map, f"{self.name} has a connection {item}, \ which is not in node_map"
[docs] def make_init_message(self, recipient_node_name): if self.coef_ready: recipient_node = self.connected_nodes[recipient_node_name] message_dim = recipient_node.potential.shape return np.ones(message_dim) raise RuntimeError( f"Need to call cal_cnp_coef first for {self.name}")
# keep all message looks urgly. convenient for debug and resource occupied # is not so huge
[docs] def store_message(self, message): sender_name = message.sender.name self.message_inbox[sender_name] = message self.latest_message = list(self.message_inbox.values())
[docs] def reset(self): self.message_inbox.clear()
# TODO: FIXAPI NAME
[docs] @abstractmethod def make_message(self, recipient_node) -> np.ndarray: """produce the val of message from current node to the recipient_node :param recipient_node: target node :type recipient_node: [type] :return: content of the message :rtype: np.ndarray """
[docs] @abstractmethod def cal_bethe(self, margin) -> float: """calculate the bethe energy :return: bethe energy on this node :rtype: float """
[docs] def send_message(self, recipient_node, is_silent=True): val = self.make_message(recipient_node) message = Message(self, val) recipient_node.store_message(message) if not is_silent: print(self.name + '->' + recipient_node.name) print(message.val)
[docs] def sendin_message(self, is_silent=True): for connected_node in self.connected_nodes.values(): connected_node.send_message(self, is_silent)
[docs] def sendout_message(self, is_silent=True): for connected_node in self.connected_nodes.values(): self.send_message(connected_node, is_silent)
[docs] def register_connection(self, node_name): self.node_degree += 1 self.connections.append(node_name)
[docs] def register_nodes(self, node_map): for item in self.connections: if item in node_map: self.connected_nodes[item] = node_map[item] else: raise IOError(f"connection of {item} of {self.name} \ do not appear in the node_map")
[docs] def get_connections(self): return self.connections
[docs] def search_node_index(self, node_name): return self.connections.index(node_name)
[docs] def search_msg_index(self, message_list, node_name): which_index = [i for i, message in enumerate(message_list) if message.sender.name == node_name] if which_index: return which_index[0] raise RuntimeError( f"{node_name} do not appear in {self.name} message")