@@ -77,6 +77,8 @@ const (
77
77
EncryptDekAlgorithm = "encrypt.dek.algorithm"
78
78
// EncryptDekExpiryDays represents dek expiry days
79
79
EncryptDekExpiryDays = "encrypt.dek.expiry.days"
80
+ // EncryptAlternateKmsKeyIDs represents alternate kms key ids
81
+ EncryptAlternateKmsKeyIDs = "encrypt.alternate.kms.key.ids"
80
82
81
83
// Aes128Gcm represents AES128_GCM algorithm
82
84
Aes128Gcm = "AES128_GCM"
@@ -394,10 +396,7 @@ func (f *ExecutorTransform) getOrCreateDek(ctx serde.RuleContext, version *int)
394
396
}
395
397
var encryptedDek []byte
396
398
if ! f .Kek .Shared {
397
- primitive , err = getAead (f .Executor .Config , f .Kek )
398
- if err != nil {
399
- return nil , err
400
- }
399
+ primitive = & AeadWrapper {f .Executor .Config , f .Kek , getKmsKeyIDs (f .Executor .Config , f .Kek )}
401
400
// Generate new dek
402
401
keyData , err := registry .NewKeyData (f .Cryptor .KeyTemplate )
403
402
if err != nil {
@@ -431,10 +430,7 @@ func (f *ExecutorTransform) getOrCreateDek(ctx serde.RuleContext, version *int)
431
430
}
432
431
if keyBytes == nil {
433
432
if primitive == nil {
434
- primitive , err = getAead (f .Executor .Config , f .Kek )
435
- if err != nil {
436
- return nil , err
437
- }
433
+ primitive = & AeadWrapper {f .Executor .Config , f .Kek , getKmsKeyIDs (f .Executor .Config , f .Kek )}
438
434
}
439
435
encryptedDek , err := f .Executor .Client .GetDekEncryptedKeyMaterialBytes (dek )
440
436
if err != nil {
@@ -629,8 +625,79 @@ func extractVersion(ciphertext []byte) (int, error) {
629
625
return int (version ), nil
630
626
}
631
627
632
- func getAead (config map [string ]string , kek deks.Kek ) (tink.AEAD , error ) {
633
- kekURL := kek .KmsType + "://" + kek .KmsKeyID
628
+ func getKmsKeyIDs (config map [string ]string , kek deks.Kek ) []string {
629
+ kmsKeyIDs := []string {kek .KmsKeyID }
630
+ var alternateKmsKeyIDs []string
631
+ if kek .KmsProps != nil {
632
+ if ids , ok := kek .KmsProps [EncryptAlternateKmsKeyIDs ]; ok {
633
+ alternateKmsKeyIDs = strings .Split (ids , "," )
634
+ }
635
+ }
636
+ if alternateKmsKeyIDs == nil {
637
+ if ids , ok := config [EncryptAlternateKmsKeyIDs ]; ok {
638
+ alternateKmsKeyIDs = strings .Split (ids , "," )
639
+ }
640
+ }
641
+ if alternateKmsKeyIDs != nil {
642
+ for _ , id := range alternateKmsKeyIDs {
643
+ id = strings .TrimSpace (id )
644
+ if len (id ) > 0 {
645
+ kmsKeyIDs = append (kmsKeyIDs , id )
646
+ }
647
+ }
648
+ }
649
+ return kmsKeyIDs
650
+ }
651
+
652
+ // AeadWrapper is a wrapper for AEAD
653
+ type AeadWrapper struct {
654
+ Config map [string ]string
655
+ Kek deks.Kek
656
+ KmsKeyIds []string
657
+ }
658
+
659
+ // Encrypt encrypts plaintext with associatedData as associated data.
660
+ func (a * AeadWrapper ) Encrypt (plaintext , associatedData []byte ) ([]byte , error ) {
661
+ var aead tink.AEAD
662
+ var err error
663
+ var ciphertext []byte
664
+ for _ , kmsKeyID := range a .KmsKeyIds {
665
+ aead , err = getAead (a .Config , a .Kek .KmsType , kmsKeyID )
666
+ if err != nil {
667
+ log .Printf ("WARN: failed to get AEAD with %s: %v\n " , kmsKeyID , err )
668
+ continue
669
+ }
670
+ ciphertext , err = aead .Encrypt (plaintext , associatedData )
671
+ if err == nil {
672
+ return ciphertext , nil
673
+ }
674
+ log .Printf ("WARN: failed to encrypt with %s: %v\n " , kmsKeyID , err )
675
+ }
676
+ return nil , err
677
+ }
678
+
679
+ // Decrypt decrypts ciphertext with associatedData as associated data.
680
+ func (a * AeadWrapper ) Decrypt (ciphertext , associatedData []byte ) ([]byte , error ) {
681
+ var aead tink.AEAD
682
+ var err error
683
+ var plaintext []byte
684
+ for _ , kmsKeyID := range a .KmsKeyIds {
685
+ aead , err = getAead (a .Config , a .Kek .KmsType , kmsKeyID )
686
+ if err != nil {
687
+ log .Printf ("WARN: failed to get AEAD with %s: %v\n " , kmsKeyID , err )
688
+ continue
689
+ }
690
+ plaintext , err = aead .Decrypt (ciphertext , associatedData )
691
+ if err == nil {
692
+ return plaintext , nil
693
+ }
694
+ log .Printf ("WARN: failed to decrypt with %s: %v\n " , kmsKeyID , err )
695
+ }
696
+ return nil , err
697
+ }
698
+
699
+ func getAead (config map [string ]string , kmsType string , kmsKeyID string ) (tink.AEAD , error ) {
700
+ kekURL := kmsType + "://" + kmsKeyID
634
701
kmsClient , err := getKMSClient (config , kekURL )
635
702
if err != nil {
636
703
return nil , err
0 commit comments