Plot node distances, part of #22

This commit is contained in:
Deepthi Pathare
2021-05-13 11:29:38 +02:00
parent bf6c2e4f40
commit 3dafda0ed7

View File

@@ -15,7 +15,7 @@ def train_model(X, grid_h, grid_w, radius, step, ep):
# Create SOFM # Create SOFM
sofmnet = algorithms.SOFM( sofmnet = algorithms.SOFM(
n_inputs=inp, n_inputs=inp,
step=0.5, step=step,
show_epoch=100, show_epoch=100,
shuffle_data=True, shuffle_data=True,
verbose=True, verbose=True,
@@ -45,8 +45,20 @@ def predict(model, data, grid_h, grid_w):
# calculating party positions based on mps # calculating party positions based on mps
party_pos = calc_party_pos(np.column_stack((xs, ys)), party_affiliation) party_pos = calc_party_pos(np.column_stack((xs, ys)), party_affiliation)
# Plot node distnaces
plt.figure()
weight = model.weight.reshape((model.n_inputs, grid_h, grid_w))
heatmap = compute_heatmap(weight, grid_h, grid_w)
plt.imshow(heatmap, interpolation='nearest',zorder=1)
plt.axis('off')
plt.colorbar()
# plotting parties # plotting parties
plot_parties(party_pos, randomize_positions=False, new_plot=True) plot_parties(party_pos, randomize_positions=False, new_plot=False)
plt.xticks(np.arange(0, grid_w+1, 1.0))
plt.yticks(np.arange(0, grid_h+1, 1.0))
plt.grid(True)
plt.title(f'Learning Radius: {model.learning_radius}, Grid Size: {grid_w}')
plt.show() plt.show()
# plotting party distances in output space # plotting party distances in output space
@@ -54,19 +66,10 @@ def predict(model, data, grid_h, grid_w):
plot_party_distances(part_distance_out) plot_party_distances(part_distance_out)
plt.show() plt.show()
# plotting party distances in input space # # plotting party distances in input space (TODO discard)
party_pos_out = calc_party_pos(X, party_affiliation) # party_pos_out = calc_party_pos(X, party_affiliation)
part_distance_in = calc_party_distances(party_pos_out) # part_distance_in = calc_party_distances(party_pos_out)
plot_party_distances(part_distance_in) # plot_party_distances(part_distance_in)
plt.show()
## Heatmap of weights
# plt.figure()
# weight = model.weight.reshape((model.n_inputs, grid_h, grid_w))
# heatmap = compute_heatmap(weight, grid_h, grid_w)
# plt.imshow(heatmap, cmap='Greys_r', interpolation='nearest')
# plt.axis('off')
# plt.colorbar()
# plt.show() # plt.show()
def iter_neighbours(weights, hexagon=False): def iter_neighbours(weights, hexagon=False):
@@ -147,7 +150,6 @@ def plot_hoverscatter(x, y, labels, colors, cmap = plt.cm.RdYlGn):
fig.canvas.draw_idle() fig.canvas.draw_idle()
fig.canvas.mpl_connect("motion_notify_event", hover) fig.canvas.mpl_connect("motion_notify_event", hover)
#plt.show()
def plot_mps(names, xs, ys, party_affiliation, randomize_positions=True): def plot_mps(names, xs, ys, party_affiliation, randomize_positions=True):
# converting parties to numeric format # converting parties to numeric format
@@ -191,12 +193,18 @@ def plot_parties(parties, randomize_positions=False, new_plot=True):
else: else:
xs_disp = parties[0].to_numpy() xs_disp = parties[0].to_numpy()
ys_disp = parties[1].to_numpy() ys_disp = parties[1].to_numpy()
fig, ax = plt.subplots() # party_colors=np.array(range(len(party_index_mapping)))
# plt.scatter(xs_disp, ys_disp, c=party_colors,cmap=plt.cm.RdYlGn, zorder=2)
# offset = 0.01
# for x,y, party in zip(xs_disp, ys_disp, party_index_mapping):
# plt.text(x + offset, y + offset, party)
for i, party in enumerate(party_index_mapping): for i, party in enumerate(party_index_mapping):
print("Party", party, " x = ", xs_disp[i], "y = ", ys_disp[i]) print("Party ", party, " x = ", xs_disp[i], "y = ", ys_disp[i])
plt.scatter(xs_disp[i], ys_disp[i], label=party) plt.scatter(xs_disp[i], ys_disp[i], label=party, zorder=2)
plt.legend(title='Parties',bbox_to_anchor=(1.05, 1), loc='upper left')
plt.legend(title='Parties',bbox_to_anchor=(1.3, 1), loc='upper left')
def calc_party_distances(parties): def calc_party_distances(parties):
distances = np.zeros((parties.shape[0], parties.shape[0])) distances = np.zeros((parties.shape[0], parties.shape[0]))