Data split based on election period and model is trained on each set, closes #15

This commit is contained in:
Deepthi Pathare
2021-05-11 23:59:08 +02:00
parent 69759ccd3b
commit 82cef720fc
3 changed files with 42 additions and 25 deletions

View File

@@ -3,19 +3,29 @@
import voting_lib.load_data as ld import voting_lib.load_data as ld
import voting_lib.voting_analysis as va import voting_lib.voting_analysis as va
import numpy as np
# Load data # Training Paramters
data = ld.load_german_data().to_numpy() grid_h = 2 # Grid height
X = data[:,2:] grid_w = 2 # Grid width
# Train model
grid_h = 10 # Grid height
grid_w = 10 # Grid width
radius = 2 # Neighbour radius radius = 2 # Neighbour radius
step = 0.5 step = 0.5
ep = 300 # No of epochs ep = 300 # No of epochs
model = va.train_model(X, grid_h, grid_w, radius, step, ep) # Load data
dataset = ld.load_german_data()
for period, df in dataset.items():
print("Election Period ", period)
data = df.to_numpy()
X = data[:,2:]
# Train model
model = va.train_model(X, grid_h, grid_w, radius, step, ep)
# Predict and visualize output
va.predict(model, data, grid_h, grid_w)
# Predict and visualize output
va.predict(model, data, grid_h, grid_w)

View File

@@ -12,8 +12,10 @@ def load_german_data():
""" """
title_file = "filename_to_titles.csv" title_file = "filename_to_titles.csv"
vote_counter = -1 vote_counter = -1
data = pd.DataFrame() #data = pd.DataFrame()
data = {}
period_column_g = 'Wahlperiode'
name_column_g = 'Bezeichnung' name_column_g = 'Bezeichnung'
party_column_g = 'Fraktion/Gruppe' party_column_g = 'Fraktion/Gruppe'
name_column = 'Member' name_column = 'Member'
@@ -25,6 +27,9 @@ def load_german_data():
for dirname, _, filenames in os.walk('./de/csv'): for dirname, _, filenames in os.walk('./de/csv'):
for filename in filenames: for filename in filenames:
if filename != title_file: if filename != title_file:
print(filename)
vote_counter += 1 vote_counter += 1
df = pd.read_csv(os.path.join(dirname, filename)) df = pd.read_csv(os.path.join(dirname, filename))
@@ -41,13 +46,15 @@ def load_german_data():
df=df.rename(columns={name_column_g:name_column,party_column_g:party_column}) df=df.rename(columns={name_column_g:name_column,party_column_g:party_column})
if data.empty: period = df.iloc[0][period_column_g]
# if first file that is loaded set data equal to data from first file
data = df[[name_column, party_column, vote_column_name]] if period in data:
else:
# merge data with already loaded data # merge data with already loaded data
data = data.merge(df[[name_column, vote_column_name]], on=name_column) data[period] = data[period].merge(df[[name_column, vote_column_name]], on=name_column)
else:
# if first file that is loaded set data equal to data from first file
data[period] = df[[name_column, party_column, vote_column_name]]
print(data) print(data)
return data return data

View File

@@ -60,14 +60,14 @@ def predict(model, data, grid_h, grid_w):
plot_party_distances(part_distance_in) plot_party_distances(part_distance_in)
plt.show() plt.show()
# Heatmap of weights ## Heatmap of weights
plt.figure() # plt.figure()
weight = model.weight.reshape((model.n_inputs, grid_h, grid_w)) # weight = model.weight.reshape((model.n_inputs, grid_h, grid_w))
heatmap = compute_heatmap(weight, grid_h, grid_w) # heatmap = compute_heatmap(weight, grid_h, grid_w)
plt.imshow(heatmap, cmap='Greys_r', interpolation='nearest') # plt.imshow(heatmap, cmap='Greys_r', interpolation='nearest')
plt.axis('off') # plt.axis('off')
plt.colorbar() # plt.colorbar()
plt.show() # plt.show()
def iter_neighbours(weights, hexagon=False): def iter_neighbours(weights, hexagon=False):
_, grid_height, grid_width = weights.shape _, grid_height, grid_width = weights.shape