Skip to content

Commit a88a093

Browse files
authored
Add char arrays output (#43)
* ✨ add char comparison when encountering int8 arrays * adding detection of wider strings * better management of strings separators & whitespaces * removed potencially disturbing character * 🎨 code style * 🎨 code style + msg fix
1 parent 8cc89f5 commit a88a093

File tree

1 file changed

+87
-21
lines changed

1 file changed

+87
-21
lines changed

geos-ats/src/geos/ats/helpers/restart_check.py

Lines changed: 87 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import argparse
88
import logging
99
import time
10+
import string
1011
from pathlib import Path
1112
try:
1213
from geos.ats.helpers.permute_array import permuteArray # type: ignore[import]
@@ -375,35 +376,100 @@ def compareIntArrays( self, path, arr, base_arr ):
375376
ARR [in]: The hdf5 Dataset to compare.
376377
BASE_ARR [in]: The hdf5 Dataset to compare against.
377378
"""
378-
# If the shapes are different they can't be compared.
379+
message = ""
379380
if arr.shape != base_arr.shape:
380-
msg = "Datasets have different shapes and therefore can't be compared: %s, %s.\n" % ( arr.shape,
381-
base_arr.shape )
382-
self.errorMsg( path, msg, True )
383-
return
381+
message = "Datasets have different shapes and therefore can't be compared statistically: %s, %s.\n" % (
382+
arr.shape, base_arr.shape )
383+
else:
384+
# Calculate the absolute difference.
385+
difference = np.subtract( arr, base_arr )
386+
np.abs( difference, out=difference )
384387

385-
# Create a copy of the arrays.
388+
offenders = difference != 0.0
389+
n_offenders = np.sum( offenders )
386390

387-
# Calculate the absolute difference.
388-
difference = np.subtract( arr, base_arr )
389-
np.abs( difference, out=difference )
391+
if n_offenders != 0:
392+
max_index = np.unravel_index( np.argmax( difference ), difference.shape )
393+
max_difference = difference[ max_index ]
394+
offenders_mean = np.mean( difference[ offenders ] )
395+
offenders_std = np.std( difference[ offenders ] )
390396

391-
offenders = difference != 0.0
392-
n_offenders = np.sum( offenders )
397+
message = "Arrays of types %s and %s have %s values of which %d have differing values.\n" % (
398+
arr.dtype, base_arr.dtype, offenders.size, n_offenders )
399+
message += "Statistics of the differences greater than 0:\n"
400+
message += "\tmax_index = %s, max = %s, mean = %s, std = %s\n" % ( max_index, max_difference,
401+
offenders_mean, offenders_std )
393402

394-
if n_offenders != 0:
395-
max_index = np.unravel_index( np.argmax( difference ), difference.shape )
396-
max_difference = difference[ max_index ]
397-
offenders_mean = np.mean( difference[ offenders ] )
398-
offenders_std = np.std( difference[ offenders ] )
403+
# actually, int8 arrays are almost always char arrays, so we sould add a character comparison.
404+
if arr.dtype == np.int8 and base_arr.dtype == np.int8:
405+
message += self.compareCharArrays( arr, base_arr )
399406

400-
message = "Arrays of types %s and %s have %s values of which %d have differing values.\n" % (
401-
arr.dtype, base_arr.dtype, offenders.size, n_offenders )
402-
message += "Statistics of the differences greater than 0:\n"
403-
message += "\tmax_index = %s, max = %s, mean = %s, std = %s\n" % ( max_index, max_difference,
404-
offenders_mean, offenders_std )
407+
if message != "":
405408
self.errorMsg( path, message, True )
406409

410+
def compareCharArrays( self, comp_arr, base_arr ):
411+
"""
412+
Compare the valid characters of two arrays and return a formatted string showing differences.
413+
414+
COMP_ARR [in]: The hdf5 Dataset to compare.
415+
BASE_ARR [in]: The hdf5 Dataset to compare against.
416+
417+
Returns a formatted string highlighting the differing characters.
418+
"""
419+
comp_ndarr = np.array( comp_arr ).flatten()
420+
base_ndarr = np.array( base_arr ).flatten()
421+
422+
# Replace invalid characters by group-separator characters ('\x1D')
423+
valid_chars = set( string.printable )
424+
invalid_char = '\x1D'
425+
comp_str = "".join(
426+
[ chr( x ) if ( x >= 0 and chr( x ) in valid_chars ) else invalid_char for x in comp_ndarr ] )
427+
base_str = "".join(
428+
[ chr( x ) if ( x >= 0 and chr( x ) in valid_chars ) else invalid_char for x in base_ndarr ] )
429+
430+
# replace whitespaces sequences by only one space (preventing indentation / spacing changes detection)
431+
whitespace_pattern = r"[ \t\n\r\v\f]+"
432+
comp_str = re.sub( whitespace_pattern, " ", comp_str )
433+
base_str = re.sub( whitespace_pattern, " ", base_str )
434+
# replace invalid characters sequences by a double space (for clear display)
435+
invalid_char_pattern = r"\x1D+"
436+
comp_str_display = re.sub( invalid_char_pattern, " ", comp_str )
437+
base_str_display = re.sub( invalid_char_pattern, " ", base_str )
438+
439+
message = ""
440+
441+
def limited_display( n, string ):
442+
return string[ :n ] + f"... ({len(string)-n} omitted chars)" if len( string ) > n else string
443+
444+
if len( comp_str ) != len( base_str ):
445+
max_display = 250
446+
message = f"Character arrays have different sizes: {len( comp_str )}, {len( base_str )}.\n"
447+
message += f" {limited_display( max_display, comp_str_display )}\n"
448+
message += f" {limited_display( max_display, base_str_display )}\n"
449+
else:
450+
# We need to trim arrays to the length of the shortest one for the comparisons
451+
min_length = min( len( comp_str_display ), len( base_str_display ) )
452+
comp_str_trim = comp_str_display[ :min_length ]
453+
base_str_trim = base_str_display[ :min_length ]
454+
455+
differing_indices = np.where( np.array( list( comp_str_trim ) ) != np.array( list( base_str_trim ) ) )[ 0 ]
456+
if differing_indices.size != 0:
457+
# check for reordering
458+
arr_set = sorted( set( comp_str.split( invalid_char ) ) )
459+
base_arr_set = sorted( set( base_str.split( invalid_char ) ) )
460+
reordering_detected = arr_set == base_arr_set
461+
462+
max_display = 110 if reordering_detected else 250
463+
message = "Differing valid characters"
464+
message += " (substrings reordering detected):\n" if reordering_detected else ":\n"
465+
466+
message += f" {limited_display( max_display, comp_str_display )}\n"
467+
message += f" {limited_display( max_display, base_str_display )}\n"
468+
message += " " + "".join(
469+
[ "^" if i in differing_indices else " " for i in range( min( max_display, min_length ) ) ] ) + "\n"
470+
471+
return message
472+
407473
def compareStringArrays( self, path, arr, base_arr ):
408474
"""
409475
Compare two string datasets. Exact equality is used as the acceptance criteria.

0 commit comments

Comments
 (0)