4242from cluster import select_best_variant
4343import config
4444from gp_query import ask_multi_query
45+ from gp_query import variable_substitution_deep_narrow_mut_query
4546from gp_query import calibrate_query_timeout
4647from gp_query import combined_ask_count_multi_query
4748from gp_query import predict_query
@@ -420,7 +421,7 @@ def _mutate_expand_node_helper(node, pb_en_out_link=config.MUTPB_EN_OUT_LINK):
420421 new_triple = (node , var_edge , var_node )
421422 else :
422423 new_triple = (var_node , var_edge , node )
423- return new_triple , var_node
424+ return new_triple , var_node , var_edge
424425
425426
426427def mutate_expand_node (child , node = None ):
@@ -433,11 +434,10 @@ def mutate_expand_node(child, node=None):
433434
434435
435436def mutate_deep_narrow_path (
436- child ,
437+ child , sparql , timeout , gtp_scores ,
437438 min_len = config .MUTPB_DN_MIN_LEN ,
438439 max_len = config .MUTPB_DN_MAX_LEN ,
439440 term_pb = config .MUTPB_DN_TERM_PB ,
440- pb_en_out_link = config .MUTPB_EN_OUT_LINK ,
441441):
442442 assert isinstance (child , GraphPattern )
443443 nodes = list (child .nodes )
@@ -451,15 +451,76 @@ def mutate_deep_narrow_path(
451451 if hop >= max_len :
452452 break
453453 hop += 1
454- new_triple , var_node = _mutate_expand_node_helper (start_node )
455- gp += [new_triple ]
456- start_node = var_node
454+ new_triple , var_node , var_edge = _mutate_expand_node_helper (start_node )
455+ test_gp = gp + [new_triple ]
456+ test_gp , fixed = _mutate_deep_narrow_path_helper (
457+ sparql , timeout , gtp_scores , test_gp , var_edge , var_node )
458+ if fixed :
459+ start_node = var_node
460+ gp = test_gp
457461
458462 # TODO: insert connection to a target node
459463 # TODO: fix edge or node ( to_count_var_over_values_query)
460464 return gp
461465
462466
467+ def _mutate_deep_narrow_path_helper (
468+ sparql ,
469+ timeout ,
470+ gtp_scores ,
471+ child ,
472+ edge_var ,
473+ node_var ,
474+ gtp_sample_n = config .MUTPB_FV_RGTP_SAMPLE_N ,
475+ limit_res = config .MUTPB_DN_QUERY_LIMIT ,
476+ sample_n = config .MUTPB_FV_SAMPLE_MAXN ,
477+ ):
478+ assert isinstance (child , GraphPattern )
479+ assert isinstance (gtp_scores , GTPScores )
480+
481+ # The further we get, the less gtps are remaining. Sampling too many (all)
482+ # of them might hurt as common substitutions (> limit ones) which are dead
483+ # ends could cover less common ones that could actually help
484+ gtp_sample_n = min (gtp_sample_n , int (gtp_scores .remaining_gain ))
485+ gtp_sample_n = random .randint (1 , gtp_sample_n )
486+
487+ ground_truth_pairs = gtp_scores .remaining_gain_sample_gtps (
488+ n = gtp_sample_n )
489+ t , substitution_counts = variable_substitution_deep_narrow_mut_query (
490+ sparql , timeout , child , edge_var , node_var , ground_truth_pairs ,
491+ limit_res )
492+ if not substitution_counts :
493+ # the current pattern is unfit, as we can't find anything fulfilling it
494+ logger .debug ("tried to fix a var %s without result:\n %s"
495+ "seems as if the pattern can't be fulfilled!" ,
496+ edge_var , child .to_sparql_select_query ())
497+ fixed = False
498+ return [child ], fixed
499+ mutate_fix_var_filter (substitution_counts )
500+ if not substitution_counts :
501+ # could have happened that we removed the only possible substitution
502+ fixed = False
503+ return [child ], fixed
504+ # randomly pick n of the substitutions with a prob ~ to their counts
505+ items , counts = zip (* substitution_counts .most_common ())
506+ substs = sample_from_list (items , counts , sample_n )
507+ logger .info (
508+ 'fixed variable %s in %sto:\n %s\n <%d out of:\n %s\n ' ,
509+ edge_var .n3 (),
510+ child ,
511+ '\n ' .join ([subst .n3 () for subst in substs ]),
512+ sample_n ,
513+ '\n ' .join ([' %d: %s' % (c , v .n3 ())
514+ for v , c in substitution_counts .most_common ()]),
515+ )
516+ fixed = True
517+ res = [
518+ GraphPattern (child , mapping = {edge_var : subst })
519+ for subst in substs
520+ ]
521+ return res , fixed
522+
523+
463524def mutate_add_edge (child ):
464525 # TODO: can maybe be improved by sparqling
465526 nodes = list (child .nodes )
@@ -682,6 +743,7 @@ def mutate(
682743 pb_dt = config .MUTPB_DT ,
683744 pb_en = config .MUTPB_EN ,
684745 pb_fv = config .MUTPB_FV ,
746+ pb_dn = config .MUTPB_DN ,
685747 pb_id = config .MUTPB_ID ,
686748 pb_iv = config .MUTPB_IV ,
687749 pb_mv = config .MUTPB_MV ,
@@ -721,6 +783,9 @@ def mutate(
721783 if random .random () < pb_sp :
722784 child = mutate_simplify_pattern (child )
723785
786+ if random .random () < pb_dn :
787+ child = mutate_deep_narrow_path (child , sparql , timeout , gtp_scores )
788+
724789 if random .random () < pb_fv :
725790 child = canonicalize (child )
726791 children = mutate_fix_var (sparql , timeout , gtp_scores , child )
0 commit comments