Skip to content

Commit a4313e4

Browse files
authored
[mypyc] Add efficient primitives for str.strip() etc. (#18742)
Fixes mypyc/mypyc#1090. Copying cpython implementation for strip, lstrip and rstrip to `str_ops.c`.
1 parent 52907ac commit a4313e4

File tree

6 files changed

+249
-1
lines changed

6 files changed

+249
-1
lines changed

mypyc/lib-rt/CPy.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -717,13 +717,27 @@ static inline char CPyDict_CheckSize(PyObject *dict, CPyTagged size) {
717717

718718
// Str operations
719719

720+
// Macros for strip type. These values are copied from CPython.
721+
#define LEFTSTRIP 0
722+
#define RIGHTSTRIP 1
723+
#define BOTHSTRIP 2
720724

721725
PyObject *CPyStr_Build(Py_ssize_t len, ...);
722726
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index);
723727
CPyTagged CPyStr_Find(PyObject *str, PyObject *substr, CPyTagged start, int direction);
724728
CPyTagged CPyStr_FindWithEnd(PyObject *str, PyObject *substr, CPyTagged start, CPyTagged end, int direction);
725729
PyObject *CPyStr_Split(PyObject *str, PyObject *sep, CPyTagged max_split);
726730
PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split);
731+
PyObject *_CPyStr_Strip(PyObject *self, int strip_type, PyObject *sep);
732+
static inline PyObject *CPyStr_Strip(PyObject *self, PyObject *sep) {
733+
return _CPyStr_Strip(self, BOTHSTRIP, sep);
734+
}
735+
static inline PyObject *CPyStr_LStrip(PyObject *self, PyObject *sep) {
736+
return _CPyStr_Strip(self, LEFTSTRIP, sep);
737+
}
738+
static inline PyObject *CPyStr_RStrip(PyObject *self, PyObject *sep) {
739+
return _CPyStr_Strip(self, RIGHTSTRIP, sep);
740+
}
727741
PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr, PyObject *new_substr, CPyTagged max_replace);
728742
PyObject *CPyStr_Append(PyObject *o1, PyObject *o2);
729743
PyObject *CPyStr_GetSlice(PyObject *obj, CPyTagged start, CPyTagged end);

mypyc/lib-rt/str_ops.c

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,59 @@
55
#include <Python.h>
66
#include "CPy.h"
77

