diff --git a/VotingAnalysis.py b/VotingAnalysis.py index 466d286..732454f 100644 --- a/VotingAnalysis.py +++ b/VotingAnalysis.py @@ -180,8 +180,8 @@ X = data[:,1:] print(X) inp = X.shape[1] # No of features (bill count) -h = 150 # Grid height -w = 150 # Grid width +h = 10 # Grid height +w = 10 # Grid width rad = 2 # Neighbour radius ep = 300 # No of epochs @@ -199,23 +199,28 @@ sofmnet = algorithms.SOFM( sofmnet.train(X, epochs=ep) #Visualizing Output -fig,ax = plt.subplots() +plt.figure() weight = sofmnet.weight.reshape((sofmnet.n_inputs, h, w)) heatmap = compute_heatmap(weight, h, w) plt.imshow(heatmap, cmap='Greys_r', interpolation='nearest') -#plt.axis('off') +plt.axis('off') plt.colorbar() +plt.show() +fig,ax = plt.subplots() prediction = sofmnet.predict(X) print(f'prediction: {prediction}') # converting to x and y coordinates -# TODO: verify that actually correct -xs = np.argmax(np.argmax(prediction.reshape(X.shape[0], h, w), axis=1), axis=1) -ys = np.argmax(np.argmax(prediction.reshape(X.shape[0], h, w), axis=2), axis=1) +ys, xs = np.unravel_index(np.argmax(X, axis=1), (h, w)) + +# add random offset to show points that are in the same location +ys_disp = ys + np.random.rand(ys.shape[0]) +xs_disp = xs + np.random.rand(xs.shape[0]) + # TODO: fix color -plot_mps(fig, ax, xs, ys, data[:,0], np.random.rand(X.shape[0])) +plot_mps(fig, ax, ys_disp, xs_disp, data[:,0], np.random.rand(X.shape[0])) plt.show() #Simple SOFM for UK