Skip to content

Commit 4712ea9

Browse files
Merge branch 'master' into for-map
2 parents fcc221e + 27b9ba0 commit 4712ea9

File tree

8 files changed

+877
-72
lines changed

8 files changed

+877
-72
lines changed

mypyc/ir/rtypes.py

Lines changed: 44 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def may_be_immortal(self) -> bool:
192192
def serialize(self) -> str:
193193
return "void"
194194

195-
def __eq__(self, other: object) -> bool:
195+
def __eq__(self, other: object) -> TypeGuard[RVoid]:
196196
return isinstance(other, RVoid)
197197

198198
def __hash__(self) -> int:
@@ -279,7 +279,7 @@ def serialize(self) -> str:
279279
def __repr__(self) -> str:
280280
return "<RPrimitive %s>" % self.name
281281

282-
def __eq__(self, other: object) -> bool:
282+
def __eq__(self, other: object) -> TypeGuard[RPrimitive]:
283283
return isinstance(other, RPrimitive) and other.name == self.name
284284

285285
def __hash__(self) -> int:
@@ -513,15 +513,15 @@ def __hash__(self) -> int:
513513
range_rprimitive: Final = RPrimitive("builtins.range", is_unboxed=False, is_refcounted=True)
514514

515515

516-
def is_tagged(rtype: RType) -> bool:
516+
def is_tagged(rtype: RType) -> TypeGuard[RPrimitive]:
517517
return rtype is int_rprimitive or rtype is short_int_rprimitive
518518

519519

520-
def is_int_rprimitive(rtype: RType) -> bool:
520+
def is_int_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
521521
return rtype is int_rprimitive
522522

523523

524-
def is_short_int_rprimitive(rtype: RType) -> bool:
524+
def is_short_int_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
525525
return rtype is short_int_rprimitive
526526

527527

@@ -535,7 +535,7 @@ def is_int32_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
535535
)
536536

537537

538-
def is_int64_rprimitive(rtype: RType) -> bool:
538+
def is_int64_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
539539
return rtype is int64_rprimitive or (
540540
rtype is c_pyssize_t_rprimitive and rtype._ctype == "int64_t"
541541
)
@@ -554,81 +554,93 @@ def is_uint8_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
554554
return rtype is uint8_rprimitive
555555

556556

557-
def is_uint32_rprimitive(rtype: RType) -> bool:
557+
def is_uint32_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
558558
return rtype is uint32_rprimitive
559559

560560

561-
def is_uint64_rprimitive(rtype: RType) -> bool:
561+
def is_uint64_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
562562
return rtype is uint64_rprimitive
563563

564564

565-
def is_c_py_ssize_t_rprimitive(rtype: RType) -> bool:
565+
def is_c_py_ssize_t_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
566566
return rtype is c_pyssize_t_rprimitive
567567

568568

569-
def is_pointer_rprimitive(rtype: RType) -> bool:
569+
def is_pointer_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
570570
return rtype is pointer_rprimitive
571571

572572

573-
def is_float_rprimitive(rtype: RType) -> bool:
573+
def is_float_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
574574
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.float"
575575

576576

577-
def is_bool_rprimitive(rtype: RType) -> bool:
577+
def is_bool_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
578578
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.bool"
579579

580580

581-
def is_bit_rprimitive(rtype: RType) -> bool:
581+
def is_bit_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
582582
return isinstance(rtype, RPrimitive) and rtype.name == "bit"
583583

584584

585-
def is_bool_or_bit_rprimitive(rtype: RType) -> bool:
585+
def is_bool_or_bit_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
586586
return is_bool_rprimitive(rtype) or is_bit_rprimitive(rtype)
587587

588588

589-
def is_object_rprimitive(rtype: RType) -> bool:
589+
def is_object_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
590590
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.object"
591591

592592

593-
def is_none_rprimitive(rtype: RType) -> bool:
593+
def is_none_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
594594
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.None"
595595

