Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ releases are available on [Anaconda.org](https://anaconda.org/conda-forge/ttsim)

## v1.0.1 — unpublished

- {gh}`41` Improve performance of `tt.shared.join` and
`ttsim.interface_dag_elements.fail_if.foreign_keys_are_invalid_in_data`
({ghuser}`JuergenWiemers`)

- {gh}`40` Improve performance of `aggregation_numpy` and `data_converters`
({ghuser}`JuergenWiemers`)

Expand Down
61 changes: 28 additions & 33 deletions src/ttsim/interface_dag_elements/fail_if.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,19 +346,17 @@ def environment_is_invalid(
)

flat_policy_environment = dt.flatten_to_tree_paths(policy_environment)
paths_with_incorrect_leaf_names = ""
for p, f in flat_policy_environment.items():
if hasattr(f, "leaf_name") and p[-1] != f.leaf_name:
paths_with_incorrect_leaf_names += f" {p}\n"
paths_with_incorrect_leaf_names = [
f" {p}"
for p, f in flat_policy_environment.items()
if hasattr(f, "leaf_name") and p[-1] != f.leaf_name
]
if paths_with_incorrect_leaf_names:
msg = (
format_errors_and_warnings(
"The last element of the object's path must be the same as the leaf name "
"of that object. The following tree paths are not compatible with the "
"corresponding object in the policy environment:\n\n"
)
+ paths_with_incorrect_leaf_names
)
msg = format_errors_and_warnings(
"The last element of the object's path must be the same as the leaf name "
"of that object. The following tree paths are not compatible with the "
"corresponding object in the policy environment:\n\n"
) + "\n".join(paths_with_incorrect_leaf_names)
raise ValueError(msg)


Expand Down Expand Up @@ -390,31 +388,28 @@ def foreign_keys_are_invalid_in_data(
if fk_name in labels__root_nodes:
path = dt.tree_path_from_qname(fk_name)
# Referenced `p_id` must exist in the input data
if not all(i in valid_ids for i in input_data__flat[path].tolist()):
message = format_errors_and_warnings(
f"""
For {path}, the following are not a valid p_id in the input
data: {[i for i in input_data__flat[path] if i not in valid_ids]}.
""",
data_array = input_data__flat[path]
valid_ids_array = numpy.array(list(valid_ids))
valid_mask = numpy.isin(data_array, valid_ids_array)
if not numpy.all(valid_mask):
invalid_ids = data_array[~valid_mask].tolist()
message = (
f"For {path}, the following are not a valid p_id in the input "
f"data: {invalid_ids}."
)
raise ValueError(message)

if fk.foreign_key_type == FKType.MUST_NOT_POINT_TO_SELF:
equal_to_pid_in_same_row = [
i
for i, j in zip(
input_data__flat[path].tolist(),
input_data__flat[("p_id",)].tolist(),
strict=False,
)
if i == j
]
if any(equal_to_pid_in_same_row):
message = format_errors_and_warnings(
f"""
For {path}, the following are equal to the p_id in the same
row: {equal_to_pid_in_same_row}.
""",
# Optimized check using numpy operations instead of Python iteration
data_array = input_data__flat[path]
p_id_array = input_data__flat[("p_id",)]
# Use vectorized equality check
self_references = data_array == p_id_array
if numpy.any(self_references):
equal_to_pid_in_same_row = data_array[self_references].tolist()
message = (
f"For {path}, the following are equal to the p_id in the same "
f"row: {equal_to_pid_in_same_row}."
)
raise ValueError(message)

Expand Down
41 changes: 20 additions & 21 deletions src/ttsim/tt/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,28 +66,27 @@ def join(
-------
The joined array.
"""
# For each foreign key and for each primary key, check if they match
matches_foreign_key = foreign_key[:, None] == primary_key

# For each foreign key, add a column with True at the end, to later fall back to
# the value for unresolved foreign keys
padded_matches_foreign_key = xnp.pad(
matches_foreign_key,
((0, 0), (0, 1)),
"constant",
constant_values=True,
# First, get the sort order of primary_key to enable efficient lookup
sort_indices = xnp.argsort(primary_key)
sorted_primary_key = primary_key[sort_indices]
sorted_target = target[sort_indices]

# Find where each foreign_key would be inserted in the sorted primary_key array
positions = xnp.searchsorted(sorted_primary_key, foreign_key, side="left")

# Check if the foreign keys actually match the primary keys at those positions
# Handle out-of-bounds positions
valid_positions = positions < len(sorted_primary_key)
matches = valid_positions & (
sorted_primary_key[xnp.minimum(positions, len(sorted_primary_key) - 1)]
== foreign_key
)

# For each foreign key, compute the index of the first matching primary key
indices = xnp.argmax(padded_matches_foreign_key, axis=1)

# Add the value for unresolved foreign keys at the end of the target array
padded_targets = xnp.pad(
target,
(0, 1),
"constant",
constant_values=value_if_foreign_key_is_missing,
# Create result array initialized with the missing value
result = xnp.full_like(
foreign_key, value_if_foreign_key_is_missing, dtype=target.dtype
)

# Return the target at the index of the first matching primary key
return padded_targets.take(indices)
# Get the corresponding target values for valid matches, use 0 for invalid indices
valid_indices = xnp.where(matches, positions, 0)
return xnp.where(matches, sorted_target[valid_indices], result)
Loading