Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions internal/replay/replay_http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"path/filepath"
"strconv"
"strings"
"sync"
"unicode"

"github.com/google/test-server/internal/config"
Expand All @@ -35,11 +36,13 @@ import (
)

type ReplayHTTPServer struct {
mu sync.Mutex // Mutex to protect fields below
prevRequestSHA string
seenFiles map[string]struct{}
config *config.EndpointConfig
recordingDir string
redactor *redact.Redact

config *config.EndpointConfig
recordingDir string
redactor *redact.Redact
}

func NewReplayHTTPServer(cfg *config.EndpointConfig, recordingDir string, redactor *redact.Redact) *ReplayHTTPServer {
Expand Down Expand Up @@ -70,9 +73,14 @@ func (r *ReplayHTTPServer) handleRequest(w http.ResponseWriter, req *http.Reques
return
}

redactedReq, err := r.createRedactedRequest(req)
// Safely get the current prevRequestSHA
r.mu.Lock()
currentPrevSHA := r.prevRequestSHA
r.mu.Unlock()

redactedReq, err := r.createRedactedRequest(req, currentPrevSHA)
if err != nil {
fmt.Printf("Error processing request")
fmt.Printf("Error processing request: %v\n", err)
http.Error(w, fmt.Sprintf("Error processing request: %v", err), http.StatusInternalServerError)
return
}
Expand All @@ -83,10 +91,15 @@ func (r *ReplayHTTPServer) handleRequest(w http.ResponseWriter, req *http.Reques
http.Error(w, fmt.Sprintf("Invalid recording file name: %v", err), http.StatusInternalServerError)
return
}

// Safely check seenFiles and update redactedReq.PreviousRequest if needed
r.mu.Lock()
if _, ok := r.seenFiles[fileName]; !ok {
// Reset to HeadSHA when first time seen request from the given file.
redactedReq.PreviousRequest = store.HeadSHA
}
r.mu.Unlock()

if req.Header.Get("Upgrade") == "websocket" {
fmt.Printf("Upgrading connection to websocket...\n")

Expand Down Expand Up @@ -114,14 +127,18 @@ func (r *ReplayHTTPServer) handleRequest(w http.ResponseWriter, req *http.Reques
fmt.Printf("Error writing response: %v\n", err)
panic(err)
}

// Safely update prevRequestSHA and seenFiles
r.mu.Lock()
if fileName != shaSum {
r.prevRequestSHA = shaSum
}
r.seenFiles[fileName] = struct{}{}
r.mu.Unlock()
}

func (r *ReplayHTTPServer) createRedactedRequest(req *http.Request) (*store.RecordedRequest, error) {
recordedRequest, err := store.NewRecordedRequest(req, r.prevRequestSHA, *r.config)
func (r *ReplayHTTPServer) createRedactedRequest(req *http.Request, prevSHA string) (*store.RecordedRequest, error) {
recordedRequest, err := store.NewRecordedRequest(req, prevSHA, *r.config)
if err != nil {
return nil, err
}
Expand Down
Loading