Skip to content

Commit cba1784

Browse files
implemented packed_simd for neon
1 parent c49e465 commit cba1784

File tree

15 files changed

+1211
-802
lines changed

15 files changed

+1211
-802
lines changed

curve25519-dalek/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,6 @@ group-bits = ["group", "ff/bits"]
7272

7373
[target.'cfg(all(not(curve25519_dalek_backend = "fiat"), not(curve25519_dalek_backend = "serial"), target_arch = "x86_64"))'.dependencies]
7474
curve25519-dalek-derive = { version = "0.1", path = "../curve25519-dalek-derive" }
75+
76+
[target.'cfg(all(not(curve25519_dalek_backend = "fiat"), not(curve25519_dalek_backend = "serial"), target_arch = "aarch64"))'.dependencies]
77+
curve25519-dalek-derive = { version = "0.1", path = "../curve25519-dalek-derive" }

curve25519-dalek/build.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ fn main() {
6666

6767
// Is the target arch & curve25519_dalek_bits potentially simd capable ?
6868
fn is_capable_simd(arch: &str, bits: DalekBits) -> bool {
69-
arch == "x86_64" && bits == DalekBits::Dalek64
69+
(arch == "x86_64" || arch == "aarch64") && bits == DalekBits::Dalek64
7070
}
7171

7272
// Deterministic cfg(curve25519_dalek_bits) when this is not explicitly set.

curve25519-dalek/src/backend/mod.rs

Lines changed: 58 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,16 +44,22 @@ pub mod vector;
4444

4545
#[derive(Copy, Clone)]
4646
enum BackendKind {
47-
#[cfg(curve25519_dalek_backend = "simd")]
47+
#[cfg(all(curve25519_dalek_backend = "simd", target_arch="x86_64"))]
4848
Avx2,
49-
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
49+
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="x86_64"))]
5050
Avx512,
51+
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="aarch64"))]
52+
Neon,
5153
Serial,
5254
}
5355

5456
#[inline]
5557
fn get_selected_backend() -> BackendKind {
56-
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
58+
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="aarch64"))]
59+
{
60+
return BackendKind::Neon;
61+
}
62+
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="x86_64"))]
5763
{
5864
cpufeatures::new!(cpuid_avx512, "avx512ifma", "avx512vl");
5965
let token_avx512: cpuid_avx512::InitToken = cpuid_avx512::init();
@@ -62,7 +68,7 @@ fn get_selected_backend() -> BackendKind {
6268
}
6369
}
6470

65-
#[cfg(curve25519_dalek_backend = "simd")]
71+
#[cfg(all(curve25519_dalek_backend = "simd", target_arch="x86_64"))]
6672
{
6773
cpufeatures::new!(cpuid_avx2, "avx2");
6874
let token_avx2: cpuid_avx2::InitToken = cpuid_avx2::init();
@@ -85,25 +91,32 @@ where
8591
use crate::traits::VartimeMultiscalarMul;
8692

8793
match get_selected_backend() {
88-
#[cfg(curve25519_dalek_backend = "simd")]
94+
#[cfg(all(curve25519_dalek_backend = "simd", target_arch="x86_64"))]
8995
BackendKind::Avx2 =>
9096
self::vector::scalar_mul::pippenger::spec_avx2::Pippenger::optional_multiscalar_mul::<I, J>(scalars, points),
91-
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
97+
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="x86_64"))]
9298
BackendKind::Avx512 =>
9399
self::vector::scalar_mul::pippenger::spec_avx512ifma_avx512vl::Pippenger::optional_multiscalar_mul::<I, J>(scalars, points),
100+
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="aarch64"))]
101+
BackendKind::Neon =>
102+
self::vector::scalar_mul::pippenger::spec_neon::Pippenger::optional_multiscalar_mul::<I, J>(scalars, points),
94103
BackendKind::Serial =>
95104
self::serial::scalar_mul::pippenger::Pippenger::optional_multiscalar_mul::<I, J>(scalars, points),
96105
}
97106
}
98107

