diff --git a/VotingAnalysis.py b/VotingAnalysis.py index b85083c..466d286 100644 --- a/VotingAnalysis.py +++ b/VotingAnalysis.py @@ -135,6 +135,42 @@ def compute_heatmap(weight, grid_height, grid_width): return heatmap +def plot_mps(fig, ax, x, y, labels, colors, cmap = plt.cm.RdYlGn): + ANNOTATION_DISTANCE = 5 + TRANSPARENCY = 0.8 + scatterplot = plt.scatter(x,y,c=colors, s=5, cmap=cmap) + + annot = ax.annotate("", xy=(0,0), + xytext=(ANNOTATION_DISTANCE, ANNOTATION_DISTANCE), + textcoords="offset points", + bbox=dict(boxstyle="Square")) + annot.set_visible(False) + + def update_annot(ind): + index = ind["ind"][0] + pos = scatterplot.get_offsets()[index] + annot.xy = pos + text = f'{labels[index]}' + annot.set_text(text) + annot.get_bbox_patch().set_facecolor(cmap(colors[index])) + annot.get_bbox_patch().set_alpha(TRANSPARENCY) + + def hover(event): + vis = annot.get_visible() + if event.inaxes == ax: + cont, ind = scatterplot.contains(event) + if cont: + update_annot(ind) + annot.set_visible(True) + fig.canvas.draw_idle() + else: + if vis: + annot.set_visible(False) + fig.canvas.draw_idle() + + fig.canvas.mpl_connect("motion_notify_event", hover) + #plt.show() + #Simple SOFM for German plt.style.use('ggplot') @@ -163,12 +199,23 @@ sofmnet = algorithms.SOFM( sofmnet.train(X, epochs=ep) #Visualizing Output -plt.figure() +fig,ax = plt.subplots() + weight = sofmnet.weight.reshape((sofmnet.n_inputs, h, w)) heatmap = compute_heatmap(weight, h, w) plt.imshow(heatmap, cmap='Greys_r', interpolation='nearest') -plt.axis('off') +#plt.axis('off') plt.colorbar() + +prediction = sofmnet.predict(X) +print(f'prediction: {prediction}') +# converting to x and y coordinates +# TODO: verify that actually correct +xs = np.argmax(np.argmax(prediction.reshape(X.shape[0], h, w), axis=1), axis=1) +ys = np.argmax(np.argmax(prediction.reshape(X.shape[0], h, w), axis=2), axis=1) + +# TODO: fix color +plot_mps(fig, ax, xs, ys, data[:,0], np.random.rand(X.shape[0])) plt.show() #Simple SOFM for UK @@ -207,4 +254,6 @@ ax.scatter3D(*X.T, label='Input'); ax.set_xlabel('vote_0') ax.set_ylabel('vote_1') ax.set_zlabel('vote_2') -ax.legend() \ No newline at end of file +ax.legend() + +plt.show()