mirror of
https://github.com/13hannes11/UU_NCML_Project.git
synced 2024-09-03 20:50:59 +02:00
Merge branch 'main' of https://github.com/13hannes11/UU_NCML_Project
This commit is contained in:
@@ -64,10 +64,7 @@ def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()):
|
||||
|
||||
# plotting parties
|
||||
plot_parties(party_pos, randomize_positions=False, new_plot=False)
|
||||
plt.xticks(np.arange(0, grid_w+1, 1.0))
|
||||
plt.yticks(np.arange(0, grid_h+1, 1.0))
|
||||
plt.grid(True)
|
||||
plt.title(f'Learning Radius: {model.learning_radius}, Grid Size: {grid_w}')
|
||||
plt.title('Node distance plot with parties')
|
||||
|
||||
# plotting party distances in output space
|
||||
part_distance_out = calc_party_distances(party_pos)
|
||||
@@ -91,13 +88,6 @@ def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()):
|
||||
plt.title(f'Normalized Distance Squared Error, with MSE={np.nanmean(err.to_numpy()):.2f}')
|
||||
plt.show()
|
||||
|
||||
|
||||
# # plotting party distances in input space (TODO discard)
|
||||
# party_pos_out = calc_party_pos(X, party_affiliation)
|
||||
# part_distance_in = calc_party_distances(party_pos_out)
|
||||
# plot_party_distances(part_distance_in)
|
||||
# plt.show()
|
||||
|
||||
def iter_neighbours(weights, hexagon=False):
|
||||
_, grid_height, grid_width = weights.shape
|
||||
|
||||
@@ -206,6 +196,7 @@ def calc_party_pos(members_of_parliament, party_affiliation):
|
||||
party_pos /= party_count
|
||||
|
||||
return pd.DataFrame(data=party_pos, index=party_index_mapping)
|
||||
|
||||
def plot_parties(parties, randomize_positions=False, new_plot=True):
|
||||
|
||||
party_index_mapping = parties.index
|
||||
@@ -220,17 +211,12 @@ def plot_parties(parties, randomize_positions=False, new_plot=True):
|
||||
xs_disp = parties[0].to_numpy()
|
||||
ys_disp = parties[1].to_numpy()
|
||||
|
||||
# party_colors=np.array(range(len(party_index_mapping)))
|
||||
# plt.scatter(xs_disp, ys_disp, c=party_colors,cmap=plt.cm.RdYlGn, zorder=2)
|
||||
# offset = 0.01
|
||||
# for x,y, party in zip(xs_disp, ys_disp, party_index_mapping):
|
||||
# plt.text(x + offset, y + offset, party)
|
||||
|
||||
for i, party in enumerate(party_index_mapping):
|
||||
print("Party ", party, " x = ", xs_disp[i], "y = ", ys_disp[i])
|
||||
plt.scatter(xs_disp[i], ys_disp[i], label=party, zorder=2)
|
||||
|
||||
plt.legend(title='Parties')
|
||||
plt.legend(title='Parties', bbox_to_anchor=(1.3, 1), loc='upper left')
|
||||
|
||||
def calc_party_distances(parties):
|
||||
distances = np.zeros((parties.shape[0], parties.shape[0]))
|
||||
for i, (_, left_party) in enumerate(parties.iterrows()):
|
||||
@@ -241,7 +227,7 @@ def calc_party_distances(parties):
|
||||
return pd.DataFrame(data=distances, index=party_index_mapping, columns=party_index_mapping)
|
||||
|
||||
def plot_party_distances(distances):
|
||||
fig = plt.figure()
|
||||
plt.figure()
|
||||
ax = plt.gca()
|
||||
ax.tick_params(axis="x", bottom=False, top=True, labelbottom=False, labeltop=True)
|
||||
sn.heatmap(distances, cmap='Oranges', annot=True)
|
||||
|
||||
Reference in New Issue
Block a user