Skip to content

[mypyc] feat: new primitives for bytes.rjust and bytes.ljust #19672

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 24 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mypyc/lib-rt/CPy.h
Original file line number Diff line number Diff line change
Expand Up @@ -767,6 +767,10 @@ CPyTagged CPyBytes_GetItem(PyObject *o, CPyTagged index);
PyObject *CPyBytes_Concat(PyObject *a, PyObject *b);
PyObject *CPyBytes_Join(PyObject *sep, PyObject *iter);
CPyTagged CPyBytes_Ord(PyObject *obj);
PyObject *CPyBytes_LjustDefaultFill(PyObject *self, CPyTagged width);
PyObject *CPyBytes_RjustDefaultFill(PyObject *self, CPyTagged width);
PyObject *CPyBytes_LjustCustomFill(PyObject *self, CPyTagged width, PyObject *fillbyte);
PyObject *CPyBytes_RjustCustomFill(PyObject *self, CPyTagged width, PyObject *fillbyte);


int CPyBytes_Compare(PyObject *left, PyObject *right);
Expand Down
95 changes: 95 additions & 0 deletions mypyc/lib-rt/bytes_ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#include <Python.h>
#include "CPy.h"
#include <string.h>

// Returns -1 on error, 0 on inequality, 1 on equality.
//
Expand Down Expand Up @@ -162,3 +163,97 @@ CPyTagged CPyBytes_Ord(PyObject *obj) {
PyErr_SetString(PyExc_TypeError, "ord() expects a character");
return CPY_INT_TAG;
}


PyObject *CPyBytes_RjustDefaultFill(PyObject *self, CPyTagged width) {
if (!PyBytes_Check(self)) {
PyErr_SetString(PyExc_TypeError, "self must be bytes");
return NULL;
}
Py_ssize_t width_size_t = CPyTagged_AsSsize_t(width);
Py_ssize_t len = PyBytes_Size(self);
if (width_size_t <= len) {
Py_INCREF(self);
return self;
}
Py_ssize_t pad = width_size_t - len;
PyObject *result = PyBytes_FromStringAndSize(NULL, width_size_t);
if (!result) return NULL;
char *res_buf = PyBytes_AsString(result);
memset(res_buf, ' ', pad);
memcpy(res_buf + pad, PyBytes_AsString(self), len);
return result;
}


PyObject *CPyBytes_RjustCustomFill(PyObject *self, CPyTagged width, PyObject *fillbyte) {
if (!PyBytes_Check(self)) {
PyErr_SetString(PyExc_TypeError, "self must be bytes");
return NULL;
}
if (!PyBytes_Check(fillbyte) || PyBytes_Size(fillbyte) != 1) {
PyErr_SetString(PyExc_TypeError, "fillbyte must be a single byte");
return NULL;
}
Py_ssize_t width_size_t = CPyTagged_AsSsize_t(width);
Py_ssize_t len = PyBytes_Size(self);
if (width_size_t <= len) {
Py_INCREF(self);
return self;
}
char fill = PyBytes_AsString(fillbyte)[0];
Py_ssize_t pad = width_size_t - len;
PyObject *result = PyBytes_FromStringAndSize(NULL, width_size_t);
if (!result) return NULL;
char *res_buf = PyBytes_AsString(result);
memset(res_buf, fill, pad);
memcpy(res_buf + pad, PyBytes_AsString(self), len);
return result;
}


PyObject *CPyBytes_LjustDefaultFill(PyObject *self, CPyTagged width) {
if (!PyBytes_Check(self)) {
PyErr_SetString(PyExc_TypeError, "self must be bytes");
return NULL;
}
Py_ssize_t width_size_t = CPyTagged_AsSsize_t(width);
Py_ssize_t len = PyBytes_Size(self);
if (width_size_t <= len) {
Py_INCREF(self);
return self;
}
Py_ssize_t pad = width_size_t - len;
PyObject *result = PyBytes_FromStringAndSize(NULL, width_size_t);
if (!result) return NULL;
char *res_buf = PyBytes_AsString(result);
memcpy(res_buf, PyBytes_AsString(self), len);
memset(res_buf + len, ' ', pad);
return result;
}


PyObject *CPyBytes_LjustCustomFill(PyObject *self, CPyTagged width, PyObject *fillbyte) {
if (!PyBytes_Check(self)) {
PyErr_SetString(PyExc_TypeError, "self must be bytes");
return NULL;
}
if (!PyBytes_Check(fillbyte) || PyBytes_Size(fillbyte) != 1) {
PyErr_SetString(PyExc_TypeError, "fillbyte must be a single byte");
return NULL;
}
Py_ssize_t width_size_t = CPyTagged_AsSsize_t(width);
Py_ssize_t len = PyBytes_Size(self);
if (width_size_t <= len) {
Py_INCREF(self);
return self;
}
char fill = PyBytes_AsString(fillbyte)[0];
Py_ssize_t pad = width_size_t - len;
PyObject *result = PyBytes_FromStringAndSize(NULL, width_size_t);
if (!result) return NULL;
char *res_buf = PyBytes_AsString(result);
memcpy(res_buf, PyBytes_AsString(self), len);
memset(res_buf + len, fill, pad);
return result;
}
36 changes: 36 additions & 0 deletions mypyc/primitives/bytes_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,39 @@
c_function_name="CPyBytes_Ord",
error_kind=ERR_MAGIC,
)

# bytes.rjust(width)
method_op(
name="rjust",
arg_types=[bytes_rprimitive, int_rprimitive],
return_type=bytes_rprimitive,
c_function_name="CPyBytes_RjustDefaultFill",
error_kind=ERR_MAGIC,
)

