-
Notifications
You must be signed in to change notification settings - Fork 0
Optimizations: tt.shared.join
, several ttsim.interface_dag_elements.fail_if
functions
#41
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov Report❌ Patch coverage is
📢 Thoughts on this report? Let us know! |
It is a problem in GETTSIM I think. This part seems to be the culprit: https://github.com/ttsim-dev/gettsim/blob/701a2a68544317508916881c00ee7a8aa4ea342c/src/gettsim/germany/ids.py#L121-L124
Maybe there is a workaround and we don't need |
Thanks!!! I am unlikely to find much time today, but since vectorisation does not interfere, we could just pass if backend == "jax":
isin = partial(xnp.isin, method="sort")
else:
isin = xnp.isin But then, if you feel like it is not too hard to implement that algo, maybe it could become a contribution to Jax? |
Coming back to @mj023's point in an earlier PR, maybe we should sort by Just need to go back to the original order when going from |
I just implemented Hans-Martin's suggestion, works perfectly. @mj023 I could simply add this to my "twin PR" ttsim-dev/gettsim#1076. Or would you prefer an alternative solution?
Sounds extremely reasonable to me! But that should probably be done in a separate PR, right? |
This would seem closely-related enough to me, but could go either way. Whatever is easier for you guys, I won't be able to do more that looking at code until next week. |
Two more thoughts:
|
Updated timings with fix for JAX's memory explosion: ========================================================================================================================
3-STAGE TIMING BREAKDOWN
========================================================================================================================
=====================================================================================================
PERFORMANCE COMPARISON NUMPY <-> JAX
========================================================================================================
Households Stage NUMPY hash JAX hash NUMPY (s) JAX (s) Speedup
--------------------------------------------------------------------------------------------------------
32,767 pre-processing - - 0.9245 4.8034 1/5.20x
computation c02820a2 c4185b32 0.3681 2.6295 1/7.14x
post-processing fc9530ad 52f96b14 0.6964 0.9190 1/1.32x
total time 1.9890 8.3519 1/4.20x
--------------------------------------------------------------------------------------------------------
32,768 pre-processing - - 0.7833 1.6503 1/2.11x
computation 84d1cbc8 c4185b32 0.3677 2.5609 1/6.96x
post-processing 990dc20f 625ff853 0.7821 0.9214 1/1.18x
total time 1.9331 5.1326 1/2.66x
--------------------------------------------------------------------------------------------------------
65,536 pre-processing - - 0.8489 1.7309 1/2.04x
computation a2fd2213 c4185b32 0.6934 2.7894 1/4.02x
post-processing c1d64a01 80e19c95 0.7032 0.9185 1/1.31x
total time 2.2455 5.4388 1/2.42x
--------------------------------------------------------------------------------------------------------
131,072 pre-processing - - 0.9830 1.7680 1/1.80x
computation 8e5c852f c4185b32 1.4201 3.4471 1/2.43x
post-processing dc3f1ea5 c9f77097 0.8230 0.9370 1/1.14x
total time 3.2261 6.1521 1/1.91x
--------------------------------------------------------------------------------------------------------
262,144 pre-processing - - 1.2871 2.2049 1/1.71x
computation a774f14b c4185b32 2.9599 4.7176 1/1.59x
post-processing 4a6e8c58 c8a25e8c 0.7397 1.2080 1/1.63x
total time 4.9867 8.1305 1/1.63x
--------------------------------------------------------------------------------------------------------
524,288 pre-processing - - 1.9231 3.3895 1/1.76x
computation de8b33a9 c4185b32 7.2729 7.3689 1/1.01x
post-processing 197c7065 5a9bbda7 0.8154 0.9543 1/1.17x
total time 10.0114 11.7126 1/1.17x
--------------------------------------------------------------------------------------------------------
1,048,576 pre-processing - - 3.5525 5.9978 1/1.69x
computation 9c6cc63f c4185b32 15.6263 12.7934 1.22x
post-processing 7b025b7b f2444df8 1.0549 0.9939 1.06x
total time 20.2337 19.7851 1.02x
--------------------------------------------------------------------------------------------------------
========================================================================================================================
MEMORY USAGE COMPARISON
========================================================================================================================
Households NumPy Init NumPy Final JAX Init JAX Final NumPy Δ JAX Δ
------------------------------------------------------------------------------------------------------------------------
32,767 177.9 241.0 161.1 577.2 63.1 416.1
32,768 207.2 244.0 499.2 600.1 36.8 100.9
65,536 254.1 336.2 583.8 707.9 82.1 124.2
131,072 353.9 508.6 746.8 906.0 154.7 159.2
262,144 541.1 860.0 994.7 1241.5 319.0 246.8
524,288 926.9 1545.9 1473.6 1872.9 619.0 399.3
1,048,576 1679.8 2357.9 2410.9 3130.3 678.1 719.4
------------------------------------------------------------------------------------------------------------------------
Legend:
Stage 1: Data preprocessing & DAG creation
Stage 2: Core computation (tax/transfer calculations)
Stage 3: DataFrame formatting (JAX → pandas conversion)
Init/Final: Memory usage before/after execution
Δ: Memory increase during execution
✓/✗: Hash verification (results match/differ) |
This is a good solution, just add it to the "twin" PR. I initially thought it would not work, but obviously the backend isn't traced.
Yes, I think this would be good to do, as of now we are really spending a lot of time sorting. I can start the PR for that, but I can only start with it next week.
I think it should be fairly easy, as the algo is already in numpy. Maybe I can get them to change the keyword too. I will also try that next week. |
========================================================================================================================
3-STAGE TIMING BREAKDOWN
========================================================================================================================
=====================================================================================================
PERFORMANCE COMPARISON NUMPY <-> JAX
========================================================================================================
Households Stage NUMPY hash JAX hash NUMPY (s) JAX (s) Speedup
--------------------------------------------------------------------------------------------------------
32,767 pre-processing - - 0.9479 4.9022 1/5.17x
computation 11efd36d f944c56e 0.5979 4.2553 1/7.12x
post-processing 45b763ae 877b7593 0.7524 1.1042 1/1.47x
total time 2.2982 10.2617 1/4.47x
--------------------------------------------------------------------------------------------------------
32,768 pre-processing - - 0.8378 1.6330 1/1.95x
computation 61f33b4d f944c56e 0.5954 4.5720 1/7.68x
post-processing 8c8dd2e0 5fdfe6ae 0.8462 0.9738 1/1.15x
total time 2.2794 7.1788 1/3.15x
--------------------------------------------------------------------------------------------------------
65,536 pre-processing - - 0.9893 1.8400 1/1.86x
computation 0a208989 f944c56e 1.1254 4.5157 1/4.01x
post-processing ad8e7a36 2f70a0fb 0.7994 0.9715 1/1.22x
total time 2.9142 7.3273 1/2.51x
--------------------------------------------------------------------------------------------------------
131,072 pre-processing - - 1.0842 2.0213 1/1.86x
computation d13051a5 f944c56e 2.2814 5.2040 1/2.28x
post-processing 25c30f73 fb1aa29a 0.8906 1.0458 1/1.17x
total time 4.2562 8.2710 1/1.94x
--------------------------------------------------------------------------------------------------------
262,144 pre-processing - - 1.3893 2.7056 1/1.95x
computation 78690020 f944c56e 4.6557 7.2447 1/1.56x
post-processing b502fd8e 159d25cb 0.8265 1.0649 1/1.29x
total time 6.8715 11.0152 1/1.60x
--------------------------------------------------------------------------------------------------------
524,288 pre-processing - - 2.1520 3.5982 1/1.67x
computation 4df6d316 f944c56e 11.5373 10.7043 1.08x
post-processing f8bde208 b2729cce 1.0169 1.0686 1/1.05x
total time 14.7062 15.3710 1/1.05x
--------------------------------------------------------------------------------------------------------
1,048,576 pre-processing - - 3.6790 6.1551 1/1.67x
computation 5058ecdb f944c56e 31.3990 17.8046 1.76x
post-processing b62c58a2 17ff396f 1.6063 1.2503 1.28x
total time 36.6844 25.2100 1.46x
--------------------------------------------------------------------------------------------------------
========================================================================================================================
MEMORY USAGE COMPARISON
========================================================================================================================
Households NumPy Init NumPy Final JAX Init JAX Final NumPy Δ JAX Δ
------------------------------------------------------------------------------------------------------------------------
32,767 184.2 301.6 164.0 671.7 117.4 507.8
32,768 227.4 304.1 525.2 705.4 76.7 180.2
65,536 280.2 456.0 602.8 841.2 175.8 238.4
131,072 403.4 747.4 824.7 1085.7 344.0 261.0
262,144 639.4 1331.5 1157.0 1560.8 692.1 403.8
524,288 1119.3 1982.8 1773.7 2491.7 863.5 718.0
1,048,576 2046.3 3256.6 2983.6 4307.1 1210.3 1323.5
------------------------------------------------------------------------------------------------------------------------ |
Sure, though the pre-commit bot may complain... But nobody cares over there, once we move to routine performance checks as you detail above (maybe make an issue out of it? we should definitely do that at some point), we can still change that. Thanks a lot! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot, just some minor stylistic issues! And don't forget the changelog!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent, thanks!!!
"Twin PR" to ttsim-dev/ttsim#41. Uses the optimized `tt.shared.join` in that PR.
I noticed that in my previous benchmarks (Optimize
aggregation_numpy.sum_by_p_id
... #40) I had chosen Bürgergeld, Wohngeld, and Kinderzuschlag as targets, but I had overwritten them in the mapper. Removing the overrides revealed new bottlenecks in the benchmarks.The crucial change is in$O(log(n))$ instead of $O(n)$ and memory usage is $O(n)$ instead of $O(n^2)$ .
tt.shared.join
. It looks like we currently use an NxN lookup table; we really shouldn't do that. 😅With this PR, runtime complexity isThe optimizations in the
fail_if
functions mostly help the "pre-processing stage".I can't even show a before<-> after comparison table because with the current
tt.shared.join
even as little asN=32768
households is too much for my laptop.For NumPy, timings look fine with this PR. Note that the timings already include the effects of another optimization PR in the GETTSIM repository
, which I will open in a minute(Optimizebürgergeld__in_anderer_bg_als_kindergeldempfänger
gettsim#1076). Without this "twin PR", timings look much worse.However, there is another problem. I can run JAX only up to 32,768 households; after that, JAX's memory usage seems to increase quadratically in
N
; withN=1,048,576
JAX wants to allocate 35TB (!). (@mj023, do you have an idea what is going on? You can find the updated benchmark scripts below.)Output for JAX
Benchmark scripts: Here are the benchmark_scripts.zip
Timings with this PR