596596

597-
def is_list_rprimitive(rtype: RType) -> bool:
597+
def is_list_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
598598
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.list"
599599

600600

601-
def is_dict_rprimitive(rtype: RType) -> bool:
601+
def is_dict_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
602602
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.dict"
603603

604604

605-
def is_set_rprimitive(rtype: RType) -> bool:
605+
def is_set_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
606606
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.set"
607607

608608

609-
def is_frozenset_rprimitive(rtype: RType) -> bool:
609+
def is_frozenset_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
610610
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.frozenset"
611611

612612

613-
def is_str_rprimitive(rtype: RType) -> bool:
613+
def is_str_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
614614
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.str"
615615

616616

617-
def is_bytes_rprimitive(rtype: RType) -> bool:
617+
def is_bytes_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
618618
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.bytes"
619619

620620

621-
def is_tuple_rprimitive(rtype: RType) -> bool:
621+
def is_tuple_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
622622
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.tuple"
623623

624624

625-
def is_range_rprimitive(rtype: RType) -> bool:
625+
def is_range_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
626626
return isinstance(rtype, RPrimitive) and rtype.name == "builtins.range"
627627

628628

629-
def is_sequence_rprimitive(rtype: RType) -> bool:
629+
def is_sequence_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
630630
return isinstance(rtype, RPrimitive) and (
631-
is_list_rprimitive(rtype) or is_tuple_rprimitive(rtype) or is_str_rprimitive(rtype)
631+
is_list_rprimitive(rtype)
632+
or is_tuple_rprimitive(rtype)
633+
or is_str_rprimitive(rtype)
634+
or is_bytes_rprimitive(rtype)
635+
)
636+
637+
638+
def is_immutable_rprimitive(rtype: RType) -> TypeGuard[RPrimitive]:
639+
return (
640+
is_str_rprimitive(rtype)
641+
or is_bytes_rprimitive(rtype)
642+
or is_tuple_rprimitive(rtype)
643+
or is_frozenset_rprimitive(rtype)
632644
)
633645

634646

@@ -717,7 +729,7 @@ def __str__(self) -> str:
717729
def __repr__(self) -> str:
718730
return "<RTuple %s>" % ", ".join(repr(typ) for typ in self.types)
719731

720-
def __eq__(self, other: object) -> bool:
732+
def __eq__(self, other: object) -> TypeGuard[RTuple]:
721733
return isinstance(other, RTuple) and self.types == other.types
722734

723735
def __hash__(self) -> int:
@@ -850,7 +862,7 @@ def __repr__(self) -> str:
850862
", ".join(name + ":" + repr(typ) for name, typ in zip(self.names, self.types)),
851863
)
852864

