Skip to content

Commit 4f63055

Browse files
committed
add async124 yield-in-asynccm-not-in-try
1 parent b94c04c commit 4f63055

File tree

5 files changed

+83
-0
lines changed

5 files changed

+83
-0
lines changed

docs/rules.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ _`ASYNC123`: bad-exception-group-flattening
9494
Dropping this information makes diagnosing errors much more difficult.
9595
We recommend ``raise SomeNewError(...) from group`` if possible; or consider using `copy.copy` to shallow-copy the exception before re-raising (for copyable types), or re-raising the error from outside the `except` block.
9696

97+
_`ASYNC124`: yield-in-asynccm-not-in-try
98+
`yield` in ``@asynccontextmanager`` should usually be in a ``try:`` with cleanup code in ``finally:`` so cleanup runs if the yielded code raises an exception.
99+
97100
Blocking sync calls in async functions
98101
======================================
99102

flake8_async/visitors/visitor102.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,9 @@ def visit_Try(self, node: ast.Try):
161161
self._critical_scope = Statement("try/finally", node.lineno, node.col_offset)
162162
self.visit_nodes(node.finalbody)
163163

164+
# don't revisit children
165+
self.novisit = True
166+
164167
def visit_ExceptHandler(self, node: ast.ExceptHandler):
165168
# if we're inside a critical scope, a nested except should never override that
166169
if self._critical_scope is not None and self._critical_scope.name != "except":

flake8_async/visitors/visitors.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,48 @@ def visit_Call(self, node: ast.Call):
437437
self.error(node, f"{match[2]}.{match[1]}")
438438

439439

440+
@error_class
441+
class Visitor124(Flake8AsyncVisitor):
442+
error_codes: Mapping[str, str] = {
443+
"ASYNC124": (
444+
"yield in @asynccontextmanager should usually be in a `try:`"
445+
" with cleanup in `finally` to ensure cleanup is run."
446+
)
447+
}
448+
449+
def __init__(self, *args: Any, **kwargs: Any):
450+
super().__init__(*args, **kwargs)
451+
self.in_asynccontextmanager: bool = False
452+
self.in_try: bool = False
453+
454+
def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef):
455+
self.save_state(node, "in_asynccontextmanager")
456+
457+
# TODO: should it also error on @contextmanager? If so it feels like
458+
# flake8-bugbear might be more appropriate.
459+
if has_decorator(node, "asynccontextmanager"):
460+
self.in_asynccontextmanager = True
461+
462+
# ast.TryStar added in py311, we run mypy on py39
463+
def visit_Try(self, node: ast.Try | ast.TryStar): # type: ignore[name-defined]
464+
old_in_try = self.in_try
465+
self.in_try = True
466+
467+
self.visit_nodes(node.body)
468+
self.in_try = old_in_try
469+
470+
self.visit_nodes(node.handlers, node.orelse, node.finalbody)
471+
472+
# don't revisit children
473+
self.novisit = True
474+
475+
visit_TryStar = visit_Try
476+
477+
def visit_Yield(self, node: ast.Yield):
478+
if self.in_asynccontextmanager and not self.in_try:
479+
self.error(node)
480+
481+
440482
@error_class_cst
441483
class Visitor300(Flake8AsyncVisitor_cst):
442484
error_codes: Mapping[str, str] = {

tests/eval_files/async124.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from contextlib import asynccontextmanager
2+
3+
4+
@asynccontextmanager
5+
async def foo():
6+
try:
7+
# TODO: should it error if there is no finally?
8+
yield
9+
except:
10+
...
11+
12+
13+
@asynccontextmanager
14+
async def foo2():
15+
try:
16+
...
17+
except:
18+
yield # error: 8

tests/eval_files/async124_py311.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from contextlib import asynccontextmanager
2+
3+
4+
@asynccontextmanager
5+
async def foo():
6+
try:
7+
yield
8+
except* Exception:
9+
...
10+
11+
12+
@asynccontextmanager
13+
async def bar():
14+
try:
15+
...
16+
except* Exception:
17+
yield # error: 8

0 commit comments

Comments
 (0)