diff --git a/voting_lib/voting_analysis.py b/voting_lib/voting_analysis.py index 254a0ff..624c388 100644 --- a/voting_lib/voting_analysis.py +++ b/voting_lib/voting_analysis.py @@ -34,13 +34,16 @@ def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()): prediction = model.predict(X) print(f'prediction: {prediction}') + # Plot hit map + plot_hits(prediction, grid_w, grid_h) + # converting to x and y coordinates ys, xs = np.unravel_index(np.argmax(prediction, axis=1), (grid_h, grid_w)) # plotting mps party_affiliation = data[:,1] plot_mps(data[:,0], xs, ys, party_affiliation, randomize_positions=True) - plt.show() + plt.show() # calculating party positions based on mps party_pos = calc_party_pos(np.column_stack((xs, ys)), party_affiliation) @@ -51,7 +54,7 @@ def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()): plt.figure() weight = model.weight.reshape((model.n_inputs, grid_h, grid_w)) heatmap = compute_heatmap(weight, grid_h, grid_w) - plt.imshow(heatmap, interpolation='nearest',zorder=1) + plt.imshow(heatmap, cmap ='Blues', interpolation='nearest',zorder=1) plt.axis('off') plt.colorbar() @@ -238,7 +241,11 @@ def plot_party_distances(distances): ax.tick_params(axis="x", bottom=False, top=True, labelbottom=False, labeltop=True) sn.heatmap(distances, cmap='Oranges', annot=True) - +def plot_hits(prediction, grid_w, grid_h): + hits = (prediction.sum(axis=0)).reshape(grid_w, grid_h) + plt.figure() + sn.heatmap(hits, annot=True, xticklabels=False, yticklabels=False, cbar=False) + def normalize_df(dataframe): df = dataframe.copy(deep=True) df = df - np.min(df.to_numpy())