8+
// Copied from cpython.git:Objects/unicodeobject.c@0ef4ffeefd1737c18dc9326133c7894d58108c2e.
9+
#define BLOOM_MASK unsigned long
10+
#define BLOOM(mask, ch) ((mask & (1UL << ((ch) & (BLOOM_WIDTH - 1)))))
11+
#if LONG_BIT >= 128
12+
#define BLOOM_WIDTH 128
13+
#elif LONG_BIT >= 64
14+
#define BLOOM_WIDTH 64
15+
#elif LONG_BIT >= 32
16+
#define BLOOM_WIDTH 32
17+
#else
18+
#error "LONG_BIT is smaller than 32"
19+
#endif
20+
21+
// Copied from cpython.git:Objects/unicodeobject.c@0ef4ffeefd1737c18dc9326133c7894d58108c2e.
22+
// This is needed for str.strip("...").
23+
static inline BLOOM_MASK
24+
make_bloom_mask(int kind, const void* ptr, Py_ssize_t len)
25+
{
26+
#define BLOOM_UPDATE(TYPE, MASK, PTR, LEN) \
27+
do { \
28+
TYPE *data = (TYPE *)PTR; \
29+
TYPE *end = data + LEN; \
30+
Py_UCS4 ch; \
31+
for (; data != end; data++) { \
32+
ch = *data; \
33+
MASK |= (1UL << (ch & (BLOOM_WIDTH - 1))); \
34+
} \
35+
break; \
36+
} while (0)
37+
38+
/* calculate simple bloom-style bitmask for a given unicode string */
39+
40+
BLOOM_MASK mask;
41+
42+
mask = 0;
43+
switch (kind) {
44+
case PyUnicode_1BYTE_KIND:
45+
BLOOM_UPDATE(Py_UCS1, mask, ptr, len);
46+
break;
47+
case PyUnicode_2BYTE_KIND:
48+
BLOOM_UPDATE(Py_UCS2, mask, ptr, len);
49+
break;
50+
case PyUnicode_4BYTE_KIND:
51+
BLOOM_UPDATE(Py_UCS4, mask, ptr, len);
52+
break;
53+
default:
54+
Py_UNREACHABLE();
55+
}
56+
return mask;
57+
58+
#undef BLOOM_UPDATE
59+
}
60+
861
PyObject *CPyStr_GetItem(PyObject *str, CPyTagged index) {
962
if (PyUnicode_READY(str) != -1) {
1063
if (CPyTagged_CheckShort(index)) {
@@ -174,6 +227,124 @@ PyObject *CPyStr_RSplit(PyObject *str, PyObject *sep, CPyTagged max_split) {
174227
return PyUnicode_RSplit(str, sep, temp_max_split);
175228
}
176229

230+
// This function has been copied from _PyUnicode_XStrip in cpython.git:Objects/unicodeobject.c@0ef4ffeefd1737c18dc9326133c7894d58108c2e.
231+
static PyObject *_PyStr_XStrip(PyObject *self, int striptype, PyObject *sepobj) {
232+
const void *data;
233+
int kind;
234+
Py_ssize_t i, j, len;
235+
BLOOM_MASK sepmask;
236+
Py_ssize_t seplen;
237+
238+
// This check is needed from Python 3.9 and earlier.
239+
if (PyUnicode_READY(self) == -1 || PyUnicode_READY(sepobj) == -1)
240+
return NULL;
241+
242+
kind = PyUnicode_KIND(self);
243+
data = PyUnicode_DATA(self);
244+
len = PyUnicode_GET_LENGTH(self);
245+
seplen = PyUnicode_GET_LENGTH(sepobj);
246+
sepmask = make_bloom_mask(PyUnicode_KIND(sepobj),
247+
PyUnicode_DATA(sepobj),
248+
seplen);
249+
250+
i = 0;
251+
if (striptype != RIGHTSTRIP) {
252+
while (i < len) {
253+
Py_UCS4 ch = PyUnicode_READ(kind, data, i);
254+
if (!BLOOM(sepmask, ch))
255+
break;
256+
if (PyUnicode_FindChar(sepobj, ch, 0, seplen, 1) < 0)
257+
break;
258+
i++;
259+
}
260+
}
261+
262+
j = len;
263+
if (striptype != LEFTSTRIP) {
264+
j--;
265+
while (j >= i) {
266+
Py_UCS4 ch = PyUnicode_READ(kind, data, j);
267+
if (!BLOOM(sepmask, ch))
268+
break;
269+
if (PyUnicode_FindChar(sepobj, ch, 0, seplen, 1) < 0)
270+
break;
271+
j--;
272+
}
273+
274+
j++;
275+
}
276+
277+
return PyUnicode_Substring(self, i, j);
278+
}
279+
280+
// Copied from do_strip function in cpython.git/Objects/unicodeobject.c@0ef4ffeefd1737c18dc9326133c7894d58108c2e.
281+
PyObject *_CPyStr_Strip(PyObject *self, int strip_type, PyObject *sep) {
282+
if (sep == NULL || sep == Py_None) {
283+
Py_ssize_t len, i, j;
284+
285+
// This check is needed from Python 3.9 and earlier.
286+
if (PyUnicode_READY(self) == -1)
287+
return NULL;
288+
289+
len = PyUnicode_GET_LENGTH(self);
290+
291+
if (PyUnicode_IS_ASCII(self)) {
292+
const Py_UCS1 *data = PyUnicode_1BYTE_DATA(self);
293+
294+
i = 0;
295+
if (strip_type != RIGHTSTRIP) {
296+
while (i < len) {
297+
Py_UCS1 ch = data[i];
298+
if (!_Py_ascii_whitespace[ch])
299+
break;
300+
i++;
301+
}
302+
}
303+
304+
j = len;
305+
if (strip_type != LEFTSTRIP) {
306+
j--;
307+
while (j >= i) {
308+
Py_UCS1 ch = data[j];
309+
if (!_Py_ascii_whitespace[ch])
310+
break;
311+
j--;
312+
}
313+
j++;
314+
}
315+
}
316+
else {
317+
int kind = PyUnicode_KIND(self);
318+
const void *data = PyUnicode_DATA(self);
319+
320+
i = 0;
321+
if (strip_type != RIGHTSTRIP) {
322+
while (i < len) {
323+
Py_UCS4 ch = PyUnicode_READ(kind, data, i);
324+
if (!Py_UNICODE_ISSPACE(ch))
325+
break;
326+
i++;
327+
}
328+
}
329+
330+
j = len;
331+
if (strip_type != LEFTSTRIP) {
332+
j--;
333+
while (j >= i) {
334+
Py_UCS4 ch = PyUnicode_READ(kind, data, j);
335+
if (!Py_UNICODE_ISSPACE(ch))
336+
break;
337+
j--;
338+
}
339+
j++;
340+
}
341+
}
342+
343+
return PyUnicode_Substring(self, i, j);
344+
}
345+
return _PyStr_XStrip(self, strip_type, sep);
346+
}
347+
177348
PyObject *CPyStr_Replace(PyObject *str, PyObject *old_substr,
178349
PyObject *new_substr, CPyTagged max_replace) {
179350
Py_ssize_t temp_max_replace = CPyTagged_AsSsize_t(max_replace);

mypyc/primitives/str_ops.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,25 @@
135135
var_arg_type=str_rprimitive,
136136
)
137137

138+
# str.strip, str.lstrip, str.rstrip
139+
for strip_prefix in ["l", "r", ""]:
140+
method_op(
141+
name=f"{strip_prefix}strip",
142+
arg_types=[str_rprimitive, str_rprimitive],
143+
return_type=str_rprimitive,
144+
c_function_name=f"CPyStr_{strip_prefix.upper()}Strip",
145+
error_kind=ERR_NEVER,
146+
)
147+
method_op(
148+
name=f"{strip_prefix}strip",
149+
arg_types=[str_rprimitive],
150+
return_type=str_rprimitive,
151+
c_function_name=f"CPyStr_{strip_prefix.upper()}Strip",
152+
# This 0 below is implicitly treated as NULL in C.
153+
extra_int_constants=[(0, c_int_rprimitive)],
154+
error_kind=ERR_NEVER,
155+
)
156+
138157
# str.startswith(str)
139158
method_op(
140159
name="startswith",

mypyc/test-data/fixtures/ir.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ def rfind(self, sub: str, start: Optional[int] = None, end: Optional[int] = None
107107
def split(self, sep: Optional[str] = None, maxsplit: int = -1) -> List[str]: pass
108108
def rsplit(self, sep: Optional[str] = None, maxsplit: int = -1) -> List[str]: pass
109109
def splitlines(self, keepends: bool = False) -> List[str]: ...
110-
def strip (self, item: str) -> str: pass
110+
def strip (self, item: Optional[str] = None) -> str: pass
111+
def lstrip(self, item: Optional[str] = None) -> str: pass
112+
def rstrip(self, item: Optional[str] = None) -> str: pass
111113
def join(self, x: Iterable[str]) -> str: pass
112114
def format(self, *args: Any, **kwargs: Any) -> str: ...
113115
def upper(self) -> str: ...

mypyc/test-data/irbuild-str.test

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,3 +481,26 @@ L0:
481481
keep_alive x
482482
r6 = unbox(int, r5)
483483
return r6
484+
485+
[case testStrip]
486+
def do_strip(s: str) -> None:
487+
s.lstrip("x")
488+
s.strip("y")
489+
s.rstrip("z")
490+
s.lstrip()
491+
s.strip()
492+
s.rstrip()
493+
[out]
494+
def do_strip(s):
495+
s, r0, r1, r2, r3, r4, r5, r6, r7, r8 :: str
496+
L0:
497+
r0 = 'x'
498+
r1 = CPyStr_LStrip(s, r0)
499+
r2 = 'y'
500+
r3 = CPyStr_Strip(s, r2)
501+
r4 = 'z'
502+
r5 = CPyStr_RStrip(s, r4)
503+
r6 = CPyStr_LStrip(s, 0)
504+
r7 = CPyStr_Strip(s, 0)
505+
r8 = CPyStr_RStrip(s, 0)
506+
return 1

mypyc/test-data/run-strings.test

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -774,3 +774,22 @@ def test_surrogate() -> None:
774774
assert ord(f()) == 0xd800
775775
assert ord("\udfff") == 0xdfff
776776
assert repr("foobar\x00\xab\ud912\U00012345") == r"'foobar\x00«\ud912𒍅'"
777+
778+
[case testStrip]
779+
def test_all_strips_default() -> None:
780+
s = " a1\t"
781+
assert s.lstrip() == "a1\t"
782+
assert s.strip() == "a1"
783+
assert s.rstrip() == " a1"
784+
def test_all_strips() -> None:
785+
s = "xxb2yy"
786+
assert s.lstrip("xy") == "b2yy"
787+
assert s.strip("xy") == "b2"
788+
assert s.rstrip("xy") == "xxb2"
789+
def test_unicode_whitespace() -> None:
790+
assert "\u200A\u000D\u2009\u2020\u000Dtt\u0085\u000A".strip() == "\u2020\u000Dtt"
791+
def test_unicode_range() -> None:
792+
assert "\u2029 \U00107581 ".lstrip() == "\U00107581 "
793+
assert "\u2029 \U0010AAAA\U00104444B\u205F ".strip() == "\U0010AAAA\U00104444B"
794+
assert " \u3000\u205F ".strip() == ""
795+
assert "\u2029 \U00102865\u205F ".rstrip() == "\u2029 \U00102865"

0 commit comments

Comments
 (0)