mirror of
https://github.com/13hannes11/UU_NCML_Project.git
synced 2024-09-03 20:50:59 +02:00
add input distances
This commit is contained in:
@@ -53,6 +53,12 @@ def predict(model, data, grid_h, grid_w):
|
||||
# plotting party distances in output space
|
||||
part_distance_out = calc_party_distances(party_pos)
|
||||
plot_party_distances(part_distance_out)
|
||||
plt.show()
|
||||
|
||||
# plotting party distances in input space
|
||||
party_pos_out = calc_party_pos(X, party_affiliation)
|
||||
part_distance_in = calc_party_distances(party_pos_out)
|
||||
plot_party_distances(part_distance_in)
|
||||
plt.show()
|
||||
|
||||
# Heatmap of weights
|
||||
@@ -160,11 +166,10 @@ def calc_party_pos(members_of_parliament, party_affiliation):
|
||||
|
||||
party_pos = np.zeros((party_index_mapping.shape[0], members_of_parliament.shape[1]))
|
||||
party_count = np.zeros((party_index_mapping.shape[0], members_of_parliament.shape[1]))
|
||||
party_pos
|
||||
|
||||
for i, mp in enumerate(members_of_parliament):
|
||||
party_index = party_ids[i]
|
||||
party_pos[party_index] += mp
|
||||
party_pos[party_index] = party_pos[party_index] + mp
|
||||
party_count[party_index] += 1
|
||||
|
||||
party_pos /= party_count
|
||||
|
||||
Reference in New Issue
Block a user