diff --git a/cmd/round/main.go b/cmd/round/main.go index f1e2192f..dbd5e2a8 100644 --- a/cmd/round/main.go +++ b/cmd/round/main.go @@ -184,7 +184,7 @@ func processRoundDecimalFile(inputPath string) (err error) { if err != nil { return err } - bufWriter = bufio.NewWriter(outputFile) + bufWriter = bufio.NewWriterSize(outputFile, 8192) colCount = len(cols) } else { // no decimal column, quick exit diff --git a/stage/map.go b/stage/map.go index c08e99ac..a583a5ad 100644 --- a/stage/map.go +++ b/stage/map.go @@ -108,6 +108,13 @@ func ParseStage(stage *Stage, stages Map) (*Stage, error) { } } stages[stage.Id] = stage + + // Set the seed for this stage (only relevant for non-stream stages) + // Stream stages will have their seeds set during runAsMultipleStreams + if stage.StreamCount == nil || *stage.StreamCount <= 1 { + stage.seed = stage.States.RandSeed + } + for _, nextStagePath := range stage.NextStagePaths { if nextStage, err := ParseStageFromFile(nextStagePath, stages); err != nil { return nil, err diff --git a/stage/mysql_run_recorder.go b/stage/mysql_run_recorder.go index 7c56aca5..e7707367 100644 --- a/stage/mysql_run_recorder.go +++ b/stage/mysql_run_recorder.go @@ -4,9 +4,10 @@ import ( "context" "database/sql" _ "embed" - _ "github.com/go-sql-driver/mysql" "pbench/log" "pbench/utils" + + _ "github.com/go-sql-driver/mysql" ) var ( @@ -65,7 +66,7 @@ VALUES (?, ?, ?, 0, 0, 0, ?)` func (m *MySQLRunRecorder) RecordQuery(_ context.Context, s *Stage, result *QueryResult) { recordNewQuery := `INSERT INTO pbench_queries (run_id, stage_id, query_file, query_index, query_id, sequence_no, -cold_run, succeeded, start_time, end_time, row_count, expected_row_count, duration_ms, info_url) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` +cold_run, succeeded, start_time, end_time, row_count, expected_row_count, duration_ms, info_url, seed) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)` var queryFile string if result.Query.File != nil { queryFile = *result.Query.File @@ -83,11 +84,13 @@ cold_run, succeeded, start_time, end_time, row_count, expected_row_count, durati result.RowCount, sql.NullInt32{ Int32: int32(result.Query.ExpectedRowCount), Valid: result.Query.ExpectedRowCount >= 0, - }, result.Duration.Milliseconds(), result.InfoUrl) + }, result.Duration.Milliseconds(), result.InfoUrl, result.Seed) + log.Info().Str("stage_id", result.StageId).Stringer("start_time", result.StartTime).Stringer("end_time", result.EndTime). + Str("info_url", result.InfoUrl).Int64("seed", result.Seed).Msg("recorded query result to MySQL") if err != nil { log.Error().EmbedObject(result).Err(err).Msg("failed to send query summary to MySQL") } - updateRunInfo := `UPDATE pbench_runs SET start_time = ?, queries_ran = queries_ran + 1, failed = ?, mismatch = ? WHERE run_id = ?` + updateRunInfo := `UPDATE pbench_runs SET start_time = ?, queries_ran = queries_ran + 1, failed = ? , mismatch = ? WHERE run_id = ?` res, err := m.db.Exec(updateRunInfo, s.States.RunStartTime, m.failed, m.mismatch, m.runId) if err != nil { log.Error().Err(err).Str("run_name", s.States.RunName).Int64("run_id", m.runId). diff --git a/stage/pbench_queries_ddl.sql b/stage/pbench_queries_ddl.sql index ecfa0f08..382b4c5f 100644 --- a/stage/pbench_queries_ddl.sql +++ b/stage/pbench_queries_ddl.sql @@ -14,6 +14,7 @@ create table if not exists pbench_queries expected_row_count int null, duration_ms int not null, info_url varchar(255) not null, + seed bigint not null, primary key (run_id, stage_id, query_file, query_index, sequence_no) ) partition by hash (`run_id`) partitions 16; diff --git a/stage/result.go b/stage/result.go index d615829d..1bb5cd10 100644 --- a/stage/result.go +++ b/stage/result.go @@ -1,13 +1,15 @@ package stage import ( - "github.com/rs/zerolog" "pbench/log" "time" + + "github.com/rs/zerolog" ) type QueryResult struct { StageId string + Seed int64 Query *Query QueryId string InfoUrl string diff --git a/stage/stage.go b/stage/stage.go index f0ec5192..4990f8f4 100644 --- a/stage/stage.go +++ b/stage/stage.go @@ -61,6 +61,10 @@ type Stage struct { // Use RandomlyExecuteUntil to specify a duration like "1h" or an integer as the number of queries should be executed // before exiting. RandomlyExecuteUntil *string `json:"randomly_execute_until,omitempty"` + // If NoRandomDuplicates is set to true, queries will not be repeated during random execution + // until all queries have been executed once. After that, the selection pool resets if more + // executions are needed. + NoRandomDuplicates *bool `json:"no_random_duplicates,omitempty"` // If not set, the default is 1. The default value is set when the stage is run. ColdRuns *int `json:"cold_runs,omitempty" validate:"omitempty,gte=0"` // If not set, the default is 0. @@ -87,6 +91,14 @@ type Stage struct { // knob was not set to true. SaveJson *bool `json:"save_json,omitempty"` NextStagePaths []string `json:"next,omitempty"` + // StreamCount specifies how many parallel instances of this stage should run. + // Each stream will execute the same queries with a different seed for reproducible randomization. + // If not set, the stage runs once with the default seed. + StreamCount *int `json:"stream_count,omitempty" validate:"omitempty,gte=1"` + // Seeds specifies custom seeds for stream instances. + // Length must be either 1 (base seed for all streams with offsets) or equal to stream_count (individual seeds). + // If empty and stream_count > 1, seeds will be auto-generated from States.RandSeed with offsets. + Seeds []int64 `json:"seeds,omitempty"` // BaseDir is set to the directory path of this stage's location. It is used to locate the descendant stages when // their locations are specified using relative paths. It is not possible to set this in a stage definition json file. @@ -101,6 +113,11 @@ type Stage struct { // Client is by default passed down to descendant stages. Client *presto.Client `json:"-"` + // Stream instance information for custom seeding and identification + // Descendant stages will **NOT** inherit this value from their parents so this is declared as a value not a pointer. + // Custom seed for this stage instance, nil if using default seeding + seed int64 `json:"-"` + // Convenient access to the expected row count array under the current schema. expectedRowCountInCurrentSchema []int // Convenient access to the catalog, schema, and timezone @@ -150,6 +167,7 @@ func (s *Stage) Run(ctx context.Context) int { go func() { s.States.wgExitMainStage.Wait() + close(s.States.resultChan) // wgExitMainStage goes down to 0 after all the goroutines finish. Then we exit the driver by // closing the timeToExit channel, which will trigger the graceful shutdown process - // (flushing the log file, writing the final time log summary, etc.). @@ -174,20 +192,30 @@ func (s *Stage) Run(ctx context.Context) int { for { select { - case result := <-s.States.resultChan: + case result, ok := <-s.States.resultChan: + if !ok { + // resultChan closed: all results received, finalize and exit + s.States.RunFinishTime = time.Now() + for _, recorder := range s.States.runRecorders { + recorder.RecordRun(utils.GetCtxWithTimeout(time.Second*5), s, results) + } + return int(s.States.exitCode.Load()) + } results = append(results, result) for _, recorder := range s.States.runRecorders { recorder.RecordQuery(utils.GetCtxWithTimeout(time.Second*5), s, result) } - case sig := <-timeToExit: - if sig != nil { - // Cancel the context and wait for the goroutines to exit. - s.States.AbortAll(fmt.Errorf(sig.String())) + case sig, ok := <-timeToExit: + if !ok { + // timeToExit channel closed, no more signals — continue to receive results continue } s.States.RunFinishTime = time.Now() - for _, recorder := range s.States.runRecorders { - recorder.RecordRun(utils.GetCtxWithTimeout(time.Second*5), s, results) + if sig != nil { + // Received shutdown signal; cancel ongoing queries + log.Info().Msgf("Shutdown signal received: %v. Aborting queries...", sig) + s.States.AbortAll(fmt.Errorf("%s", sig.String())) + // Keep receiving results until resultChan is closed } return int(s.States.exitCode.Load()) } @@ -237,8 +265,17 @@ func (s *Stage) run(ctx context.Context) (returnErr error) { if preStageErr != nil { return fmt.Errorf("pre-stage script execution failed: %w", preStageErr) } + + // Check if this stage should execute as multiple parallel streams + if s.StreamCount != nil && *s.StreamCount > 1 { + return s.runAsMultipleStreams(ctx) + } + if len(s.Queries)+len(s.QueryFiles) > 0 { if *s.RandomExecution { + if s.RandomlyExecuteUntil == nil { + return fmt.Errorf("randomly_execute_until must be set for random execution in stage %s", s.Id) + } returnErr = s.runRandomly(ctx) } else { returnErr = s.runSequentially(ctx) @@ -254,6 +291,133 @@ func (s *Stage) run(ctx context.Context) (returnErr error) { return } +// runAsMultipleStreams executes this stage as multiple parallel stream instances, +// each with its own seed for reproducible randomization +func (s *Stage) runAsMultipleStreams(ctx context.Context) error { + streamCount := *s.StreamCount + + // Validate seeds configuration + if err := s.validateStreamSeeds(streamCount); err != nil { + return err + } + + // Determine if we're using individual seeds (affects MySQL recording) + usingIndividualSeeds := len(s.Seeds) == streamCount + if usingIndividualSeeds { + s.States.RandSeedUsed = false + log.Info(). + Str("stage_id", s.Id). + Int("stream_count", streamCount). + Msg("using individual seeds per stream; base rand_seed will not be recorded") + } + + // Create channels for coordinating stream execution + errChan := make(chan error, streamCount) + var wg sync.WaitGroup + + log.Info(). + Str("stage_id", s.Id). + Int("stream_count", streamCount). + Msg("starting parallel stream execution") + + // Launch each stream instance in parallel + for i := 0; i < streamCount; i++ { + wg.Add(1) + streamSeed := s.getSeedForStream(i) + streamIndex := i + 1 // 1-indexed for user-friendly logging + + go func(index int, seed int64) { + defer wg.Done() + + log.Info(). + Str("stage_id", s.Id). + Int("stream_instance", index). + Int64("seed", seed). + Msg("stream instance started") + + if err := s.runStreamInstance(ctx, index, seed); err != nil { + log.Error(). + Err(err). + Str("stage_id", s.Id). + Int("stream_instance", index). + Msg("stream instance failed") + errChan <- fmt.Errorf("stream %d failed: %w", index, err) + } else { + log.Info(). + Str("stage_id", s.Id). + Int("stream_instance", index). + Msg("stream instance completed successfully") + } + }(streamIndex, streamSeed) + } + + wg.Wait() + close(errChan) + + // Check if any stream reported an error + for err := range errChan { + if err != nil { + return err + } + } + + log.Info(). + Str("stage_id", s.Id). + Int("stream_count", streamCount). + Msg("all stream instances completed successfully") + + return nil +} + +// validateStreamSeeds validates the seeds configuration for streams +func (s *Stage) validateStreamSeeds(streamCount int) error { + if len(s.Seeds) == 0 { + // No seeds specified - will auto-generate from States.RandSeed + return nil + } + + if len(s.Seeds) == 1 { + // Single base seed - all streams will derive from it with offsets + return nil + } + + if len(s.Seeds) == streamCount { + // Individual seed for each stream - perfect + return nil + } + + return fmt.Errorf("seeds array length (%d) must be either 1 (base seed) or equal to stream_count (%d), got %d seeds", + len(s.Seeds), streamCount, len(s.Seeds)) +} + +// getSeedForStream returns the appropriate seed for a given stream index (0-based) +func (s *Stage) getSeedForStream(streamIndex int) int64 { + if len(s.Seeds) == 0 { + // No custom seeds: generate from base RandSeed + offset + return s.States.RandSeed + int64(streamIndex)*1000 + } + + if len(s.Seeds) == 1 { + // Single base seed: use it plus offset + return s.Seeds[0] + int64(streamIndex)*1000 + } + + // Individual seeds: use the specific seed for this instance + return s.Seeds[streamIndex] +} + +func (s *Stage) runStreamInstance(ctx context.Context, streamIndex int, seed int64) error { + // Set the seed for this stream instance + // This affects the randomization in runRandomly() + s.seed = seed + + // Execute queries based on execution mode + if *s.RandomExecution { + return s.runRandomly(ctx) + } + return s.runSequentially(ctx) +} + func (s *Stage) runSequentially(ctx context.Context) (returnErr error) { // Try to match an array of expected row counts keyToMatch := s.currentCatalog + "." + s.currentSchema @@ -343,21 +507,57 @@ func (s *Stage) runRandomly(ctx context.Context) error { return nil } } - r := rand.New(rand.NewSource(s.States.RandSeed)) + + r := rand.New(rand.NewSource(s.seed)) + log.Info().Str("stream_id", s.Id).Int64("custom_seed", s.seed).Msg("initialized with seed") s.States.RandSeedUsed = true - log.Info().Int64("seed", s.States.RandSeed).Msg("random source seeded") - randIndexUpperBound := len(s.Queries) + len(s.QueryFiles) - for i := 1; continueExecution(i); i++ { - idx := r.Intn(randIndexUpperBound) - if i <= s.States.RandSkip { - if i == s.States.RandSkip { - log.Info().Msgf("skipped %d random selections", i) + + totalQueries := len(s.Queries) + len(s.QueryFiles) + + // refreshIndices generates a new set of random indices for selecting queries. + // If NoRandomDuplicates is set to true, it generates a shuffled list of all indices. + // Otherwise, it generates a list of random indices with possible duplicates. + refreshIndices := func() []int { + indices := make([]int, totalQueries) + if s.NoRandomDuplicates != nil && *s.NoRandomDuplicates { + for i := 0; i < totalQueries; i++ { + indices[i] = i + } + r.Shuffle(len(indices), func(i, j int) { + indices[i], indices[j] = indices[j], indices[i] + }) + } else { + for i := 0; i < totalQueries; i++ { + indices[i] = r.Intn(totalQueries) } + } + return indices + } + + executionCount := 1 + var currentIndices []int + var indexPosition int + + for continueExecution(executionCount) { + // Refresh indices when all queries have been used + if currentIndices == nil || indexPosition >= len(currentIndices) { + currentIndices = refreshIndices() + indexPosition = 0 + } + + idx := currentIndices[indexPosition] + indexPosition++ + + if executionCount <= s.States.RandSkip { + if executionCount == s.States.RandSkip { + log.Info().Msgf("skipped %d random selections", executionCount) + } + executionCount++ continue } + if idx < len(s.Queries) { - // Run query embedded in the json file. - pseudoFileName := fmt.Sprintf("rand_%d", i) + pseudoFileName := fmt.Sprintf("rand_%d", executionCount) if err := s.runQueries(ctx, s.Queries[idx:idx+1], &pseudoFileName, 0); err != nil { return err } @@ -367,11 +567,12 @@ func (s *Stage) runRandomly(ctx context.Context) error { if relPath, relErr := filepath.Rel(s.BaseDir, queryFile); relErr == nil { fileAlias = relPath } - fileAlias = fmt.Sprintf("rand_%d_%s", i, fileAlias) + fileAlias = fmt.Sprintf("rand_%d_%s", executionCount, fileAlias) if err := s.runQueryFile(ctx, queryFile, nil, &fileAlias); err != nil { return err } } + executionCount++ } log.Info().Msg("random execution concluded.") return nil @@ -476,6 +677,7 @@ func (s *Stage) runQuery(ctx context.Context, query *Query) (result *QueryResult result = &QueryResult{ StageId: s.Id, + Seed: s.seed, Query: query, StartTime: time.Now(), } diff --git a/stage/stage_utils.go b/stage/stage_utils.go index de4782bb..e571eb7d 100644 --- a/stage/stage_utils.go +++ b/stage/stage_utils.go @@ -68,6 +68,9 @@ func (s *Stage) MergeWith(other *Stage) *Stage { if other.RandomExecution != nil { s.RandomExecution = other.RandomExecution } + if other.NoRandomDuplicates != nil { + s.NoRandomDuplicates = other.NoRandomDuplicates + } if other.RandomlyExecuteUntil != nil { s.RandomlyExecuteUntil = other.RandomlyExecuteUntil } @@ -93,6 +96,14 @@ func (s *Stage) MergeWith(other *Stage) *Stage { s.NextStagePaths = append(s.NextStagePaths, other.NextStagePaths...) s.BaseDir = other.BaseDir + // Stream configuration - use other's values if set + if other.StreamCount != nil { + s.StreamCount = other.StreamCount + } + if len(other.Seeds) > 0 { + s.Seeds = append(s.Seeds, other.Seeds...) + } + s.PreStageShellScripts = append(s.PreStageShellScripts, other.PreStageShellScripts...) s.PostQueryShellScripts = append(s.PostQueryShellScripts, other.PostQueryShellScripts...) s.PostStageShellScripts = append(s.PostStageShellScripts, other.PostStageShellScripts...) @@ -194,6 +205,9 @@ func (s *Stage) setDefaults() { if s.RandomExecution == nil { s.RandomExecution = &falseValue } + if s.NoRandomDuplicates == nil { + s.NoRandomDuplicates = &falseValue + } if s.AbortOnError == nil { s.AbortOnError = &falseValue } @@ -235,6 +249,9 @@ func (s *Stage) propagateStates() { if nextStage.RandomExecution == nil { nextStage.RandomExecution = s.RandomExecution } + if nextStage.NoRandomDuplicates == nil { + nextStage.NoRandomDuplicates = s.NoRandomDuplicates + } if nextStage.RandomlyExecuteUntil == nil { nextStage.RandomlyExecuteUntil = s.RandomlyExecuteUntil } diff --git a/stage/streams.go b/stage/streams.go new file mode 100644 index 00000000..0ac354bb --- /dev/null +++ b/stage/streams.go @@ -0,0 +1,61 @@ +package stage + +import ( + "fmt" + "os" + "path/filepath" +) + +// Streams defines the configuration for stream-based execution +type Streams struct { + StreamPath string `json:"stream_file_path"` + StreamCount int `json:"stream_count"` + Seeds []int64 `json:"seeds,omitempty"` +} + +// Validate checks if the Streams configuration is valid +func (s *Streams) Validate() error { + if s.StreamCount <= 0 { + return fmt.Errorf("stream_count must be positive, got %d for stream %s", s.StreamCount, s.StreamPath) + } + + if len(s.Seeds) > 0 { + if len(s.Seeds) != 1 && len(s.Seeds) != s.StreamCount { + return fmt.Errorf("seeds array length (%d) must be either 1 or equal to stream_count (%d) for stream %s", + len(s.Seeds), s.StreamCount, s.StreamPath) + } + } + + return nil +} + +// GetValidatedPath returns the absolute path to the stream file and validates it exists +func (s *Streams) GetValidatedPath(baseDir string) (string, error) { + streamPath := s.StreamPath + if !filepath.IsAbs(streamPath) { + streamPath = filepath.Join(baseDir, streamPath) + } + + if _, err := os.Stat(streamPath); err != nil { + return "", fmt.Errorf("stream file %s does not exist: %w", streamPath, err) + } + + return streamPath, nil +} + +// GetSeedForInstance returns the appropriate seed for stream instance +func (s *Streams) GetSeedForInstance(instanceIndex int) (int64, bool) { + if len(s.Seeds) == 0 { + return 0, false + } + + if len(s.Seeds) == 1 { + return s.Seeds[0] + int64(instanceIndex), true + } + + if instanceIndex <= len(s.Seeds) { + return s.Seeds[instanceIndex], true + } + + return 0, false +} diff --git a/utils/utils.go b/utils/utils.go index 122e5f4f..bf68e40c 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -59,7 +59,7 @@ func InitLogFile(logPath string) (finalizer func()) { // In this case, the global logger is not changed. Log messages are still printed to stderr. return func() {} } else { - bufWriter := bufio.NewWriter(logFile) + bufWriter := bufio.NewWriterSize(logFile, 8192) log.SetGlobalLogger(zerolog.New(io.MultiWriter(os.Stderr, bufWriter)).With().Timestamp().Stack().Logger()) log.Info().Str("log_path", logPath).Msg("log file will be saved to this path") return func() {