99108
#[cfg(feature = "alloc")]
100109
pub(crate) enum VartimePrecomputedStraus {
101-
#[cfg(curve25519_dalek_backend = "simd")]
110+
#[cfg(all(curve25519_dalek_backend = "simd", target_arch="x86_64"))]
102111
Avx2(self::vector::scalar_mul::precomputed_straus::spec_avx2::VartimePrecomputedStraus),
103-
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
112+
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="x86_64"))]
104113
Avx512ifma(
105114
self::vector::scalar_mul::precomputed_straus::spec_avx512ifma_avx512vl::VartimePrecomputedStraus,
106115
),
116+
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="aarch64"))]
117+
Neon(
118+
self::vector::scalar_mul::precomputed_straus::spec_neon::VartimePrecomputedStraus
119+
),
107120
Scalar(self::serial::scalar_mul::precomputed_straus::VartimePrecomputedStraus),
108121
}
109122

@@ -117,12 +130,15 @@ impl VartimePrecomputedStraus {
117130
use crate::traits::VartimePrecomputedMultiscalarMul;
118131

119132
match get_selected_backend() {
120-
#[cfg(curve25519_dalek_backend = "simd")]
133+
#[cfg(all(curve25519_dalek_backend = "simd", target_arch="x86_64"))]
121134
BackendKind::Avx2 =>
122135
VartimePrecomputedStraus::Avx2(self::vector::scalar_mul::precomputed_straus::spec_avx2::VartimePrecomputedStraus::new(static_points)),
123-
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
136+
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="x86_64"))]
124137
BackendKind::Avx512 =>
125138
VartimePrecomputedStraus::Avx512ifma(self::vector::scalar_mul::precomputed_straus::spec_avx512ifma_avx512vl::VartimePrecomputedStraus::new(static_points)),
139+
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="aarch64"))]
140+
BackendKind::Neon =>
141+
VartimePrecomputedStraus::Neon(self::vector::scalar_mul::precomputed_straus::spec_neon::VartimePrecomputedStraus::new(static_points)),
126142
BackendKind::Serial =>
127143
VartimePrecomputedStraus::Scalar(self::serial::scalar_mul::precomputed_straus::VartimePrecomputedStraus::new(static_points))
128144
}
@@ -144,18 +160,24 @@ impl VartimePrecomputedStraus {
144160
use crate::traits::VartimePrecomputedMultiscalarMul;
145161

146162
match self {
147-
#[cfg(curve25519_dalek_backend = "simd")]
163+
#[cfg(all(curve25519_dalek_backend = "simd", target_arch="x86_64"))]
148164
VartimePrecomputedStraus::Avx2(inner) => inner.optional_mixed_multiscalar_mul(
149165
static_scalars,
150166
dynamic_scalars,
151167
dynamic_points,
152168
),
153-
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
169+
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="x86_64"))]
154170
VartimePrecomputedStraus::Avx512ifma(inner) => inner.optional_mixed_multiscalar_mul(
155171
static_scalars,
156172
dynamic_scalars,
157173
dynamic_points,
158174
),
175+
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="aarch64"))]
176+
VartimePrecomputedStraus::Neon(inner) => inner.optional_mixed_multiscalar_mul(
177+
static_scalars,
178+
dynamic_scalars,
179+
dynamic_points,
180+
),
159181
VartimePrecomputedStraus::Scalar(inner) => inner.optional_mixed_multiscalar_mul(
160182
static_scalars,
161183
dynamic_scalars,
@@ -177,19 +199,25 @@ where
177199
use crate::traits::MultiscalarMul;
178200

179201
match get_selected_backend() {
180-
#[cfg(curve25519_dalek_backend = "simd")]
202+
#[cfg(all(curve25519_dalek_backend = "simd", target_arch="x86_64"))]
181203
BackendKind::Avx2 => {
182204
self::vector::scalar_mul::straus::spec_avx2::Straus::multiscalar_mul::<I, J>(
183205
scalars, points,
184206
)
185207
}
186-
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
208+
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="x86_64"))]
187209
BackendKind::Avx512 => {
188210
self::vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::multiscalar_mul::<
189211
I,
190212
J,
191213
>(scalars, points)
192214
}
215+
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="aarch64"))]
216+
BackendKind::Neon => {
217+
self::vector::scalar_mul::straus::spec_neon::Straus::multiscalar_mul::<I, J>(
218+
scalars, points,
219+
)
220+
}
193221
BackendKind::Serial => {
194222
self::serial::scalar_mul::straus::Straus::multiscalar_mul::<I, J>(scalars, points)
195223
}
@@ -207,19 +235,25 @@ where
207235
use crate::traits::VartimeMultiscalarMul;
208236

