add input distances

This commit is contained in:
2021-05-10 20:37:43 +02:00
parent 221903d57d
commit 41a44a76b3

View File

@@ -50,11 +50,17 @@ def predict(model, data, grid_h, grid_w):
plot_parties(party_pos) plot_parties(party_pos)
plt.show() plt.show()
# plotting party distances in outputspace # plotting party distances in output space
part_distance_out = calc_party_distances(party_pos) part_distance_out = calc_party_distances(party_pos)
plot_party_distances(part_distance_out) plot_party_distances(part_distance_out)
plt.show() plt.show()
# plotting party distances in input space
party_pos_out = calc_party_pos(X, party_affiliation)
part_distance_in = calc_party_distances(party_pos_out)
plot_party_distances(part_distance_in)
plt.show()
# Heatmap of weights # Heatmap of weights
plt.figure() plt.figure()
weight = model.weight.reshape((model.n_inputs, grid_h, grid_w)) weight = model.weight.reshape((model.n_inputs, grid_h, grid_w))
@@ -160,11 +166,10 @@ def calc_party_pos(members_of_parliament, party_affiliation):
party_pos = np.zeros((party_index_mapping.shape[0], members_of_parliament.shape[1])) party_pos = np.zeros((party_index_mapping.shape[0], members_of_parliament.shape[1]))
party_count = np.zeros((party_index_mapping.shape[0], members_of_parliament.shape[1])) party_count = np.zeros((party_index_mapping.shape[0], members_of_parliament.shape[1]))
party_pos
for i, mp in enumerate(members_of_parliament): for i, mp in enumerate(members_of_parliament):
party_index = party_ids[i] party_index = party_ids[i]
party_pos[party_index] += mp party_pos[party_index] = party_pos[party_index] + mp
party_count[party_index] += 1 party_count[party_index] += 1
party_pos /= party_count party_pos /= party_count