mirror of
https://github.com/13hannes11/UU_NCML_Project.git
synced 2024-09-03 20:50:59 +02:00
Updated party colors & modified training params
This commit is contained in:
@@ -13,12 +13,12 @@ grid_h = 11 # Grid height
|
|||||||
grid_w = 11 # Grid width
|
grid_w = 11 # Grid width
|
||||||
radius = 2 # Neighbour radius
|
radius = 2 # Neighbour radius
|
||||||
step = 0.5 # Learning step
|
step = 0.5 # Learning step
|
||||||
ep = 300 # No of epochs
|
ep = 500 # No of epochs
|
||||||
|
|
||||||
# Load data
|
# Load data
|
||||||
dataset = ld.load_german_data()
|
dataset = ld.load_german_data()
|
||||||
|
|
||||||
period_to_compass_year = {17:2005, 18:2013, 19:2017}
|
period_to_compass_year = {17:2009, 18:2013, 19:2017}
|
||||||
|
|
||||||
for period, df in dataset.items():
|
for period, df in dataset.items():
|
||||||
|
|
||||||
@@ -31,5 +31,4 @@ for period, df in dataset.items():
|
|||||||
model = va.train_model(X, grid_h, grid_w, radius, step, ep)
|
model = va.train_model(X, grid_h, grid_w, radius, step, ep)
|
||||||
|
|
||||||
# Predict and visualize output
|
# Predict and visualize output
|
||||||
va.predict(model, data, grid_h, grid_w, de_name_color, pc.get_compass_parties(year=period_to_compass_year[period], country='de'))
|
va.predict(model, data, grid_h, grid_w, de_name_color, pc.get_compass_parties(year=period_to_compass_year[period], country='de'))
|
||||||
|
|
||||||
@@ -11,15 +11,12 @@ import numpy as np
|
|||||||
import pandas as pd
|
import pandas as pd
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Train model
|
# Train model
|
||||||
grid_h = 30 # Grid height
|
grid_h = 13 # Grid height
|
||||||
grid_w = 30 # Grid width
|
grid_w = 13 # Grid width
|
||||||
radius = 3 # Neighbour radius
|
radius = 2 # Neighbour radius
|
||||||
step = 0.5
|
step = 0.5
|
||||||
ep = 1 # No of epochs
|
ep = 300 # No of epochs
|
||||||
|
|
||||||
|
|
||||||
period_to_compass_year = {'2015_uk':2015, '2017_uk':2017, '2019_uk':2019}
|
period_to_compass_year = {'2015_uk':2015, '2017_uk':2017, '2019_uk':2019}
|
||||||
main_directory = 'uk/csv'
|
main_directory = 'uk/csv'
|
||||||
|
|||||||
@@ -1,13 +1,13 @@
|
|||||||
de_name_color = {
|
de_name_color = {
|
||||||
'AfD': 'blue',
|
'AfD': 'blue',
|
||||||
'BÜ90/GR': 'green',
|
'BÜ90/GR': 'green',
|
||||||
'CDU/CSU': 'black',
|
'CDU/CSU': 'orange',
|
||||||
'DIE LINKE.': 'purple',
|
'DIE LINKE.': 'purple',
|
||||||
'FDP': 'yellow',
|
'DIE LINKE': 'purple',
|
||||||
|
'FDP': 'magenta',
|
||||||
'SPD': 'red',
|
'SPD': 'red',
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
uk_name_color = {
|
uk_name_color = {
|
||||||
'Conservative': 'blue',
|
'Conservative': 'blue',
|
||||||
'Democratic Unionist Party': 'salmon',
|
'Democratic Unionist Party': 'salmon',
|
||||||
@@ -15,9 +15,11 @@ uk_name_color = {
|
|||||||
'Labour': 'red',
|
'Labour': 'red',
|
||||||
'Liberal Democrat': 'darkorange',
|
'Liberal Democrat': 'darkorange',
|
||||||
'Plaid Cymru': 'darkgreen',
|
'Plaid Cymru': 'darkgreen',
|
||||||
'Scottish National Party': 'yellow',
|
'Scottish National Party': 'magenta',
|
||||||
'Sinn Féin': 'yellowgreen',
|
'Sinn Féin': 'yellowgreen',
|
||||||
'Social Democratic & Labour Party': 'cyan',
|
'Social Democratic & Labour Party': 'cyan',
|
||||||
'UK Independence Party': 'purple',
|
'UK Independence Party': 'purple',
|
||||||
'Ulster Unionist Party': 'lightskyblue',
|
'Ulster Unionist Party': 'lightskyblue',
|
||||||
|
'Alba Party': 'black',
|
||||||
|
'Alliance': 'olive',
|
||||||
}
|
}
|
||||||
@@ -81,6 +81,7 @@ def predict(model, data, grid_h, grid_w, party_colors, comparison_data=pd.DataFr
|
|||||||
err = remove_NaN_rows_columns(normalize_df(part_distance_out) - normalize_df(comparison_data_dist))
|
err = remove_NaN_rows_columns(normalize_df(part_distance_out) - normalize_df(comparison_data_dist))
|
||||||
err = err * err
|
err = err * err
|
||||||
plot_party_distances(err)
|
plot_party_distances(err)
|
||||||
|
plt.title(f'MSE={np.nanmean(err.to_numpy()):.2f}')
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
def iter_neighbours(weights, hexagon=False):
|
def iter_neighbours(weights, hexagon=False):
|
||||||
@@ -218,7 +219,7 @@ def plot_parties(parties, party_colors, randomize_positions=False, new_plot=True
|
|||||||
|
|
||||||
for i, party in enumerate(party_index_mapping):
|
for i, party in enumerate(party_index_mapping):
|
||||||
print("Party ", party, " x = ", xs_disp[i], "y = ", ys_disp[i])
|
print("Party ", party, " x = ", xs_disp[i], "y = ", ys_disp[i])
|
||||||
plt.scatter(xs_disp[i], ys_disp[i], label=party, zorder=2, c=colors[i], edgecolors='black')
|
plt.scatter(xs_disp[i], ys_disp[i], label=party, zorder=2, c=colors[i], edgecolors='None')
|
||||||
|
|
||||||
plt.legend(title='Parties', bbox_to_anchor=(1.3, 1), loc='upper left')
|
plt.legend(title='Parties', bbox_to_anchor=(1.3, 1), loc='upper left')
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user