Source code for cbp.builder.migr_simulator

import numpy as np
from numpy.random import RandomState
from scipy.ndimage import gaussian_filter
from cbp.utils.np_utils import batch_normal_angle

from .migr_visualizer import MigrVisualizer
from .hmm_simulator import HMMSimulator, PotentialType


[docs]class MigrSimulator( HMMSimulator): # pylint: disable=too-many-instance-attributes def __init__(self, time_step, d_col, d_row, random_seed): # pylint: disable=too-many-function-args super().__init__(time_step, d_col * d_row, d_col * d_row, random_seed) self.d_row = d_row self.d_col = d_col self.rng = RandomState(random_seed) self._sim = { "angle_wind": np.pi / 2, "sensor_sigma": 4, "num_sensor": self.record.state_num, "destination": (self.d_row, self.d_col) } self.visualizer = MigrVisualizer(self.d_row, self.d_col)
[docs] def compile(self): """register various potential for simulation """ self._produce_transition_potential() self._produce_sensor_potential() init_potential = np.zeros(self.record.state_num) init_potential[0] = 1 self.register_potential(PotentialType.INIT, init_potential)
def _transition_factor_goal(self, cur_col, cur_row): col, row = np.meshgrid(np.arange(self.d_col), np.arange(self.d_row)) diff_row = row - cur_row diff_col = col - cur_col angle_goal = np.arctan2( self._sim["destination"][0] - 1 - cur_row, self._sim["destination"][1] - 1 - cur_col) angle_matrix = np.arctan2(diff_row, diff_col) goal_matrix = np.abs( batch_normal_angle( angle_matrix - angle_goal)) return 5 * goal_matrix def _transition_factor_wind(self, cur_col, cur_row): col, row = np.meshgrid(np.arange(self.d_col), np.arange(self.d_row)) diff_row = row - cur_row diff_col = col - cur_col angle_matrix = np.arctan2(diff_row, diff_col) wind_matrix = np.abs( batch_normal_angle( angle_matrix - self._sim["angle_wind"])) return 3 * wind_matrix def _transition_factor_dist(self, cur_col, cur_row): col, row = np.meshgrid(np.arange(self.d_col), np.arange(self.d_row)) diff_row = row - cur_row diff_col = col - cur_col dist_matrix = np.sqrt(np.power(diff_row, 2) + np.power(diff_col, 2)) return 1.6 * dist_matrix def _transition_logistic_regression(self, cur_col, cur_row): dist_matrix = self._transition_factor_dist(cur_col, cur_row) wind_matrix = self._transition_factor_wind(cur_col, cur_row) goal_matrix = self._transition_factor_goal(cur_col, cur_row) exponent = -dist_matrix - wind_matrix - goal_matrix exponent[cur_row, cur_col] += 1 cur_prob = np.exp(exponent) return cur_prob.flatten() / np.sum(cur_prob) def _produce_transition_potential(self): potential = [] for cur_row in range(self.d_row): for cur_col in range(self.d_col): potential.append( self._transition_logistic_regression( cur_col, cur_row)) self.register_potential(PotentialType.TRANSITION, np.array( potential).reshape(self.record.state_num, self.record.state_num)) def _produce_sensor_potential(self): potential = [] for i in range(self.d_row): for j in range(self.d_col): empty_matrix = np.zeros((self.d_row, self.d_col)) empty_matrix[i, j] = 100.0 result = gaussian_filter( empty_matrix, sigma=self._sim["sensor_sigma"]) potential.append(result.flatten() / np.sum(result)) self.register_potential(PotentialType.EMISSION, np.array( potential).reshape(self.record.state_num, self.record.state_num))
[docs] def viz_emission_potential(self): self.visualizer.potential_heatmap( self.record[PotentialType.EMISSION], title="sensor_potential", path=f"{self.path}/sensor_potential")
[docs] def viz_trans_potential(self): self.visualizer.potential_heatmap( self.record[PotentialType.TRANSITION], title="transition_potential", path=f"{self.path}/transition_potential")
[docs] def viz_gt(self): self.visualizer.migration(self.record["traj"], **{"title": "bird traj", "path": f"{self.path}/gt", "ylabel": True, "xlabel": 'Ground Truth'})
[docs] def viz_sensor(self): self.visualizer.migration(self.record["sensor"], **{"title": "bird traj", "path": f"{self.path}/sensor"})
[docs] def viz_estm(self, estimated_marginal): for i in range(self.record.time_step): bins = self.record["num_sample"] * \ estimated_marginal[i, :] png_name = f"{self.path}/estimated_{i}.png" self.visualizer.visualize_map_bins( bins, fig_name=png_name, xlabel='Estimated')