mirror of
https://github.com/13hannes11/UU_NCML_Project.git
synced 2024-09-03 20:50:59 +02:00
introduce party color mapping to make colors more consistent
This commit is contained in:
@@ -8,6 +8,7 @@ import matplotlib.pyplot as plt
|
||||
import pandas as pd
|
||||
import seaborn as sn
|
||||
import re
|
||||
from matplotlib.colors import ListedColormap
|
||||
|
||||
def train_model(X, grid_h, grid_w, radius, step, ep):
|
||||
|
||||
@@ -27,7 +28,7 @@ def train_model(X, grid_h, grid_w, radius, step, ep):
|
||||
sofmnet.train(X, epochs=ep)
|
||||
return sofmnet
|
||||
|
||||
def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()):
|
||||
def predict(model, data, grid_h, grid_w, party_colors, comparison_data=pd.DataFrame()):
|
||||
# default tight layout
|
||||
plt.rcParams["figure.autolayout"] = True
|
||||
|
||||
@@ -46,7 +47,7 @@ def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()):
|
||||
|
||||
# plotting mps
|
||||
party_affiliation = data[:,1]
|
||||
plot_mps(data[:,0], xs, ys, party_affiliation, randomize_positions=True)
|
||||
plot_mps(data[:,0], xs, ys, party_affiliation, party_colors, randomize_positions=True)
|
||||
plt.title("Members of Parliament")
|
||||
plt.show()
|
||||
|
||||
@@ -64,7 +65,7 @@ def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()):
|
||||
plt.colorbar()
|
||||
|
||||
# plotting parties
|
||||
plot_parties(party_pos, randomize_positions=False, new_plot=False)
|
||||
plot_parties(party_pos, party_colors, randomize_positions=False, new_plot=False)
|
||||
plt.title('Node distance plot with parties')
|
||||
|
||||
# plotting party distances in output space
|
||||
@@ -74,7 +75,7 @@ def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()):
|
||||
plt.show()
|
||||
|
||||
if not comparison_data.empty:
|
||||
plot_parties(comparison_data, randomize_positions=False, new_plot=True)
|
||||
plot_parties(comparison_data, party_colors, randomize_positions=False, new_plot=True)
|
||||
plt.title("Political Compass")
|
||||
plt.ylabel("libertarian - authoritarian")
|
||||
plt.xlabel("left < economic > right")
|
||||
@@ -173,7 +174,7 @@ def plot_hoverscatter(x, y, categories, hover_labels, colors, cmap = plt.cm.RdYl
|
||||
|
||||
fig.canvas.mpl_connect("motion_notify_event", hover)
|
||||
|
||||
def plot_mps(names, xs, ys, party_affiliation, randomize_positions=True):
|
||||
def plot_mps(names, xs, ys, party_affiliation, party_colors, randomize_positions=True):
|
||||
# converting parties to numeric format
|
||||
party_index_mapping, party_ids = np.unique(party_affiliation, return_inverse=True)
|
||||
|
||||
@@ -186,7 +187,9 @@ def plot_mps(names, xs, ys, party_affiliation, randomize_positions=True):
|
||||
ys_disp = ys
|
||||
|
||||
parties = party_index_mapping[party_ids]
|
||||
plot_hoverscatter(xs_disp, ys_disp, party_index_mapping, names + " (" + parties + ")", party_ids)
|
||||
|
||||
colormap = ListedColormap(list(map(lambda x: party_colors[x], party_index_mapping)))
|
||||
plot_hoverscatter(xs_disp, ys_disp, party_index_mapping, names + " (" + parties + ")", party_ids, cmap=colormap)
|
||||
|
||||
def calc_party_pos(members_of_parliament, party_affiliation):
|
||||
party_index_mapping, party_ids = np.unique(party_affiliation, return_inverse=True)
|
||||
@@ -203,9 +206,12 @@ def calc_party_pos(members_of_parliament, party_affiliation):
|
||||
|
||||
return pd.DataFrame(data=party_pos, index=party_index_mapping)
|
||||
|
||||
def plot_parties(parties, randomize_positions=False, new_plot=True):
|
||||
def plot_parties(parties, party_colors, randomize_positions=False, new_plot=True):
|
||||
|
||||
party_index_mapping = parties.index
|
||||
|
||||
colors = list(map(lambda x: party_colors[x], party_index_mapping))
|
||||
|
||||
|
||||
if new_plot:
|
||||
plt.figure()
|
||||
@@ -219,7 +225,7 @@ def plot_parties(parties, randomize_positions=False, new_plot=True):
|
||||
|
||||
for i, party in enumerate(party_index_mapping):
|
||||
print("Party ", party, " x = ", xs_disp[i], "y = ", ys_disp[i])
|
||||
plt.scatter(xs_disp[i], ys_disp[i], label=party, zorder=2)
|
||||
plt.scatter(xs_disp[i], ys_disp[i], label=party, zorder=2, c=colors[i], edgecolors='black')
|
||||
|
||||
plt.legend(title='Parties', bbox_to_anchor=(1.3, 1), loc='upper left')
|
||||
|
||||
|
||||
Reference in New Issue
Block a user