diff --git a/german_analysis.py b/de_analysis.py similarity index 83% rename from german_analysis.py rename to de_analysis.py index e0ec30f..f15909d 100755 --- a/german_analysis.py +++ b/de_analysis.py @@ -12,12 +12,12 @@ grid_h = 11 # Grid height grid_w = 11 # Grid width radius = 2 # Neighbour radius step = 0.5 -ep = 300 # No of epochs +ep = 100 # No of epochs # Load data dataset = ld.load_german_data() -years = {17:2005, 18:2013, 19:2017} +period_to_compass_year = {17:2005, 18:2013, 19:2017} for period, df in dataset.items(): @@ -30,5 +30,5 @@ for period, df in dataset.items(): model = va.train_model(X, grid_h, grid_w, radius, step, ep) # Predict and visualize output - va.predict(model, data, grid_h, grid_w, pc.get_compass_parties(year=years[period], country='de')) + va.predict(model, data, grid_h, grid_w, pc.get_compass_parties(year=period_to_compass_year[period], country='de'))