diff --git a/badgerdriver/driver.go b/badgerdriver/driver.go index d81adb2..a23bc45 100644 --- a/badgerdriver/driver.go +++ b/badgerdriver/driver.go @@ -310,7 +310,7 @@ func (d *Driver) Commit(cmd *flowstate.CommitCommand) error { for { if err := d.db.Update(func(txn *badger.Txn) error { - for i, subCmd0 := range cmd.Commands { + for _, subCmd0 := range cmd.Commands { nextRev, err := getRev() if err != nil { return fmt.Errorf("get next sequence: %w", err) @@ -335,9 +335,7 @@ func (d *Driver) Commit(cmd *flowstate.CommitCommand) error { return err } if stateCtx.Committed.Rev != commitedRev { - conflictErr := &flowstate.ErrRevMismatch{} - conflictErr.Add(fmt.Sprintf("%T", cmd.Commands[i]), stateCtx.Current.ID, nil) - return conflictErr + return &flowstate.ErrRevMismatch{IDS: []flowstate.StateID{stateCtx.Current.ID}} } commitedState := stateCtx.Current.CopyTo(&flowstate.State{}) diff --git a/delay.go b/delay.go index 7b70eff..22a0ba5 100644 --- a/delay.go +++ b/delay.go @@ -173,7 +173,10 @@ func NewDelayer(e Engine, l *slog.Logger) (*Delayer, error) { } else if err != nil { return nil, fmt.Errorf("commit meta state: %w", err) } + } else if err != nil { + return nil, fmt.Errorf("get meta state: %w", err) } + d.metaStateCtx = metaStateCtx d.commitSince, d.commitOffset = getDelayerMetaState(metaStateCtx) d.since, d.offset = d.commitSince, d.commitOffset @@ -346,7 +349,7 @@ func getDelayerMetaState(metaStateCtx *StateCtx) (time.Time, int64) { offset0 := metaStateCtx.Current.Annotations[`flowstate.delayer.offset`] offset, err := strconv.ParseInt(offset0, 10, 64) if err != nil { - panic(fmt.Errorf("cannot parse flowstate.delayer.offset=%s into int64: %w", offset0, err)) + panic(fmt.Errorf("cannot parse flowstate.delayer.offset='%s' into int64: %w", offset0, err)) } since0 := metaStateCtx.Current.Annotations[`flowstate.delayer.since`] diff --git a/errors.go b/errors.go index 27c62b1..3d3b5f4 100644 --- a/errors.go +++ b/errors.go @@ -1,12 +1,12 @@ package flowstate -import "errors" +import ( + "errors" +) // ErrRevMismatch is an error that indicates a revision mismatch during a commit operation. type ErrRevMismatch struct { - cmds []string - stateIDs []StateID - errs []error + IDS []StateID } func (err ErrRevMismatch) As(target interface{}) bool { @@ -19,30 +19,28 @@ func (err ErrRevMismatch) As(target interface{}) bool { } func (err ErrRevMismatch) Error() string { - msg := "conflict;" - for i := range err.cmds { - msg += " cmd: " + err.cmds[i] + " sid: " + string(err.stateIDs[i]) + ";" - if err.errs[i] != nil { - msg += " err: " + err.errs[i].Error() + ";" + msg := "rev mismatch: " + for i, id := range err.IDS { + if i > 0 { + msg += ", " } - } + msg += string(id) + } return msg } -func (err *ErrRevMismatch) Add(cmd string, sID StateID, cmdErr error) { - err.cmds = append(err.cmds, cmd) - err.stateIDs = append(err.stateIDs, sID) - err.errs = append(err.errs, cmdErr) +func (err *ErrRevMismatch) Add(id StateID) { + err.IDS = append(err.IDS, id) } -func (err *ErrRevMismatch) TaskIDs() []StateID { - return err.stateIDs +func (err *ErrRevMismatch) All() []StateID { + return err.IDS } -func (err *ErrRevMismatch) Contains(sID StateID) bool { - for _, s := range err.stateIDs { - if s == sID { +func (err *ErrRevMismatch) Contains(id StateID) bool { + for _, s := range err.IDS { + if s == id { return true } } diff --git a/memdriver/driver.go b/memdriver/driver.go index 783d3b3..3e7b273 100644 --- a/memdriver/driver.go +++ b/memdriver/driver.go @@ -192,9 +192,7 @@ func (d *Driver) Commit(cmd *flowstate.CommitCommand) error { } if _, rev := d.stateLog.GetLatestByID(stateCtx.Current.ID); rev != stateCtx.Committed.Rev { - conflictErr := &flowstate.ErrRevMismatch{} - conflictErr.Add(fmt.Sprintf("%T", cmd), stateCtx.Current.ID, fmt.Errorf("rev mismatch")) - return conflictErr + return &flowstate.ErrRevMismatch{IDS: []flowstate.StateID{stateCtx.Current.ID}} } d.stateLog.Append(stateCtx) diff --git a/netdriver/driver.go b/netdriver/driver.go new file mode 100644 index 0000000..1f06466 --- /dev/null +++ b/netdriver/driver.go @@ -0,0 +1,295 @@ +package netdriver + +import ( + "bytes" + "fmt" + "io" + "net/http" + "strings" + + "github.com/makasim/flowstate" +) + +var _ flowstate.Driver = (*Driver)(nil) + +type Driver struct { + flowstate.FlowRegistry + + httpHost string + + c *http.Client +} + +func New(httpHost string) *Driver { + return &Driver{ + httpHost: httpHost, + + c: &http.Client{}, + } +} + +func (d *Driver) Init(_ flowstate.Engine) error { + return nil +} + +func (d *Driver) GetStateByID(cmd *flowstate.GetStateByIDCommand) error { + return d.do(cmd, "/flowstate.v1.Driver/GetStateByID") +} + +func (d *Driver) GetStateByLabels(cmd *flowstate.GetStateByLabelsCommand) error { + return d.do(cmd, "/flowstate.v1.Driver/GetStateByLabels") +} + +func (d *Driver) GetStates(cmd *flowstate.GetStatesCommand) error { + return d.do(cmd, "/flowstate.v1.Driver/GetStates") +} + +func (d *Driver) GetDelayedStates(cmd *flowstate.GetDelayedStatesCommand) error { + return d.do(cmd, "/flowstate.v1.Driver/GetDelayedStates") +} + +func (d *Driver) Delay(cmd *flowstate.DelayCommand) error { + return d.do(cmd, "/flowstate.v1.Driver/Delay") +} + +func (d *Driver) Commit(cmd *flowstate.CommitCommand) error { + return d.do(cmd, "/flowstate.v1.Driver/Commit") +} + +func (d *Driver) GetData(cmd *flowstate.GetDataCommand) error { + return d.do(cmd, "/flowstate.v1.Driver/GetData") +} + +func (d *Driver) StoreData(cmd *flowstate.StoreDataCommand) error { + return d.do(cmd, "/flowstate.v1.Driver/StoreData") +} + +func (d *Driver) do(cmd flowstate.Command, path string) error { + b := flowstate.MarshalCommand(cmd, nil) + req, err := http.NewRequest(`POST`, strings.TrimRight(d.httpHost, `/`)+path, bytes.NewBuffer(b)) + if err != nil { + return fmt.Errorf("new request: %w", err) + } + + resp, err := d.c.Do(req) + if err != nil { + return fmt.Errorf("do request: %w", err) + } + defer resp.Body.Close() + + b, err = io.ReadAll(resp.Body) + if err != nil { + return fmt.Errorf("read response body: %w", err) + } + + if http.StatusOK != resp.StatusCode { + code, message, err := unmarshalError(b) + if err != nil { + return fmt.Errorf("response status code: %d; unmarshal error: %s", resp.StatusCode, err) + } + + switch { + case code == "not_found" && strings.Contains(message, flowstate.ErrNotFound.Error()): + return flowstate.ErrNotFound + case code == "aborted" && strings.HasPrefix(message, "rev mismatch: "): + _, idsStr, _ := strings.Cut(message, "rev mismatch: ") + splitIds := strings.Split(idsStr, ",") + ids := make([]flowstate.StateID, 0, len(splitIds)) + for i := range splitIds { + id := strings.TrimSpace(splitIds[i]) + if id == "" { + continue + } + + ids = append(ids, flowstate.StateID(id)) + } + + return flowstate.ErrRevMismatch{IDS: ids} + } + + return fmt.Errorf("%s: %s", code, message) + } + + resCmd, err := flowstate.UnmarshalCommand(b) + if err != nil { + return fmt.Errorf("unmarshal response: %w", err) + } + + if err := syncResult(cmd, resCmd); err != nil { + return fmt.Errorf("sync result: %w", err) + } + + return nil +} + +func syncResult(inCmd0, resCmd0 flowstate.Command) error { + switch inCmd := inCmd0.(type) { + case *flowstate.TransitCommand: + resCmd, ok := resCmd0.(*flowstate.TransitCommand) + if !ok { + return fmt.Errorf("resCmd is not a TransitCommand") + } + + resCmd.StateCtx.CopyTo(inCmd.StateCtx) + return nil + case *flowstate.PauseCommand: + resCmd, ok := resCmd0.(*flowstate.PauseCommand) + if !ok { + return fmt.Errorf("resCmd is not a PauseCommand") + } + + resCmd.StateCtx.CopyTo(inCmd.StateCtx) + return nil + case *flowstate.ResumeCommand: + resCmd, ok := resCmd0.(*flowstate.ResumeCommand) + if !ok { + return fmt.Errorf("resCmd is not a ResumeCommand") + } + + resCmd.StateCtx.CopyTo(inCmd.StateCtx) + return nil + case *flowstate.EndCommand: + resCmd, ok := resCmd0.(*flowstate.EndCommand) + if !ok { + return fmt.Errorf("resCmd is not a EndCommand") + } + + resCmd.StateCtx.CopyTo(inCmd.StateCtx) + return nil + case *flowstate.ExecuteCommand: + resCmd, ok := resCmd0.(*flowstate.ExecuteCommand) + if !ok { + return fmt.Errorf("resCmd is not a ExecuteCommand") + } + + resCmd.StateCtx.CopyTo(inCmd.StateCtx) + return nil + case *flowstate.DelayCommand: + resCmd, ok := resCmd0.(*flowstate.DelayCommand) + if !ok { + return fmt.Errorf("resCmd is not a DelayCommand") + } + + resCmd.StateCtx.CopyTo(inCmd.StateCtx) + resCmd.DelayingState.CopyTo(&inCmd.DelayingState) + return nil + case *flowstate.CommitCommand: + resCmd, ok := resCmd0.(*flowstate.CommitCommand) + if !ok { + return fmt.Errorf("resCmd is not a CommitCommand") + } + if len(resCmd.Commands) != len(inCmd.Commands) { + return fmt.Errorf("resCmd.Commands length mismatch: got %d, want %d", len(resCmd.Commands), len(inCmd.Commands)) + } + + for i := range inCmd.Commands { + if err := syncResult(inCmd.Commands[i], resCmd.Commands[i]); err != nil { + return fmt.Errorf("%d# command %T: %w", i, inCmd.Commands[i], err) + } + } + + return nil + case *flowstate.NoopCommand: + resCmd, ok := resCmd0.(*flowstate.ExecuteCommand) + if !ok { + return fmt.Errorf("resCmd is not a ExecuteCommand") + } + + resCmd.StateCtx.CopyTo(inCmd.StateCtx) + return nil + case *flowstate.SerializeCommand: + resCmd, ok := resCmd0.(*flowstate.SerializeCommand) + if !ok { + return fmt.Errorf("resCmd is not a SerializeCommand") + } + + resCmd.StateCtx.CopyTo(inCmd.StateCtx) + resCmd.SerializableStateCtx.CopyTo(inCmd.SerializableStateCtx) + return nil + case *flowstate.DeserializeCommand: + resCmd, ok := resCmd0.(*flowstate.DeserializeCommand) + if !ok { + return fmt.Errorf("resCmd is not a DeserializeCommand") + } + + resCmd.StateCtx.CopyTo(inCmd.StateCtx) + resCmd.DeserializedStateCtx.CopyTo(inCmd.DeserializedStateCtx) + return nil + case *flowstate.GetDataCommand: + resCmd, ok := resCmd0.(*flowstate.GetDataCommand) + if !ok { + return fmt.Errorf("resCmd is not a GetDataCommand") + } + + resCmd.Data.CopyTo(inCmd.Data) + return nil + case *flowstate.StoreDataCommand: + resCmd, ok := resCmd0.(*flowstate.StoreDataCommand) + if !ok { + return fmt.Errorf("resCmd is not a StoreDataCommand") + } + + resCmd.Data.CopyTo(inCmd.Data) + return nil + case *flowstate.ReferenceDataCommand: + resCmd, ok := resCmd0.(*flowstate.ReferenceDataCommand) + if !ok { + return fmt.Errorf("resCmd is not a ReferenceDataCommand") + } + + resCmd.StateCtx.CopyTo(inCmd.StateCtx) + resCmd.Data.CopyTo(inCmd.Data) + return nil + case *flowstate.DereferenceDataCommand: + resCmd, ok := resCmd0.(*flowstate.DereferenceDataCommand) + if !ok { + return fmt.Errorf("resCmd is not a DereferenceDataCommand") + } + + resCmd.StateCtx.CopyTo(inCmd.StateCtx) + resCmd.Data.CopyTo(inCmd.Data) + return nil + case *flowstate.GetStateByIDCommand: + resCmd, ok := resCmd0.(*flowstate.GetStateByIDCommand) + if !ok { + return fmt.Errorf("resCmd is not a GetStateByIDCommand") + } + + resCmd.StateCtx.CopyTo(inCmd.StateCtx) + return nil + case *flowstate.GetStateByLabelsCommand: + resCmd, ok := resCmd0.(*flowstate.GetStateByLabelsCommand) + if !ok { + return fmt.Errorf("resCmd is not a GetStateByLabelsCommand") + } + + resCmd.StateCtx.CopyTo(inCmd.StateCtx) + return nil + case *flowstate.GetStatesCommand: + resCmd, ok := resCmd0.(*flowstate.GetStatesCommand) + if !ok { + return fmt.Errorf("resCmd is not a GetStatesCommand") + } + + inCmd.Result = resCmd.Result + return nil + case *flowstate.GetDelayedStatesCommand: + resCmd, ok := resCmd0.(*flowstate.GetDelayedStatesCommand) + if !ok { + return fmt.Errorf("resCmd is not a GetDelayedStatesCommand") + } + + inCmd.Result = resCmd.Result + return nil + case *flowstate.CommitStateCtxCommand: + resCmd, ok := resCmd0.(*flowstate.CommitStateCtxCommand) + if !ok { + return fmt.Errorf("resCmd is not a CommitStateCtxCommand") + } + + resCmd.StateCtx.CopyTo(inCmd.StateCtx) + return nil + default: + return fmt.Errorf("unknown inCmd: %T", inCmd0) + } +} diff --git a/netdriver/server.go b/netdriver/server.go new file mode 100644 index 0000000..8894246 --- /dev/null +++ b/netdriver/server.go @@ -0,0 +1,309 @@ +package netdriver + +import ( + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/VictoriaMetrics/easyproto" + "github.com/makasim/flowstate" +) + +func HandleAll(rw http.ResponseWriter, r *http.Request, d flowstate.Driver) bool { + if HandleGetStateByID(rw, r, d) { + return true + } + if HandleGetStateByLabels(rw, r, d) { + return true + } + if HandleCommit(rw, r, d) { + return true + } + if HandleGetStates(rw, r, d) { + return true + } + if HandleGetDelayedStates(rw, r, d) { + return true + } + if HandleDelay(rw, r, d) { + return true + } + if HandleGetData(rw, r, d) { + return true + } + if HandleStoreData(rw, r, d) { + return true + } + + return false +} + +func HandleGetStateByID(rw http.ResponseWriter, r *http.Request, d flowstate.Driver) bool { + if r.URL.Path != "/flowstate.v1.Driver/GetStateByID" { + return false + } + + cmd, err := readCmd[*flowstate.GetStateByIDCommand](r) + if err != nil { + writeInvalidArgumentError(rw, err.Error()) + return true + } + + if err := d.GetStateByID(cmd); errors.Is(err, flowstate.ErrNotFound) { + writeNotFoundError(rw, err.Error()) + return true + } else if err != nil { + writeUnknownError(rw, err.Error()) + return true + } + + writeCmd(rw, cmd) + return true +} + +func HandleGetStateByLabels(rw http.ResponseWriter, r *http.Request, d flowstate.Driver) bool { + if r.URL.Path != "/flowstate.v1.Driver/GetStateByLabels" { + return false + } + + cmd, err := readCmd[*flowstate.GetStateByLabelsCommand](r) + if err != nil { + writeInvalidArgumentError(rw, err.Error()) + return true + } + + if err := d.GetStateByLabels(cmd); errors.Is(err, flowstate.ErrNotFound) { + writeNotFoundError(rw, err.Error()) + return true + } else if err != nil { + writeUnknownError(rw, err.Error()) + return true + } + + writeCmd(rw, cmd) + return true +} + +func HandleGetStates(rw http.ResponseWriter, r *http.Request, d flowstate.Driver) bool { + if r.URL.Path != "/flowstate.v1.Driver/GetStates" { + return false + } + + cmd, err := readCmd[*flowstate.GetStatesCommand](r) + if err != nil { + writeInvalidArgumentError(rw, err.Error()) + return true + } + + if err := d.GetStates(cmd); err != nil { + writeUnknownError(rw, err.Error()) + return true + } + + writeCmd(rw, cmd) + return true +} + +func HandleGetDelayedStates(rw http.ResponseWriter, r *http.Request, d flowstate.Driver) bool { + if r.URL.Path != "/flowstate.v1.Driver/GetDelayedStates" { + return false + } + + cmd, err := readCmd[*flowstate.GetDelayedStatesCommand](r) + if err != nil { + writeInvalidArgumentError(rw, err.Error()) + return true + } + + if err := d.GetDelayedStates(cmd); err != nil { + writeUnknownError(rw, err.Error()) + return true + } + + writeCmd(rw, cmd) + return true +} + +func HandleDelay(rw http.ResponseWriter, r *http.Request, d flowstate.Driver) bool { + if r.URL.Path != "/flowstate.v1.Driver/Delay" { + return false + } + + cmd, err := readCmd[*flowstate.DelayCommand](r) + if err != nil { + writeInvalidArgumentError(rw, err.Error()) + return true + } + + if err := d.Delay(cmd); err != nil { + writeUnknownError(rw, err.Error()) + return true + } + + writeCmd(rw, cmd) + return true +} + +func HandleCommit(rw http.ResponseWriter, r *http.Request, d flowstate.Driver) bool { + if r.URL.Path != "/flowstate.v1.Driver/Commit" { + return false + } + + cmd, err := readCmd[*flowstate.CommitCommand](r) + if err != nil { + writeInvalidArgumentError(rw, err.Error()) + return true + } + + if err := d.Commit(cmd); flowstate.IsErrRevMismatch(err) { + writeAbortedError(rw, err.Error()) + return true + } else if errors.Is(err, flowstate.ErrNotFound) { + writeNotFoundError(rw, err.Error()) + return true + } else if err != nil { + writeUnknownError(rw, err.Error()) + return true + } + + writeCmd(rw, cmd) + return true +} + +func HandleGetData(rw http.ResponseWriter, r *http.Request, d flowstate.Driver) bool { + if r.URL.Path != "/flowstate.v1.Driver/GetData" { + return false + } + + cmd, err := readCmd[*flowstate.GetDataCommand](r) + if err != nil { + writeInvalidArgumentError(rw, err.Error()) + return true + } + + if err := d.GetData(cmd); err != nil { + writeUnknownError(rw, err.Error()) + return true + } + + writeCmd(rw, cmd) + return true +} + +func HandleStoreData(rw http.ResponseWriter, r *http.Request, d flowstate.Driver) bool { + if r.URL.Path != "/flowstate.v1.Driver/StoreData" { + return false + } + + cmd, err := readCmd[*flowstate.StoreDataCommand](r) + if err != nil { + writeInvalidArgumentError(rw, err.Error()) + return true + } + + if err := d.StoreData(cmd); err != nil { + writeUnknownError(rw, err.Error()) + return true + } + + writeCmd(rw, cmd) + return true +} + +func readCmd[T flowstate.Command](r *http.Request) (T, error) { + var defCmd T + + reqBody, err := io.ReadAll(r.Body) + if err != nil { + return defCmd, fmt.Errorf("failed to read request body: %w", err) + } + + cmd0, err := flowstate.UnmarshalCommand(reqBody) + if err != nil { + return defCmd, fmt.Errorf("failed to unmarshal command: %w", err) + } + + cmd, ok := cmd0.(T) + if !ok { + return defCmd, fmt.Errorf("invalid command type: expected %T; got: %T", defCmd, cmd0) + } + + return cmd, nil +} + +func writeCmd(rw http.ResponseWriter, cmd flowstate.Command) { + rw.Header().Set("Content-Type", "application/proto") + rw.WriteHeader(http.StatusOK) + + _, _ = rw.Write(flowstate.MarshalCommand(cmd, nil)) +} + +func writeInvalidArgumentError(rw http.ResponseWriter, message string) { + rw.Header().Set("Content-Type", "application/proto") + rw.WriteHeader(http.StatusBadRequest) + + _, _ = rw.Write(marshalError("invalid_argument", message)) +} + +func writeUnknownError(rw http.ResponseWriter, message string) { + rw.Header().Set("Content-Type", "application/proto") + rw.WriteHeader(http.StatusInternalServerError) + + _, _ = rw.Write(marshalError("unknown", message)) +} + +func writeNotFoundError(rw http.ResponseWriter, message string) { + rw.Header().Set("Content-Type", "application/proto") + rw.WriteHeader(http.StatusNotFound) + + _, _ = rw.Write(marshalError("not_found", message)) +} + +func writeAbortedError(rw http.ResponseWriter, message string) { + rw.Header().Set("Content-Type", "application/proto") + rw.WriteHeader(http.StatusConflict) + + _, _ = rw.Write(marshalError("aborted", message)) +} + +func marshalError(code, message string) []byte { + m := &easyproto.Marshaler{} + mm := m.MessageMarshaler() + + if code != "" { + mm.AppendString(1, code) + } + if message != "" { + mm.AppendString(2, message) + } + + return m.Marshal(nil) +} + +func unmarshalError(src []byte) (code, message string, err error) { + var fc easyproto.FieldContext + for len(src) > 0 { + src, err = fc.NextField(src) + if err != nil { + return "", "", fmt.Errorf("cannot read next field") + } + switch fc.FieldNum { + case 1: + v, ok := fc.String() + if !ok { + return "", "", fmt.Errorf("cannot read code field") + } + code = strings.Clone(v) + case 2: + v, ok := fc.String() + if !ok { + return "", "", fmt.Errorf("cannot read message field") + } + message = strings.Clone(v) + } + } + + return code, message, nil +} diff --git a/netdriver/server_test.go b/netdriver/server_test.go new file mode 100644 index 0000000..3894542 --- /dev/null +++ b/netdriver/server_test.go @@ -0,0 +1,31 @@ +package netdriver + +import ( + "testing" +) + +func TestMarshalUnmarshalError(t *testing.T) { + b := marshalError("", "") + actCode, actMessage, err := unmarshalError(b) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if actCode != "" { + t.Errorf("expected code to be empty, got: %s", actCode) + } + if actMessage != "" { + t.Errorf("expected message to be empty, got: %s", actMessage) + } + + b = marshalError("theCode", "theMessage") + actCode, actMessage, err = unmarshalError(b) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if actCode != "theCode" { + t.Errorf("expected code to be 'theCode', got: %s", actCode) + } + if actMessage != "theMessage" { + t.Errorf("expected message to be 'theMessage', got: %s", actMessage) + } +} diff --git a/netdriver/suite_test.go b/netdriver/suite_test.go new file mode 100644 index 0000000..02f2a3d --- /dev/null +++ b/netdriver/suite_test.go @@ -0,0 +1,68 @@ +package netdriver + +import ( + "fmt" + "log/slog" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/makasim/flowstate" + "github.com/makasim/flowstate/memdriver" + "github.com/makasim/flowstate/testcases" +) + +func TestSuite(t *testing.T) { + s := testcases.Get(func(t *testing.T) flowstate.Driver { + l, _ := testcases.NewTestLogger(t) + + srv := startSrv(t, l) + + return New(srv.URL) + }) + + //s.SetUpDelayer = false + //s.DisableGoleak() + s.Test(t) +} + +func startSrv(t *testing.T, l *slog.Logger) *httptest.Server { + d := memdriver.New(l) + + srv := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) { + if r.URL.Path == `/health` { + rw.WriteHeader(http.StatusOK) + return + } + if HandleAll(rw, r, d) { + return + } + + writeNotFoundError(rw, fmt.Sprintf("path %s not found", r.URL.Path)) + })) + + t.Cleanup(srv.Close) + + timeoutT := time.NewTimer(time.Second) + defer timeoutT.Stop() + readyT := time.NewTicker(time.Millisecond * 50) + defer readyT.Stop() + +loop: + for { + select { + case <-timeoutT.C: + t.Fatalf("app not ready within %s", time.Second) + case <-readyT.C: + + resp, err := http.Get(srv.URL + `/health`) + if err != nil { + continue loop + } + resp.Body.Close() + + return srv + } + } +} diff --git a/pgdriver/driver.go b/pgdriver/driver.go index 01db892..ca2f21f 100644 --- a/pgdriver/driver.go +++ b/pgdriver/driver.go @@ -163,14 +163,14 @@ func (d *Driver) Commit(cmd *flowstate.CommitCommand) error { if committableStateCtx.Committed.Rev > 0 { if err := d.q.UpdateState(context.Background(), tx, &nextState); isRevMismatchErr(err) { - revMismatchErr.Add(fmt.Sprintf("%T", cmd), committableStateCtx.Current.ID, err) + revMismatchErr.Add(committableStateCtx.Current.ID) return revMismatchErr } else if err != nil { return fmt.Errorf("update state: %w", err) } } else { if err := d.q.InsertState(context.Background(), tx, &nextState); isRevMismatchErr(err) { - revMismatchErr.Add(fmt.Sprintf("%T", cmd), committableStateCtx.Current.ID, err) + revMismatchErr.Add(committableStateCtx.Current.ID) return revMismatchErr } else if err != nil { return fmt.Errorf("insert state: %w", err) @@ -182,7 +182,7 @@ func (d *Driver) Commit(cmd *flowstate.CommitCommand) error { committableStateCtx.Transitions = committableStateCtx.Transitions[:0] } - if len(revMismatchErr.TaskIDs()) > 0 { + if len(revMismatchErr.All()) > 0 { return revMismatchErr } diff --git a/proto/flowstate/v1/messages.proto b/proto/flowstate/v1/messages.proto index 7637b8e..12151d8 100644 --- a/proto/flowstate/v1/messages.proto +++ b/proto/flowstate/v1/messages.proto @@ -156,34 +156,34 @@ message GetStatesCommand { map labels = 1; } - message Result { - repeated State states = 1; - bool more = 2; - } - int64 since_rev = 1; int64 since_time_usec = 2; repeated Labels labels = 3; bool latest_only = 4; int64 limit = 5; - Result result = 6; + GetStatesResult result = 6; } -message GetDelayedStatesCommand { - message Result { - repeated DelayedState delayed_states = 1; - bool more = 2; - } +message GetStatesResult { + repeated State states = 1; + bool more = 2; +} +message GetDelayedStatesCommand { int64 since_time_sec = 1; int64 until_time_sec = 2; int64 offset = 3; int64 limit = 4; - Result result = 5; + GetDelayedStatesResult result = 5; +} + +message GetDelayedStatesResult { + repeated DelayedState delayed_states = 1; + bool more = 2; } message CommitStateCtxCommand { StateRef state_ref = 1; -} \ No newline at end of file +} diff --git a/proto/flowstate/v1/services.proto b/proto/flowstate/v1/services.proto new file mode 100644 index 0000000..e24fa39 --- /dev/null +++ b/proto/flowstate/v1/services.proto @@ -0,0 +1,22 @@ +syntax = "proto3"; + +package flowstate.v1; + +import "flowstate/v1/messages.proto"; + +service Driver { + rpc Do(Command) returns (Command) {} + + rpc GetStateByID(Command) returns (Command) {} + rpc GetStateByLabels(Command) returns (Command) {} + rpc GetStates(Command) returns (GetStatesResult) {} + rpc GetDelayedStates(Command) returns (GetDelayedStatesResult) {} + rpc Delay(Command) returns (Command) {} + rpc Commit(Command) returns (Command) {} + rpc GetData(Command) returns (Command) {} + rpc StoreData(Command) returns (Command) {} +} + +service Flow { + rpc Execute(StateCtx) returns (Command) {} +} diff --git a/protobuf.go b/protobuf.go index 48e5831..ae1da98 100644 --- a/protobuf.go +++ b/protobuf.go @@ -2119,7 +2119,7 @@ func commandStateCtxs(cmd0 Command) []*StateCtx { return nil } - slices.CompactFunc(stateCtxs, func(l, r *StateCtx) bool { + stateCtxs = slices.CompactFunc(stateCtxs, func(l, r *StateCtx) bool { return l.Current.ID == r.Current.ID && l.Current.Rev == r.Current.Rev }) @@ -2154,7 +2154,7 @@ func commandDatas(cmd0 Command) []*Data { return nil } - slices.CompactFunc(datas, func(l, r *Data) bool { + datas = slices.CompactFunc(datas, func(l, r *Data) bool { return l.ID == r.ID && l.Rev == r.Rev }) diff --git a/protobuf_test.go b/protobuf_test.go index af80387..6b05139 100644 --- a/protobuf_test.go +++ b/protobuf_test.go @@ -337,6 +337,23 @@ func TestMarshalUnmarshalCommand(t *testing.T) { f(&flowstate.CommitCommand{}) + stateCtx := &flowstate.StateCtx{ + Current: flowstate.State{ + ID: "theTransitID", + Rev: 123, + }, + } + f(&flowstate.CommitCommand{ + Commands: []flowstate.Command{ + &flowstate.TransitCommand{ + StateCtx: stateCtx, + }, + &flowstate.PauseCommand{ + StateCtx: stateCtx, + }, + }, + }) + f(&flowstate.CommitCommand{ Commands: []flowstate.Command{ &flowstate.TransitCommand{ @@ -359,6 +376,29 @@ func TestMarshalUnmarshalCommand(t *testing.T) { }, }) + data := &flowstate.Data{ID: "theDataID", Rev: 123, B: []byte("theDataValue")} + f(&flowstate.CommitCommand{ + Commands: []flowstate.Command{ + &flowstate.StoreDataCommand{ + Data: data, + }, + &flowstate.GetDataCommand{ + Data: data, + }, + }, + }) + + f(&flowstate.CommitCommand{ + Commands: []flowstate.Command{ + &flowstate.StoreDataCommand{ + Data: &flowstate.Data{ID: "theDataID", Rev: 123, B: []byte("theDataValue")}, + }, + &flowstate.StoreDataCommand{ + Data: &flowstate.Data{ID: "theOtherDataID", Rev: 234}, + }, + }, + }) + f(&flowstate.NoopCommand{}) f(&flowstate.NoopCommand{