diff --git a/.github/workflows/on-push-pr.yml b/.github/workflows/on-push-pr.yml index b74323b..4fb2ef1 100644 --- a/.github/workflows/on-push-pr.yml +++ b/.github/workflows/on-push-pr.yml @@ -83,3 +83,45 @@ jobs: with: name: release-zips path: "*.zip" + mysql-test: + runs-on: 'ubuntu-latest' + needs: format-build-test + services: + mysql: + image: mysql:8.0 + env: + MYSQL_RANDOM_ROOT_PASSWORD: yes + MYSQL_DATABASE: nanocmd + MYSQL_USER: nanocmd + MYSQL_PASSWORD: nanocmd + ports: + - 3800:3306 + options: --health-cmd="mysqladmin ping" --health-interval=5s --health-timeout=2s --health-retries=3 + defaults: + run: + shell: bash + env: + MYSQL_PWD: nanocmd + PORT: 3800 + steps: + - uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 # v4.1.0 + + - uses: actions/setup-go@6edd4406fa81c3da01a34fa6f6343087c207a568 # v3.5.0 + with: + go-version: '1.19.x' + + - name: verify mysql + run: | + while ! mysqladmin ping --host=localhost --port=$PORT --protocol=TCP --silent; do + sleep 1 + done + + - name: mysql schema + run: | + mysql --version + mysql --user=nanocmd --host=localhost --port=$PORT --protocol=TCP nanocmd < ./engine/storage/mysql/schema.sql + + - name: setup test dsn + run: echo "NANOCMD_MYSQL_STORAGE_TEST_DSN=nanocmd:nanocmd@tcp(localhost:$PORT)/nanocmd" >> $GITHUB_ENV + + - run: go test -v ./engine/storage/mysql diff --git a/cmd/nanocmd/storage.go b/cmd/nanocmd/storage.go index b002f3b..d4d8a9f 100644 --- a/cmd/nanocmd/storage.go +++ b/cmd/nanocmd/storage.go @@ -6,6 +6,7 @@ import ( storageeng "github.com/micromdm/nanocmd/engine/storage" storageengdiskv "github.com/micromdm/nanocmd/engine/storage/diskv" storageenginmem "github.com/micromdm/nanocmd/engine/storage/inmem" + storageengmysql "github.com/micromdm/nanocmd/engine/storage/mysql" storagecmdplan "github.com/micromdm/nanocmd/subsystem/cmdplan/storage" storagecmdplandiskv "github.com/micromdm/nanocmd/subsystem/cmdplan/storage/diskv" storagecmdplaninmem "github.com/micromdm/nanocmd/subsystem/cmdplan/storage/inmem" @@ -19,6 +20,8 @@ import ( storageprof "github.com/micromdm/nanocmd/subsystem/profile/storage" storageprofdiskv "github.com/micromdm/nanocmd/subsystem/profile/storage/diskv" storageprofinmem "github.com/micromdm/nanocmd/subsystem/profile/storage/inmem" + + _ "github.com/go-sql-driver/mysql" ) type storageConfig struct { @@ -65,6 +68,24 @@ func parseStorage(name, dsn string) (*storageConfig, error) { event: eng, filevault: fv, }, nil + case "mysql": + inv := storageinvinmem.New() + fv, err := storagefvinmem.New(storagefvinvprk.NewInvPRK(inv)) + if err != nil { + return nil, fmt.Errorf("creating filevault inmem storage: %w", err) + } + eng, err := storageengmysql.New(storageengmysql.WithDSN(dsn)) + if err != nil { + return nil, err + } + return &storageConfig{ + engine: eng, + inventory: inv, + profile: storageprofinmem.New(), + cmdplan: storagecmdplaninmem.New(), + event: eng, + filevault: fv, + }, nil } return nil, fmt.Errorf("unknown storage: %s", name) } diff --git a/docs/operations-guide.md b/docs/operations-guide.md index cc082d5..01ef433 100644 --- a/docs/operations-guide.md +++ b/docs/operations-guide.md @@ -88,6 +88,17 @@ Configures the `inmem` storage backend. Data is stored entirely in-memory and is *Example:* `-storage inmem` +##### mysql storage backend + +* `-storage mysql` + +Configures the MySQL storage backend. The `-storage-dsn` flag should be in the [format the SQL driver expects](https://github.com/go-sql-driver/mysql#dsn-data-source-name). +Be sure to create the storage tables with the [schema.sql](../storage/mysql/schema.sql) file. MySQL 8.0.19 or later is required. + +**WARNING:** The MySQL backend currently only implements storage for the workflow *engine*. When running NanoCMD the *subsystem* storage is completely in-memory as if you supplied `-storage inmem`. The practical effect is that subsystem storage is volatile and no data will be persisted for them. + +*Example:* `-storage mysql -dsn nanocmd:nanocmd/mycmddb` + #### -version * print version diff --git a/engine/convert.go b/engine/convert.go index 4337b95..1e35ac5 100644 --- a/engine/convert.go +++ b/engine/convert.go @@ -129,7 +129,7 @@ func storageStepCommandFromRawResponse(reqType string, rawResp []byte) (*storage ResultReport: rawResp, Completed: genResp.Status != "" && genResp.Status != "NotNow", } - return sc, response, nil + return sc, response, sc.Validate() } // workflowCommandResponseFromRawResponse converts a raw XML plist of a command response to a workflow response. diff --git a/engine/storage/kv/kv.go b/engine/storage/kv/kv.go index 12cc565..9daf0b9 100644 --- a/engine/storage/kv/kv.go +++ b/engine/storage/kv/kv.go @@ -3,6 +3,7 @@ package kv import ( "context" + "errors" "fmt" "sync" "time" @@ -33,6 +34,9 @@ func New(stepStore kv.TraversingBucket, idCmdStore kv.TraversingBucket, eventSto // RetrieveCommandRequestType implements the storage interface method. func (s *KV) RetrieveCommandRequestType(ctx context.Context, id string, cmdUUID string) (string, bool, error) { + if id == "" || cmdUUID == "" { + return "", false, errors.New("empty id or command uuid") + } s.mu.RLock() defer s.mu.RUnlock() // first check if we have a valid command @@ -157,6 +161,23 @@ func (s *KV) StoreStep(ctx context.Context, step *storage.StepEnqueuingWithConfi // fabricate a unique ID to track this unique step stepID := s.ider.ID() + if step != nil { + idCmdUUIDs := make(map[string]struct{}) + for _, sc := range step.Commands { + for _, id := range step.IDs { + if _, ok := idCmdUUIDs[id+sc.CommandUUID]; ok { + return fmt.Errorf("duplicate command (id=%s, uuid=%s)", id, sc.CommandUUID) + } + idCmdUUIDs[id+sc.CommandUUID] = struct{}{} + if ok, err := kvIDCmdExists(ctx, s.idCmdStore, id, sc.CommandUUID); err != nil { + return fmt.Errorf("checking duplicate commands: %w", err) + } else if ok { + return fmt.Errorf("duplicate command (id=%s, uuid=%s)", id, sc.CommandUUID) + } + } + } + } + err := kvSetStep(ctx, s.stepStore, stepID, step) if err != nil { return fmt.Errorf("setting step record: %w", err) diff --git a/engine/storage/mysql/event.go b/engine/storage/mysql/event.go new file mode 100644 index 0000000..0f5a90b --- /dev/null +++ b/engine/storage/mysql/event.go @@ -0,0 +1,74 @@ +package mysql + +import ( + "context" + "fmt" + + "github.com/micromdm/nanocmd/engine/storage" + "github.com/micromdm/nanocmd/workflow" +) + +// RetrieveEventSubscriptions retrieves event subscriptions by names. +// See the storage interface type for further docs. +func (s *MySQLStorage) RetrieveEventSubscriptions(ctx context.Context, names []string) (map[string]*storage.EventSubscription, error) { + events, err := s.q.GetEventsByNames(ctx, names) + if err != nil { + return nil, fmt.Errorf("get events by name: %w", err) + } + retEvents := make(map[string]*storage.EventSubscription) + for _, event := range events { + retEvents[event.EventName] = &storage.EventSubscription{ + Event: event.EventType, + Workflow: event.WorkflowName, + Context: event.Context.String, + } + } + return retEvents, nil +} + +// RetrieveEventSubscriptionsByEvent retrieves event subscriptions by event flag. +// See the storage interface type for further docs. +func (s *MySQLStorage) RetrieveEventSubscriptionsByEvent(ctx context.Context, f workflow.EventFlag) ([]*storage.EventSubscription, error) { + events, err := s.q.GetEventsByType(ctx, f.String()) + if err != nil { + return nil, fmt.Errorf("get events by type: %w", err) + } + var retEvents []*storage.EventSubscription + for _, event := range events { + retEvents = append(retEvents, &storage.EventSubscription{ + Event: event.EventType, + Workflow: event.WorkflowName, + Context: event.Context.String, + }) + } + return retEvents, nil +} + +// StoreEventSubscription stores an event subscription. +// See the storage interface type for further docs. +func (s *MySQLStorage) StoreEventSubscription(ctx context.Context, name string, es *storage.EventSubscription) error { + _, err := s.db.ExecContext( + ctx, + ` +INSERT INTO wf_events + (event_name, event_type, workflow_name, context) +VALUES + (?, ?, ?, ?) AS new +ON DUPLICATE KEY +UPDATE + workflow_name = new.workflow_name, + event_type = new.event_type, + context = new.context;`, + name, + es.Event, + es.Workflow, + sqlNullString(es.Context), + ) + return err +} + +// DeleteEventSubscription removes an event subscription. +// See the storage interface type for further docs. +func (s *MySQLStorage) DeleteEventSubscription(ctx context.Context, name string) error { + return s.q.RemoveEvent(ctx, name) +} diff --git a/engine/storage/mysql/generate.go b/engine/storage/mysql/generate.go new file mode 100644 index 0000000..973bbcf --- /dev/null +++ b/engine/storage/mysql/generate.go @@ -0,0 +1,3 @@ +package mysql + +//go:generate sqlc generate diff --git a/engine/storage/mysql/mysql.go b/engine/storage/mysql/mysql.go new file mode 100644 index 0000000..8212aca --- /dev/null +++ b/engine/storage/mysql/mysql.go @@ -0,0 +1,108 @@ +package mysql + +import ( + "context" + "database/sql" + "fmt" + "math/rand" + "sync" + "time" + + "github.com/micromdm/nanocmd/engine/storage/mysql/sqlc" +) + +// MySQLStorage implements a storage.AllStorage using MySQL. +type MySQLStorage struct { + db *sql.DB + q *sqlc.Queries + + randMu sync.Mutex + rand *rand.Rand +} + +type config struct { + driver string + dsn string + db *sql.DB +} + +// Option allows configuring a MySQLStorage. +type Option func(*config) + +// WithDSN sets the storage MySQL data source name. +func WithDSN(dsn string) Option { + return func(c *config) { + c.dsn = dsn + } +} + +// WithDriver sets a custom MySQL driver for the storage. +// Default driver is "mysql" but is ignored if WithDB is used. +func WithDriver(driver string) Option { + return func(c *config) { + c.driver = driver + } +} + +// WithDB sets a custom MySQL *sql.DB to the storage. +// If set, driver passed via WithDriver is ignored. +func WithDB(db *sql.DB) Option { + return func(c *config) { + c.db = db + } +} + +// New creates and returns a new MySQL. +func New(opts ...Option) (*MySQLStorage, error) { + cfg := &config{driver: "mysql"} + for _, opt := range opts { + opt(cfg) + } + var err error + if cfg.db == nil { + cfg.db, err = sql.Open(cfg.driver, cfg.dsn) + if err != nil { + return nil, err + } + } + if err = cfg.db.Ping(); err != nil { + return nil, err + } + return &MySQLStorage{ + db: cfg.db, + q: sqlc.New(cfg.db), + rand: rand.New(rand.NewSource(time.Now().UnixNano())), + }, nil +} + +// sqlNullString sets Valid to true of the return value of s is not empty. +func sqlNullString(s string) sql.NullString { + return sql.NullString{String: s, Valid: s != ""} +} + +// sqlNullTime sets Valid to true of the return value of t is not zero. +func sqlNullTime(t time.Time) sql.NullTime { + return sql.NullTime{Valid: !t.IsZero(), Time: t} +} + +// txcb executes SQL within transactions when wrapped in tx(). +type txcb func(ctx context.Context, tx *sql.Tx, qtx *sqlc.Queries) error + +// tx wraps g in transactions using db. +// If g returns an err the transaction will be rolled back; otherwise committed. +func tx(ctx context.Context, db *sql.DB, q *sqlc.Queries, g txcb) error { + tx, err := db.BeginTx(ctx, nil) + if err != nil { + return fmt.Errorf("tx begin: %w", err) + } + if err = g(ctx, tx, q.WithTx(tx)); err != nil { + if rbErr := tx.Rollback(); rbErr != nil { + return fmt.Errorf("tx rollback: %w; while trying to handle error: %v", rbErr, err) + } + return fmt.Errorf("tx rolled back: %w", err) + } + if err = tx.Commit(); err != nil { + return fmt.Errorf("tx commit: %w", err) + } + return nil +} diff --git a/engine/storage/mysql/mysql_test.go b/engine/storage/mysql/mysql_test.go new file mode 100644 index 0000000..ac2b7c2 --- /dev/null +++ b/engine/storage/mysql/mysql_test.go @@ -0,0 +1,25 @@ +package mysql + +import ( + "os" + "testing" + + "github.com/micromdm/nanocmd/engine/storage" + "github.com/micromdm/nanocmd/engine/storage/test" + + _ "github.com/go-sql-driver/mysql" +) + +func TestMySQLStorage(t *testing.T) { + testDSN := os.Getenv("NANOCMD_MYSQL_STORAGE_TEST_DSN") + if testDSN == "" { + t.Skip("NANOCMD_MYSQL_STORAGE_TEST_DSN not set") + } + + s, err := New(WithDSN(testDSN)) + if err != nil { + t.Fatal(err) + } + + test.TestEngineStorage(t, func() storage.AllStorage { return s }) +} diff --git a/engine/storage/mysql/query.sql b/engine/storage/mysql/query.sql new file mode 100644 index 0000000..b3d6036 --- /dev/null +++ b/engine/storage/mysql/query.sql @@ -0,0 +1,155 @@ +-- name: GetRequestType :one +SELECT + request_type +FROM + id_commands +WHERE + enrollment_id = ? AND + command_uuid = ?; + +-- name: CreateStep :execlastid +INSERT INTO steps + (workflow_name, instance_id, step_name, context, not_until, timeout) +VALUES + (?, ?, ?, ?, ?, ?); + +-- name: CreateIDCommand :exec +INSERT INTO id_commands + (enrollment_id, command_uuid, step_id, request_type, last_push) +VALUES + (?, ?, ?, ?, ?); + +-- name: DeleteIDCommandByWorkflow :exec +DELETE + c +FROM + id_commands c + INNER JOIN steps s + ON c.step_id = s.id +WHERE + c.enrollment_id = ? AND + s.workflow_name = ?; + +-- name: DeleteIDCommands :exec +DELETE FROM + id_commands +WHERE + enrollment_id = ?; + +-- name: DeleteUnusedStepCommands :exec +DELETE + sc +FROM + step_commands sc + LEFT JOIN id_commands c + ON sc.command_uuid = c.command_uuid +WHERE + c.command_uuid IS NULL; + +-- name: DeleteWorkflowStepHavingNoCommands :exec +DELETE + s +FROM + steps s + LEFT JOIN id_commands c + ON s.id = c.step_id +WHERE + c.step_id IS NULL; + +-- name: DeleteWorkflowStepHavingNoCommandsByWorkflowName :exec +DELETE + s +FROM + steps s + LEFT JOIN id_commands c + ON s.id = c.step_id +WHERE + c.step_id IS NULL AND + s.workflow_name = ?; + +-- name: UpdateIDCommandTimestamp :exec +UPDATE + id_commands +SET + updated_at = CURRENT_TIMESTAMP +WHERE + enrollment_id = ? AND + command_uuid = ? +LIMIT 1; + +-- name: UpdateIDCommand :exec +UPDATE + id_commands +SET + completed = ?, + result = ? +WHERE + enrollment_id = ? AND + command_uuid = ? +LIMIT 1; + +-- name: CountOutstandingIDWorkflowStepCommands :one +SELECT + COUNT(*), + c1.step_id +FROM + id_commands c1 + JOIN id_commands c2 + ON c1.step_id = c2.step_id +WHERE + c1.enrollment_id = ? AND + c1.completed = 0 AND + c2.enrollment_id = c1.enrollment_id AND + c2.command_uuid = ? +GROUP BY + c1.step_id +LIMIT 1; + +-- name: GetStepByID :one +SELECT + workflow_name, + instance_id, + step_name, + context +FROM + steps +WHERE + id = ?; + +-- name: GetIDCommandsByStepID :many +SELECT + command_uuid, + request_type, + result +FROM + id_commands +WHERE + enrollment_id = ? AND + step_id = ? AND + completed != 0; + +-- name: RemoveIDCommandsByStepID :exec +DELETE FROM + id_commands +WHERE + enrollment_id = ? AND + step_id = ?; + +-- name: CreateStepCommand :exec +INSERT INTO step_commands + (step_id, command_uuid, request_type, command) +VALUES + (?, ?, ?, ?); + +-- name: GetOutstandingIDs :many +SELECT DISTINCT + c.enrollment_id +FROM + id_commands c + JOIN steps s + ON s.id = c.step_id +WHERE + c.enrollment_id IN (sqlc.slice('ids')) AND + c.completed = 0 AND + s.workflow_name = ?; + diff --git a/engine/storage/mysql/query_event.sql b/engine/storage/mysql/query_event.sql new file mode 100644 index 0000000..ff161e8 --- /dev/null +++ b/engine/storage/mysql/query_event.sql @@ -0,0 +1,24 @@ +-- name: GetEventsByNames :many +SELECT + event_name, + context, + workflow_name, + event_type +FROM + wf_events +WHERE + event_name IN (sqlc.slice('names')); + +-- name: GetEventsByType :many +SELECT + context, + workflow_name, + event_type +FROM + wf_events +WHERE + event_type = ?; + +-- name: RemoveEvent :exec +DELETE FROM wf_events WHERE event_name = ?; + diff --git a/engine/storage/mysql/query_worker.sql b/engine/storage/mysql/query_worker.sql new file mode 100644 index 0000000..1187d80 --- /dev/null +++ b/engine/storage/mysql/query_worker.sql @@ -0,0 +1,131 @@ +-- name: UpdateStepAfterNotUntil :exec +UPDATE + steps +SET + process_id = ? +WHERE + process_id IS NULL AND + not_until < ?; + +-- name: GetStepsByProcessID :many +SELECT + id, + workflow_name, + instance_id, + step_name +FROM + steps +WHERE + process_id = ?; + +-- name: GetStepCommandsByProcessID :many +SELECT + sc.step_id, + sc.command_uuid, + sc.request_type, + sc.command +FROM + step_commands sc + JOIN steps s + ON sc.step_id = s.id +WHERE + s.process_id = ?; + +-- name: GetIDCommandIDsByProcessID :many +SELECT + step_id, + enrollment_id +FROM + id_commands c + JOIN steps s + ON c.step_id = s.id +WHERE + s.process_id = ?; + +-- name: RemoveStepCommandsByProcessID :exec +DELETE sc FROM + step_commands sc + JOIN steps s + ON sc.step_id = s.id +WHERE + s.process_id = ?; + +-- name: UpdateLastPushByProcessID :exec +UPDATE + id_commands c + JOIN steps s + ON c.step_id = s.id +SET + c.last_push = CURRENT_TIMESTAMP +WHERE + s.process_id = ?; + +-- name: UpdateStepAfterTimeout :exec +UPDATE + steps +SET + process_id = ? +WHERE + process_id IS NULL AND + timeout <= ?; + +-- name: GetStepsWithContextByProcessID :many +SELECT + id, + workflow_name, + instance_id, + step_name, + context +FROM + steps +WHERE + process_id = ?; + +-- name: GetIDCommandDetailsByProcessID :many +SELECT + step_id, + enrollment_id, + command_uuid, + request_type, + completed, + result +FROM + id_commands c + JOIN steps s + ON c.step_id = s.id +WHERE + s.process_id = ? +ORDER BY + step_id, enrollment_id; + +-- name: RemoveIDCommandsByProcessID :exec +DELETE sc FROM + id_commands sc + JOIN steps s + ON sc.step_id = s.id +WHERE + s.process_id = ?; + +-- name: RemoveStepsByProcessID :exec +DELETE FROM + steps +WHERE + process_id = ?; + +-- name: GetRePushIDs :many +SELECT DISTINCT + enrollment_id +FROM + id_commands +WHERE + last_push IS NOT NULL AND + last_push < sqlc.arg(before); + +-- name: UpdateRePushIDs :exec +UPDATE + id_commands +SET + last_push = ? +WHERE + last_push IS NOT NULL AND + last_push < sqlc.arg(before); diff --git a/engine/storage/mysql/schema.sql b/engine/storage/mysql/schema.sql new file mode 100644 index 0000000..96213c4 --- /dev/null +++ b/engine/storage/mysql/schema.sql @@ -0,0 +1,80 @@ +CREATE TABLE steps ( + id BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY, + + workflow_name VARCHAR(255) NOT NULL, + instance_id VARCHAR(255) NOT NULL, + step_name VARCHAR(255) NULL, + + context MEDIUMTEXT NULL, + + not_until TIMESTAMP NULL, + timeout TIMESTAMP NULL, + + process_id CHAR(45) NULL, + + INDEX (workflow_name), + INDEX (not_until), + INDEX (timeout), + INDEX (process_id), + + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP +); + +CREATE TABLE id_commands ( + enrollment_id VARCHAR(255) NOT NULL, + command_uuid VARCHAR(127) NOT NULL, + + step_id BIGINT NOT NULL, + + request_type VARCHAR(63) NOT NULL, + completed BOOLEAN NOT NULL DEFAULT 0, + result MEDIUMTEXT NULL, + + last_push TIMESTAMP NULL, + + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + + INDEX (enrollment_id), + INDEX (command_uuid), + INDEX (step_id), + INDEX (completed), + INDEX (last_push), + + FOREIGN KEY (step_id) + REFERENCES steps (id), + + PRIMARY KEY (enrollment_id, command_uuid) +); + +CREATE TABLE step_commands ( + command_uuid VARCHAR(127) NOT NULL, + step_id BIGINT NOT NULL, + + command MEDIUMTEXT NOT NULL, + request_type VARCHAR(63) NOT NULL, + + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + + FOREIGN KEY (step_id) + REFERENCES steps (id), + + PRIMARY KEY (command_uuid, step_id) +); + +CREATE TABLE wf_events ( + event_name VARCHAR(255) NOT NULL, + + context MEDIUMTEXT NULL, + workflow_name VARCHAR(255) NOT NULL, + event_type VARCHAR(63) NOT NULL, + + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, + + INDEX (event_type), + + PRIMARY KEY (event_name) +); \ No newline at end of file diff --git a/engine/storage/mysql/sqlc.yaml b/engine/storage/mysql/sqlc.yaml new file mode 100644 index 0000000..ac70b4f --- /dev/null +++ b/engine/storage/mysql/sqlc.yaml @@ -0,0 +1,25 @@ +version: 2 +sql: + - engine: "mysql" + queries: + - "query.sql" + - "query_event.sql" + - "query_worker.sql" + schema: "schema.sql" + gen: + go: + package: "sqlc" + out: "sqlc" + overrides: + - column: "steps.context" + go_type: + type: "byte" + slice: true + - column: "step_commands.command" + go_type: + type: "byte" + slice: true + - column: "id_commands.result" + go_type: + type: "byte" + slice: true diff --git a/engine/storage/mysql/sqlc/db.go b/engine/storage/mysql/sqlc/db.go new file mode 100644 index 0000000..6a77d41 --- /dev/null +++ b/engine/storage/mysql/sqlc/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.21.0 + +package sqlc + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/engine/storage/mysql/sqlc/models.go b/engine/storage/mysql/sqlc/models.go new file mode 100644 index 0000000..af8de69 --- /dev/null +++ b/engine/storage/mysql/sqlc/models.go @@ -0,0 +1,52 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.21.0 + +package sqlc + +import ( + "database/sql" +) + +type IDCommand struct { + EnrollmentID string + CommandUuid string + StepID int64 + RequestType string + Completed bool + Result []byte + LastPush sql.NullTime + CreatedAt sql.NullTime + UpdatedAt sql.NullTime +} + +type Step struct { + ID int64 + WorkflowName string + InstanceID string + StepName sql.NullString + Context []byte + NotUntil sql.NullTime + Timeout sql.NullTime + ProcessID sql.NullString + CreatedAt sql.NullTime + UpdatedAt sql.NullTime +} + +type StepCommand struct { + CommandUuid string + StepID int64 + Command []byte + RequestType string + CreatedAt sql.NullTime + UpdatedAt sql.NullTime +} + +type WfEvent struct { + EventName string + Context sql.NullString + WorkflowName string + EventType string + CreatedAt sql.NullTime + UpdatedAt sql.NullTime +} diff --git a/engine/storage/mysql/sqlc/query.sql.go b/engine/storage/mysql/sqlc/query.sql.go new file mode 100644 index 0000000..cd7e2aa --- /dev/null +++ b/engine/storage/mysql/sqlc/query.sql.go @@ -0,0 +1,431 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.21.0 +// source: query.sql + +package sqlc + +import ( + "context" + "database/sql" + "strings" +) + +const countOutstandingIDWorkflowStepCommands = `-- name: CountOutstandingIDWorkflowStepCommands :one +SELECT + COUNT(*), + c1.step_id +FROM + id_commands c1 + JOIN id_commands c2 + ON c1.step_id = c2.step_id +WHERE + c1.enrollment_id = ? AND + c1.completed = 0 AND + c2.enrollment_id = c1.enrollment_id AND + c2.command_uuid = ? +GROUP BY + c1.step_id +LIMIT 1 +` + +type CountOutstandingIDWorkflowStepCommandsParams struct { + EnrollmentID string + CommandUuid string +} + +type CountOutstandingIDWorkflowStepCommandsRow struct { + Count int64 + StepID int64 +} + +func (q *Queries) CountOutstandingIDWorkflowStepCommands(ctx context.Context, arg CountOutstandingIDWorkflowStepCommandsParams) (CountOutstandingIDWorkflowStepCommandsRow, error) { + row := q.db.QueryRowContext(ctx, countOutstandingIDWorkflowStepCommands, arg.EnrollmentID, arg.CommandUuid) + var i CountOutstandingIDWorkflowStepCommandsRow + err := row.Scan(&i.Count, &i.StepID) + return i, err +} + +const createIDCommand = `-- name: CreateIDCommand :exec +INSERT INTO id_commands + (enrollment_id, command_uuid, step_id, request_type, last_push) +VALUES + (?, ?, ?, ?, ?) +` + +type CreateIDCommandParams struct { + EnrollmentID string + CommandUuid string + StepID int64 + RequestType string + LastPush sql.NullTime +} + +func (q *Queries) CreateIDCommand(ctx context.Context, arg CreateIDCommandParams) error { + _, err := q.db.ExecContext(ctx, createIDCommand, + arg.EnrollmentID, + arg.CommandUuid, + arg.StepID, + arg.RequestType, + arg.LastPush, + ) + return err +} + +const createStep = `-- name: CreateStep :execlastid +INSERT INTO steps + (workflow_name, instance_id, step_name, context, not_until, timeout) +VALUES + (?, ?, ?, ?, ?, ?) +` + +type CreateStepParams struct { + WorkflowName string + InstanceID string + StepName sql.NullString + Context []byte + NotUntil sql.NullTime + Timeout sql.NullTime +} + +func (q *Queries) CreateStep(ctx context.Context, arg CreateStepParams) (int64, error) { + result, err := q.db.ExecContext(ctx, createStep, + arg.WorkflowName, + arg.InstanceID, + arg.StepName, + arg.Context, + arg.NotUntil, + arg.Timeout, + ) + if err != nil { + return 0, err + } + return result.LastInsertId() +} + +const createStepCommand = `-- name: CreateStepCommand :exec +INSERT INTO step_commands + (step_id, command_uuid, request_type, command) +VALUES + (?, ?, ?, ?) +` + +type CreateStepCommandParams struct { + StepID int64 + CommandUuid string + RequestType string + Command []byte +} + +func (q *Queries) CreateStepCommand(ctx context.Context, arg CreateStepCommandParams) error { + _, err := q.db.ExecContext(ctx, createStepCommand, + arg.StepID, + arg.CommandUuid, + arg.RequestType, + arg.Command, + ) + return err +} + +const deleteIDCommandByWorkflow = `-- name: DeleteIDCommandByWorkflow :exec +DELETE + c +FROM + id_commands c + INNER JOIN steps s + ON c.step_id = s.id +WHERE + c.enrollment_id = ? AND + s.workflow_name = ? +` + +type DeleteIDCommandByWorkflowParams struct { + EnrollmentID string + WorkflowName string +} + +func (q *Queries) DeleteIDCommandByWorkflow(ctx context.Context, arg DeleteIDCommandByWorkflowParams) error { + _, err := q.db.ExecContext(ctx, deleteIDCommandByWorkflow, arg.EnrollmentID, arg.WorkflowName) + return err +} + +const deleteIDCommands = `-- name: DeleteIDCommands :exec +DELETE FROM + id_commands +WHERE + enrollment_id = ? +` + +func (q *Queries) DeleteIDCommands(ctx context.Context, enrollmentID string) error { + _, err := q.db.ExecContext(ctx, deleteIDCommands, enrollmentID) + return err +} + +const deleteUnusedStepCommands = `-- name: DeleteUnusedStepCommands :exec +DELETE + sc +FROM + step_commands sc + LEFT JOIN id_commands c + ON sc.command_uuid = c.command_uuid +WHERE + c.command_uuid IS NULL +` + +func (q *Queries) DeleteUnusedStepCommands(ctx context.Context) error { + _, err := q.db.ExecContext(ctx, deleteUnusedStepCommands) + return err +} + +const deleteWorkflowStepHavingNoCommands = `-- name: DeleteWorkflowStepHavingNoCommands :exec +DELETE + s +FROM + steps s + LEFT JOIN id_commands c + ON s.id = c.step_id +WHERE + c.step_id IS NULL +` + +func (q *Queries) DeleteWorkflowStepHavingNoCommands(ctx context.Context) error { + _, err := q.db.ExecContext(ctx, deleteWorkflowStepHavingNoCommands) + return err +} + +const deleteWorkflowStepHavingNoCommandsByWorkflowName = `-- name: DeleteWorkflowStepHavingNoCommandsByWorkflowName :exec +DELETE + s +FROM + steps s + LEFT JOIN id_commands c + ON s.id = c.step_id +WHERE + c.step_id IS NULL AND + s.workflow_name = ? +` + +func (q *Queries) DeleteWorkflowStepHavingNoCommandsByWorkflowName(ctx context.Context, workflowName string) error { + _, err := q.db.ExecContext(ctx, deleteWorkflowStepHavingNoCommandsByWorkflowName, workflowName) + return err +} + +const getIDCommandsByStepID = `-- name: GetIDCommandsByStepID :many +SELECT + command_uuid, + request_type, + result +FROM + id_commands +WHERE + enrollment_id = ? AND + step_id = ? AND + completed != 0 +` + +type GetIDCommandsByStepIDParams struct { + EnrollmentID string + StepID int64 +} + +type GetIDCommandsByStepIDRow struct { + CommandUuid string + RequestType string + Result []byte +} + +func (q *Queries) GetIDCommandsByStepID(ctx context.Context, arg GetIDCommandsByStepIDParams) ([]GetIDCommandsByStepIDRow, error) { + rows, err := q.db.QueryContext(ctx, getIDCommandsByStepID, arg.EnrollmentID, arg.StepID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetIDCommandsByStepIDRow + for rows.Next() { + var i GetIDCommandsByStepIDRow + if err := rows.Scan(&i.CommandUuid, &i.RequestType, &i.Result); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getOutstandingIDs = `-- name: GetOutstandingIDs :many +SELECT DISTINCT + c.enrollment_id +FROM + id_commands c + JOIN steps s + ON s.id = c.step_id +WHERE + c.enrollment_id IN (/*SLICE:ids*/?) AND + c.completed = 0 AND + s.workflow_name = ? +` + +type GetOutstandingIDsParams struct { + Ids []string + WorkflowName string +} + +func (q *Queries) GetOutstandingIDs(ctx context.Context, arg GetOutstandingIDsParams) ([]string, error) { + query := getOutstandingIDs + var queryParams []interface{} + if len(arg.Ids) > 0 { + for _, v := range arg.Ids { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:ids*/?", strings.Repeat(",?", len(arg.Ids))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:ids*/?", "NULL", 1) + } + queryParams = append(queryParams, arg.WorkflowName) + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var enrollment_id string + if err := rows.Scan(&enrollment_id); err != nil { + return nil, err + } + items = append(items, enrollment_id) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getRequestType = `-- name: GetRequestType :one +SELECT + request_type +FROM + id_commands +WHERE + enrollment_id = ? AND + command_uuid = ? +` + +type GetRequestTypeParams struct { + EnrollmentID string + CommandUuid string +} + +func (q *Queries) GetRequestType(ctx context.Context, arg GetRequestTypeParams) (string, error) { + row := q.db.QueryRowContext(ctx, getRequestType, arg.EnrollmentID, arg.CommandUuid) + var request_type string + err := row.Scan(&request_type) + return request_type, err +} + +const getStepByID = `-- name: GetStepByID :one +SELECT + workflow_name, + instance_id, + step_name, + context +FROM + steps +WHERE + id = ? +` + +type GetStepByIDRow struct { + WorkflowName string + InstanceID string + StepName sql.NullString + Context []byte +} + +func (q *Queries) GetStepByID(ctx context.Context, id int64) (GetStepByIDRow, error) { + row := q.db.QueryRowContext(ctx, getStepByID, id) + var i GetStepByIDRow + err := row.Scan( + &i.WorkflowName, + &i.InstanceID, + &i.StepName, + &i.Context, + ) + return i, err +} + +const removeIDCommandsByStepID = `-- name: RemoveIDCommandsByStepID :exec +DELETE FROM + id_commands +WHERE + enrollment_id = ? AND + step_id = ? +` + +type RemoveIDCommandsByStepIDParams struct { + EnrollmentID string + StepID int64 +} + +func (q *Queries) RemoveIDCommandsByStepID(ctx context.Context, arg RemoveIDCommandsByStepIDParams) error { + _, err := q.db.ExecContext(ctx, removeIDCommandsByStepID, arg.EnrollmentID, arg.StepID) + return err +} + +const updateIDCommand = `-- name: UpdateIDCommand :exec +UPDATE + id_commands +SET + completed = ?, + result = ? +WHERE + enrollment_id = ? AND + command_uuid = ? +LIMIT 1 +` + +type UpdateIDCommandParams struct { + Completed bool + Result []byte + EnrollmentID string + CommandUuid string +} + +func (q *Queries) UpdateIDCommand(ctx context.Context, arg UpdateIDCommandParams) error { + _, err := q.db.ExecContext(ctx, updateIDCommand, + arg.Completed, + arg.Result, + arg.EnrollmentID, + arg.CommandUuid, + ) + return err +} + +const updateIDCommandTimestamp = `-- name: UpdateIDCommandTimestamp :exec +UPDATE + id_commands +SET + updated_at = CURRENT_TIMESTAMP +WHERE + enrollment_id = ? AND + command_uuid = ? +LIMIT 1 +` + +type UpdateIDCommandTimestampParams struct { + EnrollmentID string + CommandUuid string +} + +func (q *Queries) UpdateIDCommandTimestamp(ctx context.Context, arg UpdateIDCommandTimestampParams) error { + _, err := q.db.ExecContext(ctx, updateIDCommandTimestamp, arg.EnrollmentID, arg.CommandUuid) + return err +} diff --git a/engine/storage/mysql/sqlc/query_event.sql.go b/engine/storage/mysql/sqlc/query_event.sql.go new file mode 100644 index 0000000..37ed716 --- /dev/null +++ b/engine/storage/mysql/sqlc/query_event.sql.go @@ -0,0 +1,118 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.21.0 +// source: query_event.sql + +package sqlc + +import ( + "context" + "database/sql" + "strings" +) + +const getEventsByNames = `-- name: GetEventsByNames :many +SELECT + event_name, + context, + workflow_name, + event_type +FROM + wf_events +WHERE + event_name IN (/*SLICE:names*/?) +` + +type GetEventsByNamesRow struct { + EventName string + Context sql.NullString + WorkflowName string + EventType string +} + +func (q *Queries) GetEventsByNames(ctx context.Context, names []string) ([]GetEventsByNamesRow, error) { + query := getEventsByNames + var queryParams []interface{} + if len(names) > 0 { + for _, v := range names { + queryParams = append(queryParams, v) + } + query = strings.Replace(query, "/*SLICE:names*/?", strings.Repeat(",?", len(names))[1:], 1) + } else { + query = strings.Replace(query, "/*SLICE:names*/?", "NULL", 1) + } + rows, err := q.db.QueryContext(ctx, query, queryParams...) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetEventsByNamesRow + for rows.Next() { + var i GetEventsByNamesRow + if err := rows.Scan( + &i.EventName, + &i.Context, + &i.WorkflowName, + &i.EventType, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getEventsByType = `-- name: GetEventsByType :many +SELECT + context, + workflow_name, + event_type +FROM + wf_events +WHERE + event_type = ? +` + +type GetEventsByTypeRow struct { + Context sql.NullString + WorkflowName string + EventType string +} + +func (q *Queries) GetEventsByType(ctx context.Context, eventType string) ([]GetEventsByTypeRow, error) { + rows, err := q.db.QueryContext(ctx, getEventsByType, eventType) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetEventsByTypeRow + for rows.Next() { + var i GetEventsByTypeRow + if err := rows.Scan(&i.Context, &i.WorkflowName, &i.EventType); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const removeEvent = `-- name: RemoveEvent :exec +DELETE FROM wf_events WHERE event_name = ? +` + +func (q *Queries) RemoveEvent(ctx context.Context, eventName string) error { + _, err := q.db.ExecContext(ctx, removeEvent, eventName) + return err +} diff --git a/engine/storage/mysql/sqlc/query_worker.sql.go b/engine/storage/mysql/sqlc/query_worker.sql.go new file mode 100644 index 0000000..785e25f --- /dev/null +++ b/engine/storage/mysql/sqlc/query_worker.sql.go @@ -0,0 +1,403 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.21.0 +// source: query_worker.sql + +package sqlc + +import ( + "context" + "database/sql" +) + +const getIDCommandDetailsByProcessID = `-- name: GetIDCommandDetailsByProcessID :many +SELECT + step_id, + enrollment_id, + command_uuid, + request_type, + completed, + result +FROM + id_commands c + JOIN steps s + ON c.step_id = s.id +WHERE + s.process_id = ? +ORDER BY + step_id, enrollment_id +` + +type GetIDCommandDetailsByProcessIDRow struct { + StepID int64 + EnrollmentID string + CommandUuid string + RequestType string + Completed bool + Result []byte +} + +func (q *Queries) GetIDCommandDetailsByProcessID(ctx context.Context, processID sql.NullString) ([]GetIDCommandDetailsByProcessIDRow, error) { + rows, err := q.db.QueryContext(ctx, getIDCommandDetailsByProcessID, processID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetIDCommandDetailsByProcessIDRow + for rows.Next() { + var i GetIDCommandDetailsByProcessIDRow + if err := rows.Scan( + &i.StepID, + &i.EnrollmentID, + &i.CommandUuid, + &i.RequestType, + &i.Completed, + &i.Result, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getIDCommandIDsByProcessID = `-- name: GetIDCommandIDsByProcessID :many +SELECT + step_id, + enrollment_id +FROM + id_commands c + JOIN steps s + ON c.step_id = s.id +WHERE + s.process_id = ? +` + +type GetIDCommandIDsByProcessIDRow struct { + StepID int64 + EnrollmentID string +} + +func (q *Queries) GetIDCommandIDsByProcessID(ctx context.Context, processID sql.NullString) ([]GetIDCommandIDsByProcessIDRow, error) { + rows, err := q.db.QueryContext(ctx, getIDCommandIDsByProcessID, processID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetIDCommandIDsByProcessIDRow + for rows.Next() { + var i GetIDCommandIDsByProcessIDRow + if err := rows.Scan(&i.StepID, &i.EnrollmentID); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getRePushIDs = `-- name: GetRePushIDs :many +SELECT DISTINCT + enrollment_id +FROM + id_commands +WHERE + last_push IS NOT NULL AND + last_push < ? +` + +func (q *Queries) GetRePushIDs(ctx context.Context, before sql.NullTime) ([]string, error) { + rows, err := q.db.QueryContext(ctx, getRePushIDs, before) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var enrollment_id string + if err := rows.Scan(&enrollment_id); err != nil { + return nil, err + } + items = append(items, enrollment_id) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getStepCommandsByProcessID = `-- name: GetStepCommandsByProcessID :many +SELECT + sc.step_id, + sc.command_uuid, + sc.request_type, + sc.command +FROM + step_commands sc + JOIN steps s + ON sc.step_id = s.id +WHERE + s.process_id = ? +` + +type GetStepCommandsByProcessIDRow struct { + StepID int64 + CommandUuid string + RequestType string + Command []byte +} + +func (q *Queries) GetStepCommandsByProcessID(ctx context.Context, processID sql.NullString) ([]GetStepCommandsByProcessIDRow, error) { + rows, err := q.db.QueryContext(ctx, getStepCommandsByProcessID, processID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetStepCommandsByProcessIDRow + for rows.Next() { + var i GetStepCommandsByProcessIDRow + if err := rows.Scan( + &i.StepID, + &i.CommandUuid, + &i.RequestType, + &i.Command, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getStepsByProcessID = `-- name: GetStepsByProcessID :many +SELECT + id, + workflow_name, + instance_id, + step_name +FROM + steps +WHERE + process_id = ? +` + +type GetStepsByProcessIDRow struct { + ID int64 + WorkflowName string + InstanceID string + StepName sql.NullString +} + +func (q *Queries) GetStepsByProcessID(ctx context.Context, processID sql.NullString) ([]GetStepsByProcessIDRow, error) { + rows, err := q.db.QueryContext(ctx, getStepsByProcessID, processID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetStepsByProcessIDRow + for rows.Next() { + var i GetStepsByProcessIDRow + if err := rows.Scan( + &i.ID, + &i.WorkflowName, + &i.InstanceID, + &i.StepName, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getStepsWithContextByProcessID = `-- name: GetStepsWithContextByProcessID :many +SELECT + id, + workflow_name, + instance_id, + step_name, + context +FROM + steps +WHERE + process_id = ? +` + +type GetStepsWithContextByProcessIDRow struct { + ID int64 + WorkflowName string + InstanceID string + StepName sql.NullString + Context []byte +} + +func (q *Queries) GetStepsWithContextByProcessID(ctx context.Context, processID sql.NullString) ([]GetStepsWithContextByProcessIDRow, error) { + rows, err := q.db.QueryContext(ctx, getStepsWithContextByProcessID, processID) + if err != nil { + return nil, err + } + defer rows.Close() + var items []GetStepsWithContextByProcessIDRow + for rows.Next() { + var i GetStepsWithContextByProcessIDRow + if err := rows.Scan( + &i.ID, + &i.WorkflowName, + &i.InstanceID, + &i.StepName, + &i.Context, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const removeIDCommandsByProcessID = `-- name: RemoveIDCommandsByProcessID :exec +DELETE sc FROM + id_commands sc + JOIN steps s + ON sc.step_id = s.id +WHERE + s.process_id = ? +` + +func (q *Queries) RemoveIDCommandsByProcessID(ctx context.Context, processID sql.NullString) error { + _, err := q.db.ExecContext(ctx, removeIDCommandsByProcessID, processID) + return err +} + +const removeStepCommandsByProcessID = `-- name: RemoveStepCommandsByProcessID :exec +DELETE sc FROM + step_commands sc + JOIN steps s + ON sc.step_id = s.id +WHERE + s.process_id = ? +` + +func (q *Queries) RemoveStepCommandsByProcessID(ctx context.Context, processID sql.NullString) error { + _, err := q.db.ExecContext(ctx, removeStepCommandsByProcessID, processID) + return err +} + +const removeStepsByProcessID = `-- name: RemoveStepsByProcessID :exec +DELETE FROM + steps +WHERE + process_id = ? +` + +func (q *Queries) RemoveStepsByProcessID(ctx context.Context, processID sql.NullString) error { + _, err := q.db.ExecContext(ctx, removeStepsByProcessID, processID) + return err +} + +const updateLastPushByProcessID = `-- name: UpdateLastPushByProcessID :exec +UPDATE + id_commands c + JOIN steps s + ON c.step_id = s.id +SET + c.last_push = CURRENT_TIMESTAMP +WHERE + s.process_id = ? +` + +func (q *Queries) UpdateLastPushByProcessID(ctx context.Context, processID sql.NullString) error { + _, err := q.db.ExecContext(ctx, updateLastPushByProcessID, processID) + return err +} + +const updateRePushIDs = `-- name: UpdateRePushIDs :exec +UPDATE + id_commands +SET + last_push = ? +WHERE + last_push IS NOT NULL AND + last_push < ? +` + +type UpdateRePushIDsParams struct { + LastPush sql.NullTime + Before sql.NullTime +} + +func (q *Queries) UpdateRePushIDs(ctx context.Context, arg UpdateRePushIDsParams) error { + _, err := q.db.ExecContext(ctx, updateRePushIDs, arg.LastPush, arg.Before) + return err +} + +const updateStepAfterNotUntil = `-- name: UpdateStepAfterNotUntil :exec +UPDATE + steps +SET + process_id = ? +WHERE + process_id IS NULL AND + not_until < ? +` + +type UpdateStepAfterNotUntilParams struct { + ProcessID sql.NullString + NotUntil sql.NullTime +} + +func (q *Queries) UpdateStepAfterNotUntil(ctx context.Context, arg UpdateStepAfterNotUntilParams) error { + _, err := q.db.ExecContext(ctx, updateStepAfterNotUntil, arg.ProcessID, arg.NotUntil) + return err +} + +const updateStepAfterTimeout = `-- name: UpdateStepAfterTimeout :exec +UPDATE + steps +SET + process_id = ? +WHERE + process_id IS NULL AND + timeout <= ? +` + +type UpdateStepAfterTimeoutParams struct { + ProcessID sql.NullString + Timeout sql.NullTime +} + +func (q *Queries) UpdateStepAfterTimeout(ctx context.Context, arg UpdateStepAfterTimeoutParams) error { + _, err := q.db.ExecContext(ctx, updateStepAfterTimeout, arg.ProcessID, arg.Timeout) + return err +} diff --git a/engine/storage/mysql/storage.go b/engine/storage/mysql/storage.go new file mode 100644 index 0000000..0ce8271 --- /dev/null +++ b/engine/storage/mysql/storage.go @@ -0,0 +1,243 @@ +package mysql + +import ( + "context" + "database/sql" + "errors" + "fmt" + "time" + + "github.com/micromdm/nanocmd/engine/storage" + "github.com/micromdm/nanocmd/engine/storage/mysql/sqlc" +) + +// RetrieveCommandRequestType retrieves a command request type given id and uuid. +// See the storage interface type for further docs. +func (s *MySQLStorage) RetrieveCommandRequestType(ctx context.Context, id string, uuid string) (string, bool, error) { + if id == "" || uuid == "" { + return "", false, errors.New("empty id or command uuid") + } + reqType, err := s.q.GetRequestType(ctx, sqlc.GetRequestTypeParams{EnrollmentID: id, CommandUuid: uuid}) + if errors.Is(err, sql.ErrNoRows) { + return "", false, nil + } + return reqType, reqType != "", err +} + +// StoreCommandResponseAndRetrieveCompletedStep stores a command response and returns the completed step for the id. +// See the storage interface type for further docs. +func (s *MySQLStorage) StoreCommandResponseAndRetrieveCompletedStep(ctx context.Context, id string, sc *storage.StepCommandResult) (*storage.StepResult, error) { + if sc == nil { + return nil, errors.New("nil storage command") + } + if !sc.Completed { + // if this command is not completed (i.e. NotNow) then the step cannot be completed, either. + err := s.q.UpdateIDCommandTimestamp(ctx, sqlc.UpdateIDCommandTimestampParams{ + EnrollmentID: id, + CommandUuid: sc.CommandUUID, + }) + if err != nil { + err = fmt.Errorf("updating id command timestamp: %w", err) + } + return nil, err + } + + cmdCt, err := s.q.CountOutstandingIDWorkflowStepCommands( + ctx, + sqlc.CountOutstandingIDWorkflowStepCommandsParams{ + EnrollmentID: id, + CommandUuid: sc.CommandUUID, + }, + ) + if err != nil { + return nil, fmt.Errorf("counting outstanding id workflow steps: %w", err) + } + if cmdCt.StepID < 1 { + return nil, fmt.Errorf("no step ID found (id=%s, uuid=%s)", id, sc.CommandUUID) + } + + if cmdCt.Count > 1 { + // if there are other uncompleted commands for us for this step + // then just update this commands results for another command + // to come in. + err = s.q.UpdateIDCommand(ctx, sqlc.UpdateIDCommandParams{ + Completed: sc.Completed, + Result: sc.ResultReport, + // where + EnrollmentID: id, + CommandUuid: sc.CommandUUID, + }) + if err != nil { + return nil, fmt.Errorf("updating id command: %w", err) + } + return nil, nil + } + + // reaching here implies this is the last command to be completed + // for the workflow step, for this instance ID for this enrollment ID. + + var ret *storage.StepResult + + err = tx(ctx, s.db, s.q, func(ctx context.Context, _ *sql.Tx, qtx *sqlc.Queries) error { + sd, err := qtx.GetStepByID(ctx, cmdCt.StepID) + if err != nil { + return fmt.Errorf("get step by id (%d): %w", cmdCt.StepID, err) + } + + ret = &storage.StepResult{ + IDs: []string{id}, + StepContext: storage.StepContext{ + WorkflowName: sd.WorkflowName, + InstanceID: sd.InstanceID, + Name: sd.StepName.String, + Context: sd.Context, + }, + // this command result + Commands: []storage.StepCommandResult{*sc}, + } + + // TODO: select ... for update on id commands? + + cmdR, err := qtx.GetIDCommandsByStepID(ctx, sqlc.GetIDCommandsByStepIDParams{ + EnrollmentID: id, + StepID: cmdCt.StepID, + }) + if err != nil { + return fmt.Errorf("get id commands by step by id (%d): %w", cmdCt.StepID, err) + } + + for _, dbSC := range cmdR { + ret.Commands = append(ret.Commands, storage.StepCommandResult{ + RequestType: dbSC.RequestType, + CommandUUID: dbSC.CommandUuid, + ResultReport: dbSC.Result, + Completed: true, + }) + } + + err = qtx.RemoveIDCommandsByStepID(ctx, sqlc.RemoveIDCommandsByStepIDParams{ + EnrollmentID: id, + StepID: cmdCt.StepID, + }) + if err != nil { + return fmt.Errorf("remove id commands by step by id (%d): %w", cmdCt.StepID, err) + } + + err = qtx.DeleteWorkflowStepHavingNoCommandsByWorkflowName(ctx, sd.WorkflowName) + if err != nil { + return fmt.Errorf("delete workflow with no commands (%s): %w", sd.WorkflowName, err) + } + + return nil + }) + if err != nil { + return ret, fmt.Errorf("tx step completed: %w", err) + } + return ret, nil +} + +// StoreStep stores a step and its commands for later state tracking. +// See the storage interface type for further docs. +func (s *MySQLStorage) StoreStep(ctx context.Context, step *storage.StepEnqueuingWithConfig, pushTime time.Time) error { + err := step.Validate() + if err != nil { + return fmt.Errorf("validating step: %w", err) + } + return tx(ctx, s.db, s.q, func(ctx context.Context, _ *sql.Tx, qtx *sqlc.Queries) error { + params := sqlc.CreateStepParams{ + WorkflowName: step.WorkflowName, + InstanceID: step.InstanceID, + StepName: sqlNullString(step.Name), + NotUntil: sqlNullTime(step.NotUntil), + Timeout: sqlNullTime(step.Timeout), + } + stepID, err := qtx.CreateStep(ctx, params) + if err != nil { + return fmt.Errorf("creating step: %w", err) + } + + for _, sc := range step.Commands { + if !step.NotUntil.IsZero() { + err = qtx.CreateStepCommand(ctx, sqlc.CreateStepCommandParams{ + StepID: stepID, + CommandUuid: sc.CommandUUID, + RequestType: sc.RequestType, + Command: sc.Command, + }) + if err != nil { + return fmt.Errorf("creating step command: %w", err) + } + } + for _, id := range step.IDs { + params := sqlc.CreateIDCommandParams{ + EnrollmentID: id, + CommandUuid: sc.CommandUUID, + RequestType: sc.RequestType, + StepID: stepID, + } + if step.NotUntil.IsZero() { + // assume we've successfully pushed + params.LastPush = sql.NullTime{Valid: true, Time: pushTime} + } + if err := qtx.CreateIDCommand(ctx, params); err != nil { + return fmt.Errorf("creating id command: %w", err) + } + } + } + return nil + }) +} + +// RetrieveOutstandingWorkflowStates finds enrollment IDs with an outstanding workflow step from a given set. +// See the storage interface type for further docs. +func (s *MySQLStorage) RetrieveOutstandingWorkflowStatus(ctx context.Context, workflowName string, ids []string) (outstandingIDs []string, err error) { + outstandingIDs, err = s.q.GetOutstandingIDs(ctx, sqlc.GetOutstandingIDsParams{ + Ids: ids, + WorkflowName: workflowName, + }) + if err != nil { + err = fmt.Errorf("getting outstanding ids (%d): %w", len(ids), err) + } + return +} + +// CancelSteps cancels workflow steps for id. +// See the storage interface type for further docs. +func (s *MySQLStorage) CancelSteps(ctx context.Context, id, workflowName string) error { + if id == "" { + return errors.New("must supply both id and workflow name") + } + return tx(ctx, s.db, s.q, func(ctx context.Context, _ *sql.Tx, qtx *sqlc.Queries) error { + if workflowName != "" { + err := qtx.DeleteIDCommandByWorkflow(ctx, sqlc.DeleteIDCommandByWorkflowParams{ + EnrollmentID: id, + WorkflowName: workflowName, + }) + if err != nil { + return fmt.Errorf("delete id command by workflow (%s, %s): %w", id, workflowName, err) + } + } else { + err := qtx.DeleteIDCommands(ctx, id) + if err != nil { + return fmt.Errorf("delete id command (%s): %w", id, err) + } + } + + err := qtx.DeleteUnusedStepCommands(ctx) + if err != nil { + return fmt.Errorf("delete unused step commands: %w", err) + } + + if workflowName != "" { + err = qtx.DeleteWorkflowStepHavingNoCommandsByWorkflowName(ctx, workflowName) + if err != nil { + return fmt.Errorf("delete workflow step having no commands (%s): %w", workflowName, err) + } + } else { + if err = qtx.DeleteWorkflowStepHavingNoCommands(ctx); err != nil { + return fmt.Errorf("delete workflow step having no commands (%s): %w", workflowName, err) + } + } + return nil + }) +} diff --git a/engine/storage/mysql/worker.go b/engine/storage/mysql/worker.go new file mode 100644 index 0000000..652b66f --- /dev/null +++ b/engine/storage/mysql/worker.go @@ -0,0 +1,224 @@ +package mysql + +import ( + "context" + "database/sql" + "encoding/hex" + "errors" + "fmt" + "time" + + "github.com/micromdm/nanocmd/engine/storage" + "github.com/micromdm/nanocmd/engine/storage/mysql/sqlc" +) + +// randHexString generates 40-character string of hex-encoded random data. +func (s *MySQLStorage) randHexString(prefix string) sql.NullString { + p := make([]byte, 20) + s.randMu.Lock() + defer s.randMu.Unlock() + s.rand.Read(p) + return sql.NullString{String: prefix + "." + hex.EncodeToString(p), Valid: true} +} + +// RetrieveStepsToEnqueue fetches steps to be enqueued that were enqueued "later" with NotUntil. +// See the storage interface type for further docs. +func (s *MySQLStorage) RetrieveStepsToEnqueue(ctx context.Context, pushTime time.Time) ([]*storage.StepEnqueueing, error) { + if pushTime.IsZero() { + return nil, errors.New("empty push time") + } + + var ret []*storage.StepEnqueueing + err := tx(ctx, s.db, s.q, func(ctx context.Context, _ *sql.Tx, qtx *sqlc.Queries) error { + + // this smells like a bad SQL paradigm + notUntilProcVal := s.randHexString("notu") + + err := qtx.UpdateStepAfterNotUntil(ctx, sqlc.UpdateStepAfterNotUntilParams{ + ProcessID: notUntilProcVal, + NotUntil: sql.NullTime{Valid: true, Time: pushTime}, + }) + if err != nil { + return fmt.Errorf("update step with not until proc (%s): %w", notUntilProcVal.String, err) + } + + steps, err := qtx.GetStepsByProcessID(ctx, notUntilProcVal) + if err != nil { + return fmt.Errorf("get step: %w", err) + } + + seID := make(map[int64]*storage.StepEnqueueing) + for _, se := range steps { + seID[se.ID] = &storage.StepEnqueueing{ + StepContext: storage.StepContext{ + Name: se.StepName.String, + WorkflowName: se.WorkflowName, + InstanceID: se.InstanceID, + }, + } + } + + cmdIDs, err := qtx.GetIDCommandIDsByProcessID(ctx, notUntilProcVal) + if err != nil { + return fmt.Errorf("get command ids: %w", err) + } + + for _, cmdID := range cmdIDs { + se, ok := seID[cmdID.StepID] + if !ok || se == nil { + // TODO: mismatch of step, should we error here? + continue + } + se.IDs = append(se.IDs, cmdID.EnrollmentID) + } + + cmds, err := qtx.GetStepCommandsByProcessID(ctx, notUntilProcVal) + if err != nil { + return fmt.Errorf("get step commands: %w", err) + } + + for _, cmd := range cmds { + se, ok := seID[cmd.StepID] + if !ok || se == nil { + // TODO: mismatch of step, should we error here? + continue + } + se.Commands = append(se.Commands, storage.StepCommandRaw{ + CommandUUID: cmd.CommandUuid, + RequestType: cmd.RequestType, + Command: cmd.Command, + }) + } + + for _, v := range seID { + ret = append(ret, v) + } + + err = qtx.RemoveStepCommandsByProcessID(ctx, notUntilProcVal) + if err != nil { + return fmt.Errorf("remove step commands by not until proc (%s): %w", notUntilProcVal.String, err) + } + + err = qtx.UpdateLastPushByProcessID(ctx, notUntilProcVal) + if err != nil { + return fmt.Errorf("update last push by not until proc (%s): %w", notUntilProcVal.String, err) + } + + return nil + }) + return ret, err +} + +// RetrieveTimedOutSteps fetches steps that have timed out. +// See the storage interface type for further docs. +func (s *MySQLStorage) RetrieveTimedOutSteps(ctx context.Context) ([]*storage.StepResult, error) { + var ret []*storage.StepResult + + now := time.Now() + + err := tx(ctx, s.db, s.q, func(ctx context.Context, _ *sql.Tx, qtx *sqlc.Queries) error { + + // this smells like a bad SQL paradigm + timeoutProcVal := s.randHexString("tout") + + err := qtx.UpdateStepAfterTimeout(ctx, sqlc.UpdateStepAfterTimeoutParams{ + ProcessID: timeoutProcVal, + Timeout: sql.NullTime{Valid: true, Time: now}, + }) + if err != nil { + return fmt.Errorf("update step with not until proc (%s): %w", timeoutProcVal.String, err) + } + + steps, err := qtx.GetStepsWithContextByProcessID(ctx, timeoutProcVal) + if err != nil { + return fmt.Errorf("get step: %w", err) + } + + scID := make(map[int64]*storage.StepContext) + for _, se := range steps { + scID[se.ID] = &storage.StepContext{ + Name: se.StepName.String, + WorkflowName: se.WorkflowName, + InstanceID: se.InstanceID, + Context: se.Context, + } + } + + cmdIDs, err := qtx.GetIDCommandDetailsByProcessID(ctx, timeoutProcVal) + if err != nil { + return fmt.Errorf("get command ids: %w", err) + } + + rID := make(map[string]*storage.StepResult) + for _, cmdID := range cmdIDs { + sc, ok := scID[cmdID.StepID] + if !ok || sc == nil { + // TODO: mismatch of step, should we error here? + continue + } + sr, ok := rID[cmdID.EnrollmentID] + if !ok || sr == nil { + sr = &storage.StepResult{ + IDs: []string{cmdID.EnrollmentID}, + StepContext: *sc, + } + rID[cmdID.EnrollmentID] = sr + } + scr := storage.StepCommandResult{ + CommandUUID: cmdID.CommandUuid, + RequestType: cmdID.RequestType, + ResultReport: cmdID.Result, + Completed: cmdID.Completed, + } + sr.Commands = append(sr.Commands, scr) + } + + for _, v := range rID { + ret = append(ret, v) + } + + err = qtx.RemoveStepCommandsByProcessID(ctx, timeoutProcVal) + if err != nil { + return fmt.Errorf("remove step commands by timeout proc (%s): %w", timeoutProcVal.String, err) + } + + err = qtx.RemoveIDCommandsByProcessID(ctx, timeoutProcVal) + if err != nil { + return fmt.Errorf("remove id commands by timeout proc (%s): %w", timeoutProcVal.String, err) + } + + err = qtx.RemoveStepsByProcessID(ctx, timeoutProcVal) + if err != nil { + return fmt.Errorf("remove steps by timeout proc (%s): %w", timeoutProcVal.String, err) + } + + return nil + }) + + return ret, err +} + +// RetrieveAndMarkRePushed retrieves a set of IDs that need to have APNs re-pushes sent. +// See the storage interface type for further docs. +func (s *MySQLStorage) RetrieveAndMarkRePushed(ctx context.Context, ifBefore time.Time, pushTime time.Time) ([]string, error) { + var ids []string + ifBeforeTime := sqlNullTime(ifBefore) + err := tx(ctx, s.db, s.q, func(ctx context.Context, _ *sql.Tx, qtx *sqlc.Queries) error { + var err error + ids, err = qtx.GetRePushIDs(ctx, ifBeforeTime) + if err != nil { + return fmt.Errorf("get repush ids: %w", err) + } + + err = qtx.UpdateRePushIDs(ctx, sqlc.UpdateRePushIDsParams{ + LastPush: sqlNullTime(pushTime), + Before: ifBeforeTime, + }) + if err != nil { + return fmt.Errorf("update repush ids: %w", err) + } + + return nil + }) + return ids, err +} diff --git a/engine/storage/storage.go b/engine/storage/storage.go index 4d6cc51..642c264 100644 --- a/engine/storage/storage.go +++ b/engine/storage/storage.go @@ -55,6 +55,21 @@ type StepCommandResult struct { Completed bool // whether this specific command did *not* have a NotNow status } +var ( + ErrEmptyStepCommandResult = errors.New("empty step command result") + ErrEmptyResultReport = errors.New("empty result report") +) + +// Validate checks sc for issues. +func (sc *StepCommandResult) Validate() error { + if sc == nil { + return ErrEmptyStepCommandResult + } else if len(sc.ResultReport) < 1 { + return ErrEmptyResultReport + } + return nil +} + // StepCommandRaw is a raw command, its UUID, and request type. // An approximately serialized form of a workflow step command. type StepCommandRaw struct { @@ -187,6 +202,7 @@ type WorkerStorage interface { type AllStorage interface { Storage WorkerStorage + EventSubscriptionStorage } // EventSubscription is a user-configured subscription for starting workflows with optional context. diff --git a/engine/storage/test/event.go b/engine/storage/test/event.go new file mode 100644 index 0000000..881c8ee --- /dev/null +++ b/engine/storage/test/event.go @@ -0,0 +1,77 @@ +package test + +import ( + "context" + "testing" + + "github.com/micromdm/nanocmd/engine/storage" + "github.com/micromdm/nanocmd/workflow" +) + +func TestEventStorage(t *testing.T, store storage.EventSubscriptionStorage) { + ctx := context.Background() + + evTest := &storage.EventSubscription{ + Event: "Enrollment", + Workflow: "wf", + Context: "ctx", + } + + testEventData := func(t *testing.T, es *storage.EventSubscription) { + if es == nil { + t.Fatal("nil event subscription") + } + + err := es.Validate() + if err != nil { + t.Fatalf("invalid test data") + } + + if have, want := es.Event, evTest.Event; have != want { + t.Errorf("[event] have: %v, want: %v", have, want) + } + + if have, want := es.Workflow, evTest.Workflow; have != want { + t.Errorf("[workflow] have: %v, want: %v", have, want) + } + + if have, want := es.Context, evTest.Context; have != want { + t.Errorf("[context] have: %v, want: %v", have, want) + } + } + + t.Run("testdata", func(t *testing.T) { + testEventData(t, evTest) + }) + + err := store.StoreEventSubscription(ctx, "test", evTest) + if err != nil { + t.Fatal(err) + } + + events, err := store.RetrieveEventSubscriptions(ctx, []string{"test"}) + if err != nil { + t.Fatal(err) + } + + if have, want := len(events), 1; have != want { + t.Fatalf("have: %v, want: %v", have, want) + } + + t.Run("retrieve-by-name", func(t *testing.T) { + testEventData(t, events["test"]) + }) + + eventsList, err := store.RetrieveEventSubscriptionsByEvent(ctx, workflow.EventEnrollment) + if err != nil { + t.Fatal(err) + } + + if have, want := len(eventsList), 1; have != want { + t.Fatalf("have: %v, want: %v", have, want) + } + + t.Run("retrieve-by-event", func(t *testing.T) { + testEventData(t, eventsList[0]) + }) +} diff --git a/engine/storage/test/test.go b/engine/storage/test/test.go index 260ffae..2234dbe 100644 --- a/engine/storage/test/test.go +++ b/engine/storage/test/test.go @@ -38,6 +38,10 @@ func TestEngineStorage(t *testing.T, newStorage func() storage.AllStorage) { t.Run("testOutstanding", func(t *testing.T) { testOutstanding(t, s) }) + + t.Run("testEvent", func(t *testing.T) { + TestEventStorage(t, s) + }) } func mainTest(t *testing.T, s storage.AllStorage) { @@ -76,35 +80,6 @@ func mainTest(t *testing.T, s storage.AllStorage) { true, nil, }, - { - "missing_command_ReportResults", - &storage.StepEnqueuingWithConfig{ - StepEnqueueing: storage.StepEnqueueing{ - IDs: []string{fakeID}, - StepContext: storage.StepContext{ - WorkflowName: "workflow.name.test1", - InstanceID: "A", - }, - Commands: []storage.StepCommandRaw{ - { - CommandUUID: "UUID-1", - RequestType: "DeviceInformation", - }, - }, - }, - }, - false, - []responseTest{{ - testName: "missing_ReportResults", - resp: &storage.StepCommandResult{ - CommandUUID: "UUID-1", - Completed: true, - // missing ReportResults (should error) - }, - shouldBeCompleted: false, - shouldError: true, - }}, - }, { "normal_test1_command_multi_id", &storage.StepEnqueuingWithConfig{ @@ -257,47 +232,41 @@ func mainTest(t *testing.T, s storage.AllStorage) { }, Commands: []storage.StepCommandRaw{ { - CommandUUID: "UUID-1", + CommandUUID: "W-UUID-1", RequestType: "DeviceInformation", }, { - CommandUUID: "UUID-1", + CommandUUID: "W-UUID-1", RequestType: "SecurityInfo", }, }, }, }, - false, + true, []responseTest{ { testName: "resp1", resp: &storage.StepCommandResult{ - CommandUUID: "UUID-1", + CommandUUID: "W-UUID-1", Completed: true, - ResultReport: []byte("UUID-1"), - }, - shouldBeCompleted: true, - shouldError: false, - skipReqType: true, - skipCmdLen: true, - id: "CCC222", - }, - { - testName: "resp2", - resp: &storage.StepCommandResult{ - CommandUUID: "UUID-1", - Completed: true, - ResultReport: []byte("UUID-1"), + RequestType: "DeviceInformation", + ResultReport: []byte("W-UUID-1"), }, shouldBeCompleted: false, shouldError: true, skipReqType: true, + skipCmdLen: true, + reqType: "DeviceInformation", id: "CCC222", }, }, }, } { t.Run("step-"+tStep.testName, func(t *testing.T) { + // if err := tStep.step.Validate(); err != nil { + // t.Fatalf("invalid test data: step enqueueing with config: %v", err) + // } + err := s.StoreStep(ctx, tStep.step, time.Now()) if tStep.shouldError && err == nil { t.Fatalf("StoreStep: expected error; step=%v", tStep.step) @@ -316,6 +285,10 @@ func mainTest(t *testing.T, s storage.AllStorage) { t.Errorf("request type does not match; have: %s, want: %s", have, want) } + if err = tRespStep.resp.Validate(); err != nil { + t.Fatalf("invalid test data: step command result: %v", err) + } + completedStep, err := s.StoreCommandResponseAndRetrieveCompletedStep(ctx, tRespStep.id, tRespStep.resp) if tRespStep.shouldError && err == nil { diff --git a/engine/storage/test/worker.go b/engine/storage/test/worker.go index 16c342e..1af6f3b 100644 --- a/engine/storage/test/worker.go +++ b/engine/storage/test/worker.go @@ -94,6 +94,16 @@ func testEngineStorageNotUntil(t *testing.T, s storage.AllStorage) { if have, want := len(steps), test.stepsWanted2; have != want { t.Fatalf("expected steps (2nd): have %v, want %v", have, want) } + + // cancel these steps so they don't hang around for other tests to get confused on + for _, step := range test.steps { + for _, id := range step.IDs { + if err = s.CancelSteps(ctx, id, step.WorkflowName); err != nil { + t.Errorf("cancelling steps (%s, %s): %v", id, step.WorkflowName, err) + } + } + } + }) } } @@ -129,8 +139,9 @@ func testEngineStepTimeout(t *testing.T, s storage.AllStorage) { }, Commands: []storage.StepCommandRaw{ { - CommandUUID: "UUID-1", + CommandUUID: "Y-UUID-1", RequestType: "DeviceInformation", + Command: []byte("Y-UUID-1"), }, }, }, @@ -141,7 +152,7 @@ func testEngineStepTimeout(t *testing.T, s storage.AllStorage) { { id: "EnrollmentID-1", sc: storage.StepCommandResult{ - CommandUUID: "UUID-1", + CommandUUID: "Y-UUID-1", RequestType: "DeviceInformation", ResultReport: []byte("Command-1"), Completed: true, @@ -220,13 +231,13 @@ func testRepush(t *testing.T, s storage.AllStorage) { }, Commands: []storage.StepCommandRaw{ { - CommandUUID: "UUID-1", + CommandUUID: "W-UUID-1", RequestType: "DeviceInformation", Command: []byte("Command-1"), }, }, }, - // NotUntil: not setting NotUntil to sure these are simulated to be sent pushes "now" + // NotUntil: not setting NotUntil to be sure these are simulated to be sent pushes "now" } now := time.Now() @@ -238,7 +249,7 @@ func testRepush(t *testing.T, s storage.AllStorage) { // complete one of the commands _, err = s.StoreCommandResponseAndRetrieveCompletedStep(ctx, enq.IDs[0], &storage.StepCommandResult{ - CommandUUID: "UUID-1", + CommandUUID: "W-UUID-1", RequestType: "DeviceInformation", ResultReport: []byte("Result-1"), Completed: true, @@ -247,8 +258,8 @@ func testRepush(t *testing.T, s storage.AllStorage) { t.Fatal(err) } - ifBefore := now.Add(time.Second) - now = ifBefore.Add(time.Second) + ifBefore := now.Add(time.Second * 2) + now = ifBefore.Add(time.Second * 2) ids, err := s.RetrieveAndMarkRePushed(ctx, ifBefore, now) if err != nil { @@ -282,7 +293,7 @@ func testRepush(t *testing.T, s storage.AllStorage) { }, }, }, - NotUntil: time.Now().Add(-time.Minute), + NotUntil: now.Add(-time.Minute), } err = s.StoreStep(ctx, enq2, now) @@ -307,8 +318,8 @@ func testRepush(t *testing.T, s storage.AllStorage) { t.Fatal(err) } - ifBefore = now.Add(time.Second) - now = ifBefore.Add(time.Second) + ifBefore = now.Add(time.Second * 2) + now = ifBefore.Add(time.Second * 2) ids, err = s.RetrieveAndMarkRePushed(ctx, ifBefore, now) if err != nil { diff --git a/go.mod b/go.mod index aba944f..973597c 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.19 require ( github.com/alexedwards/flow v0.0.0-20220806114457-cf11be9e0e03 + github.com/go-sql-driver/mysql v1.7.1 github.com/google/uuid v1.3.1 github.com/groob/plist v0.0.0-20220217120414-63fa881b19a5 github.com/jessepeterson/mdmcommands v0.0.0-20230517161100-c5ca4128e1e3 diff --git a/go.sum b/go.sum index 5890124..d32a691 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/alexedwards/flow v0.0.0-20220806114457-cf11be9e0e03 h1:r07xZN3ENBWdxGuU/feCsnpsgHJ7+3uLm7cq9S0sqoI= github.com/alexedwards/flow v0.0.0-20220806114457-cf11be9e0e03/go.mod h1:1rjOQiOqQlmMdUMuvlJFjldqTnE/tQULE7qPIu4aq3U= +github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= +github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.1.2 h1:xf4v41cLI2Z6FxbKm+8Bu+m8ifhj15JuZ9sa0jZCMUU= github.com/google/btree v1.1.2/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= diff --git a/workflow/event.go b/workflow/event.go index d9a60b0..3a34e43 100644 --- a/workflow/event.go +++ b/workflow/event.go @@ -11,8 +11,8 @@ var ErrEventsNotSupported = errors.New("events not supported for this workflow") // EventFlag is a bitmask of event types. type EventFlag uint -// Storage backends (persistent storage) are likely to use these numeric -// values. Treat these as append-only: Order and position matter. +// Storage backends (persistent storage) may use these numeric values. +// So treat these as append-only: order and position matter. const ( EventAllCommandResponse EventFlag = 1 << iota EventAuthenticate