Hit plot, closes #21

This commit is contained in:
Deepthi Pathare
2021-05-18 18:27:31 +02:00
parent 43c77768ee
commit c5cf10ff65

View File

@@ -34,6 +34,9 @@ def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()):
prediction = model.predict(X) prediction = model.predict(X)
print(f'prediction: {prediction}') print(f'prediction: {prediction}')
# Plot hit map
plot_hits(prediction, grid_w, grid_h)
# converting to x and y coordinates # converting to x and y coordinates
ys, xs = np.unravel_index(np.argmax(prediction, axis=1), (grid_h, grid_w)) ys, xs = np.unravel_index(np.argmax(prediction, axis=1), (grid_h, grid_w))
@@ -51,7 +54,7 @@ def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()):
plt.figure() plt.figure()
weight = model.weight.reshape((model.n_inputs, grid_h, grid_w)) weight = model.weight.reshape((model.n_inputs, grid_h, grid_w))
heatmap = compute_heatmap(weight, grid_h, grid_w) heatmap = compute_heatmap(weight, grid_h, grid_w)
plt.imshow(heatmap, interpolation='nearest',zorder=1) plt.imshow(heatmap, cmap ='Blues', interpolation='nearest',zorder=1)
plt.axis('off') plt.axis('off')
plt.colorbar() plt.colorbar()
@@ -238,6 +241,10 @@ def plot_party_distances(distances):
ax.tick_params(axis="x", bottom=False, top=True, labelbottom=False, labeltop=True) ax.tick_params(axis="x", bottom=False, top=True, labelbottom=False, labeltop=True)
sn.heatmap(distances, cmap='Oranges', annot=True) sn.heatmap(distances, cmap='Oranges', annot=True)
def plot_hits(prediction, grid_w, grid_h):
hits = (prediction.sum(axis=0)).reshape(grid_w, grid_h)
plt.figure()
sn.heatmap(hits, annot=True, xticklabels=False, yticklabels=False, cbar=False)
def normalize_df(dataframe): def normalize_df(dataframe):
df = dataframe.copy(deep=True) df = dataframe.copy(deep=True)