Skip to content

Commit c95d8ab

Browse files
[mypyc] feat: cache len for iterating over immutable types (#19656)
Currently, if a user uses an immutable type as the sequence input for a for loop, the length is checked once at each iteration which, while necessary for some container types such as list and dictionaries, is not necessary for iterating over immutable types tuple, str, and bytes. This PR modifies the codebase such that the length is only checked at the first iteration, and reused from there. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent fb41108 commit c95d8ab

File tree

8 files changed

+846
-41
lines changed

8 files changed

+846
-41
lines changed

mypyc/ir/rtypes.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,19 @@ def is_range_rprimitive(rtype: RType) -> bool:
628628

629629
def is_sequence_rprimitive(rtype: RType) -> bool:
630630
return isinstance(rtype, RPrimitive) and (
631-
is_list_rprimitive(rtype) or is_tuple_rprimitive(rtype) or is_str_rprimitive(rtype)
631+
is_list_rprimitive(rtype)
632+
or is_tuple_rprimitive(rtype)
633+
or is_str_rprimitive(rtype)
634+
or is_bytes_rprimitive(rtype)
635+
)
636+
637+
638+
def is_immutable_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
639+
return (
640+
is_str_rprimitive(rtype)
641+
or is_bytes_rprimitive(rtype)
642+
or is_tuple_rprimitive(rtype)
643+
or is_frozenset_rprimitive(rtype)
632644
)
633645

634646

mypyc/irbuild/builder.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
RType,
9292
RUnion,
9393
bitmap_rprimitive,
94+
bytes_rprimitive,
9495
c_pyssize_t_rprimitive,
9596
dict_rprimitive,
9697
int_rprimitive,
@@ -962,8 +963,12 @@ def get_sequence_type_from_type(self, target_type: Type) -> RType:
962963
elif isinstance(target_type, Instance):
963964
if target_type.type.fullname == "builtins.str":
964965
return str_rprimitive
965-
else:
966+
elif target_type.type.fullname == "builtins.bytes":
967+
return bytes_rprimitive
968+
try:
966969
return self.type_to_rtype(target_type.args[0])
970+
except IndexError:
971+
raise ValueError(f"{target_type!r} is not a valid sequence.") from None
967972
# This elif-blocks are needed for iterating over classes derived from NamedTuple.
968973
elif isinstance(target_type, TypeVarLikeType):
969974
return self.get_sequence_type_from_type(target_type.upper_bound)

mypyc/irbuild/for_helpers.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
int_rprimitive,
4949
is_dict_rprimitive,
5050
is_fixed_width_rtype,
51+
is_immutable_rprimitive,
5152
is_list_rprimitive,
5253
is_sequence_rprimitive,
5354
is_short_int_rprimitive,
@@ -205,9 +206,9 @@ def sequence_from_generator_preallocate_helper(
205206
there is no condition list in the generator and only one original sequence with
206207
one index is allowed.
207208
208-
e.g. (1) tuple(f(x) for x in a_list/a_tuple)
209-
(2) list(f(x) for x in a_list/a_tuple)
210-
(3) [f(x) for x in a_list/a_tuple]
209+
e.g. (1) tuple(f(x) for x in a_list/a_tuple/a_str/a_bytes)
210+
(2) list(f(x) for x in a_list/a_tuple/a_str/a_bytes)
211+
(3) [f(x) for x in a_list/a_tuple/a_str/a_bytes]
211212
RTuple as an original sequence is not supported yet.
212213
213214
Args:
@@ -224,7 +225,7 @@ def sequence_from_generator_preallocate_helper(
224225
"""
225226
if len(gen.sequences) == 1 and len(gen.indices) == 1 and len(gen.condlists[0]) == 0:
226227
rtype = builder.node_type(gen.sequences[0])
227-
if is_list_rprimitive(rtype) or is_tuple_rprimitive(rtype) or is_str_rprimitive(rtype):
228+
if is_sequence_rprimitive(rtype):
228229
sequence = builder.accept(gen.sequences[0])
229230
length = builder.builder.builtin_len(sequence, gen.line, use_pyssize_t=True)
230231
target_op = empty_op_llbuilder(length, gen.line)
@@ -785,17 +786,31 @@ class ForSequence(ForGenerator):
785786
Supports iterating in both forward and reverse.
786787
"""
787788

789+
length_reg: Value | AssignmentTarget | None
790+
788791
def init(self, expr_reg: Value, target_type: RType, reverse: bool) -> None:
792+
assert is_sequence_rprimitive(expr_reg.type), expr_reg
789793
builder = self.builder
790794
self.reverse = reverse
791795
# Define target to contain the expression, along with the index that will be used
792796
# for the for-loop. If we are inside of a generator function, spill these into the
793797
# environment class.
794798
self.expr_target = builder.maybe_spill(expr_reg)
799+
if is_immutable_rprimitive(expr_reg.type):
800+
# If the expression is an immutable type, we can load the length just once.
801+
self.length_reg = builder.maybe_spill(self.load_len(self.expr_target))
802+
else:
803+
# Otherwise, even if the length is known, we must recalculate the length
804+
# at every iteration for compatibility with python semantics.
805+
self.length_reg = None
795806
if not reverse:
796807
index_reg: Value = Integer(0, c_pyssize_t_rprimitive)
797808
else:
798-
index_reg = builder.builder.int_sub(self.load_len(self.expr_target), 1)
809+
if self.length_reg is not None:
810+
len_val = builder.read(self.length_reg)
811+
else:
812+
len_val = self.load_len(self.expr_target)
813+
index_reg = builder.builder.int_sub(len_val, 1)
799814
self.index_target = builder.maybe_spill_assignable(index_reg)
800815
self.target_type = target_type
801816

@@ -814,9 +829,13 @@ def gen_condition(self) -> None:
814829
second_check = BasicBlock()
815830
builder.add_bool_branch(comparison, second_check, self.loop_exit)
816831
builder.activate_block(second_check)
817-
# For compatibility with python semantics we recalculate the length
818-
# at every iteration.
819-
len_reg = self.load_len(self.expr_target)
832+
if self.length_reg is None:
833+
# For compatibility with python semantics we recalculate the length
834+
# at every iteration.
835+
len_reg = self.load_len(self.expr_target)
836+
else:
837+
# (unless input is immutable type).
838+
len_reg = builder.read(self.length_reg, line)
820839
comparison = builder.binary_op(builder.read(self.index_target, line), len_reg, "<", line)
821840
builder.add_bool_branch(comparison, self.body_block, self.loop_exit)
822841

mypyc/irbuild/specialize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def translate_tuple_from_generator_call(
288288
"""Special case for simplest tuple creation from a generator.
289289
290290
For example:
291-
tuple(f(x) for x in some_list/some_tuple/some_str)
291+
tuple(f(x) for x in some_list/some_tuple/some_str/some_bytes)
292292
'translate_safe_generator_call()' would take care of other cases
293293
if this fails.
294294
"""

mypyc/test-data/fixtures/ir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,7 @@ def __getitem__(self, i: int) -> int: ...
172172
def __getitem__(self, i: slice) -> bytes: ...
173173
def join(self, x: Iterable[object]) -> bytes: ...
174174
def decode(self, x: str=..., y: str=...) -> str: ...
175+
def __iter__(self) -> Iterator[int]: ...
175176

176177
class bytearray:
177178
@overload

mypyc/test-data/irbuild-generics.test

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -711,18 +711,18 @@ L0:
711711
r0 = __mypyc_self__.__mypyc_env__
712712
r1 = var_object_size args
713713
r2 = PyList_New(r1)
714-
r3 = 0
714+
r3 = var_object_size args
715+
r4 = 0
715716
L1:
716-
r4 = var_object_size args
717-
r5 = r3 < r4 :: signed
717+
r5 = r4 < r3 :: signed
718718
if r5 goto L2 else goto L4 :: bool
719719
L2:
720-
r6 = CPySequenceTuple_GetItemUnsafe(args, r3)
720+
r6 = CPySequenceTuple_GetItemUnsafe(args, r4)
721721
x = r6
722-
CPyList_SetItemUnsafe(r2, r3, x)
722+
CPyList_SetItemUnsafe(r2, r4, x)
723723
L3:
724-
r7 = r3 + 1
725-
r3 = r7
724+
r7 = r4 + 1
725+
r4 = r7
726726
goto L1
727727
L4:
728728
can_listcomp = r2

0 commit comments

Comments
 (0)