Skip to content

Commit b2ded6c

Browse files
committed
refactor: further simplification
1 parent e96b0c2 commit b2ded6c

25 files changed

+146
-156
lines changed

src/main/kotlin/AssociatedData.kt

Lines changed: 0 additions & 10 deletions
This file was deleted.

src/main/kotlin/CipherKey.kt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
11
package nl.sanderdijkhuis.noise
22

3-
import nl.sanderdijkhuis.noise.Size.Companion.valueSize
4-
53
@JvmInline
6-
value class CipherKey(val value: ByteArray) {
4+
value class CipherKey(val data: Data) {
75

86
init {
9-
require(value.valueSize == SIZE)
7+
require(data.size == SIZE)
108
}
119

1210
companion object {

src/main/kotlin/CipherState.kt

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,21 @@ package nl.sanderdijkhuis.noise
22

33
data class CipherState(val cryptography: Cryptography, val key: CipherKey? = null, val nonce: Nonce = Nonce.zero) {
44

5-
fun encryptWithAssociatedData(associatedData: AssociatedData, plaintext: Plaintext) =
6-
key?.let {
7-
println("Encrypting $key $nonce $associatedData $plaintext")
8-
State(copy(nonce = nonce.increment()), cryptography.encrypt(it, nonce, associatedData, plaintext))
9-
} ?: let {
10-
println("Returning plaintext $plaintext $nonce")
5+
fun encryptWithAssociatedData(associatedData: Data, plaintext: Plaintext) =
6+
if (key == null)
117
State(this, plaintext.ciphertext)
12-
}
8+
else
9+
nonce.increment()?.let {
10+
State(copy(nonce = it), cryptography.encrypt(key, nonce, associatedData, plaintext))
11+
} ?: State(this, plaintext.ciphertext)
1312

14-
fun decryptWithAssociatedData(data: AssociatedData, ciphertext: Ciphertext): State<CipherState, Plaintext>? = let {
15-
println("Decrypting $key $nonce $data $ciphertext")
13+
fun decryptWithAssociatedData(data: Data, ciphertext: Ciphertext): State<CipherState, Plaintext>? =
1614
if (key == null)
1715
State(this, ciphertext.plaintext)
1816
else
19-
cryptography.decrypt(key, nonce, data, ciphertext)?.let {
20-
State(copy(nonce = nonce.increment()), it)
17+
nonce.increment()?.let { n ->
18+
cryptography.decrypt(key, nonce, data, ciphertext)?.let { p ->
19+
State(copy(nonce = n), p)
20+
}
2121
}
22-
}
2322
}

src/main/kotlin/Ciphertext.kt

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
package nl.sanderdijkhuis.noise
22

33
@JvmInline
4-
value class Ciphertext(val value: ByteArray) {
4+
value class Ciphertext(val data: Data) {
55

6-
val data get() = Data(value)
7-
8-
val plaintext get() = Plaintext(value)
6+
val plaintext get() = Plaintext(data)
97
}

src/main/kotlin/Cryptography.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ interface Cryptography {
44

55
fun agree(privateKey: PrivateKey, publicKey: PublicKey): SharedSecret
66

7-
fun encrypt(key: CipherKey, nonce: Nonce, associatedData: AssociatedData, plaintext: Plaintext): Ciphertext
7+
fun encrypt(key: CipherKey, nonce: Nonce, associatedData: Data, plaintext: Plaintext): Ciphertext
88

9-
fun decrypt(key: CipherKey, nonce: Nonce, associatedData: AssociatedData, ciphertext: Ciphertext): Plaintext?
9+
fun decrypt(key: CipherKey, nonce: Nonce, associatedData: Data, ciphertext: Ciphertext): Plaintext?
1010

1111
fun hash(data: Data): Digest
1212
}

src/main/kotlin/Data.kt

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@ package nl.sanderdijkhuis.noise
33
import nl.sanderdijkhuis.noise.Size.Companion.valueSize
44
import kotlin.experimental.xor
55

6-
@JvmInline
7-
value class Data(val value: ByteArray) {
6+
data class Data(val value: ByteArray) {
87

98
init {
109
require(value.valueSize <= Size.MAX_MESSAGE)
@@ -14,11 +13,18 @@ value class Data(val value: ByteArray) {
1413

1514
val size get() = value.valueSize
1615

16+
val isEmpty get() = value.isEmpty()
17+
1718
fun xor(that: Data) = Data(let {
1819
require(value.size == that.value.size)
1920
ByteArray(value.size) { this.value[it].xor(that.value[it]) }
2021
})
2122

23+
override fun equals(other: Any?) =
24+
this === other || ((other as? Data)?.let { value.contentEquals(it.value) } ?: false)
25+
26+
override fun hashCode() = value.contentHashCode()
27+
2228
companion object {
2329

2430
val empty get() = Data(ByteArray(0))

src/main/kotlin/Digest.kt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,5 @@ value class Digest(val data: Data) {
77
require(data.size == HashFunction.HASH_SIZE)
88
}
99

10-
val associatedData get() = AssociatedData(data)
11-
1210
val messageAuthenticationKey get() = MessageAuthenticationKey(data)
1311
}

src/main/kotlin/HandshakeHash.kt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
package nl.sanderdijkhuis.noise
2+
3+
@JvmInline
4+
value class HandshakeHash(val digest: Digest)

src/main/kotlin/HandshakePattern.kt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@ data class HandshakePattern(val name: ProtocolName, val preSharedMessagePatterns
66

77
val Noise_XN_25519_ChaChaPoly_SHA256 =
88
HandshakePattern(
9-
ProtocolName("Noise_XN_25519_ChaChaPoly_SHA256".toByteArray()),
9+
ProtocolName(Data("Noise_XN_25519_ChaChaPoly_SHA256".toByteArray())),
1010
listOf(),
1111
listOf(listOf(Token.E), listOf(Token.E, Token.EE), listOf(Token.S, Token.SE))
1212
)
1313

1414
val Noise_NK_25519_ChaChaPoly_SHA256 =
1515
HandshakePattern(
16-
ProtocolName("Noise_NK_25519_ChaChaPoly_SHA256".toByteArray()),
16+
ProtocolName(Data("Noise_NK_25519_ChaChaPoly_SHA256".toByteArray())),
1717
listOf(listOf(), listOf(Token.S)),
1818
listOf(listOf(Token.E, Token.ES), listOf(Token.E, Token.EE))
1919
)

src/main/kotlin/HandshakeState.kt

Lines changed: 52 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ data class HandshakeState(
44
val role: Role,
55
val symmetricState: SymmetricState,
66
val messagePatterns: List<List<Token>>,
7-
val s: KeyPair? = null,
8-
val e: KeyPair? = null,
9-
val rs: PublicKey? = null,
10-
val re: PublicKey? = null,
7+
val localStaticKeyPair: KeyPair? = null,
8+
val localEphemeralKeyPair: KeyPair? = null,
9+
val remoteStaticKey: PublicKey? = null,
10+
val remoteEphemeralKey: PublicKey? = null,
1111
val trustedStaticKeys: Set<PublicKey> = emptySet()
1212
) {
1313

@@ -34,17 +34,17 @@ data class HandshakeState(
3434
}
3535
when {
3636
state == null -> null
37-
token == Token.E && e != null -> state.run(e.public.data) { it.mixHash(e.public.data) }
38-
token == Token.S && s != null -> state.runAndAppendInState {
39-
it.encryptAndHash(s.public.plaintext).map { c -> c.data }
37+
token == Token.E && localEphemeralKeyPair != null -> state.run(localEphemeralKeyPair.public.data) { it.mixHash(localEphemeralKeyPair.public.data) }
38+
token == Token.S && localStaticKeyPair != null -> state.runAndAppendInState {
39+
it.encryptAndHash(localStaticKeyPair.public.plaintext).map { c -> c.data }
4040
}
4141

42-
token == Token.EE -> mixKey(e, re)
43-
token == Token.ES && role == Role.INITIATOR -> mixKey(e, rs)
44-
token == Token.ES && role == Role.RESPONDER -> mixKey(s, re)
45-
token == Token.SE && role == Role.INITIATOR -> mixKey(s, re)
46-
token == Token.SE && role == Role.RESPONDER -> mixKey(e, rs)
47-
token == Token.SS -> mixKey(s, rs)
42+
token == Token.EE -> mixKey(localEphemeralKeyPair, remoteEphemeralKey)
43+
token == Token.ES && role == Role.INITIATOR -> mixKey(localEphemeralKeyPair, remoteStaticKey)
44+
token == Token.ES && role == Role.RESPONDER -> mixKey(localStaticKeyPair, remoteEphemeralKey)
45+
token == Token.SE && role == Role.INITIATOR -> mixKey(localStaticKeyPair, remoteEphemeralKey)
46+
token == Token.SE && role == Role.RESPONDER -> mixKey(localEphemeralKeyPair, remoteStaticKey)
47+
token == Token.SS -> mixKey(localStaticKeyPair, remoteStaticKey)
4848
else -> null
4949
}
5050
}?.runAndAppendInState { it.encryptAndHash(payload.plainText).map { c -> c.data } }
@@ -53,7 +53,14 @@ data class HandshakeState(
5353
when {
5454
state == null -> MessageResult.InsufficientKeyMaterial
5555
rest.isEmpty() -> symmetricState.split()
56-
.let { MessageResult.FinalHandshakeMessage(it.first, it.second, symmetricState.digest, state.result) }
56+
.let {
57+
MessageResult.FinalHandshakeMessage(
58+
it.first,
59+
it.second,
60+
symmetricState.handshakeHash,
61+
state.result
62+
)
63+
}
5764

5865
else -> MessageResult.IntermediateHandshakeMessage(
5966
state.current.copy(messagePatterns = rest),
@@ -74,38 +81,40 @@ data class HandshakeState(
7481
println("Token $token")
7582
when {
7683
state == null -> null
77-
token == Token.E && state.current.re == null ->
84+
token == Token.E && state.current.remoteEphemeralKey == null ->
7885
let {
7986
val re =
8087
PublicKey(
81-
state.result.value.sliceArray(
82-
IntRange(
83-
0,
84-
KeyAgreementConfiguration.SIZE.value - 1
88+
Data(
89+
state.result.value.sliceArray(
90+
IntRange(
91+
0,
92+
SharedSecret.SIZE.value - 1
93+
)
8594
)
8695
)
8796
)
8897
println("E: read $re")
8998
val mixed = state.current.symmetricState.mixHash(re.data)
9099
state.copy(
91-
current = state.current.copy(symmetricState = mixed, re = re),
92-
result = Data(state.result.value.drop(KeyAgreementConfiguration.SIZE.value).toByteArray())
100+
current = state.current.copy(symmetricState = mixed, remoteEphemeralKey = re),
101+
result = Data(state.result.value.drop(SharedSecret.SIZE.value).toByteArray())
93102
)
94103
}
95104

96-
token == Token.S && state.current.rs == null -> let {
105+
token == Token.S && state.current.remoteStaticKey == null -> let {
97106
println("S")
98-
val splitAt = KeyAgreementConfiguration.SIZE.value + 16
107+
val splitAt = SharedSecret.SIZE.value + 16
99108
val temp =
100109
state.result.value.sliceArray(IntRange(0, splitAt - 1))
101-
state.current.symmetricState.decryptAndHash(Ciphertext(temp))?.let {
102-
val publicKey = PublicKey(it.result.value)
110+
state.current.symmetricState.decryptAndHash(Ciphertext(Data(temp)))?.let {
111+
val publicKey = PublicKey(it.result.data)
103112
println("Public key $publicKey")
104113
println("Trusting $trustedStaticKeys")
105114
println("Trusted? ${trustedStaticKeys.contains(publicKey)}")
106115
if (trustedStaticKeys.contains(publicKey))
107116
state.copy(
108-
current = state.current.copy(symmetricState = it.current, rs = publicKey),
117+
current = state.current.copy(symmetricState = it.current, remoteStaticKey = publicKey),
109118
result = Data(
110119
state.result.value.drop(splitAt).toByteArray()
111120
)
@@ -115,26 +124,26 @@ data class HandshakeState(
115124
}
116125

117126
token == Token.EE -> let {
118-
println("EE: Mixing ${state.current.e} ${state.current.re}")
119-
mixKey(state.current.e, state.current.re)
127+
println("EE: Mixing ${state.current.localEphemeralKeyPair} ${state.current.remoteEphemeralKey}")
128+
mixKey(state.current.localEphemeralKeyPair, state.current.remoteEphemeralKey)
120129
}
121130

122-
token == Token.ES && role == Role.INITIATOR -> mixKey(state.current.e, state.current.rs)
123-
token == Token.ES && role == Role.RESPONDER -> mixKey(state.current.s, state.current.re)
131+
token == Token.ES && role == Role.INITIATOR -> mixKey(state.current.localEphemeralKeyPair, state.current.remoteStaticKey)
132+
token == Token.ES && role == Role.RESPONDER -> mixKey(state.current.localStaticKeyPair, state.current.remoteEphemeralKey)
124133
token == Token.SE && role == Role.INITIATOR -> let {
125134
println("SE")
126-
mixKey(state.current.s, state.current.re)
135+
mixKey(state.current.localStaticKeyPair, state.current.remoteEphemeralKey)
127136
}
128137

129-
token == Token.SE && role == Role.RESPONDER -> mixKey(state.current.e, state.current.rs)
130-
token == Token.SS -> mixKey(state.current.s, state.current.rs)
138+
token == Token.SE && role == Role.RESPONDER -> mixKey(state.current.localEphemeralKeyPair, state.current.remoteStaticKey)
139+
token == Token.SS -> mixKey(state.current.localStaticKeyPair, state.current.remoteStaticKey)
131140
else -> null
132141
}
133142
}?.let {
134-
it.current.symmetricState.decryptAndHash(Ciphertext(it.result.value))?.let { decrypted ->
143+
it.current.symmetricState.decryptAndHash(Ciphertext(it.result))?.let { decrypted ->
135144
State(
136145
it.current.copy(symmetricState = decrypted.current), Payload(
137-
Data(decrypted.result.value)
146+
decrypted.result.data
138147
)
139148
)
140149
}
@@ -144,7 +153,14 @@ data class HandshakeState(
144153
when {
145154
state == null -> MessageResult.InsufficientKeyMaterial
146155
rest.isEmpty() -> symmetricState.split()
147-
.let { MessageResult.FinalHandshakeMessage(it.first, it.second, symmetricState.digest, state.result) }
156+
.let {
157+
MessageResult.FinalHandshakeMessage(
158+
it.first,
159+
it.second,
160+
symmetricState.handshakeHash,
161+
state.result
162+
)
163+
}
148164

149165
else -> MessageResult.IntermediateHandshakeMessage(
150166
state.current.copy(messagePatterns = rest),

0 commit comments

Comments
 (0)