add political compass for germany including plotting

This commit is contained in:
2021-05-18 15:28:16 +02:00
parent d41a9b3a85
commit ff0a416e26
3 changed files with 70 additions and 4 deletions

View File

@@ -26,7 +26,7 @@ def train_model(X, grid_h, grid_w, radius, step, ep):
sofmnet.train(X, epochs=ep)
return sofmnet
def predict(model, data, grid_h, grid_w):
def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()):
X = data[:,2:]
@@ -45,6 +45,8 @@ def predict(model, data, grid_h, grid_w):
# calculating party positions based on mps
party_pos = calc_party_pos(np.column_stack((xs, ys)), party_affiliation)
print(party_pos)
# Plot node distnaces
plt.figure()
weight = model.weight.reshape((model.n_inputs, grid_h, grid_w))
@@ -66,6 +68,22 @@ def predict(model, data, grid_h, grid_w):
plot_party_distances(part_distance_out)
plt.show()
if not comparison_data.empty:
plot_parties(comparison_data, randomize_positions=False, new_plot=True)
plt.title("political compass")
plt.ylabel("libertarian - authoritarian")
plt.xlabel("left < economic > right")
plt.show()
comparison_data_dist = calc_party_distances(comparison_data)
plot_party_distances(comparison_data_dist)
plt.show()
err = normalize_df(part_distance_out) - normalize_df(comparison_data_dist)
err = err * err
plot_party_distances(err)
plt.title(f'distance squared error, with mse={str(np.nanmean(err.to_numpy())):.2}')
plt.show()
# # plotting party distances in input space (TODO discard)
# party_pos_out = calc_party_pos(X, party_affiliation)
# part_distance_in = calc_party_distances(party_pos_out)
@@ -204,8 +222,7 @@ def plot_parties(parties, randomize_positions=False, new_plot=True):
print("Party ", party, " x = ", xs_disp[i], "y = ", ys_disp[i])
plt.scatter(xs_disp[i], ys_disp[i], label=party, zorder=2)
plt.legend(title='Parties',bbox_to_anchor=(1.3, 1), loc='upper left')
plt.legend(title='Parties')
def calc_party_distances(parties):
distances = np.zeros((parties.shape[0], parties.shape[0]))
for i, (_, left_party) in enumerate(parties.iterrows()):
@@ -220,3 +237,10 @@ def plot_party_distances(distances):
ax = plt.gca()
ax.tick_params(axis="x", bottom=False, top=True, labelbottom=False, labeltop=True)
sn.heatmap(distances, cmap='Oranges', annot=True)
def normalize_df(dataframe):
df = dataframe.copy(deep=True)
df = df - np.min(df.to_numpy())
df = df / np.max(df.to_numpy())
return df