Skip to content

Commit 237f4d9

Browse files
authored
Optimize rails_xss_friendly_size and rails_friendly_size with ARM Neon instructions. (#966)
* Optimize rails_xss_friendly_size with ARM Neon instructions. * Fixed a typo in a comment. * Optimize rails_friendly_size with ARM Neon instructions. Additionally add tests. * Formatting. * Comment out code which forces testing rails_friendly_size. * Added even longer tests. * Try to run include active_support again to cover rails_friendly_size in the tests. * Commenting out tests that fails in Rails CI.
1 parent 1c2cbe5 commit 237f4d9

File tree

3 files changed

+344
-11
lines changed

3 files changed

+344
-11
lines changed

ext/oj/dump.c

Lines changed: 98 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -163,23 +163,41 @@ inline static uint8x16x4_t load_uint8x16_4(const unsigned char *table) {
163163
}
164164

165165
static uint8x16x4_t hibit_friendly_chars_neon[2];
166+
static uint8x16x4_t rails_friendly_chars_neon[2];
167+
static uint8x16x4_t rails_xss_friendly_chars_neon[4];
166168

167169
void initialize_neon(void) {
170+
// We only need the first 128 bytes of the hibit friendly chars table. Everything above 127 is
171+
// set to 1. If that ever changes, the code will need to be updated.
168172
hibit_friendly_chars_neon[0] = load_uint8x16_4((const unsigned char *)hibit_friendly_chars);
169173
hibit_friendly_chars_neon[1] = load_uint8x16_4((const unsigned char *)hibit_friendly_chars + 64);
170174

175+
// rails_friendly_chars is the same as hibit_friendly_chars. Only the first 128 bytes have values
176+
// that are not '1'. If that ever changes, the code will need to be updated.
177+
rails_friendly_chars_neon[0] = load_uint8x16_4((const unsigned char *)rails_friendly_chars);
178+
rails_friendly_chars_neon[1] = load_uint8x16_4((const unsigned char *)rails_friendly_chars + 64);
179+
180+
rails_xss_friendly_chars_neon[0] = load_uint8x16_4((const unsigned char *)rails_xss_friendly_chars);
181+
rails_xss_friendly_chars_neon[1] = load_uint8x16_4((const unsigned char *)rails_xss_friendly_chars + 64);
182+
rails_xss_friendly_chars_neon[2] = load_uint8x16_4((const unsigned char *)rails_xss_friendly_chars + 128);
183+
rails_xss_friendly_chars_neon[3] = load_uint8x16_4((const unsigned char *)rails_xss_friendly_chars + 192);
184+
171185
// All bytes should be 0 except for those that need more than 1 byte of output. This will allow the
172186
// code to limit the lookups to the first 128 bytes (values 0 - 127). Bytes above 127 will result
173187
// in 0 with the vqtbl4q_u8 instruction.
174-
hibit_friendly_chars_neon[0].val[0] = vsubq_u8(hibit_friendly_chars_neon[0].val[0], vdupq_n_u8('1'));
175-
hibit_friendly_chars_neon[0].val[1] = vsubq_u8(hibit_friendly_chars_neon[0].val[1], vdupq_n_u8('1'));
176-
hibit_friendly_chars_neon[0].val[2] = vsubq_u8(hibit_friendly_chars_neon[0].val[2], vdupq_n_u8('1'));
177-
hibit_friendly_chars_neon[0].val[3] = vsubq_u8(hibit_friendly_chars_neon[0].val[3], vdupq_n_u8('1'));
178-
179-
hibit_friendly_chars_neon[1].val[0] = vsubq_u8(hibit_friendly_chars_neon[1].val[0], vdupq_n_u8('1'));
180-
hibit_friendly_chars_neon[1].val[1] = vsubq_u8(hibit_friendly_chars_neon[1].val[1], vdupq_n_u8('1'));
181-
hibit_friendly_chars_neon[1].val[2] = vsubq_u8(hibit_friendly_chars_neon[1].val[2], vdupq_n_u8('1'));
182-
hibit_friendly_chars_neon[1].val[3] = vsubq_u8(hibit_friendly_chars_neon[1].val[3], vdupq_n_u8('1'));
188+
uint8x16_t one = vdupq_n_u8('1');
189+
for (int i = 0; i < 2; i++) {
190+
for (int j = 0; j < 4; j++) {
191+
hibit_friendly_chars_neon[i].val[j] = vsubq_u8(hibit_friendly_chars_neon[i].val[j], one);
192+
rails_friendly_chars_neon[i].val[j] = vsubq_u8(rails_friendly_chars_neon[i].val[j], one);
193+
}
194+
}
195+
196+
for (int i = 0; i < 4; i++) {
197+
for (int j = 0; j < 4; j++) {
198+
rails_xss_friendly_chars_neon[i].val[j] = vsubq_u8(rails_xss_friendly_chars_neon[i].val[j], one);
199+
}
200+
}
183201
}
184202
#endif
185203

@@ -235,9 +253,43 @@ inline static size_t hixss_friendly_size(const uint8_t *str, size_t len) {
235253

236254
inline static long rails_xss_friendly_size(const uint8_t *str, size_t len) {
237255
long size = 0;
238-
size_t i = len;
239256
uint8_t hi = 0;
240257

258+
#ifdef HAVE_SIMD_NEON
259+
size_t i = 0;
260+
261+
uint8x16_t has_some_hibit = vdupq_n_u8(0);
262+
uint8x16_t hibit = vdupq_n_u8(0x80);
263+
for (; i + sizeof(uint8x16_t) < len; i += sizeof(uint8x16_t), str += sizeof(uint8x16_t)) {
264+
size += sizeof(uint8x16_t);
265+
266+
uint8x16_t chunk = vld1q_u8(str);
267+
268+
// Check to see if any of these bytes have the high bit set.
269+
has_some_hibit = vorrq_u8(has_some_hibit, vandq_u8(chunk, hibit));
270+
271+
uint8x16_t tmp1 = vqtbl4q_u8(rails_xss_friendly_chars_neon[0], chunk);
272+
uint8x16_t tmp2 = vqtbl4q_u8(rails_xss_friendly_chars_neon[1], veorq_u8(chunk, vdupq_n_u8(0x40)));
273+
uint8x16_t tmp3 = vqtbl4q_u8(rails_xss_friendly_chars_neon[2], veorq_u8(chunk, vdupq_n_u8(0x80)));
274+
uint8x16_t tmp4 = vqtbl4q_u8(rails_xss_friendly_chars_neon[3], veorq_u8(chunk, vdupq_n_u8(0xc0)));
275+
uint8x16_t result = vorrq_u8(tmp4, vorrq_u8(tmp3, vorrq_u8(tmp1, tmp2)));
276+
uint8_t tmp = vaddvq_u8(result);
277+
size += tmp;
278+
}
279+
280+
// 'hi' should be set if any of the bytes we processed have the high bit set. It doesn't matter which ones.
281+
hi = vmaxvq_u8(has_some_hibit) != 0;
282+
283+
for (; i < len; str++, i++) {
284+
size += rails_xss_friendly_chars[*str] - '0';
285+
hi |= *str & 0x80;
286+
}
287+
if (0 == hi) {
288+
return size;
289+
}
290+
return -(size);
291+
#else
292+
size_t i = len;
241293
for (; 0 < i; str++, i--) {
242294
size += rails_xss_friendly_chars[*str];
243295
hi |= *str & 0x80;
@@ -246,13 +298,47 @@ inline static long rails_xss_friendly_size(const uint8_t *str, size_t len) {
246298
return size - len * (size_t)'0';
247299
}
248300
return -(size - len * (size_t)'0');
301+
#endif /* HAVE_SIMD_NEON */
249302
}
250303

251304
inline static size_t rails_friendly_size(const uint8_t *str, size_t len) {
252305
long size = 0;
253-
size_t i = len;
254306
uint8_t hi = 0;
307+
#ifdef HAVE_SIMD_NEON
308+
size_t i = 0;
255309

310+
uint8x16_t has_some_hibit = vdupq_n_u8(0);
311+
uint8x16_t hibit = vdupq_n_u8(0x80);
312+
313+
for (; i + sizeof(uint8x16_t) < len; i += sizeof(uint8x16_t), str += sizeof(uint8x16_t)) {
314+
size += sizeof(uint8x16_t);
315+
316+
// See https://lemire.me/blog/2019/07/23/arbitrary-byte-to-byte-maps-using-arm-neon/
317+
uint8x16_t chunk = vld1q_u8(str);
318+
319+
// Check to see if any of these bytes have the high bit set.
320+
has_some_hibit = vorrq_u8(has_some_hibit, vandq_u8(chunk, hibit));
321+
322+
uint8x16_t tmp1 = vqtbl4q_u8(rails_friendly_chars_neon[0], chunk);
323+
uint8x16_t tmp2 = vqtbl4q_u8(rails_friendly_chars_neon[1], veorq_u8(chunk, vdupq_n_u8(0x40)));
324+
uint8x16_t result = vorrq_u8(tmp1, tmp2);
325+
uint8_t tmp = vaddvq_u8(result);
326+
size += tmp;
327+
}
328+
329+
// 'hi' should be set if any of the bytes we processed have the high bit set. It doesn't matter which ones.
330+
hi = vmaxvq_u8(has_some_hibit) != 0;
331+
332+
for (; i < len; str++, i++) {
333+
size += rails_friendly_chars[*str] - '0';
334+
hi |= *str & 0x80;
335+
}
336+
if (0 == hi) {
337+
return size;
338+
}
339+
return -(size);
340+
#else
341+
size_t i = len;
256342
for (; 0 < i; str++, i--) {
257343
size += rails_friendly_chars[*str];
258344
hi |= *str & 0x80;
@@ -261,6 +347,7 @@ inline static size_t rails_friendly_size(const uint8_t *str, size_t len) {
261347
return size - len * (size_t)'0';
262348
}
263349
return -(size - len * (size_t)'0');
350+
#endif /* HAVE_SIMD_NEON */
264351
}
265352

266353
const char *oj_nan_str(VALUE obj, int opt, int mode, bool plus, int *lenp) {

test/test_long_strings.rb

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
#!/usr/bin/env ruby
2+
# frozen_string_literal: true
3+
4+
$LOAD_PATH << __dir__
5+
6+
require 'helper'
7+
8+
# # This is to force testing 'rails_friendly_size'.
9+
# begin
10+
# require 'active_support'
11+
# HAS_RAILS = true
12+
# rescue LoadError
13+
# puts 'ActiveSupport not found. Skipping ActiveSupport tests.'
14+
# HAS_RAILS = false
15+
# end
16+
17+
# The tests in this file are to specifically handle testing the ARM Neon code
18+
# that is used to speed up the dumping of long strings. The tests are likely
19+
# redundant with respect to correctness. However, they are designed specifically
20+
# to exercise the code paths that operate on vectors of 16 bytes. Additionally,
21+
# we need to ensure that not all tests are exactly multiples of 16 bytes in length
22+
# so that we can test the code that handles the remainder of the string.
23+
class LongStringsTest < Minitest::Test
24+
25+
def test_escapes
26+
run_basic_tests(:compat)
27+
run_basic_tests(:rails)
28+
29+
# if HAS_RAILS
30+
# Oj.optimize_rails()
31+
# ActiveSupport::JSON::Encoding.escape_html_entities_in_json = false
32+
# run_basic_tests(:rails)
33+
# end
34+
end
35+
36+
def run_basic_tests(mode)
37+
str = '\n'*15
38+
expected = "\"#{'\\\\n'*15}\""
39+
out = Oj.dump(str, mode: mode)
40+
assert_equal(expected, out)
41+
42+
str = '\n'*16
43+
expected = "\"#{'\\\\n'*16}\""
44+
out = Oj.dump(str, mode: mode)
45+
assert_equal(expected, out)
46+
47+
str = '\n'*17
48+
expected = "\"#{'\\\\n'*17}\""
49+
out = Oj.dump(str, mode: mode)
50+
assert_equal(expected, out)
51+
52+
str = '\n'*1700
53+
expected = "\"#{'\\\\n'*1700}\""
54+
out = Oj.dump(str, mode: mode)
55+
assert_equal(expected, out)
56+
57+
str = '\n'*32
58+
expected = "\"#{'\\\\n'*32}\""
59+
out = Oj.dump(str, mode: mode)
60+
assert_equal(expected, out)
61+
62+
str = '\f'*63
63+
expected = "\"#{'\\\\f'*63}\""
64+
out = Oj.dump(str, mode: mode)
65+
assert_equal(expected, out)
66+
67+
str = '\t'*127
68+
expected = "\"#{'\\\\t'*127}\""
69+
out = Oj.dump(str, mode: mode)
70+
assert_equal(expected, out)
71+
72+
str = '\t'*500
73+
expected = "\"#{'\\\\t'*500}\""
74+
out = Oj.dump(str, mode: mode)
75+
assert_equal(expected, out)
76+
77+
str = "\u0001" * 16
78+
out = Oj.dump(str, mode: mode)
79+
expected = "\"#{"\\u0001" * 16}\""
80+
assert_equal(expected, out)
81+
82+
str = "\u0001" * 1024
83+
out = Oj.dump(str, mode: mode)
84+
expected = "\"#{"\\u0001" * 1024}\""
85+
assert_equal(expected, out)
86+
87+
str = "\u0001\u0002" * 8
88+
out = Oj.dump(str, mode: mode)
89+
expected = "\"#{"\\u0001\\u0002" * 8}\""
90+
assert_equal(expected, out)
91+
92+
str = "\u0001\u0002" * 2000
93+
out = Oj.dump(str, mode: mode)
94+
expected = "\"#{"\\u0001\\u0002" * 2000}\""
95+
assert_equal(expected, out)
96+
97+
str = "\u0001a" * 8
98+
out = Oj.dump(str, mode: mode)
99+
expected = "\"#{"\\u0001a" * 8}\""
100+
assert_equal(expected, out)
101+
102+
str = "abc\u0010" * 4
103+
out = Oj.dump(str, mode: mode)
104+
expected = "\"#{"abc\\u0010" * 4}\""
105+
assert_equal(expected, out)
106+
107+
str = "abc\u0010" * 5
108+
out = Oj.dump(str, mode: mode)
109+
expected = "\"#{"abc\\u0010" * 5}\""
110+
assert_equal(expected, out)
111+
112+
str = "\u0001\u0002" * 9
113+
out = Oj.dump(str, mode: mode)
114+
expected = "\"#{"\\u0001\\u0002" * 9}\""
115+
assert_equal(expected, out)
116+
117+
str = "\u0001\u0002" * 2048
118+
out = Oj.dump(str, mode: mode)
119+
expected = "\"#{"\\u0001\\u0002" * 2048}\""
120+
assert_equal(expected, out)
121+
122+
str = '\"'
123+
out = Oj.dump(str, mode: mode)
124+
expected = "\"\\\\\\\"\""
125+
assert_equal(expected, out)
126+
127+
str = '"'*16
128+
out = Oj.dump(str, mode: mode)
129+
expected = '"\\"\\"\\"\\"\\"\\"\\"\\"\\"\\"\\"\\"\\"\\"\\"\\""'
130+
assert_equal(expected, out)
131+
132+
str = '"'*20
133+
out = Oj.dump(str, mode: mode)
134+
expected = '"\\"\\"\\"\\"\\"\\"\\"\\"\\"\\"\\"\\"\\"\\"\\"\\"\\"\\"\\"\\""'
135+
assert_equal(expected, out)
136+
end
137+
138+
def test_dump_long_str_no_escapes
139+
str = 'This is a test of the emergency broadcast system. This is only a test.'
140+
out = Oj.dump(str)
141+
assert_equal(%|"#{str}"|, out)
142+
end
143+
144+
def test_dump_long_str_with_escapes
145+
str = 'This is a\ntest of the emergency broadcast system. This is only a test.'
146+
out = Oj.dump(str)
147+
expected = %|"This is a\\\\ntest of the emergency broadcast system. This is only a test."|
148+
assert_equal(expected, out)
149+
end
150+
151+
def test_dump_long_str_with_quotes
152+
str = 'This is a "test" of the emergency broadcast system. This is only a "test".'
153+
out = Oj.dump(str)
154+
expected = %|"This is a \\\"test\\\" of the emergency broadcast system. This is only a \\\"test\\\"."|
155+
assert_equal(expected, out)
156+
end
157+
158+
def test_dump_long_str_no_escapes_rails
159+
str = 'This is a test of the emergency broadcast system. This is only a test.'
160+
out = Oj.dump(str, mode: :rails)
161+
assert_equal(%|"#{str}"|, out)
162+
end
163+
164+
def test_dump_long_str_with_escapes_rails
165+
str = 'This is a\ntest of the emergency broadcast system. This is only a test.'
166+
out = Oj.dump(str, mode: :rails)
167+
expected = %|"This is a\\\\ntest of the emergency broadcast system. This is only a test."|
168+
assert_equal(expected, out)
169+
end
170+
171+
def test_dump_long_str_with_quotes_rails
172+
str = 'This is a "test" of the emergency broadcast system. This is only a "test".'
173+
out = Oj.dump(str, mode: :rails)
174+
expected = %|"This is a \\\"test\\\" of the emergency broadcast system. This is only a \\\"test\\\"."|
175+
assert_equal(expected, out)
176+
end
177+
178+
def test_long_string_with_high_byte_set
179+
str = 'This item will cost €1000.00. I hope you have a great day!'
180+
out = Oj.dump(str)
181+
expected = %["This item will cost €1000.00. I hope you have a great day!"]
182+
assert_equal(expected, out)
183+
184+
out = Oj.dump(str, mode: :rails)
185+
assert_equal(expected, out)
186+
end
187+
188+
def test_high_byte_set
189+
str = "€"*15
190+
out = Oj.dump(str)
191+
expected = %["#{"€"*15}"]
192+
assert_equal(expected, out)
193+
out = Oj.dump(str, mode: :rails)
194+
assert_equal(expected, out)
195+
196+
str = "€"*16
197+
out = Oj.dump(str)
198+
expected = %["#{"€"*16}"]
199+
assert_equal(expected, out)
200+
out = Oj.dump(str, mode: :rails)
201+
assert_equal(expected, out)
202+
203+
str = "€"*17
204+
out = Oj.dump(str)
205+
expected = %["#{"€"*17}"]
206+
assert_equal(expected, out)
207+
out = Oj.dump(str, mode: :rails)
208+
assert_equal(expected, out)
209+
210+
str = "€"*1700
211+
out = Oj.dump(str)
212+
expected = %["#{"€"*1700}"]
213+
assert_equal(expected, out)
214+
out = Oj.dump(str, mode: :rails)
215+
assert_equal(expected, out)
216+
217+
str = "€abcdefghijklmnop"
218+
out = Oj.dump(str)
219+
expected = %["#{"€abcdefghijklmnop"}"]
220+
assert_equal(expected, out)
221+
out = Oj.dump(str, mode: :rails)
222+
assert_equal(expected, out)
223+
224+
str = "€" + "abcdefghijklmnop"*1000
225+
out = Oj.dump(str)
226+
expected = %["#{"€" + "abcdefghijklmnop"*1000}"]
227+
assert_equal(expected, out)
228+
out = Oj.dump(str, mode: :rails)
229+
assert_equal(expected, out)
230+
231+
str = "€" + "abcdefghijklmnop" * 2000 + "€"
232+
out = Oj.dump(str)
233+
expected = %["#{"€" + "abcdefghijklmnop" * 2000 + "€"}"]
234+
assert_equal(expected, out)
235+
out = Oj.dump(str, mode: :rails)
236+
assert_equal(expected, out)
237+
238+
str = "abcdefghijklmnop"*3000 + "€"
239+
out = Oj.dump(str)
240+
expected = %["#{"abcdefghijklmnop"*3000 + "€"}"]
241+
assert_equal(expected, out)
242+
out = Oj.dump(str, mode: :rails)
243+
assert_equal(expected, out)
244+
end
245+
end

0 commit comments

Comments
 (0)