mirror of
https://github.com/13hannes11/UU_NCML_Project.git
synced 2024-09-03 20:50:59 +02:00
Data split based on election period and model is trained on each set, closes #15
This commit is contained in:
@@ -60,14 +60,14 @@ def predict(model, data, grid_h, grid_w):
|
||||
plot_party_distances(part_distance_in)
|
||||
plt.show()
|
||||
|
||||
# Heatmap of weights
|
||||
plt.figure()
|
||||
weight = model.weight.reshape((model.n_inputs, grid_h, grid_w))
|
||||
heatmap = compute_heatmap(weight, grid_h, grid_w)
|
||||
plt.imshow(heatmap, cmap='Greys_r', interpolation='nearest')
|
||||
plt.axis('off')
|
||||
plt.colorbar()
|
||||
plt.show()
|
||||
## Heatmap of weights
|
||||
# plt.figure()
|
||||
# weight = model.weight.reshape((model.n_inputs, grid_h, grid_w))
|
||||
# heatmap = compute_heatmap(weight, grid_h, grid_w)
|
||||
# plt.imshow(heatmap, cmap='Greys_r', interpolation='nearest')
|
||||
# plt.axis('off')
|
||||
# plt.colorbar()
|
||||
# plt.show()
|
||||
|
||||
def iter_neighbours(weights, hexagon=False):
|
||||
_, grid_height, grid_width = weights.shape
|
||||
|
||||
Reference in New Issue
Block a user