Skip to content

Optimizations: tt.shared.join, several ttsim.interface_dag_elements.fail_if functions #41

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 14, 2025

Conversation

JuergenWiemers
Copy link
Collaborator

@JuergenWiemers JuergenWiemers commented Aug 12, 2025

  • I noticed that in my previous benchmarks (Optimize aggregation_numpy.sum_by_p_id ... #40) I had chosen Bürgergeld, Wohngeld, and Kinderzuschlag as targets, but I had overwritten them in the mapper. Removing the overrides revealed new bottlenecks in the benchmarks.

  • The crucial change is in tt.shared.join. It looks like we currently use an NxN lookup table; we really shouldn't do that. 😅With this PR, runtime complexity is $O(log(n))$ instead of $O(n)$ and memory usage is $O(n)$ instead of $O(n^2)$.

  • The optimizations in the fail_if functions mostly help the "pre-processing stage".

  • I can't even show a before<-> after comparison table because with the current tt.shared.join even as little as N=32768 households is too much for my laptop.

  • For NumPy, timings look fine with this PR. Note that the timings already include the effects of another optimization PR in the GETTSIM repository , which I will open in a minute (Optimize bürgergeld__in_anderer_bg_als_kindergeldempfänger gettsim#1076). Without this "twin PR", timings look much worse.

  • However, there is another problem. I can run JAX only up to 32,768 households; after that, JAX's memory usage seems to increase quadratically in N; with N=1,048,576 JAX wants to allocate 35TB (!). (@mj023, do you have an idea what is going on? You can find the updated benchmark scripts below.)

    Output for JAX

    ============================================================
    Testing jax backend
    ============================================================
    Preparing environment for jax backend...
      Resetting session state for jax backend...
      Garbage collection completed
      JAX compilation cache cleared
    Running benchmark: 32,767 households, jax backend
      Generating data...
    Created DataFrame with 131068 rows (32767 households)
      Stage 1: Data preprocessing and DAG creation...
        JAX operations synchronized
      Stage 2: Computation only...
        JAX operations synchronized
      Stage 3: Convert raw results to DataFrame...
        JAX operations synchronizedStage 1 (pre-processing): 4.7734s (10.7%)
      ✓ Stage 2 (computation): 37.5159s (84.3%)
      ✓ Stage 3 (post-processing): 2.2136s (5.0%)
      ✓ Total time: 44.5029 seconds
      Result shape: (131068, 8)
      Memory usage: 169.2 MB121.2 MB (Δ-48.0 MB)
      Stage 1 hash: e1ef8c7680e3221b...
      Stage 2 hash: c4185b3248934f48...
      Stage 3 hash: 52f96b14596cceef...
    
    Running benchmark: 32,768 households, jax backend
      Generating data...
    Created DataFrame with 131072 rows (32768 households)
      Stage 1: Data preprocessing and DAG creation...
        JAX operations synchronized
      Stage 2: Computation only...
        JAX operations synchronized
      Stage 3: Convert raw results to DataFrame...
        JAX operations synchronizedStage 1 (pre-processing): 2.2594s (6.8%)
      ✓ Stage 2 (computation): 28.9499s (87.5%)
      ✓ Stage 3 (post-processing): 1.8655s (5.6%)
      ✓ Total time: 33.0747 seconds
      Result shape: (131072, 8)
      Memory usage: 162.1 MB105.6 MB (Δ-56.5 MB)
      Stage 1 hash: 3f044e76bf532912...
      Stage 2 hash: c4185b3248934f48...
      Stage 3 hash: 625ff853d37a8962...
    
    Running benchmark: 65,536 households, jax backend
      Generating data...
    Created DataFrame with 262144 rows (65536 households)
      Stage 1: Data preprocessing and DAG creation...
        JAX operations synchronized
      Stage 2: Computation only...
        JAX operations synchronizedFailed: INTERNAL: Buffer Definition Event: Error preparing computation: Out of memory allocating 139639391848 bytes.
    
    Running benchmark: 131,072 households, jax backend
      Generating data...
    Created DataFrame with 524288 rows (131072 households)
      Stage 1: Data preprocessing and DAG creation...
        JAX operations synchronized
      Stage 2: Computation only...
        JAX operations synchronizedFailed: INTERNAL: Buffer Definition Event: Error preparing computation: Out of memory allocating 558451656296 bytes.
    
    Running benchmark: 262,144 households, jax backend
      Generating data...
    Created DataFrame with 1048576 rows (262144 households)
      Stage 1: Data preprocessing and DAG creation...
        JAX operations synchronized
      Stage 2: Computation only...
        JAX operations synchronizedFailed: INTERNAL: Buffer Definition Event: Error preparing computation: Out of memory allocating 2233594807912 bytes.
    
    Running benchmark: 524,288 households, jax backend
      Generating data...
    Created DataFrame with 2097152 rows (524288 households)
      Stage 1: Data preprocessing and DAG creation...
        JAX operations synchronized
      Stage 2: Computation only...
        JAX operations synchronizedFailed: INTERNAL: Buffer Definition Event: Error preparing computation: Out of memory allocating 8933947213416 bytes.
    
    Running benchmark: 1,048,576 households, jax backend
      Generating data...
    Created DataFrame with 4194304 rows (1048576 households)
      Stage 1: Data preprocessing and DAG creation...
        JAX operations synchronized
      Stage 2: Computation only...
        JAX operations synchronizedFailed: INTERNAL: Buffer Definition Event: Error preparing computation: Out of memory allocating 35734958376552 bytes.

  • Benchmark scripts: Here are the benchmark_scripts.zip

Timings with this PR

========================================================================================================================
3-STAGE TIMING BREAKDOWN
========================================================================================================================

=====================================================================================================
PERFORMANCE COMPARISON NUMPY <-> JAX
========================================================================================================
Households  Stage             NUMPY hash  JAX hash    NUMPY (s)   JAX (s)     Speedup
--------------------------------------------------------------------------------------------------------
32,767      pre-processing    -           -           0.9343      4.7734      1/5.11x
            computation       c02820a2    c4185b32    0.3726      37.5159     1/100.69x
            post-processing   fc9530ad    52f96b14    0.7784      2.2136      1/2.84x
            total time                                2.0852      44.5029     1/21.34x
--------------------------------------------------------------------------------------------------------
32,768      pre-processing    -           -           0.8053      2.2594      1/2.81x
            computation       84d1cbc8    c4185b32    0.4042      28.9499     1/71.63x
            post-processing   990dc20f    625ff853    0.7996      1.8655      1/2.33x
            total time                                2.0091      33.0747     1/16.46x
--------------------------------------------------------------------------------------------------------
65,536      pre-processing    -           -           0.8748      FAILED      N/A
            computation       a2fd2213    FAILED      0.7240      FAILED      N/A
            post-processing   c1d64a01    FAILED      0.7246      FAILED      N/A
            total time                                2.3234      FAILED      N/A
--------------------------------------------------------------------------------------------------------
131,072     pre-processing    -           -           1.0318      FAILED      N/A
            computation       8e5c852f    FAILED      1.4785      FAILED      N/A
            post-processing   dc3f1ea5    FAILED      0.8463      FAILED      N/A
            total time                                3.3566      FAILED      N/A
--------------------------------------------------------------------------------------------------------
262,144     pre-processing    -           -           1.3009      FAILED      N/A
            computation       a774f14b    FAILED      3.0404      FAILED      N/A
            post-processing   4a6e8c58    FAILED      0.7449      FAILED      N/A
            total time                                5.0862      FAILED      N/A
--------------------------------------------------------------------------------------------------------
524,288     pre-processing    -           -           2.0369      FAILED      N/A
            computation       de8b33a9    FAILED      6.3293      FAILED      N/A
            post-processing   197c7065    FAILED      0.8779      FAILED      N/A
            total time                                9.2441      FAILED      N/A
--------------------------------------------------------------------------------------------------------
1,048,576   pre-processing    -           -           3.3479      FAILED      N/A
            computation       9c6cc63f    FAILED      13.6209     FAILED      N/A
            post-processing   7b025b7b    FAILED      0.9363      FAILED      N/A
            total time                                17.9051     FAILED      N/A
--------------------------------------------------------------------------------------------------------

========================================================================================================================
MEMORY USAGE COMPARISON
========================================================================================================================
Households  NumPy Init  NumPy Final JAX Init    JAX Final   NumPy Δ     JAX Δ
------------------------------------------------------------------------------------------------------------------------
32,767      177.7       240.2       169.2       121.2       62.4        -48.0
32,768      206.6       243.8       162.1       105.6       37.2        -56.5
65,536      251.8       335.5       FAILED      FAILED      83.7        FAILED
131,072     354.4       508.5       FAILED      FAILED      154.1       FAILED
262,144     541.1       859.4       FAILED      FAILED      318.3       FAILED
524,288     926.3       1546.7      FAILED      FAILED      620.4       FAILED
1,048,576   1680.6      2814.8      FAILED      FAILED      1134.3      FAILED
------------------------------------------------------------------------------------------------------------------------

Legend:
  Stage 1: Data preprocessing & DAG creation
  Stage 2: Core computation (tax/transfer calculations)
  Stage 3: DataFrame formatting (JAXpandas conversion)
  Init/Final: Memory usage before/after execution
  Δ: Memory increase during execution/✗: Hash verification (results match/differ)

Copy link

codecov bot commented Aug 12, 2025

Codecov Report

❌ Patch coverage is 86.95652% with 3 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/ttsim/interface_dag_elements/fail_if.py 78.57% 2 Missing and 1 partial ⚠️

📢 Thoughts on this report? Let us know!

@mj023
Copy link
Collaborator

mj023 commented Aug 12, 2025

It is a problem in GETTSIM I think. This part seems to be the culprit: https://github.com/ttsim-dev/gettsim/blob/701a2a68544317508916881c00ee7a8aa4ea342c/src/gettsim/germany/ids.py#L121-L124

np.isin() uses a better algorithm than jax.numpy.isin() which is ~O(2n) in runtime and O(n) in memory (Link to the discussion). If I choose method='sort' for jnp.isin(), I can at least run it with Jax, without the memory blowing up and it's similarly fast, but for some reason this keyword is named kind in np.isin(), so it breaks our xnp strategy.

Maybe there is a workaround and we don't need isin() at all, but if there isn't, we would have to implement Numpys algorithm or at least find a way to pass method='sort' without it breaking the Numpy version of the function.

@hmgaudecker
Copy link
Contributor

Thanks!!!

I am unlikely to find much time today, but since vectorisation does not interfere, we could just pass backend as an argument and do something like:

if backend == "jax":
    isin = partial(xnp.isin, method="sort")
else:
    isin = xnp.isin

But then, if you feel like it is not too hard to implement that algo, maybe it could become a contribution to Jax?

@hmgaudecker
Copy link
Contributor

Coming back to @mj023's point in an earlier PR, maybe we should sort by p_id when creating processed_data and assume sort order subsequently? While join looks more general, it really only works for p_id , I think.

Just need to go back to the original order when going from raw_results to results.

@JuergenWiemers
Copy link
Collaborator Author

if backend == "jax":
    isin = partial(xnp.isin, method="sort")
else:
    isin = xnp.isin

But then, if you feel like it is not too hard to implement that algo, maybe it could become a contribution to Jax?

I just implemented Hans-Martin's suggestion, works perfectly. @mj023 I could simply add this to my "twin PR" ttsim-dev/gettsim#1076. Or would you prefer an alternative solution?

Coming back to @mj023's point in an earlier PR, maybe we should sort by p_id when creating processed_data and assume sort order subsequently? While join looks more general, it really only works for p_id , I think.

Sounds extremely reasonable to me! But that should probably be done in a separate PR, right?

@hmgaudecker
Copy link
Contributor

Coming back to @mj023's point in an earlier PR, maybe we should sort by p_id when creating processed_data and assume sort order subsequently? While join looks more general, it really only works for p_id , I think.

Sounds extremely reasonable to me! But that should probably be done in a separate PR, right?

This would seem closely-related enough to me, but could go either way. Whatever is easier for you guys, I won't be able to do more that looking at code until next week.

@JuergenWiemers
Copy link
Collaborator Author

Two more thoughts:

  • I should add every existing GETTSIM target leaf to the tt_targets and override none of the intermediate nodes in the mapper in my benchmark script. This way the script should be able to catch badly scaling code in any GETTSIM policy function or in TTSIM.
  • In order to detect possible runtime/scaling regressions in GETTSIM/TTSIM that may be accidentally introduced in a future PR, it might even be useful to perform this type of benchmark (semi-)automatically. For example, the maintainers of the Julia language have set up a Github bot they call "nanosoldier". If called in a PR comment with the keyword @nanosoldier, it runs a large benchmark suite on the most popular packages in the Julia ecosystem, once for the PR and once for the main branch and compares results. Here is a recent example and corresponding report. I have no idea how complicated it is to set something like this up, and I guess Github doesn't do it for free. Or maybe it is possible to include this kind of benchmark in CI? With up to 1M households, it currently only takes about 4 minutes to run the benchmark twice (for main and PR).

@JuergenWiemers
Copy link
Collaborator Author

Updated timings with fix for JAX's memory explosion:

========================================================================================================================
3-STAGE TIMING BREAKDOWN
========================================================================================================================

=====================================================================================================
PERFORMANCE COMPARISON NUMPY <-> JAX
========================================================================================================
Households  Stage             NUMPY hash  JAX hash    NUMPY (s)   JAX (s)     Speedup
--------------------------------------------------------------------------------------------------------
32,767      pre-processing    -           -           0.9245      4.8034      1/5.20x
            computation       c02820a2    c4185b32    0.3681      2.6295      1/7.14x
            post-processing   fc9530ad    52f96b14    0.6964      0.9190      1/1.32x
            total time                                1.9890      8.3519      1/4.20x
--------------------------------------------------------------------------------------------------------
32,768      pre-processing    -           -           0.7833      1.6503      1/2.11x
            computation       84d1cbc8    c4185b32    0.3677      2.5609      1/6.96x
            post-processing   990dc20f    625ff853    0.7821      0.9214      1/1.18x
            total time                                1.9331      5.1326      1/2.66x
--------------------------------------------------------------------------------------------------------
65,536      pre-processing    -           -           0.8489      1.7309      1/2.04x
            computation       a2fd2213    c4185b32    0.6934      2.7894      1/4.02x
            post-processing   c1d64a01    80e19c95    0.7032      0.9185      1/1.31x
            total time                                2.2455      5.4388      1/2.42x
--------------------------------------------------------------------------------------------------------
131,072     pre-processing    -           -           0.9830      1.7680      1/1.80x
            computation       8e5c852f    c4185b32    1.4201      3.4471      1/2.43x
            post-processing   dc3f1ea5    c9f77097    0.8230      0.9370      1/1.14x
            total time                                3.2261      6.1521      1/1.91x
--------------------------------------------------------------------------------------------------------
262,144     pre-processing    -           -           1.2871      2.2049      1/1.71x
            computation       a774f14b    c4185b32    2.9599      4.7176      1/1.59x
            post-processing   4a6e8c58    c8a25e8c    0.7397      1.2080      1/1.63x
            total time                                4.9867      8.1305      1/1.63x
--------------------------------------------------------------------------------------------------------
524,288     pre-processing    -           -           1.9231      3.3895      1/1.76x
            computation       de8b33a9    c4185b32    7.2729      7.3689      1/1.01x
            post-processing   197c7065    5a9bbda7    0.8154      0.9543      1/1.17x
            total time                                10.0114     11.7126     1/1.17x
--------------------------------------------------------------------------------------------------------
1,048,576   pre-processing    -           -           3.5525      5.9978      1/1.69x
            computation       9c6cc63f    c4185b32    15.6263     12.7934     1.22x
            post-processing   7b025b7b    f2444df8    1.0549      0.9939      1.06x
            total time                                20.2337     19.7851     1.02x
--------------------------------------------------------------------------------------------------------

========================================================================================================================
MEMORY USAGE COMPARISON
========================================================================================================================
Households  NumPy Init  NumPy Final JAX Init    JAX Final   NumPy Δ     JAX Δ
------------------------------------------------------------------------------------------------------------------------
32,767      177.9       241.0       161.1       577.2       63.1        416.1
32,768      207.2       244.0       499.2       600.1       36.8        100.9
65,536      254.1       336.2       583.8       707.9       82.1        124.2
131,072     353.9       508.6       746.8       906.0       154.7       159.2
262,144     541.1       860.0       994.7       1241.5      319.0       246.8
524,288     926.9       1545.9      1473.6      1872.9      619.0       399.3
1,048,576   1679.8      2357.9      2410.9      3130.3      678.1       719.4
------------------------------------------------------------------------------------------------------------------------

Legend:
  Stage 1: Data preprocessing & DAG creation
  Stage 2: Core computation (tax/transfer calculations)
  Stage 3: DataFrame formatting (JAXpandas conversion)
  Init/Final: Memory usage before/after execution
  Δ: Memory increase during execution/✗: Hash verification (results match/differ)

@mj023
Copy link
Collaborator

mj023 commented Aug 13, 2025

This is a good solution, just add it to the "twin" PR. I initially thought it would not work, but obviously the backend isn't traced.

Coming back to @mj023's point in an earlier PR, maybe we should sort by p_id when creating processed_data and assume sort order subsequently? While join looks more general, it really only works for p_id , I think.

Sounds extremely reasonable to me! But that should probably be done in a separate PR, right?

Yes, I think this would be good to do, as of now we are really spending a lot of time sorting. I can start the PR for that, but I can only start with it next week.

But then, if you feel like it is not too hard to implement that algo, maybe it could become a contribution to Jax?

I think it should be fairly easy, as the algo is already in numpy. Maybe I can get them to change the keyword too. I will also try that next week.

@JuergenWiemers
Copy link
Collaborator Author

  • I should add every existing GETTSIM target leaf to the tt_targets and override none of the intermediate nodes in the mapper in my benchmark script. This way the script should be able to catch badly scaling code in any GETTSIM policy function or in TTSIM.
  • Did that, and good news: Nothing blows up, profiles look good, timings remain very reasonable. It takes a bit longer than previously, but that's expected since there is much more to compute now.
  • I only tried policy_date_str="2025-01-01", so there still might be problematic code lurking in the dark corners of the policy functions that are no longer active in 2025. However, I also searched for potentially problematic "hot" Python loops in all of the GETTSIM/TTSIM codebase. Couldn't find any.
  • Updated benchmark code: benchmark.zip
  • (@hmgaudecker Regarding making a PR for the benchmark code in https://github.com/ttsim-dev/gettsim-code-for-picking: Is it okay if I push it with --no-verify? The code is 99% AI-generated and ruff almost gets a heart attack because of the ~100 "print"-calls...
========================================================================================================================
3-STAGE TIMING BREAKDOWN
========================================================================================================================

=====================================================================================================
PERFORMANCE COMPARISON NUMPY <-> JAX
========================================================================================================
Households  Stage             NUMPY hash  JAX hash    NUMPY (s)   JAX (s)     Speedup
--------------------------------------------------------------------------------------------------------
32,767      pre-processing    -           -           0.9479      4.9022      1/5.17x
            computation       11efd36d    f944c56e    0.5979      4.2553      1/7.12x
            post-processing   45b763ae    877b7593    0.7524      1.1042      1/1.47x
            total time                                2.2982      10.2617     1/4.47x
--------------------------------------------------------------------------------------------------------
32,768      pre-processing    -           -           0.8378      1.6330      1/1.95x
            computation       61f33b4d    f944c56e    0.5954      4.5720      1/7.68x
            post-processing   8c8dd2e0    5fdfe6ae    0.8462      0.9738      1/1.15x
            total time                                2.2794      7.1788      1/3.15x
--------------------------------------------------------------------------------------------------------
65,536      pre-processing    -           -           0.9893      1.8400      1/1.86x
            computation       0a208989    f944c56e    1.1254      4.5157      1/4.01x
            post-processing   ad8e7a36    2f70a0fb    0.7994      0.9715      1/1.22x
            total time                                2.9142      7.3273      1/2.51x
--------------------------------------------------------------------------------------------------------
131,072     pre-processing    -           -           1.0842      2.0213      1/1.86x
            computation       d13051a5    f944c56e    2.2814      5.2040      1/2.28x
            post-processing   25c30f73    fb1aa29a    0.8906      1.0458      1/1.17x
            total time                                4.2562      8.2710      1/1.94x
--------------------------------------------------------------------------------------------------------
262,144     pre-processing    -           -           1.3893      2.7056      1/1.95x
            computation       78690020    f944c56e    4.6557      7.2447      1/1.56x
            post-processing   b502fd8e    159d25cb    0.8265      1.0649      1/1.29x
            total time                                6.8715      11.0152     1/1.60x
--------------------------------------------------------------------------------------------------------
524,288     pre-processing    -           -           2.1520      3.5982      1/1.67x
            computation       4df6d316    f944c56e    11.5373     10.7043     1.08x
            post-processing   f8bde208    b2729cce    1.0169      1.0686      1/1.05x
            total time                                14.7062     15.3710     1/1.05x
--------------------------------------------------------------------------------------------------------
1,048,576   pre-processing    -           -           3.6790      6.1551      1/1.67x
            computation       5058ecdb    f944c56e    31.3990     17.8046     1.76x
            post-processing   b62c58a2    17ff396f    1.6063      1.2503      1.28x
            total time                                36.6844     25.2100     1.46x
--------------------------------------------------------------------------------------------------------

========================================================================================================================
MEMORY USAGE COMPARISON
========================================================================================================================
Households  NumPy Init  NumPy Final JAX Init    JAX Final   NumPy Δ     JAX Δ
------------------------------------------------------------------------------------------------------------------------
32,767      184.2       301.6       164.0       671.7       117.4       507.8
32,768      227.4       304.1       525.2       705.4       76.7        180.2
65,536      280.2       456.0       602.8       841.2       175.8       238.4
131,072     403.4       747.4       824.7       1085.7      344.0       261.0
262,144     639.4       1331.5      1157.0      1560.8      692.1       403.8
524,288     1119.3      1982.8      1773.7      2491.7      863.5       718.0
1,048,576   2046.3      3256.6      2983.6      4307.1      1210.3      1323.5
------------------------------------------------------------------------------------------------------------------------

@hmgaudecker
Copy link
Contributor

(@hmgaudecker Regarding making a PR for the benchmark code in https://github.com/ttsim-dev/gettsim-code-for-picking: Is it okay if I push it with --no-verify? The code is 99% AI-generated and ruff almost gets a heart attack because of the ~100 "print"-calls...

Sure, though the pre-commit bot may complain... But nobody cares over there, once we move to routine performance checks as you detail above (maybe make an issue out of it? we should definitely do that at some point), we can still change that.

Thanks a lot!

Copy link
Contributor

@hmgaudecker hmgaudecker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot, just some minor stylistic issues! And don't forget the changelog!

Copy link
Contributor

@hmgaudecker hmgaudecker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excellent, thanks!!!

@hmgaudecker hmgaudecker merged commit 8d88fa7 into main Aug 14, 2025
11 checks passed
@hmgaudecker hmgaudecker deleted the JW/dev/optimizations branch August 14, 2025 08:12
hmgaudecker pushed a commit to ttsim-dev/gettsim that referenced this pull request Aug 21, 2025
"Twin PR" to ttsim-dev/ttsim#41. Uses the
optimized `tt.shared.join` in that PR.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants