From 82cef720fcd4507fa5d0fe5b0305f5a0a157911a Mon Sep 17 00:00:00 2001 From: Deepthi Pathare Date: Tue, 11 May 2021 23:59:08 +0200 Subject: [PATCH] Data split based on election period and model is trained on each set, closes #15 --- german_analysis.py | 30 ++++++++++++++++++++---------- voting_lib/load_data.py | 21 ++++++++++++++------- voting_lib/voting_analysis.py | 16 ++++++++-------- 3 files changed, 42 insertions(+), 25 deletions(-) diff --git a/german_analysis.py b/german_analysis.py index 2233ec0..66579cc 100755 --- a/german_analysis.py +++ b/german_analysis.py @@ -3,19 +3,29 @@ import voting_lib.load_data as ld import voting_lib.voting_analysis as va +import numpy as np -# Load data -data = ld.load_german_data().to_numpy() -X = data[:,2:] - -# Train model -grid_h = 10 # Grid height -grid_w = 10 # Grid width +# Training Paramters +grid_h = 2 # Grid height +grid_w = 2 # Grid width radius = 2 # Neighbour radius step = 0.5 ep = 300 # No of epochs -model = va.train_model(X, grid_h, grid_w, radius, step, ep) +# Load data +dataset = ld.load_german_data() + + +for period, df in dataset.items(): + + print("Election Period ", period) + data = df.to_numpy() + + X = data[:,2:] + + # Train model + model = va.train_model(X, grid_h, grid_w, radius, step, ep) + + # Predict and visualize output + va.predict(model, data, grid_h, grid_w) -# Predict and visualize output -va.predict(model, data, grid_h, grid_w) \ No newline at end of file diff --git a/voting_lib/load_data.py b/voting_lib/load_data.py index a8ac9fd..d410f0d 100755 --- a/voting_lib/load_data.py +++ b/voting_lib/load_data.py @@ -12,8 +12,10 @@ def load_german_data(): """ title_file = "filename_to_titles.csv" vote_counter = -1 - data = pd.DataFrame() + #data = pd.DataFrame() + data = {} + period_column_g = 'Wahlperiode' name_column_g = 'Bezeichnung' party_column_g = 'Fraktion/Gruppe' name_column = 'Member' @@ -25,6 +27,9 @@ def load_german_data(): for dirname, _, filenames in os.walk('./de/csv'): for filename in filenames: if filename != title_file: + + print(filename) + vote_counter += 1 df = pd.read_csv(os.path.join(dirname, filename)) @@ -41,13 +46,15 @@ def load_german_data(): df=df.rename(columns={name_column_g:name_column,party_column_g:party_column}) - if data.empty: - # if first file that is loaded set data equal to data from first file - data = df[[name_column, party_column, vote_column_name]] - else: + period = df.iloc[0][period_column_g] + + if period in data: # merge data with already loaded data - data = data.merge(df[[name_column, vote_column_name]], on=name_column) - + data[period] = data[period].merge(df[[name_column, vote_column_name]], on=name_column) + else: + # if first file that is loaded set data equal to data from first file + data[period] = df[[name_column, party_column, vote_column_name]] + print(data) return data diff --git a/voting_lib/voting_analysis.py b/voting_lib/voting_analysis.py index 45c5978..ff628c0 100644 --- a/voting_lib/voting_analysis.py +++ b/voting_lib/voting_analysis.py @@ -60,14 +60,14 @@ def predict(model, data, grid_h, grid_w): plot_party_distances(part_distance_in) plt.show() - # Heatmap of weights - plt.figure() - weight = model.weight.reshape((model.n_inputs, grid_h, grid_w)) - heatmap = compute_heatmap(weight, grid_h, grid_w) - plt.imshow(heatmap, cmap='Greys_r', interpolation='nearest') - plt.axis('off') - plt.colorbar() - plt.show() + ## Heatmap of weights + # plt.figure() + # weight = model.weight.reshape((model.n_inputs, grid_h, grid_w)) + # heatmap = compute_heatmap(weight, grid_h, grid_w) + # plt.imshow(heatmap, cmap='Greys_r', interpolation='nearest') + # plt.axis('off') + # plt.colorbar() + # plt.show() def iter_neighbours(weights, hexagon=False): _, grid_height, grid_width = weights.shape