diff --git a/benchmark_code/BENCHMARK_README.md b/benchmark_code/BENCHMARK_README.md new file mode 100644 index 0000000..4ce3b9c --- /dev/null +++ b/benchmark_code/BENCHMARK_README.md @@ -0,0 +1,61 @@ +# Benchmark Comparison Workflow + +This document explains how to compare performance between the main branch and a PR branch with optimizations. + +## Scripts Overview + +1. **`benchmark.py`** - Runs performance benchmarks and saves results to JSON +2. **`compare_benchmark_results.py`** - Compares results from two benchmark runs + +## Workflow + +### Step 1: Run benchmark on main branch + +```bash +# Switch to main branch (ttsim) +git checkout main + +# Run benchmark +python benchmark.py + +# This creates a file like: benchmark_results_20250806_143022.json +``` + +### Step 2: Run benchmark on PR branch + +```bash +# Switch to PR branch (ttsim) +git checkout JW/dev/speedup-JAX + +# Run benchmark +python benchmark.py + +# This creates another file like: benchmark_results_20250806_145133.json +``` + +### Step 3: Compare results + +```bash +# Compare the two result files (e.g. first file=main branch results, second file=PR branch results) +python benchmark_compare_stages.py benchmark_results_20250806_143022.json benchmark_results_20250806_145133.json + +# Optional: Save comparison to file +python benchmark_compare_stages.py benchmark_results_20250806_143022.json benchmark_results_20250806_145133.json --save-comparison +``` + +## Interpreting Results + +- **Speedup > 1.0**: PR branch is faster than main branch +- **Identical Hashes**: Optimizations maintain numerical accuracy +- **Hash Mismatches**: Potential numerical differences (investigate!) + +## Dataset Sizes Tested + +- 32,767 households (2^15 - 1) +- 32,768 households (2^15) +- 65,536 households (2^16) +- 131,072 households (2^17) +- 262,144 households (2^18) +- 524,288 households (2^19) +- 1,048,576 households (2^20) +- 2,097,152 households (2^21) diff --git a/benchmark_code/benchmark.py b/benchmark_code/benchmark.py new file mode 100644 index 0000000..e3f728e --- /dev/null +++ b/benchmark_code/benchmark.py @@ -0,0 +1,728 @@ +"""Performance comparison script for numpy vs jax backends.""" +import pandas as pd +from gettsim import InputData, MainTarget, TTTargets, Labels, SpecializedEnvironment, RawResults + +# Hack: Override GETTSIM main to make all TTSIM parameters of main available in GETTSIM. +# Necessary because of GETTSIM issue #1075. +# When resolved, this can be removed and gettsim.main can be used directly. +from gettsim import germany +import ttsim +from ttsim.main_args import OrigPolicyObjects + +def main(**kwargs): + """Wrapper around ttsim.main that automatically sets the German root path and supports tt_function.""" + # Set German tax system as default if no orig_policy_objects provided + if kwargs.get('orig_policy_objects') is None: + kwargs['orig_policy_objects'] = OrigPolicyObjects(root=germany.ROOT_PATH) + + return ttsim.main(**kwargs) + +import time +import hashlib +import json +import os +import psutil +import gc +import threading +from datetime import datetime +from make_data import make_data + +# JAX-specific imports for cache management +try: + import jax + JAX_AVAILABLE = True +except ImportError: + JAX_AVAILABLE = False + +# %% +TT_TARGETS = { + "einkommensteuer": { + "betrag_m_sn": "income_tax_m", + "zu_versteuerndes_einkommen_y_sn": "taxable_income_y_sn", + }, + "sozialversicherung": { + "pflege": { + "beitrag": { + "betrag_versicherter_m": "long_term_care_insurance_contribution_m", + }, + }, + "kranken": { + "beitrag": {"betrag_versicherter_m": "health_insurance_contribution_m"}, + }, + "rente": { + "beitrag": {"betrag_versicherter_m": "pension_insurance_contribution_m"}, + "entgeltpunkte_updated": "pension_entitlement_points_updated", + "grundrente": { + "gesamteinnahmen_aus_renten_für_einkommensberechnung_im_folgejahr_m": "pension_total_income_for_income_calculation_next_year_m", + }, + "entgeltpunkte_updated": "pension_entitlement_points_updated", + "wartezeit_15_jahre_erfüllt": "pension_waiting_period_15_years_fulfilled", + }, + "arbeitslosen": { + "mean_nettoeinkommen_für_bemessungsgrundlage_bei_arbeitslosigkeit_y": "mean_net_income_for_benefit_basis_in_case_of_unemployment_y", + "beitrag": { + "betrag_versicherter_m": "unemployment_insurance_contribution_m", + } + }, + "pflege": {"beitrag": {"betrag_gesamt_in_gleitzone_m": "long_term_care_insurance_contribution_total_in_transition_zone_m"}}, + "beiträge_gesamt_m": "social_insurance_contributions_total_m", + }, + "kindergeld": {"betrag_m": "KG_betrag_m"}, + "bürgergeld": {"betrag_m_bg": "BG_betrag_m_bg"}, + "grundsicherung": {"im_alter": {"betrag_m_eg": "GS_betrag_m_eg"}}, + "wohngeld": {"betrag_m_wthh": "WG_betrag_m_wthh"}, + "kinderzuschlag": { + "betrag_m_bg": "KiZ_betrag_m_bg", + }, + "familie": {"alleinerziehend_fg": "single_parent_fg"}, + "elterngeld": { + "betrag_m": "EG_betrag_m", + "anrechenbarer_betrag_m": "EG_anrechenbarer_betrag_m", + "mean_nettoeinkommen_für_bemessungsgrundlage_nach_geburt_m": "EG_mean_nettoeinkommen_für_bemessungsgrundlage_nach_geburt_m" + }, + "unterhalt": { + "tatsächlich_erhaltener_betrag_m": "unterhalt_tatsächlich_erhaltener_betrag_m", + "kind_festgelegter_zahlbetrag_m": "unterhalt_kind_festgelegter_zahlbetrag_m", + }, + "unterhaltsvorschuss": { + "an_elternteil_auszuzahlender_betrag_m": "unterhaltsvorschuss_an_elternteil_auszuzahlender_betrag_m", + }, +} + + +# %% +MAPPER = { + "alter": "age", + "alter_monate": "alter_monate", + "geburtsmonat": 1, + "arbeitsstunden_w": "working_hours", + "behinderungsgrad": "disability_grade", + "schwerbehindert_grad_g": False, + "geburtsjahr": "birth_year", + "hh_id": "hh_id", + "p_id": "p_id", + "wohnort_ost_hh": "east_germany", + "einnahmen": { + "bruttolohn_m": 2000.0, + "kapitalerträge_y": 0.0, + "renten": { + "betriebliche_altersvorsorge_m": 0.0, + "geförderte_private_vorsorge_m": 0.0, + "gesetzliche_m": 0.0, + "sonstige_private_vorsorge_m": 0.0, + }, + }, + "einkommensteuer": { + "einkünfte": { + "ist_hauptberuflich_selbstständig": False, + "ist_selbstständig": "self_employed", + "aus_gewerbebetrieb": {"betrag_m": "income_from_self_employment"}, + "aus_vermietung_und_verpachtung": {"betrag_m": "income_from_rent"}, + "aus_nichtselbstständiger_arbeit": { + "bruttolohn_m": "income_from_employment" + }, + "aus_forst_und_landwirtschaft": { + "betrag_m": "income_from_forest_and_agriculture" + }, + "aus_selbstständiger_arbeit": {"betrag_m": "income_from_self_employment"}, + "aus_kapitalvermögen": {"kapitalerträge_m": "income_from_capital"}, + "sonstige": { + "alle_weiteren_y": 0.0, + "ohne_renten_m": "income_from_other_sources", + # "rente": {"ertragsanteil": 0.0}, + "renteneinkünfte_m": "pension_income", + }, + }, + "abzüge": { + "beitrag_private_rentenversicherung_m": "contribution_to_private_pension_insurance", # noqa: E501 + "kinderbetreuungskosten_m": "childcare_expenses", + "p_id_kinderbetreuungskostenträger": "person_that_pays_childcare_expenses", + }, + "gemeinsam_veranlagt": "joint_taxation", + }, + "lohnsteuer": {"steuerklasse": "lohnsteuer__steuerklasse"}, + "sozialversicherung": { + "arbeitslosen": { + # "betrag_m": 0.0 + "mean_nettoeinkommen_in_12_monaten_vor_arbeitslosigkeit_m": 2000.0, + "arbeitssuchend": False, + "monate_beitragspflichtig_versichert_in_letzten_30_monaten": 30, + "monate_sozialversicherungspflichtiger_beschäftigung_in_letzten_5_jahren": 60, + "monate_durchgängigen_bezugs_von_arbeitslosengeld": 0, + }, + "rente": { + "monat_renteneintritt": 1, + "jahr_renteneintritt": "jahr_renteneintritt", + "private_rente_betrag_m": "amount_private_pension_income", + "monate_in_arbeitsunfähigkeit": 0, + "krankheitszeiten_ab_16_bis_24_monate": 0.0, + "monate_in_mutterschutz": 0, + "monate_in_arbeitslosigkeit": 0, + "monate_in_ausbildungssuche": 0, + "monate_in_schulausbildung": 0, + "monate_mit_bezug_entgeltersatzleistungen_wegen_arbeitslosigkeit": 0, + "monate_geringfügiger_beschäftigung": 0, + "kinderberücksichtigungszeiten_monate": 0, + "pflegeberücksichtigungszeiten_monate": 0, + "erwerbsminderung": { + "voll_erwerbsgemindert": False, + "teilweise_erwerbsgemindert": False, + }, + "altersrente": { + # "betrag_m": 0.0, + }, + "grundrente": { + "grundrentenzeiten_monate": 0, + "bewertungszeiten_monate": 0, + "gesamteinnahmen_aus_renten_vorjahr_m": 0.0, + "mean_entgeltpunkte": 0.0, + "bruttolohn_vorjahr_y": 20000.0, + "einnahmen_aus_renten_vorjahr_y": 0.0, + "einnahmen_aus_kapitalvermögen_vorvorjahr_y": 0.0, + "einnahmen_aus_selbstständiger_arbeit_vorvorjahr_y": 0.0, + "einnahmen_aus_vermietung_und_verpachtung_vorvorjahr_y": 0.0, + }, + "bezieht_rente": False, + "entgeltpunkte": 0.0, + "pflichtbeitragsmonate": 0, + "freiwillige_beitragsmonate": 0, + "ersatzzeiten_monate": 0, + }, + "kranken": { + "beitrag": {"privat_versichert": "contribution_private_health_insurance"} + }, + "pflege": {"beitrag": {"hat_kinder": "has_children"}}, + }, + "familie": { + "alleinerziehend": "single_parent", + "kind": "is_child", + "p_id_ehepartner": "spouse_id", + "p_id_elternteil_1": "parent_id_1", + "p_id_elternteil_2": "parent_id_2", + }, + "wohnen": { + "bewohnt_eigentum_hh": False, + "bruttokaltmiete_m_hh": 900.0, + "heizkosten_m_hh": 150.0, + "wohnfläche_hh": 80.0, + }, + "kindergeld": { + "in_ausbildung": "in_training", + "p_id_empfänger": "id_recipient_child_allowance", + }, + "vermögen": 0.0, + "unterhalt": { + "tatsächlich_erhaltener_betrag_m": 0.0, + "anspruch_m": 0.0, + }, + "elterngeld": { + # "betrag_m": 0.0, + # "anrechenbarer_betrag_m": 0.0, + "zu_versteuerndes_einkommen_vorjahr_y_sn": 30000.0, + "mean_nettoeinkommen_in_12_monaten_vor_geburt_m": 2000.0, + "claimed": False, + "bisherige_bezugsmonate": 0 + }, + "bürgergeld": { + # "betrag_m_bg": 0.0, + "p_id_einstandspartner": "bürgergeld__p_id_einstandspartner", + "bezug_im_vorjahr": False, + }, + "wohngeld": { + # "betrag_m_wthh": 0.0, + "mietstufe_hh": 3, + }, + "kinderzuschlag": { + # "betrag_m_bg": 0.0, + }, +} + + +def sync_jax_if_needed(backend): + """Force JAX synchronization to ensure all operations are complete.""" + if backend == "jax": + try: + import jax + # Force synchronization of all JAX operations + jax.block_until_ready(jax.numpy.array([1.0])) + print(" JAX operations synchronized") + except ImportError: + pass # JAX not available, skip synchronization + +def clear_jax_cache(): + """Clear JAX compilation cache to ensure clean state.""" + if JAX_AVAILABLE: + try: + # Import jax locally to avoid issues + import jax as jax_local + # Clear all JAX caches + jax_local.clear_caches() + print(" JAX compilation cache cleared") + except Exception as e: + print(f" Warning: Could not clear JAX cache: {e}") + +def force_garbage_collection(): + """Force aggressive garbage collection between runs.""" + gc.collect() + gc.collect() # Run twice for good measure + print(" Garbage collection completed") + +def reset_session_state(backend): + """Reset session state between different backend runs.""" + print(f" Resetting session state for {backend} backend...") + + # Force garbage collection + force_garbage_collection() + + # Clear JAX-specific state if switching to/from JAX + if backend == "jax" or JAX_AVAILABLE: + clear_jax_cache() + + # Add a small delay to let system settle + time.sleep(0.5) + +def get_memory_usage_mb(): + """Get current memory usage in MB.""" + process = psutil.Process(os.getpid()) + return process.memory_info().rss / 1024 / 1024 + +class MemoryTracker: + """Track peak memory usage during execution with continuous monitoring.""" + def __init__(self): + self.peak_memory = 0 + self.process = psutil.Process(os.getpid()) + self.monitoring = False + self.monitor_thread = None + + def start_monitoring(self): + """Start continuous memory monitoring in background thread.""" + self.monitoring = True + self.peak_memory = self.get_current_memory() + self.monitor_thread = threading.Thread(target=self._monitor_loop, daemon=True) + self.monitor_thread.start() + + def stop_monitoring(self): + """Stop continuous memory monitoring.""" + self.monitoring = False + if self.monitor_thread: + self.monitor_thread.join(timeout=1.0) + + def _monitor_loop(self): + """Background monitoring loop.""" + while self.monitoring: + current = self.get_current_memory() + if current > self.peak_memory: + self.peak_memory = current + time.sleep(0.01) # Check every 10ms + + def get_current_memory(self): + """Get current memory usage in MB.""" + return self.process.memory_info().rss / 1024 / 1024 + + def update(self): + """Update peak memory if current usage is higher.""" + current = self.get_current_memory() + if current > self.peak_memory: + self.peak_memory = current + return current + + def get_peak(self): + """Get peak memory usage in MB.""" + return self.peak_memory + + +def run_benchmark( + N_households, backend, + save_memory_profile=False, + reset_session=False, + sync_jax=False, + ): + """Run a single benchmark with 3-stage timing as in gettsim_profile_stages.py.""" + print(f"Running benchmark: {N_households:,} households, {backend} backend") + + # Reset session state to ensure clean environment + if reset_session: + reset_session_state(backend) + + # Generate data + print(" Generating data...") + data = make_data(N_households) + + # Memory tracking setup + tracker = MemoryTracker() if save_memory_profile else None + + # Initial memory reading + initial_memory = get_memory_usage_mb() + if tracker: + tracker.start_monitoring() + + try: + # STAGE 1: Data preprocessing and DAG creation + print(" Stage 1: Data preprocessing and DAG creation...") + stage1_start = time.time() + + tmp = main( + policy_date_str="2025-01-01", + input_data=InputData.df_and_mapper( + df=data, + mapper=MAPPER, + ), + main_targets=[ + MainTarget.specialized_environment.tt_dag, + MainTarget.processed_data, + MainTarget.labels.root_nodes, + MainTarget.input_data.flat, # Need this for stage 3 + MainTarget.tt_function, + ], + tt_targets=TTTargets( + tree=TT_TARGETS, + ), + include_fail_nodes=False, + include_warn_nodes=False, + backend=backend, + ) + + # Force JAX synchronization before recording end time + if sync_jax: + sync_jax_if_needed(backend) + + stage1_end = time.time() + stage1_time = stage1_end - stage1_start + + # Generate hash for Stage 1 output (tmp) + stage1_hash = hashlib.md5(str(tmp).encode('utf-8')).hexdigest() + + # STAGE 2: Computation only (no data preprocessing) + print(" Stage 2: Computation only...") + + stage2_start = time.time() + + raw_results__columns = main( + policy_date_str="2025-01-01", + main_target=MainTarget.raw_results.columns, + tt_targets=TTTargets( + tree=TT_TARGETS, + ), + processed_data=tmp["processed_data"], + labels=Labels(root_nodes=tmp["labels"]["root_nodes"]), + tt_function=tmp["tt_function"], # Reuse pre-compiled JAX function + include_fail_nodes=False, + include_warn_nodes=False, + backend=backend, + ) + + # Force JAX synchronization before recording end time + if sync_jax: + sync_jax_if_needed(backend) + + stage2_end = time.time() + stage2_time = stage2_end - stage2_start + + # Generate hash for Stage 2 output (raw_results__columns) + stage2_hash = hashlib.md5(str(raw_results__columns).encode('utf-8')).hexdigest() + + # STAGE 3: Convert raw results to DataFrame (no computation, just formatting) + print(" Stage 3: Convert raw results to DataFrame...") + stage3_start = time.time() + + result = main( + policy_date_str="2025-01-01", + main_target=MainTarget.results.df_with_mapper, + tt_targets=TTTargets( + tree=TT_TARGETS, + ), + raw_results=RawResults.columns(raw_results__columns), + input_data=InputData.flat(tmp["input_data"]["flat"]), # Provide the flat input data from stage 1 + processed_data=tmp["processed_data"], + labels=Labels(root_nodes=tmp["labels"]["root_nodes"]), + specialized_environment=SpecializedEnvironment( + tt_dag=tmp["specialized_environment"]["tt_dag"] + ), + include_fail_nodes=False, + include_warn_nodes=False, + backend=backend, + ) + + # Force JAX synchronization before recording end time + if sync_jax: + sync_jax_if_needed(backend) + + stage3_end = time.time() + stage3_time = stage3_end - stage3_start + total_time = stage1_time + stage2_time + stage3_time + + # Generate hash for Stage 3 output (result) + stage3_hash = hashlib.md5(str(result).encode('utf-8')).hexdigest() + + # Final memory reading + final_memory = get_memory_usage_mb() + if tracker: + tracker.stop_monitoring() + + # Determine result shape and type + if hasattr(result, 'shape'): + result_shape = result.shape + else: + result_shape = getattr(result, 'shape', None) + + print(f" ✓ Stage 1 (pre-processing): {stage1_time:.4f}s ({stage1_time/total_time*100:.1f}%)") + print(f" ✓ Stage 2 (computation): {stage2_time:.4f}s ({stage2_time/total_time*100:.1f}%)") + print(f" ✓ Stage 3 (post-processing): {stage3_time:.4f}s ({stage3_time/total_time*100:.1f}%)") + print(f" ✓ Total time: {total_time:.4f} seconds") + if result_shape: + print(f" Result shape: {result_shape}") + else: + print(f" Result type: {type(result)}") + print(f" Memory usage: {initial_memory:.1f} MB → {final_memory:.1f} MB (Δ{final_memory-initial_memory:+.1f} MB)") + print(f" Stage 1 hash: {stage1_hash[:16]}...") + print(f" Stage 2 hash: {stage2_hash[:16]}...") + print(f" Stage 3 hash: {stage3_hash[:16]}...") + + return { + 'stage1_time': stage1_time, + 'stage2_time': stage2_time, + 'stage3_time': stage3_time, + 'execution_time': total_time, # Keep for backwards compatibility + 'stage1_hash': stage1_hash, + 'stage2_hash': stage2_hash, + 'stage3_hash': stage3_hash, + 'initial_memory': initial_memory, + 'final_memory': final_memory, + 'memory_delta': final_memory - initial_memory, + 'result_shape': result_shape, + 'memory_tracker': tracker, + 'peak_memory': tracker.get_peak() if tracker else final_memory + } + + except Exception as e: + print(f" ✗ Failed: {str(e)}") + if tracker: + tracker.stop_monitoring() + return { + 'stage1_time': None, + 'stage2_time': None, + 'stage3_time': None, + 'execution_time': None, + 'result_hash': None, + 'initial_memory': initial_memory, + 'final_memory': get_memory_usage_mb(), + 'memory_delta': None, + 'result_shape': None, + 'memory_tracker': tracker, + 'peak_memory': tracker.get_peak() if tracker else get_memory_usage_mb(), + 'error': str(e) + } + +if __name__ == "__main__": + # Dataset sizes (number of households) + household_sizes = [2**15-1, 2**15, 2**16, 2**17, 2**18, 2**19, 2**20] + # household_sizes = [2**21] # for testing purposes + backends = ["numpy", "jax"] + # backends = ["numpy"] + + results = {} + + # Add metadata + results["metadata"] = { + "timestamp": datetime.now().isoformat(), + "household_sizes": household_sizes, + "backends": backends + } + + for backend in backends: + print(f"\n{'='*60}") + print(f"Testing {backend} backend") + print(f"{'='*60}") + + # Clear all caches and reset session before starting new backend + print(f"Preparing environment for {backend} backend...") + reset_session_state(backend) + + for N_households in household_sizes: + # Add extra session reset for larger datasets to ensure clean state + reset_between_sizes = N_households >= 2**18 # Reset for 256k+ households + + result = run_benchmark( + N_households, + backend, + reset_session=False, # reset_between_sizes (no impact on results) + sync_jax=True, # Set to True if you want to force JAX synchronization + # Seems necessary for realistic (reported time = wall clock time) JAX timings + ) + if result and result.get('execution_time'): + # Store all stage timing data + results[f"{N_households}_{backend}_stage1_time"] = result['stage1_time'] + results[f"{N_households}_{backend}_stage2_time"] = result['stage2_time'] + results[f"{N_households}_{backend}_stage3_time"] = result['stage3_time'] + results[f"{N_households}_{backend}_time"] = result['execution_time'] # Total time + results[f"{N_households}_{backend}_stage1_hash"] = result['stage1_hash'] + results[f"{N_households}_{backend}_stage2_hash"] = result['stage2_hash'] + results[f"{N_households}_{backend}_stage3_hash"] = result['stage3_hash'] + results[f"{N_households}_{backend}_initial_memory"] = result['initial_memory'] + results[f"{N_households}_{backend}_final_memory"] = result['final_memory'] + results[f"{N_households}_{backend}_memory_delta"] = result['memory_delta'] + results[f"{N_households}_{backend}_peak_memory"] = result['peak_memory'] + results[f"{N_households}_{backend}_result_shape"] = result['result_shape'] + else: + # Store None values for failed runs + results[f"{N_households}_{backend}_stage1_time"] = None + results[f"{N_households}_{backend}_stage2_time"] = None + results[f"{N_households}_{backend}_stage3_time"] = None + results[f"{N_households}_{backend}_time"] = None + results[f"{N_households}_{backend}_hash"] = None + results[f"{N_households}_{backend}_initial_memory"] = None + results[f"{N_households}_{backend}_final_memory"] = None + results[f"{N_households}_{backend}_memory_delta"] = None + results[f"{N_households}_{backend}_peak_memory"] = None + results[f"{N_households}_{backend}_result_shape"] = None + print() + + # Comprehensive cleanup after completing all sizes for this backend + print(f"Completing {backend} backend tests...") + # reset_session_state(backend) + print(f"{backend} backend tests completed with full cleanup") + + # Save results to JSON file + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + filename = f"benchmark_results_{timestamp}.json" + with open(filename, 'w') as f: + json.dump(results, f, indent=2) + print(f"\nResults saved to: {filename}") + + print(f"\n{'='*120}") + print("3-STAGE TIMING BREAKDOWN") + print(f"{'='*120}") + + # Print comparison table in the requested format + print(f"\n{'='*101}") + print("PERFORMANCE COMPARISON NUMPY <-> JAX") + print(f"{'='*104}") + print(f"{'Households':<12}{'Stage':<18}{'NUMPY hash':<12}{'JAX hash':<12}{'NUMPY (s)':<12}{'JAX (s)':<12}{'Speedup':<12}") + print("-" * 104) + + for N_households in household_sizes: + # Get timing data for all stages + numpy_s1 = results.get(f"{N_households}_numpy_stage1_time") + numpy_s2 = results.get(f"{N_households}_numpy_stage2_time") + numpy_s3 = results.get(f"{N_households}_numpy_stage3_time") + numpy_total = results.get(f"{N_households}_numpy_time") + + jax_s1 = results.get(f"{N_households}_jax_stage1_time") + jax_s2 = results.get(f"{N_households}_jax_stage2_time") + jax_s3 = results.get(f"{N_households}_jax_stage3_time") + jax_total = results.get(f"{N_households}_jax_time") + + numpy_hash = results.get(f"{N_households}_numpy_hash") + jax_hash = results.get(f"{N_households}_jax_hash") + + # Get stage-specific hashes + numpy_s1_hash = results.get(f"{N_households}_numpy_stage1_hash") + numpy_s2_hash = results.get(f"{N_households}_numpy_stage2_hash") + numpy_s3_hash = results.get(f"{N_households}_numpy_stage3_hash") + + jax_s1_hash = results.get(f"{N_households}_jax_stage1_hash") + jax_s2_hash = results.get(f"{N_households}_jax_stage2_hash") + jax_s3_hash = results.get(f"{N_households}_jax_stage3_hash") + + # Truncate hashes for display, handling both successful and failed cases + def format_hash_display(hash_value, time_value): + """Format hash display based on whether the stage succeeded.""" + if time_value is None: + return "FAILED" + elif hash_value: + return hash_value[:8] + else: + return "N/A" + + numpy_s1_hash_display = format_hash_display(numpy_s1_hash, numpy_s1) + numpy_s2_hash_display = format_hash_display(numpy_s2_hash, numpy_s2) + numpy_s3_hash_display = format_hash_display(numpy_s3_hash, numpy_s3) + + jax_s1_hash_display = format_hash_display(jax_s1_hash, jax_s1) + jax_s2_hash_display = format_hash_display(jax_s2_hash, jax_s2) + jax_s3_hash_display = format_hash_display(jax_s3_hash, jax_s3) + + # Helper function to format time display + def format_time_display(time_value): + """Format time display for successful or failed runs.""" + return f"{time_value:.4f}" if time_value is not None else "FAILED" + + # Helper function to calculate speedup + def calculate_speedup(numpy_time, jax_time): + """Calculate speedup string, handling failed cases.""" + if numpy_time is None and jax_time is None: + return "FAILED" + elif numpy_time is None: + return "N/A" + elif jax_time is None: + return "N/A" + elif jax_time > 0: + speedup = numpy_time / jax_time + return f"{speedup:.2f}x" if speedup >= 1 else f"1/{jax_time/numpy_time:.2f}x" + else: + return "N/A" + + # Determine if we should show stage breakdown or overall FAILED + show_stages = (numpy_total is not None) or (jax_total is not None) + + if show_stages: + # Show individual stage results + + # Pre-processing row (Stage 1 hashes often unstable due to dict return) + s1_speedup_str = calculate_speedup(numpy_s1, jax_s1) + print(f"{N_households:<12,}{'pre-processing':<18}{'-':<12}{'-':<12}{format_time_display(numpy_s1):<12}{format_time_display(jax_s1):<12}{s1_speedup_str:<12}") + + # Computation row (Stage 2 hashes should be stable) + s2_speedup_str = calculate_speedup(numpy_s2, jax_s2) + print(f"{'':>12}{'computation':<18}{numpy_s2_hash_display:<12}{jax_s2_hash_display:<12}{format_time_display(numpy_s2):<12}{format_time_display(jax_s2):<12}{s2_speedup_str:<12}") + + # Post-processing row (Stage 3 hashes should be stable) + s3_speedup_str = calculate_speedup(numpy_s3, jax_s3) + print(f"{'':>12}{'post-processing':<18}{numpy_s3_hash_display:<12}{jax_s3_hash_display:<12}{format_time_display(numpy_s3):<12}{format_time_display(jax_s3):<12}{s3_speedup_str:<12}") + + # Total time row + total_speedup_str = calculate_speedup(numpy_total, jax_total) + print(f"{'':>12}{'total time':<18}{'':>12}{'':>12}{format_time_display(numpy_total):<12}{format_time_display(jax_total):<12}{total_speedup_str:<12}") + + print("-" * 104) + else: + # Both backends completely failed + print(f"{N_households:<12,}{'FAILED':<18}{'FAILED':<12}{'FAILED':<12}{'FAILED':<12}{'FAILED':<12}{'FAILED':<12}") + print("-" * 104) + + # Print memory comparison + print(f"\n{'='*120}") + print("MEMORY USAGE COMPARISON") + print(f"{'='*120}") + print(f"{'Households':<12}{'NumPy Init':<12}{'NumPy Final':<12}{'JAX Init':<12}{'JAX Final':<12}{'NumPy Δ':<12}{'JAX Δ':<12}") + print("-" * 120) + + for N_households in household_sizes: + numpy_init = results.get(f"{N_households}_numpy_initial_memory") + numpy_final = results.get(f"{N_households}_numpy_final_memory") + jax_init = results.get(f"{N_households}_jax_initial_memory") + jax_final = results.get(f"{N_households}_jax_final_memory") + numpy_delta = results.get(f"{N_households}_numpy_memory_delta") + jax_delta = results.get(f"{N_households}_jax_memory_delta") + + # Helper function to format memory values + def format_memory(value): + return f"{value:.1f}" if value is not None else "FAILED" + + # Show memory data even if only one backend succeeded + print(f"{N_households:<12,}{format_memory(numpy_init):<12}{format_memory(numpy_final):<12}{format_memory(jax_init):<12}{format_memory(jax_final):<12}{format_memory(numpy_delta):<12}{format_memory(jax_delta):<12}") + + print("-" * 120) + print("\nLegend:") + print(" Stage 1: Data preprocessing & DAG creation") + print(" Stage 2: Core computation (tax/transfer calculations)") + print(" Stage 3: DataFrame formatting (JAX → pandas conversion)") + print(" Init/Final: Memory usage before/after execution") + print(" Δ: Memory increase during execution") + print(" ✓/✗: Hash verification (results match/differ)") + + print(f"\n{'='*120}") + print("BENCHMARK COMPLETED") + print(f"{'='*120}") + print(f"Results saved to: {filename}") + print(f"Generated at: {datetime.now().isoformat()}") diff --git a/benchmark_code/benchmark_compare.py b/benchmark_code/benchmark_compare.py new file mode 100644 index 0000000..6dacba6 --- /dev/null +++ b/benchmark_code/benchmark_compare.py @@ -0,0 +1,354 @@ +""" +Script to compare benchmark results from main branch vs PR branch. +This script loads two JSON files from benchmark_stages.py runs and creates +comparison tables showing the impact of optimizations with 3-stage breakdown. + +Usage: + python benchmark_compare.py main_results.json pr_results.json [--save-comparison] +""" + +import json +import os +import sys +from datetime import datetime +import argparse + +def load_benchmark_results(filepath): + """Load benchmark results from JSON file.""" + try: + with open(filepath, 'r') as f: + data = json.load(f) + return data + except FileNotFoundError: + print(f"Error: File '{filepath}' not found.") + return None + except json.JSONDecodeError: + print(f"Error: Invalid JSON in file '{filepath}'.") + return None + +def extract_household_sizes(results): + """Extract household sizes from results metadata or data keys.""" + if "metadata" in results and "household_sizes" in results["metadata"]: + return results["metadata"]["household_sizes"] + + # Fallback: extract from data keys + household_sizes = set() + for key in results.keys(): + if key.endswith("_numpy_time") or key.endswith("_jax_time"): + try: + size = int(key.split("_")[0]) + household_sizes.add(size) + except ValueError: + continue + + return sorted(list(household_sizes)) + +def print_jax_comparison_table(main_results, pr_results, household_sizes): + """Print comparison table for JAX backend with 3-stage breakdown.""" + print(f"\n{'='*140}") + print("JAX BACKEND COMPARISON: Main Branch vs PR Branch - 3-STAGE BREAKDOWN") + print(f"{'='*140}") + print(f"{'Households':<12}{'Stage':<18}{'Main (s)':<12}{'PR (s)':<12}{'Speedup':<12}{'Description':<25}{'Hash Match':<12}") + print("-" * 140) + + for N_households in household_sizes: + # Get timing data for all stages + main_s1 = main_results.get(f"{N_households}_jax_stage1_time") + main_s2 = main_results.get(f"{N_households}_jax_stage2_time") + main_s3 = main_results.get(f"{N_households}_jax_stage3_time") + main_total = main_results.get(f"{N_households}_jax_time") + + pr_s1 = pr_results.get(f"{N_households}_jax_stage1_time") + pr_s2 = pr_results.get(f"{N_households}_jax_stage2_time") + pr_s3 = pr_results.get(f"{N_households}_jax_stage3_time") + pr_total = pr_results.get(f"{N_households}_jax_time") + + # Get stage-specific hashes + main_s1_hash = main_results.get(f"{N_households}_jax_stage1_hash") + main_s2_hash = main_results.get(f"{N_households}_jax_stage2_hash") + main_s3_hash = main_results.get(f"{N_households}_jax_stage3_hash") + + pr_s1_hash = pr_results.get(f"{N_households}_jax_stage1_hash") + pr_s2_hash = pr_results.get(f"{N_households}_jax_stage2_hash") + pr_s3_hash = pr_results.get(f"{N_households}_jax_stage3_hash") + + # Check hash matches for each stage + # Stage 1 hashes are intentionally omitted due to instability, so show empty + s1_hash_match = "" + s2_hash_match = "✓" if main_s2_hash == pr_s2_hash else "✗" if main_s2_hash and pr_s2_hash else "N/A" + s3_hash_match = "✓" if main_s3_hash == pr_s3_hash else "✗" if main_s3_hash and pr_s3_hash else "N/A" + + # Check if we have valid data + if all(x is not None for x in [main_s1, main_s2, main_s3, pr_s1, pr_s2, pr_s3]): + # Stage 1 row + s1_speedup = main_s1 / pr_s1 + s1_speedup_str = f"{s1_speedup:.2f}x" if s1_speedup >= 1 else f"1/{pr_s1/main_s1:.2f}x" + print(f"{N_households:<12,}{'pre-processing':<18}{main_s1:<12.4f}{pr_s1:<12.4f}{s1_speedup_str:<12}{'Data preprocessing':<25}{s1_hash_match:<12}") + + # Stage 2 row + s2_speedup = main_s2 / pr_s2 + s2_speedup_str = f"{s2_speedup:.2f}x" if s2_speedup >= 1 else f"1/{pr_s2/main_s2:.2f}x" + print(f"{'':>12}{'computation':<18}{main_s2:<12.4f}{pr_s2:<12.4f}{s2_speedup_str:<12}{'Core computation':<25}{s2_hash_match:<12}") + + # Stage 3 row + s3_speedup = main_s3 / pr_s3 + s3_speedup_str = f"{s3_speedup:.2f}x" if s3_speedup >= 1 else f"1/{pr_s3/main_s3:.2f}x" + print(f"{'':>12}{'post-processing':<18}{main_s3:<12.4f}{pr_s3:<12.4f}{s3_speedup_str:<12}{'DataFrame formatting':<25}{s3_hash_match:<12}") + + # Total row + if main_total and pr_total: + total_speedup = main_total / pr_total + total_speedup_str = f"{total_speedup:.2f}x" if total_speedup >= 1 else f"1/{pr_total/main_total:.2f}x" + print(f"{'':>12}{'total time':<18}{main_total:<12.4f}{pr_total:<12.4f}{total_speedup_str:<12}{'Complete execution':<25}{'':<12}") + + print("-" * 140) + else: + # Handle failed cases + main_time_str = f"{main_total:.4f}" if main_total is not None else "FAILED" + pr_time_str = f"{pr_total:.4f}" if pr_total is not None else "FAILED" + print(f"{N_households:<12,}{'FAILED':<18}{main_time_str:<12}{pr_time_str:<12}{'N/A':<12}{'Benchmark failed':<25}{'N/A':<12}") + print("-" * 140) + +def print_numpy_comparison_table(main_results, pr_results, household_sizes): + """Print comparison table for NumPy backend with 3-stage breakdown.""" + print(f"\n{'='*140}") + print("NUMPY BACKEND COMPARISON: Main Branch vs PR Branch - 3-STAGE BREAKDOWN") + print(f"{'='*140}") + print(f"{'Households':<12}{'Stage':<18}{'Main (s)':<12}{'PR (s)':<12}{'Speedup':<12}{'Description':<25}{'Hash Match':<12}") + print("-" * 140) + + for N_households in household_sizes: + # Get timing data for all stages + main_s1 = main_results.get(f"{N_households}_numpy_stage1_time") + main_s2 = main_results.get(f"{N_households}_numpy_stage2_time") + main_s3 = main_results.get(f"{N_households}_numpy_stage3_time") + main_total = main_results.get(f"{N_households}_numpy_time") + + pr_s1 = pr_results.get(f"{N_households}_numpy_stage1_time") + pr_s2 = pr_results.get(f"{N_households}_numpy_stage2_time") + pr_s3 = pr_results.get(f"{N_households}_numpy_stage3_time") + pr_total = pr_results.get(f"{N_households}_numpy_time") + + # Get stage-specific hashes + main_s1_hash = main_results.get(f"{N_households}_numpy_stage1_hash") + main_s2_hash = main_results.get(f"{N_households}_numpy_stage2_hash") + main_s3_hash = main_results.get(f"{N_households}_numpy_stage3_hash") + + pr_s1_hash = pr_results.get(f"{N_households}_numpy_stage1_hash") + pr_s2_hash = pr_results.get(f"{N_households}_numpy_stage2_hash") + pr_s3_hash = pr_results.get(f"{N_households}_numpy_stage3_hash") + + # Check hash matches for each stage + # Stage 1 hashes are intentionally omitted due to instability, so show empty + s1_hash_match = "" + s2_hash_match = "✓" if main_s2_hash == pr_s2_hash else "✗" if main_s2_hash and pr_s2_hash else "N/A" + s3_hash_match = "✓" if main_s3_hash == pr_s3_hash else "✗" if main_s3_hash and pr_s3_hash else "N/A" + + # Check if we have valid data + if all(x is not None for x in [main_s1, main_s2, main_s3, pr_s1, pr_s2, pr_s3]): + # Stage 1 row + s1_speedup = main_s1 / pr_s1 + s1_speedup_str = f"{s1_speedup:.2f}x" if s1_speedup >= 1 else f"1/{pr_s1/main_s1:.2f}x" + print(f"{N_households:<12,}{'pre-processing':<18}{main_s1:<12.4f}{pr_s1:<12.4f}{s1_speedup_str:<12}{'Data preprocessing':<25}{s1_hash_match:<12}") + + # Stage 2 row + s2_speedup = main_s2 / pr_s2 + s2_speedup_str = f"{s2_speedup:.2f}x" if s2_speedup >= 1 else f"1/{pr_s2/main_s2:.2f}x" + print(f"{'':>12}{'computation':<18}{main_s2:<12.4f}{pr_s2:<12.4f}{s2_speedup_str:<12}{'Core computation':<25}{s2_hash_match:<12}") + + # Stage 3 row + s3_speedup = main_s3 / pr_s3 + s3_speedup_str = f"{s3_speedup:.2f}x" if s3_speedup >= 1 else f"1/{pr_s3/main_s3:.2f}x" + print(f"{'':>12}{'post-processing':<18}{main_s3:<12.4f}{pr_s3:<12.4f}{s3_speedup_str:<12}{'DataFrame formatting':<25}{s3_hash_match:<12}") + + # Total row + if main_total and pr_total: + total_speedup = main_total / pr_total + total_speedup_str = f"{total_speedup:.2f}x" if total_speedup >= 1 else f"1/{pr_total/main_total:.2f}x" + print(f"{'':>12}{'total time':<18}{main_total:<12.4f}{pr_total:<12.4f}{total_speedup_str:<12}{'Complete execution':<25}{'':<12}") + + print("-" * 140) + else: + # Handle failed cases + main_time_str = f"{main_total:.4f}" if main_total is not None else "FAILED" + pr_time_str = f"{pr_total:.4f}" if pr_total is not None else "FAILED" + print(f"{N_households:<12,}{'FAILED':<18}{main_time_str:<12}{pr_time_str:<12}{'N/A':<12}{'Benchmark failed':<25}{'N/A':<12}") + print("-" * 140) + +def print_summary_statistics(main_results, pr_results, household_sizes): + """Print summary statistics comparing main vs PR performance with 3-stage breakdown.""" + print(f"\n{'='*100}") + print("SUMMARY STATISTICS - 3-STAGE BREAKDOWN") + print(f"{'='*100}") + + backends = ["numpy", "jax"] + stages = ["stage1", "stage2", "stage3", "total"] + stage_names = { + "stage1": "Stage 1 (preprocessing)", + "stage2": "Stage 2 (computation)", + "stage3": "Stage 3 (formatting)", + "total": "Total execution" + } + + for backend in backends: + print(f"\n{backend.upper()} Backend:") + print("-" * 40) + + for stage in stages: + if stage == "total": + time_suffix = "_time" + else: + time_suffix = f"_{stage}_time" + + valid_speedups = [] + successful_runs = 0 + total_runs = len(household_sizes) + + for N_households in household_sizes: + main_time = main_results.get(f"{N_households}_{backend}{time_suffix}") + pr_time = pr_results.get(f"{N_households}_{backend}{time_suffix}") + + if main_time is not None and pr_time is not None: + successful_runs += 1 + speedup = main_time / pr_time + valid_speedups.append(speedup) + + if valid_speedups: + avg_speedup = sum(valid_speedups) / len(valid_speedups) + max_speedup = max(valid_speedups) + min_speedup = min(valid_speedups) + + print(f" {stage_names[stage]}:") + print(f" Average speedup: {avg_speedup:.2f}x") + print(f" Maximum speedup: {max_speedup:.2f}x") + print(f" Minimum speedup: {min_speedup:.2f}x") + print(f" Successful runs: {successful_runs}/{total_runs}") + else: + print(f" {stage_names[stage]}: No valid comparisons available") + + # Check hash consistency for all stages + stage_hash_results = {} + for stage_num in [1, 2, 3]: + hash_mismatches = 0 + total_comparisons = 0 + + for N_households in household_sizes: + main_hash = main_results.get(f"{N_households}_{backend}_stage{stage_num}_hash") + pr_hash = pr_results.get(f"{N_households}_{backend}_stage{stage_num}_hash") + + if main_hash and pr_hash: + total_comparisons += 1 + if main_hash != pr_hash: + hash_mismatches += 1 + + stage_hash_results[stage_num] = (hash_mismatches, total_comparisons) + + # Print hash verification results + all_stages_perfect = True + for stage_num in [1, 2, 3]: + hash_mismatches, total_comparisons = stage_hash_results[stage_num] + stage_name = {1: "Stage 1", 2: "Stage 2", 3: "Stage 3"}[stage_num] + + if total_comparisons > 0: + print(f" {stage_name} hash verification: {hash_mismatches}/{total_comparisons} mismatches") + if hash_mismatches > 0: + all_stages_perfect = False + else: + print(f" {stage_name} hash verification: No valid comparisons available") + all_stages_perfect = False + + if all_stages_perfect and any(total for _, total in stage_hash_results.values()): + print(f" ✓ All stage results are numerically identical") + elif any(mismatches for mismatches, _ in stage_hash_results.values()): + print(f" ⚠ Some stage results differ between main and PR") + else: + print(f" No valid hash comparisons available") + + # Overall comparison + print(f"\n{'='*100}") + print("OVERALL PERFORMANCE IMPACT") + print(f"{'='*100}") + + for backend in backends: + total_speedups = [] + for N_households in household_sizes: + main_total = main_results.get(f"{N_households}_{backend}_time") + pr_total = pr_results.get(f"{N_households}_{backend}_time") + + if main_total and pr_total: + total_speedups.append(main_total / pr_total) + + if total_speedups: + avg_speedup = sum(total_speedups) / len(total_speedups) + if avg_speedup > 1.05: + impact = f"PR is {avg_speedup:.1f}x faster (significant improvement)" + elif avg_speedup < 0.95: + impact = f"PR is {1/avg_speedup:.1f}x slower (performance regression)" + else: + impact = "PR has minimal performance impact (±5%)" + + print(f"{backend.upper()}: {impact}") + else: + print(f"{backend.upper()}: No valid performance comparisons available") + +def main(): + parser = argparse.ArgumentParser(description="Compare benchmark results from main branch vs PR branch") + parser.add_argument("main_file", help="Path to benchmark results JSON file from main branch") + parser.add_argument("pr_file", help="Path to benchmark results JSON file from PR branch") + parser.add_argument("--save-comparison", help="Save comparison tables to text file", action="store_true") + + args = parser.parse_args() + + # Load benchmark results + print("Loading benchmark results...") + main_results = load_benchmark_results(args.main_file) + pr_results = load_benchmark_results(args.pr_file) + + if main_results is None or pr_results is None: + sys.exit(1) + + # Extract household sizes (use PR results as primary, fallback to main) + household_sizes = extract_household_sizes(pr_results) + if not household_sizes: + household_sizes = extract_household_sizes(main_results) + + if not household_sizes: + print("Error: Could not extract household sizes from either file.") + sys.exit(1) + + print(f"Found data for household sizes: {household_sizes}") + + # Print comparison tables + print_jax_comparison_table(main_results, pr_results, household_sizes) + print_numpy_comparison_table(main_results, pr_results, household_sizes) + print_summary_statistics(main_results, pr_results, household_sizes) + + # Save to file if requested + if args.save_comparison: + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + output_file = f"benchmark_comparison_{timestamp}.txt" + + # Redirect stdout to file + original_stdout = sys.stdout + + try: + with open(output_file, 'w') as f: + sys.stdout = f + print(f"Benchmark Comparison Report") + print(f"Generated: {datetime.now().isoformat()}") + print(f"Main branch file: {args.main_file}") + print(f"PR branch file: {args.pr_file}") + + print_jax_comparison_table(main_results, pr_results, household_sizes) + print_numpy_comparison_table(main_results, pr_results, household_sizes) + print_summary_statistics(main_results, pr_results, household_sizes) + + sys.stdout = original_stdout + print(f"\nComparison saved to: {output_file}") + + except Exception as e: + sys.stdout = original_stdout + print(f"Error saving comparison: {e}") + +if __name__ == "__main__": + main() diff --git a/benchmark_code/gettsim_profile.py b/benchmark_code/gettsim_profile.py new file mode 100644 index 0000000..8454498 --- /dev/null +++ b/benchmark_code/gettsim_profile.py @@ -0,0 +1,418 @@ +""" +GETTSIM Profiling Script + +This script profiles GETTSIM/TTSIM with synthetic data. +It supports both JAX and NumPy backends. + +Usage: + python gettsim_profile.py -N 32768 -b numpy (without profile) + py-spy record -o profile.svg -- python gettsim_profile.py -N 32768 -b numpy (with profile) + +""" + + +# %% +import pandas as pd +import time +import argparse +import hashlib +from gettsim import InputData, MainTarget, TTTargets, Labels, SpecializedEnvironment, RawResults + +# Hack: Override GETTSIM main to make all TTSIM parameters of main available in GETTSIM. +# Necessary because of GETTSIM issue #1075. +# When resolved, this can be removed and gettsim.main can be used directly. +from gettsim import germany +import ttsim +from ttsim.main_args import OrigPolicyObjects + +def main(**kwargs): + """Wrapper around ttsim.main that automatically sets the German root path and supports tt_function.""" + # Set German tax system as default if no orig_policy_objects provided + if kwargs.get('orig_policy_objects') is None: + kwargs['orig_policy_objects'] = OrigPolicyObjects(root=germany.ROOT_PATH) + + return ttsim.main(**kwargs) + +from make_data import make_data + + + + +# %% +TT_TARGETS = { + "einkommensteuer": { + "betrag_m_sn": "income_tax_m", + "zu_versteuerndes_einkommen_y_sn": "taxable_income_y_sn", + }, + "sozialversicherung": { + "pflege": { + "beitrag": { + "betrag_versicherter_m": "long_term_care_insurance_contribution_m", + }, + }, + "kranken": { + "beitrag": {"betrag_versicherter_m": "health_insurance_contribution_m"}, + }, + "rente": { + "beitrag": {"betrag_versicherter_m": "pension_insurance_contribution_m"}, + "entgeltpunkte_updated": "pension_entitlement_points_updated", + "grundrente": { + "gesamteinnahmen_aus_renten_für_einkommensberechnung_im_folgejahr_m": "pension_total_income_for_income_calculation_next_year_m", + }, + "entgeltpunkte_updated": "pension_entitlement_points_updated", + "wartezeit_15_jahre_erfüllt": "pension_waiting_period_15_years_fulfilled", + }, + "arbeitslosen": { + "mean_nettoeinkommen_für_bemessungsgrundlage_bei_arbeitslosigkeit_y": "mean_net_income_for_benefit_basis_in_case_of_unemployment_y", + "beitrag": { + "betrag_versicherter_m": "unemployment_insurance_contribution_m", + } + }, + "pflege": {"beitrag": {"betrag_gesamt_in_gleitzone_m": "long_term_care_insurance_contribution_total_in_transition_zone_m"}}, + "beiträge_gesamt_m": "social_insurance_contributions_total_m", + }, + "kindergeld": {"betrag_m": "KG_betrag_m"}, + "bürgergeld": {"betrag_m_bg": "BG_betrag_m_bg"}, + "grundsicherung": {"im_alter": {"betrag_m_eg": "GS_betrag_m_eg"}}, + "wohngeld": {"betrag_m_wthh": "WG_betrag_m_wthh"}, + "kinderzuschlag": { + "betrag_m_bg": "KiZ_betrag_m_bg", + }, + "familie": {"alleinerziehend_fg": "single_parent_fg"}, + "elterngeld": { + "betrag_m": "EG_betrag_m", + "anrechenbarer_betrag_m": "EG_anrechenbarer_betrag_m", + "mean_nettoeinkommen_für_bemessungsgrundlage_nach_geburt_m": "EG_mean_nettoeinkommen_für_bemessungsgrundlage_nach_geburt_m" + }, + "unterhalt": { + "tatsächlich_erhaltener_betrag_m": "unterhalt_tatsächlich_erhaltener_betrag_m", + "kind_festgelegter_zahlbetrag_m": "unterhalt_kind_festgelegter_zahlbetrag_m", + }, + "unterhaltsvorschuss": { + "an_elternteil_auszuzahlender_betrag_m": "unterhaltsvorschuss_an_elternteil_auszuzahlender_betrag_m", + }, +} + + +# %% +MAPPER = { + "alter": "age", + "alter_monate": "alter_monate", + "geburtsmonat": 1, + "arbeitsstunden_w": "working_hours", + "behinderungsgrad": "disability_grade", + "schwerbehindert_grad_g": False, + "geburtsjahr": "birth_year", + "hh_id": "hh_id", + "p_id": "p_id", + "wohnort_ost_hh": "east_germany", + "einnahmen": { + "bruttolohn_m": 2000.0, + "kapitalerträge_y": 0.0, + "renten": { + "betriebliche_altersvorsorge_m": 0.0, + "geförderte_private_vorsorge_m": 0.0, + "gesetzliche_m": 0.0, + "sonstige_private_vorsorge_m": 0.0, + }, + }, + "einkommensteuer": { + "einkünfte": { + "ist_hauptberuflich_selbstständig": False, + "ist_selbstständig": "self_employed", + "aus_gewerbebetrieb": {"betrag_m": "income_from_self_employment"}, + "aus_vermietung_und_verpachtung": {"betrag_m": "income_from_rent"}, + "aus_nichtselbstständiger_arbeit": { + "bruttolohn_m": "income_from_employment" + }, + "aus_forst_und_landwirtschaft": { + "betrag_m": "income_from_forest_and_agriculture" + }, + "aus_selbstständiger_arbeit": {"betrag_m": "income_from_self_employment"}, + "aus_kapitalvermögen": {"kapitalerträge_m": "income_from_capital"}, + "sonstige": { + "alle_weiteren_y": 0.0, + "ohne_renten_m": "income_from_other_sources", + # "rente": {"ertragsanteil": 0.0}, + "renteneinkünfte_m": "pension_income", + }, + }, + "abzüge": { + "beitrag_private_rentenversicherung_m": "contribution_to_private_pension_insurance", # noqa: E501 + "kinderbetreuungskosten_m": "childcare_expenses", + "p_id_kinderbetreuungskostenträger": "person_that_pays_childcare_expenses", + }, + "gemeinsam_veranlagt": "joint_taxation", + }, + "lohnsteuer": {"steuerklasse": "lohnsteuer__steuerklasse"}, + "sozialversicherung": { + "arbeitslosen": { + # "betrag_m": 0.0 + "mean_nettoeinkommen_in_12_monaten_vor_arbeitslosigkeit_m": 2000.0, + "arbeitssuchend": False, + "monate_beitragspflichtig_versichert_in_letzten_30_monaten": 30, + "monate_sozialversicherungspflichtiger_beschäftigung_in_letzten_5_jahren": 60, + "monate_durchgängigen_bezugs_von_arbeitslosengeld": 0, + }, + "rente": { + "monat_renteneintritt": 1, + "jahr_renteneintritt": "jahr_renteneintritt", + "private_rente_betrag_m": "amount_private_pension_income", + "monate_in_arbeitsunfähigkeit": 0, + "krankheitszeiten_ab_16_bis_24_monate": 0.0, + "monate_in_mutterschutz": 0, + "monate_in_arbeitslosigkeit": 0, + "monate_in_ausbildungssuche": 0, + "monate_in_schulausbildung": 0, + "monate_mit_bezug_entgeltersatzleistungen_wegen_arbeitslosigkeit": 0, + "monate_geringfügiger_beschäftigung": 0, + "kinderberücksichtigungszeiten_monate": 0, + "pflegeberücksichtigungszeiten_monate": 0, + "erwerbsminderung": { + "voll_erwerbsgemindert": False, + "teilweise_erwerbsgemindert": False, + }, + "altersrente": { + # "betrag_m": 0.0, + }, + "grundrente": { + "grundrentenzeiten_monate": 0, + "bewertungszeiten_monate": 0, + "gesamteinnahmen_aus_renten_vorjahr_m": 0.0, + "mean_entgeltpunkte": 0.0, + "bruttolohn_vorjahr_y": 20000.0, + "einnahmen_aus_renten_vorjahr_y": 0.0, + "einnahmen_aus_kapitalvermögen_vorvorjahr_y": 0.0, + "einnahmen_aus_selbstständiger_arbeit_vorvorjahr_y": 0.0, + "einnahmen_aus_vermietung_und_verpachtung_vorvorjahr_y": 0.0, + }, + "bezieht_rente": False, + "entgeltpunkte": 0.0, + "pflichtbeitragsmonate": 0, + "freiwillige_beitragsmonate": 0, + "ersatzzeiten_monate": 0, + }, + "kranken": { + "beitrag": {"privat_versichert": "contribution_private_health_insurance"} + }, + "pflege": {"beitrag": {"hat_kinder": "has_children"}}, + }, + "familie": { + "alleinerziehend": "single_parent", + "kind": "is_child", + "p_id_ehepartner": "spouse_id", + "p_id_elternteil_1": "parent_id_1", + "p_id_elternteil_2": "parent_id_2", + }, + "wohnen": { + "bewohnt_eigentum_hh": False, + "bruttokaltmiete_m_hh": 900.0, + "heizkosten_m_hh": 150.0, + "wohnfläche_hh": 80.0, + }, + "kindergeld": { + "in_ausbildung": "in_training", + "p_id_empfänger": "id_recipient_child_allowance", + }, + "vermögen": 0.0, + "unterhalt": { + "tatsächlich_erhaltener_betrag_m": 0.0, + "anspruch_m": 0.0, + }, + "elterngeld": { + # "betrag_m": 0.0, + # "anrechenbarer_betrag_m": 0.0, + "zu_versteuerndes_einkommen_vorjahr_y_sn": 30000.0, + "mean_nettoeinkommen_in_12_monaten_vor_geburt_m": 2000.0, + "claimed": False, + "bisherige_bezugsmonate": 0 + }, + "bürgergeld": { + # "betrag_m_bg": 0.0, + "p_id_einstandspartner": "bürgergeld__p_id_einstandspartner", + "bezug_im_vorjahr": False, + }, + "wohngeld": { + # "betrag_m_wthh": 0.0, + "mietstufe_hh": 3, + }, + "kinderzuschlag": { + # "betrag_m_bg": 0.0, + }, +} + +def sync_jax_if_needed(backend): + """Force JAX synchronization to ensure all operations are complete.""" + if backend == "jax": + try: + import jax + # Force synchronization of all JAX operations + jax.block_until_ready(jax.numpy.array([1.0])) + print(" JAX operations synchronized") + except ImportError: + pass + except Exception as e: + print(f" Warning: JAX sync failed: {e}") + + +def run_profile(N, backend): + """Run GETTSIM profiling with specified parameters.""" + print(f"Generating dataset with {N:,} households...") + data = make_data(N) + print(f"Dataset created successfully. Shape: {data.shape}") + + print(f"Running GETTSIM with backend: {backend}") + + # First stage - preprocessing and DAG creation + print("\n=== STAGE 1: Data preprocessing and DAG creation ===") + stage1_start = time.time() + + tmp = main( + policy_date_str="2025-01-01", + input_data=InputData.df_and_mapper( + df=data, + mapper=MAPPER, + ), + main_targets=[ + MainTarget.specialized_environment.tt_dag, + MainTarget.processed_data, + MainTarget.labels.root_nodes, + MainTarget.input_data.flat, # Need this for stage 3 + MainTarget.tt_function, # Use compiled tt_function in stage 2 with JAX backend + ], + tt_targets=TTTargets( + tree=TT_TARGETS, + ), + include_fail_nodes=True, + include_warn_nodes=False, + backend=backend, + ) + + # Force JAX synchronization before recording end time + sync_jax_if_needed(backend) + + stage1_end = time.time() + stage1_time = stage1_end - stage1_start + + # Generate hash for Stage 1 output (tmp) - avoid memory issues with large arrays + stage1_hash = hashlib.md5(str(tmp).encode('utf-8')).hexdigest() + + print(f"Stage 1 completed in: {stage1_time:.4f} seconds") + print(f"Processed data keys: {len(tmp['processed_data'])}") + print(f"DAG nodes: {len(tmp['specialized_environment']['tt_dag'])}") + print(f"Stage 1 hash: {stage1_hash[:16]}...") + + # Second stage - computation only (no data preprocessing) + print("\n=== STAGE 2: Computation only (no preprocessing) ===") + print(f"Wall clock time: {time.strftime('%H:%M:%S')} - Starting Stage 2") + stage2_start = time.time() + + raw_results__columns = main( + policy_date_str="2025-01-01", + main_target=MainTarget.raw_results.columns, + tt_targets=TTTargets( + tree=TT_TARGETS, + ), + processed_data=tmp["processed_data"], + labels=Labels(root_nodes=tmp["labels"]["root_nodes"]), + tt_function=tmp["tt_function"], # Reuse pre-compiled JAX function + include_fail_nodes=True, + include_warn_nodes=False, + backend=backend, + ) + + # Force JAX synchronization before recording end time + sync_jax_if_needed(backend) + + stage2_end = time.time() + print(f"Wall clock time: {time.strftime('%H:%M:%S')} - Completed Stage 2") + stage2_time = stage2_end - stage2_start + + # Generate hash for Stage 2 output - avoid memory issues with large JAX arrays + stage2_hash = hashlib.md5(str(raw_results__columns).encode('utf-8')).hexdigest() + + print(f"Stage 2 completed in: {stage2_time:.4f} seconds") + print(f"Raw results keys: {len(raw_results__columns)}") + print(f"Stage 2 hash: {stage2_hash[:16]}...") + + # Third stage - convert raw results to DataFrame (no computation, just formatting) + print("\n=== STAGE 3: Convert raw results to DataFrame ===") + print(f"Wall clock time: {time.strftime('%H:%M:%S')} - Starting Stage 3") + stage3_start = time.time() + + final_results = main( + policy_date_str="2025-01-01", + main_target=MainTarget.results.df_with_mapper, + tt_targets=TTTargets( + tree=TT_TARGETS, + ), + raw_results=RawResults.columns(raw_results__columns), + input_data=InputData.flat(tmp["input_data"]["flat"]), # Provide the flat input data from stage 1 + processed_data=tmp["processed_data"], + labels=Labels(root_nodes=tmp["labels"]["root_nodes"]), + specialized_environment=SpecializedEnvironment( + tt_dag=tmp["specialized_environment"]["tt_dag"] + ), + include_fail_nodes=True, + include_warn_nodes=False, + backend=backend, + ) + + # Force JAX synchronization before recording end time + sync_jax_if_needed(backend) + + stage3_end = time.time() + print(f"Wall clock time: {time.strftime('%H:%M:%S')} - Completed Stage 3") + stage3_time = stage3_end - stage3_start + total_time = stage1_time + stage2_time + stage3_time + + # Generate hash for Stage 3 output - avoid memory issues + stage3_hash = hashlib.md5(str(final_results).encode('utf-8')).hexdigest() + + print(f"Stage 3 completed in: {stage3_time:.4f} seconds") + print(f"Final DataFrame shape: {final_results.shape if hasattr(final_results, 'shape') else 'N/A'}") + print(f"Final DataFrame type: {type(final_results)}") + print(f"Stage 3 hash: {stage3_hash[:16]}...") + print(f"Total execution time: {total_time:.4f} seconds") + print(f"Stage 1 (preprocessing): {stage1_time:.4f}s ({stage1_time/total_time*100:.1f}%)") + print(f"Stage 2 (computation): {stage2_time:.4f}s ({stage2_time/total_time*100:.1f}%)") + print(f"Stage 3 (formatting): {stage3_time:.4f}s ({stage3_time/total_time*100:.1f}%)") + print(f"Backend: {backend}") + print(f"Households: {N:,}") + print(f"People: {len(data):,}") + print(f"Performance: {N / total_time:.0f} households/second") + print("\n=== STAGE HASHES ===") + print(f"Stage 1 hash: {stage1_hash[:16]}...") + print(f"Stage 2 hash: {stage2_hash[:16]}...") + print(f"Stage 3 hash: {stage3_hash[:16]}...") + + return final_results, total_time + + +def main_cli(): + """Main function for command line interface.""" + parser = argparse.ArgumentParser(description='Profile GETTSIM with synthetic data') + parser.add_argument('-N', '--households', type=int, default=32768, + help='Number of households to generate (default: 32768)') + parser.add_argument('-b', '--backend', choices=['numpy', 'jax'], default='numpy', + help='Backend to use: numpy or jax (default: numpy)') + + args = parser.parse_args() + + print("GETTSIM Profiling Tool") + print("=" * 50) + + result, exec_time = run_profile(args.households, args.backend) + + print("\n" + "=" * 50) + print("Profiling completed successfully!") + + return result, exec_time + + +if __name__ == "__main__": + main_cli() + +# %% +# For interactive use - you can also run this directly +# result, exec_time = run_profile(N=32768, backend="numpy") diff --git a/benchmark_code/make_data.py b/benchmark_code/make_data.py new file mode 100644 index 0000000..3918b7f --- /dev/null +++ b/benchmark_code/make_data.py @@ -0,0 +1,158 @@ +""" +Script to generate synthetic datasets for GETTSIM benchmarking/profiling. + +""" +# %% + +import pandas as pd +import numpy as np +import time + +def make_data(N): + """ + Create a DataFrame with N households, each containing 2 parents and 2 children. + Uses vectorized operations for fast data generation. + + Parameters: + N (int): Number of households to create + + Returns: + pd.DataFrame: DataFrame with household data (4*N rows) + """ + # Total number of people (4 per household: 2 parents + 2 children) + total_people = N * 4 + + # Create base template for one household (4 people) + base_template = np.array([ + # Parent 1 + [30, 35, 0, 1995, 0, 0, False, False, 0, 0, 5000, 0, 500, 0, 0, 0, 0, -1, True, 0, 0, True, False, False, 1, -1, -1, False, -1, 1, 4, 360, 2062], + # Parent 2 + [30, 35, 0, 1995, 0, 1, False, False, 0, 0, 4000, 0, 0, 0, 0, 0, 0, -1, True, 0, 0, True, False, False, 0, -1, -1, False, -1, 0, 4, 360, 2062], + # Child 1 + [10, 0, 0, 2015, 0, 2, False, False, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, False, 0, 0, False, False, True, -1, 0, 1, False, 0, -1, -1, 120, 2082], + # Child 2 (twin) + [10, 0, 0, 2015, 0, 3, False, False, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, False, 0, 0, False, False, True, -1, 0, 1, False, 0, -1, -1, 120, 2082] + ]) + + # Replicate template for all households + data_array = np.tile(base_template, (N, 1)) + + # Create household and person IDs using vectorized operations + hh_ids = np.repeat(np.arange(N), 4) + p_ids = np.arange(total_people) + + # Update IDs in the data array + data_array[:, 4] = hh_ids # hh_id column + data_array[:, 5] = p_ids # p_id column + + # Update spouse_ids for parents (every 4th person starting from 0 gets spouse_id of next person, and vice versa) + spouse_mask_1 = np.arange(total_people) % 4 == 0 # Parent 1 positions + spouse_mask_2 = np.arange(total_people) % 4 == 1 # Parent 2 positions + data_array[spouse_mask_1, 24] = p_ids[spouse_mask_1] + 1 # Parent 1 -> Parent 2 + data_array[spouse_mask_2, 24] = p_ids[spouse_mask_2] - 1 # Parent 2 -> Parent 1 + + # Update bürgergeld__p_id_einstandspartner (identical to spouse_id) + data_array[spouse_mask_1, 29] = p_ids[spouse_mask_1] + 1 # Parent 1 -> Parent 2 + data_array[spouse_mask_2, 29] = p_ids[spouse_mask_2] - 1 # Parent 2 -> Parent 1 + + # Update parent_ids for children + child_mask_1 = np.arange(total_people) % 4 == 2 # Child 1 positions + child_mask_2 = np.arange(total_people) % 4 == 3 # Child 2 positions + parent1_ids = p_ids[child_mask_1] - 2 # Parent 1 IDs for children + parent2_ids = p_ids[child_mask_1] - 1 # Parent 2 IDs for children + + data_array[child_mask_1, 25] = parent1_ids # parent_id_1 for child 1 + data_array[child_mask_1, 26] = parent2_ids # parent_id_2 for child 1 + data_array[child_mask_2, 25] = parent1_ids # parent_id_1 for child 2 + data_array[child_mask_2, 26] = parent2_ids # parent_id_2 for child 2 + + # Update person_that_pays_childcare_expenses and id_recipient_child_allowance for children + data_array[child_mask_1, 17] = parent1_ids # person_that_pays_childcare_expenses for child 1 + data_array[child_mask_2, 17] = parent1_ids # person_that_pays_childcare_expenses for child 2 + data_array[child_mask_1, 28] = parent1_ids # id_recipient_child_allowance for child 1 + data_array[child_mask_2, 28] = parent1_ids # id_recipient_child_allowance for child 2 + + # Column names in the same order as the template + columns = [ + "age", "working_hours", "disability_grade", "birth_year", "hh_id", "p_id", + "east_germany", "self_employed", "income_from_self_employment", "income_from_rent", + "income_from_employment", "income_from_forest_and_agriculture", "income_from_capital", + "income_from_other_sources", "pension_income", "contribution_to_private_pension_insurance", + "childcare_expenses", "person_that_pays_childcare_expenses", "joint_taxation", + "amount_private_pension_income", "contribution_private_health_insurance", "has_children", + "single_parent", "is_child", "spouse_id", "parent_id_1", "parent_id_2", "in_training", + "id_recipient_child_allowance", "bürgergeld__p_id_einstandspartner", "lohnsteuer__steuerklasse", + "alter_monate", "jahr_renteneintritt", + ] + + # Create DataFrame + data = pd.DataFrame(data_array, columns=columns) + + # Convert boolean columns back to bool (they become float during array operations) + bool_columns = ["east_germany", "self_employed", "joint_taxation", "has_children", "single_parent", "is_child", "in_training"] + for col in bool_columns: + data[col] = data[col].astype(bool) + + # Convert integer columns to int + int_columns = ["age", "working_hours", "disability_grade", "birth_year", "hh_id", "p_id", + "spouse_id", "parent_id_1", "parent_id_2", "person_that_pays_childcare_expenses", + "id_recipient_child_allowance", "bürgergeld__p_id_einstandspartner", "lohnsteuer__steuerklasse", + "alter_monate", "jahr_renteneintritt"] + for col in int_columns: + data[col] = data[col].astype(int) + + print(f"Created DataFrame with {len(data)} rows ({len(data) // 4} households)") + return data + + +def main(): + """Generate datasets for all required sizes and measure timing.""" + # Dataset sizes (number of households) + # Each household has 4 people (2 parents + 2 children) + household_sizes = [2**15-1, 2**15, 2**16, 2**17, 2**18, 2**19, 2**20] + + print("Generating synthetic datasets for GETTSIM benchmarking...") + print("=" * 60) + + timing_results = [] + + for num_households in household_sizes: + print(f"\nGenerating dataset for {num_households:,} households...") + + # Time the data creation + start_time = time.time() + data = make_data(num_households) + end_time = time.time() + + creation_time = end_time - start_time + timing_results.append((num_households, creation_time)) + + # Calculate memory usage estimation + memory_mb = data.memory_usage(deep=True).sum() / (1024 * 1024) + + print(f"✓ Created {num_households:,} households ({len(data):,} people) in {creation_time:.3f} seconds") + print(f" Memory usage: {memory_mb:.2f} MB") + print(f" Speed: {num_households / creation_time:.0f} households/second") + + print("\n" + "=" * 60) + print("Performance Summary:") + print("=" * 60) + print(f"{'Households':<12} {'Time (s)':<10} {'Speed (hh/s)':<15} {'People':<10}") + print("-" * 60) + + for num_households, creation_time in timing_results: + speed = num_households / creation_time + people = num_households * 4 + print(f"{num_households:<12,} {creation_time:<10.3f} {speed:<15,.0f} {people:<10,}") + + print("\nDataset generation completed successfully!") + print("Note: Data is held in memory only - no files saved to disk.") + + +if __name__ == "__main__": + main() + +# %% +# For inspection: +# example_data = make_data(3) # 3 households = 12 people +