Skip to content

Commit 0113690

Browse files
author
Alex Couture-Beil
committed
Use aes256 when payload exceeds 256 bytes
Signed-off-by: Alex Couture-Beil <alex@mofo.ca>
1 parent 953febc commit 0113690

File tree

2 files changed

+143
-16
lines changed

2 files changed

+143
-16
lines changed

build/darwin/amd64/secretshare

-64 Bytes
Binary file not shown.

cmd/secretshare/main.go

Lines changed: 143 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
package main
22

33
import (
4+
"bytes"
5+
"crypto/aes"
6+
"crypto/cipher"
47
"crypto/rand"
58
"crypto/rsa"
69
"crypto/sha256"
710
"crypto/x509"
811
"encoding/base64"
12+
"encoding/binary"
913
"encoding/pem"
1014
"fmt"
1115
"io/ioutil"
@@ -41,8 +45,87 @@ func generateKey() (string, string, error) {
4145
return pubKeyStr, privKeyStr, nil
4246
}
4347

44-
func encrypt(msg, publicKey string) (string, error) {
45-
parsed, _, _, _, err := ssh.ParseAuthorizedKey([]byte(publicKey))
48+
// encryptAES256 returns a random passphrase and corresponding bytes encrypted with it
49+
func encryptAES256(data []byte) ([]byte, []byte, error) {
50+
key := make([]byte, 32)
51+
if _, err := rand.Read(key); err != nil {
52+
return nil, nil, err
53+
}
54+
55+
n := len(data)
56+
buf := new(bytes.Buffer)
57+
if err := binary.Write(buf, binary.LittleEndian, uint64(n)); err != nil {
58+
return nil, nil, err
59+
}
60+
if _, err := buf.Write(data); err != nil {
61+
return nil, nil, err
62+
}
63+
64+
paddingN := aes.BlockSize - (buf.Len() % aes.BlockSize)
65+
if paddingN > 0 {
66+
padding := make([]byte, paddingN)
67+
if _, err := rand.Read(padding); err != nil {
68+
return nil, nil, err
69+
}
70+
if _, err := buf.Write(padding); err != nil {
71+
return nil, nil, err
72+
}
73+
}
74+
plaintext := buf.Bytes()
75+
76+
block, err := aes.NewCipher(key)
77+
if err != nil {
78+
return nil, nil, err
79+
}
80+
81+
ciphertext := make([]byte, aes.BlockSize+len(plaintext))
82+
iv := ciphertext[:aes.BlockSize]
83+
if _, err := rand.Read(iv); err != nil {
84+
return nil, nil, err
85+
}
86+
87+
mode := cipher.NewCBCEncrypter(block, iv)
88+
mode.CryptBlocks(ciphertext[aes.BlockSize:], plaintext)
89+
90+
return key, ciphertext, nil
91+
}
92+
93+
func decryptAES(key, ciphertext []byte) ([]byte, error) {
94+
block, err := aes.NewCipher(key)
95+
if err != nil {
96+
return nil, err
97+
}
98+
99+
if len(ciphertext) < aes.BlockSize {
100+
panic("ciphertext too short")
101+
}
102+
iv := ciphertext[:aes.BlockSize]
103+
ciphertext = ciphertext[aes.BlockSize:]
104+
105+
if len(ciphertext)%aes.BlockSize != 0 {
106+
panic("ciphertext is not a multiple of the block size")
107+
}
108+
109+
mode := cipher.NewCBCDecrypter(block, iv)
110+
111+
// works inplace when both args are the same
112+
mode.CryptBlocks(ciphertext, ciphertext)
113+
114+
buf := bytes.NewReader(ciphertext)
115+
var n uint64
116+
if err = binary.Read(buf, binary.LittleEndian, &n); err != nil {
117+
return nil, err
118+
}
119+
payload := make([]byte, n)
120+
if _, err = buf.Read(payload); err != nil {
121+
return nil, err
122+
}
123+
124+
return payload, nil
125+
}
126+
127+
func encrypt(msg, publicKey []byte) (string, error) {
128+
parsed, _, _, _, err := ssh.ParseAuthorizedKey(publicKey)
46129
if err != nil {
47130
return "", err
48131
}
@@ -56,35 +139,79 @@ func encrypt(msg, publicKey string) (string, error) {
56139
// Finally, we can convert back to an *rsa.PublicKey
57140
pub := pubCrypto.(*rsa.PublicKey)
58141

142+
if len(msg) <= 256 {
143+
// msg is small enough to only use OAEP encryption; this will result in less bytes to transfer.
144+
encryptedBytes, err := rsa.EncryptOAEP(
145+
sha256.New(),
146+
rand.Reader,
147+
pub,
148+
msg,
149+
nil)
150+
if err != nil {
151+
return "", err
152+
}
153+
if len(encryptedBytes) != 256 {
154+
panic(len(encryptedBytes))
155+
}
156+
return base64.StdEncoding.EncodeToString(encryptedBytes), nil
157+
}
158+
159+
// otherwise, encrypt using AES256
160+
161+
key, ciphertext, err := encryptAES256(msg)
162+
if err != nil {
163+
return "", err
164+
}
165+
59166
encryptedBytes, err := rsa.EncryptOAEP(
60167
sha256.New(),
61168
rand.Reader,
62169
pub,
63-
[]byte(msg),
170+
key,
64171
nil)
65172
if err != nil {
66173
return "", err
67174
}
68-
return base64.StdEncoding.EncodeToString(encryptedBytes), nil
175+
if len(encryptedBytes) != 256 {
176+
panic(len(encryptedBytes))
177+
}
178+
return base64.StdEncoding.EncodeToString(append(encryptedBytes, ciphertext...)), nil
69179
}
70180

71-
func decrypt(data, priv string) (string, error) {
181+
func decrypt(data, priv string) ([]byte, error) {
72182
data2, err := base64.StdEncoding.DecodeString(data)
73183
if err != nil {
74-
return "", err
184+
return nil, err
185+
}
186+
187+
if len(data2) < 256 {
188+
return nil, fmt.Errorf("not enough data to decrypt")
75189
}
76190

77191
block, _ := pem.Decode([]byte(priv))
78192
key, err := x509.ParsePKCS1PrivateKey(block.Bytes)
79193
if err != nil {
80-
return "", err
194+
return nil, err
81195
}
82196

83-
decrypted, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, key, data2, nil)
197+
oaepData := data2[:256]
198+
aesData := data2[256:]
199+
payload, err := rsa.DecryptOAEP(sha256.New(), rand.Reader, key, oaepData, nil)
84200
if err != nil {
85-
return "", err
201+
return nil, err
86202
}
87-
return string(decrypted), nil
203+
204+
if len(aesData) == 0 {
205+
return payload, nil
206+
}
207+
208+
decryptedAESKey := payload
209+
decrypted, err := decryptAES(decryptedAESKey, aesData)
210+
if err != nil {
211+
return nil, err
212+
}
213+
214+
return decrypted, nil
88215
}
89216

90217
func test() {
@@ -94,8 +221,8 @@ func test() {
94221
}
95222
pub = strings.TrimPrefix(pub, "ssh-rsa ")
96223

97-
data := "hello test"
98-
encrypted, err := encrypt(data, "ssh-rsa "+pub)
224+
data := []byte("hello test")
225+
encrypted, err := encrypt(data, []byte("ssh-rsa "+pub))
99226
if err != nil {
100227
panic(err)
101228
}
@@ -105,7 +232,7 @@ func test() {
105232
panic(err)
106233
}
107234

108-
if data != data2 {
235+
if !bytes.Equal(data, data2) {
109236
panic("missmatch")
110237
}
111238
}
@@ -178,21 +305,21 @@ func main() {
178305

179306
data, err := ioutil.ReadAll(os.Stdin)
180307
if err != nil {
181-
fmt.Fprintf(os.Stderr, "failed while reading from stdin: %s", err.Error())
308+
fmt.Fprintf(os.Stderr, "failed while reading from stdin: %s\n", err.Error())
182309
os.Exit(1)
183310
}
184311

185312
if arg == "decrypt" {
186313
data2, err := decrypt(string(data), priv)
187314
if err != nil {
188-
fmt.Fprintf(os.Stderr, "failed while decrypting: %s", err.Error())
315+
fmt.Fprintf(os.Stderr, "failed while decrypting: %s\n", err.Error())
189316
os.Exit(1)
190317
}
191318
fmt.Printf("%s", data2)
192319
return
193320
}
194321

195-
encrypted, err := encrypt(string(data), "ssh-rsa "+arg)
322+
encrypted, err := encrypt(data, []byte("ssh-rsa "+arg))
196323
if err != nil {
197324
fmt.Fprintf(os.Stderr, "failed while encrypting: %s", err.Error())
198325
os.Exit(1)

0 commit comments

Comments
 (0)