From e368fa2f7b8351b93e1ef816f802e04e51c5d214 Mon Sep 17 00:00:00 2001 From: Robert Yokota Date: Tue, 5 Aug 2025 15:31:08 -0700 Subject: [PATCH 1/3] First cut --- .../rules/encryption/encrypt_executor.go | 79 ++++++++++++++++--- 1 file changed, 69 insertions(+), 10 deletions(-) diff --git a/schemaregistry/rules/encryption/encrypt_executor.go b/schemaregistry/rules/encryption/encrypt_executor.go index 4c34ee920..b6fd6540a 100644 --- a/schemaregistry/rules/encryption/encrypt_executor.go +++ b/schemaregistry/rules/encryption/encrypt_executor.go @@ -77,6 +77,8 @@ const ( EncryptDekAlgorithm = "encrypt.dek.algorithm" // EncryptDekExpiryDays represents dek expiry days EncryptDekExpiryDays = "encrypt.dek.expiry.days" + // EncryptAlternateKmsKeyIDs represents alternate kms key ids + EncryptAlternateKmsKeyIDs = "encrypt.alternate.kms.key.ids" // Aes128Gcm represents AES128_GCM algorithm Aes128Gcm = "AES128_GCM" @@ -394,10 +396,7 @@ func (f *ExecutorTransform) getOrCreateDek(ctx serde.RuleContext, version *int) } var encryptedDek []byte if !f.Kek.Shared { - primitive, err = getAead(f.Executor.Config, f.Kek) - if err != nil { - return nil, err - } + primitive = &AeadWrapper{f.Executor.Config, f.Kek, getKmsKeyIDs(f.Kek)} // Generate new dek keyData, err := registry.NewKeyData(f.Cryptor.KeyTemplate) if err != nil { @@ -431,10 +430,7 @@ func (f *ExecutorTransform) getOrCreateDek(ctx serde.RuleContext, version *int) } if keyBytes == nil { if primitive == nil { - primitive, err = getAead(f.Executor.Config, f.Kek) - if err != nil { - return nil, err - } + primitive = &AeadWrapper{f.Executor.Config, f.Kek, getKmsKeyIDs(f.Kek)} } encryptedDek, err := f.Executor.Client.GetDekEncryptedKeyMaterialBytes(dek) if err != nil { @@ -629,8 +625,71 @@ func extractVersion(ciphertext []byte) (int, error) { return int(version), nil } -func getAead(config map[string]string, kek deks.Kek) (tink.AEAD, error) { - kekURL := kek.KmsType + "://" + kek.KmsKeyID +func getKmsKeyIDs(kek deks.Kek) []string { + kmsKeyIDs := []string{kek.KmsKeyID} + if kek.KmsProps != nil { + if alternateKmsKeyIDs, ok := kek.KmsProps[EncryptAlternateKmsKeyIDs]; ok { + ids := strings.Split(alternateKmsKeyIDs, ",") + for _, id := range ids { + id = strings.TrimSpace(id) + if len(id) > 0 { + kmsKeyIDs = append(kmsKeyIDs, id) + } + } + } + } + return kmsKeyIDs +} + +// AeadWrapper is a wrapper for AEAD +type AeadWrapper struct { + Config map[string]string + Kek deks.Kek + KmsKeyIds []string +} + +// Encrypt encrypts plaintext with associatedData as associated data. +func (a *AeadWrapper) Encrypt(plaintext, associatedData []byte) ([]byte, error) { + var aead tink.AEAD + var err error + var ciphertext []byte + for _, kmsKeyID := range a.KmsKeyIds { + aead, err = getAead(a.Config, a.Kek.KmsType, kmsKeyID) + if err != nil { + log.Printf("WARN: failed to get AEAD with %s: %v\n", kmsKeyID, err) + continue + } + ciphertext, err = aead.Encrypt(plaintext, associatedData) + if err == nil { + return ciphertext, nil + } + log.Printf("WARN: failed to encrypt with %s: %v\n", kmsKeyID, err) + } + return nil, err +} + +// Decrypt decrypts ciphertext with associatedData as associated data. +func (a *AeadWrapper) Decrypt(ciphertext, associatedData []byte) ([]byte, error) { + var aead tink.AEAD + var err error + var plaintext []byte + for _, kmsKeyID := range a.KmsKeyIds { + aead, err = getAead(a.Config, a.Kek.KmsType, kmsKeyID) + if err != nil { + log.Printf("WARN: failed to get AEAD with %s: %v\n", kmsKeyID, err) + continue + } + plaintext, err = aead.Decrypt(ciphertext, associatedData) + if err == nil { + return plaintext, nil + } + log.Printf("WARN: failed to decrypt with %s: %v\n", kmsKeyID, err) + } + return nil, err +} + +func getAead(config map[string]string, kmsType string, kmsKeyID string) (tink.AEAD, error) { + kekURL := kmsType + "://" + kmsKeyID kmsClient, err := getKMSClient(config, kekURL) if err != nil { return nil, err From 3426bc6d1dd10cf49bb1b95e122b6d74d56de9ad Mon Sep 17 00:00:00 2001 From: Robert Yokota Date: Fri, 8 Aug 2025 15:59:29 -0700 Subject: [PATCH 2/3] Add test --- .../rules/encryption/encrypt_executor.go | 26 ++++--- schemaregistry/serde/avrov2/avro_test.go | 70 +++++++++++++++++++ 2 files changed, 87 insertions(+), 9 deletions(-) diff --git a/schemaregistry/rules/encryption/encrypt_executor.go b/schemaregistry/rules/encryption/encrypt_executor.go index b6fd6540a..e79967a9d 100644 --- a/schemaregistry/rules/encryption/encrypt_executor.go +++ b/schemaregistry/rules/encryption/encrypt_executor.go @@ -396,7 +396,7 @@ func (f *ExecutorTransform) getOrCreateDek(ctx serde.RuleContext, version *int) } var encryptedDek []byte if !f.Kek.Shared { - primitive = &AeadWrapper{f.Executor.Config, f.Kek, getKmsKeyIDs(f.Kek)} + primitive = &AeadWrapper{f.Executor.Config, f.Kek, getKmsKeyIDs(f.Executor.Config, f.Kek)} // Generate new dek keyData, err := registry.NewKeyData(f.Cryptor.KeyTemplate) if err != nil { @@ -430,7 +430,7 @@ func (f *ExecutorTransform) getOrCreateDek(ctx serde.RuleContext, version *int) } if keyBytes == nil { if primitive == nil { - primitive = &AeadWrapper{f.Executor.Config, f.Kek, getKmsKeyIDs(f.Kek)} + primitive = &AeadWrapper{f.Executor.Config, f.Kek, getKmsKeyIDs(f.Executor.Config, f.Kek)} } encryptedDek, err := f.Executor.Client.GetDekEncryptedKeyMaterialBytes(dek) if err != nil { @@ -625,16 +625,24 @@ func extractVersion(ciphertext []byte) (int, error) { return int(version), nil } -func getKmsKeyIDs(kek deks.Kek) []string { +func getKmsKeyIDs(config map[string]string, kek deks.Kek) []string { kmsKeyIDs := []string{kek.KmsKeyID} + var ids []string if kek.KmsProps != nil { if alternateKmsKeyIDs, ok := kek.KmsProps[EncryptAlternateKmsKeyIDs]; ok { - ids := strings.Split(alternateKmsKeyIDs, ",") - for _, id := range ids { - id = strings.TrimSpace(id) - if len(id) > 0 { - kmsKeyIDs = append(kmsKeyIDs, id) - } + ids = strings.Split(alternateKmsKeyIDs, ",") + } + } + if ids == nil { + if alternateKmsKeyIDs, ok := config[EncryptAlternateKmsKeyIDs]; ok { + ids = strings.Split(alternateKmsKeyIDs, ",") + } + } + if ids != nil { + for _, id := range ids { + id = strings.TrimSpace(id) + if len(id) > 0 { + kmsKeyIDs = append(kmsKeyIDs, id) } } } diff --git a/schemaregistry/serde/avrov2/avro_test.go b/schemaregistry/serde/avrov2/avro_test.go index fae815a17..c4f3f5cd2 100644 --- a/schemaregistry/serde/avrov2/avro_test.go +++ b/schemaregistry/serde/avrov2/avro_test.go @@ -1654,6 +1654,76 @@ func TestAvroSerdePayloadEncryption(t *testing.T) { serde.MaybeFail("deserialization", err, serde.Expect(newobj, &obj)) } +func TestAvroSerdeEncryptionAlternateKeks(t *testing.T) { + serde.MaybeFail = serde.InitFailFunc(t) + var err error + + conf := schemaregistry.NewConfig("mock://") + + client, err := schemaregistry.NewClient(conf) + serde.MaybeFail("Schema Registry configuration", err) + + serConfig := NewSerializerConfig() + serConfig.AutoRegisterSchemas = false + serConfig.UseLatestVersion = true + serConfig.RuleConfig = map[string]string{ + "secret": "mysecret", + "encrypt.alternate.kms.key.ids": "mykey2,mykey3", + } + ser, err := NewSerializer(client, serde.ValueSerde, serConfig) + serde.MaybeFail("Serializer configuration", err) + + encRule := schemaregistry.Rule{ + Name: "test-encrypt", + Kind: "TRANSFORM", + Mode: "WRITEREAD", + Type: "ENCRYPT_PAYLOAD", + Params: map[string]string{ + "encrypt.kek.name": "kek1", + "encrypt.kms.type": "local-kms", + "encrypt.kms.key.id": "mykey", + }, + OnFailure: "ERROR,NONE", + } + ruleSet := schemaregistry.RuleSet{ + EncodingRules: []schemaregistry.Rule{encRule}, + } + + info := schemaregistry.SchemaInfo{ + Schema: demoSchema, + SchemaType: "AVRO", + RuleSet: &ruleSet, + } + + id, err := client.Register("topic1-value", info, false) + serde.MaybeFail("Schema registration", err) + if id <= 0 { + t.Errorf("Expected valid schema id, found %d", id) + } + + obj := DemoSchema{} + obj.IntField = 123 + obj.DoubleField = 45.67 + obj.StringField = "hi" + obj.BoolField = true + obj.BytesField = []byte{1, 2} + + bytes, err := ser.Serialize("topic1", &obj) + serde.MaybeFail("serialization", err) + + deserConfig := NewDeserializerConfig() + deserConfig.RuleConfig = map[string]string{ + "secret": "mysecret", + } + deser, err := NewDeserializer(client, serde.ValueSerde, deserConfig) + serde.MaybeFail("Deserializer configuration", err) + deser.Client = ser.Client + deser.MessageFactory = testMessageFactory + + newobj, err := deser.Deserialize("topic1", bytes) + serde.MaybeFail("deserialization", err, serde.Expect(newobj, &obj)) +} + func TestAvroSerdeEncryptionDeterministic(t *testing.T) { serde.MaybeFail = serde.InitFailFunc(t) var err error From d979c645a52ae1bad571a854bb1ad716ce374334 Mon Sep 17 00:00:00 2001 From: Robert Yokota Date: Fri, 8 Aug 2025 16:11:25 -0700 Subject: [PATCH 3/3] Add test --- examples/docker_aws_lambda_example/go.sum | 1 + .../rules/encryption/encrypt_executor.go | 16 ++++++++-------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/docker_aws_lambda_example/go.sum b/examples/docker_aws_lambda_example/go.sum index 9be073518..db85de32c 100644 --- a/examples/docker_aws_lambda_example/go.sum +++ b/examples/docker_aws_lambda_example/go.sum @@ -53,6 +53,7 @@ github.com/confluentinc/confluent-kafka-go/v2 v2.4.0 h1:NbOku86JJlsRJPJKE0snNsz6 github.com/confluentinc/confluent-kafka-go/v2 v2.4.0/go.mod h1:E1dEQy50ZLfqs7T9luxz0rLxaeFZJZE92XvApJOr/Rk= github.com/confluentinc/confluent-kafka-go/v2 v2.5.0/go.mod h1:Hyo+IIQ/tmsfkOcRP8T6VlSeOW3T33v0Me8Xvq4u90Y= github.com/confluentinc/confluent-kafka-go/v2 v2.5.3/go.mod h1:QxYLPRKR1MVlkXCCjzjjrpXb0VyFNfVaZXi0obZykJ0= +github.com/confluentinc/confluent-kafka-go/v2 v2.11.0/go.mod h1:hScqtFIGUI1wqHIgM3mjoqEou4VweGGGX7dMpcUKves= github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARubLw= github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U= github.com/containerd/containerd v1.7.12 h1:+KQsnv4VnzyxWcfO9mlxxELaoztsDEjOuCMPAuPqgU0= diff --git a/schemaregistry/rules/encryption/encrypt_executor.go b/schemaregistry/rules/encryption/encrypt_executor.go index e79967a9d..996260e7a 100644 --- a/schemaregistry/rules/encryption/encrypt_executor.go +++ b/schemaregistry/rules/encryption/encrypt_executor.go @@ -627,19 +627,19 @@ func extractVersion(ciphertext []byte) (int, error) { func getKmsKeyIDs(config map[string]string, kek deks.Kek) []string { kmsKeyIDs := []string{kek.KmsKeyID} - var ids []string + var alternateKmsKeyIDs []string if kek.KmsProps != nil { - if alternateKmsKeyIDs, ok := kek.KmsProps[EncryptAlternateKmsKeyIDs]; ok { - ids = strings.Split(alternateKmsKeyIDs, ",") + if ids, ok := kek.KmsProps[EncryptAlternateKmsKeyIDs]; ok { + alternateKmsKeyIDs = strings.Split(ids, ",") } } - if ids == nil { - if alternateKmsKeyIDs, ok := config[EncryptAlternateKmsKeyIDs]; ok { - ids = strings.Split(alternateKmsKeyIDs, ",") + if alternateKmsKeyIDs == nil { + if ids, ok := config[EncryptAlternateKmsKeyIDs]; ok { + alternateKmsKeyIDs = strings.Split(ids, ",") } } - if ids != nil { - for _, id := range ids { + if alternateKmsKeyIDs != nil { + for _, id := range alternateKmsKeyIDs { id = strings.TrimSpace(id) if len(id) > 0 { kmsKeyIDs = append(kmsKeyIDs, id)