fix plotting of MPs #7

This commit is contained in:
2021-05-04 10:26:49 +02:00
parent 35317f8d67
commit b8c9f51063

View File

@@ -180,8 +180,8 @@ X = data[:,1:]
print(X) print(X)
inp = X.shape[1] # No of features (bill count) inp = X.shape[1] # No of features (bill count)
h = 150 # Grid height h = 10 # Grid height
w = 150 # Grid width w = 10 # Grid width
rad = 2 # Neighbour radius rad = 2 # Neighbour radius
ep = 300 # No of epochs ep = 300 # No of epochs
@@ -199,23 +199,28 @@ sofmnet = algorithms.SOFM(
sofmnet.train(X, epochs=ep) sofmnet.train(X, epochs=ep)
#Visualizing Output #Visualizing Output
fig,ax = plt.subplots() plt.figure()
weight = sofmnet.weight.reshape((sofmnet.n_inputs, h, w)) weight = sofmnet.weight.reshape((sofmnet.n_inputs, h, w))
heatmap = compute_heatmap(weight, h, w) heatmap = compute_heatmap(weight, h, w)
plt.imshow(heatmap, cmap='Greys_r', interpolation='nearest') plt.imshow(heatmap, cmap='Greys_r', interpolation='nearest')
#plt.axis('off') plt.axis('off')
plt.colorbar() plt.colorbar()
plt.show()
fig,ax = plt.subplots()
prediction = sofmnet.predict(X) prediction = sofmnet.predict(X)
print(f'prediction: {prediction}') print(f'prediction: {prediction}')
# converting to x and y coordinates # converting to x and y coordinates
# TODO: verify that actually correct ys, xs = np.unravel_index(np.argmax(X, axis=1), (h, w))
xs = np.argmax(np.argmax(prediction.reshape(X.shape[0], h, w), axis=1), axis=1)
ys = np.argmax(np.argmax(prediction.reshape(X.shape[0], h, w), axis=2), axis=1) # add random offset to show points that are in the same location
ys_disp = ys + np.random.rand(ys.shape[0])
xs_disp = xs + np.random.rand(xs.shape[0])
# TODO: fix color # TODO: fix color
plot_mps(fig, ax, xs, ys, data[:,0], np.random.rand(X.shape[0])) plot_mps(fig, ax, ys_disp, xs_disp, data[:,0], np.random.rand(X.shape[0]))
plt.show() plt.show()
#Simple SOFM for UK #Simple SOFM for UK