mirror of
https://github.com/13hannes11/UU_NCML_Project.git
synced 2024-09-03 20:50:59 +02:00
Hit plot, closes #21
This commit is contained in:
@@ -34,13 +34,16 @@ def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()):
|
|||||||
prediction = model.predict(X)
|
prediction = model.predict(X)
|
||||||
print(f'prediction: {prediction}')
|
print(f'prediction: {prediction}')
|
||||||
|
|
||||||
|
# Plot hit map
|
||||||
|
plot_hits(prediction, grid_w, grid_h)
|
||||||
|
|
||||||
# converting to x and y coordinates
|
# converting to x and y coordinates
|
||||||
ys, xs = np.unravel_index(np.argmax(prediction, axis=1), (grid_h, grid_w))
|
ys, xs = np.unravel_index(np.argmax(prediction, axis=1), (grid_h, grid_w))
|
||||||
|
|
||||||
# plotting mps
|
# plotting mps
|
||||||
party_affiliation = data[:,1]
|
party_affiliation = data[:,1]
|
||||||
plot_mps(data[:,0], xs, ys, party_affiliation, randomize_positions=True)
|
plot_mps(data[:,0], xs, ys, party_affiliation, randomize_positions=True)
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
# calculating party positions based on mps
|
# calculating party positions based on mps
|
||||||
party_pos = calc_party_pos(np.column_stack((xs, ys)), party_affiliation)
|
party_pos = calc_party_pos(np.column_stack((xs, ys)), party_affiliation)
|
||||||
@@ -51,7 +54,7 @@ def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()):
|
|||||||
plt.figure()
|
plt.figure()
|
||||||
weight = model.weight.reshape((model.n_inputs, grid_h, grid_w))
|
weight = model.weight.reshape((model.n_inputs, grid_h, grid_w))
|
||||||
heatmap = compute_heatmap(weight, grid_h, grid_w)
|
heatmap = compute_heatmap(weight, grid_h, grid_w)
|
||||||
plt.imshow(heatmap, interpolation='nearest',zorder=1)
|
plt.imshow(heatmap, cmap ='Blues', interpolation='nearest',zorder=1)
|
||||||
plt.axis('off')
|
plt.axis('off')
|
||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
|
|
||||||
@@ -238,7 +241,11 @@ def plot_party_distances(distances):
|
|||||||
ax.tick_params(axis="x", bottom=False, top=True, labelbottom=False, labeltop=True)
|
ax.tick_params(axis="x", bottom=False, top=True, labelbottom=False, labeltop=True)
|
||||||
sn.heatmap(distances, cmap='Oranges', annot=True)
|
sn.heatmap(distances, cmap='Oranges', annot=True)
|
||||||
|
|
||||||
|
def plot_hits(prediction, grid_w, grid_h):
|
||||||
|
hits = (prediction.sum(axis=0)).reshape(grid_w, grid_h)
|
||||||
|
plt.figure()
|
||||||
|
sn.heatmap(hits, annot=True, xticklabels=False, yticklabels=False, cbar=False)
|
||||||
|
|
||||||
def normalize_df(dataframe):
|
def normalize_df(dataframe):
|
||||||
df = dataframe.copy(deep=True)
|
df = dataframe.copy(deep=True)
|
||||||
df = df - np.min(df.to_numpy())
|
df = df - np.min(df.to_numpy())
|
||||||
|
|||||||
Reference in New Issue
Block a user