mirror of
https://github.com/13hannes11/UU_NCML_Project.git
synced 2024-09-03 20:50:59 +02:00
Plot node distances, part of #22
This commit is contained in:
@@ -15,7 +15,7 @@ def train_model(X, grid_h, grid_w, radius, step, ep):
|
|||||||
# Create SOFM
|
# Create SOFM
|
||||||
sofmnet = algorithms.SOFM(
|
sofmnet = algorithms.SOFM(
|
||||||
n_inputs=inp,
|
n_inputs=inp,
|
||||||
step=0.5,
|
step=step,
|
||||||
show_epoch=100,
|
show_epoch=100,
|
||||||
shuffle_data=True,
|
shuffle_data=True,
|
||||||
verbose=True,
|
verbose=True,
|
||||||
@@ -45,8 +45,20 @@ def predict(model, data, grid_h, grid_w):
|
|||||||
# calculating party positions based on mps
|
# calculating party positions based on mps
|
||||||
party_pos = calc_party_pos(np.column_stack((xs, ys)), party_affiliation)
|
party_pos = calc_party_pos(np.column_stack((xs, ys)), party_affiliation)
|
||||||
|
|
||||||
|
# Plot node distnaces
|
||||||
|
plt.figure()
|
||||||
|
weight = model.weight.reshape((model.n_inputs, grid_h, grid_w))
|
||||||
|
heatmap = compute_heatmap(weight, grid_h, grid_w)
|
||||||
|
plt.imshow(heatmap, interpolation='nearest',zorder=1)
|
||||||
|
plt.axis('off')
|
||||||
|
plt.colorbar()
|
||||||
|
|
||||||
# plotting parties
|
# plotting parties
|
||||||
plot_parties(party_pos, randomize_positions=False, new_plot=True)
|
plot_parties(party_pos, randomize_positions=False, new_plot=False)
|
||||||
|
plt.xticks(np.arange(0, grid_w+1, 1.0))
|
||||||
|
plt.yticks(np.arange(0, grid_h+1, 1.0))
|
||||||
|
plt.grid(True)
|
||||||
|
plt.title(f'Learning Radius: {model.learning_radius}, Grid Size: {grid_w}')
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
# plotting party distances in output space
|
# plotting party distances in output space
|
||||||
@@ -54,19 +66,10 @@ def predict(model, data, grid_h, grid_w):
|
|||||||
plot_party_distances(part_distance_out)
|
plot_party_distances(part_distance_out)
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
# plotting party distances in input space
|
# # plotting party distances in input space (TODO discard)
|
||||||
party_pos_out = calc_party_pos(X, party_affiliation)
|
# party_pos_out = calc_party_pos(X, party_affiliation)
|
||||||
part_distance_in = calc_party_distances(party_pos_out)
|
# part_distance_in = calc_party_distances(party_pos_out)
|
||||||
plot_party_distances(part_distance_in)
|
# plot_party_distances(part_distance_in)
|
||||||
plt.show()
|
|
||||||
|
|
||||||
## Heatmap of weights
|
|
||||||
# plt.figure()
|
|
||||||
# weight = model.weight.reshape((model.n_inputs, grid_h, grid_w))
|
|
||||||
# heatmap = compute_heatmap(weight, grid_h, grid_w)
|
|
||||||
# plt.imshow(heatmap, cmap='Greys_r', interpolation='nearest')
|
|
||||||
# plt.axis('off')
|
|
||||||
# plt.colorbar()
|
|
||||||
# plt.show()
|
# plt.show()
|
||||||
|
|
||||||
def iter_neighbours(weights, hexagon=False):
|
def iter_neighbours(weights, hexagon=False):
|
||||||
@@ -147,7 +150,6 @@ def plot_hoverscatter(x, y, labels, colors, cmap = plt.cm.RdYlGn):
|
|||||||
fig.canvas.draw_idle()
|
fig.canvas.draw_idle()
|
||||||
|
|
||||||
fig.canvas.mpl_connect("motion_notify_event", hover)
|
fig.canvas.mpl_connect("motion_notify_event", hover)
|
||||||
#plt.show()
|
|
||||||
|
|
||||||
def plot_mps(names, xs, ys, party_affiliation, randomize_positions=True):
|
def plot_mps(names, xs, ys, party_affiliation, randomize_positions=True):
|
||||||
# converting parties to numeric format
|
# converting parties to numeric format
|
||||||
@@ -191,12 +193,18 @@ def plot_parties(parties, randomize_positions=False, new_plot=True):
|
|||||||
else:
|
else:
|
||||||
xs_disp = parties[0].to_numpy()
|
xs_disp = parties[0].to_numpy()
|
||||||
ys_disp = parties[1].to_numpy()
|
ys_disp = parties[1].to_numpy()
|
||||||
|
|
||||||
fig, ax = plt.subplots()
|
# party_colors=np.array(range(len(party_index_mapping)))
|
||||||
|
# plt.scatter(xs_disp, ys_disp, c=party_colors,cmap=plt.cm.RdYlGn, zorder=2)
|
||||||
|
# offset = 0.01
|
||||||
|
# for x,y, party in zip(xs_disp, ys_disp, party_index_mapping):
|
||||||
|
# plt.text(x + offset, y + offset, party)
|
||||||
|
|
||||||
for i, party in enumerate(party_index_mapping):
|
for i, party in enumerate(party_index_mapping):
|
||||||
print("Party", party, " x = ", xs_disp[i], "y = ", ys_disp[i])
|
print("Party ", party, " x = ", xs_disp[i], "y = ", ys_disp[i])
|
||||||
plt.scatter(xs_disp[i], ys_disp[i], label=party)
|
plt.scatter(xs_disp[i], ys_disp[i], label=party, zorder=2)
|
||||||
plt.legend(title='Parties',bbox_to_anchor=(1.05, 1), loc='upper left')
|
|
||||||
|
plt.legend(title='Parties',bbox_to_anchor=(1.3, 1), loc='upper left')
|
||||||
|
|
||||||
def calc_party_distances(parties):
|
def calc_party_distances(parties):
|
||||||
distances = np.zeros((parties.shape[0], parties.shape[0]))
|
distances = np.zeros((parties.shape[0], parties.shape[0]))
|
||||||
|
|||||||
Reference in New Issue
Block a user