mirror of
https://github.com/13hannes11/UU_NCML_Project.git
synced 2024-09-03 20:50:59 +02:00
Data split based on election period and model is trained on each set, closes #15
This commit is contained in:
@@ -3,19 +3,29 @@
|
|||||||
|
|
||||||
import voting_lib.load_data as ld
|
import voting_lib.load_data as ld
|
||||||
import voting_lib.voting_analysis as va
|
import voting_lib.voting_analysis as va
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
# Load data
|
# Training Paramters
|
||||||
data = ld.load_german_data().to_numpy()
|
grid_h = 2 # Grid height
|
||||||
X = data[:,2:]
|
grid_w = 2 # Grid width
|
||||||
|
|
||||||
# Train model
|
|
||||||
grid_h = 10 # Grid height
|
|
||||||
grid_w = 10 # Grid width
|
|
||||||
radius = 2 # Neighbour radius
|
radius = 2 # Neighbour radius
|
||||||
step = 0.5
|
step = 0.5
|
||||||
ep = 300 # No of epochs
|
ep = 300 # No of epochs
|
||||||
|
|
||||||
|
# Load data
|
||||||
|
dataset = ld.load_german_data()
|
||||||
|
|
||||||
|
|
||||||
|
for period, df in dataset.items():
|
||||||
|
|
||||||
|
print("Election Period ", period)
|
||||||
|
data = df.to_numpy()
|
||||||
|
|
||||||
|
X = data[:,2:]
|
||||||
|
|
||||||
|
# Train model
|
||||||
model = va.train_model(X, grid_h, grid_w, radius, step, ep)
|
model = va.train_model(X, grid_h, grid_w, radius, step, ep)
|
||||||
|
|
||||||
# Predict and visualize output
|
# Predict and visualize output
|
||||||
va.predict(model, data, grid_h, grid_w)
|
va.predict(model, data, grid_h, grid_w)
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,10 @@ def load_german_data():
|
|||||||
"""
|
"""
|
||||||
title_file = "filename_to_titles.csv"
|
title_file = "filename_to_titles.csv"
|
||||||
vote_counter = -1
|
vote_counter = -1
|
||||||
data = pd.DataFrame()
|
#data = pd.DataFrame()
|
||||||
|
data = {}
|
||||||
|
|
||||||
|
period_column_g = 'Wahlperiode'
|
||||||
name_column_g = 'Bezeichnung'
|
name_column_g = 'Bezeichnung'
|
||||||
party_column_g = 'Fraktion/Gruppe'
|
party_column_g = 'Fraktion/Gruppe'
|
||||||
name_column = 'Member'
|
name_column = 'Member'
|
||||||
@@ -25,6 +27,9 @@ def load_german_data():
|
|||||||
for dirname, _, filenames in os.walk('./de/csv'):
|
for dirname, _, filenames in os.walk('./de/csv'):
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
if filename != title_file:
|
if filename != title_file:
|
||||||
|
|
||||||
|
print(filename)
|
||||||
|
|
||||||
vote_counter += 1
|
vote_counter += 1
|
||||||
df = pd.read_csv(os.path.join(dirname, filename))
|
df = pd.read_csv(os.path.join(dirname, filename))
|
||||||
|
|
||||||
@@ -41,12 +46,14 @@ def load_german_data():
|
|||||||
|
|
||||||
df=df.rename(columns={name_column_g:name_column,party_column_g:party_column})
|
df=df.rename(columns={name_column_g:name_column,party_column_g:party_column})
|
||||||
|
|
||||||
if data.empty:
|
period = df.iloc[0][period_column_g]
|
||||||
# if first file that is loaded set data equal to data from first file
|
|
||||||
data = df[[name_column, party_column, vote_column_name]]
|
if period in data:
|
||||||
else:
|
|
||||||
# merge data with already loaded data
|
# merge data with already loaded data
|
||||||
data = data.merge(df[[name_column, vote_column_name]], on=name_column)
|
data[period] = data[period].merge(df[[name_column, vote_column_name]], on=name_column)
|
||||||
|
else:
|
||||||
|
# if first file that is loaded set data equal to data from first file
|
||||||
|
data[period] = df[[name_column, party_column, vote_column_name]]
|
||||||
|
|
||||||
print(data)
|
print(data)
|
||||||
return data
|
return data
|
||||||
|
|||||||
@@ -60,14 +60,14 @@ def predict(model, data, grid_h, grid_w):
|
|||||||
plot_party_distances(part_distance_in)
|
plot_party_distances(part_distance_in)
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
# Heatmap of weights
|
## Heatmap of weights
|
||||||
plt.figure()
|
# plt.figure()
|
||||||
weight = model.weight.reshape((model.n_inputs, grid_h, grid_w))
|
# weight = model.weight.reshape((model.n_inputs, grid_h, grid_w))
|
||||||
heatmap = compute_heatmap(weight, grid_h, grid_w)
|
# heatmap = compute_heatmap(weight, grid_h, grid_w)
|
||||||
plt.imshow(heatmap, cmap='Greys_r', interpolation='nearest')
|
# plt.imshow(heatmap, cmap='Greys_r', interpolation='nearest')
|
||||||
plt.axis('off')
|
# plt.axis('off')
|
||||||
plt.colorbar()
|
# plt.colorbar()
|
||||||
plt.show()
|
# plt.show()
|
||||||
|
|
||||||
def iter_neighbours(weights, hexagon=False):
|
def iter_neighbours(weights, hexagon=False):
|
||||||
_, grid_height, grid_width = weights.shape
|
_, grid_height, grid_width = weights.shape
|
||||||
|
|||||||
Reference in New Issue
Block a user