mirror of
https://github.com/13hannes11/UU_NCML_Project.git
synced 2024-09-03 20:50:59 +02:00
Hit plot, closes #21
This commit is contained in:
@@ -34,6 +34,9 @@ def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()):
|
||||
prediction = model.predict(X)
|
||||
print(f'prediction: {prediction}')
|
||||
|
||||
# Plot hit map
|
||||
plot_hits(prediction, grid_w, grid_h)
|
||||
|
||||
# converting to x and y coordinates
|
||||
ys, xs = np.unravel_index(np.argmax(prediction, axis=1), (grid_h, grid_w))
|
||||
|
||||
@@ -51,7 +54,7 @@ def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()):
|
||||
plt.figure()
|
||||
weight = model.weight.reshape((model.n_inputs, grid_h, grid_w))
|
||||
heatmap = compute_heatmap(weight, grid_h, grid_w)
|
||||
plt.imshow(heatmap, interpolation='nearest',zorder=1)
|
||||
plt.imshow(heatmap, cmap ='Blues', interpolation='nearest',zorder=1)
|
||||
plt.axis('off')
|
||||
plt.colorbar()
|
||||
|
||||
@@ -238,6 +241,10 @@ def plot_party_distances(distances):
|
||||
ax.tick_params(axis="x", bottom=False, top=True, labelbottom=False, labeltop=True)
|
||||
sn.heatmap(distances, cmap='Oranges', annot=True)
|
||||
|
||||
def plot_hits(prediction, grid_w, grid_h):
|
||||
hits = (prediction.sum(axis=0)).reshape(grid_w, grid_h)
|
||||
plt.figure()
|
||||
sn.heatmap(hits, annot=True, xticklabels=False, yticklabels=False, cbar=False)
|
||||
|
||||
def normalize_df(dataframe):
|
||||
df = dataframe.copy(deep=True)
|
||||
|
||||
Reference in New Issue
Block a user