@@ -346,19 +346,17 @@ def environment_is_invalid(
346
346
)
347
347
348
348
flat_policy_environment = dt .flatten_to_tree_paths (policy_environment )
349
- paths_with_incorrect_leaf_names = ""
350
- for p , f in flat_policy_environment .items ():
351
- if hasattr (f , "leaf_name" ) and p [- 1 ] != f .leaf_name :
352
- paths_with_incorrect_leaf_names += f" { p } \n "
349
+ paths_with_incorrect_leaf_names = [
350
+ f" { p } "
351
+ for p , f in flat_policy_environment .items ()
352
+ if hasattr (f , "leaf_name" ) and p [- 1 ] != f .leaf_name
353
+ ]
353
354
if paths_with_incorrect_leaf_names :
354
- msg = (
355
- format_errors_and_warnings (
356
- "The last element of the object's path must be the same as the leaf name "
357
- "of that object. The following tree paths are not compatible with the "
358
- "corresponding object in the policy environment:\n \n "
359
- )
360
- + paths_with_incorrect_leaf_names
361
- )
355
+ msg = format_errors_and_warnings (
356
+ "The last element of the object's path must be the same as the leaf name "
357
+ "of that object. The following tree paths are not compatible with the "
358
+ "corresponding object in the policy environment:\n \n "
359
+ ) + "\n " .join (paths_with_incorrect_leaf_names )
362
360
raise ValueError (msg )
363
361
364
362
@@ -390,31 +388,28 @@ def foreign_keys_are_invalid_in_data(
390
388
if fk_name in labels__root_nodes :
391
389
path = dt .tree_path_from_qname (fk_name )
392
390
# Referenced `p_id` must exist in the input data
393
- if not all (i in valid_ids for i in input_data__flat [path ].tolist ()):
394
- message = format_errors_and_warnings (
395
- f"""
396
- For { path } , the following are not a valid p_id in the input
397
- data: { [i for i in input_data__flat [path ] if i not in valid_ids ]} .
398
- """ ,
391
+ data_array = input_data__flat [path ]
392
+ valid_ids_array = numpy .array (list (valid_ids ))
393
+ valid_mask = numpy .isin (data_array , valid_ids_array )
394
+ if not numpy .all (valid_mask ):
395
+ invalid_ids = data_array [~ valid_mask ].tolist ()
396
+ message = (
397
+ f"For { path } , the following are not a valid p_id in the input "
398
+ f"data: { invalid_ids } ."
399
399
)
400
400
raise ValueError (message )
401
401
402
402
if fk .foreign_key_type == FKType .MUST_NOT_POINT_TO_SELF :
403
- equal_to_pid_in_same_row = [
404
- i
405
- for i , j in zip (
406
- input_data__flat [path ].tolist (),
407
- input_data__flat [("p_id" ,)].tolist (),
408
- strict = False ,
409
- )
410
- if i == j
411
- ]
412
- if any (equal_to_pid_in_same_row ):
413
- message = format_errors_and_warnings (
414
- f"""
415
- For { path } , the following are equal to the p_id in the same
416
- row: { equal_to_pid_in_same_row } .
417
- """ ,
403
+ # Optimized check using numpy operations instead of Python iteration
404
+ data_array = input_data__flat [path ]
405
+ p_id_array = input_data__flat [("p_id" ,)]
406
+ # Use vectorized equality check
407
+ self_references = data_array == p_id_array
408
+ if numpy .any (self_references ):
409
+ equal_to_pid_in_same_row = data_array [self_references ].tolist ()
410
+ message = (
411
+ f"For { path } , the following are equal to the p_id in the same "
412
+ f"row: { equal_to_pid_in_same_row } ."
418
413
)
419
414
raise ValueError (message )
420
415
0 commit comments