Skip to content

Commit 04a8753

Browse files
authored
accept both string and sequence types for reduce_ops parameter in reduce_events() (#45)
- add parameterized test for str and list equivalence of reduce_ops in `test_reduce_events` - bump pre-commit hooks - remove duplicate typos pre-commit hook
1 parent b593a59 commit 04a8753

File tree

3 files changed

+44
-14
lines changed

3 files changed

+44
-14
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ default_install_hook_types: [pre-commit, commit-msg]
88

99
repos:
1010
- repo: https://github.com/astral-sh/ruff-pre-commit
11-
rev: v0.11.8
11+
rev: v0.11.12
1212
hooks:
1313
- id: ruff
1414
args: [--fix]
@@ -48,12 +48,6 @@ repos:
4848
- id: nbstripout
4949
args: [--drop-empty-cells, --keep-output]
5050

51-
- repo: https://github.com/crate-ci/typos
52-
rev: v1.32.0
53-
hooks:
54-
- id: typos
55-
exclude_types: [bib]
56-
5751
- repo: local
5852
hooks:
5953
- id: ty

tensorboard_reducer/reduce.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
def reduce_events(
1212
events_dict: dict[str, pd.DataFrame],
13-
reduce_ops: Sequence[str],
13+
reduce_ops: str | Sequence[str],
1414
*,
1515
verbose: bool = False,
1616
) -> dict[str, dict[str, pd.DataFrame]]:
@@ -22,14 +22,19 @@ def reduce_events(
2222
2323
Args:
2424
events_dict (dict[str, pd.DataFrame]): Dict of arrays to reduce.
25-
reduce_ops (list[str]): Names of numpy reduce ops. E.g. mean, std, min, max, ...
25+
reduce_ops (str | list[str]): Names of numpy reduce ops. E.g. mean, std, min,
26+
max, ... Can be a single string or a sequence of strings.
2627
verbose (bool, optional): Whether to print progress. Defaults to False.
2728
2829
Returns:
2930
dict[str, dict[str, pd.DataFrame]]: Dict of dicts where each subdict holds one
3031
reduced array for each of the specified reduce ops, e.g.
3132
{"loss": {"mean": arr.mean(-1), "std": arr.std(-1)}}.
3233
"""
34+
# Handle case where reduce_ops is a single string
35+
if isinstance(reduce_ops, str):
36+
reduce_ops = [reduce_ops]
37+
3338
reductions: dict[str, dict[str, pd.DataFrame]] = {}
3439

3540
for op in reduce_ops:

tests/test_reduce.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
def generate_sample_data(
1616
n_tags: int = 1, n_runs: int = 10, n_steps: int = 5
1717
) -> dict[str, pd.DataFrame]:
18+
"""Generate sample test data for testing reduce operations."""
1819
events_dict = {}
1920
rng = np.random.default_rng()
2021
for idx in range(n_tags):
@@ -32,6 +33,7 @@ def test_reduce_events(
3233
verbose: bool,
3334
capsys: pytest.CaptureFixture[str],
3435
) -> None:
36+
"""Test reduce_events with sequence of operations."""
3537
reduced_events = reduce_events(events_dict, reduce_ops, verbose=verbose)
3638

3739
out_keys = list(reduced_events)
@@ -41,17 +43,17 @@ def test_reduce_events(
4143
)
4244

4345
# loop over reduce operations
44-
for (op, out_dict), in_arr in zip(
45-
reduced_events.items(), events_dict.values(), strict=True
46-
):
47-
n_steps = len(in_arr) # length of TB logs
46+
for op in reduce_ops:
47+
out_dict = reduced_events[op]
4848

49-
# loop over event tags (only 'strict/foo' here)
49+
# loop over event tags (e.g., 'strict/foo')
5050
for tag, out_arr in out_dict.items():
5151
assert tag in events_dict, (
5252
f"unexpected key {tag} in reduced event dict[{op}] = {list(out_dict)}"
5353
)
5454

55+
in_arr = events_dict[tag]
56+
n_steps = len(in_arr) # length of TB logs
5557
out_steps = len(out_arr)
5658

5759
assert n_steps == out_steps, (
@@ -76,8 +78,36 @@ def test_reduce_events(
7678
assert stdout == ""
7779

7880

81+
@pytest.mark.parametrize("reduce_op", ["mean", "std", "max"])
82+
def test_reduce_events_reduce_op_str_list_equivalence(reduce_op: str) -> None:
83+
"""Test string input (fixes issue #44) and equivalence with list input."""
84+
events_dict = generate_sample_data(n_tags=2, n_runs=3, n_steps=4)
85+
86+
# Test string input
87+
result_string = reduce_events(events_dict, reduce_op)
88+
89+
# Test list input
90+
result_list = reduce_events(events_dict, [reduce_op])
91+
92+
# Verify string input structure and correctness
93+
assert len(result_string) == 1
94+
assert reduce_op in result_string
95+
for tag, df in events_dict.items():
96+
assert tag in result_string[reduce_op]
97+
expected = getattr(df, reduce_op)(axis=1)
98+
pd.testing.assert_series_equal(result_string[reduce_op][tag], expected)
99+
100+
# Verify string and list inputs produce identical results
101+
assert result_string.keys() == result_list.keys()
102+
for tag in events_dict:
103+
pd.testing.assert_series_equal(
104+
result_string[reduce_op][tag], result_list[reduce_op][tag]
105+
)
106+
107+
79108
@pytest.mark.parametrize("n_tags, n_runs, n_steps", [(1, 10, 5), (2, 5, 3), (3, 3, 10)])
80109
def test_reduce_events_dimensions(n_tags: int, n_runs: int, n_steps: int) -> None:
110+
"""Test reduce_events with different data dimensions."""
81111
events_dict = generate_sample_data(n_tags=n_tags, n_runs=n_runs, n_steps=n_steps)
82112
reduce_ops = ["mean", "std", "max", "min"]
83113
reduced_events = reduce_events(events_dict, reduce_ops)
@@ -97,5 +127,6 @@ def test_reduce_events_dimensions(n_tags: int, n_runs: int, n_steps: int) -> Non
97127

98128
@pytest.mark.parametrize("reduce_ops", [["mean"], ["max", "min"], ["std", "median"]])
99129
def test_reduce_events_empty_input(reduce_ops: Sequence[str]) -> None:
130+
"""Test reduce_events with empty input dictionary."""
100131
reduced_events = reduce_events({}, reduce_ops)
101132
assert reduced_events == {op: {} for op in reduce_ops}

0 commit comments

Comments
 (0)