From 3dafda0ed7ef51d54aea0966a00ceae42f757b73 Mon Sep 17 00:00:00 2001 From: Deepthi Pathare Date: Thu, 13 May 2021 11:29:38 +0200 Subject: [PATCH] Plot node distances, part of #22 --- voting_lib/voting_analysis.py | 50 ++++++++++++++++++++--------------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/voting_lib/voting_analysis.py b/voting_lib/voting_analysis.py index cf35293..d00694e 100644 --- a/voting_lib/voting_analysis.py +++ b/voting_lib/voting_analysis.py @@ -15,7 +15,7 @@ def train_model(X, grid_h, grid_w, radius, step, ep): # Create SOFM sofmnet = algorithms.SOFM( n_inputs=inp, - step=0.5, + step=step, show_epoch=100, shuffle_data=True, verbose=True, @@ -45,8 +45,20 @@ def predict(model, data, grid_h, grid_w): # calculating party positions based on mps party_pos = calc_party_pos(np.column_stack((xs, ys)), party_affiliation) + # Plot node distnaces + plt.figure() + weight = model.weight.reshape((model.n_inputs, grid_h, grid_w)) + heatmap = compute_heatmap(weight, grid_h, grid_w) + plt.imshow(heatmap, interpolation='nearest',zorder=1) + plt.axis('off') + plt.colorbar() + # plotting parties - plot_parties(party_pos, randomize_positions=False, new_plot=True) + plot_parties(party_pos, randomize_positions=False, new_plot=False) + plt.xticks(np.arange(0, grid_w+1, 1.0)) + plt.yticks(np.arange(0, grid_h+1, 1.0)) + plt.grid(True) + plt.title(f'Learning Radius: {model.learning_radius}, Grid Size: {grid_w}') plt.show() # plotting party distances in output space @@ -54,19 +66,10 @@ def predict(model, data, grid_h, grid_w): plot_party_distances(part_distance_out) plt.show() - # plotting party distances in input space - party_pos_out = calc_party_pos(X, party_affiliation) - part_distance_in = calc_party_distances(party_pos_out) - plot_party_distances(part_distance_in) - plt.show() - - ## Heatmap of weights - # plt.figure() - # weight = model.weight.reshape((model.n_inputs, grid_h, grid_w)) - # heatmap = compute_heatmap(weight, grid_h, grid_w) - # plt.imshow(heatmap, cmap='Greys_r', interpolation='nearest') - # plt.axis('off') - # plt.colorbar() + # # plotting party distances in input space (TODO discard) + # party_pos_out = calc_party_pos(X, party_affiliation) + # part_distance_in = calc_party_distances(party_pos_out) + # plot_party_distances(part_distance_in) # plt.show() def iter_neighbours(weights, hexagon=False): @@ -147,7 +150,6 @@ def plot_hoverscatter(x, y, labels, colors, cmap = plt.cm.RdYlGn): fig.canvas.draw_idle() fig.canvas.mpl_connect("motion_notify_event", hover) - #plt.show() def plot_mps(names, xs, ys, party_affiliation, randomize_positions=True): # converting parties to numeric format @@ -191,12 +193,18 @@ def plot_parties(parties, randomize_positions=False, new_plot=True): else: xs_disp = parties[0].to_numpy() ys_disp = parties[1].to_numpy() - - fig, ax = plt.subplots() + + # party_colors=np.array(range(len(party_index_mapping))) + # plt.scatter(xs_disp, ys_disp, c=party_colors,cmap=plt.cm.RdYlGn, zorder=2) + # offset = 0.01 + # for x,y, party in zip(xs_disp, ys_disp, party_index_mapping): + # plt.text(x + offset, y + offset, party) + for i, party in enumerate(party_index_mapping): - print("Party", party, " x = ", xs_disp[i], "y = ", ys_disp[i]) - plt.scatter(xs_disp[i], ys_disp[i], label=party) - plt.legend(title='Parties',bbox_to_anchor=(1.05, 1), loc='upper left') + print("Party ", party, " x = ", xs_disp[i], "y = ", ys_disp[i]) + plt.scatter(xs_disp[i], ys_disp[i], label=party, zorder=2) + + plt.legend(title='Parties',bbox_to_anchor=(1.3, 1), loc='upper left') def calc_party_distances(parties): distances = np.zeros((parties.shape[0], parties.shape[0]))