Skip to content

Commit 26e8a53

Browse files
committed
[mypyc] feat: new primitives for bytes.rjust and bytes.ljust
1 parent 3fcfcb8 commit 26e8a53

File tree

4 files changed

+171
-0
lines changed

4 files changed

+171
-0
lines changed

mypyc/lib-rt/bytes_ops.c

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <Python.h>
66
#include "CPy.h"
7+
#include <string.h>
78

89
// Returns -1 on error, 0 on inequality, 1 on equality.
910
//
@@ -162,3 +163,51 @@ CPyTagged CPyBytes_Ord(PyObject *obj) {
162163
PyErr_SetString(PyExc_TypeError, "ord() expects a character");
163164
return CPY_INT_TAG;
164165
}
166+
167+
PyObject *CPyBytes_Rjust(PyObject *self, Py_ssize_t width, PyObject *fillbyte) {
168+
if (!PyBytes_Check(self)) {
169+
PyErr_SetString(PyExc_TypeError, "self must be bytes");
170+
return NULL;
171+
}
172+
if (!PyBytes_Check(fillbyte) || PyBytes_Size(fillbyte) != 1) {
173+
PyErr_SetString(PyExc_TypeError, "fillbyte must be a single byte");
174+
return NULL;
175+
}
176+
Py_ssize_t len = PyBytes_Size(self);
177+
if (width <= len) {
178+
return PyBytes_FromStringAndSize(PyBytes_AsString(self), len);
179+
}
180+
char fill = PyBytes_AsString(fillbyte)[0];
181+
Py_ssize_t pad = width - len;
182+
PyObject *result = PyBytes_FromStringAndSize(NULL, width);
183+
if (!result) return NULL;
184+
char *res_buf = PyBytes_AsString(result);
185+
memset(res_buf, fill, pad);
186+
memcpy(res_buf + pad, PyBytes_AsString(self), len);
187+
return result;
188+
}
189+
190+
PyObject *CPyBytes_Ljust(PyObject *self, Py_ssize_t width, PyObject *fillbyte) {
191+
if (!PyBytes_Check(self)) {
192+
PyErr_SetString(PyExc_TypeError, "self must be bytes");
193+
return NULL;
194+
}
195+
if (!PyBytes_Check(fillbyte) || PyBytes_Size(fillbyte) != 1) {
196+
PyErr_SetString(PyExc_TypeError, "fillbyte must be a single byte");
197+
return NULL;
198+
}
199+
Py_ssize_t len = PyBytes_Size(self);
200+
if (width <= len) {
201+
return PyBytes_FromStringAndSize(PyBytes_AsString(self), len);
202+
}
203+
char fill = PyBytes_AsString(fillbyte)[0];
204+
Py_ssize_t pad = width - len;
205+
PyObject *result = PyBytes_FromStringAndSize(NULL, width);
206+
if (!result) return NULL;
207+
char *res_buf = PyBytes_AsString(result);
208+
memcpy(res_buf, PyBytes_AsString(self), len);
209+
memset(res_buf + len, fill, pad);
210+
return result;
211+
}
212+
213+
// ... existing code ...

mypyc/primitives/bytes_ops.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,3 +126,21 @@
126126
c_function_name="CPyBytes_Ord",
127127
error_kind=ERR_MAGIC,
128128
)
129+
130+
# bytes.rjust(width, fillbyte=b' ')
131+
method_op(
132+
name="rjust",
133+
arg_types=[bytes_rprimitive, int_rprimitive, bytes_rprimitive],
134+
return_type=bytes_rprimitive,
135+
c_function_name="CPyBytes_Rjust",
136+
error_kind=ERR_MAGIC,
137+
)
138+
139+
# bytes.ljust(width, fillbyte=b' ')
140+
method_op(
141+
name="ljust",
142+
arg_types=[bytes_rprimitive, int_rprimitive, bytes_rprimitive],
143+
return_type=bytes_rprimitive,
144+
c_function_name="CPyBytes_Ljust",
145+
error_kind=ERR_MAGIC,
146+
)

