mirror of
https://github.com/13hannes11/UU_NCML_Project.git
synced 2024-09-03 20:50:59 +02:00
fix plotting of MPs #7
This commit is contained in:
@@ -180,8 +180,8 @@ X = data[:,1:]
|
||||
print(X)
|
||||
|
||||
inp = X.shape[1] # No of features (bill count)
|
||||
h = 150 # Grid height
|
||||
w = 150 # Grid width
|
||||
h = 10 # Grid height
|
||||
w = 10 # Grid width
|
||||
rad = 2 # Neighbour radius
|
||||
ep = 300 # No of epochs
|
||||
|
||||
@@ -199,23 +199,28 @@ sofmnet = algorithms.SOFM(
|
||||
sofmnet.train(X, epochs=ep)
|
||||
|
||||
#Visualizing Output
|
||||
fig,ax = plt.subplots()
|
||||
plt.figure()
|
||||
|
||||
weight = sofmnet.weight.reshape((sofmnet.n_inputs, h, w))
|
||||
heatmap = compute_heatmap(weight, h, w)
|
||||
plt.imshow(heatmap, cmap='Greys_r', interpolation='nearest')
|
||||
#plt.axis('off')
|
||||
plt.axis('off')
|
||||
plt.colorbar()
|
||||
plt.show()
|
||||
|
||||
fig,ax = plt.subplots()
|
||||
prediction = sofmnet.predict(X)
|
||||
print(f'prediction: {prediction}')
|
||||
# converting to x and y coordinates
|
||||
# TODO: verify that actually correct
|
||||
xs = np.argmax(np.argmax(prediction.reshape(X.shape[0], h, w), axis=1), axis=1)
|
||||
ys = np.argmax(np.argmax(prediction.reshape(X.shape[0], h, w), axis=2), axis=1)
|
||||
ys, xs = np.unravel_index(np.argmax(X, axis=1), (h, w))
|
||||
|
||||
# add random offset to show points that are in the same location
|
||||
ys_disp = ys + np.random.rand(ys.shape[0])
|
||||
xs_disp = xs + np.random.rand(xs.shape[0])
|
||||
|
||||
|
||||
# TODO: fix color
|
||||
plot_mps(fig, ax, xs, ys, data[:,0], np.random.rand(X.shape[0]))
|
||||
plot_mps(fig, ax, ys_disp, xs_disp, data[:,0], np.random.rand(X.shape[0]))
|
||||
plt.show()
|
||||
|
||||
#Simple SOFM for UK
|
||||
|
||||
Reference in New Issue
Block a user