@@ -493,6 +493,7 @@ func (reference *ForeignKeyReferenceHandler) IsInitialized() bool {
493
493
}
494
494
495
495
// CheckReference checks that the given row has an index entry in the referenced table.
496
+ // Performs MySQL-compatible foreign key constraint validation with type-specific checks.
496
497
func (reference * ForeignKeyReferenceHandler ) CheckReference (ctx * sql.Context , row sql.Row ) error {
497
498
// If even one of the values are NULL then we don't check the parent
498
499
for _ , pos := range reference .RowMapper .IndexPositions {
@@ -507,7 +508,7 @@ func (reference *ForeignKeyReferenceHandler) CheckReference(ctx *sql.Context, ro
507
508
}
508
509
defer rowIter .Close (ctx )
509
510
510
- _ , err = rowIter .Next (ctx )
511
+ parentRow , err : = rowIter .Next (ctx )
511
512
if err != nil && err != io .EOF {
512
513
// For SET types, conversion failures during foreign key validation should be treated as foreign key violations
513
514
if sql .ErrConvertingToSet .Is (err ) || sql .ErrInvalidSetValue .Is (err ) {
@@ -518,12 +519,10 @@ func (reference *ForeignKeyReferenceHandler) CheckReference(ctx *sql.Context, ro
518
519
}
519
520
if err == nil {
520
521
// We have a parent row, but check for type-specific validation
521
- if validationErr := reference .validateDecimalConstraints (row ); validationErr != nil {
522
- return validationErr
523
- }
524
- if validationErr := reference .validateTimeConstraints (row ); validationErr != nil {
522
+ if validationErr := reference .validateColumnTypeConstraints (ctx , row , parentRow ); validationErr != nil {
525
523
return validationErr
526
524
}
525
+
527
526
// We have a parent row so throw no error
528
527
return nil
529
528
}
@@ -551,76 +550,55 @@ func (reference *ForeignKeyReferenceHandler) CheckReference(ctx *sql.Context, ro
551
550
reference .ForeignKey .ParentTable , reference .RowMapper .GetKeyString (row ))
552
551
}
553
552
554
- // validateDecimalConstraints checks that decimal foreign key columns have compatible scales.
555
- func (reference * ForeignKeyReferenceHandler ) validateDecimalConstraints (row sql.Row ) error {
556
- if reference .RowMapper .Index == nil {
557
- return nil
558
- }
559
- indexColumnTypes := reference .RowMapper .Index .ColumnExpressionTypes ()
560
- for parentIdx , parentCol := range indexColumnTypes {
561
- if parentIdx >= len (reference .RowMapper .IndexPositions ) {
562
- break
563
- }
564
- parentType := parentCol .Type
565
- childColIdx := reference .RowMapper .IndexPositions [parentIdx ]
566
- childType := reference .RowMapper .SourceSch [childColIdx ].Type
567
- childDecimal , ok := childType .(sql.DecimalType )
568
- if ! ok {
569
- continue
570
- }
571
- parentDecimal , ok := parentType .(sql.DecimalType )
572
- if ! ok {
573
- continue
574
- }
575
- if childDecimal .Scale () != parentDecimal .Scale () {
576
- return sql .ErrForeignKeyChildViolation .New (
577
- reference .ForeignKey .Name ,
578
- reference .ForeignKey .Table ,
579
- reference .ForeignKey .ParentTable ,
580
- reference .RowMapper .GetKeyString (row ),
581
- )
582
- }
583
- }
584
- return nil
585
- }
586
553
587
- // validateTimeConstraints checks that time-related foreign key columns have exact type and precision matches .
588
- // MySQL requires strict matching for time types in foreign keys - even logically equivalent values
589
- // like '2001-02-03 12:34:56' vs '2001-02-03 12:34:56.000000' are rejected if precision differs.
590
- func ( reference * ForeignKeyReferenceHandler ) validateTimeConstraints ( row sql. Row ) error {
591
- if reference . RowMapper .Index == nil {
554
+ // validateColumnTypeConstraints validates that column types meet MySQL foreign key requirements .
555
+ // Centralizes type validation for decimal scale matching and exact time type precision matching.
556
+ func ( reference * ForeignKeyReferenceHandler ) validateColumnTypeConstraints ( ctx * sql. Context , childRow sql. Row , parentRow sql. Row ) error {
557
+ mapper := reference . RowMapper
558
+ if mapper .Index == nil {
592
559
return nil
593
560
}
594
- indexColumnTypes := reference . RowMapper . Index . ColumnExpressionTypes ()
595
- for parentIdx , parentCol := range indexColumnTypes {
596
- if parentIdx >= len (reference . RowMapper .IndexPositions ) {
561
+
562
+ for parentIdx , parentCol := range mapper . Index . ColumnExpressionTypes () {
563
+ if parentIdx >= len (mapper .IndexPositions ) {
597
564
break
598
565
}
566
+
599
567
parentType := parentCol .Type
600
- childColIdx := reference .RowMapper .IndexPositions [parentIdx ]
601
- childType := reference .RowMapper .SourceSch [childColIdx ].Type
568
+ childType := mapper.SourceSch [mapper.IndexPositions [parentIdx ]].Type
602
569
603
- // Check if both types are time-related
604
- isChildTime := types .IsTime (childType ) || types .IsTimespan (childType )
605
- isParentTime := types .IsTime (parentType ) || types .IsTimespan (parentType )
570
+ // Check for constraint violations
571
+ hasViolation := false
606
572
607
- if ! isChildTime || ! isParentTime {
608
- continue
573
+ // Decimal scale must match
574
+ if childDecimal , ok := childType .(sql.DecimalType ); ok {
575
+ if parentDecimal , ok := parentType .(sql.DecimalType ); ok {
576
+ hasViolation = childDecimal .Scale () != parentDecimal .Scale ()
577
+ }
609
578
}
610
579
611
- // MySQL requires exact type matching for time types in foreign key validation
612
- if ! childType .Equals (parentType ) {
580
+ // Time types must match exactly (including precision)
581
+ if ! hasViolation {
582
+ isChildTime := types .IsTime (childType ) || types .IsTimespan (childType )
583
+ isParentTime := types .IsTime (parentType ) || types .IsTimespan (parentType )
584
+ if isChildTime && isParentTime {
585
+ hasViolation = ! childType .Equals (parentType )
586
+ }
587
+ }
588
+
589
+ if hasViolation {
613
590
return sql .ErrForeignKeyChildViolation .New (
614
591
reference .ForeignKey .Name ,
615
592
reference .ForeignKey .Table ,
616
593
reference .ForeignKey .ParentTable ,
617
- reference . RowMapper . GetKeyString (row ),
594
+ mapper . GetKeyString (childRow ),
618
595
)
619
596
}
620
597
}
621
598
return nil
622
599
}
623
600
601
+
624
602
// CheckTable checks that every row in the table has an index entry in the referenced table.
625
603
func (reference * ForeignKeyReferenceHandler ) CheckTable (ctx * sql.Context , tbl sql.ForeignKeyTable ) error {
626
604
partIter , err := tbl .Partitions (ctx )
@@ -678,6 +656,7 @@ func (mapper *ForeignKeyRowMapper) GetIter(ctx *sql.Context, row sql.Row, refChe
678
656
}
679
657
680
658
targetType := mapper .SourceSch [rowPos ].Type
659
+
681
660
// Transform the type of the value in this row to the one in the other table for the index lookup, if necessary
682
661
if mapper .TargetTypeConversions != nil && mapper .TargetTypeConversions [rowPos ] != nil {
683
662
var err error
0 commit comments