Skip to content

Commit f032849

Browse files
Hamsajjmostafa
andauthored
feat(464): add support for queueing async actions in background (#544)
* feat(464): add support for queueing async actions in background * Tidy deps * Fix bug in using goroutines for running async actions * Add log message to see if Redis is enabled for publishing async tasks * Add default values for the config options --------- Co-authored-by: Mostafa Moradian <[email protected]>
1 parent bd11be5 commit f032849

File tree

13 files changed

+442
-346
lines changed

13 files changed

+442
-346
lines changed

.golangci.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ linters-settings:
7474
- "golang.org/x/text/cases"
7575
- "golang.org/x/text/language"
7676
- "gopkg.in/yaml.v2"
77+
- "github.com/redis/go-redis/v9"
7778
test:
7879
files:
7980
- $test
@@ -92,6 +93,8 @@ linters-settings:
9293
- "github.com/knadh/koanf"
9394
- "github.com/spf13/cast"
9495
- "github.com/jackc/pgx/v5/pgproto3"
96+
- "github.com/testcontainers/testcontainers-go"
97+
- "github.com/redis/go-redis/v9"
9598
tagalign:
9699
align: false
97100
sort: false

act/act_helpers_test.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
package act
22

33
import (
4+
"context"
5+
"testing"
46
"time"
57

68
sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act"
9+
"github.com/stretchr/testify/assert"
10+
"github.com/testcontainers/testcontainers-go"
11+
"github.com/testcontainers/testcontainers-go/modules/redis"
712
)
813

914
func createWaitActEntities(async bool) (
@@ -49,3 +54,20 @@ func createWaitActEntities(async bool) (
4954

5055
return name, actions, signals, policy
5156
}
57+
58+
func createTestRedis(t *testing.T) string {
59+
t.Helper()
60+
ctx := context.Background()
61+
62+
redisContainer, err := redis.RunContainer(ctx, testcontainers.WithImage("redis:6"))
63+
64+
assert.NoError(t, err)
65+
t.Cleanup(func() {
66+
assert.NoError(t, redisContainer.Terminate(ctx))
67+
})
68+
host, err := redisContainer.Host(ctx)
69+
assert.NoError(t, err)
70+
port, err := redisContainer.MappedPort(ctx, "6379/tcp")
71+
assert.NoError(t, err)
72+
return host + ":" + port.Port()
73+
}

act/publisher.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
package act
2+
3+
import (
4+
"context"
5+
"fmt"
6+
7+
"github.com/redis/go-redis/v9"
8+
"github.com/rs/zerolog"
9+
)
10+
11+
type IPublisher interface {
12+
Publish(ctx context.Context, payload []byte) error
13+
}
14+
15+
var _ IPublisher = (*Publisher)(nil)
16+
17+
type Publisher struct {
18+
Logger zerolog.Logger
19+
RedisDB redis.Cmdable
20+
ChannelName string
21+
}
22+
23+
func NewPublisher(publisher Publisher) (*Publisher, error) {
24+
if err := publisher.RedisDB.Ping(context.Background()).Err(); err != nil {
25+
publisher.Logger.Err(err).Msg("failed to connect redis")
26+
}
27+
return &Publisher{
28+
Logger: publisher.Logger,
29+
RedisDB: publisher.RedisDB,
30+
ChannelName: publisher.ChannelName,
31+
}, nil
32+
}
33+
34+
func (p *Publisher) Publish(ctx context.Context, payload []byte) error {
35+
if err := p.RedisDB.Publish(ctx, p.ChannelName, payload).Err(); err != nil {
36+
p.Logger.Err(err).Str("ChannelName", p.ChannelName).Msg("failed to publish task to redis")
37+
return fmt.Errorf("failed to publish task to redis: %w", err)
38+
}
39+
return nil
40+
}

act/registry.go

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@ package act
22

33
import (
44
"context"
5+
"encoding/json"
56
"errors"
7+
"fmt"
68
"slices"
79
"time"
810

@@ -26,6 +28,10 @@ type Registry struct {
2628
// Default timeout for running actions
2729
DefaultActionTimeout time.Duration
2830

31+
// TaskPublisher is the publisher for async actions.
32+
// if not given, will invoke simple goroutine to run async actions
33+
TaskPublisher *Publisher
34+
2935
Signals map[string]*sdkAct.Signal
3036
Policies map[string]*sdkAct.Policy
3137
Actions map[string]*sdkAct.Action
@@ -34,6 +40,27 @@ type Registry struct {
3440
DefaultSignal *sdkAct.Signal
3541
}
3642

43+
type AsyncActionMessage struct {
44+
Output *sdkAct.Output
45+
Params []sdkAct.Parameter
46+
}
47+
48+
// Encode marshals the AsyncActionMessage struct to JSON bytes.
49+
func (msg *AsyncActionMessage) Encode() ([]byte, error) {
50+
marshaled, err := json.Marshal(msg)
51+
if err != nil {
52+
return nil, fmt.Errorf("error encoding JSON: %w", err)
53+
}
54+
return marshaled, nil
55+
}
56+
57+
func (msg *AsyncActionMessage) Decode(data []byte) error {
58+
if err := json.Unmarshal(data, msg); err != nil {
59+
return fmt.Errorf("error decoding JSON: %w", err)
60+
}
61+
return nil
62+
}
63+
3764
var _ IRegistry = (*Registry)(nil)
3865

3966
// NewActRegistry creates a new act registry with the specified default policy and timeout
@@ -88,6 +115,7 @@ func NewActRegistry(
88115
Actions: registry.Actions,
89116
DefaultPolicy: registry.Policies[registry.DefaultPolicyName],
90117
DefaultSignal: registry.Signals[registry.DefaultPolicyName],
118+
TaskPublisher: registry.TaskPublisher,
91119
}
92120
}
93121

@@ -234,6 +262,18 @@ func (r *Registry) Run(
234262
if action.Timeout > 0 {
235263
timeout = time.Duration(action.Timeout) * time.Second
236264
}
265+
266+
// if task is async and publisher is configured, publish it and do not run it
267+
if r.TaskPublisher != nil && !action.Sync {
268+
err := r.publishTask(output, params)
269+
if err != nil {
270+
r.Logger.Error().Err(err).Msg("Error publishing async action")
271+
return nil, gerr.ErrPublishingAsyncAction
272+
}
273+
return nil, gerr.ErrAsyncAction
274+
}
275+
276+
// no publisher, or sync action. run the action
237277
var ctx context.Context
238278
var cancel context.CancelFunc
239279
// if timeout is zero, then the context should not have timeout
@@ -248,14 +288,83 @@ func (r *Registry) Run(
248288
return runActionWithTimeout(ctx, action, output, params, r.Logger)
249289
}
250290

251-
// Run the action asynchronously.
291+
// If the action is asynchronous, run it in a goroutine and return the sentinel error.
252292
go func() {
253293
defer cancel()
254294
_, _ = runActionWithTimeout(ctx, action, output, params, r.Logger)
255295
}()
296+
256297
return nil, gerr.ErrAsyncAction
257298
}
258299

300+
func (r *Registry) publishTask(output *sdkAct.Output, params []sdkAct.Parameter) error {
301+
r.Logger.Debug().Msg("Publishing async action")
302+
task := AsyncActionMessage{
303+
Output: output,
304+
Params: params,
305+
}
306+
payload, err := task.Encode()
307+
if err != nil {
308+
return err
309+
}
310+
if err := r.TaskPublisher.Publish(context.Background(), payload); err != nil {
311+
return fmt.Errorf("error publishing task: %w", err)
312+
}
313+
return nil
314+
}
315+
316+
func (r *Registry) runAsyncActionFn(ctx context.Context, message []byte) error {
317+
msg := &AsyncActionMessage{}
318+
if err := msg.Decode(message); err != nil {
319+
r.Logger.Error().Err(err).Msg("Error decoding message")
320+
return err
321+
}
322+
output := msg.Output
323+
params := msg.Params
324+
325+
// In certain cases, the output may be nil, for example, if the policy
326+
// evaluation fails. In this case, the run is aborted.
327+
if output == nil {
328+
// This should never happen, since the output is always set by the registry
329+
// to be the default policy if no signals are provided.
330+
r.Logger.Debug().Msg("Output is nil, run aborted")
331+
return gerr.ErrNilPointer
332+
}
333+
334+
action, ok := r.Actions[output.MatchedPolicy]
335+
if !ok {
336+
r.Logger.Warn().Str("matchedPolicy", output.MatchedPolicy).Msg(
337+
"Action does not exist, run aborted")
338+
return gerr.ErrActionNotExist
339+
}
340+
341+
// Prepend the logger to the parameters if needed.
342+
if len(params) == 0 || params[0].Key != LoggerKey {
343+
params = append([]sdkAct.Parameter{WithLogger(r.Logger)}, params...)
344+
} else {
345+
params[0] = WithLogger(r.Logger)
346+
}
347+
348+
timeout := r.DefaultActionTimeout
349+
if action.Timeout > 0 {
350+
timeout = time.Duration(action.Timeout) * time.Second
351+
}
352+
var ctxWithTimeout context.Context
353+
var cancel context.CancelFunc
354+
// if timeout is zero, then the context should not have timeout
355+
if timeout > 0 {
356+
ctxWithTimeout, cancel = context.WithTimeout(ctx, timeout)
357+
} else {
358+
ctxWithTimeout, cancel = context.WithCancel(ctx)
359+
}
360+
// If the action is synchronous, run it and return the result immediately.
361+
defer cancel()
362+
if _, err := runActionWithTimeout(ctxWithTimeout, action, output, params, r.Logger); err != nil {
363+
return err
364+
}
365+
return nil
366+
}
367+
259368
func runActionWithTimeout(
260369
ctx context.Context,
261370
action *sdkAct.Action,
@@ -293,7 +402,7 @@ func runActionWithTimeout(
293402
}
294403
}
295404

296-
// WithLogger returns a parameter with the logger to be used by the action.
405+
// WithLogger returns a parameter with the Logger to be used by the action.
297406
// This is automatically prepended to the parameters when running an action.
298407
func WithLogger(logger zerolog.Logger) sdkAct.Parameter {
299408
return sdkAct.Parameter{

act/registry_test.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,16 @@ package act
22

33
import (
44
"bytes"
5+
"context"
6+
"sync"
57
"testing"
68
"time"
79

810
sdkAct "github.com/gatewayd-io/gatewayd-plugin-sdk/act"
911
"github.com/gatewayd-io/gatewayd/config"
1012
gerr "github.com/gatewayd-io/gatewayd/errors"
13+
"github.com/hashicorp/go-hclog"
14+
"github.com/redis/go-redis/v9"
1115
"github.com/rs/zerolog"
1216
"github.com/spf13/cast"
1317
"github.com/stretchr/testify/assert"
@@ -705,6 +709,88 @@ func Test_Run_Async(t *testing.T) {
705709
assert.Contains(t, out.String(), "{\"level\":\"info\",\"async\":true,\"message\":\"test\"}")
706710
}
707711

712+
// Test_Run_Async tests the Run function of the act registry with an asynchronous action.
713+
func Test_Run_Async_Redis(t *testing.T) {
714+
out := bytes.Buffer{}
715+
logger := zerolog.New(&out)
716+
hclogger := hclog.New(&hclog.LoggerOptions{
717+
Output: &out,
718+
Level: hclog.Debug,
719+
JSONFormat: true,
720+
})
721+
722+
rdbAddr := createTestRedis(t)
723+
rdb := redis.NewClient(&redis.Options{
724+
Addr: rdbAddr,
725+
})
726+
publisher, err := NewPublisher(Publisher{
727+
Logger: logger,
728+
RedisDB: rdb,
729+
ChannelName: "test-async-chan",
730+
})
731+
require.NoError(t, err)
732+
733+
var waitGroup sync.WaitGroup
734+
actRegistry := NewActRegistry(
735+
Registry{
736+
Signals: BuiltinSignals(),
737+
Policies: BuiltinPolicies(),
738+
Actions: BuiltinActions(),
739+
DefaultPolicyName: config.DefaultPolicy,
740+
PolicyTimeout: config.DefaultPolicyTimeout,
741+
DefaultActionTimeout: config.DefaultActionTimeout,
742+
Logger: logger,
743+
TaskPublisher: publisher,
744+
})
745+
assert.NotNil(t, actRegistry)
746+
747+
consumer, err := sdkAct.NewConsumer(hclogger, rdb, 5, "test-async-chan")
748+
require.NoError(t, err)
749+
750+
require.NoError(t, consumer.Subscribe(context.Background(), func(ctx context.Context, task []byte) error {
751+
err := actRegistry.runAsyncActionFn(ctx, task)
752+
waitGroup.Done()
753+
return err
754+
}))
755+
756+
outputs := actRegistry.Apply([]sdkAct.Signal{
757+
*sdkAct.Log("info", "test", map[string]any{"async": true}),
758+
}, sdkAct.Hook{
759+
Name: "HOOK_NAME_ON_TRAFFIC_FROM_CLIENT",
760+
Priority: 1000,
761+
Params: map[string]any{},
762+
Result: map[string]any{},
763+
})
764+
assert.NotNil(t, outputs)
765+
assert.Equal(t, "log", outputs[0].MatchedPolicy)
766+
assert.Equal(t,
767+
map[string]interface{}{
768+
"async": true,
769+
"level": "info",
770+
"log": true,
771+
"message": "test",
772+
},
773+
outputs[0].Metadata,
774+
)
775+
assert.False(t, outputs[0].Sync)
776+
assert.True(t, cast.ToBool(outputs[0].Verdict))
777+
assert.False(t, outputs[0].Terminal)
778+
waitGroup.Add(1)
779+
result, err := actRegistry.Run(outputs[0], WithResult(map[string]any{"key": "value"}))
780+
waitGroup.Wait()
781+
assert.Equal(t, err, gerr.ErrAsyncAction, "expected async action sentinel error")
782+
assert.Nil(t, result, "expected nil result")
783+
784+
time.Sleep(time.Millisecond) // wait for async action to complete
785+
786+
// The following is the expected log output from running the async action.
787+
assert.Contains(t, out.String(), "{\"level\":\"debug\",\"action\":\"log\",\"executionMode\":\"async\",\"message\":\"Running action\"}") //nolint:lll
788+
// The following is the expected log output from the run function of the async action.
789+
assert.Contains(t, out.String(), "{\"level\":\"info\",\"async\":true,\"message\":\"test\"}")
790+
// The following is expected log from consumer in hclog format
791+
assert.Contains(t, out.String(), "\"@level\":\"debug\",\"@message\":\"async redis task processed successfully\"")
792+
}
793+
708794
// Test_Run_NilRegistry tests the Run function of the action with a nil output object.
709795
func Test_Run_NilOutput(t *testing.T) {
710796
buf := bytes.Buffer{}

0 commit comments

Comments
 (0)