diff --git a/voting_lib/voting_analysis.py b/voting_lib/voting_analysis.py index 7a58905..dbd2286 100644 --- a/voting_lib/voting_analysis.py +++ b/voting_lib/voting_analysis.py @@ -64,10 +64,7 @@ def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()): # plotting parties plot_parties(party_pos, randomize_positions=False, new_plot=False) - plt.xticks(np.arange(0, grid_w+1, 1.0)) - plt.yticks(np.arange(0, grid_h+1, 1.0)) - plt.grid(True) - plt.title(f'Learning Radius: {model.learning_radius}, Grid Size: {grid_w}') + plt.title('Node distance plot with parties') # plotting party distances in output space part_distance_out = calc_party_distances(party_pos) @@ -91,13 +88,6 @@ def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()): plt.title(f'Normalized Distance Squared Error, with MSE={np.nanmean(err.to_numpy()):.2f}') plt.show() - - # # plotting party distances in input space (TODO discard) - # party_pos_out = calc_party_pos(X, party_affiliation) - # part_distance_in = calc_party_distances(party_pos_out) - # plot_party_distances(part_distance_in) - # plt.show() - def iter_neighbours(weights, hexagon=False): _, grid_height, grid_width = weights.shape @@ -206,6 +196,7 @@ def calc_party_pos(members_of_parliament, party_affiliation): party_pos /= party_count return pd.DataFrame(data=party_pos, index=party_index_mapping) + def plot_parties(parties, randomize_positions=False, new_plot=True): party_index_mapping = parties.index @@ -219,18 +210,13 @@ def plot_parties(parties, randomize_positions=False, new_plot=True): else: xs_disp = parties[0].to_numpy() ys_disp = parties[1].to_numpy() - - # party_colors=np.array(range(len(party_index_mapping))) - # plt.scatter(xs_disp, ys_disp, c=party_colors,cmap=plt.cm.RdYlGn, zorder=2) - # offset = 0.01 - # for x,y, party in zip(xs_disp, ys_disp, party_index_mapping): - # plt.text(x + offset, y + offset, party) for i, party in enumerate(party_index_mapping): print("Party ", party, " x = ", xs_disp[i], "y = ", ys_disp[i]) plt.scatter(xs_disp[i], ys_disp[i], label=party, zorder=2) - plt.legend(title='Parties') + plt.legend(title='Parties', bbox_to_anchor=(1.3, 1), loc='upper left') + def calc_party_distances(parties): distances = np.zeros((parties.shape[0], parties.shape[0])) for i, (_, left_party) in enumerate(parties.iterrows()): @@ -241,7 +227,7 @@ def calc_party_distances(parties): return pd.DataFrame(data=distances, index=party_index_mapping, columns=party_index_mapping) def plot_party_distances(distances): - fig = plt.figure() + plt.figure() ax = plt.gca() ax.tick_params(axis="x", bottom=False, top=True, labelbottom=False, labeltop=True) sn.heatmap(distances, cmap='Oranges', annot=True)