Skip to content
Closed
9 changes: 8 additions & 1 deletion dbos/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,13 @@ func (c *client) Enqueue(queueName, workflowName string, input any, opts ...Enqu
if params.priority > uint(math.MaxInt) {
return nil, fmt.Errorf("priority %d exceeds maximum allowed value %d", params.priority, math.MaxInt)
}

// Serialize input before storing in workflow status
encodedInput, err := serialize(params.workflowInput)
if err != nil {
return nil, fmt.Errorf("failed to serialize workflow input: %w", err)
}

status := WorkflowStatus{
Name: params.workflowName,
ApplicationVersion: params.applicationVersion,
Expand All @@ -155,7 +162,7 @@ func (c *client) Enqueue(queueName, workflowName string, input any, opts ...Enqu
CreatedAt: time.Now(),
Deadline: deadline,
Timeout: params.workflowTimeout,
Input: params.workflowInput,
Input: encodedInput,
QueueName: queueName,
DeduplicationID: params.deduplicationID,
Priority: int(params.priority),
Expand Down
19 changes: 4 additions & 15 deletions dbos/queue.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
package dbos

import (
"bytes"
"context"
"encoding/base64"
"encoding/gob"
"log/slog"
"math"
"math/rand"
Expand Down Expand Up @@ -229,18 +226,10 @@ func (qr *queueRunner) run(ctx *dbosContext) {

// Deserialize input
var input any
if len(workflow.input) > 0 {
inputBytes, err := base64.StdEncoding.DecodeString(workflow.input)
if err != nil {
qr.logger.Error("failed to decode input for workflow", "workflow_id", workflow.id, "error", err)
continue
}
buf := bytes.NewBuffer(inputBytes)
dec := gob.NewDecoder(buf)
if err := dec.Decode(&input); err != nil {
qr.logger.Error("failed to decode input for workflow", "workflow_id", workflow.id, "error", err)
continue
}
input, err = deserialize(workflow.input)
if err != nil {
qr.logger.Error("failed to deserialize input for workflow", "workflow_id", workflow.id, "error", err)
continue
}

_, err := registeredWorkflow.wrappedFunction(ctx, input, WithWorkflowID(workflow.id))
Expand Down
20 changes: 12 additions & 8 deletions dbos/recovery.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,5 @@
package dbos

import (
"strings"
)

func recoverPendingWorkflows(ctx *dbosContext, executorIDs []string) ([]WorkflowHandle[any], error) {
workflowHandles := make([]WorkflowHandle[any], 0)
// List pending workflows for the executors
Expand All @@ -18,9 +14,17 @@ func recoverPendingWorkflows(ctx *dbosContext, executorIDs []string) ([]Workflow
}

for _, workflow := range pendingWorkflows {
if inputStr, ok := workflow.Input.(string); ok {
if strings.Contains(inputStr, "Failed to decode") {
ctx.logger.Warn("Skipping workflow recovery due to input decoding failure", "workflow_id", workflow.ID, "name", workflow.Name)
// Deserialize the workflow input
var decodedInput any
if workflow.Input != nil {
inputString, ok := workflow.Input.(*string)
if !ok {
ctx.logger.Warn("Skipping workflow recovery due to invalid input type", "workflow_id", workflow.ID, "name", workflow.Name, "input_type", workflow.Input)
continue
}
decodedInput, err = deserialize(inputString)
if err != nil {
ctx.logger.Warn("Skipping workflow recovery due to input decoding failure", "workflow_id", workflow.ID, "name", workflow.Name, "error", err)
continue
}
}
Expand Down Expand Up @@ -59,7 +63,7 @@ func recoverPendingWorkflows(ctx *dbosContext, executorIDs []string) ([]Workflow
WithWorkflowID(workflow.ID),
}
// Create a workflow context from the executor context
handle, err := registeredWorkflow.wrappedFunction(ctx, workflow.Input, opts...)
handle, err := registeredWorkflow.wrappedFunction(ctx, decodedInput, opts...)
if err != nil {
return nil, err
}
Expand Down
35 changes: 24 additions & 11 deletions dbos/serialization.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,37 @@ import (
"strings"
)

func serialize(data any) (string, error) {
var inputBytes []byte
if data != nil {
var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
if err := enc.Encode(&data); err != nil {
return "", fmt.Errorf("failed to encode data: %w", err)
}
inputBytes = buf.Bytes()
func serialize(data any) (*string, error) {
// Handle nil values specially - return nil pointer which will be stored as NULL in DB
if data == nil {
return nil, nil
}
return base64.StdEncoding.EncodeToString(inputBytes), nil

// Handle empty string specially - return pointer to empty string which will be stored as "" in DB
if str, ok := data.(string); ok && str == "" {
return &str, nil
}

var buf bytes.Buffer
enc := gob.NewEncoder(&buf)
if err := enc.Encode(&data); err != nil {
return nil, fmt.Errorf("failed to encode data: %w", err)
}
inputBytes := buf.Bytes()

encoded := base64.StdEncoding.EncodeToString(inputBytes)
return &encoded, nil
}

func deserialize(data *string) (any, error) {
if data == nil || *data == "" {
if data == nil {
return nil, nil
}

if *data == "" {
return "", nil
}

dataBytes, err := base64.StdEncoding.DecodeString(*data)
if err != nil {
return nil, fmt.Errorf("failed to decode data: %w", err)
Expand Down
Loading
Loading