Skip to content

Commit cdbac88

Browse files
committed
feat!: adds context cancellation propagation to plugin operations
BREAKING_CHANGE: policy.Provider interface method updates to accept context Signed-off-by: Jennifer Power <barnabei.jennifer@gmail.com>
1 parent d11ee79 commit cdbac88

File tree

17 files changed

+175
-134
lines changed

17 files changed

+175
-134
lines changed

cmd/c2pcli/cli/subcommands/config.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"fmt"
1212
"os"
1313
"path/filepath"
14+
"time"
1415

1516
oscalTypes "github.com/defenseunicorns/go-oscal/src/types/oscal-1-1-3"
1617
"github.com/oscal-compass/oscal-sdk-go/models"
@@ -23,6 +24,9 @@ import (
2324
"github.com/oscal-compass/compliance-to-policy-go/v2/framework/actions"
2425
)
2526

27+
// Plugin running times might be highly variable this is maximum timeout value.
28+
var pluginTimeout = 5 * time.Minute
29+
2630
// Config returns a populated C2PConfig for the CLI to use.
2731
func Config(option *Options) (*framework.C2PConfig, error) {
2832
c2pConfig := framework.DefaultConfig()

cmd/c2pcli/cli/subcommands/oscal2policy.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,17 @@ func runOSCAL2Policy(ctx context.Context, option *Options) error {
7676
var configSelections framework.PluginConfig = func(pluginID plugin.ID) map[string]string {
7777
return option.Plugins[pluginID.String()]
7878
}
79-
launchedPlugins, err := manager.LaunchPolicyPlugins(foundPlugins, configSelections)
79+
launchedPlugins, err := manager.LaunchPolicyPlugins(ctx, foundPlugins, configSelections)
8080
// Defer clean before returning an error to avoid unterminated processes
8181
defer manager.Clean()
8282
if err != nil {
8383
return err
8484
}
8585

86-
err = actions.GeneratePolicy(ctx, inputContext, launchedPlugins)
86+
pluginCtx, cancel := context.WithTimeout(ctx, pluginTimeout)
87+
defer cancel()
88+
89+
err = actions.GeneratePolicy(pluginCtx, inputContext, launchedPlugins)
8790
if err != nil {
8891
return err
8992
}

cmd/c2pcli/cli/subcommands/result2oscal.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,14 +83,17 @@ func runResult2Policy(ctx context.Context, option *Options) error {
8383
var configSelections framework.PluginConfig = func(pluginID plugin.ID) map[string]string {
8484
return option.Plugins[pluginID.String()]
8585
}
86-
launchedPlugins, err := manager.LaunchPolicyPlugins(foundPlugins, configSelections)
86+
launchedPlugins, err := manager.LaunchPolicyPlugins(ctx, foundPlugins, configSelections)
8787
// Defer clean before returning an error to avoid unterminated processes
8888
defer manager.Clean()
8989
if err != nil {
9090
return err
9191
}
9292

93-
results, err := actions.AggregateResults(ctx, inputContext, launchedPlugins)
93+
pluginCtx, cancel := context.WithTimeout(ctx, pluginTimeout)
94+
defer cancel()
95+
96+
results, err := actions.AggregateResults(pluginCtx, inputContext, launchedPlugins)
9497
if err != nil {
9598
return err
9699
}

cmd/c2pcli/main.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,20 @@ limitations under the License.
1717
package main
1818

1919
import (
20+
"context"
2021
"os"
22+
"os/signal"
23+
"syscall"
2124

2225
"github.com/oscal-compass/compliance-to-policy-go/v2/cmd/c2pcli/cli"
2326
)
2427

2528
func main() {
2629
command := cli.New()
27-
err := command.Execute()
30+
ctx, cancel := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM)
31+
defer cancel()
32+
33+
err := command.ExecuteContext(ctx)
2834
if err != nil {
2935
os.Exit(1)
3036
}

cmd/kyverno-plugin/server/server.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package server
77

88
import (
9+
"context"
910
"errors"
1011
"fmt"
1112

@@ -34,14 +35,14 @@ func NewPlugin() *Plugin {
3435
return &Plugin{}
3536
}
3637

37-
func (p *Plugin) Configure(m map[string]string) error {
38+
func (p *Plugin) Configure(_ context.Context, m map[string]string) error {
3839
if err := mapstructure.Decode(m, &p.config); err != nil {
3940
return errors.New("error decoding configuration")
4041
}
4142
return p.config.Validate()
4243
}
4344

44-
func (p *Plugin) Generate(pl policy.Policy) error {
45+
func (p *Plugin) Generate(_ context.Context, pl policy.Policy) error {
4546
logger.Debug(fmt.Sprintf("Using resources from %s", p.config.PoliciesDir))
4647
tmpdir := utils.NewTempDirectory(p.config.TempDir)
4748
composer := NewOscal2Policy(p.config.PoliciesDir, tmpdir)
@@ -58,7 +59,7 @@ func (p *Plugin) Generate(pl policy.Policy) error {
5859
return nil
5960
}
6061

61-
func (p *Plugin) GetResults(pl policy.Policy) (policy.PVPResult, error) {
62+
func (p *Plugin) GetResults(_ context.Context, pl policy.Policy) (policy.PVPResult, error) {
6263
results := NewResultToOscal(pl, p.config.PolicyResultsDir)
6364
return results.GenerateResults()
6465
}

cmd/kyverno-plugin/server/server_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,12 @@ func TestConfigure(t *testing.T) {
4040
configuration := map[string]string{
4141
"policy-dir": "not-exist",
4242
}
43-
err := plugin.Configure(configuration)
43+
err := plugin.Configure(context.Background(), configuration)
4444
require.EqualError(t, err, "path \"not-exist\": stat not-exist: no such file or directory")
4545

4646
policyDir := utils.PathFromInternalDirectory("./testdata/kyverno/policy-resources")
4747
configuration["policy-dir"] = policyDir
48-
err = plugin.Configure(configuration)
48+
err = plugin.Configure(context.Background(), configuration)
4949
require.NoError(t, err)
5050
}
5151

cmd/ocm-plugin/server/server.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
package server
77

88
import (
9+
"context"
910
"errors"
1011
"os"
1112
"strings"
@@ -38,14 +39,14 @@ func NewPlugin() *Plugin {
3839
}
3940
}
4041

41-
func (p *Plugin) Configure(m map[string]string) error {
42+
func (p *Plugin) Configure(_ context.Context, m map[string]string) error {
4243
if err := mapstructure.Decode(m, &p.config); err != nil {
4344
return errors.New("error decoding configuration")
4445
}
4546
return p.config.Validate()
4647
}
4748

48-
func (p *Plugin) Generate(pl policy.Policy) error {
49+
func (p *Plugin) Generate(_ context.Context, pl policy.Policy) error {
4950
tmpdir := utils.NewTempDirectory(p.config.TempDir)
5051
composer := NewComposerByTempDirectory(p.config.PoliciesDir, tmpdir)
5152
if err := composer.ComposeByPolicies(pl, p.config); err != nil {
@@ -79,7 +80,7 @@ func (p *Plugin) Generate(pl policy.Policy) error {
7980
return nil
8081
}
8182

82-
func (p *Plugin) GetResults(pl policy.Policy) (policy.PVPResult, error) {
83+
func (p *Plugin) GetResults(_ context.Context, pl policy.Policy) (policy.PVPResult, error) {
8384
results := NewResultToOscal(pl, p.config.PolicyResultsDir, p.config.Namespace, p.config.PolicySetName)
8485
return results.GenerateResults()
8586
}

cmd/ocm-plugin/server/server_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func TestOscal2Policy(t *testing.T) {
4141
plugin.config.TempDir = tempDir.GetTempDir()
4242
plugin.config.OutputDir = tmpOutputDir
4343
plugin.config.PolicyResultsDir = tmpOutputDir
44-
require.NoError(t, plugin.Generate(testPolicy))
44+
require.NoError(t, plugin.Generate(context.Background(), testPolicy))
4545
}
4646

4747
func TestResult2Oscal(t *testing.T) {
@@ -154,17 +154,17 @@ func TestConfigure(t *testing.T) {
154154
configuration := map[string]string{
155155
"policy-dir": policyDir,
156156
}
157-
err := plugin.Configure(configuration)
157+
err := plugin.Configure(context.Background(), configuration)
158158
require.EqualError(t, err, "policy set name must be set")
159159

160160
configuration["policy-set-name"] = "set"
161161

162162
configuration["policy-dir"] = "not-exist"
163-
err = plugin.Configure(configuration)
163+
err = plugin.Configure(context.Background(), configuration)
164164
require.EqualError(t, err, "path \"not-exist\": stat not-exist: no such file or directory")
165165

166166
configuration["policy-dir"] = policyDir
167-
err = plugin.Configure(configuration)
167+
err = plugin.Configure(context.Background(), configuration)
168168
require.NoError(t, err)
169169
}
170170

framework/actions/aggregate.go

Lines changed: 32 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,9 @@ import (
99
"context"
1010
"errors"
1111
"fmt"
12-
"sync"
1312

1413
"github.com/oscal-compass/oscal-sdk-go/settings"
15-
"golang.org/x/sync/semaphore"
14+
"golang.org/x/sync/errgroup"
1615

1716
"github.com/oscal-compass/compliance-to-policy-go/v2/logging"
1817
"github.com/oscal-compass/compliance-to-policy-go/v2/plugin"
@@ -27,65 +26,53 @@ func AggregateResults(ctx context.Context, inputContext *InputContext, pluginSet
2726
log := logging.GetLogger("aggregator")
2827

2928
var allResults []policy.PVPResult
30-
sem := semaphore.NewWeighted(inputContext.MaxConcurrentWeight)
31-
var wg sync.WaitGroup
32-
errorCh := make(chan error, len(pluginSet))
3329
resultChan := make(chan policy.PVPResult, len(pluginSet))
3430

31+
eg, egCtx := errgroup.WithContext(ctx)
32+
eg.SetLimit(inputContext.MaxConcurrency)
3533
for providerId, policyPlugin := range pluginSet {
36-
37-
wg.Add(1)
38-
39-
go func(providerId plugin.ID, plugin policy.Provider) {
40-
defer wg.Done()
41-
42-
if err := sem.Acquire(ctx, 1); err != nil {
43-
errorCh <- fmt.Errorf("%s failed to acquire semaphore: %w", providerId.String(), err)
44-
return
45-
}
46-
defer sem.Release(1)
47-
48-
componentTitle, err := inputContext.ProviderTitle(providerId)
49-
if err != nil {
50-
if errors.Is(err, ErrMissingProvider) {
51-
log.Warn(fmt.Sprintf("skipping %s provider: missing validation component", providerId))
52-
return
34+
func(providerId plugin.ID, plugin policy.Provider) {
35+
eg.Go(func() error {
36+
select {
37+
case <-egCtx.Done():
38+
return fmt.Errorf("%s skipped due to context cancellation/timeout: %w", providerId.String(), egCtx.Err())
39+
default:
5340
}
54-
errorCh <- err
55-
return
5641

57-
}
58-
log.Debug(fmt.Sprintf("Aggregating results for provider %s", providerId))
42+
componentTitle, err := inputContext.ProviderTitle(providerId)
43+
if err != nil {
44+
if errors.Is(err, ErrMissingProvider) {
45+
log.Warn(fmt.Sprintf("skipping %s provider: missing validation component", providerId))
46+
return nil
47+
}
48+
return err
49+
}
50+
log.Debug(fmt.Sprintf("Aggregating results for provider %s", providerId))
5951

60-
appliedRuleSet, err := settings.ApplyToComponent(ctx, componentTitle, inputContext.Store(), inputContext.Settings)
61-
if err != nil {
62-
errorCh <- fmt.Errorf("failed to get rule sets for component %s: %w", componentTitle, err)
63-
return
64-
}
52+
appliedRuleSet, err := settings.ApplyToComponent(egCtx, componentTitle, inputContext.Store(), inputContext.Settings)
53+
if err != nil {
54+
return fmt.Errorf("failed to get rule sets for component %s: %w", componentTitle, err)
55+
}
6556

66-
pluginResults, err := policyPlugin.GetResults(appliedRuleSet)
67-
if err != nil {
68-
errorCh <- fmt.Errorf("plugin %s: %w", providerId, err)
69-
return
70-
}
71-
resultChan <- pluginResults
57+
pluginResults, err := policyPlugin.GetResults(egCtx, appliedRuleSet)
58+
if err != nil {
59+
return err
60+
}
61+
resultChan <- pluginResults
62+
return nil
63+
})
7264
}(providerId, policyPlugin)
7365
}
7466

67+
var err error
7568
go func() {
76-
wg.Wait()
77-
close(errorCh)
69+
err = eg.Wait()
7870
close(resultChan)
7971
}()
8072

81-
var errs []error
82-
for err := range errorCh {
83-
errs = append(errs, err)
84-
}
85-
8673
for result := range resultChan {
8774
allResults = append(allResults, result)
8875
}
8976

90-
return allResults, errors.Join(errs...)
77+
return allResults, err
9178
}

framework/actions/aggregate_test.go

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ package actions
77

88
import (
99
"context"
10+
"errors"
1011
"os"
1112
"sort"
1213
"testing"
14+
"time"
1315

1416
"github.com/oscal-compass/oscal-sdk-go/extensions"
1517
"github.com/oscal-compass/oscal-sdk-go/models"
@@ -165,30 +167,77 @@ func TestAggregateResults_Multi(t *testing.T) {
165167
providerTestObj.AssertExpectations(t)
166168
providerTestObj2.AssertExpectations(t)
167169
require.Len(t, gotResults, 2)
170+
171+
// Test with error
172+
providerTestObj3 := new(policyProvider)
173+
providerTestObj3.On("GetResults", policy.Policy{kyvernoRule}).Return(policy.PVPResult{}, errors.New("failed"))
174+
pluginSet = map[plugin.ID]policy.Provider{
175+
"ocm": providerTestObj,
176+
"kyverno": providerTestObj3,
177+
}
178+
179+
gotResults, err = AggregateResults(context.Background(), inputContext, pluginSet)
180+
require.EqualError(t, err, "failed")
181+
providerTestObj.AssertExpectations(t)
182+
providerTestObj3.AssertExpectations(t)
183+
require.Len(t, gotResults, 1)
184+
185+
// Test with cancellation
186+
done := make(chan struct{})
187+
ctx, cancel := context.WithCancel(context.Background())
188+
189+
providerTestObj3.delay = 500 * time.Millisecond
190+
191+
go func() {
192+
gotResults, err = AggregateResults(ctx, inputContext, pluginSet)
193+
close(done)
194+
}()
195+
196+
// Wait for a short period to allow some goroutines to start
197+
time.Sleep(100 * time.Millisecond)
198+
199+
// Now, cancel.
200+
cancel()
201+
202+
select {
203+
case <-done:
204+
require.EqualError(t, err, "context canceled")
205+
require.Len(t, gotResults, 1)
206+
case <-time.After(2 * time.Second):
207+
t.Fatal("error: did not after cancellation signal within timeout")
208+
}
209+
168210
}
169211

170212
// policyProvider is a mocked implementation of policy.Provider.
171213
type policyProvider struct {
172214
mock.Mock
215+
delay time.Duration
173216
}
174217

175-
func (p *policyProvider) Configure(option map[string]string) error {
218+
func (p *policyProvider) Configure(_ context.Context, option map[string]string) error {
176219
args := p.Called(option)
177220
return args.Error(0)
178221
}
179222

180-
func (p *policyProvider) Generate(policyRules policy.Policy) error {
223+
func (p *policyProvider) Generate(_ context.Context, policyRules policy.Policy) error {
181224
sort.SliceStable(policyRules, func(i, j int) bool {
182225
return policyRules[i].Rule.ID > policyRules[j].Rule.ID
183226
})
184227
args := p.Called(policyRules)
185228
return args.Error(0)
186229
}
187230

188-
func (p *policyProvider) GetResults(policyRules policy.Policy) (policy.PVPResult, error) {
231+
func (p *policyProvider) GetResults(ctx context.Context, policyRules policy.Policy) (policy.PVPResult, error) {
189232
sort.SliceStable(policyRules, func(i, j int) bool {
190233
return policyRules[i].Rule.ID > policyRules[j].Rule.ID
191234
})
192235
args := p.Called(policyRules)
193-
return args.Get(0).(policy.PVPResult), args.Error(1)
236+
237+
select {
238+
case <-ctx.Done():
239+
return policy.PVPResult{}, ctx.Err()
240+
case <-time.After(p.delay): // Simulate completing work
241+
return args.Get(0).(policy.PVPResult), args.Error(1)
242+
}
194243
}

0 commit comments

Comments
 (0)