add legend to hoverscatter

This commit is contained in:
2021-05-24 14:30:16 +02:00
parent 339c686449
commit 15f15df19f

View File

@@ -7,6 +7,7 @@ from itertools import product
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import pandas as pd import pandas as pd
import seaborn as sn import seaborn as sn
import re
def train_model(X, grid_h, grid_w, radius, step, ep): def train_model(X, grid_h, grid_w, radius, step, ep):
@@ -131,12 +132,18 @@ def compute_heatmap(weight, grid_height, grid_width):
return heatmap return heatmap
def plot_hoverscatter(x, y, labels, colors, cmap = plt.cm.RdYlGn): def plot_hoverscatter(x, y, categories, hover_labels, colors, cmap = plt.cm.RdYlGn):
fig, ax = plt.subplots() fig, ax = plt.subplots()
ANNOTATION_DISTANCE = 5 ANNOTATION_DISTANCE = 5
TRANSPARENCY = 0.8 TRANSPARENCY = 0.8
scatterplot = plt.scatter(x,y,c=colors, s=5, cmap=cmap) scatterplot = plt.scatter(x,y,c=colors, s=5, cmap=cmap)
handles, labels = scatterplot.legend_elements(prop="colors", alpha=0.6)
print(labels[0])
cat = list(map(lambda l: categories[int(re.sub(r'([^\d]+)', "", l))], labels))
legend = ax.legend(handles, cat, loc="upper right", title="Sizes")
annot = ax.annotate("", xy=(0,0), annot = ax.annotate("", xy=(0,0),
xytext=(ANNOTATION_DISTANCE, ANNOTATION_DISTANCE), xytext=(ANNOTATION_DISTANCE, ANNOTATION_DISTANCE),
textcoords="offset points", textcoords="offset points",
@@ -147,9 +154,8 @@ def plot_hoverscatter(x, y, labels, colors, cmap = plt.cm.RdYlGn):
index = ind["ind"][0] index = ind["ind"][0]
pos = scatterplot.get_offsets()[index] pos = scatterplot.get_offsets()[index]
annot.xy = pos annot.xy = pos
text = f'{labels[index]}' text = f'{hover_labels[index]}'
annot.set_text(text) annot.set_text(text)
annot.get_bbox_patch().set_facecolor(cmap(colors[index]))
annot.get_bbox_patch().set_alpha(TRANSPARENCY) annot.get_bbox_patch().set_alpha(TRANSPARENCY)
def hover(event): def hover(event):
@@ -180,7 +186,7 @@ def plot_mps(names, xs, ys, party_affiliation, randomize_positions=True):
ys_disp = ys ys_disp = ys
parties = party_index_mapping[party_ids] parties = party_index_mapping[party_ids]
plot_hoverscatter(xs_disp, ys_disp, names + " (" + parties + ")", party_ids) plot_hoverscatter(xs_disp, ys_disp, party_index_mapping, names + " (" + parties + ")", party_ids)
def calc_party_pos(members_of_parliament, party_affiliation): def calc_party_pos(members_of_parliament, party_affiliation):
party_index_mapping, party_ids = np.unique(party_affiliation, return_inverse=True) party_index_mapping, party_ids = np.unique(party_affiliation, return_inverse=True)