Skip to content

Commit 8d88fa7

Browse files
Optimizations: tt.shared.join, several ttsim.interface_dag_elements.fail_if functions (#41)
- The crucial change is in `tt.shared.join`. It looks like we currently use an NxN lookup table; we really shouldn't do that. 😅With this PR, runtime complexity is $O(log(n))$ instead of $O(n)$ and memory usage is $O(n)$ instead of $O(n^2)$. - The optimizations in the `fail_if` functions mostly help the "pre-processing stage".
1 parent ad625be commit 8d88fa7

File tree

3 files changed

+52
-54
lines changed

3 files changed

+52
-54
lines changed

CHANGES.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ releases are available on [Anaconda.org](https://anaconda.org/conda-forge/ttsim)
66

77
## v1.0.1 — unpublished
88

9+
- {gh}`41` Improve performance of `tt.shared.join` and
10+
`ttsim.interface_dag_elements.fail_if.foreign_keys_are_invalid_in_data`
11+
({ghuser}`JuergenWiemers`)
12+
913
- {gh}`40` Improve performance of `aggregation_numpy` and `data_converters`
1014
({ghuser}`JuergenWiemers`)
1115

src/ttsim/interface_dag_elements/fail_if.py

Lines changed: 28 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -346,19 +346,17 @@ def environment_is_invalid(
346346
)
347347

348348
flat_policy_environment = dt.flatten_to_tree_paths(policy_environment)
349-
paths_with_incorrect_leaf_names = ""
350-
for p, f in flat_policy_environment.items():
351-
if hasattr(f, "leaf_name") and p[-1] != f.leaf_name:
352-
paths_with_incorrect_leaf_names += f" {p}\n"
349+
paths_with_incorrect_leaf_names = [
350+
f" {p}"
351+
for p, f in flat_policy_environment.items()
352+
if hasattr(f, "leaf_name") and p[-1] != f.leaf_name
353+
]
353354
if paths_with_incorrect_leaf_names:
354-
msg = (
355-
format_errors_and_warnings(
356-
"The last element of the object's path must be the same as the leaf name "
357-
"of that object. The following tree paths are not compatible with the "
358-
"corresponding object in the policy environment:\n\n"
359-
)
360-
+ paths_with_incorrect_leaf_names
361-
)
355+
msg = format_errors_and_warnings(
356+
"The last element of the object's path must be the same as the leaf name "
357+
"of that object. The following tree paths are not compatible with the "
358+
"corresponding object in the policy environment:\n\n"
359+
) + "\n".join(paths_with_incorrect_leaf_names)
362360
raise ValueError(msg)
363361

364362

@@ -390,31 +388,28 @@ def foreign_keys_are_invalid_in_data(
390388
if fk_name in labels__root_nodes:
391389
path = dt.tree_path_from_qname(fk_name)
392390
# Referenced `p_id` must exist in the input data
393-
if not all(i in valid_ids for i in input_data__flat[path].tolist()):
394-
message = format_errors_and_warnings(
395-
f"""
396-
For {path}, the following are not a valid p_id in the input
397-
data: {[i for i in input_data__flat[path] if i not in valid_ids]}.
398-
""",
391+
data_array = input_data__flat[path]
392+
valid_ids_array = numpy.array(list(valid_ids))
393+
valid_mask = numpy.isin(data_array, valid_ids_array)
394+
if not numpy.all(valid_mask):
395+
invalid_ids = data_array[~valid_mask].tolist()
396+
message = (
397+
f"For {path}, the following are not a valid p_id in the input "
398+
f"data: {invalid_ids}."
399399
)
400400
raise ValueError(message)
401401

402402
if fk.foreign_key_type == FKType.MUST_NOT_POINT_TO_SELF:
403-
equal_to_pid_in_same_row = [
404-
i
405-
for i, j in zip(
406-
input_data__flat[path].tolist(),
407-
input_data__flat[("p_id",)].tolist(),
408-
strict=False,
409-
)
410-
if i == j
411-
]
412-
if any(equal_to_pid_in_same_row):
413-
message = format_errors_and_warnings(
414-
f"""
415-
For {path}, the following are equal to the p_id in the same
416-
row: {equal_to_pid_in_same_row}.
417-
""",
403+
# Optimized check using numpy operations instead of Python iteration
404+
data_array = input_data__flat[path]
405+
p_id_array = input_data__flat[("p_id",)]
406+
# Use vectorized equality check
407+
self_references = data_array == p_id_array
408+
if numpy.any(self_references):
409+
equal_to_pid_in_same_row = data_array[self_references].tolist()
410+
message = (
411+
f"For {path}, the following are equal to the p_id in the same "
412+
f"row: {equal_to_pid_in_same_row}."
418413
)
419414
raise ValueError(message)
420415

src/ttsim/tt/shared.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -66,28 +66,27 @@ def join(
6666
-------
6767
The joined array.
6868
"""
69-
# For each foreign key and for each primary key, check if they match
70-
matches_foreign_key = foreign_key[:, None] == primary_key
71-
72-
# For each foreign key, add a column with True at the end, to later fall back to
73-
# the value for unresolved foreign keys
74-
padded_matches_foreign_key = xnp.pad(
75-
matches_foreign_key,
76-
((0, 0), (0, 1)),
77-
"constant",
78-
constant_values=True,
69+
# First, get the sort order of primary_key to enable efficient lookup
70+
sort_indices = xnp.argsort(primary_key)
71+
sorted_primary_key = primary_key[sort_indices]
72+
sorted_target = target[sort_indices]
73+
74+
# Find where each foreign_key would be inserted in the sorted primary_key array
75+
positions = xnp.searchsorted(sorted_primary_key, foreign_key, side="left")
76+
77+
# Check if the foreign keys actually match the primary keys at those positions
78+
# Handle out-of-bounds positions
79+
valid_positions = positions < len(sorted_primary_key)
80+
matches = valid_positions & (
81+
sorted_primary_key[xnp.minimum(positions, len(sorted_primary_key) - 1)]
82+
== foreign_key
7983
)
8084

81-
# For each foreign key, compute the index of the first matching primary key
82-
indices = xnp.argmax(padded_matches_foreign_key, axis=1)
83-
84-
# Add the value for unresolved foreign keys at the end of the target array
85-
padded_targets = xnp.pad(
86-
target,
87-
(0, 1),
88-
"constant",
89-
constant_values=value_if_foreign_key_is_missing,
85+
# Create result array initialized with the missing value
86+
result = xnp.full_like(
87+
foreign_key, value_if_foreign_key_is_missing, dtype=target.dtype
9088
)
9189

92-
# Return the target at the index of the first matching primary key
93-
return padded_targets.take(indices)
90+
# Get the corresponding target values for valid matches, use 0 for invalid indices
91+
valid_indices = xnp.where(matches, positions, 0)
92+
return xnp.where(matches, sorted_target[valid_indices], result)

0 commit comments

Comments
 (0)