mirror of
https://github.com/13hannes11/UU_NCML_Project.git
synced 2024-09-03 20:50:59 +02:00
add plotting for mps #7
This commit is contained in:
@@ -135,6 +135,42 @@ def compute_heatmap(weight, grid_height, grid_width):
|
|||||||
|
|
||||||
return heatmap
|
return heatmap
|
||||||
|
|
||||||
|
def plot_mps(fig, ax, x, y, labels, colors, cmap = plt.cm.RdYlGn):
|
||||||
|
ANNOTATION_DISTANCE = 5
|
||||||
|
TRANSPARENCY = 0.8
|
||||||
|
scatterplot = plt.scatter(x,y,c=colors, s=5, cmap=cmap)
|
||||||
|
|
||||||
|
annot = ax.annotate("", xy=(0,0),
|
||||||
|
xytext=(ANNOTATION_DISTANCE, ANNOTATION_DISTANCE),
|
||||||
|
textcoords="offset points",
|
||||||
|
bbox=dict(boxstyle="Square"))
|
||||||
|
annot.set_visible(False)
|
||||||
|
|
||||||
|
def update_annot(ind):
|
||||||
|
index = ind["ind"][0]
|
||||||
|
pos = scatterplot.get_offsets()[index]
|
||||||
|
annot.xy = pos
|
||||||
|
text = f'{labels[index]}'
|
||||||
|
annot.set_text(text)
|
||||||
|
annot.get_bbox_patch().set_facecolor(cmap(colors[index]))
|
||||||
|
annot.get_bbox_patch().set_alpha(TRANSPARENCY)
|
||||||
|
|
||||||
|
def hover(event):
|
||||||
|
vis = annot.get_visible()
|
||||||
|
if event.inaxes == ax:
|
||||||
|
cont, ind = scatterplot.contains(event)
|
||||||
|
if cont:
|
||||||
|
update_annot(ind)
|
||||||
|
annot.set_visible(True)
|
||||||
|
fig.canvas.draw_idle()
|
||||||
|
else:
|
||||||
|
if vis:
|
||||||
|
annot.set_visible(False)
|
||||||
|
fig.canvas.draw_idle()
|
||||||
|
|
||||||
|
fig.canvas.mpl_connect("motion_notify_event", hover)
|
||||||
|
#plt.show()
|
||||||
|
|
||||||
#Simple SOFM for German
|
#Simple SOFM for German
|
||||||
plt.style.use('ggplot')
|
plt.style.use('ggplot')
|
||||||
|
|
||||||
@@ -163,12 +199,23 @@ sofmnet = algorithms.SOFM(
|
|||||||
sofmnet.train(X, epochs=ep)
|
sofmnet.train(X, epochs=ep)
|
||||||
|
|
||||||
#Visualizing Output
|
#Visualizing Output
|
||||||
plt.figure()
|
fig,ax = plt.subplots()
|
||||||
|
|
||||||
weight = sofmnet.weight.reshape((sofmnet.n_inputs, h, w))
|
weight = sofmnet.weight.reshape((sofmnet.n_inputs, h, w))
|
||||||
heatmap = compute_heatmap(weight, h, w)
|
heatmap = compute_heatmap(weight, h, w)
|
||||||
plt.imshow(heatmap, cmap='Greys_r', interpolation='nearest')
|
plt.imshow(heatmap, cmap='Greys_r', interpolation='nearest')
|
||||||
plt.axis('off')
|
#plt.axis('off')
|
||||||
plt.colorbar()
|
plt.colorbar()
|
||||||
|
|
||||||
|
prediction = sofmnet.predict(X)
|
||||||
|
print(f'prediction: {prediction}')
|
||||||
|
# converting to x and y coordinates
|
||||||
|
# TODO: verify that actually correct
|
||||||
|
xs = np.argmax(np.argmax(prediction.reshape(X.shape[0], h, w), axis=1), axis=1)
|
||||||
|
ys = np.argmax(np.argmax(prediction.reshape(X.shape[0], h, w), axis=2), axis=1)
|
||||||
|
|
||||||
|
# TODO: fix color
|
||||||
|
plot_mps(fig, ax, xs, ys, data[:,0], np.random.rand(X.shape[0]))
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
#Simple SOFM for UK
|
#Simple SOFM for UK
|
||||||
@@ -207,4 +254,6 @@ ax.scatter3D(*X.T, label='Input');
|
|||||||
ax.set_xlabel('vote_0')
|
ax.set_xlabel('vote_0')
|
||||||
ax.set_ylabel('vote_1')
|
ax.set_ylabel('vote_1')
|
||||||
ax.set_zlabel('vote_2')
|
ax.set_zlabel('vote_2')
|
||||||
ax.legend()
|
ax.legend()
|
||||||
|
|
||||||
|
plt.show()
|
||||||
|
|||||||
Reference in New Issue
Block a user