853-
def __eq__(self, other: object) -> bool:
865+
def __eq__(self, other: object) -> TypeGuard[RStruct]:
854866
return (
855867
isinstance(other, RStruct)
856868
and self.name == other.name
@@ -920,7 +932,7 @@ def attr_type(self, name: str) -> RType:
920932
def __repr__(self) -> str:
921933
return "<RInstance %s>" % self.name
922934

923-
def __eq__(self, other: object) -> bool:
935+
def __eq__(self, other: object) -> TypeGuard[RInstance]:
924936
return isinstance(other, RInstance) and other.name == self.name
925937

926938
def __hash__(self) -> int:
@@ -974,7 +986,7 @@ def __str__(self) -> str:
974986
return "union[%s]" % ", ".join(str(item) for item in self.items)
975987

976988
# We compare based on the set because order in a union doesn't matter
977-
def __eq__(self, other: object) -> bool:
989+
def __eq__(self, other: object) -> TypeGuard[RUnion]:
978990
return isinstance(other, RUnion) and self.items_set == other.items_set
979991

980992
def __hash__(self) -> int:
@@ -1016,7 +1028,7 @@ def optional_value_type(rtype: RType) -> RType | None:
10161028
return None
10171029

10181030

1019-
def is_optional_type(rtype: RType) -> bool:
1031+
def is_optional_type(rtype: RType) -> TypeGuard[RUnion]:
10201032
"""Is rtype an optional type with exactly two union items?"""
10211033
return optional_value_type(rtype) is not None
10221034

@@ -1048,7 +1060,7 @@ def __str__(self) -> str:
10481060
def __repr__(self) -> str:
10491061
return f"<RArray {self.item_type!r}[{self.length}]>"
10501062

1051-
def __eq__(self, other: object) -> bool:
1063+
def __eq__(self, other: object) -> TypeGuard[RArray]:
10521064
return (
10531065
isinstance(other, RArray)
10541066
and self.item_type == other.item_type

mypyc/irbuild/builder.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@
9191
RType,
9292
RUnion,
9393
bitmap_rprimitive,
94+
bytes_rprimitive,
9495
c_pyssize_t_rprimitive,
9596
dict_rprimitive,
9697
int_rprimitive,
@@ -962,8 +963,12 @@ def get_sequence_type_from_type(self, target_type: Type) -> RType:
962963
elif isinstance(target_type, Instance):
963964
if target_type.type.fullname == "builtins.str":
964965
return str_rprimitive
965-
else:
966+
elif target_type.type.fullname == "builtins.bytes":
967+
return bytes_rprimitive
968+
try:
966969
return self.type_to_rtype(target_type.args[0])
970+
except IndexError:
971+
raise ValueError(f"{target_type!r} is not a valid sequence.") from None
967972
# This elif-blocks are needed for iterating over classes derived from NamedTuple.
968973
elif isinstance(target_type, TypeVarLikeType):
969974
return self.get_sequence_type_from_type(target_type.upper_bound)

mypyc/irbuild/for_helpers.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
int_rprimitive,
5151
is_dict_rprimitive,
5252
is_fixed_width_rtype,
53+
is_immutable_rprimitive,
5354
is_list_rprimitive,
5455
is_sequence_rprimitive,
5556
is_short_int_rprimitive,
@@ -207,9 +208,9 @@ def sequence_from_generator_preallocate_helper(
207208
there is no condition list in the generator and only one original sequence with
208209
one index is allowed.
209210
210-
e.g. (1) tuple(f(x) for x in a_list/a_tuple)
211-
(2) list(f(x) for x in a_list/a_tuple)
212-
(3) [f(x) for x in a_list/a_tuple]
211+
e.g. (1) tuple(f(x) for x in a_list/a_tuple/a_str/a_bytes)
212+
(2) list(f(x) for x in a_list/a_tuple/a_str/a_bytes)
213+
(3) [f(x) for x in a_list/a_tuple/a_str/a_bytes]
213214
RTuple as an original sequence is not supported yet.
214215
215216
Args:
@@ -226,7 +227,7 @@ def sequence_from_generator_preallocate_helper(
226227
"""
227228
if len(gen.sequences) == 1 and len(gen.indices) == 1 and len(gen.condlists[0]) == 0:
228229
rtype = builder.node_type(gen.sequences[0])
229-
if is_list_rprimitive(rtype) or is_tuple_rprimitive(rtype) or is_str_rprimitive(rtype):
230+
if is_sequence_rprimitive(rtype):
230231
sequence = builder.accept(gen.sequences[0])
231232
length = builder.builder.builtin_len(sequence, gen.line, use_pyssize_t=True)
232233
target_op = empty_op_llbuilder(length, gen.line)
@@ -797,17 +798,31 @@ class ForSequence(ForGenerator):
797798
Supports iterating in both forward and reverse.
798799
"""
799800

801+
length_reg: Value | AssignmentTarget | None
802+
800803
def init(self, expr_reg: Value, target_type: RType, reverse: bool) -> None:
804+
assert is_sequence_rprimitive(expr_reg.type), expr_reg
801805
builder = self.builder
802806
self.reverse = reverse
803807
# Define target to contain the expression, along with the index that will be used
804808
# for the for-loop. If we are inside of a generator function, spill these into the
805809
# environment class.
806810
self.expr_target = builder.maybe_spill(expr_reg)
811+
if is_immutable_rprimitive(expr_reg.type):
812+
# If the expression is an immutable type, we can load the length just once.
813+
self.length_reg = builder.maybe_spill(self.load_len(self.expr_target))
814+
else:
815+
# Otherwise, even if the length is known, we must recalculate the length
816+
# at every iteration for compatibility with python semantics.
817+
self.length_reg = None
807818
if not reverse:
808819
index_reg: Value = Integer(0, c_pyssize_t_rprimitive)
809820
else:
810-
index_reg = builder.builder.int_sub(self.load_len(self.expr_target), 1)
821+
if self.length_reg is not None:
822+
len_val = builder.read(self.length_reg)
823+
else:
824+
len_val = self.load_len(self.expr_target)
825+
index_reg = builder.builder.int_sub(len_val, 1)
811826
self.index_target = builder.maybe_spill_assignable(index_reg)
812827
self.target_type = target_type
813828

@@ -826,9 +841,13 @@ def gen_condition(self) -> None:
826841
second_check = BasicBlock()
827842
builder.add_bool_branch(comparison, second_check, self.loop_exit)
828843
builder.activate_block(second_check)
829-
# For compatibility with python semantics we recalculate the length
830-
# at every iteration.
831-
len_reg = self.load_len(self.expr_target)
844+
if self.length_reg is None:
845+
# For compatibility with python semantics we recalculate the length
846+
# at every iteration.
847+
len_reg = self.load_len(self.expr_target)
848+
else:
849+
# (unless input is immutable type).
850+
len_reg = builder.read(self.length_reg, line)
832851
comparison = builder.binary_op(builder.read(self.index_target, line), len_reg, "<", line)
833852
builder.add_bool_branch(comparison, self.body_block, self.loop_exit)
834853

mypyc/irbuild/specialize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ def translate_tuple_from_generator_call(
288288
"""Special case for simplest tuple creation from a generator.
289289
290290
For example:
291-
tuple(f(x) for x in some_list/some_tuple/some_str)
291+
tuple(f(x) for x in some_list/some_tuple/some_str/some_bytes)
292292
'translate_safe_generator_call()' would take care of other cases
293293
if this fails.
294294
"""

mypyc/test-data/fixtures/ir.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def __getitem__(self, i: int) -> int: ...
174174
def __getitem__(self, i: slice) -> bytes: ...
175175
def join(self, x: Iterable[object]) -> bytes: ...
176176
def decode(self, x: str=..., y: str=...) -> str: ...
177+
def __iter__(self) -> Iterator[int]: ...
177178

178179
class bytearray:
179180
@overload

mypyc/test-data/irbuild-generics.test

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -711,18 +711,18 @@ L0:
711711
r0 = __mypyc_self__.__mypyc_env__
712712
r1 = var_object_size args
713713
r2 = PyList_New(r1)
714-
r3 = 0
714+
r3 = var_object_size args
715+
r4 = 0
715716
L1:
716-
r4 = var_object_size args
717-
r5 = r3 < r4 :: signed
717+
r5 = r4 < r3 :: signed
718718
if r5 goto L2 else goto L4 :: bool
719719
L2:
720-
r6 = CPySequenceTuple_GetItemUnsafe(args, r3)
720+
r6 = CPySequenceTuple_GetItemUnsafe(args, r4)
721721
x = r6
722-
CPyList_SetItemUnsafe(r2, r3, x)
722+
CPyList_SetItemUnsafe(r2, r4, x)
723723
L3:
724-
r7 = r3 + 1
725-
r3 = r7
724+
r7 = r4 + 1
725+
r4 = r7
726726
goto L1
727727
L4:
728728
can_listcomp = r2

0 commit comments

Comments
 (0)