Skip to content

Commit 65888b4

Browse files
zphangjeswan
andauthored
Benchmark script fixes (#1301)
* fixes recommended by eritain Co-authored-by: jeswan <[email protected]>
1 parent b4b5de0 commit 65888b4

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

guides/benchmarks/glue.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
python benchmark_submission_formatter.py \
99
--benchmark GLUE \
1010
--input_base_path $INPUT_BASE_PATH \
11-
--output_path $OUTPUT_BASE PATH
11+
--output_path $OUTPUT_BASE_PATH
1212
```
1313

1414
where `$INPUT_BASE_PATH` contains the task folder(s) output by [runscript.py](https://github.com/jiant-dev/jiant/blob/master/jiant/proj/main/runscript.py). Alternatively, a subset of tasks can be formatted using:
@@ -18,5 +18,5 @@ python benchmark_submission_formatter.py \
1818
--benchmark GLUE \
1919
--tasks cola mrpc \
2020
--input_base_path $INPUT_BASE_PATH \
21-
--output_path $OUTPUT_BASE PATH
21+
--output_path $OUTPUT_BASE_PATH
2222
```

guides/benchmarks/superglue.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
python benchmark_submission_formatter.py \
77
--benchmark SUPERGLUE \
88
--input_base_path $INPUT_BASE_PATH \
9-
--output_path $OUTPUT_BASE PATH
9+
--output_path $OUTPUT_BASE_PATH
1010
```
1111

1212
where `$INPUT_BASE_PATH` contains the task folder(s) output by [runscript.py](https://github.com/nyu-mll/jiant/blob/master/jiant/proj/main/runscript.py). Alternatively, a subset of tasks can be formatted using:
@@ -16,5 +16,5 @@ python benchmark_submission_formatter.py \
1616
--benchmark SUPERGLUE \
1717
--tasks cola mrpc \
1818
--input_base_path $INPUT_BASE_PATH \
19-
--output_path $OUTPUT_BASE PATH
19+
--output_path $OUTPUT_BASE_PATH
2020
```

jiant/scripts/benchmarks/benchmark_submission_formatter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os
55
import argparse
66

7-
from jiant.scripts.postproc.benchmarks import GlueBenchmark, SuperglueBenchmark
7+
from jiant.scripts.benchmarks.benchmarks import GlueBenchmark, SuperglueBenchmark
88

99

1010
SUPPORTED_BENCHMARKS = {"GLUE": GlueBenchmark, "SUPERGLUE": SuperglueBenchmark}
@@ -17,7 +17,7 @@ def main():
1717
parser.add_argument(
1818
"--input_base_path",
1919
required=True,
20-
help="base input path of benchmark task predictions (contains the benchmark task folders)",
20+
help="base path where per-task folders contain raw prediction files",
2121
)
2222
parser.add_argument("--output_path", required=True, help="output path for formatted files")
2323
parser.add_argument(
@@ -31,15 +31,15 @@ def main():
3131
benchmark = SUPPORTED_BENCHMARKS[args.benchmark]
3232

3333
if args.tasks:
34-
assert args.tasks in benchmark.TASKS
34+
assert set(args.tasks) <= benchmark.TASKS
3535
task_names = args.tasks
3636
else:
3737
task_names = benchmark.TASKS
3838

3939
for task_name in task_names:
4040
input_filepath = os.path.join(args.input_base_path, task_name, "test_preds.p")
4141
output_filepath = os.path.join(
42-
args.output_path, benchmark.BENCHMARK_SUBMISSION_FILENAMES[task_name]
42+
os.path.abspath(args.output_path), benchmark.BENCHMARK_SUBMISSION_FILENAMES[task_name]
4343
)
4444
benchmark.write_predictions(task_name, input_filepath, output_filepath)
4545

0 commit comments

Comments
 (0)