diff --git a/dbos/client.go b/dbos/client.go index c04130b..a8c8fa7 100644 --- a/dbos/client.go +++ b/dbos/client.go @@ -147,6 +147,13 @@ func (c *client) Enqueue(queueName, workflowName string, input any, opts ...Enqu if params.priority > uint(math.MaxInt) { return nil, fmt.Errorf("priority %d exceeds maximum allowed value %d", params.priority, math.MaxInt) } + + // Serialize input before storing in workflow status + encodedInput, err := serialize(params.workflowInput) + if err != nil { + return nil, fmt.Errorf("failed to serialize workflow input: %w", err) + } + status := WorkflowStatus{ Name: params.workflowName, ApplicationVersion: params.applicationVersion, @@ -155,7 +162,7 @@ func (c *client) Enqueue(queueName, workflowName string, input any, opts ...Enqu CreatedAt: time.Now(), Deadline: deadline, Timeout: params.workflowTimeout, - Input: params.workflowInput, + Input: encodedInput, QueueName: queueName, DeduplicationID: params.deduplicationID, Priority: int(params.priority), diff --git a/dbos/queue.go b/dbos/queue.go index f323887..af1d104 100644 --- a/dbos/queue.go +++ b/dbos/queue.go @@ -1,10 +1,7 @@ package dbos import ( - "bytes" "context" - "encoding/base64" - "encoding/gob" "log/slog" "math" "math/rand" @@ -229,18 +226,10 @@ func (qr *queueRunner) run(ctx *dbosContext) { // Deserialize input var input any - if len(workflow.input) > 0 { - inputBytes, err := base64.StdEncoding.DecodeString(workflow.input) - if err != nil { - qr.logger.Error("failed to decode input for workflow", "workflow_id", workflow.id, "error", err) - continue - } - buf := bytes.NewBuffer(inputBytes) - dec := gob.NewDecoder(buf) - if err := dec.Decode(&input); err != nil { - qr.logger.Error("failed to decode input for workflow", "workflow_id", workflow.id, "error", err) - continue - } + input, err = deserialize(workflow.input) + if err != nil { + qr.logger.Error("failed to deserialize input for workflow", "workflow_id", workflow.id, "error", err) + continue } _, err := registeredWorkflow.wrappedFunction(ctx, input, WithWorkflowID(workflow.id)) diff --git a/dbos/recovery.go b/dbos/recovery.go index ba51a6d..fefd56d 100644 --- a/dbos/recovery.go +++ b/dbos/recovery.go @@ -1,9 +1,5 @@ package dbos -import ( - "strings" -) - func recoverPendingWorkflows(ctx *dbosContext, executorIDs []string) ([]WorkflowHandle[any], error) { workflowHandles := make([]WorkflowHandle[any], 0) // List pending workflows for the executors @@ -18,9 +14,17 @@ func recoverPendingWorkflows(ctx *dbosContext, executorIDs []string) ([]Workflow } for _, workflow := range pendingWorkflows { - if inputStr, ok := workflow.Input.(string); ok { - if strings.Contains(inputStr, "Failed to decode") { - ctx.logger.Warn("Skipping workflow recovery due to input decoding failure", "workflow_id", workflow.ID, "name", workflow.Name) + // Deserialize the workflow input + var decodedInput any + if workflow.Input != nil { + inputString, ok := workflow.Input.(*string) + if !ok { + ctx.logger.Warn("Skipping workflow recovery due to invalid input type", "workflow_id", workflow.ID, "name", workflow.Name, "input_type", workflow.Input) + continue + } + decodedInput, err = deserialize(inputString) + if err != nil { + ctx.logger.Warn("Skipping workflow recovery due to input decoding failure", "workflow_id", workflow.ID, "name", workflow.Name, "error", err) continue } } @@ -59,7 +63,7 @@ func recoverPendingWorkflows(ctx *dbosContext, executorIDs []string) ([]Workflow WithWorkflowID(workflow.ID), } // Create a workflow context from the executor context - handle, err := registeredWorkflow.wrappedFunction(ctx, workflow.Input, opts...) + handle, err := registeredWorkflow.wrappedFunction(ctx, decodedInput, opts...) if err != nil { return nil, err } diff --git a/dbos/serialization.go b/dbos/serialization.go index c2e8814..7747c88 100644 --- a/dbos/serialization.go +++ b/dbos/serialization.go @@ -9,24 +9,37 @@ import ( "strings" ) -func serialize(data any) (string, error) { - var inputBytes []byte - if data != nil { - var buf bytes.Buffer - enc := gob.NewEncoder(&buf) - if err := enc.Encode(&data); err != nil { - return "", fmt.Errorf("failed to encode data: %w", err) - } - inputBytes = buf.Bytes() +func serialize(data any) (*string, error) { + // Handle nil values specially - return nil pointer which will be stored as NULL in DB + if data == nil { + return nil, nil } - return base64.StdEncoding.EncodeToString(inputBytes), nil + + // Handle empty string specially - return pointer to empty string which will be stored as "" in DB + if str, ok := data.(string); ok && str == "" { + return &str, nil + } + + var buf bytes.Buffer + enc := gob.NewEncoder(&buf) + if err := enc.Encode(&data); err != nil { + return nil, fmt.Errorf("failed to encode data: %w", err) + } + inputBytes := buf.Bytes() + + encoded := base64.StdEncoding.EncodeToString(inputBytes) + return &encoded, nil } func deserialize(data *string) (any, error) { - if data == nil || *data == "" { + if data == nil { return nil, nil } + if *data == "" { + return "", nil + } + dataBytes, err := base64.StdEncoding.DecodeString(*data) if err != nil { return nil, fmt.Errorf("failed to decode data: %w", err) diff --git a/dbos/system_database.go b/dbos/system_database.go index ac1abc4..88a962e 100644 --- a/dbos/system_database.go +++ b/dbos/system_database.go @@ -37,7 +37,7 @@ type systemDatabase interface { insertWorkflowStatus(ctx context.Context, input insertWorkflowStatusDBInput) (*insertWorkflowResult, error) listWorkflows(ctx context.Context, input listWorkflowsDBInput) ([]WorkflowStatus, error) updateWorkflowOutcome(ctx context.Context, input updateWorkflowOutcomeDBInput) error - awaitWorkflowResult(ctx context.Context, workflowID string) (any, error) + awaitWorkflowResult(ctx context.Context, workflowID string) (*string, error) cancelWorkflow(ctx context.Context, workflowID string) error cancelAllBefore(ctx context.Context, cutoffTime time.Time) error resumeWorkflow(ctx context.Context, workflowID string) error @@ -55,7 +55,7 @@ type systemDatabase interface { // Communication (special steps) send(ctx context.Context, input WorkflowSendInput) error - recv(ctx context.Context, input recvInput) (any, error) + recv(ctx context.Context, input recvInput) (*string, error) setEvent(ctx context.Context, input WorkflowSetEventInput) error getEvent(ctx context.Context, input getEventInput) (any, error) @@ -440,11 +440,6 @@ func (s *sysDB) insertWorkflowStatus(ctx context.Context, input insertWorkflowSt timeoutMs = &millis } - inputString, err := serialize(input.status.Input) - if err != nil { - return nil, fmt.Errorf("failed to serialize input: %w", err) - } - // Our DB works with NULL values var applicationVersion *string if len(input.status.ApplicationVersion) > 0 { @@ -516,7 +511,7 @@ func (s *sysDB) insertWorkflowStatus(ctx context.Context, input insertWorkflowSt updatedAt.UnixMilli(), timeoutMs, deadline, - inputString, + input.status.Input, // encoded input deduplicationID, input.status.Priority, WorkflowStatusEnqueued, @@ -791,18 +786,13 @@ func (s *sysDB) listWorkflows(ctx context.Context, input listWorkflowsDBInput) ( wf.Error = errors.New(*errorStr) } - wf.Output, err = deserialize(outputString) - if err != nil { - return nil, fmt.Errorf("failed to deserialize output: %w", err) - } + // Return output as encoded *string + wf.Output = outputString } - // Handle input only if loadInput is true + // Return input as encoded *string if input.loadInput { - wf.Input, err = deserialize(inputString) - if err != nil { - return nil, fmt.Errorf("failed to deserialize input: %w", err) - } + wf.Input = inputString } workflows = append(workflows, wf) @@ -818,7 +808,7 @@ func (s *sysDB) listWorkflows(ctx context.Context, input listWorkflowsDBInput) ( type updateWorkflowOutcomeDBInput struct { workflowID string status WorkflowStatusType - output any + output *string err error tx pgx.Tx } @@ -830,20 +820,16 @@ func (s *sysDB) updateWorkflowOutcome(ctx context.Context, input updateWorkflowO SET status = $1, output = $2, error = $3, updated_at = $4, deduplication_id = NULL WHERE workflow_uuid = $5 AND NOT (status = $6 AND $1 in ($7, $8))`, pgx.Identifier{s.schema}.Sanitize()) - outputString, err := serialize(input.output) - if err != nil { - return fmt.Errorf("failed to serialize output: %w", err) - } - var errorStr string if input.err != nil { errorStr = input.err.Error() } + var err error if input.tx != nil { - _, err = input.tx.Exec(ctx, query, input.status, outputString, errorStr, time.Now().UnixMilli(), input.workflowID, WorkflowStatusCancelled, WorkflowStatusSuccess, WorkflowStatusError) + _, err = input.tx.Exec(ctx, query, input.status, input.output, errorStr, time.Now().UnixMilli(), input.workflowID, WorkflowStatusCancelled, WorkflowStatusSuccess, WorkflowStatusError) } else { - _, err = s.pool.Exec(ctx, query, input.status, outputString, errorStr, time.Now().UnixMilli(), input.workflowID, WorkflowStatusCancelled, WorkflowStatusSuccess, WorkflowStatusError) + _, err = s.pool.Exec(ctx, query, input.status, input.output, errorStr, time.Now().UnixMilli(), input.workflowID, WorkflowStatusCancelled, WorkflowStatusSuccess, WorkflowStatusError) } if err != nil { @@ -1105,11 +1091,6 @@ func (s *sysDB) forkWorkflow(ctx context.Context, input forkWorkflowDBInput) (st recovery_attempts ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)`, pgx.Identifier{s.schema}.Sanitize()) - inputString, err := serialize(originalWorkflow.Input) - if err != nil { - return "", fmt.Errorf("failed to serialize input: %w", err) - } - // Marshal authenticated roles (slice of strings) to JSON for TEXT column authenticatedRoles, err := json.Marshal(originalWorkflow.AuthenticatedRoles) @@ -1127,7 +1108,7 @@ func (s *sysDB) forkWorkflow(ctx context.Context, input forkWorkflowDBInput) (st &appVersion, originalWorkflow.ApplicationID, _DBOS_INTERNAL_QUEUE_NAME, - inputString, + originalWorkflow.Input, // encoded input time.Now().UnixMilli(), time.Now().UnixMilli(), 0) @@ -1157,7 +1138,7 @@ func (s *sysDB) forkWorkflow(ctx context.Context, input forkWorkflowDBInput) (st return forkedWorkflowID, nil } -func (s *sysDB) awaitWorkflowResult(ctx context.Context, workflowID string) (any, error) { +func (s *sysDB) awaitWorkflowResult(ctx context.Context, workflowID string) (*string, error) { query := fmt.Sprintf(`SELECT status, output, error FROM %s.workflow_status WHERE workflow_uuid = $1`, pgx.Identifier{s.schema}.Sanitize()) var status WorkflowStatusType for { @@ -1179,20 +1160,14 @@ func (s *sysDB) awaitWorkflowResult(ctx context.Context, workflowID string) (any return nil, fmt.Errorf("failed to query workflow status: %w", err) } - // Deserialize output from TEXT to bytes then from bytes to R using gob - output, err := deserialize(outputString) - if err != nil { - return nil, fmt.Errorf("failed to deserialize output: %w", err) - } - switch status { case WorkflowStatusSuccess, WorkflowStatusError: if errorStr == nil || len(*errorStr) == 0 { - return output, nil + return outputString, nil } - return output, errors.New(*errorStr) + return outputString, errors.New(*errorStr) case WorkflowStatusCancelled: - return output, newAwaitedWorkflowCancelledError(workflowID) + return outputString, newAwaitedWorkflowCancelledError(workflowID) default: time.Sleep(_DB_RETRY_INTERVAL) } @@ -1203,7 +1178,7 @@ type recordOperationResultDBInput struct { workflowID string stepID int stepName string - output any + output *string err error tx pgx.Tx } @@ -1219,16 +1194,12 @@ func (s *sysDB) recordOperationResult(ctx context.Context, input recordOperation errorString = &e } - outputString, err := serialize(input.output) - if err != nil { - return fmt.Errorf("failed to serialize output: %w", err) - } - + var err error if input.tx != nil { _, err = input.tx.Exec(ctx, query, input.workflowID, input.stepID, - outputString, + input.output, errorString, input.stepName, ) @@ -1236,7 +1207,7 @@ func (s *sysDB) recordOperationResult(ctx context.Context, input recordOperation _, err = s.pool.Exec(ctx, query, input.workflowID, input.stepID, - outputString, + input.output, errorString, input.stepName, ) @@ -1326,7 +1297,7 @@ type recordChildGetResultDBInput struct { parentWorkflowID string childWorkflowID string stepID int - output string + output *string err error } @@ -1361,7 +1332,7 @@ func (s *sysDB) recordChildGetResult(ctx context.Context, input recordChildGetRe /*******************************/ type recordedResult struct { - output any + output *string err error } @@ -1431,17 +1402,12 @@ func (s *sysDB) checkOperationExecution(ctx context.Context, input checkOperatio return nil, newUnexpectedStepError(input.workflowID, input.stepID, input.stepName, recordedFunctionName) } - output, err := deserialize(outputString) - if err != nil { - return nil, fmt.Errorf("failed to deserialize output: %w", err) - } - var recordedError error if errorStr != nil && *errorStr != "" { recordedError = errors.New(*errorStr) } result := &recordedResult{ - output: output, + output: outputString, err: recordedError, } return result, nil @@ -1485,13 +1451,9 @@ func (s *sysDB) getWorkflowSteps(ctx context.Context, input getWorkflowStepsInpu return nil, fmt.Errorf("failed to scan step row: %w", err) } - // Deserialize output if present and loadOutput is true - if input.loadOutput && outputString != nil { - output, err := deserialize(outputString) - if err != nil { - return nil, fmt.Errorf("failed to deserialize output: %w", err) - } - step.Output = output + // Return output as encoded string if loadOutput is true + if input.loadOutput { + step.Output = outputString } // Convert error string to error if present @@ -1564,10 +1526,16 @@ func (s *sysDB) sleep(ctx context.Context, input sleepInput) (time.Duration, err return 0, fmt.Errorf("no recorded end time for recorded sleep operation") } + // Deserialize the recorded end time + decodedOutput, err := deserialize(recordedResult.output) + if err != nil { + return 0, fmt.Errorf("failed to deserialize sleep end time: %w", err) + } + // The output should be a time.Time representing the end time - endTimeInterface, ok := recordedResult.output.(time.Time) + endTimeInterface, ok := decodedOutput.(time.Time) if !ok { - return 0, fmt.Errorf("recorded output is not a time.Time: %T", recordedResult.output) + return 0, fmt.Errorf("decoded output is not a time.Time: %T", decodedOutput) } endTime = endTimeInterface @@ -1578,12 +1546,18 @@ func (s *sysDB) sleep(ctx context.Context, input sleepInput) (time.Duration, err // First execution: calculate and record the end time endTime = time.Now().Add(input.duration) + // Serialize the end time before recording + encodedEndTime, serErr := serialize(endTime) + if serErr != nil { + return 0, fmt.Errorf("failed to serialize sleep end time: %w", serErr) + } + // Record the operation result with the calculated end time recordInput := recordOperationResultDBInput{ workflowID: wfState.workflowID, stepID: stepID, stepName: functionName, - output: endTime, + output: encodedEndTime, err: nil, } @@ -1730,7 +1704,7 @@ const _DBOS_NULL_TOPIC = "__null__topic__" type WorkflowSendInput struct { DestinationID string - Message any + Message *string Topic string } @@ -1783,14 +1757,8 @@ func (s *sysDB) send(ctx context.Context, input WorkflowSendInput) error { topic = input.Topic } - // Serialize the message. It must have been registered with encoding/gob by the user if not a basic type. - messageString, err := serialize(input.Message) - if err != nil { - return fmt.Errorf("failed to serialize message: %w", err) - } - insertQuery := fmt.Sprintf(`INSERT INTO %s.notifications (destination_uuid, topic, message) VALUES ($1, $2, $3)`, pgx.Identifier{s.schema}.Sanitize()) - _, err = tx.Exec(ctx, insertQuery, input.DestinationID, topic, messageString) + _, err = tx.Exec(ctx, insertQuery, input.DestinationID, topic, input.Message) if err != nil { // Check for foreign key violation (destination workflow doesn't exist) if pgErr, ok := err.(*pgconn.PgError); ok && pgErr.Code == _PG_ERROR_FOREIGN_KEY_VIOLATION { @@ -1825,7 +1793,7 @@ func (s *sysDB) send(ctx context.Context, input WorkflowSendInput) error { } // Recv is a special type of step that receives a message destined for a given workflow -func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) { +func (s *sysDB) recv(ctx context.Context, input recvInput) (*string, error) { functionName := "DBOS.recv" // Get workflow state from context @@ -1885,7 +1853,7 @@ func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) { err = s.pool.QueryRow(ctx, query, destinationID, topic).Scan(&exists) if err != nil { cond.L.Unlock() - return false, fmt.Errorf("failed to check message: %w", err) + return nil, fmt.Errorf("failed to check message: %w", err) } if !exists { done := make(chan struct{}) @@ -1939,29 +1907,17 @@ func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) { var messageString *string err = tx.QueryRow(ctx, query, destinationID, topic).Scan(&messageString) if err != nil { - if err == pgx.ErrNoRows { - // No message found, record nil result - messageString = nil - } else { + if err != pgx.ErrNoRows { return nil, fmt.Errorf("failed to consume message: %w", err) } } - // Deserialize the message - var message any - if messageString != nil { // nil message can happen on the timeout path only - message, err = deserialize(messageString) - if err != nil { - return nil, fmt.Errorf("failed to deserialize message: %w", err) - } - } - - // Record the operation result + // Record the operation result (with encoded message string) recordInput := recordOperationResultDBInput{ workflowID: destinationID, stepID: stepID, stepName: functionName, - output: message, + output: messageString, tx: tx, } err = s.recordOperationResult(ctx, recordInput) @@ -1973,12 +1929,13 @@ func (s *sysDB) recv(ctx context.Context, input recvInput) (any, error) { return nil, fmt.Errorf("failed to commit transaction: %w", err) } - return message, nil + // Return the message string pointer + return messageString, nil } type WorkflowSetEventInput struct { Key string - Message any + Message *string } func (s *sysDB) setEvent(ctx context.Context, input WorkflowSetEventInput) error { @@ -2018,19 +1975,13 @@ func (s *sysDB) setEvent(ctx context.Context, input WorkflowSetEventInput) error return nil } - // Serialize the message. It must have been registered with encoding/gob by the user if not a basic type. - messageString, err := serialize(input.Message) - if err != nil { - return fmt.Errorf("failed to serialize message: %w", err) - } - // Insert or update the event using UPSERT insertQuery := fmt.Sprintf(`INSERT INTO %s.workflow_events (workflow_uuid, key, value) VALUES ($1, $2, $3) ON CONFLICT (workflow_uuid, key) DO UPDATE SET value = EXCLUDED.value`, pgx.Identifier{s.schema}.Sanitize()) - _, err = tx.Exec(ctx, insertQuery, wfState.workflowID, input.Key, messageString) + _, err = tx.Exec(ctx, insertQuery, wfState.workflowID, input.Key, input.Message) if err != nil { return fmt.Errorf("failed to insert/update workflow event: %w", err) } @@ -2160,22 +2111,13 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error) } } - // Deserialize the value if it exists - var value any - if valueString != nil { - value, err = deserialize(valueString) - if err != nil { - return nil, fmt.Errorf("failed to deserialize event value: %w", err) - } - } - // Record the operation result if this is called within a workflow if isInWorkflow { recordInput := recordOperationResultDBInput{ workflowID: wfState.workflowID, stepID: stepID, stepName: functionName, - output: value, + output: valueString, err: nil, } @@ -2185,7 +2127,8 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error) } } - return value, nil + // Return the value string pointer + return valueString, nil } /*******************************/ @@ -2195,7 +2138,7 @@ func (s *sysDB) getEvent(ctx context.Context, input getEventInput) (any, error) type dequeuedWorkflow struct { id string name string - input string + input *string } type dequeueWorkflowsInput struct { @@ -2397,21 +2340,16 @@ func (s *sysDB) dequeueWorkflows(ctx context.Context, input dequeueWorkflowsInpu WHERE workflow_uuid = $5 RETURNING name, inputs`, pgx.Identifier{s.schema}.Sanitize()) - var inputString *string err := tx.QueryRow(ctx, updateQuery, WorkflowStatusPending, input.applicationVersion, input.executorID, time.Now().UnixMilli(), - id).Scan(&retWorkflow.name, &inputString) + id).Scan(&retWorkflow.name, &retWorkflow.input) if err != nil { return nil, fmt.Errorf("failed to update workflow %s during dequeue: %w", id, err) } - if inputString != nil && len(*inputString) > 0 { - retWorkflow.input = *inputString - } - retWorkflows = append(retWorkflows, retWorkflow) } diff --git a/dbos/workflow.go b/dbos/workflow.go index 5117933..876f4f5 100644 --- a/dbos/workflow.go +++ b/dbos/workflow.go @@ -249,9 +249,24 @@ func (h *workflowPollingHandle[R]) GetResult(opts ...GetResultOption) (R, error) defer cancel() } - result, err := retryWithResult(ctx, func() (any, error) { + encodedResult, err := retryWithResult(ctx, func() (any, error) { return h.dbosContext.(*dbosContext).systemDB.awaitWorkflowResult(ctx, h.workflowID) }, withRetrierLogger(h.dbosContext.(*dbosContext).logger)) + + // Deserialize the result (but preserve any error from awaitWorkflowResult) + var result any + if encodedResult != nil { + encodedStr, ok := encodedResult.(*string) + if !ok { + return *new(R), newWorkflowUnexpectedResultType(h.workflowID, "string (encoded)", fmt.Sprintf("%T", encodedResult)) + } + var deserErr error + result, deserErr = deserialize(encodedStr) + if deserErr != nil { + return *new(R), fmt.Errorf("failed to deserialize workflow result: %w", deserErr) + } + } + if result != nil { typedResult, ok := result.(R) if !ok { @@ -806,6 +821,13 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt if params.priority > uint(math.MaxInt) { return nil, fmt.Errorf("priority %d exceeds maximum allowed value %d", params.priority, math.MaxInt) } + + // Serialize input before storing in workflow status + encodedInput, serErr := serialize(input) + if serErr != nil { + return nil, newWorkflowExecutionError(workflowID, fmt.Errorf("failed to serialize workflow input: %w", serErr)) + } + workflowStatus := WorkflowStatus{ Name: params.workflowName, ApplicationVersion: params.applicationVersion, @@ -815,7 +837,7 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt CreatedAt: time.Now(), Deadline: deadline, Timeout: timeout, - Input: input, + Input: encodedInput, ApplicationID: c.GetApplicationID(), QueueName: params.queueName, DeduplicationID: params.deduplicationID, @@ -937,9 +959,26 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt // Handle DBOS ID conflict errors by waiting workflow result if errors.Is(err, &DBOSError{Code: ConflictingIDError}) { c.logger.Warn("Workflow ID conflict detected. Waiting for existing workflow to complete", "workflow_id", workflowID) - result, err = retryWithResult(c, func() (any, error) { + var encodedResult any + encodedResult, err = retryWithResult(c, func() (any, error) { return c.systemDB.awaitWorkflowResult(uncancellableCtx, workflowID) }, withRetrierLogger(c.logger)) + var deserErr error + encodedResultString, ok := encodedResult.(*string) + if !ok { + c.logger.Error("Unexpected result type when awaiting workflow result after ID conflict", "workflow_id", workflowID, "type", fmt.Sprintf("%T", encodedResult)) + outcomeChan <- workflowOutcome[any]{result: nil, err: fmt.Errorf("unexpected result type when awaiting workflow result after ID conflict: expected string, got %T", encodedResult)} + close(outcomeChan) + return + } + fmt.Println("Encoded result string:", *encodedResultString) + result, deserErr = deserialize(encodedResultString) + if deserErr != nil { + c.logger.Error("Failed to deserialize workflow result after ID conflict", "workflow_id", workflowID, "error", deserErr) + outcomeChan <- workflowOutcome[any]{result: nil, err: fmt.Errorf("failed to deserialize workflow result after ID conflict: %w", deserErr)} + close(outcomeChan) + return + } } else { status := WorkflowStatusSuccess @@ -955,12 +994,21 @@ func (c *dbosContext) RunWorkflow(_ DBOSContext, fn WorkflowFunc, input any, opt status = WorkflowStatusCancelled } + // Serialize the output before recording + encodedOutput, serErr := serialize(result) + if serErr != nil { + c.logger.Error("Failed to serialize workflow output", "workflow_id", workflowID, "error", serErr) + outcomeChan <- workflowOutcome[any]{result: nil, err: fmt.Errorf("failed to serialize output: %w", serErr)} + close(outcomeChan) + return + } + recordErr := retry(c, func() error { return c.systemDB.updateWorkflowOutcome(uncancellableCtx, updateWorkflowOutcomeDBInput{ workflowID: workflowID, status: status, err: err, - output: result, + output: encodedOutput, }) }, withRetrierLogger(c.logger)) if recordErr != nil { @@ -1178,7 +1226,15 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, opts ...StepOption) return nil, newStepExecutionError(stepState.workflowID, stepOpts.stepName, fmt.Errorf("checking operation execution: %w", err)) } if recordedOutput != nil { - return recordedOutput.output, recordedOutput.err + // Deserialize the recorded output + var decodedOutput any + if recordedOutput.output != nil { + decodedOutput, err = deserialize(recordedOutput.output) + if err != nil { + return nil, newStepExecutionError(stepState.workflowID, stepOpts.stepName, fmt.Errorf("failed to deserialize recorded output: %w", err)) + } + } + return decodedOutput, recordedOutput.err } // Spawn a child DBOSContext with the step state @@ -1228,13 +1284,19 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, opts ...StepOption) } } + // Serialize step output before recording + encodedStepOutput, serErr := serialize(stepOutput) + if serErr != nil { + return nil, newStepExecutionError(stepState.workflowID, stepOpts.stepName, fmt.Errorf("failed to serialize step output: %w", serErr)) + } + // Record the final result dbInput := recordOperationResultDBInput{ workflowID: stepState.workflowID, stepName: stepOpts.stepName, stepID: stepState.stepID, err: stepError, - output: stepOutput, + output: encodedStepOutput, } recErr := retry(c, func() error { return c.systemDB.recordOperationResult(uncancellableCtx, dbInput) @@ -1251,10 +1313,16 @@ func (c *dbosContext) RunAsStep(_ DBOSContext, fn StepFunc, opts ...StepOption) /****************************************/ func (c *dbosContext) Send(_ DBOSContext, destinationID string, message any, topic string) error { + // Serialize the message before sending + encodedMessage, err := serialize(message) + if err != nil { + return fmt.Errorf("failed to serialize message: %w", err) + } + return retry(c, func() error { return c.systemDB.send(c, WorkflowSendInput{ DestinationID: destinationID, - Message: message, + Message: encodedMessage, Topic: topic, }) }, withRetrierLogger(c.logger)) @@ -1292,9 +1360,14 @@ func (c *dbosContext) Recv(_ DBOSContext, topic string, timeout time.Duration) ( Topic: topic, Timeout: timeout, } - return retryWithResult(c, func() (any, error) { + encodedMsg, err := retryWithResult(c, func() (*string, error) { return c.systemDB.recv(c, input) }, withRetrierLogger(c.logger)) + if err != nil { + return nil, err + } + + return deserialize(encodedMsg) } // Recv receives a message sent to this workflow with type safety. @@ -1332,10 +1405,16 @@ func Recv[R any](ctx DBOSContext, topic string, timeout time.Duration) (R, error } func (c *dbosContext) SetEvent(_ DBOSContext, key string, message any) error { + // Serialize the event value before storing + encodedMessage, err := serialize(message) + if err != nil { + return fmt.Errorf("failed to serialize event value: %w", err) + } + return retry(c, func() error { return c.systemDB.setEvent(c, WorkflowSetEventInput{ Key: key, - Message: message, + Message: encodedMessage, }) }, withRetrierLogger(c.logger)) } @@ -1375,9 +1454,22 @@ func (c *dbosContext) GetEvent(_ DBOSContext, targetWorkflowID, key string, time Key: key, Timeout: timeout, } - return retryWithResult(c, func() (any, error) { + encodedValue, err := retryWithResult(c, func() (any, error) { return c.systemDB.getEvent(c, input) }, withRetrierLogger(c.logger)) + if err != nil { + return nil, err + } + + // Deserialize the event value + if encodedValue != nil { + encodedStr, ok := encodedValue.(*string) + if !ok { + return nil, fmt.Errorf("event value must be encoded string, got %T", encodedValue) + } + return deserialize(encodedStr) + } + return nil, nil } // GetEvent retrieves a key-value event from a target workflow with type safety. @@ -1940,15 +2032,50 @@ func (c *dbosContext) ListWorkflows(_ DBOSContext, opts ...ListWorkflowsOption) } // Call the context method to list workflows + var workflows []WorkflowStatus + var err error workflowState, ok := c.Value(workflowStateKey).(*workflowState) isWithinWorkflow := ok && workflowState != nil if isWithinWorkflow { - return RunAsStep(c, func(ctx context.Context) ([]WorkflowStatus, error) { + workflows, err = RunAsStep(c, func(ctx context.Context) ([]WorkflowStatus, error) { return c.systemDB.listWorkflows(ctx, dbInput) }, WithStepName("DBOS.listWorkflows")) } else { - return c.systemDB.listWorkflows(c, dbInput) + workflows, err = c.systemDB.listWorkflows(c, dbInput) + } + if err != nil { + return nil, err } + + // Deserialize Input and Output fields if they were loaded + if params.loadInput || params.loadOutput { + for i := range workflows { + if params.loadInput && workflows[i].Input != nil { + encodedInput, ok := workflows[i].Input.(*string) + if !ok { + return nil, fmt.Errorf("workflow input must be encoded string, got %T", workflows[i].Input) + } + decodedInput, err := deserialize(encodedInput) + if err != nil { + return nil, fmt.Errorf("failed to deserialize workflow input for %s: %w", workflows[i].ID, err) + } + workflows[i].Input = decodedInput + } + if params.loadOutput && workflows[i].Output != nil { + encodedOutput, ok := workflows[i].Output.(*string) + if !ok { + return nil, fmt.Errorf("workflow output must be encoded string, got %T", workflows[i].Output) + } + decodedOutput, err := deserialize(encodedOutput) + if err != nil { + return nil, fmt.Errorf("failed to deserialize workflow output for %s: %w", workflows[i].ID, err) + } + workflows[i].Output = decodedOutput + } + } + } + + return workflows, nil } // ListWorkflows retrieves a list of workflows based on the provided filters. @@ -2014,15 +2141,37 @@ func (c *dbosContext) GetWorkflowSteps(_ DBOSContext, workflowID string) ([]Step loadOutput: loadOutput, } + var steps []StepInfo + var err error workflowState, ok := c.Value(workflowStateKey).(*workflowState) isWithinWorkflow := ok && workflowState != nil if isWithinWorkflow { - return RunAsStep(c, func(ctx context.Context) ([]StepInfo, error) { + steps, err = RunAsStep(c, func(ctx context.Context) ([]StepInfo, error) { return c.systemDB.getWorkflowSteps(ctx, getWorkflowStepsInput) }, WithStepName("DBOS.getWorkflowSteps")) } else { - return c.systemDB.getWorkflowSteps(c, getWorkflowStepsInput) + steps, err = c.systemDB.getWorkflowSteps(c, getWorkflowStepsInput) + } + if err != nil { + return nil, err } + + // Deserialize outputs if asked to + if loadOutput { + for i := range steps { + encodedOutput, ok := steps[i].Output.(*string) + if !ok { + return nil, fmt.Errorf("step output must be encoded string, got %T", steps[i].Output) + } + decodedOutput, err := deserialize(encodedOutput) + if err != nil { + return nil, fmt.Errorf("failed to deserialize step output for step %d: %w", steps[i].StepID, err) + } + steps[i].Output = decodedOutput + } + } + + return steps, nil } // GetWorkflowSteps retrieves the execution steps of a workflow.