Skip to content

Commit 26c494f

Browse files
committed
Use packed_simd::shuffle instead of vqtbx1q_u8
1 parent 3ab954c commit 26c494f

File tree

1 file changed

+35
-40
lines changed

1 file changed

+35
-40
lines changed

src/backend/vector/neon/field.rs

Lines changed: 35 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
// - Henry de Valence <hdevalence@hdevalence.ca>
1111
// - Robrecht Blancquaert <Robrecht.Simon.Blancquaert@vub.be>
1212

13-
//! More details on the algorithms can be found in the `avx2`
14-
//! module. Here comments are mostly added only when needed
15-
//! to explain differenes between the 'base' avx2 version and
13+
//! More details on the algorithms can be found in the `avx2`
14+
//! module. Here comments are mostly added only when needed
15+
//! to explain differenes between the 'base' avx2 version and
1616
//! this re-implementation for arm neon.
1717
1818
//! The most major difference is the split of one vector of 8
@@ -61,10 +61,10 @@ fn repack_pair(x: (u32x4, u32x4), y: (u32x4, u32x4)) -> (u32x4, u32x4) {
6161
use core::arch::aarch64::vgetq_lane_u32;
6262

6363
(vcombine_u32(
64-
vset_lane_u32(vgetq_lane_u32(x.0.into_bits(), 2) , vget_low_u32(x.0.into_bits()), 1),
65-
vset_lane_u32(vgetq_lane_u32(y.0.into_bits(), 2) , vget_low_u32(y.0.into_bits()), 1)).into_bits(),
64+
vset_lane_u32(vgetq_lane_u32(x.0.into_bits(), 2) , vget_low_u32(x.0.into_bits()), 1),
65+
vset_lane_u32(vgetq_lane_u32(y.0.into_bits(), 2) , vget_low_u32(y.0.into_bits()), 1)).into_bits(),
6666
vcombine_u32(
67-
vset_lane_u32(vgetq_lane_u32(x.1.into_bits(), 2) , vget_low_u32(x.1.into_bits()), 1),
67+
vset_lane_u32(vgetq_lane_u32(x.1.into_bits(), 2) , vget_low_u32(x.1.into_bits()), 1),
6868
vset_lane_u32(vgetq_lane_u32(y.1.into_bits(), 2) , vget_low_u32(y.1.into_bits()), 1)).into_bits())
6969
}
7070
}
@@ -100,16 +100,16 @@ macro_rules! lane_shuffle {
100100
unsafe {
101101
use core::arch::aarch64::vgetq_lane_u32;
102102
const c: [i32; 8] = [$l0, $l1, $l2, $l3, $l4, $l5, $l6, $l7];
103-
(u32x4::new(if c[0] < 4 { vgetq_lane_u32($x.0.into_bits(), c[0]) } else { vgetq_lane_u32($x.1.into_bits(), c[0] - 4) },
104-
if c[1] < 4 { vgetq_lane_u32($x.0.into_bits(), c[1]) } else { vgetq_lane_u32($x.1.into_bits(), c[1] - 4) },
105-
if c[2] < 4 { vgetq_lane_u32($x.0.into_bits(), c[2]) } else { vgetq_lane_u32($x.1.into_bits(), c[2] - 4) },
103+
(u32x4::new(if c[0] < 4 { vgetq_lane_u32($x.0.into_bits(), c[0]) } else { vgetq_lane_u32($x.1.into_bits(), c[0] - 4) },
104+
if c[1] < 4 { vgetq_lane_u32($x.0.into_bits(), c[1]) } else { vgetq_lane_u32($x.1.into_bits(), c[1] - 4) },
105+
if c[2] < 4 { vgetq_lane_u32($x.0.into_bits(), c[2]) } else { vgetq_lane_u32($x.1.into_bits(), c[2] - 4) },
106106
if c[3] < 4 { vgetq_lane_u32($x.0.into_bits(), c[3]) } else { vgetq_lane_u32($x.1.into_bits(), c[3] - 4) }),
107-
u32x4::new(if c[4] < 4 { vgetq_lane_u32($x.0.into_bits(), c[4]) } else { vgetq_lane_u32($x.1.into_bits(), c[4] - 4) },
108-
if c[5] < 4 { vgetq_lane_u32($x.0.into_bits(), c[5]) } else { vgetq_lane_u32($x.1.into_bits(), c[5] - 4) },
109-
if c[6] < 4 { vgetq_lane_u32($x.0.into_bits(), c[6]) } else { vgetq_lane_u32($x.1.into_bits(), c[6] - 4) },
107+
u32x4::new(if c[4] < 4 { vgetq_lane_u32($x.0.into_bits(), c[4]) } else { vgetq_lane_u32($x.1.into_bits(), c[4] - 4) },
108+
if c[5] < 4 { vgetq_lane_u32($x.0.into_bits(), c[5]) } else { vgetq_lane_u32($x.1.into_bits(), c[5] - 4) },
109+
if c[6] < 4 { vgetq_lane_u32($x.0.into_bits(), c[6]) } else { vgetq_lane_u32($x.1.into_bits(), c[6] - 4) },
110110
if c[7] < 4 { vgetq_lane_u32($x.0.into_bits(), c[7]) } else { vgetq_lane_u32($x.1.into_bits(), c[7] - 4) }))
111111
}
112-
112+
113113
}
114114
}
115115

@@ -161,14 +161,14 @@ impl FieldElement2625x4 {
161161
pub fn split(&self) -> [FieldElement51; 4] {
162162
let mut out = [FieldElement51::zero(); 4];
163163
for i in 0..5 {
164-
let a_2i = self.0[i].0.extract(0) as u64;
165-
let b_2i = self.0[i].0.extract(1) as u64;
166-
let a_2i_1 = self.0[i].0.extract(2) as u64;
164+
let a_2i = self.0[i].0.extract(0) as u64;
165+
let b_2i = self.0[i].0.extract(1) as u64;
166+
let a_2i_1 = self.0[i].0.extract(2) as u64;
167167
let b_2i_1 = self.0[i].0.extract(3) as u64;
168168
let c_2i = self.0[i].1.extract(0) as u64;
169-
let d_2i = self.0[i].1.extract(1) as u64;
170-
let c_2i_1 = self.0[i].1.extract(2) as u64;
171-
let d_2i_1 = self.0[i].1.extract(3) as u64;
169+
let d_2i = self.0[i].1.extract(1) as u64;
170+
let c_2i_1 = self.0[i].1.extract(2) as u64;
171+
let d_2i_1 = self.0[i].1.extract(3) as u64;
172172

173173
out[0].0[i] = a_2i + (a_2i_1 << 26);
174174
out[1].0[i] = b_2i + (b_2i_1 << 26);
@@ -212,33 +212,28 @@ impl FieldElement2625x4 {
212212
#[inline(always)]
213213
fn blend_lanes(x: (u32x4, u32x4), y: (u32x4, u32x4), control: Lanes) -> (u32x4, u32x4) {
214214
unsafe {
215-
use core::arch::aarch64::vqtbx1q_u8;
215+
use packed_simd::shuffle;
216216
match control {
217217
Lanes::C => {
218-
(x.0,
219-
vqtbx1q_u8(x.1.into_bits(), y.1.into_bits(), u8x16::new( 0, 1, 2, 3, 16, 16, 16, 16, 8, 9, 10, 11, 16, 16, 16, 16).into_bits()).into_bits())
218+
(x.0, shuffle!(y.1, x.1, [0, 5, 2, 7]))
220219
}
221220
Lanes::D => {
222-
(x.0,
223-
vqtbx1q_u8(x.1.into_bits(), y.1.into_bits(), u8x16::new(16, 16, 16, 16, 4, 5, 6, 7, 16, 16, 16, 16, 12, 13, 14, 15).into_bits()).into_bits())
221+
(x.0, shuffle!(y.1, x.1, [4, 1, 6, 3]))
224222
}
225223
Lanes::AD => {
226-
(vqtbx1q_u8(x.0.into_bits(), y.0.into_bits(), u8x16::new( 0, 1, 2, 3, 16, 16, 16, 16, 8, 9, 10, 11, 16, 16, 16, 16).into_bits() ).into_bits(),
227-
vqtbx1q_u8(x.1.into_bits(), y.1.into_bits(), u8x16::new(16, 16, 16, 16, 4, 5, 6, 7, 16, 16, 16, 16, 12, 13, 14, 15).into_bits() ).into_bits())
224+
(shuffle!(y.0, x.0, [0, 5, 2, 7]), shuffle!(y.1, x.1, [4, 1, 6, 3]))
228225
}
229226
Lanes::AB => {
230227
(y.0, x.1)
231228
}
232229
Lanes::AC => {
233-
(vqtbx1q_u8(x.0.into_bits(), y.0.into_bits(), u8x16::new( 0, 1, 2, 3, 16, 16, 16, 16, 8, 9, 10, 11, 16, 16, 16, 16).into_bits()).into_bits(),
234-
vqtbx1q_u8(x.1.into_bits(), y.1.into_bits(), u8x16::new( 0, 1, 2, 3, 16, 16, 16, 16, 8, 9, 10, 11, 16, 16, 16, 16).into_bits()).into_bits())
230+
(shuffle!(y.0, x.0, [0, 5, 2, 7]), shuffle!(y.1, x.1, [0, 5, 2, 7]))
235231
}
236232
Lanes::CD => {
237-
(x.0, y.1)
233+
(x.0, y.1)
238234
}
239235
Lanes::BC => {
240-
(vqtbx1q_u8(x.0.into_bits(), y.0.into_bits(), u8x16::new(16, 16, 16, 16, 4, 5, 6, 7, 16, 16, 16, 16, 12, 13, 14, 15).into_bits() ).into_bits(),
241-
vqtbx1q_u8(x.1.into_bits(), y.1.into_bits(), u8x16::new( 0, 1, 2, 3, 16, 16, 16, 16, 8, 9, 10, 11, 16, 16, 16, 16).into_bits() ).into_bits())
236+
(shuffle!(y.0, x.0, [4, 1, 6, 3]), shuffle!(y.1, x.1, [0, 5, 2, 7]))
242237
}
243238
Lanes::ABCD => {
244239
y
@@ -333,7 +328,7 @@ impl FieldElement2625x4 {
333328
use core::arch::aarch64::vget_high_u32;
334329
use core::arch::aarch64::vcombine_u32;
335330

336-
let c: (u32x4, u32x4) = (vqshlq_u32(v.0.into_bits(), shifts.0.into_bits()).into_bits(),
331+
let c: (u32x4, u32x4) = (vqshlq_u32(v.0.into_bits(), shifts.0.into_bits()).into_bits(),
337332
vqshlq_u32(v.1.into_bits(), shifts.1.into_bits()).into_bits());
338333
(vcombine_u32(vget_high_u32(c.0.into_bits()), vget_low_u32(c.0.into_bits())).into_bits(),
339334
vcombine_u32(vget_high_u32(c.1.into_bits()), vget_low_u32(c.1.into_bits())).into_bits())
@@ -377,7 +372,7 @@ impl FieldElement2625x4 {
377372
use core::arch::aarch64::vmulq_n_u32;
378373
use core::arch::aarch64::vget_low_u32;
379374
use core::arch::aarch64::vcombine_u32;
380-
375+
381376
let c9_19_spread: (u32x4, u32x4) = (vmulq_n_u32(c98.0.into_bits(), 19).into_bits(), vmulq_n_u32(c98.1.into_bits(), 19).into_bits());
382377

383378
(vcombine_u32(vget_low_u32(c9_19_spread.0.into_bits()), u32x2::splat(0).into_bits()).into_bits(),
@@ -423,9 +418,9 @@ impl FieldElement2625x4 {
423418
unsafe {
424419
use core::arch::aarch64::vmulq_n_u32;
425420

426-
c0 = (vmulq_n_u32(c0.0.into_bits(), 19).into_bits(),
421+
c0 = (vmulq_n_u32(c0.0.into_bits(), 19).into_bits(),
427422
vmulq_n_u32(c0.1.into_bits(), 19).into_bits());
428-
c1 = (vmulq_n_u32(c1.0.into_bits(), 19).into_bits(),
423+
c1 = (vmulq_n_u32(c1.0.into_bits(), 19).into_bits(),
429424
vmulq_n_u32(c1.1.into_bits(), 19).into_bits());
430425
}
431426

@@ -457,8 +452,8 @@ impl FieldElement2625x4 {
457452
#[inline(always)]
458453
fn m_lo(x: (u32x2, u32x2), y: (u32x2, u32x2)) -> (u32x2, u32x2) {
459454
use core::arch::aarch64::vmull_u32;
460-
unsafe {
461-
let x: (u32x4, u32x4) = (vmull_u32(x.0.into_bits(), y.0.into_bits()).into_bits(),
455+
unsafe {
456+
let x: (u32x4, u32x4) = (vmull_u32(x.0.into_bits(), y.0.into_bits()).into_bits(),
462457
vmull_u32(x.1.into_bits(), y.1.into_bits()).into_bits());
463458
(u32x2::new(x.0.extract(0), x.0.extract(2)), u32x2::new(x.1.extract(0), x.1.extract(2)))
464459
}
@@ -497,7 +492,7 @@ impl FieldElement2625x4 {
497492
let mut z7 = m(x0_2,x7) + m(x1_2,x6) + m(x2_2,x5) + m(x3_2,x4) + ((m(x8,x9_19)) << 1);
498493
let mut z8 = m(x0_2,x8) + m(x1_2,x7_2) + m(x2_2,x6) + m(x3_2,x5_2) + m(x4,x4) + ((m(x9,x9_19)) << 1);
499494
let mut z9 = m(x0_2,x9) + m(x1_2,x8) + m(x2_2,x7) + m(x3_2,x6) + m(x4_2,x5);
500-
495+
501496

502497
let low__p37 = u64x4::splat(0x3ffffed << 37);
503498
let even_p37 = u64x4::splat(0x3ffffff << 37);
@@ -609,8 +604,8 @@ impl<'a, 'b> Mul<&'b FieldElement2625x4> for &'a FieldElement2625x4 {
609604
#[inline(always)]
610605
fn m_lo(x: (u32x2, u32x2), y: (u32x2, u32x2)) -> (u32x2, u32x2) {
611606
use core::arch::aarch64::vmull_u32;
612-
unsafe {
613-
let x: (u32x4, u32x4) = (vmull_u32(x.0.into_bits(), y.0.into_bits()).into_bits(),
607+
unsafe {
608+
let x: (u32x4, u32x4) = (vmull_u32(x.0.into_bits(), y.0.into_bits()).into_bits(),
614609
vmull_u32(x.1.into_bits(), y.1.into_bits()).into_bits());
615610
(u32x2::new(x.0.extract(0), x.0.extract(2)), u32x2::new(x.1.extract(0), x.1.extract(2)))
616611
}

0 commit comments

Comments
 (0)