Skip to content

Commit 60e30d6

Browse files
committed
async124 now ignores class methods. fixed two bugs in async91x
1 parent 42971f1 commit 60e30d6

File tree

6 files changed

+138
-18
lines changed

6 files changed

+138
-18
lines changed

docs/changelog.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ Changelog
77
24.10.3
88
=======
99
- Add :ref:`ASYNC124 <async124>` async-function-could-be-sync
10+
- :ref:`ASYNC91x <ASYNC910>` now correctly handles ``await()`` in parameter lists.
11+
- Fixed a bug with :ref:`ASYNC91x <ASYNC910>` and nested empty functions.
1012

1113
24.10.2
1214
=======

flake8_async/visitors/visitor91x.py

Lines changed: 64 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def func_empty_body(node: cst.FunctionDef) -> bool:
6969
)
7070

7171

72+
# this could've been implemented as part of visitor91x, but /shrug
7273
@error_class_cst
7374
class Visitor124(Flake8AsyncVisitor_cst):
7475
error_codes: Mapping[str, str] = {
@@ -81,12 +82,37 @@ class Visitor124(Flake8AsyncVisitor_cst):
8182
def __init__(self, *args: Any, **kwargs: Any):
8283
super().__init__(*args, **kwargs)
8384
self.has_await = False
85+
self.in_class = False
8486

87+
def visit_ClassDef(self, node: cst.ClassDef):
88+
self.save_state(node, "in_class", copy=False)
89+
self.in_class = True
90+
91+
def leave_ClassDef(
92+
self, original_node: cst.ClassDef, updated_node: cst.ClassDef
93+
) -> cst.ClassDef:
94+
self.restore_state(original_node)
95+
return updated_node
96+
97+
# await in sync defs are not valid, but handling this will make ASYNC124
98+
# correctly pop up in parent func as if the child function was async
8599
def visit_FunctionDef(self, node: cst.FunctionDef):
86-
# await in sync defs are not valid, but handling this will make ASYNC124
87-
# pop up in parent func as if the child function was async
88-
self.save_state(node, "has_await", copy=False)
89-
self.has_await = False
100+
# default values are evaluated in parent scope
101+
# this visitor has no autofixes, so we can throw away return value
102+
_ = node.params.visit(self)
103+
104+
self.save_state(node, "has_await", "in_class", copy=False)
105+
106+
# ignore class methods
107+
self.has_await = self.in_class
108+
109+
# but not nested functions
110+
self.in_class = False
111+
112+
_ = node.body.visit(self)
113+
114+
# we've manually visited subnodes (that we care about).
115+
return False
90116

91117
def leave_FunctionDef(
92118
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
@@ -99,7 +125,7 @@ def leave_FunctionDef(
99125
# skip functions named 'text_xxx' with params, since they may be relying
100126
# on async fixtures. This is esp bad as sync funcs relying on async fixtures
101127
# is not well handled: https://github.com/pytest-dev/pytest/issues/10839
102-
# also skip funcs with @fixtures and params
128+
# also skip funcs with @fixture and params
103129
and not (
104130
original_node.params.params
105131
and (
@@ -115,11 +141,12 @@ def leave_FunctionDef(
115141
def visit_Await(self, node: cst.Await):
116142
self.has_await = True
117143

118-
def visit_With(self, node: cst.With | cst.For):
144+
def visit_With(self, node: cst.With | cst.For | cst.CompFor):
119145
if node.asynchronous is not None:
120146
self.has_await = True
121147

122148
visit_For = visit_With
149+
visit_CompFor = visit_With
123150

124151

125152
@dataclass
@@ -348,6 +375,9 @@ def __init__(self, *args: Any, **kwargs: Any):
348375
# --exception-suppress-context-manager
349376
self.suppress_imported_as: list[str] = []
350377

378+
# used to transfer new body between visit_FunctionDef and leave_FunctionDef
379+
self.new_body: cst.BaseSuite | None = None
380+
351381
def should_autofix(self, node: cst.CSTNode, code: str | None = None) -> bool:
352382
if code is None: # pragma: no branch
353383
code = "ASYNC911" if self.has_yield else "ASYNC910"
@@ -388,6 +418,10 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
388418
return
389419

390420
def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
421+
# `await` in default values happen in parent scope
422+
# we also know we don't ever modify parameters so we can ignore the return value
423+
_ = node.params.visit(self)
424+
391425
# don't lint functions whose bodies solely consist of pass or ellipsis
392426
# @overload functions are also guaranteed to be empty
393427
# we also ignore pytest fixtures
@@ -417,36 +451,49 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
417451
node.decorators, *self.options.no_checkpoint_warning_decorators
418452
)
419453
)
420-
if not self.async_function:
421-
# only visit subnodes if there is an async function defined inside
422-
# this should improve performance on codebases with many sync functions
423-
return any(m.findall(node, m.FunctionDef(asynchronous=m.Asynchronous())))
454+
# only visit subnodes if there is an async function defined inside
455+
# this should improve performance on codebases with many sync functions
456+
if not self.async_function and not any(
457+
m.findall(node, m.FunctionDef(asynchronous=m.Asynchronous()))
458+
):
459+
return False
424460

425461
pos = self.get_metadata(PositionProvider, node).start # type: ignore
426462
self.uncheckpointed_statements = {
427463
Statement("function definition", pos.line, pos.column) # type: ignore
428464
}
429-
return True
465+
466+
# visit body
467+
# we're not gonna get FlattenSentinel or RemovalSentinel
468+
self.new_body = cast(cst.BaseSuite, node.body.visit(self))
469+
470+
# we know that leave_FunctionDef for this FunctionDef will run immediately after
471+
# this function exits so we don't need to worry about save_state for new_body
472+
return False
430473

431474
def leave_FunctionDef(
432475
self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef
433476
) -> cst.FunctionDef:
434477
if (
435-
self.async_function
478+
self.new_body is not None
479+
and self.async_function
436480
# updated_node does not have a position, so we must send original_node
437481
and self.check_function_exit(original_node)
438482
and self.should_autofix(original_node)
439-
and isinstance(updated_node.body, cst.IndentedBlock)
483+
and isinstance(self.new_body, cst.IndentedBlock)
440484
):
441485
# insert checkpoint at the end of body
442-
new_body = list(updated_node.body.body)
443-
new_body.append(self.checkpoint_statement())
444-
indentedblock = updated_node.body.with_changes(body=new_body)
445-
updated_node = updated_node.with_changes(body=indentedblock)
486+
new_body_block = list(self.new_body.body)
487+
new_body_block.append(self.checkpoint_statement())
488+
self.new_body = self.new_body.with_changes(body=new_body_block)
446489

447490
self.ensure_imported_library()
448491

492+
if self.new_body is not None:
493+
updated_node = updated_node.with_changes(body=self.new_body)
449494
self.restore_state(original_node)
495+
# reset self.new_body
496+
self.new_body = None
450497
return updated_node
451498

452499
# error if function exit/return/yields with uncheckpointed statements

tests/autofix_files/async910.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,3 +619,16 @@ async def fn_226(): # error: 0, "exit", Statement("function definition", lineno
619619
except Exception:
620620
pass
621621
await trio.lowlevel.checkpoint()
622+
623+
# the await() is evaluated in the parent scope
624+
async def foo_default_value_await():
625+
async def bar(arg=await foo()): # error: 4, "exit", Statement("function definition", lineno)
626+
print()
627+
await trio.lowlevel.checkpoint()
628+
629+
630+
async def foo_nested_empty_async():
631+
# this previously errored because leave_FunctionDef assumed a non-empty body
632+
async def bar():
633+
...
634+
await foo()

tests/autofix_files/async910.py.diff

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,17 @@
213213

214214

215215
# Issue #226
216-
@@ x,3 x,4 @@
216+
@@ x,11 x,13 @@
217217
pass
218218
except Exception:
219219
pass
220220
+ await trio.lowlevel.checkpoint()
221+
222+
# the await() is evaluated in the parent scope
223+
async def foo_default_value_await():
224+
async def bar(arg=await foo()): # error: 4, "exit", Statement("function definition", lineno)
225+
print()
226+
+ await trio.lowlevel.checkpoint()
227+
228+
229+
async def foo_nested_empty_async():

tests/eval_files/async124.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,37 @@ async def foo_fix(my_async_fixture):
9999
@fixture
100100
async def foo_fix_no_subfix(): # ASYNC124: 0
101101
print("blah")
102+
103+
104+
async def default_value():
105+
def foo(arg=await foo()): ...
106+
107+
108+
# 124 doesn't care if you evaluate the comprehension or not
109+
# 910 is stingy
110+
async def foo_async_gen():
111+
return (
112+
a async for a in foo_gen()
113+
) # ASYNC910: 4, "return", Statement("function definition", lineno-1)
114+
115+
116+
async def foo_async_for_comprehension():
117+
return [a async for a in foo_gen()]
118+
119+
120+
class Foo:
121+
# async124 ignores class methods
122+
async def bar(
123+
self,
124+
): # ASYNC910: 4, "exit", Statement("function definition", lineno)
125+
async def bee(): # ASYNC124: 8 # ASYNC910: 8, "exit", Statement("function definition", lineno)
126+
print("blah")
127+
128+
async def later_in_class(
129+
self,
130+
): # ASYNC910: 4, "exit", Statement("function definition", lineno)
131+
print()
132+
133+
134+
async def after_class(): # ASYNC124: 0 # ASYNC910: 0, "exit", Statement("function definition", lineno)
135+
print()

tests/eval_files/async910.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,3 +589,18 @@ async def fn_226(): # error: 0, "exit", Statement("function definition", lineno
589589
pass
590590
except Exception:
591591
pass
592+
593+
594+
# the await() is evaluated in the parent scope
595+
async def foo_default_value_await():
596+
async def bar(
597+
arg=await foo(),
598+
): # error: 4, "exit", Statement("function definition", lineno)
599+
print()
600+
601+
602+
async def foo_nested_empty_async():
603+
# this previously errored because leave_FunctionDef assumed a non-empty body
604+
async def bar(): ...
605+
606+
await foo()

0 commit comments

Comments
 (0)