diff --git a/voting_lib/voting_analysis.py b/voting_lib/voting_analysis.py index 810eb18..73587cf 100644 --- a/voting_lib/voting_analysis.py +++ b/voting_lib/voting_analysis.py @@ -59,7 +59,7 @@ def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()): plt.figure() weight = model.weight.reshape((model.n_inputs, grid_h, grid_w)) heatmap = compute_heatmap(weight, grid_h, grid_w) - plt.imshow(heatmap, cmap ='Blues', interpolation='nearest',zorder=1) + plt.imshow(heatmap, cmap ='Blues', interpolation='nearest',zorder=1, alpha=0.5) plt.axis('off') plt.colorbar()