Hit plot, closes #21

This commit is contained in:
Deepthi Pathare
2021-05-18 18:27:31 +02:00
parent 43c77768ee
commit c5cf10ff65

View File

@@ -34,13 +34,16 @@ def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()):
prediction = model.predict(X)
print(f'prediction: {prediction}')
# Plot hit map
plot_hits(prediction, grid_w, grid_h)
# converting to x and y coordinates
ys, xs = np.unravel_index(np.argmax(prediction, axis=1), (grid_h, grid_w))
# plotting mps
party_affiliation = data[:,1]
plot_mps(data[:,0], xs, ys, party_affiliation, randomize_positions=True)
plt.show()
plt.show()
# calculating party positions based on mps
party_pos = calc_party_pos(np.column_stack((xs, ys)), party_affiliation)
@@ -51,7 +54,7 @@ def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()):
plt.figure()
weight = model.weight.reshape((model.n_inputs, grid_h, grid_w))
heatmap = compute_heatmap(weight, grid_h, grid_w)
plt.imshow(heatmap, interpolation='nearest',zorder=1)
plt.imshow(heatmap, cmap ='Blues', interpolation='nearest',zorder=1)
plt.axis('off')
plt.colorbar()
@@ -238,7 +241,11 @@ def plot_party_distances(distances):
ax.tick_params(axis="x", bottom=False, top=True, labelbottom=False, labeltop=True)
sn.heatmap(distances, cmap='Oranges', annot=True)
def plot_hits(prediction, grid_w, grid_h):
hits = (prediction.sum(axis=0)).reshape(grid_w, grid_h)
plt.figure()
sn.heatmap(hits, annot=True, xticklabels=False, yticklabels=False, cbar=False)
def normalize_df(dataframe):
df = dataframe.copy(deep=True)
df = df - np.min(df.to_numpy())