mypyc/test-data/irbuild-bytes.test

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,69 @@ L0:
185185
r10 = CPyBytes_Build(2, var, r9)
186186
b4 = r10
187187
return 1
188+
189+
[case testBytesRjustDefault]
190+
def f(b: bytes) -> bytes:
191+
return b.rjust(6)
192+
[out]
193+
def f(b):
194+
b :: bytes
195+
r0 :: bytes
196+
L0:
197+
r0 = b.rjust(6, b' ')
198+
return r0
199+
200+
[case testBytesRjustCustom]
201+
def f(b: bytes) -> bytes:
202+
return b.rjust(8, b'0')
203+
[out]
204+
def f(b):
205+
b :: bytes
206+
r0 :: bytes
207+
L0:
208+
r0 = b.rjust(8, b'0')
209+
return r0
210+
211+
[case testBytesLjustDefault]
212+
def f(b: bytes) -> bytes:
213+
return b.ljust(7)
214+
[out]
215+
def f(b):
216+
b :: bytes
217+
r0 :: bytes
218+
L0:
219+
r0 = b.ljust(7, b' ')
220+
return r0
221+
222+
[case testBytesLjustCustom]
223+
def f(b: bytes) -> bytes:
224+
return b.ljust(10, b'_')
225+
[out]
226+
def f(b):
227+
b :: bytes
228+
r0 :: bytes
229+
L0:
230+
r0 = b.ljust(10, b'_')
231+
return r0
232+
233+
[case testBytesRjustNoPad]
234+
def f(b: bytes) -> bytes:
235+
return b.rjust(2)
236+
[out]
237+
def f(b):
238+
b :: bytes
239+
r0 :: bytes
240+
L0:
241+
r0 = b.rjust(2, b' ')
242+
return r0
243+
244+
[case testBytesLjustNoPad]
245+
def f(b: bytes) -> bytes:
246+
return b.ljust(1)
247+
[out]
248+
def f(b):
249+
b :: bytes
250+
r0 :: bytes
251+
L0:
252+
r0 = b.ljust(1, b' ')
253+
return r0

mypyc/test-data/run-bytes.test

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,3 +374,41 @@ class subbytearray(bytearray):
374374
[file userdefinedbytes.py]
375375
class bytes:
376376
pass
377+
378+
[case testBytesRjustLjust]
379+
def rjust_bytes(b: bytes, width: int, fill: bytes = b' ') -> bytes:
380+
return b.rjust(width, fill)
381+
382+
def ljust_bytes(b: bytes, width: int, fill: bytes = b' ') -> bytes:
383+
return b.ljust(width, fill)
384+
385+
def test_rjust_with_default_fill() -> None:
386+
assert rjust_bytes(b'abc', 6) == b' abc'
387+
assert rjust_bytes(b'abc', 3) == b'abc'
388+
assert rjust_bytes(b'abc', 2) == b'abc'
389+
assert rjust_bytes(b'', 4) == b' '
390+
391+
def test_rjust_with_custom_fill() -> None:
392+
assert rjust_bytes(b'abc', 6, b'0') == b'000abc'
393+
assert rjust_bytes(b'abc', 5, b'_') == b'__abc'
394+
assert rjust_bytes(b'abc', 3, b'X') == b'abc'
395+
396+
def test_ljust_with_default_fill() -> None:
397+
assert ljust_bytes(b'abc', 6) == b'abc '
398+
assert ljust_bytes(b'abc', 3) == b'abc'
399+
assert ljust_bytes(b'abc', 2) == b'abc'
400+
assert ljust_bytes(b'', 4) == b' '
401+
402+
def test_ljust_with_custom_fill() -> None:
403+
assert ljust_bytes(b'abc', 6, b'0') == b'abc000'
404+
assert ljust_bytes(b'abc', 5, b'_') == b'abc__'
405+
assert ljust_bytes(b'abc', 3, b'X') == b'abc'
406+
407+
def test_edge_cases() -> None:
408+
assert rjust_bytes(b'abc', 0) == b'abc'
409+
assert ljust_bytes(b'abc', 0) == b'abc'
410+
# fillbyte must be length 1
411+
with assertRaises(TypeError):
412+
rjust_bytes(b'abc', 5, b'')
413+
with assertRaises(TypeError):
414+
ljust_bytes(b'abc', 5, b'12')

0 commit comments

Comments
 (0)