diff --git a/VotingAnalysis.py b/VotingAnalysis.py index 60cb343..68a7640 100644 --- a/VotingAnalysis.py +++ b/VotingAnalysis.py @@ -250,7 +250,7 @@ prediction = sofmnet.predict(X) print(f'prediction: {prediction}') # converting to x and y coordinates -ys, xs = np.unravel_index(np.argmax(X, axis=1), (h, w)) +ys, xs = np.unravel_index(np.argmax(prediction, axis=1), (h, w)) # plotting mps plot_mps(data[:,0], xs, ys, data[:,1]) @@ -305,7 +305,7 @@ prediction = sofmnet.predict(X) print(f'prediction: {prediction}') # converting to x and y coordinates -ys, xs = np.unravel_index(np.argmax(X, axis=1), (h, w)) +ys, xs = np.unravel_index(np.argmax(prediction, axis=1), (h, w)) # plotting mps plot_mps(data[:,0], xs, ys, data[:,1])