1
1
package main
2
2
3
3
import (
4
+ "bytes"
5
+ "crypto/aes"
6
+ "crypto/cipher"
4
7
"crypto/rand"
5
8
"crypto/rsa"
6
9
"crypto/sha256"
7
10
"crypto/x509"
8
11
"encoding/base64"
12
+ "encoding/binary"
9
13
"encoding/pem"
10
14
"fmt"
11
15
"io/ioutil"
@@ -41,8 +45,87 @@ func generateKey() (string, string, error) {
41
45
return pubKeyStr , privKeyStr , nil
42
46
}
43
47
44
- func encrypt (msg , publicKey string ) (string , error ) {
45
- parsed , _ , _ , _ , err := ssh .ParseAuthorizedKey ([]byte (publicKey ))
48
+ // encryptAES256 returns a random passphrase and corresponding bytes encrypted with it
49
+ func encryptAES256 (data []byte ) ([]byte , []byte , error ) {
50
+ key := make ([]byte , 32 )
51
+ if _ , err := rand .Read (key ); err != nil {
52
+ return nil , nil , err
53
+ }
54
+
55
+ n := len (data )
56
+ buf := new (bytes.Buffer )
57
+ if err := binary .Write (buf , binary .LittleEndian , uint64 (n )); err != nil {
58
+ return nil , nil , err
59
+ }
60
+ if _ , err := buf .Write (data ); err != nil {
61
+ return nil , nil , err
62
+ }
63
+
64
+ paddingN := aes .BlockSize - (buf .Len () % aes .BlockSize )
65
+ if paddingN > 0 {
66
+ padding := make ([]byte , paddingN )
67
+ if _ , err := rand .Read (padding ); err != nil {
68
+ return nil , nil , err
69
+ }
70
+ if _ , err := buf .Write (padding ); err != nil {
71
+ return nil , nil , err
72
+ }
73
+ }
74
+ plaintext := buf .Bytes ()
75
+
76
+ block , err := aes .NewCipher (key )
77
+ if err != nil {
78
+ return nil , nil , err
79
+ }
80
+
81
+ ciphertext := make ([]byte , aes .BlockSize + len (plaintext ))
82
+ iv := ciphertext [:aes .BlockSize ]
83
+ if _ , err := rand .Read (iv ); err != nil {
84
+ return nil , nil , err
85
+ }
86
+
87
+ mode := cipher .NewCBCEncrypter (block , iv )
88
+ mode .CryptBlocks (ciphertext [aes .BlockSize :], plaintext )
89
+
90
+ return key , ciphertext , nil
91
+ }
92
+
93
+ func decryptAES (key , ciphertext []byte ) ([]byte , error ) {
94
+ block , err := aes .NewCipher (key )
95
+ if err != nil {
96
+ return nil , err
97
+ }
98
+
99
+ if len (ciphertext ) < aes .BlockSize {
100
+ panic ("ciphertext too short" )
101
+ }
102
+ iv := ciphertext [:aes .BlockSize ]
103
+ ciphertext = ciphertext [aes .BlockSize :]
104
+
105
+ if len (ciphertext )% aes .BlockSize != 0 {
106
+ panic ("ciphertext is not a multiple of the block size" )
107
+ }
108
+
109
+ mode := cipher .NewCBCDecrypter (block , iv )
110
+
111
+ // works inplace when both args are the same
112
+ mode .CryptBlocks (ciphertext , ciphertext )
113
+
114
+ buf := bytes .NewReader (ciphertext )
115
+ var n uint64
116
+ if err = binary .Read (buf , binary .LittleEndian , & n ); err != nil {
117
+ return nil , err
118
+ }
119
+ payload := make ([]byte , n )
120
+ if _ , err = buf .Read (payload ); err != nil {
121
+ return nil , err
122
+ }
123
+
124
+ return payload , nil
125
+ }
126
+
127
+ func encrypt (msg , publicKey []byte ) (string , error ) {
128
+ parsed , _ , _ , _ , err := ssh .ParseAuthorizedKey (publicKey )
46
129
if err != nil {
47
130
return "" , err
48
131
}
@@ -56,35 +139,79 @@ func encrypt(msg, publicKey string) (string, error) {
56
139
// Finally, we can convert back to an *rsa.PublicKey
57
140
pub := pubCrypto .(* rsa.PublicKey )
58
141
142
+ if len (msg ) <= 256 {
143
+ // msg is small enough to only use OAEP encryption; this will result in less bytes to transfer.
144
+ encryptedBytes , err := rsa .EncryptOAEP (
145
+ sha256 .New (),
146
+ rand .Reader ,
147
+ pub ,
148
+ msg ,
149
+ nil )
150
+ if err != nil {
151
+ return "" , err
152
+ }
153
+ if len (encryptedBytes ) != 256 {
154
+ panic (len (encryptedBytes ))
155
+ }
156
+ return base64 .StdEncoding .EncodeToString (encryptedBytes ), nil
157
+ }
158
+
159
+ // otherwise, encrypt using AES256
160
+
161
+ key , ciphertext , err := encryptAES256 (msg )
162
+ if err != nil {
163
+ return "" , err
164
+ }
165
+
59
166
encryptedBytes , err := rsa .EncryptOAEP (
60
167
sha256 .New (),
61
168
rand .Reader ,
62
169
pub ,
63
- [] byte ( msg ) ,
170
+ key ,
64
171
nil )
65
172
if err != nil {
66
173
return "" , err
67
174
}
68
- return base64 .StdEncoding .EncodeToString (encryptedBytes ), nil
175
+ if len (encryptedBytes ) != 256 {
176
+ panic (len (encryptedBytes ))
177
+ }
178
+ return base64 .StdEncoding .EncodeToString (append (encryptedBytes , ciphertext ... )), nil
69
179
}
70
180
71
- func decrypt (data , priv string ) (string , error ) {
181
+ func decrypt (data , priv string ) ([] byte , error ) {
72
182
data2 , err := base64 .StdEncoding .DecodeString (data )
73
183
if err != nil {
74
- return "" , err
184
+ return nil , err
185
+ }
186
+
187
+ if len (data2 ) < 256 {
188
+ return nil , fmt .Errorf ("not enough data to decrypt" )
75
189
}
76
190
77
191
block , _ := pem .Decode ([]byte (priv ))
78
192
key , err := x509 .ParsePKCS1PrivateKey (block .Bytes )
79
193
if err != nil {
80
- return "" , err
194
+ return nil , err
81
195
}
82
196
83
- decrypted , err := rsa .DecryptOAEP (sha256 .New (), rand .Reader , key , data2 , nil )
197
+ oaepData := data2 [:256 ]
198
+ aesData := data2 [256 :]
199
+ payload , err := rsa .DecryptOAEP (sha256 .New (), rand .Reader , key , oaepData , nil )
84
200
if err != nil {
85
- return "" , err
201
+ return nil , err
86
202
}
87
- return string (decrypted ), nil
203
+
204
+ if len (aesData ) == 0 {
205
+ return payload , nil
206
+ }
207
+
208
+ decryptedAESKey := payload
209
+ decrypted , err := decryptAES (decryptedAESKey , aesData )
210
+ if err != nil {
211
+ return nil , err
212
+ }
213
+
214
+ return decrypted , nil
88
215
}
89
216
90
217
func test () {
@@ -94,8 +221,8 @@ func test() {
94
221
}
95
222
pub = strings .TrimPrefix (pub , "ssh-rsa " )
96
223
97
- data := "hello test"
98
- encrypted , err := encrypt (data , "ssh-rsa " + pub )
224
+ data := [] byte ( "hello test" )
225
+ encrypted , err := encrypt (data , [] byte ( "ssh-rsa " + pub ) )
99
226
if err != nil {
100
227
panic (err )
101
228
}
@@ -105,7 +232,7 @@ func test() {
105
232
panic (err )
106
233
}
107
234
108
- if data != data2 {
235
+ if ! bytes . Equal ( data , data2 ) {
109
236
panic ("missmatch" )
110
237
}
111
238
}
@@ -178,21 +305,21 @@ func main() {
178
305
179
306
data , err := ioutil .ReadAll (os .Stdin )
180
307
if err != nil {
181
- fmt .Fprintf (os .Stderr , "failed while reading from stdin: %s" , err .Error ())
308
+ fmt .Fprintf (os .Stderr , "failed while reading from stdin: %s\n " , err .Error ())
182
309
os .Exit (1 )
183
310
}
184
311
185
312
if arg == "decrypt" {
186
313
data2 , err := decrypt (string (data ), priv )
187
314
if err != nil {
188
- fmt .Fprintf (os .Stderr , "failed while decrypting: %s" , err .Error ())
315
+ fmt .Fprintf (os .Stderr , "failed while decrypting: %s\n " , err .Error ())
189
316
os .Exit (1 )
190
317
}
191
318
fmt .Printf ("%s" , data2 )
192
319
return
193
320
}
194
321
195
- encrypted , err := encrypt (string ( data ), "ssh-rsa " + arg )
322
+ encrypted , err := encrypt (data , [] byte ( "ssh-rsa " + arg ) )
196
323
if err != nil {
197
324
fmt .Fprintf (os .Stderr , "failed while encrypting: %s" , err .Error ())
198
325
os .Exit (1 )
0 commit comments