209237
match get_selected_backend() {
210-
#[cfg(curve25519_dalek_backend = "simd")]
238+
#[cfg(all(curve25519_dalek_backend = "simd", target_arch="x86_64"))]
211239
BackendKind::Avx2 => {
212240
self::vector::scalar_mul::straus::spec_avx2::Straus::optional_multiscalar_mul::<I, J>(
213241
scalars, points,
214242
)
215243
}
216-
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
244+
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="x86_64"))]
217245
BackendKind::Avx512 => {
218246
self::vector::scalar_mul::straus::spec_avx512ifma_avx512vl::Straus::optional_multiscalar_mul::<
219247
I,
220248
J,
221249
>(scalars, points)
222250
}
251+
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="aarch64"))]
252+
BackendKind::Neon => {
253+
self::vector::scalar_mul::straus::spec_neon::Straus::optional_multiscalar_mul::<I, J>(
254+
scalars, points
255+
)
256+
}
223257
BackendKind::Serial => {
224258
self::serial::scalar_mul::straus::Straus::optional_multiscalar_mul::<I, J>(
225259
scalars, points,
@@ -231,12 +265,14 @@ where
231265
/// Perform constant-time, variable-base scalar multiplication.
232266
pub fn variable_base_mul(point: &EdwardsPoint, scalar: &Scalar) -> EdwardsPoint {
233267
match get_selected_backend() {
234-
#[cfg(curve25519_dalek_backend = "simd")]
268+
#[cfg(all(curve25519_dalek_backend = "simd", target_arch="x86_64"))]
235269
BackendKind::Avx2 => self::vector::scalar_mul::variable_base::spec_avx2::mul(point, scalar),
236-
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
270+
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="x86_64"))]
237271
BackendKind::Avx512 => {
238272
self::vector::scalar_mul::variable_base::spec_avx512ifma_avx512vl::mul(point, scalar)
239273
}
274+
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="aarch64"))]
275+
BackendKind::Neon => self::vector::scalar_mul::variable_base::spec_neon::mul(point, scalar),
240276
BackendKind::Serial => self::serial::scalar_mul::variable_base::mul(point, scalar),
241277
}
242278
}
@@ -245,12 +281,14 @@ pub fn variable_base_mul(point: &EdwardsPoint, scalar: &Scalar) -> EdwardsPoint
245281
#[allow(non_snake_case)]
246282
pub fn vartime_double_base_mul(a: &Scalar, A: &EdwardsPoint, b: &Scalar) -> EdwardsPoint {
247283
match get_selected_backend() {
248-
#[cfg(curve25519_dalek_backend = "simd")]
284+
#[cfg(all(curve25519_dalek_backend = "simd", target_arch="x86_64"))]
249285
BackendKind::Avx2 => self::vector::scalar_mul::vartime_double_base::spec_avx2::mul(a, A, b),
250-
#[cfg(all(curve25519_dalek_backend = "simd", nightly))]
286+
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="x86_64"))]
251287
BackendKind::Avx512 => {
252288
self::vector::scalar_mul::vartime_double_base::spec_avx512ifma_avx512vl::mul(a, A, b)
253289
}
290+
#[cfg(all(curve25519_dalek_backend = "simd", nightly, target_arch="aarch64"))]
291+
BackendKind::Neon => self::vector::scalar_mul::vartime_double_base::spec_neon::mul(a, A, b),
254292
BackendKind::Serial => self::serial::scalar_mul::vartime_double_base::mul(a, A, b),
255293
}
256294
}

curve25519-dalek/src/backend/vector/mod.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,17 @@
1212
#![doc = include_str!("../../../docs/parallel-formulas.md")]
1313

1414
#[allow(missing_docs)]
15+
#[cfg(all(target_arch="x86_64"))]
1516
pub mod packed_simd;
1617

18+
19+
#[cfg(all(target_arch="x86_64"))]
1720
pub mod avx2;
1821

19-
#[cfg(nightly)]
22+
#[cfg(all(nightly, target_arch="x86_64"))]
2023
pub mod ifma;
2124

22-
#[cfg(nightly)]
25+
#[cfg(all(nightly, target_arch="aarch64"))]
2326
pub mod neon;
2427

2528
pub mod scalar_mul;

0 commit comments

Comments
 (0)