Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ It is therefore recommended to use a stacklevel of 2 or greater to provide more

**B030**: Except handlers should only be exception classes or tuples of exception classes.

**B031**: Using the generator returned from `itertools.groupby()` more than once will do nothing on the
second usage. Save the result to a list if the result is needed multiple times.

Opinionated warnings
~~~~~~~~~~~~~~~~~~~~

Expand Down
65 changes: 65 additions & 0 deletions bugbear.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ def children_in_scope(node):
yield from children_in_scope(child)


def walk_list(nodes):
for node in nodes:
yield from ast.walk(node)


def _typesafe_issubclass(cls, class_or_tuple):
try:
return issubclass(cls, class_or_tuple)
Expand Down Expand Up @@ -401,6 +406,7 @@ def visit_For(self, node):
self.check_for_b007(node)
self.check_for_b020(node)
self.check_for_b023(node)
self.check_for_b031(node)
self.generic_visit(node)

def visit_AsyncFor(self, node):
Expand Down Expand Up @@ -793,6 +799,56 @@ def check_for_b026(self, call: ast.Call):
):
self.errors.append(B026(starred.lineno, starred.col_offset))

def check_for_b031(self, loop_node): # noqa: C901
"""Check that `itertools.groupby` isn't iterated over more than once.

We emit a warning when the generator returned by `groupby()` is used
more than once inside a loop body or when it's used in a nested loop.
"""
# for <loop_node.target> in <loop_node.iter>: ...
if isinstance(loop_node.iter, ast.Call):
node = loop_node.iter
if (isinstance(node.func, ast.Name) and node.func.id in ("groupby",)) or (
isinstance(node.func, ast.Attribute)
and node.func.attr == "groupby"
and isinstance(node.func.value, ast.Name)
and node.func.value.id == "itertools"
):
# We have an invocation of groupby which is a simple unpacking
if isinstance(loop_node.target, ast.Tuple) and isinstance(
loop_node.target.elts[1], ast.Name
):
group_name = loop_node.target.elts[1].id
else:
# Ignore any `groupby()` invocation that isn't unpacked
return

num_usages = 0
for node in walk_list(loop_node.body):
# Handled nested loops
if isinstance(node, ast.For):
for nested_node in walk_list(node.body):
assert nested_node != node
if (
isinstance(nested_node, ast.Name)
and nested_node.id == group_name
):
self.errors.append(
B031(
nested_node.lineno,
nested_node.col_offset,
vars=(nested_node.id,),
)
)

# Handle multiple uses
if isinstance(node, ast.Name) and node.id == group_name:
num_usages += 1
if num_usages > 1:
self.errors.append(
B031(node.lineno, node.col_offset, vars=(node.id,))
)

def _get_assigned_names(self, loop_node):
loop_targets = (ast.For, ast.AsyncFor, ast.comprehension)
for node in children_in_scope(loop_node):
Expand Down Expand Up @@ -1558,8 +1614,17 @@ def visit_Lambda(self, node):
"anything. Add exceptions to handle."
)
)

B030 = Error(message="B030 Except handlers should only be names of exception classes")

B031 = Error(
message=(
"B031 Using the generator returned from `itertools.groupby()` more than once"
" will do nothing on the second usage. Save the result to a list, if the"
" result is needed multiple times."
)
)

# Warnings disabled by default.
B901 = Error(
message=(
Expand Down
64 changes: 64 additions & 0 deletions tests/b031.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Should emit:
B030 - on lines 29, 33, 43
"""
import itertools
from itertools import groupby

shoppers = ["Jane", "Joe", "Sarah"]
items = [
("lettuce", "greens"),
("tomatoes", "greens"),
("cucumber", "greens"),
("chicken breast", "meats & fish"),
("salmon", "meats & fish"),
("ice cream", "frozen items"),
]

carts = {shopper: [] for shopper in shoppers}


def collect_shop_items(shopper, items):
# Imagine this an expensive database query or calculation that is
# advantageous to batch.
carts[shopper] += items


# Group by shopping section
for _section, section_items in groupby(items, key=lambda p: p[1]):
for shopper in shoppers:
collect_shop_items(shopper, section_items)

for _section, section_items in groupby(items, key=lambda p: p[1]):
collect_shop_items("Jane", section_items)
collect_shop_items("Joe", section_items)


for _section, section_items in groupby(items, key=lambda p: p[1]):
# This is ok
collect_shop_items("Jane", section_items)

for _section, section_items in itertools.groupby(items, key=lambda p: p[1]):
for shopper in shoppers:
collect_shop_items(shopper, section_items)

for group in groupby(items, key=lambda p: p[1]):
# This is bad, but not detected currently
collect_shop_items("Jane", group[1])
collect_shop_items("Joe", group[1])


# Make sure we ignore - but don't fail on more complicated invocations
for _key, (_value1, _value2) in groupby(
[("a", (1, 2)), ("b", (3, 4)), ("a", (5, 6))], key=lambda p: p[1]
):
collect_shop_items("Jane", group[1])
collect_shop_items("Joe", group[1])

# Make sure we ignore - but don't fail on more complicated invocations
for (_key1, _key2), (_value1, _value2) in groupby(
[(("a", "a"), (1, 2)), (("b", "b"), (3, 4)), (("a", "a"), (5, 6))],
key=lambda p: p[1],
):
collect_shop_items("Jane", group[1])
collect_shop_items("Joe", group[1])
12 changes: 12 additions & 0 deletions tests/test_bugbear.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
B028,
B029,
B030,
B031,
B901,
B902,
B903,
Expand Down Expand Up @@ -459,6 +460,17 @@ def test_b030(self):
)
self.assertEqual(errors, expected)

def test_b031(self):
filename = Path(__file__).absolute().parent / "b031.py"
bbc = BugBearChecker(filename=str(filename))
errors = list(bbc.run())
expected = self.errors(
B031(30, 36, vars=("section_items",)),
B031(34, 30, vars=("section_items",)),
B031(43, 36, vars=("section_items",)),
)
self.assertEqual(errors, expected)

@unittest.skipIf(sys.version_info < (3, 8), "not implemented for <3.8")
def test_b907(self):
filename = Path(__file__).absolute().parent / "b907.py"
Expand Down