Skip to content

Commit ac646c0

Browse files
[mypyc] feat: optimize f-string building from Final values (#19611)
We can do some extra constant folding in cases like this: ```python from typing import Final BASE_URL: Final = "https://example.com" PORT: Final = 1234 def get_url(endpoint: str) -> str: return f"{BASE_URL}:{PORT}/{endpoint}" ``` which should generate the same C code as ```python def get_url(endpoint: str) -> str: return f"https://example.com:1234/{endpoint}" ``` This PR makes it so.
1 parent 8e66cf2 commit ac646c0

File tree

3 files changed

+61
-2
lines changed

3 files changed

+61
-2
lines changed

mypyc/irbuild/format_str_tokenizer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
)
1313
from mypy.errors import Errors
1414
from mypy.messages import MessageBuilder
15-
from mypy.nodes import Context, Expression
15+
from mypy.nodes import Context, Expression, StrExpr
1616
from mypy.options import Options
1717
from mypyc.ir.ops import Integer, Value
1818
from mypyc.ir.rtypes import (
@@ -143,7 +143,9 @@ def convert_format_expr_to_str(
143143
for x, format_op in zip(exprs, format_ops):
144144
node_type = builder.node_type(x)
145145
if format_op == FormatOp.STR:
146-
if is_str_rprimitive(node_type):
146+
if is_str_rprimitive(node_type) or isinstance(
147+
x, StrExpr
148+
): # NOTE: why does mypyc think our fake StrExprs are not str rprimitives?
147149
var_str = builder.accept(x)
148150
elif is_int_rprimitive(node_type) or is_short_int_rprimitive(node_type):
149151
var_str = builder.primitive_op(int_to_str_op, [builder.accept(x)], line)

mypyc/irbuild/specialize.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
RefExpr,
3232
StrExpr,
3333
TupleExpr,
34+
Var,
3435
)
3536
from mypy.types import AnyType, TypeOfAny
3637
from mypyc.ir.ops import (
@@ -710,6 +711,22 @@ def translate_fstring(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Va
710711
format_ops.append(FormatOp.STR)
711712
exprs.append(item.args[0])
712713

714+
def get_literal_str(expr: Expression) -> str | None:
715+
if isinstance(expr, StrExpr):
716+
return expr.value
717+
elif isinstance(expr, RefExpr) and isinstance(expr.node, Var) and expr.node.is_final:
718+
return str(expr.node.final_value)
719+
return None
720+
721+
for i in range(len(exprs) - 1):
722+
while (
723+
len(exprs) >= i + 2
724+
and (first := get_literal_str(exprs[i])) is not None
725+
and (second := get_literal_str(exprs[i + 1])) is not None
726+
):
727+
exprs = [*exprs[:i], StrExpr(first + second), *exprs[i + 2 :]]
728+
format_ops = [*format_ops[:i], FormatOp.STR, *format_ops[i + 2 :]]
729+
713730
substitutions = convert_format_expr_to_str(builder, format_ops, exprs, expr.line)
714731
if substitutions is None:
715732
return None

mypyc/test-data/irbuild-str.test

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,3 +605,43 @@ L3:
605605
r6 = r7
606606
L4:
607607
return r6
608+
609+
[case testFStringFromConstants]
610+
from typing import Final
611+
string: Final = "abc"
612+
integer: Final = 123
613+
floating: Final = 3.14
614+
boolean: Final = True
615+
616+
def test(x: str) -> str:
617+
return f"{string}{integer}{floating}{boolean}{x}{boolean}{floating}{integer}{string}{x}{string}{integer}{floating}{boolean}"
618+
def test2(x: str) -> str:
619+
return f"{string}{integer}{floating}{boolean}{x}{boolean}{floating}{integer}{string}{x}{string}{integer}{floating}{boolean}{x}"
620+
def test3(x: str) -> str:
621+
return f"{x}{string}{integer}{floating}{boolean}{x}{boolean}{floating}{integer}{string}{x}{string}{integer}{floating}{boolean}{x}"
622+
623+
[out]
624+
def test(x):
625+
x, r0, r1, r2, r3 :: str
626+
L0:
627+
r0 = 'abc1233.14True'
628+
r1 = 'True3.14123abc'
629+
r2 = 'abc1233.14True'
630+
r3 = CPyStr_Build(5, r0, x, r1, x, r2)
631+
return r3
632+
def test2(x):
633+
x, r0, r1, r2, r3 :: str
634+
L0:
635+
r0 = 'abc1233.14True'
636+
r1 = 'True3.14123abc'
637+
r2 = 'abc1233.14True'
638+
r3 = CPyStr_Build(6, r0, x, r1, x, r2, x)
639+
return r3
640+
def test3(x):
641+
x, r0, r1, r2, r3 :: str
642+
L0:
643+
r0 = 'abc1233.14True'
644+
r1 = 'True3.14123abc'
645+
r2 = 'abc1233.14True'
646+
r3 = CPyStr_Build(7, x, r0, x, r1, x, r2, x)
647+
return r3

0 commit comments

Comments
 (0)