Skip to content

Commit 4c815e5

Browse files
authored
DGS-21595 Allow alternate KMS key IDs on a KEK (#1460)
* First cut * Add test * Add test
1 parent b348372 commit 4c815e5

File tree

3 files changed

+148
-10
lines changed

3 files changed

+148
-10
lines changed

examples/docker_aws_lambda_example/go.sum

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ github.com/confluentinc/confluent-kafka-go/v2 v2.4.0 h1:NbOku86JJlsRJPJKE0snNsz6
5353
github.com/confluentinc/confluent-kafka-go/v2 v2.4.0/go.mod h1:E1dEQy50ZLfqs7T9luxz0rLxaeFZJZE92XvApJOr/Rk=
5454
github.com/confluentinc/confluent-kafka-go/v2 v2.5.0/go.mod h1:Hyo+IIQ/tmsfkOcRP8T6VlSeOW3T33v0Me8Xvq4u90Y=
5555
github.com/confluentinc/confluent-kafka-go/v2 v2.5.3/go.mod h1:QxYLPRKR1MVlkXCCjzjjrpXb0VyFNfVaZXi0obZykJ0=
56+
github.com/confluentinc/confluent-kafka-go/v2 v2.11.0/go.mod h1:hScqtFIGUI1wqHIgM3mjoqEou4VweGGGX7dMpcUKves=
5657
github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARubLw=
5758
github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U=
5859
github.com/containerd/containerd v1.7.12 h1:+KQsnv4VnzyxWcfO9mlxxELaoztsDEjOuCMPAuPqgU0=

schemaregistry/rules/encryption/encrypt_executor.go

Lines changed: 77 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,8 @@ const (
7777
EncryptDekAlgorithm = "encrypt.dek.algorithm"
7878
// EncryptDekExpiryDays represents dek expiry days
7979
EncryptDekExpiryDays = "encrypt.dek.expiry.days"
80+
// EncryptAlternateKmsKeyIDs represents alternate kms key ids
81+
EncryptAlternateKmsKeyIDs = "encrypt.alternate.kms.key.ids"
8082

8183
// Aes128Gcm represents AES128_GCM algorithm
8284
Aes128Gcm = "AES128_GCM"
@@ -394,10 +396,7 @@ func (f *ExecutorTransform) getOrCreateDek(ctx serde.RuleContext, version *int)
394396
}
395397
var encryptedDek []byte
396398
if !f.Kek.Shared {
397-
primitive, err = getAead(f.Executor.Config, f.Kek)
398-
if err != nil {
399-
return nil, err
400-
}
399+
primitive = &AeadWrapper{f.Executor.Config, f.Kek, getKmsKeyIDs(f.Executor.Config, f.Kek)}
401400
// Generate new dek
402401
keyData, err := registry.NewKeyData(f.Cryptor.KeyTemplate)
403402
if err != nil {
@@ -431,10 +430,7 @@ func (f *ExecutorTransform) getOrCreateDek(ctx serde.RuleContext, version *int)
431430
}
432431
if keyBytes == nil {
433432
if primitive == nil {
434-
primitive, err = getAead(f.Executor.Config, f.Kek)
435-
if err != nil {
436-
return nil, err
437-
}
433+
primitive = &AeadWrapper{f.Executor.Config, f.Kek, getKmsKeyIDs(f.Executor.Config, f.Kek)}
438434
}
439435
encryptedDek, err := f.Executor.Client.GetDekEncryptedKeyMaterialBytes(dek)
440436
if err != nil {
@@ -629,8 +625,79 @@ func extractVersion(ciphertext []byte) (int, error) {
629625
return int(version), nil
630626
}
631627

632-
func getAead(config map[string]string, kek deks.Kek) (tink.AEAD, error) {
633-
kekURL := kek.KmsType + "://" + kek.KmsKeyID
628+
func getKmsKeyIDs(config map[string]string, kek deks.Kek) []string {
629+
kmsKeyIDs := []string{kek.KmsKeyID}
630+
var alternateKmsKeyIDs []string
631+
if kek.KmsProps != nil {
632+
if ids, ok := kek.KmsProps[EncryptAlternateKmsKeyIDs]; ok {
633+
alternateKmsKeyIDs = strings.Split(ids, ",")
634+
}
635+
}
636+
if alternateKmsKeyIDs == nil {
637+
if ids, ok := config[EncryptAlternateKmsKeyIDs]; ok {
638+
alternateKmsKeyIDs = strings.Split(ids, ",")
639+
}
640+
}
641+
if alternateKmsKeyIDs != nil {
642+
for _, id := range alternateKmsKeyIDs {
643+
id = strings.TrimSpace(id)
644+
if len(id) > 0 {
645+
kmsKeyIDs = append(kmsKeyIDs, id)
646+
}
647+
}
648+
}
649+
return kmsKeyIDs
650+
}
651+
652+
// AeadWrapper is a wrapper for AEAD
653+
type AeadWrapper struct {
654+
Config map[string]string
655+
Kek deks.Kek
656+
KmsKeyIds []string
657+
}
658+
659+
// Encrypt encrypts plaintext with associatedData as associated data.
660+
func (a *AeadWrapper) Encrypt(plaintext, associatedData []byte) ([]byte, error) {
661+
var aead tink.AEAD
662+
var err error
663+
var ciphertext []byte
664+
for _, kmsKeyID := range a.KmsKeyIds {
665+
aead, err = getAead(a.Config, a.Kek.KmsType, kmsKeyID)
666+
if err != nil {
667+
log.Printf("WARN: failed to get AEAD with %s: %v\n", kmsKeyID, err)
668+
continue
669+
}
670+
ciphertext, err = aead.Encrypt(plaintext, associatedData)
671+
if err == nil {
672+
return ciphertext, nil
673+
}
674+
log.Printf("WARN: failed to encrypt with %s: %v\n", kmsKeyID, err)
675+
}
676+
return nil, err
677+
}
678+
679+
// Decrypt decrypts ciphertext with associatedData as associated data.
680+
func (a *AeadWrapper) Decrypt(ciphertext, associatedData []byte) ([]byte, error) {
681+
var aead tink.AEAD
682+
var err error
683+
var plaintext []byte
684+
for _, kmsKeyID := range a.KmsKeyIds {
685+
aead, err = getAead(a.Config, a.Kek.KmsType, kmsKeyID)
686+
if err != nil {
687+
log.Printf("WARN: failed to get AEAD with %s: %v\n", kmsKeyID, err)
688+
continue
689+
}
690+
plaintext, err = aead.Decrypt(ciphertext, associatedData)
691+
if err == nil {
692+
return plaintext, nil
693+
}
694+
log.Printf("WARN: failed to decrypt with %s: %v\n", kmsKeyID, err)
695+
}
696+
return nil, err
697+
}
698+
699+
func getAead(config map[string]string, kmsType string, kmsKeyID string) (tink.AEAD, error) {
700+
kekURL := kmsType + "://" + kmsKeyID
634701
kmsClient, err := getKMSClient(config, kekURL)
635702
if err != nil {
636703
return nil, err

schemaregistry/serde/avrov2/avro_test.go

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1654,6 +1654,76 @@ func TestAvroSerdePayloadEncryption(t *testing.T) {
16541654
serde.MaybeFail("deserialization", err, serde.Expect(newobj, &obj))
16551655
}
16561656

1657+
func TestAvroSerdeEncryptionAlternateKeks(t *testing.T) {
1658+
serde.MaybeFail = serde.InitFailFunc(t)
1659+
var err error
1660+
1661+
conf := schemaregistry.NewConfig("mock://")
1662+
1663+
client, err := schemaregistry.NewClient(conf)
1664+
serde.MaybeFail("Schema Registry configuration", err)
1665+
1666+
serConfig := NewSerializerConfig()
1667+
serConfig.AutoRegisterSchemas = false
1668+
serConfig.UseLatestVersion = true
1669+
serConfig.RuleConfig = map[string]string{
1670+
"secret": "mysecret",
1671+
"encrypt.alternate.kms.key.ids": "mykey2,mykey3",
1672+
}
1673+
ser, err := NewSerializer(client, serde.ValueSerde, serConfig)
1674+
serde.MaybeFail("Serializer configuration", err)
1675+
1676+
encRule := schemaregistry.Rule{
1677+
Name: "test-encrypt",
1678+
Kind: "TRANSFORM",
1679+
Mode: "WRITEREAD",
1680+
Type: "ENCRYPT_PAYLOAD",
1681+
Params: map[string]string{
1682+
"encrypt.kek.name": "kek1",
1683+
"encrypt.kms.type": "local-kms",
1684+
"encrypt.kms.key.id": "mykey",
1685+
},
1686+
OnFailure: "ERROR,NONE",
1687+
}
1688+
ruleSet := schemaregistry.RuleSet{
1689+
EncodingRules: []schemaregistry.Rule{encRule},
1690+
}
1691+
1692+
info := schemaregistry.SchemaInfo{
1693+
Schema: demoSchema,
1694+
SchemaType: "AVRO",
1695+
RuleSet: &ruleSet,
1696+
}
1697+
1698+
id, err := client.Register("topic1-value", info, false)
1699+
serde.MaybeFail("Schema registration", err)
1700+
if id <= 0 {
1701+
t.Errorf("Expected valid schema id, found %d", id)
1702+
}
1703+
1704+
obj := DemoSchema{}
1705+
obj.IntField = 123
1706+
obj.DoubleField = 45.67
1707+
obj.StringField = "hi"
1708+
obj.BoolField = true
1709+
obj.BytesField = []byte{1, 2}
1710+
1711+
bytes, err := ser.Serialize("topic1", &obj)
1712+
serde.MaybeFail("serialization", err)
1713+
1714+
deserConfig := NewDeserializerConfig()
1715+
deserConfig.RuleConfig = map[string]string{
1716+
"secret": "mysecret",
1717+
}
1718+
deser, err := NewDeserializer(client, serde.ValueSerde, deserConfig)
1719+
serde.MaybeFail("Deserializer configuration", err)
1720+
deser.Client = ser.Client
1721+
deser.MessageFactory = testMessageFactory
1722+
1723+
newobj, err := deser.Deserialize("topic1", bytes)
1724+
serde.MaybeFail("deserialization", err, serde.Expect(newobj, &obj))
1725+
}
1726+
16571727
func TestAvroSerdeEncryptionDeterministic(t *testing.T) {
16581728
serde.MaybeFail = serde.InitFailFunc(t)
16591729
var err error

0 commit comments

Comments
 (0)