# bytes.rjust(width, fillbyte)
method_op(
name="rjust",
arg_types=[bytes_rprimitive, int_rprimitive, bytes_rprimitive],
return_type=bytes_rprimitive,
c_function_name="CPyBytes_RjustCustomFill",
error_kind=ERR_MAGIC,
)

# bytes.ljust(width)
method_op(
name="ljust",
arg_types=[bytes_rprimitive, int_rprimitive],
return_type=bytes_rprimitive,
c_function_name="CPyBytes_LjustDefaultFill",
error_kind=ERR_MAGIC,
)

# bytes.ljust(width, fillbyte)
method_op(
name="ljust",
arg_types=[bytes_rprimitive, int_rprimitive, bytes_rprimitive],
return_type=bytes_rprimitive,
c_function_name="CPyBytes_LjustCustomFill",
error_kind=ERR_MAGIC,
)
2 changes: 2 additions & 0 deletions mypyc/test-data/fixtures/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ def __getitem__(self, i: slice) -> bytes: ...
def join(self, x: Iterable[object]) -> bytes: ...
def decode(self, x: str=..., y: str=...) -> str: ...
def __iter__(self) -> Iterator[int]: ...
def ljust(self, width: int, fillchar: bytes | bytearray = b" ") -> bytes: ...
def rjust(self, width: int, fillchar: bytes | bytearray = b" ") -> bytes: ...

class bytearray:
@overload
Expand Down
42 changes: 42 additions & 0 deletions mypyc/test-data/irbuild-bytes.test
Original file line number Diff line number Diff line change
Expand Up @@ -185,3 +185,45 @@ L0:
r10 = CPyBytes_Build(2, var, r9)
b4 = r10
return 1

[case testBytesRjustDefault]
def f(b: bytes) -> bytes:
return b.rjust(6)
[out]
def f(b):
b, r0 :: bytes
L0:
r0 = CPyBytes_RjustDefaultFill(b, 12)
return r0

[case testBytesRjustCustom]
def f(b: bytes) -> bytes:
return b.rjust(8, b'0')
[out]
def f(b):
b, r0, r1 :: bytes
L0:
r0 = b'0'
r1 = CPyBytes_RjustCustomFill(b, 16, r0)
return r1

[case testBytesLjustDefault]
def f(b: bytes) -> bytes:
return b.ljust(7)
[out]
def f(b):
b, r0 :: bytes
L0:
r0 = CPyBytes_LjustDefaultFill(b, 14)
return r0

[case testBytesLjustCustom]
def f(b: bytes) -> bytes:
return b.ljust(10, b'_')
[out]
def f(b):
b, r0, r1 :: bytes
L0:
r0 = b'_'
r1 = CPyBytes_LjustCustomFill(b, 20, r0)
return r1
40 changes: 40 additions & 0 deletions mypyc/test-data/run-bytes.test
Original file line number Diff line number Diff line change
Expand Up @@ -374,3 +374,43 @@ class subbytearray(bytearray):
[file userdefinedbytes.py]
class bytes:
pass

[case testBytesRjustLjust]
from testutil import assertRaises

def rjust_bytes(b: bytes, width: int, fill: bytes = b' ') -> bytes:
return b.rjust(width, fill)

def ljust_bytes(b: bytes, width: int, fill: bytes = b' ') -> bytes:
return b.ljust(width, fill)

def test_rjust_with_default_fill() -> None:
assert rjust_bytes(b'abc', 6) == b' abc', rjust_bytes(b'abc', 6)
assert rjust_bytes(b'abc', 3) == b'abc', rjust_bytes(b'abc', 3)
assert rjust_bytes(b'abc', 2) == b'abc', rjust_bytes(b'abc', 2)
assert rjust_bytes(b'', 4) == b' ', rjust_bytes(b'', 4)

def test_rjust_with_custom_fill() -> None:
assert rjust_bytes(b'abc', 6, b'0') == b'000abc', rjust_bytes(b'abc', 6, b'0')
assert rjust_bytes(b'abc', 5, b'_') == b'__abc', rjust_bytes(b'abc', 5, b'_')
assert rjust_bytes(b'abc', 3, b'X') == b'abc', rjust_bytes(b'abc', 3, b'X')

def test_ljust_with_default_fill() -> None:
assert ljust_bytes(b'abc', 6) == b'abc ', ljust_bytes(b'abc', 6)
assert ljust_bytes(b'abc', 3) == b'abc', ljust_bytes(b'abc', 3)
assert ljust_bytes(b'abc', 2) == b'abc', ljust_bytes(b'abc', 2)
assert ljust_bytes(b'', 4) == b' ', ljust_bytes(b'', 4)

def test_ljust_with_custom_fill() -> None:
assert ljust_bytes(b'abc', 6, b'0') == b'abc000', ljust_bytes(b'abc', 6, b'0')
assert ljust_bytes(b'abc', 5, b'_') == b'abc__', ljust_bytes(b'abc', 5, b'_')
assert ljust_bytes(b'abc', 3, b'X') == b'abc', ljust_bytes(b'abc', 3, b'X')

def test_edge_cases() -> None:
assert rjust_bytes(b'abc', 0) == b'abc', rjust_bytes(b'abc', 0)
assert ljust_bytes(b'abc', 0) == b'abc', ljust_bytes(b'abc', 0)
# fillbyte must be length 1
with assertRaises(TypeError):
rjust_bytes(b'abc', 5, b'')
with assertRaises(TypeError):
ljust_bytes(b'abc', 5, b'12')
Loading