add input distances

This commit is contained in:
2021-05-10 20:37:43 +02:00
parent 221903d57d
commit 41a44a76b3

View File

@@ -53,6 +53,12 @@ def predict(model, data, grid_h, grid_w):
# plotting party distances in output space # plotting party distances in output space
part_distance_out = calc_party_distances(party_pos) part_distance_out = calc_party_distances(party_pos)
plot_party_distances(part_distance_out) plot_party_distances(part_distance_out)
plt.show()
# plotting party distances in input space
party_pos_out = calc_party_pos(X, party_affiliation)
part_distance_in = calc_party_distances(party_pos_out)
plot_party_distances(part_distance_in)
plt.show() plt.show()
# Heatmap of weights # Heatmap of weights
@@ -160,11 +166,10 @@ def calc_party_pos(members_of_parliament, party_affiliation):
party_pos = np.zeros((party_index_mapping.shape[0], members_of_parliament.shape[1])) party_pos = np.zeros((party_index_mapping.shape[0], members_of_parliament.shape[1]))
party_count = np.zeros((party_index_mapping.shape[0], members_of_parliament.shape[1])) party_count = np.zeros((party_index_mapping.shape[0], members_of_parliament.shape[1]))
party_pos
for i, mp in enumerate(members_of_parliament): for i, mp in enumerate(members_of_parliament):
party_index = party_ids[i] party_index = party_ids[i]
party_pos[party_index] += mp party_pos[party_index] = party_pos[party_index] + mp
party_count[party_index] += 1 party_count[party_index] += 1
party_pos /= party_count party_pos /= party_count