diff --git a/mypyc/lib-rt/CPy.h b/mypyc/lib-rt/CPy.h index 1881aa97f308..cf096b1e2324 100644 --- a/mypyc/lib-rt/CPy.h +++ b/mypyc/lib-rt/CPy.h @@ -767,6 +767,10 @@ CPyTagged CPyBytes_GetItem(PyObject *o, CPyTagged index); PyObject *CPyBytes_Concat(PyObject *a, PyObject *b); PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter); CPyTagged CPyBytes_Ord(PyObject *obj); +PyObject *CPyBytes_LjustDefaultFill(PyObject *self, CPyTagged width); +PyObject *CPyBytes_RjustDefaultFill(PyObject *self, CPyTagged width); +PyObject *CPyBytes_LjustCustomFill(PyObject *self, CPyTagged width, PyObject *fillbyte); +PyObject *CPyBytes_RjustCustomFill(PyObject *self, CPyTagged width, PyObject *fillbyte); int CPyBytes_Compare(PyObject *left, PyObject *right); diff --git a/mypyc/lib-rt/bytes_ops.c b/mypyc/lib-rt/bytes_ops.c index 6ff34b021a9a..7d1dd363203b 100644 --- a/mypyc/lib-rt/bytes_ops.c +++ b/mypyc/lib-rt/bytes_ops.c @@ -4,6 +4,7 @@ #include #include "CPy.h" +#include // Returns -1 on error, 0 on inequality, 1 on equality. // @@ -162,3 +163,97 @@ CPyTagged CPyBytes_Ord(PyObject *obj) { PyErr_SetString(PyExc_TypeError, "ord() expects a character"); return CPY_INT_TAG; } + + +PyObject *CPyBytes_RjustDefaultFill(PyObject *self, CPyTagged width) { + if (!PyBytes_Check(self)) { + PyErr_SetString(PyExc_TypeError, "self must be bytes"); + return NULL; + } + Py_ssize_t width_size_t = CPyTagged_AsSsize_t(width); + Py_ssize_t len = PyBytes_Size(self); + if (width_size_t <= len) { + Py_INCREF(self); + return self; + } + Py_ssize_t pad = width_size_t - len; + PyObject *result = PyBytes_FromStringAndSize(NULL, width_size_t); + if (!result) return NULL; + char *res_buf = PyBytes_AsString(result); + memset(res_buf, ' ', pad); + memcpy(res_buf + pad, PyBytes_AsString(self), len); + return result; +} + + +PyObject *CPyBytes_RjustCustomFill(PyObject *self, CPyTagged width, PyObject *fillbyte) { + if (!PyBytes_Check(self)) { + PyErr_SetString(PyExc_TypeError, "self must be bytes"); + return NULL; + } + if (!PyBytes_Check(fillbyte) || PyBytes_Size(fillbyte) != 1) { + PyErr_SetString(PyExc_TypeError, "fillbyte must be a single byte"); + return NULL; + } + Py_ssize_t width_size_t = CPyTagged_AsSsize_t(width); + Py_ssize_t len = PyBytes_Size(self); + if (width_size_t <= len) { + Py_INCREF(self); + return self; + } + char fill = PyBytes_AsString(fillbyte)[0]; + Py_ssize_t pad = width_size_t - len; + PyObject *result = PyBytes_FromStringAndSize(NULL, width_size_t); + if (!result) return NULL; + char *res_buf = PyBytes_AsString(result); + memset(res_buf, fill, pad); + memcpy(res_buf + pad, PyBytes_AsString(self), len); + return result; +} + + +PyObject *CPyBytes_LjustDefaultFill(PyObject *self, CPyTagged width) { + if (!PyBytes_Check(self)) { + PyErr_SetString(PyExc_TypeError, "self must be bytes"); + return NULL; + } + Py_ssize_t width_size_t = CPyTagged_AsSsize_t(width); + Py_ssize_t len = PyBytes_Size(self); + if (width_size_t <= len) { + Py_INCREF(self); + return self; + } + Py_ssize_t pad = width_size_t - len; + PyObject *result = PyBytes_FromStringAndSize(NULL, width_size_t); + if (!result) return NULL; + char *res_buf = PyBytes_AsString(result); + memcpy(res_buf, PyBytes_AsString(self), len); + memset(res_buf + len, ' ', pad); + return result; +} + + +PyObject *CPyBytes_LjustCustomFill(PyObject *self, CPyTagged width, PyObject *fillbyte) { + if (!PyBytes_Check(self)) { + PyErr_SetString(PyExc_TypeError, "self must be bytes"); + return NULL; + } + if (!PyBytes_Check(fillbyte) || PyBytes_Size(fillbyte) != 1) { + PyErr_SetString(PyExc_TypeError, "fillbyte must be a single byte"); + return NULL; + } + Py_ssize_t width_size_t = CPyTagged_AsSsize_t(width); + Py_ssize_t len = PyBytes_Size(self); + if (width_size_t <= len) { + Py_INCREF(self); + return self; + } + char fill = PyBytes_AsString(fillbyte)[0]; + Py_ssize_t pad = width_size_t - len; + PyObject *result = PyBytes_FromStringAndSize(NULL, width_size_t); + if (!result) return NULL; + char *res_buf = PyBytes_AsString(result); + memcpy(res_buf, PyBytes_AsString(self), len); + memset(res_buf + len, fill, pad); + return result; +} diff --git a/mypyc/primitives/bytes_ops.py b/mypyc/primitives/bytes_ops.py index c88e89d1a2ba..b4fe692fb4cd 100644 --- a/mypyc/primitives/bytes_ops.py +++ b/mypyc/primitives/bytes_ops.py @@ -126,3 +126,39 @@ c_function_name="CPyBytes_Ord", error_kind=ERR_MAGIC, ) + +# bytes.rjust(width) +method_op( + name="rjust", + arg_types=[bytes_rprimitive, int_rprimitive], + return_type=bytes_rprimitive, + c_function_name="CPyBytes_RjustDefaultFill", + error_kind=ERR_MAGIC, +) + +# bytes.rjust(width, fillbyte) +method_op( + name="rjust", + arg_types=[bytes_rprimitive, int_rprimitive, bytes_rprimitive], + return_type=bytes_rprimitive, + c_function_name="CPyBytes_RjustCustomFill", + error_kind=ERR_MAGIC, +) + +# bytes.ljust(width) +method_op( + name="ljust", + arg_types=[bytes_rprimitive, int_rprimitive], + return_type=bytes_rprimitive, + c_function_name="CPyBytes_LjustDefaultFill", + error_kind=ERR_MAGIC, +) + +# bytes.ljust(width, fillbyte) +method_op( + name="ljust", + arg_types=[bytes_rprimitive, int_rprimitive, bytes_rprimitive], + return_type=bytes_rprimitive, + c_function_name="CPyBytes_LjustCustomFill", + error_kind=ERR_MAGIC, +) diff --git a/mypyc/test-data/fixtures/ir.py b/mypyc/test-data/fixtures/ir.py index 661ae50fd5f3..fb0f277fa39a 100644 --- a/mypyc/test-data/fixtures/ir.py +++ b/mypyc/test-data/fixtures/ir.py @@ -173,6 +173,8 @@ def __getitem__(self, i: slice) -> bytes: ... def join(self, x: Iterable[object]) -> bytes: ... def decode(self, x: str=..., y: str=...) -> str: ... def __iter__(self) -> Iterator[int]: ... + def ljust(self, width: int, fillchar: bytes | bytearray = b" ") -> bytes: ... + def rjust(self, width: int, fillchar: bytes | bytearray = b" ") -> bytes: ... class bytearray: @overload diff --git a/mypyc/test-data/irbuild-bytes.test b/mypyc/test-data/irbuild-bytes.test index 476c5ac59f48..e5ef0d6cb004 100644 --- a/mypyc/test-data/irbuild-bytes.test +++ b/mypyc/test-data/irbuild-bytes.test @@ -185,3 +185,45 @@ L0: r10 = CPyBytes_Build(2, var, r9) b4 = r10 return 1 + +[case testBytesRjustDefault] +def f(b: bytes) -> bytes: + return b.rjust(6) +[out] +def f(b): + b, r0 :: bytes +L0: + r0 = CPyBytes_RjustDefaultFill(b, 12) + return r0 + +[case testBytesRjustCustom] +def f(b: bytes) -> bytes: + return b.rjust(8, b'0') +[out] +def f(b): + b, r0, r1 :: bytes +L0: + r0 = b'0' + r1 = CPyBytes_RjustCustomFill(b, 16, r0) + return r1 + +[case testBytesLjustDefault] +def f(b: bytes) -> bytes: + return b.ljust(7) +[out] +def f(b): + b, r0 :: bytes +L0: + r0 = CPyBytes_LjustDefaultFill(b, 14) + return r0 + +[case testBytesLjustCustom] +def f(b: bytes) -> bytes: + return b.ljust(10, b'_') +[out] +def f(b): + b, r0, r1 :: bytes +L0: + r0 = b'_' + r1 = CPyBytes_LjustCustomFill(b, 20, r0) + return r1 diff --git a/mypyc/test-data/run-bytes.test b/mypyc/test-data/run-bytes.test index 5a285320c849..3b81814d056d 100644 --- a/mypyc/test-data/run-bytes.test +++ b/mypyc/test-data/run-bytes.test @@ -374,3 +374,43 @@ class subbytearray(bytearray): [file userdefinedbytes.py] class bytes: pass + +[case testBytesRjustLjust] +from testutil import assertRaises + +def rjust_bytes(b: bytes, width: int, fill: bytes = b' ') -> bytes: + return b.rjust(width, fill) + +def ljust_bytes(b: bytes, width: int, fill: bytes = b' ') -> bytes: + return b.ljust(width, fill) + +def test_rjust_with_default_fill() -> None: + assert rjust_bytes(b'abc', 6) == b' abc', rjust_bytes(b'abc', 6) + assert rjust_bytes(b'abc', 3) == b'abc', rjust_bytes(b'abc', 3) + assert rjust_bytes(b'abc', 2) == b'abc', rjust_bytes(b'abc', 2) + assert rjust_bytes(b'', 4) == b' ', rjust_bytes(b'', 4) + +def test_rjust_with_custom_fill() -> None: + assert rjust_bytes(b'abc', 6, b'0') == b'000abc', rjust_bytes(b'abc', 6, b'0') + assert rjust_bytes(b'abc', 5, b'_') == b'__abc', rjust_bytes(b'abc', 5, b'_') + assert rjust_bytes(b'abc', 3, b'X') == b'abc', rjust_bytes(b'abc', 3, b'X') + +def test_ljust_with_default_fill() -> None: + assert ljust_bytes(b'abc', 6) == b'abc ', ljust_bytes(b'abc', 6) + assert ljust_bytes(b'abc', 3) == b'abc', ljust_bytes(b'abc', 3) + assert ljust_bytes(b'abc', 2) == b'abc', ljust_bytes(b'abc', 2) + assert ljust_bytes(b'', 4) == b' ', ljust_bytes(b'', 4) + +def test_ljust_with_custom_fill() -> None: + assert ljust_bytes(b'abc', 6, b'0') == b'abc000', ljust_bytes(b'abc', 6, b'0') + assert ljust_bytes(b'abc', 5, b'_') == b'abc__', ljust_bytes(b'abc', 5, b'_') + assert ljust_bytes(b'abc', 3, b'X') == b'abc', ljust_bytes(b'abc', 3, b'X') + +def test_edge_cases() -> None: + assert rjust_bytes(b'abc', 0) == b'abc', rjust_bytes(b'abc', 0) + assert ljust_bytes(b'abc', 0) == b'abc', ljust_bytes(b'abc', 0) + # fillbyte must be length 1 + with assertRaises(TypeError): + rjust_bytes(b'abc', 5, b'') + with assertRaises(TypeError): + ljust_bytes(b'abc', 5, b'12')