Skip to content

Fix Ctrl-C handling in REPL #3306

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 105 additions & 4 deletions src/trio/_repl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

import ast
import contextlib
import ctypes
import inspect
import os
import sys
import types
import warnings
Expand All @@ -19,15 +21,38 @@
class TrioInteractiveConsole(InteractiveConsole):
def __init__(self, repl_locals: dict[str, object] | None = None) -> None:
super().__init__(locals=repl_locals)
self.code_to_run = None
self.compile.compiler.flags |= ast.PyCF_ALLOW_TOP_LEVEL_AWAIT

readline = sys.modules.get("readline")
if readline is not None:
self.readline = readline
if hasattr(readline, "__file__"):
self.rl = ctypes.CDLL(readline.__file__)
else:
self.rl = ctypes.pythonapi
if hasattr(self.rl, "rl_catch_signals"):
ctypes.c_int.in_dll(self.rl, "rl_catch_signals").value = 0
self.rlcallbacktype = ctypes.CFUNCTYPE(None, ctypes.c_char_p)
self.rl.rl_callback_handler_install.argtypes = [
ctypes.c_char_p,
self.rlcallbacktype,
]
else:
self.rl = None
self.linebuffer = ""

def runcode(self, code: types.CodeType) -> None:
self.code_to_run = code

async def actually_run_code(self) -> None:
# https://github.com/python/typeshed/issues/13768
func = types.FunctionType(code, self.locals) # type: ignore[arg-type]
func = types.FunctionType(self.code_to_run, self.locals) # type: ignore[arg-type]
self.code_to_run = None
if inspect.iscoroutinefunction(func):
result = trio.from_thread.run(outcome.acapture, func)
result = await outcome.acapture(func)
else:
result = trio.from_thread.run_sync(outcome.capture, func)
result = outcome.capture(func)
if isinstance(result, outcome.Error):
# If it is SystemExit, quit the repl. Otherwise, print the traceback.
# If there is a SystemExit inside a BaseExceptionGroup, it probably isn't
Expand All @@ -50,6 +75,78 @@ def runcode(self, code: types.CodeType) -> None:
# This means that overriding self.write also does nothing to tbs.
sys.excepthook(sys.last_type, sys.last_value, sys.last_traceback)

async def ainteract(self, banner):
try:
sys.ps1
except AttributeError:
sys.ps1 = ">>> "
try:
sys.ps2
except AttributeError:
sys.ps2 = "... "

self.write("%s\n" % str(banner))
more = 0

while True:
try:
if more:
prompt = sys.ps2
else:
prompt = sys.ps1
try:
line = await self.async_input(prompt)
except EOFError:
self.write("\n")
break
else:
more = self.push(line)
if more == 0:
await self.actually_run_code()
except KeyboardInterrupt:
self.write("\nKeyboardInterrupt\n")
self.resetbuffer()
more = 0

async def async_input(self, prompt=""):
if self.rl:
line = b""

@self.rlcallbacktype
def callback(text):
nonlocal line
line = text

try:
self.rl.rl_callback_handler_install(prompt.encode(), callback)
while line == b"":
await trio.lowlevel.wait_readable(0)
self.rl.rl_callback_read_char()
Copy link
Contributor

@A5rocks A5rocks Aug 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So just to understand I'm reading this correctly, this rl_callback_read_char would fail a ctrl+c check, except it's guaranteed to have a character as stdin is readable?

except KeyboardInterrupt:
self.rl.rl_free_line_state()
raise
finally:
self.rl.rl_callback_handler_remove()
if line is None:
raise EOFError
self.readline.add_history(line.decode())
return line.decode()
else:
line = ""
print(prompt, file=sys.stderr, end="")
sys.stderr.flush()
while True:
await trio.lowlevel.wait_readable(0)
new = os.read(0, 1024).decode()
if new == "":
raise EOFError
self.linebuffer += new
line, nl, buffer = self.linebuffer.partition("\n")
if nl:
self.linebuffer = buffer
return line
return line


async def run_repl(console: TrioInteractiveConsole) -> None:
banner = (
Expand All @@ -60,7 +157,7 @@ async def run_repl(console: TrioInteractiveConsole) -> None:
f'{getattr(sys, "ps1", ">>> ")}import trio'
)
try:
await trio.to_thread.run_sync(console.interact, banner)
await console.ainteract(banner)
finally:
warnings.filterwarnings(
"ignore",
Expand All @@ -86,3 +183,7 @@ def main(original_locals: dict[str, object]) -> None:

console = TrioInteractiveConsole(repl_locals)
trio.run(run_repl, console)


if __name__ == "__main__":
main(locals())
Loading