diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 75f065b..bddf0ef 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,11 +6,11 @@ repos: - id: check-useless-excludes # - id: identity # Prints all files passed to pre-commits. Debugging. - repo: https://github.com/lyz-code/yamlfix - rev: 1.17.0 + rev: 1.19.0 hooks: - id: yamlfix - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v5.0.0 + rev: v6.0.0 hooks: - id: check-added-large-files args: @@ -44,20 +44,20 @@ repos: rev: v1.37.1 hooks: - id: yamllint - - repo: https://github.com/psf/black - rev: 25.1.0 + - repo: https://github.com/psf/black-pre-commit-mirror + rev: 25.9.0 hooks: - id: black # It is recommended to specify the latest version of Python # supported by your project here language_version: python3.11 - repo: https://github.com/asottile/blacken-docs - rev: 1.19.1 + rev: 1.20.0 hooks: - id: blacken-docs # exclude: docs/source/how_to_guides/optimization/how_to_specify_constraints.md - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.12.7 + rev: v0.14.3 hooks: - id: ruff # args: @@ -75,7 +75,7 @@ repos: - id: nbqa-black - id: nbqa-ruff - repo: https://github.com/executablebooks/mdformat - rev: 0.7.22 + rev: 1.0.0 hooks: - id: mdformat additional_dependencies: diff --git a/benchmark_code/BENCHMARK_README.md b/benchmark_code/BENCHMARK_README.md index 2290e4c..14ddd1c 100644 --- a/benchmark_code/BENCHMARK_README.md +++ b/benchmark_code/BENCHMARK_README.md @@ -1,32 +1,43 @@ # Benchmark Comparison Workflow -This document explains how to compare performance between the main branch and a PR branch with optimizations. +This document explains how to compare performance between the main branch and a PR +branch with optimizations. ## Scripts Overview ### Core Scripts -1. **`benchmark.py`** - Runs comprehensive performance benchmarks across multiple dataset sizes and saves results to JSON -2. **`benchmark_profile.py`** - Runs profiling for a single configuration with detailed memory tracking and timing breakdown -3. **`benchmark_compare.py`** - Compares results from two benchmark runs + +1. **`benchmark.py`** - Runs comprehensive performance benchmarks across multiple + dataset sizes and saves results to JSON +1. **`benchmark_profile.py`** - Runs profiling for a single configuration with detailed + memory tracking and timing breakdown +1. **`benchmark_compare.py`** - Compares results from two benchmark runs ### Supporting Files -4. **`benchmark_setup.py`** - Shared configuration (TT_TARGETS, MAPPER, utilities) used by both main scripts -5. **`benchmark_make_data.py`** - Synthetic data generation for standardized testing - - `make_data(N, scramble_data=False)` - Generate N households with optional data scrambling + +4. **`benchmark_setup.py`** - Shared configuration (TT_TARGETS, MAPPER, utilities) used + by both main scripts +1. **`benchmark_make_data.py`** - Synthetic data generation for standardized testing + - `make_data(N, scramble_data=False)` - Generate N households with optional data + scrambling - By default, data is kept in sorted p_id order for optimal performance - Set `scramble_data=True` to test performance with unsorted data -6. **`benchmark_compare.py`** - Stage-by-stage comparison tool +1. **`benchmark_compare.py`** - Stage-by-stage comparison tool ## Key Features ### 3-Stage Timing Analysis + All scripts break down execution into: + - **Stage 1**: Data preprocessing & DAG creation -- **Stage 2**: Core computation (tax/transfer calculations) +- **Stage 2**: Core computation (tax/transfer calculations) - **Stage 3**: DataFrame formatting (JAX → pandas conversion) ### Memory Tracking -- Both `benchmark.py` and `benchmark_profile.py` now include comprehensive memory tracking + +- Both `benchmark.py` and `benchmark_profile.py` now include comprehensive memory + tracking - Continuous monitoring of peak memory usage during execution - Memory delta reporting (initial → final) @@ -48,7 +59,7 @@ python benchmark.py -scramble # or: benchmark_results_20250819_143022_scrambled.json ``` -### Step 2: Run benchmark on PR branch +### Step 2: Run benchmark on PR branch ```bash # Switch to PR branch (ttsim) @@ -114,7 +125,8 @@ py-spy record -o profile_scrambled.svg -- python benchmark_profile.py -N 32768 - ## Data Generation Options -The `benchmark_make_data.py` module provides the `make_data()` function with the following options: +The `benchmark_make_data.py` module provides the `make_data()` function with the +following options: ```python # Generate sorted data (default - optimal performance) diff --git a/benchmark_code/benchmark.py b/benchmark_code/benchmark.py index 4d4515c..68015b3 100644 --- a/benchmark_code/benchmark.py +++ b/benchmark_code/benchmark.py @@ -1,45 +1,59 @@ """Performance comparison script for numpy vs jax backends.""" -import json + +import argparse import hashlib +import json import time -import argparse from datetime import datetime -from gettsim import main, InputData, MainTarget, TTTargets, Labels, SpecializedEnvironment, RawResults +from benchmark_make_data import make_data # Import shared benchmark configuration and utilities from benchmark_setup import ( - TT_TARGETS, MAPPER, JAX_AVAILABLE, - sync_jax_if_needed, clear_jax_cache, get_memory_usage_mb, MemoryTracker, - force_garbage_collection, reset_session_state, BENCHMARK_HOUSEHOLD_SIZES, BACKENDS + BACKENDS, + BENCHMARK_HOUSEHOLD_SIZES, + MAPPER, + TT_TARGETS, + MemoryTracker, + get_memory_usage_mb, + reset_session_state, + sync_jax_if_needed, +) +from gettsim import ( + InputData, + Labels, + MainTarget, + SpecializedEnvironment, + TTTargets, + main, ) -from benchmark_make_data import make_data def run_benchmark( - N_households, backend, - reset_session=False, - sync_jax=False, - scramble_data=False, - ): + N_households, + backend, + reset_session=False, + sync_jax=False, + scramble_data=False, +): """Run a single benchmark with 3-stage timing as in gettsim_profile_stages.py.""" print(f"Running benchmark: {N_households:,} households, {backend} backend") - + # Reset session state to ensure clean environment if reset_session: reset_session_state(backend) - + # Generate data print(" Generating data...") data = make_data(N_households, scramble_data=scramble_data) - + # Memory tracking setup - always track peak memory for benchmarking tracker = MemoryTracker() - + # Initial memory reading initial_memory = get_memory_usage_mb() tracker.start_monitoring() - + try: # STAGE 1: Data preprocessing and DAG creation print(" Stage 1: Data preprocessing and DAG creation...") @@ -64,7 +78,7 @@ def run_benchmark( include_fail_nodes=True, include_warn_nodes=False, backend=backend, - ) + ) # Force JAX synchronization before recording end time if sync_jax: @@ -74,11 +88,11 @@ def run_benchmark( stage1_time = stage1_end - stage1_start # Generate hash for Stage 1 output (tmp) - stage1_hash = hashlib.md5(str(tmp).encode('utf-8')).hexdigest() + stage1_hash = hashlib.md5(str(tmp).encode("utf-8")).hexdigest() # STAGE 2: Computation only (no data preprocessing) print(" Stage 2: Computation only...") - + stage2_start = time.time() raw_results_stage2 = main( @@ -92,7 +106,9 @@ def run_benchmark( tree=TT_TARGETS, ), processed_data=tmp["processed_data"], - input_data=InputData.flat(tmp["input_data"]["flat"]), # Provide the flat input data from stage 1 + input_data=InputData.flat( + tmp["input_data"]["flat"] + ), # Provide the flat input data from stage 1 labels=Labels(root_nodes=tmp["labels"]["root_nodes"]), tt_function=tmp["tt_function"], # Reuse pre-compiled JAX function include_fail_nodes=False, @@ -108,7 +124,7 @@ def run_benchmark( stage2_time = stage2_end - stage2_start # Generate hash for Stage 2 output (raw_results_stage2) - stage2_hash = hashlib.md5(str(raw_results_stage2).encode('utf-8')).hexdigest() + stage2_hash = hashlib.md5(str(raw_results_stage2).encode("utf-8")).hexdigest() # STAGE 3: Convert raw results to DataFrame (no computation, just formatting) print(" Stage 3: Convert raw results to DataFrame...") @@ -121,7 +137,9 @@ def run_benchmark( tree=TT_TARGETS, ), raw_results=raw_results_stage2["raw_results"], - input_data=InputData.flat(tmp["input_data"]["flat"]), # Provide the flat input data from stage 1 + input_data=InputData.flat( + tmp["input_data"]["flat"] + ), # Provide the flat input data from stage 1 processed_data=tmp["processed_data"], labels=Labels(root_nodes=tmp["labels"]["root_nodes"]), specialized_environment=SpecializedEnvironment( @@ -139,10 +157,10 @@ def run_benchmark( stage3_end = time.time() stage3_time = stage3_end - stage3_start total_time = stage1_time + stage2_time + stage3_time - + # Generate hash for Stage 3 output (result) - stage3_hash = hashlib.md5(str(result).encode('utf-8')).hexdigest() - + stage3_hash = hashlib.md5(str(result).encode("utf-8")).hexdigest() + # Final memory reading final_memory = get_memory_usage_mb() tracker.stop_monitoring() @@ -150,34 +168,36 @@ def run_benchmark( # Calculate memory delta memory_delta = final_memory - initial_memory - + # Results shape - result_shape = result.shape if hasattr(result, 'shape') else 'N/A' - + result_shape = result.shape if hasattr(result, "shape") else "N/A" + print(f" Stage 1: {stage1_time:.4f}s ({stage1_time/total_time*100:.1f}%)") print(f" Stage 2: {stage2_time:.4f}s ({stage2_time/total_time*100:.1f}%)") print(f" Stage 3: {stage3_time:.4f}s ({stage3_time/total_time*100:.1f}%)") print(f" Total: {total_time:.4f}s") print(f" Result shape: {result_shape}") - print(f" Memory: {initial_memory:.1f} -> {final_memory:.1f} MB (Δ{memory_delta:+.1f}, peak: {peak_memory:.1f})") - + print( + f" Memory: {initial_memory:.1f} -> {final_memory:.1f} MB (Δ{memory_delta:+.1f}, peak: {peak_memory:.1f})" + ) + return { - 'stage1_time': stage1_time, - 'stage2_time': stage2_time, - 'stage3_time': stage3_time, - 'execution_time': total_time, - 'stage1_hash': stage1_hash, - 'stage2_hash': stage2_hash, - 'stage3_hash': stage3_hash, - 'initial_memory': initial_memory, - 'final_memory': final_memory, - 'peak_memory': peak_memory, - 'memory_delta': memory_delta, - 'result_shape': result_shape, - 'backend': backend, - 'N_households': N_households, + "stage1_time": stage1_time, + "stage2_time": stage2_time, + "stage3_time": stage3_time, + "execution_time": total_time, + "stage1_hash": stage1_hash, + "stage2_hash": stage2_hash, + "stage3_hash": stage3_hash, + "initial_memory": initial_memory, + "final_memory": final_memory, + "peak_memory": peak_memory, + "memory_delta": memory_delta, + "result_shape": result_shape, + "backend": backend, + "N_households": N_households, } - + except Exception as e: print(f" ERROR: {e}") tracker.stop_monitoring() @@ -186,26 +206,30 @@ def run_benchmark( def main_cli(): """Main function for command line interface.""" - parser = argparse.ArgumentParser(description='Run GETTSIM performance benchmarks') - parser.add_argument('-scramble', '--scramble-data', action='store_true', - help='Scramble data to create unsorted p_id order (default: sorted)') - + parser = argparse.ArgumentParser(description="Run GETTSIM performance benchmarks") + parser.add_argument( + "-scramble", + "--scramble-data", + action="store_true", + help="Scramble data to create unsorted p_id order (default: sorted)", + ) + args = parser.parse_args() - + # Dataset sizes (number of households) household_sizes = BENCHMARK_HOUSEHOLD_SIZES backends = BACKENDS - + results = {} - + # Add metadata results["metadata"] = { "timestamp": datetime.now().isoformat(), "household_sizes": household_sizes, "backends": backends, - "scrambled_data": args.scramble_data + "scrambled_data": args.scramble_data, } - + for backend in backends: print(f"\n{'='*60}") print(f"Testing {backend} backend") @@ -214,41 +238,51 @@ def main_cli(): else: print("Data scrambling: DISABLED (sorted p_id order)") print(f"{'='*60}") - + # Clear all caches and reset session before starting new backend print(f"Preparing environment for {backend} backend...") reset_session_state(backend) - + for N_households in household_sizes: # Add extra session reset for larger datasets to ensure clean state reset_between_sizes = N_households >= 2**18 # Reset for 256k+ households - + result = run_benchmark( - N_households, - backend, - reset_session=False, # reset_between_sizes (no impact on results) + N_households, + backend, + reset_session=False, # reset_between_sizes (no impact on results) sync_jax=True, # Set to True if you want to force JAX synchronization - # Seems necessary for realistic (reported time = wall clock time) JAX timings + # Seems necessary for realistic (reported time = wall clock time) JAX timings scramble_data=args.scramble_data, ) - if result and result.get('execution_time'): + if result and result.get("execution_time"): # Store all stage timing data - results[f"{N_households}_{backend}_stage1_time"] = result['stage1_time'] - results[f"{N_households}_{backend}_stage2_time"] = result['stage2_time'] - results[f"{N_households}_{backend}_stage3_time"] = result['stage3_time'] - results[f"{N_households}_{backend}_time"] = result['execution_time'] # Total time - results[f"{N_households}_{backend}_stage1_hash"] = result['stage1_hash'] - results[f"{N_households}_{backend}_stage2_hash"] = result['stage2_hash'] - results[f"{N_households}_{backend}_stage3_hash"] = result['stage3_hash'] - results[f"{N_households}_{backend}_initial_memory"] = result['initial_memory'] - results[f"{N_households}_{backend}_final_memory"] = result['final_memory'] - results[f"{N_households}_{backend}_memory_delta"] = result['memory_delta'] - results[f"{N_households}_{backend}_peak_memory"] = result['peak_memory'] - results[f"{N_households}_{backend}_result_shape"] = result['result_shape'] + results[f"{N_households}_{backend}_stage1_time"] = result["stage1_time"] + results[f"{N_households}_{backend}_stage2_time"] = result["stage2_time"] + results[f"{N_households}_{backend}_stage3_time"] = result["stage3_time"] + results[f"{N_households}_{backend}_time"] = result[ + "execution_time" + ] # Total time + results[f"{N_households}_{backend}_stage1_hash"] = result["stage1_hash"] + results[f"{N_households}_{backend}_stage2_hash"] = result["stage2_hash"] + results[f"{N_households}_{backend}_stage3_hash"] = result["stage3_hash"] + results[f"{N_households}_{backend}_initial_memory"] = result[ + "initial_memory" + ] + results[f"{N_households}_{backend}_final_memory"] = result[ + "final_memory" + ] + results[f"{N_households}_{backend}_memory_delta"] = result[ + "memory_delta" + ] + results[f"{N_households}_{backend}_peak_memory"] = result["peak_memory"] + results[f"{N_households}_{backend}_result_shape"] = result[ + "result_shape" + ] else: # Store None values for failed runs results[f"{N_households}_{backend}_stage1_time"] = None - results[f"{N_households}_{backend}_stage2_time"] = None + results[f"{N_households}_{backend}_stage2_time"] = None results[f"{N_households}_{backend}_stage3_time"] = None results[f"{N_households}_{backend}_time"] = None results[f"{N_households}_{backend}_hash"] = None @@ -258,19 +292,19 @@ def main_cli(): results[f"{N_households}_{backend}_peak_memory"] = None results[f"{N_households}_{backend}_result_shape"] = None print() - + # Comprehensive cleanup after completing all sizes for this backend print(f"Completing {backend} backend tests...") print(f"{backend} backend tests completed with full cleanup") - + # Save results to JSON file timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") scramble_suffix = "_scrambled" if args.scramble_data else "_sorted" filename = f"benchmark_results_{timestamp}{scramble_suffix}.json" - with open(filename, 'w') as f: + with open(filename, "w") as f: json.dump(results, f, indent=2) print(f"\nResults saved to: {filename}") - + print(f"\n{'='*120}") print("3-STAGE TIMING BREAKDOWN") if args.scramble_data: @@ -278,35 +312,37 @@ def main_cli(): else: print("Data ordering: SORTED (sequential p_id)") print(f"{'='*120}") - + # Print comparison table in the requested format print(f"\n{'='*101}") print("PERFORMANCE COMPARISON NUMPY <-> JAX") print(f"{'='*104}") - print(f"{'Households':<12}{'Stage':<18}{'NUMPY hash':<12}{'JAX hash':<12}{'NUMPY (s)':<12}{'JAX (s)':<12}{'Speedup':<12}") + print( + f"{'Households':<12}{'Stage':<18}{'NUMPY hash':<12}{'JAX hash':<12}{'NUMPY (s)':<12}{'JAX (s)':<12}{'Speedup':<12}" + ) print("-" * 104) - + for N_households in household_sizes: # Get timing data for all stages numpy_s1 = results.get(f"{N_households}_numpy_stage1_time") numpy_s2 = results.get(f"{N_households}_numpy_stage2_time") numpy_s3 = results.get(f"{N_households}_numpy_stage3_time") numpy_total = results.get(f"{N_households}_numpy_time") - + jax_s1 = results.get(f"{N_households}_jax_stage1_time") jax_s2 = results.get(f"{N_households}_jax_stage2_time") jax_s3 = results.get(f"{N_households}_jax_stage3_time") jax_total = results.get(f"{N_households}_jax_time") - + # Get stage-specific hashes numpy_s1_hash = results.get(f"{N_households}_numpy_stage1_hash") numpy_s2_hash = results.get(f"{N_households}_numpy_stage2_hash") numpy_s3_hash = results.get(f"{N_households}_numpy_stage3_hash") - + jax_s1_hash = results.get(f"{N_households}_jax_stage1_hash") jax_s2_hash = results.get(f"{N_households}_jax_stage2_hash") jax_s3_hash = results.get(f"{N_households}_jax_stage3_hash") - + # Truncate hashes for display, handling both successful and failed cases def format_hash_display(hash_value, time_value): """Format hash display based on whether the stage succeeded.""" @@ -316,70 +352,84 @@ def format_hash_display(hash_value, time_value): return hash_value[:8] else: return "N/A" - + numpy_s1_hash_display = format_hash_display(numpy_s1_hash, numpy_s1) numpy_s2_hash_display = format_hash_display(numpy_s2_hash, numpy_s2) numpy_s3_hash_display = format_hash_display(numpy_s3_hash, numpy_s3) - + jax_s1_hash_display = format_hash_display(jax_s1_hash, jax_s1) jax_s2_hash_display = format_hash_display(jax_s2_hash, jax_s2) jax_s3_hash_display = format_hash_display(jax_s3_hash, jax_s3) - + # Helper function to format time display def format_time_display(time_value): """Format time display for successful or failed runs.""" return f"{time_value:.4f}" if time_value is not None else "FAILED" - + # Helper function to calculate speedup def calculate_speedup(numpy_time, jax_time): """Calculate speedup string, handling failed cases.""" if numpy_time is None and jax_time is None: return "FAILED" - elif numpy_time is None: - return "N/A" - elif jax_time is None: + elif numpy_time is None or jax_time is None: return "N/A" elif jax_time > 0: speedup = numpy_time / jax_time - return f"{speedup:.2f}x" if speedup >= 1 else f"1/{jax_time/numpy_time:.2f}x" + return ( + f"{speedup:.2f}x" + if speedup >= 1 + else f"1/{jax_time/numpy_time:.2f}x" + ) else: return "N/A" - + # Determine if we should show stage breakdown or overall FAILED show_stages = (numpy_total is not None) or (jax_total is not None) - + if show_stages: # Show individual stage results - + # Pre-processing row (Stage 1 hashes often unstable due to dict return) s1_speedup_str = calculate_speedup(numpy_s1, jax_s1) - print(f"{N_households:<12,}{'pre-processing':<18}{'-':<12}{'-':<12}{format_time_display(numpy_s1):<12}{format_time_display(jax_s1):<12}{s1_speedup_str:<12}") - + print( + f"{N_households:<12,}{'pre-processing':<18}{'-':<12}{'-':<12}{format_time_display(numpy_s1):<12}{format_time_display(jax_s1):<12}{s1_speedup_str:<12}" + ) + # Computation row (Stage 2 hashes should be stable) s2_speedup_str = calculate_speedup(numpy_s2, jax_s2) - print(f"{'':>12}{'computation':<18}{numpy_s2_hash_display:<12}{jax_s2_hash_display:<12}{format_time_display(numpy_s2):<12}{format_time_display(jax_s2):<12}{s2_speedup_str:<12}") - + print( + f"{'':>12}{'computation':<18}{numpy_s2_hash_display:<12}{jax_s2_hash_display:<12}{format_time_display(numpy_s2):<12}{format_time_display(jax_s2):<12}{s2_speedup_str:<12}" + ) + # Post-processing row (Stage 3 hashes should be stable) s3_speedup_str = calculate_speedup(numpy_s3, jax_s3) - print(f"{'':>12}{'post-processing':<18}{numpy_s3_hash_display:<12}{jax_s3_hash_display:<12}{format_time_display(numpy_s3):<12}{format_time_display(jax_s3):<12}{s3_speedup_str:<12}") - + print( + f"{'':>12}{'post-processing':<18}{numpy_s3_hash_display:<12}{jax_s3_hash_display:<12}{format_time_display(numpy_s3):<12}{format_time_display(jax_s3):<12}{s3_speedup_str:<12}" + ) + # Total time row total_speedup_str = calculate_speedup(numpy_total, jax_total) - print(f"{'':>12}{'total time':<18}{'':>12}{'':>12}{format_time_display(numpy_total):<12}{format_time_display(jax_total):<12}{total_speedup_str:<12}") - + print( + f"{'':>12}{'total time':<18}{'':>12}{'':>12}{format_time_display(numpy_total):<12}{format_time_display(jax_total):<12}{total_speedup_str:<12}" + ) + print("-" * 104) else: # Both backends completely failed - print(f"{N_households:<12,}{'FAILED':<18}{'FAILED':<12}{'FAILED':<12}{'FAILED':<12}{'FAILED':<12}{'FAILED':<12}") + print( + f"{N_households:<12,}{'FAILED':<18}{'FAILED':<12}{'FAILED':<12}{'FAILED':<12}{'FAILED':<12}{'FAILED':<12}" + ) print("-" * 104) - + # Print memory comparison print(f"\n{'='*140}") print("MEMORY USAGE COMPARISON") print(f"{'='*140}") - print(f"{'Households':<12}{'NumPy Init':<12}{'NumPy Final':<13}{'NumPy Δ':<12}{'NumPy Peak':<13}{'JAX Init':<12}{'JAX Final':<12}{'JAX Δ':<12}{'JAX Peak':<12}") + print( + f"{'Households':<12}{'NumPy Init':<12}{'NumPy Final':<13}{'NumPy Δ':<12}{'NumPy Peak':<13}{'JAX Init':<12}{'JAX Final':<12}{'JAX Δ':<12}{'JAX Peak':<12}" + ) print("-" * 140) - + for N_households in household_sizes: numpy_init = results.get(f"{N_households}_numpy_initial_memory") numpy_final = results.get(f"{N_households}_numpy_final_memory") @@ -389,14 +439,16 @@ def calculate_speedup(numpy_time, jax_time): jax_final = results.get(f"{N_households}_jax_final_memory") jax_delta = results.get(f"{N_households}_jax_memory_delta") jax_peak = results.get(f"{N_households}_jax_peak_memory") - + # Helper function to format memory values def format_memory(value): return f"{value:.1f}" if value is not None else "FAILED" - + # Show memory data even if only one backend succeeded - print(f"{N_households:<12,}{format_memory(numpy_init):<12}{format_memory(numpy_final):<13}{format_memory(numpy_delta):<12}{format_memory(numpy_peak):<13}{format_memory(jax_init):<12}{format_memory(jax_final):<12}{format_memory(jax_delta):<12}{format_memory(jax_peak):<12}") - + print( + f"{N_households:<12,}{format_memory(numpy_init):<12}{format_memory(numpy_final):<13}{format_memory(numpy_delta):<12}{format_memory(numpy_peak):<13}{format_memory(jax_init):<12}{format_memory(jax_final):<12}{format_memory(jax_delta):<12}{format_memory(jax_peak):<12}" + ) + print("-" * 140) print("\nLegend:") print(" Stage 1: Data preprocessing & DAG creation") @@ -405,7 +457,7 @@ def format_memory(value): print(" Init/Final: Memory usage before/after execution") print(" Δ: Memory increase during execution") print(" Peak: Maximum memory usage during execution") - + print(f"\n{'='*120}") print("BENCHMARK COMPLETED") print(f"{'='*120}") @@ -414,4 +466,4 @@ def format_memory(value): if __name__ == "__main__": - main_cli() \ No newline at end of file + main_cli() diff --git a/benchmark_code/benchmark_compare.py b/benchmark_code/benchmark_compare.py index 1e081f0..1bb7f34 100644 --- a/benchmark_code/benchmark_compare.py +++ b/benchmark_code/benchmark_compare.py @@ -7,16 +7,16 @@ python benchmark_compare.py main_results.json pr_results.json [--save-comparison] """ +import argparse import json -import os import sys from datetime import datetime -import argparse + def load_benchmark_results(filepath): """Load benchmark results from JSON file.""" try: - with open(filepath, 'r') as f: + with open(filepath) as f: data = json.load(f) return data except FileNotFoundError: @@ -26,11 +26,12 @@ def load_benchmark_results(filepath): print(f"Error: Invalid JSON in file '{filepath}'.") return None + def extract_household_sizes(results): """Extract household sizes from results metadata or data keys.""" if "metadata" in results and "household_sizes" in results["metadata"]: return results["metadata"]["household_sizes"] - + # Fallback: extract from data keys household_sizes = set() for key in results.keys(): @@ -40,184 +41,247 @@ def extract_household_sizes(results): household_sizes.add(size) except ValueError: continue - + return sorted(list(household_sizes)) + def print_jax_comparison_table(main_results, pr_results, household_sizes): """Print comparison table for JAX backend with 3-stage breakdown.""" print(f"\n{'='*140}") print("JAX BACKEND COMPARISON: Main Branch vs PR Branch - 3-STAGE BREAKDOWN") print(f"{'='*140}") - print(f"{'Households':<12}{'Stage':<18}{'Main (s)':<12}{'PR (s)':<12}{'Speedup':<12}{'Description':<25}{'Hash Match':<12}") + print( + f"{'Households':<12}{'Stage':<18}{'Main (s)':<12}{'PR (s)':<12}{'Speedup':<12}{'Description':<25}{'Hash Match':<12}" + ) print("-" * 140) - + for N_households in household_sizes: # Get timing data for all stages main_s1 = main_results.get(f"{N_households}_jax_stage1_time") main_s2 = main_results.get(f"{N_households}_jax_stage2_time") main_s3 = main_results.get(f"{N_households}_jax_stage3_time") main_total = main_results.get(f"{N_households}_jax_time") - + pr_s1 = pr_results.get(f"{N_households}_jax_stage1_time") pr_s2 = pr_results.get(f"{N_households}_jax_stage2_time") pr_s3 = pr_results.get(f"{N_households}_jax_stage3_time") pr_total = pr_results.get(f"{N_households}_jax_time") - + # Get stage-specific hashes main_s1_hash = main_results.get(f"{N_households}_jax_stage1_hash") main_s2_hash = main_results.get(f"{N_households}_jax_stage2_hash") main_s3_hash = main_results.get(f"{N_households}_jax_stage3_hash") - + pr_s1_hash = pr_results.get(f"{N_households}_jax_stage1_hash") pr_s2_hash = pr_results.get(f"{N_households}_jax_stage2_hash") pr_s3_hash = pr_results.get(f"{N_households}_jax_stage3_hash") - + # Check hash matches for each stage # Stage 1 hashes are intentionally omitted due to unstable dict returns s1_hash_match = "N/A" - s2_hash_match = "✓" if main_s2_hash == pr_s2_hash else "✗" if main_s2_hash and pr_s2_hash else "N/A" - s3_hash_match = "✓" if main_s3_hash == pr_s3_hash else "✗" if main_s3_hash and pr_s3_hash else "N/A" - + s2_hash_match = ( + "✓" + if main_s2_hash == pr_s2_hash + else "✗" if main_s2_hash and pr_s2_hash else "N/A" + ) + s3_hash_match = ( + "✓" + if main_s3_hash == pr_s3_hash + else "✗" if main_s3_hash and pr_s3_hash else "N/A" + ) + # Check if we have valid data if all(x is not None for x in [main_s1, main_s2, main_s3, pr_s1, pr_s2, pr_s3]): # Stage 1 row s1_speedup = main_s1 / pr_s1 - s1_speedup_str = f"{s1_speedup:.2f}x" if s1_speedup >= 1 else f"1/{pr_s1/main_s1:.2f}x" - print(f"{N_households:<12,}{'pre-processing':<18}{main_s1:<12.4f}{pr_s1:<12.4f}{s1_speedup_str:<12}{'Data preprocessing':<25}{s1_hash_match:<12}") - + s1_speedup_str = ( + f"{s1_speedup:.2f}x" if s1_speedup >= 1 else f"1/{pr_s1/main_s1:.2f}x" + ) + print( + f"{N_households:<12,}{'pre-processing':<18}{main_s1:<12.4f}{pr_s1:<12.4f}{s1_speedup_str:<12}{'Data preprocessing':<25}{s1_hash_match:<12}" + ) + # Stage 2 row - s2_speedup = main_s2 / pr_s2 - s2_speedup_str = f"{s2_speedup:.2f}x" if s2_speedup >= 1 else f"1/{pr_s2/main_s2:.2f}x" - print(f"{'':>12}{'computation':<18}{main_s2:<12.4f}{pr_s2:<12.4f}{s2_speedup_str:<12}{'Core computation':<25}{s2_hash_match:<12}") - + s2_speedup = main_s2 / pr_s2 + s2_speedup_str = ( + f"{s2_speedup:.2f}x" if s2_speedup >= 1 else f"1/{pr_s2/main_s2:.2f}x" + ) + print( + f"{'':>12}{'computation':<18}{main_s2:<12.4f}{pr_s2:<12.4f}{s2_speedup_str:<12}{'Core computation':<25}{s2_hash_match:<12}" + ) + # Stage 3 row s3_speedup = main_s3 / pr_s3 - s3_speedup_str = f"{s3_speedup:.2f}x" if s3_speedup >= 1 else f"1/{pr_s3/main_s3:.2f}x" - print(f"{'':>12}{'post-processing':<18}{main_s3:<12.4f}{pr_s3:<12.4f}{s3_speedup_str:<12}{'DataFrame formatting':<25}{s3_hash_match:<12}") - + s3_speedup_str = ( + f"{s3_speedup:.2f}x" if s3_speedup >= 1 else f"1/{pr_s3/main_s3:.2f}x" + ) + print( + f"{'':>12}{'post-processing':<18}{main_s3:<12.4f}{pr_s3:<12.4f}{s3_speedup_str:<12}{'DataFrame formatting':<25}{s3_hash_match:<12}" + ) + # Total row if main_total and pr_total: total_speedup = main_total / pr_total - total_speedup_str = f"{total_speedup:.2f}x" if total_speedup >= 1 else f"1/{pr_total/main_total:.2f}x" - print(f"{'':>12}{'total time':<18}{main_total:<12.4f}{pr_total:<12.4f}{total_speedup_str:<12}{'Complete execution':<25}{'':<12}") - + total_speedup_str = ( + f"{total_speedup:.2f}x" + if total_speedup >= 1 + else f"1/{pr_total/main_total:.2f}x" + ) + print( + f"{'':>12}{'total time':<18}{main_total:<12.4f}{pr_total:<12.4f}{total_speedup_str:<12}{'Complete execution':<25}{'':<12}" + ) + print("-" * 140) else: # Handle failed cases main_time_str = f"{main_total:.4f}" if main_total is not None else "FAILED" pr_time_str = f"{pr_total:.4f}" if pr_total is not None else "FAILED" - print(f"{N_households:<12,}{'FAILED':<18}{main_time_str:<12}{pr_time_str:<12}{'N/A':<12}{'Benchmark failed':<25}{'N/A':<12}") + print( + f"{N_households:<12,}{'FAILED':<18}{main_time_str:<12}{pr_time_str:<12}{'N/A':<12}{'Benchmark failed':<25}{'N/A':<12}" + ) print("-" * 140) + def print_numpy_comparison_table(main_results, pr_results, household_sizes): """Print comparison table for NumPy backend with 3-stage breakdown.""" print(f"\n{'='*140}") print("NUMPY BACKEND COMPARISON: Main Branch vs PR Branch - 3-STAGE BREAKDOWN") print(f"{'='*140}") - print(f"{'Households':<12}{'Stage':<18}{'Main (s)':<12}{'PR (s)':<12}{'Speedup':<12}{'Description':<25}{'Hash Match':<12}") + print( + f"{'Households':<12}{'Stage':<18}{'Main (s)':<12}{'PR (s)':<12}{'Speedup':<12}{'Description':<25}{'Hash Match':<12}" + ) print("-" * 140) - + for N_households in household_sizes: # Get timing data for all stages main_s1 = main_results.get(f"{N_households}_numpy_stage1_time") main_s2 = main_results.get(f"{N_households}_numpy_stage2_time") main_s3 = main_results.get(f"{N_households}_numpy_stage3_time") main_total = main_results.get(f"{N_households}_numpy_time") - + pr_s1 = pr_results.get(f"{N_households}_numpy_stage1_time") pr_s2 = pr_results.get(f"{N_households}_numpy_stage2_time") pr_s3 = pr_results.get(f"{N_households}_numpy_stage3_time") pr_total = pr_results.get(f"{N_households}_numpy_time") - + # Get stage-specific hashes main_s1_hash = main_results.get(f"{N_households}_numpy_stage1_hash") main_s2_hash = main_results.get(f"{N_households}_numpy_stage2_hash") main_s3_hash = main_results.get(f"{N_households}_numpy_stage3_hash") - + pr_s1_hash = pr_results.get(f"{N_households}_numpy_stage1_hash") pr_s2_hash = pr_results.get(f"{N_households}_numpy_stage2_hash") pr_s3_hash = pr_results.get(f"{N_households}_numpy_stage3_hash") - + # Check hash matches for each stage # Stage 1 hashes are intentionally omitted due to unstable dict returns s1_hash_match = "N/A" - s2_hash_match = "✓" if main_s2_hash == pr_s2_hash else "✗" if main_s2_hash and pr_s2_hash else "N/A" - s3_hash_match = "✓" if main_s3_hash == pr_s3_hash else "✗" if main_s3_hash and pr_s3_hash else "N/A" - + s2_hash_match = ( + "✓" + if main_s2_hash == pr_s2_hash + else "✗" if main_s2_hash and pr_s2_hash else "N/A" + ) + s3_hash_match = ( + "✓" + if main_s3_hash == pr_s3_hash + else "✗" if main_s3_hash and pr_s3_hash else "N/A" + ) + # Check if we have valid data if all(x is not None for x in [main_s1, main_s2, main_s3, pr_s1, pr_s2, pr_s3]): # Stage 1 row s1_speedup = main_s1 / pr_s1 - s1_speedup_str = f"{s1_speedup:.2f}x" if s1_speedup >= 1 else f"1/{pr_s1/main_s1:.2f}x" - print(f"{N_households:<12,}{'pre-processing':<18}{main_s1:<12.4f}{pr_s1:<12.4f}{s1_speedup_str:<12}{'Data preprocessing':<25}{s1_hash_match:<12}") - + s1_speedup_str = ( + f"{s1_speedup:.2f}x" if s1_speedup >= 1 else f"1/{pr_s1/main_s1:.2f}x" + ) + print( + f"{N_households:<12,}{'pre-processing':<18}{main_s1:<12.4f}{pr_s1:<12.4f}{s1_speedup_str:<12}{'Data preprocessing':<25}{s1_hash_match:<12}" + ) + # Stage 2 row - s2_speedup = main_s2 / pr_s2 - s2_speedup_str = f"{s2_speedup:.2f}x" if s2_speedup >= 1 else f"1/{pr_s2/main_s2:.2f}x" - print(f"{'':>12}{'computation':<18}{main_s2:<12.4f}{pr_s2:<12.4f}{s2_speedup_str:<12}{'Core computation':<25}{s2_hash_match:<12}") - + s2_speedup = main_s2 / pr_s2 + s2_speedup_str = ( + f"{s2_speedup:.2f}x" if s2_speedup >= 1 else f"1/{pr_s2/main_s2:.2f}x" + ) + print( + f"{'':>12}{'computation':<18}{main_s2:<12.4f}{pr_s2:<12.4f}{s2_speedup_str:<12}{'Core computation':<25}{s2_hash_match:<12}" + ) + # Stage 3 row s3_speedup = main_s3 / pr_s3 - s3_speedup_str = f"{s3_speedup:.2f}x" if s3_speedup >= 1 else f"1/{pr_s3/main_s3:.2f}x" - print(f"{'':>12}{'post-processing':<18}{main_s3:<12.4f}{pr_s3:<12.4f}{s3_speedup_str:<12}{'DataFrame formatting':<25}{s3_hash_match:<12}") - + s3_speedup_str = ( + f"{s3_speedup:.2f}x" if s3_speedup >= 1 else f"1/{pr_s3/main_s3:.2f}x" + ) + print( + f"{'':>12}{'post-processing':<18}{main_s3:<12.4f}{pr_s3:<12.4f}{s3_speedup_str:<12}{'DataFrame formatting':<25}{s3_hash_match:<12}" + ) + # Total row if main_total and pr_total: total_speedup = main_total / pr_total - total_speedup_str = f"{total_speedup:.2f}x" if total_speedup >= 1 else f"1/{pr_total/main_total:.2f}x" - print(f"{'':>12}{'total time':<18}{main_total:<12.4f}{pr_total:<12.4f}{total_speedup_str:<12}{'Complete execution':<25}{'':<12}") - + total_speedup_str = ( + f"{total_speedup:.2f}x" + if total_speedup >= 1 + else f"1/{pr_total/main_total:.2f}x" + ) + print( + f"{'':>12}{'total time':<18}{main_total:<12.4f}{pr_total:<12.4f}{total_speedup_str:<12}{'Complete execution':<25}{'':<12}" + ) + print("-" * 140) else: # Handle failed cases main_time_str = f"{main_total:.4f}" if main_total is not None else "FAILED" pr_time_str = f"{pr_total:.4f}" if pr_total is not None else "FAILED" - print(f"{N_households:<12,}{'FAILED':<18}{main_time_str:<12}{pr_time_str:<12}{'N/A':<12}{'Benchmark failed':<25}{'N/A':<12}") + print( + f"{N_households:<12,}{'FAILED':<18}{main_time_str:<12}{pr_time_str:<12}{'N/A':<12}{'Benchmark failed':<25}{'N/A':<12}" + ) print("-" * 140) + def print_summary_statistics(main_results, pr_results, household_sizes): """Print summary statistics comparing main vs PR performance with 3-stage breakdown.""" print(f"\n{'='*100}") print("SUMMARY STATISTICS - 3-STAGE BREAKDOWN") print(f"{'='*100}") - + backends = ["numpy", "jax"] stages = ["stage1", "stage2", "stage3", "total"] stage_names = { "stage1": "Stage 1 (preprocessing)", "stage2": "Stage 2 (computation)", "stage3": "Stage 3 (formatting)", - "total": "Total execution" + "total": "Total execution", } - + for backend in backends: print(f"\n{backend.upper()} Backend:") print("-" * 40) - + for stage in stages: if stage == "total": time_suffix = "_time" else: time_suffix = f"_{stage}_time" - + valid_speedups = [] successful_runs = 0 total_runs = len(household_sizes) - + for N_households in household_sizes: main_time = main_results.get(f"{N_households}_{backend}{time_suffix}") pr_time = pr_results.get(f"{N_households}_{backend}{time_suffix}") - + if main_time is not None and pr_time is not None: successful_runs += 1 speedup = main_time / pr_time valid_speedups.append(speedup) - + if valid_speedups: avg_speedup = sum(valid_speedups) / len(valid_speedups) max_speedup = max(valid_speedups) min_speedup = min(valid_speedups) - + print(f" {stage_names[stage]}:") print(f" Average speedup: {avg_speedup:.2f}x") print(f" Maximum speedup: {max_speedup:.2f}x") @@ -225,65 +289,83 @@ def print_summary_statistics(main_results, pr_results, household_sizes): print(f" Successful runs: {successful_runs}/{total_runs}") else: print(f" {stage_names[stage]}: No valid comparisons available") - + # Check hash consistency for all stages stage_hash_results = {} for stage_num in [1, 2, 3]: hash_mismatches = 0 total_comparisons = 0 - + for N_households in household_sizes: - main_hash = main_results.get(f"{N_households}_{backend}_stage{stage_num}_hash") - pr_hash = pr_results.get(f"{N_households}_{backend}_stage{stage_num}_hash") - + main_hash = main_results.get( + f"{N_households}_{backend}_stage{stage_num}_hash" + ) + pr_hash = pr_results.get( + f"{N_households}_{backend}_stage{stage_num}_hash" + ) + if main_hash and pr_hash: total_comparisons += 1 if main_hash != pr_hash: hash_mismatches += 1 - + stage_hash_results[stage_num] = (hash_mismatches, total_comparisons) - + # Print hash verification results meaningful_mismatches = False # Track mismatches in stages 2 and 3 only for stage_num in [1, 2, 3]: hash_mismatches, total_comparisons = stage_hash_results[stage_num] stage_name = {1: "Stage 1", 2: "Stage 2", 3: "Stage 3"}[stage_num] - + if total_comparisons > 0: if stage_num == 1: # Stage 1 hash mismatches are expected due to unstable dict returns - print(f" {stage_name} hash verification: {hash_mismatches}/{total_comparisons} mismatches (expected - unstable dict)") + print( + f" {stage_name} hash verification: {hash_mismatches}/{total_comparisons} mismatches (expected - unstable dict)" + ) else: - print(f" {stage_name} hash verification: {hash_mismatches}/{total_comparisons} mismatches") + print( + f" {stage_name} hash verification: {hash_mismatches}/{total_comparisons} mismatches" + ) if hash_mismatches > 0: meaningful_mismatches = True else: - print(f" {stage_name} hash verification: No valid comparisons available") - + print( + f" {stage_name} hash verification: No valid comparisons available" + ) + # Only show warning for meaningful mismatches (stages 2 and 3) - stages_2_3_have_data = any(total for stage_num, (_, total) in stage_hash_results.items() if stage_num in [2, 3]) - + stages_2_3_have_data = any( + total + for stage_num, (_, total) in stage_hash_results.items() + if stage_num in [2, 3] + ) + if meaningful_mismatches: - print(f" ⚠ Stage 2/3 results differ between main and PR - investigate numerical differences") + print( + " ⚠ Stage 2/3 results differ between main and PR - investigate numerical differences" + ) elif stages_2_3_have_data: - print(f" ✓ All meaningful stage results are numerically identical (Stage 2 & 3)") + print( + " ✓ All meaningful stage results are numerically identical (Stage 2 & 3)" + ) else: - print(f" No valid hash comparisons available for meaningful stages") - + print(" No valid hash comparisons available for meaningful stages") + # Overall comparison print(f"\n{'='*100}") print("OVERALL PERFORMANCE IMPACT") print(f"{'='*100}") - + for backend in backends: total_speedups = [] for N_households in household_sizes: main_total = main_results.get(f"{N_households}_{backend}_time") pr_total = pr_results.get(f"{N_households}_{backend}_time") - + if main_total and pr_total: total_speedups.append(main_total / pr_total) - + if total_speedups: avg_speedup = sum(total_speedups) / len(total_speedups) if avg_speedup > 1.05: @@ -292,69 +374,81 @@ def print_summary_statistics(main_results, pr_results, household_sizes): impact = f"PR is {1/avg_speedup:.1f}x slower (performance regression)" else: impact = "PR has minimal performance impact (±5%)" - + print(f"{backend.upper()}: {impact}") else: print(f"{backend.upper()}: No valid performance comparisons available") + def main(): - parser = argparse.ArgumentParser(description="Compare benchmark results from main branch vs PR branch") - parser.add_argument("main_file", help="Path to benchmark results JSON file from main branch") - parser.add_argument("pr_file", help="Path to benchmark results JSON file from PR branch") - parser.add_argument("--save-comparison", help="Save comparison tables to text file", action="store_true") - + parser = argparse.ArgumentParser( + description="Compare benchmark results from main branch vs PR branch" + ) + parser.add_argument( + "main_file", help="Path to benchmark results JSON file from main branch" + ) + parser.add_argument( + "pr_file", help="Path to benchmark results JSON file from PR branch" + ) + parser.add_argument( + "--save-comparison", + help="Save comparison tables to text file", + action="store_true", + ) + args = parser.parse_args() - + # Load benchmark results print("Loading benchmark results...") main_results = load_benchmark_results(args.main_file) pr_results = load_benchmark_results(args.pr_file) - + if main_results is None or pr_results is None: sys.exit(1) - + # Extract household sizes (use PR results as primary, fallback to main) household_sizes = extract_household_sizes(pr_results) if not household_sizes: household_sizes = extract_household_sizes(main_results) - + if not household_sizes: print("Error: Could not extract household sizes from either file.") sys.exit(1) - + print(f"Found data for household sizes: {household_sizes}") - + # Print comparison tables print_jax_comparison_table(main_results, pr_results, household_sizes) print_numpy_comparison_table(main_results, pr_results, household_sizes) print_summary_statistics(main_results, pr_results, household_sizes) - + # Save to file if requested if args.save_comparison: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_file = f"benchmark_comparison_{timestamp}.txt" - + # Redirect stdout to file original_stdout = sys.stdout - + try: - with open(output_file, 'w') as f: + with open(output_file, "w") as f: sys.stdout = f - print(f"Benchmark Comparison Report") + print("Benchmark Comparison Report") print(f"Generated: {datetime.now().isoformat()}") print(f"Main branch file: {args.main_file}") print(f"PR branch file: {args.pr_file}") - + print_jax_comparison_table(main_results, pr_results, household_sizes) print_numpy_comparison_table(main_results, pr_results, household_sizes) print_summary_statistics(main_results, pr_results, household_sizes) - + sys.stdout = original_stdout print(f"\nComparison saved to: {output_file}") - + except Exception as e: sys.stdout = original_stdout print(f"Error saving comparison: {e}") + if __name__ == "__main__": main() diff --git a/benchmark_code/benchmark_make_data.py b/benchmark_code/benchmark_make_data.py index f43a34a..21d2edb 100644 --- a/benchmark_code/benchmark_make_data.py +++ b/benchmark_code/benchmark_make_data.py @@ -4,118 +4,312 @@ This module provides the make_data function to create standardized synthetic datasets for GETTSIM/TTSIM performance testing. """ + # %% -import pandas as pd -import numpy as np import time +import numpy as np +import pandas as pd + + def make_data(N, scramble_data=False): """ Create a DataFrame with N households, each containing 2 parents and 2 children. Uses vectorized operations for fast data generation. - + Parameters: N (int): Number of households to create scramble_data (bool): Whether to randomly shuffle rows to create unsorted p_id order. Default is False to maintain sorted order for better performance. - + Returns: pd.DataFrame: DataFrame with household data (4*N rows) """ # Total number of people (4 per household: 2 parents + 2 children) total_people = N * 4 - + # Create base template for one household (4 people) - base_template = np.array([ - # Parent 1 - [30, 35, 0, 1995, 0, 0, False, False, 0, 0, 5000, 0, 500, 0, 0, 0, 0, -1, True, 0, 0, True, False, False, 1, -1, -1, False, -1, 1, 4, 360, 2062], - # Parent 2 - [30, 35, 0, 1995, 0, 1, False, False, 0, 0, 4000, 0, 0, 0, 0, 0, 0, -1, True, 0, 0, True, False, False, 0, -1, -1, False, -1, 0, 4, 360, 2062], - # Child 1 - [10, 0, 0, 2015, 0, 2, False, False, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, False, 0, 0, False, False, True, -1, 0, 1, False, 0, -1, -1, 120, 2082], - # Child 2 (twin) - [10, 0, 0, 2015, 0, 3, False, False, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, False, 0, 0, False, False, True, -1, 0, 1, False, 0, -1, -1, 120, 2082] - ]) - + base_template = np.array( + [ + # Parent 1 + [ + 30, + 35, + 0, + 1995, + 0, + 0, + False, + False, + 0, + 0, + 5000, + 0, + 500, + 0, + 0, + 0, + 0, + -1, + True, + 0, + 0, + True, + False, + False, + 1, + -1, + -1, + False, + -1, + 1, + 4, + 360, + 2062, + ], + # Parent 2 + [ + 30, + 35, + 0, + 1995, + 0, + 1, + False, + False, + 0, + 0, + 4000, + 0, + 0, + 0, + 0, + 0, + 0, + -1, + True, + 0, + 0, + True, + False, + False, + 0, + -1, + -1, + False, + -1, + 0, + 4, + 360, + 2062, + ], + # Child 1 + [ + 10, + 0, + 0, + 2015, + 0, + 2, + False, + False, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + False, + 0, + 0, + False, + False, + True, + -1, + 0, + 1, + False, + 0, + -1, + -1, + 120, + 2082, + ], + # Child 2 (twin) + [ + 10, + 0, + 0, + 2015, + 0, + 3, + False, + False, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + False, + 0, + 0, + False, + False, + True, + -1, + 0, + 1, + False, + 0, + -1, + -1, + 120, + 2082, + ], + ] + ) + # Replicate template for all households data_array = np.tile(base_template, (N, 1)) - + # Create household and person IDs using vectorized operations hh_ids = np.repeat(np.arange(N), 4) p_ids = np.arange(total_people) - + # Update IDs in the data array data_array[:, 4] = hh_ids # hh_id column - data_array[:, 5] = p_ids # p_id column - + data_array[:, 5] = p_ids # p_id column + # Update spouse_ids for parents (every 4th person starting from 0 gets spouse_id of next person, and vice versa) spouse_mask_1 = np.arange(total_people) % 4 == 0 # Parent 1 positions spouse_mask_2 = np.arange(total_people) % 4 == 1 # Parent 2 positions data_array[spouse_mask_1, 24] = p_ids[spouse_mask_1] + 1 # Parent 1 -> Parent 2 data_array[spouse_mask_2, 24] = p_ids[spouse_mask_2] - 1 # Parent 2 -> Parent 1 - + # Update bürgergeld__p_id_einstandspartner (identical to spouse_id) data_array[spouse_mask_1, 29] = p_ids[spouse_mask_1] + 1 # Parent 1 -> Parent 2 data_array[spouse_mask_2, 29] = p_ids[spouse_mask_2] - 1 # Parent 2 -> Parent 1 - + # Update parent_ids for children child_mask_1 = np.arange(total_people) % 4 == 2 # Child 1 positions child_mask_2 = np.arange(total_people) % 4 == 3 # Child 2 positions parent1_ids = p_ids[child_mask_1] - 2 # Parent 1 IDs for children parent2_ids = p_ids[child_mask_1] - 1 # Parent 2 IDs for children - + data_array[child_mask_1, 25] = parent1_ids # parent_id_1 for child 1 data_array[child_mask_1, 26] = parent2_ids # parent_id_2 for child 1 - data_array[child_mask_2, 25] = parent1_ids # parent_id_1 for child 2 + data_array[child_mask_2, 25] = parent1_ids # parent_id_1 for child 2 data_array[child_mask_2, 26] = parent2_ids # parent_id_2 for child 2 - + # Update person_that_pays_childcare_expenses and id_recipient_child_allowance for children - data_array[child_mask_1, 17] = parent1_ids # person_that_pays_childcare_expenses for child 1 - data_array[child_mask_2, 17] = parent1_ids # person_that_pays_childcare_expenses for child 2 - data_array[child_mask_1, 28] = parent1_ids # id_recipient_child_allowance for child 1 - data_array[child_mask_2, 28] = parent1_ids # id_recipient_child_allowance for child 2 - + data_array[child_mask_1, 17] = ( + parent1_ids # person_that_pays_childcare_expenses for child 1 + ) + data_array[child_mask_2, 17] = ( + parent1_ids # person_that_pays_childcare_expenses for child 2 + ) + data_array[child_mask_1, 28] = ( + parent1_ids # id_recipient_child_allowance for child 1 + ) + data_array[child_mask_2, 28] = ( + parent1_ids # id_recipient_child_allowance for child 2 + ) + # Column names in the same order as the template columns = [ - "age", "working_hours", "disability_grade", "birth_year", "hh_id", "p_id", - "east_germany", "self_employed", "income_from_self_employment", "income_from_rent", - "income_from_employment", "income_from_forest_and_agriculture", "income_from_capital", - "income_from_other_sources", "pension_income", "contribution_to_private_pension_insurance", - "childcare_expenses", "person_that_pays_childcare_expenses", "joint_taxation", - "amount_private_pension_income", "contribution_private_health_insurance", "has_children", - "single_parent", "is_child", "spouse_id", "parent_id_1", "parent_id_2", "in_training", - "id_recipient_child_allowance", "bürgergeld__p_id_einstandspartner", "lohnsteuer__steuerklasse", - "alter_monate", "jahr_renteneintritt", + "age", + "working_hours", + "disability_grade", + "birth_year", + "hh_id", + "p_id", + "east_germany", + "self_employed", + "income_from_self_employment", + "income_from_rent", + "income_from_employment", + "income_from_forest_and_agriculture", + "income_from_capital", + "income_from_other_sources", + "pension_income", + "contribution_to_private_pension_insurance", + "childcare_expenses", + "person_that_pays_childcare_expenses", + "joint_taxation", + "amount_private_pension_income", + "contribution_private_health_insurance", + "has_children", + "single_parent", + "is_child", + "spouse_id", + "parent_id_1", + "parent_id_2", + "in_training", + "id_recipient_child_allowance", + "bürgergeld__p_id_einstandspartner", + "lohnsteuer__steuerklasse", + "alter_monate", + "jahr_renteneintritt", ] - + # Create DataFrame data = pd.DataFrame(data_array, columns=columns) - + # Convert boolean columns back to bool (they become float during array operations) - bool_columns = ["east_germany", "self_employed", "joint_taxation", "has_children", "single_parent", "is_child", "in_training"] + bool_columns = [ + "east_germany", + "self_employed", + "joint_taxation", + "has_children", + "single_parent", + "is_child", + "in_training", + ] for col in bool_columns: data[col] = data[col].astype(bool) - + # Convert integer columns to int - int_columns = ["age", "working_hours", "disability_grade", "birth_year", "hh_id", "p_id", - "spouse_id", "parent_id_1", "parent_id_2", "person_that_pays_childcare_expenses", - "id_recipient_child_allowance", "bürgergeld__p_id_einstandspartner", "lohnsteuer__steuerklasse", - "alter_monate", "jahr_renteneintritt"] + int_columns = [ + "age", + "working_hours", + "disability_grade", + "birth_year", + "hh_id", + "p_id", + "spouse_id", + "parent_id_1", + "parent_id_2", + "person_that_pays_childcare_expenses", + "id_recipient_child_allowance", + "bürgergeld__p_id_einstandspartner", + "lohnsteuer__steuerklasse", + "alter_monate", + "jahr_renteneintritt", + ] for col in int_columns: data[col] = data[col].astype(int) - + # SCRAMBLE DATA: Optionally shuffle rows to create unsorted p_id order if scramble_data: np.random.seed(42) # Fixed seed for reproducible results scrambled_indices = np.random.permutation(len(data)) data = data.iloc[scrambled_indices].reset_index(drop=True) print(f"Created DataFrame with {len(data)} rows ({len(data) // 4} households)") - print(f"Data scrambled: p_id order is now unsorted") + print("Data scrambled: p_id order is now unsorted") else: print(f"Created DataFrame with {len(data)} rows ({len(data) // 4} households)") - print(f"Data kept sorted: p_id order is sequential") - + print("Data kept sorted: p_id order is sequential") + return data @@ -123,28 +317,30 @@ def main(): """Generate datasets for all required sizes and measure timing.""" # Dataset sizes (number of households) # Each household has 4 people (2 parents + 2 children) - household_sizes = [2**15-1, 2**15, 2**16, 2**17, 2**18, 2**19, 2**20] - + household_sizes = [2**15 - 1, 2**15, 2**16, 2**17, 2**18, 2**19, 2**20] + print("Generating synthetic datasets for GETTSIM benchmarking...") print("=" * 60) - + timing_results = [] - + for num_households in household_sizes: print(f"\nGenerating dataset for {num_households:,} households...") - + # Time the data creation start_time = time.time() data = make_data(num_households) end_time = time.time() - + creation_time = end_time - start_time timing_results.append((num_households, creation_time)) - + # Calculate memory usage estimation memory_mb = data.memory_usage(deep=True).sum() / (1024 * 1024) - - print(f"✓ Created {num_households:,} households ({len(data):,} people) in {creation_time:.3f} seconds") + + print( + f"✓ Created {num_households:,} households ({len(data):,} people) in {creation_time:.3f} seconds" + ) print(f" Memory usage: {memory_mb:.2f} MB") print(f" Speed: {num_households / creation_time:.0f} households/second") @@ -153,12 +349,14 @@ def main(): print("=" * 60) print(f"{'Households':<12} {'Time (s)':<10} {'Speed (hh/s)':<15} {'People':<10}") print("-" * 60) - + for num_households, creation_time in timing_results: speed = num_households / creation_time people = num_households * 4 - print(f"{num_households:<12,} {creation_time:<10.3f} {speed:<15,.0f} {people:<10,}") - + print( + f"{num_households:<12,} {creation_time:<10.3f} {speed:<15,.0f} {people:<10,}" + ) + print("\nDataset generation completed successfully!") print("Note: Data is held in memory only - no files saved to disk.") @@ -170,4 +368,3 @@ def main(): # For inspection: # example_data = make_data(3) # 3 households = 12 people (sorted) # example_scrambled = make_data(3, scramble_data=True) # 3 households = 12 people (scrambled) - diff --git a/benchmark_code/benchmark_profile.py b/benchmark_code/benchmark_profile.py index d3ab31d..408cc20 100644 --- a/benchmark_code/benchmark_profile.py +++ b/benchmark_code/benchmark_profile.py @@ -10,19 +10,29 @@ """ -import time import argparse import hashlib +import time -from gettsim import main, InputData, MainTarget, TTTargets, Labels, SpecializedEnvironment, RawResults +from benchmark_make_data import make_data # Import shared benchmark configuration and utilities from benchmark_setup import ( - TT_TARGETS, MAPPER, JAX_AVAILABLE, - sync_jax_if_needed, get_memory_usage_mb, MemoryTracker, - PROFILE_HOUSEHOLD_SIZES, BACKENDS + BACKENDS, + MAPPER, + TT_TARGETS, + MemoryTracker, + get_memory_usage_mb, + sync_jax_if_needed, +) +from gettsim import ( + InputData, + Labels, + MainTarget, + SpecializedEnvironment, + TTTargets, + main, ) -from benchmark_make_data import make_data def run_profile(N, backend, scramble_data=False): @@ -30,14 +40,14 @@ def run_profile(N, backend, scramble_data=False): print(f"Generating dataset with {N:,} households...") data = make_data(N, scramble_data=scramble_data) print(f"Dataset created successfully. Shape: {data.shape}") - + print(f"Running GETTSIM with backend: {backend}") - + # Memory tracking setup tracker = MemoryTracker() initial_memory = get_memory_usage_mb() tracker.start_monitoring() - + try: # First stage - preprocessing and DAG creation print("\n=== STAGE 1: Data preprocessing and DAG creation ===") @@ -54,7 +64,7 @@ def run_profile(N, backend, scramble_data=False): MainTarget.processed_data, MainTarget.labels.root_nodes, MainTarget.input_data.flat, # Need this for stage 3 - MainTarget.tt_function, # Use compiled tt_function in stage 2 with JAX backend + MainTarget.tt_function, # Use compiled tt_function in stage 2 with JAX backend ], tt_targets=TTTargets( tree=TT_TARGETS, @@ -62,16 +72,16 @@ def run_profile(N, backend, scramble_data=False): include_fail_nodes=True, include_warn_nodes=False, backend=backend, - ) + ) # Force JAX synchronization before recording end time sync_jax_if_needed(backend) - + stage1_end = time.time() stage1_time = stage1_end - stage1_start - + # Generate hash for Stage 1 output (tmp) - avoid memory issues with large arrays - stage1_hash = hashlib.md5(str(tmp).encode('utf-8')).hexdigest() + stage1_hash = hashlib.md5(str(tmp).encode("utf-8")).hexdigest() print(f"Stage 1 completed in: {stage1_time:.4f} seconds") print(f"Processed data keys: {len(tmp['processed_data'])}") @@ -95,7 +105,9 @@ def run_profile(N, backend, scramble_data=False): tree=TT_TARGETS, ), processed_data=tmp["processed_data"], - input_data=InputData.flat(tmp["input_data"]["flat"]), # Provide the flat input data from stage 1 + input_data=InputData.flat( + tmp["input_data"]["flat"] + ), # Provide the flat input data from stage 1 labels=Labels(root_nodes=tmp["labels"]["root_nodes"]), tt_function=tmp["tt_function"], # Reuse pre-compiled JAX function include_fail_nodes=False, @@ -109,12 +121,14 @@ def run_profile(N, backend, scramble_data=False): stage2_end = time.time() print(f"Wall clock time: {time.strftime('%H:%M:%S')} - Completed Stage 2") stage2_time = stage2_end - stage2_start - + # Generate hash for Stage 2 output - avoid memory issues with large JAX arrays - stage2_hash = hashlib.md5(str(raw_results_stage2).encode('utf-8')).hexdigest() - + stage2_hash = hashlib.md5(str(raw_results_stage2).encode("utf-8")).hexdigest() + print(f"Stage 2 completed in: {stage2_time:.4f} seconds") - print(f"Raw results components: {list(raw_results_stage2['raw_results'].keys())}") + print( + f"Raw results components: {list(raw_results_stage2['raw_results'].keys())}" + ) print(f"Stage 2 hash: {stage2_hash[:16]}...") # Third stage - convert raw results to DataFrame (no computation, just formatting) @@ -129,7 +143,9 @@ def run_profile(N, backend, scramble_data=False): tree=TT_TARGETS, ), raw_results=raw_results_stage2["raw_results"], - input_data=InputData.flat(tmp["input_data"]["flat"]), # Provide the flat input data from stage 1 + input_data=InputData.flat( + tmp["input_data"]["flat"] + ), # Provide the flat input data from stage 1 processed_data=tmp["processed_data"], labels=Labels(root_nodes=tmp["labels"]["root_nodes"]), specialized_environment=SpecializedEnvironment( @@ -147,36 +163,46 @@ def run_profile(N, backend, scramble_data=False): print(f"Wall clock time: {time.strftime('%H:%M:%S')} - Completed Stage 3") stage3_time = stage3_end - stage3_start total_time = stage1_time + stage2_time + stage3_time - + # Generate hash for Stage 3 output - avoid memory issues - stage3_hash = hashlib.md5(str(final_results).encode('utf-8')).hexdigest() - + stage3_hash = hashlib.md5(str(final_results).encode("utf-8")).hexdigest() + # Stop memory tracking and get final readings tracker.stop_monitoring() final_memory = get_memory_usage_mb() peak_memory = tracker.get_peak() memory_delta = final_memory - initial_memory - + print(f"Stage 3 completed in: {stage3_time:.4f} seconds") - print(f"Final DataFrame shape: {final_results.shape if hasattr(final_results, 'shape') else 'N/A'}") + print( + f"Final DataFrame shape: {final_results.shape if hasattr(final_results, 'shape') else 'N/A'}" + ) print(f"Final DataFrame type: {type(final_results)}") print(f"Stage 3 hash: {stage3_hash[:16]}...") print(f"Total execution time: {total_time:.4f} seconds") - print(f"Stage 1 (preprocessing): {stage1_time:.4f}s ({stage1_time/total_time*100:.1f}%)") - print(f"Stage 2 (computation): {stage2_time:.4f}s ({stage2_time/total_time*100:.1f}%)") - print(f"Stage 3 (formatting): {stage3_time:.4f}s ({stage3_time/total_time*100:.1f}%)") + print( + f"Stage 1 (preprocessing): {stage1_time:.4f}s ({stage1_time/total_time*100:.1f}%)" + ) + print( + f"Stage 2 (computation): {stage2_time:.4f}s ({stage2_time/total_time*100:.1f}%)" + ) + print( + f"Stage 3 (formatting): {stage3_time:.4f}s ({stage3_time/total_time*100:.1f}%)" + ) print(f"Backend: {backend}") print(f"Households: {N:,}") print(f"People: {len(data):,}") print(f"Performance: {N / total_time:.0f} households/second") - print(f"Memory: {initial_memory:.1f} -> {final_memory:.1f} MB (Δ{memory_delta:+.1f}, peak: {peak_memory:.1f})") + print( + f"Memory: {initial_memory:.1f} -> {final_memory:.1f} MB (Δ{memory_delta:+.1f}, peak: {peak_memory:.1f})" + ) print("\n=== STAGE HASHES ===") print(f"Stage 1 hash: {stage1_hash[:16]}...") print(f"Stage 2 hash: {stage2_hash[:16]}...") print(f"Stage 3 hash: {stage3_hash[:16]}...") - + return final_results, total_time - + except Exception as e: print(f"ERROR during profiling: {e}") tracker.stop_monitoring() @@ -185,28 +211,42 @@ def run_profile(N, backend, scramble_data=False): def main_cli(): """Main function for command line interface.""" - parser = argparse.ArgumentParser(description='Profile GETTSIM with synthetic data') - parser.add_argument('-N', '--households', type=int, default=32768, - help='Number of households to generate (default: 32768)') - parser.add_argument('-b', '--backend', choices=BACKENDS, default='numpy', - help='Backend to use: numpy or jax (default: numpy)') - parser.add_argument('-scramble', '--scramble-data', action='store_true', - help='Scramble data to create unsorted p_id order (default: sorted)') - + parser = argparse.ArgumentParser(description="Profile GETTSIM with synthetic data") + parser.add_argument( + "-N", + "--households", + type=int, + default=32768, + help="Number of households to generate (default: 32768)", + ) + parser.add_argument( + "-b", + "--backend", + choices=BACKENDS, + default="numpy", + help="Backend to use: numpy or jax (default: numpy)", + ) + parser.add_argument( + "-scramble", + "--scramble-data", + action="store_true", + help="Scramble data to create unsorted p_id order (default: sorted)", + ) + args = parser.parse_args() - + print("GETTSIM Profiling Tool") print("=" * 50) - + result, exec_time = run_profile(args.households, args.backend, args.scramble_data) - + if result is not None: print("\n" + "=" * 50) print("Profiling completed successfully!") else: print("\n" + "=" * 50) print("Profiling failed!") - + return result, exec_time @@ -215,4 +255,4 @@ def main_cli(): # %% # For interactive use - you can also run this directly -# result, exec_time = run_profile(N=32768, backend="numpy", scramble_data=False) \ No newline at end of file +# result, exec_time = run_profile(N=32768, backend="numpy", scramble_data=False) diff --git a/benchmark_code/benchmark_setup.py b/benchmark_code/benchmark_setup.py index 35af5f0..f0000b6 100644 --- a/benchmark_code/benchmark_setup.py +++ b/benchmark_code/benchmark_setup.py @@ -7,14 +7,15 @@ import gc import os -import time import threading +import time + import psutil -from datetime import datetime # JAX-specific imports for cache management try: import jax + JAX_AVAILABLE = True except ImportError: JAX_AVAILABLE = False @@ -50,7 +51,7 @@ "mean_nettoeinkommen_für_bemessungsgrundlage_bei_arbeitslosigkeit_y": "mean_net_income_for_benefit_basis_in_case_of_unemployment_y", "beitrag": { "betrag_versicherter_m": "unemployment_insurance_contribution_m", - } + }, }, "beiträge_gesamt_m": "social_insurance_contributions_total_m", }, @@ -65,7 +66,7 @@ "elterngeld": { "betrag_m": "EG_betrag_m", "anrechenbarer_betrag_m": "EG_anrechenbarer_betrag_m", - "mean_nettoeinkommen_für_bemessungsgrundlage_nach_geburt_m": "EG_mean_nettoeinkommen_für_bemessungsgrundlage_nach_geburt_m" + "mean_nettoeinkommen_für_bemessungsgrundlage_nach_geburt_m": "EG_mean_nettoeinkommen_für_bemessungsgrundlage_nach_geburt_m", }, "unterhalt": { "tatsächlich_erhaltener_betrag_m": "unterhalt_tatsächlich_erhaltener_betrag_m", @@ -211,7 +212,7 @@ "zu_versteuerndes_einkommen_vorjahr_y_sn": 30000.0, "mean_nettoeinkommen_in_12_monaten_vor_geburt_m": 2000.0, "claimed": False, - "bisherige_bezugsmonate": 0 + "bisherige_bezugsmonate": 0, }, "bürgergeld": { # "betrag_m_bg": 0.0, @@ -232,11 +233,13 @@ # JAX UTILITIES # ============================================================================= + def sync_jax_if_needed(backend): """Force JAX synchronization to ensure all operations are complete.""" if backend == "jax" and JAX_AVAILABLE: try: import jax + # Force synchronization of all JAX operations jax.block_until_ready(jax.numpy.array([1.0])) print(" JAX operations synchronized") @@ -251,6 +254,7 @@ def clear_jax_cache(): if JAX_AVAILABLE: try: import jax + # Clear the JIT compilation cache jax.clear_caches() print(" JAX cache cleared") @@ -262,6 +266,7 @@ def clear_jax_cache(): # MEMORY TRACKING # ============================================================================= + def get_memory_usage_mb(): """Get current memory usage in MB.""" process = psutil.Process(os.getpid()) @@ -270,42 +275,42 @@ def get_memory_usage_mb(): class MemoryTracker: """Track peak memory usage during execution with continuous monitoring.""" + def __init__(self): self.peak_memory = 0 self.process = psutil.Process(os.getpid()) self.monitoring = False self.monitor_thread = None - + def start_monitoring(self): """Start continuous memory monitoring in background thread.""" self.monitoring = True self.peak_memory = self.get_current_memory() self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True) self.monitor_thread.start() - + def stop_monitoring(self): """Stop continuous memory monitoring.""" self.monitoring = False if self.monitor_thread: self.monitor_thread.join(timeout=1.0) - + def _monitor_loop(self): """Background monitoring loop.""" while self.monitoring: self.update() time.sleep(0.01) # Check every 10ms - + def get_current_memory(self): """Get current memory usage in MB.""" return self.process.memory_info().rss / 1024 / 1024 - + def update(self): """Update peak memory if current usage is higher.""" current = self.get_current_memory() - if current > self.peak_memory: - self.peak_memory = current + self.peak_memory = max(self.peak_memory, current) return current - + def get_peak(self): """Get peak memory usage in MB.""" return self.peak_memory @@ -315,6 +320,7 @@ def get_peak(self): # SESSION MANAGEMENT # ============================================================================= + def force_garbage_collection(): """Force aggressive garbage collection between runs.""" gc.collect() @@ -325,14 +331,14 @@ def force_garbage_collection(): def reset_session_state(backend): """Reset session state between different backend runs.""" print(f" Resetting session state for {backend} backend...") - + # Force garbage collection force_garbage_collection() - + # Clear JAX-specific state if switching to/from JAX if backend == "jax" or JAX_AVAILABLE: clear_jax_cache() - + # Add a small delay to let system settle time.sleep(0.5) @@ -341,6 +347,6 @@ def reset_session_state(backend): # COMMON DATASET SIZES # ============================================================================= -BENCHMARK_HOUSEHOLD_SIZES = [2**15-1, 2**15, 2**16, 2**17, 2**18, 2**19, 2**20] +BENCHMARK_HOUSEHOLD_SIZES = [2**15 - 1, 2**15, 2**16, 2**17, 2**18, 2**19, 2**20] PROFILE_HOUSEHOLD_SIZES = [2**15] # Default for profiling: 32,768 households -BACKENDS = ["numpy", "jax"] \ No newline at end of file +BACKENDS = ["numpy", "jax"] diff --git a/test_data_conversion_scripts/convert_xlsx_tests_csv_lohnst.py b/test_data_conversion_scripts/convert_xlsx_tests_csv_lohnst.py index 7172d06..14e47f7 100644 --- a/test_data_conversion_scripts/convert_xlsx_tests_csv_lohnst.py +++ b/test_data_conversion_scripts/convert_xlsx_tests_csv_lohnst.py @@ -1,4 +1,3 @@ - from pathlib import Path import pandas as pd