@@ -69,6 +69,7 @@ def func_empty_body(node: cst.FunctionDef) -> bool:
69
69
)
70
70
71
71
72
+ # this could've been implemented as part of visitor91x, but /shrug
72
73
@error_class_cst
73
74
class Visitor124 (Flake8AsyncVisitor_cst ):
74
75
error_codes : Mapping [str , str ] = {
@@ -81,12 +82,37 @@ class Visitor124(Flake8AsyncVisitor_cst):
81
82
def __init__ (self , * args : Any , ** kwargs : Any ):
82
83
super ().__init__ (* args , ** kwargs )
83
84
self .has_await = False
85
+ self .in_class = False
84
86
87
+ def visit_ClassDef (self , node : cst .ClassDef ):
88
+ self .save_state (node , "in_class" , copy = False )
89
+ self .in_class = True
90
+
91
+ def leave_ClassDef (
92
+ self , original_node : cst .ClassDef , updated_node : cst .ClassDef
93
+ ) -> cst .ClassDef :
94
+ self .restore_state (original_node )
95
+ return updated_node
96
+
97
+ # await in sync defs are not valid, but handling this will make ASYNC124
98
+ # correctly pop up in parent func as if the child function was async
85
99
def visit_FunctionDef (self , node : cst .FunctionDef ):
86
- # await in sync defs are not valid, but handling this will make ASYNC124
87
- # pop up in parent func as if the child function was async
88
- self .save_state (node , "has_await" , copy = False )
89
- self .has_await = False
100
+ # default values are evaluated in parent scope
101
+ # this visitor has no autofixes, so we can throw away return value
102
+ _ = node .params .visit (self )
103
+
104
+ self .save_state (node , "has_await" , "in_class" , copy = False )
105
+
106
+ # ignore class methods
107
+ self .has_await = self .in_class
108
+
109
+ # but not nested functions
110
+ self .in_class = False
111
+
112
+ _ = node .body .visit (self )
113
+
114
+ # we've manually visited subnodes (that we care about).
115
+ return False
90
116
91
117
def leave_FunctionDef (
92
118
self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef
@@ -99,7 +125,7 @@ def leave_FunctionDef(
99
125
# skip functions named 'text_xxx' with params, since they may be relying
100
126
# on async fixtures. This is esp bad as sync funcs relying on async fixtures
101
127
# is not well handled: https://github.com/pytest-dev/pytest/issues/10839
102
- # also skip funcs with @fixtures and params
128
+ # also skip funcs with @fixture and params
103
129
and not (
104
130
original_node .params .params
105
131
and (
@@ -115,11 +141,12 @@ def leave_FunctionDef(
115
141
def visit_Await (self , node : cst .Await ):
116
142
self .has_await = True
117
143
118
- def visit_With (self , node : cst .With | cst .For ):
144
+ def visit_With (self , node : cst .With | cst .For | cst . CompFor ):
119
145
if node .asynchronous is not None :
120
146
self .has_await = True
121
147
122
148
visit_For = visit_With
149
+ visit_CompFor = visit_With
123
150
124
151
125
152
@dataclass
@@ -348,6 +375,9 @@ def __init__(self, *args: Any, **kwargs: Any):
348
375
# --exception-suppress-context-manager
349
376
self .suppress_imported_as : list [str ] = []
350
377
378
+ # used to transfer new body between visit_FunctionDef and leave_FunctionDef
379
+ self .new_body : cst .BaseSuite | None = None
380
+
351
381
def should_autofix (self , node : cst .CSTNode , code : str | None = None ) -> bool :
352
382
if code is None : # pragma: no branch
353
383
code = "ASYNC911" if self .has_yield else "ASYNC910"
@@ -388,6 +418,10 @@ def visit_ImportFrom(self, node: cst.ImportFrom) -> None:
388
418
return
389
419
390
420
def visit_FunctionDef (self , node : cst .FunctionDef ) -> bool :
421
+ # `await` in default values happen in parent scope
422
+ # we also know we don't ever modify parameters so we can ignore the return value
423
+ _ = node .params .visit (self )
424
+
391
425
# don't lint functions whose bodies solely consist of pass or ellipsis
392
426
# @overload functions are also guaranteed to be empty
393
427
# we also ignore pytest fixtures
@@ -417,36 +451,49 @@ def visit_FunctionDef(self, node: cst.FunctionDef) -> bool:
417
451
node .decorators , * self .options .no_checkpoint_warning_decorators
418
452
)
419
453
)
420
- if not self .async_function :
421
- # only visit subnodes if there is an async function defined inside
422
- # this should improve performance on codebases with many sync functions
423
- return any (m .findall (node , m .FunctionDef (asynchronous = m .Asynchronous ())))
454
+ # only visit subnodes if there is an async function defined inside
455
+ # this should improve performance on codebases with many sync functions
456
+ if not self .async_function and not any (
457
+ m .findall (node , m .FunctionDef (asynchronous = m .Asynchronous ()))
458
+ ):
459
+ return False
424
460
425
461
pos = self .get_metadata (PositionProvider , node ).start # type: ignore
426
462
self .uncheckpointed_statements = {
427
463
Statement ("function definition" , pos .line , pos .column ) # type: ignore
428
464
}
429
- return True
465
+
466
+ # visit body
467
+ # we're not gonna get FlattenSentinel or RemovalSentinel
468
+ self .new_body = cast (cst .BaseSuite , node .body .visit (self ))
469
+
470
+ # we know that leave_FunctionDef for this FunctionDef will run immediately after
471
+ # this function exits so we don't need to worry about save_state for new_body
472
+ return False
430
473
431
474
def leave_FunctionDef (
432
475
self , original_node : cst .FunctionDef , updated_node : cst .FunctionDef
433
476
) -> cst .FunctionDef :
434
477
if (
435
- self .async_function
478
+ self .new_body is not None
479
+ and self .async_function
436
480
# updated_node does not have a position, so we must send original_node
437
481
and self .check_function_exit (original_node )
438
482
and self .should_autofix (original_node )
439
- and isinstance (updated_node . body , cst .IndentedBlock )
483
+ and isinstance (self . new_body , cst .IndentedBlock )
440
484
):
441
485
# insert checkpoint at the end of body
442
- new_body = list (updated_node .body .body )
443
- new_body .append (self .checkpoint_statement ())
444
- indentedblock = updated_node .body .with_changes (body = new_body )
445
- updated_node = updated_node .with_changes (body = indentedblock )
486
+ new_body_block = list (self .new_body .body )
487
+ new_body_block .append (self .checkpoint_statement ())
488
+ self .new_body = self .new_body .with_changes (body = new_body_block )
446
489
447
490
self .ensure_imported_library ()
448
491
492
+ if self .new_body is not None :
493
+ updated_node = updated_node .with_changes (body = self .new_body )
449
494
self .restore_state (original_node )
495
+ # reset self.new_body
496
+ self .new_body = None
450
497
return updated_node
451
498
452
499
# error if function exit/return/yields with uncheckpointed statements
0 commit comments