Skip to content

Commit 2b5c09f

Browse files
committed
improve filtering of datasets in benchmarks
1 parent 6337f60 commit 2b5c09f

File tree

1 file changed

+29
-17
lines changed

1 file changed

+29
-17
lines changed

docs/benchmarks/ebm-benchmark.ipynb

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -913,7 +913,8 @@
913913
"\n",
914914
"# Optionally filter out results we want to replace\n",
915915
"#results_df = results_df[results_df['method'] != 'ebm']\n",
916-
"#results_df = results_df[~((results_df['method'] == 'ebm') & (results_df['meta'] == '{}'))]\n",
916+
"#results_df = results_df[(results_df['method'] != 'ebm') | (results_df['meta'] != '{}')]\n",
917+
"#results_df = results_df[(results_df['method'] != 'ebm') | (results_df['meta'] != '{\"interactions\": 0}')]\n",
917918
"print(f'Results (post-filtered) count: {results_df.shape[0]}')"
918919
]
919920
},
@@ -925,21 +926,26 @@
925926
"outputs": [],
926927
"source": [
927928
"# Fill in results from previous runs if desired.\n",
928-
"filler_df = pd.DataFrame(columns=results_df.columns)\n",
929-
"#filler_df = pd.read_csv(\"prev.csv\")\n",
930-
"\n",
931-
"# Optionally filter out results from the filter\n",
932-
"#filler_df = filler_df[filler_df['meta'] == \"{'interactions': 0\"]\n",
933-
"\n",
934-
"key_columns = ['task', 'method', 'meta', 'replicate_num', 'name', 'seq_num']\n",
935-
"filler_df = filler_df[~filler_df.set_index(key_columns).index.isin(results_df.set_index(key_columns).index)]\n",
936-
"if 0 < filler_df.shape[0]:\n",
937-
" results_df = pd.concat([results_df, filler_df], ignore_index=True)\n",
938-
" results_df = results_df.sort_values(by=[\"task\", \"method\", \"meta\", \"replicate_num\", \"name\", \"seq_num\"])\n",
939-
" results_df.to_csv(\"merged.csv\", index=None)\n",
940-
"print(f'Filter count: {filler_df.shape[0]}')\n",
941-
"print(f'Results count: {results_df.shape[0]}')\n",
942-
"#print(filler_df.to_string())"
929+
"basefile = 'base.csv'\n",
930+
"import os\n",
931+
"if os.path.exists(basefile):\n",
932+
" filler_df = pd.DataFrame(columns=results_df.columns)\n",
933+
" filler_df = pd.read_csv(basefile)\n",
934+
" \n",
935+
" # Optionally filter out results from the filter\n",
936+
" filler_df = filler_df[filler_df['method'] != 'ebm']\n",
937+
" #filler_df = filler_df[(filler_df['method'] != 'ebm') | (filler_df['meta'] != '{}')]\n",
938+
" #filler_df = filler_df[(filler_df['method'] != 'ebm') | (filler_df['meta'] != '{\"interactions\": 0}')]\n",
939+
" \n",
940+
" key_columns = ['task', 'method', 'meta', 'replicate_num', 'name', 'seq_num']\n",
941+
" filler_df = filler_df[~filler_df.set_index(key_columns).index.isin(results_df.set_index(key_columns).index)]\n",
942+
" if 0 < filler_df.shape[0]:\n",
943+
" results_df = pd.concat([results_df, filler_df], ignore_index=True)\n",
944+
" results_df = results_df.sort_values(by=[\"task\", \"method\", \"meta\", \"replicate_num\", \"name\", \"seq_num\"])\n",
945+
" results_df.to_csv(\"merged.csv\", index=None)\n",
946+
" print(f'Filter count: {filler_df.shape[0]}')\n",
947+
" print(f'Results count: {results_df.shape[0]}')\n",
948+
" #print(filler_df.to_string())"
943949
]
944950
},
945951
{
@@ -960,7 +966,13 @@
960966
"\n",
961967
"# Optionally filter out any incomplete datasets\n",
962968
"#results_df = results_df[results_df['task'] != 'Devnagari-Script']\n",
963-
"#results_df = results_df[results_df['type'] == 'regression']\n",
969+
"#results_df = results_df[results_df['task'] != 'CIFAR_10']\n",
970+
"#results_df = results_df[results_df['task'] != 'isolet']\n",
971+
"#results_df = results_df[results_df['task'] != 'mnist_784']\n",
972+
"#results_df = results_df[results_df['task'] != 'Airlines_DepDelay_10M']\n",
973+
"#results_df = results_df[results_df['type'] != 'binary']\n",
974+
"#results_df = results_df[results_df['type'] != 'multiclass']\n",
975+
"#results_df = results_df[results_df['type'] != 'regression']\n",
964976
"print(f'Final count: {results_df.shape[0]}')"
965977
]
966978
},

0 commit comments

Comments
 (0)