add code to remove NaN rows and columns

This commit is contained in:
2021-05-20 12:10:41 +02:00
parent 2dc57693a0
commit 287d87b668

View File

@@ -80,7 +80,7 @@ def predict(model, data, grid_h, grid_w, comparison_data=pd.DataFrame()):
comparison_data_dist = calc_party_distances(comparison_data)
plot_party_distances(comparison_data_dist)
plt.show()
err = normalize_df(part_distance_out) - normalize_df(comparison_data_dist)
err = remove_NaN_rows_columns(normalize_df(part_distance_out) - normalize_df(comparison_data_dist))
err = err * err
plot_party_distances(err)
plt.title(f'distance squared error, with mse={np.nanmean(err.to_numpy()):.2f}')
@@ -251,3 +251,9 @@ def normalize_df(dataframe):
df = df - np.min(df.to_numpy())
df = df / np.max(df.to_numpy())
return df
def remove_NaN_rows_columns(dataframe):
df = dataframe.copy(deep=True)
df = df.dropna(axis=0, how='all', thresh=None, subset=None, inplace=False)
df = df.dropna(axis=1, how='all', thresh=None, subset=None, inplace=False)
return df