[1]:
import numpy as np
import cbp
1. Construct nodes¶
1.1 Construct Variable node¶
[2]:
rv_dim = 3
varnode_1 = cbp.node.VarNode(rv_dim)
fix_marginal = np.ones(3) / 3
varnode_2 = cbp.node.VarNode(rv_dim,constrained_marginal=fix_marginal)
1.2 Construct Factor node¶
[3]:
connected_var = ["VarNode_000","VarNode_001"]
factor_node = cbp.node.FactorNode(connected_var, np.ones((3,3)))
2. Construct Graph¶
[4]:
graph = cbp.graph.GraphModel()
graph.add_varnode(varnode_1) # GraphModel use a simple for naming added node, varnode_1.name == "VarNode_000"
graph.add_varnode(varnode_2) # "VarNode_001"
graph.add_factornode(factor_node) # 'FactorNode_000'
[4]:
'FactorNode_000'
3. Run inference¶
[5]:
graph.run_bp() # run iterative scaling
# graph.run_cnp() # run cnp
[5]:
([0.0], 1)
4. Access to the marginal¶
[6]:
varnode = graph.get_node('VarNode_000') # same node as varnode_1
print(varnode.marginal())
print(factor_node.marginal())
[0.33333333 0.33333333 0.33333333]
[[0.11111111 0.11111111 0.11111111]
[0.11111111 0.11111111 0.11111111]
[0.11111111 0.11111111 0.11111111]]
[7]:
print(graph.export_convergence_marginals())
{'FactorNode_000': array([[0.11111111, 0.11111111, 0.11111111],
[0.11111111, 0.11111111, 0.11111111],
[0.11111111, 0.11111111, 0.11111111]]), 'VarNode_000': array([0.33333333, 0.33333333, 0.33333333]), 'VarNode_001': array([0.33333333, 0.33333333, 0.33333333])}
5. Construct HMM model¶
[8]:
num_hidden_state = 3
num_obser_state = 4
T = 3
hmm = cbp.graph.GraphModel()
5.2 Construct observation state node¶
[10]:
for i in range(T):
hmm.add_varnode(cbp.node.VarNode(num_obser_state,np.random.dirichlet([1]*num_obser_state)))
5.3 Construct transition factor node¶
[11]:
for i in range(T-1):
hmm.add_factornode(cbp.node.FactorNode([f"VarNode_{i:03d}",f"VarNode_{i+1:03d}"],potential=np.random.dirichlet([1]*num_hidden_state,size=num_hidden_state)))
5.4 Construct emission factor node¶
[12]:
for i in range(T):
hmm.add_factornode(cbp.node.FactorNode([f"VarNode_{i:03d}",f"VarNode_{i+T:03d}"],potential=np.random.dirichlet([1]*num_obser_state,size=num_hidden_state)))
5.5 Visualize hmm graph¶
[13]:
hmm.plot()
5.6 Run inference¶
[14]:
hmm.run_bp()
[14]:
([0.9803173953701733, 0.05627442398462251, 7.303185833862358e-16], 3)