mirror of
https://github.com/13hannes11/UU_NCML_Project.git
synced 2024-09-03 20:50:59 +02:00
Updated party colors & modified training params
This commit is contained in:
@@ -81,6 +81,7 @@ def predict(model, data, grid_h, grid_w, party_colors, comparison_data=pd.DataFr
|
||||
err = remove_NaN_rows_columns(normalize_df(part_distance_out) - normalize_df(comparison_data_dist))
|
||||
err = err * err
|
||||
plot_party_distances(err)
|
||||
plt.title(f'MSE={np.nanmean(err.to_numpy()):.2f}')
|
||||
plt.show()
|
||||
|
||||
def iter_neighbours(weights, hexagon=False):
|
||||
@@ -218,7 +219,7 @@ def plot_parties(parties, party_colors, randomize_positions=False, new_plot=True
|
||||
|
||||
for i, party in enumerate(party_index_mapping):
|
||||
print("Party ", party, " x = ", xs_disp[i], "y = ", ys_disp[i])
|
||||
plt.scatter(xs_disp[i], ys_disp[i], label=party, zorder=2, c=colors[i], edgecolors='black')
|
||||
plt.scatter(xs_disp[i], ys_disp[i], label=party, zorder=2, c=colors[i], edgecolors='None')
|
||||
|
||||
plt.legend(title='Parties', bbox_to_anchor=(1.3, 1), loc='upper left')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user