From c2c56ce9a5f5929c8b488801030b05d2e156a42c Mon Sep 17 00:00:00 2001 From: Dan McGee Date: Sat, 29 Jun 2024 09:10:47 -0500 Subject: [PATCH] Add support for PG 17 interval infinity values Added in this commit: https://github.com/postgres/postgres/commit/519fc1bd9 The wire format for infinity/-infinity adds no breaking changes- it is just the various values set to either minimum or maximum possible int64/int32 values. I attempted to match how InfinityModifier works in other types, such as date and timestamptz. --- go.sum | 4 - pgtype/date.go | 12 +-- pgtype/interval.go | 215 ++++++++++++++++++++++++---------------- pgtype/interval_test.go | 12 +++ pgtype/timestamp.go | 12 +-- pgtype/timestamptz.go | 12 +-- 6 files changed, 152 insertions(+), 115 deletions(-) diff --git a/go.sum b/go.sum index 4b02a0365..29fe452b2 100644 --- a/go.sum +++ b/go.sum @@ -4,10 +4,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA= -github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= diff --git a/pgtype/date.go b/pgtype/date.go index 784b16deb..61e29f243 100644 --- a/pgtype/date.go +++ b/pgtype/date.go @@ -80,10 +80,8 @@ func (src Date) MarshalJSON() ([]byte, error) { switch src.InfinityModifier { case Finite: s = src.Time.Format("2006-01-02") - case Infinity: - s = "infinity" - case NegativeInfinity: - s = "-infinity" + case Infinity, NegativeInfinity: + s = src.InfinityModifier.String() } return json.Marshal(s) @@ -213,10 +211,8 @@ func (encodePlanDateCodecText) Encode(value any, buf []byte) (newBuf []byte, err if bc { buf = append(buf, " BC"...) } - case Infinity: - buf = append(buf, "infinity"...) - case NegativeInfinity: - buf = append(buf, "-infinity"...) + case Infinity, NegativeInfinity: + buf = append(buf, date.InfinityModifier.String()...) } return buf, nil diff --git a/pgtype/interval.go b/pgtype/interval.go index 4b5116295..9a95aa449 100644 --- a/pgtype/interval.go +++ b/pgtype/interval.go @@ -4,6 +4,7 @@ import ( "database/sql/driver" "encoding/binary" "fmt" + "math" "strconv" "strings" @@ -27,10 +28,11 @@ type IntervalValuer interface { } type Interval struct { - Microseconds int64 - Days int32 - Months int32 - Valid bool + Microseconds int64 + Days int32 + Months int32 + InfinityModifier InfinityModifier + Valid bool } func (interval *Interval) ScanInterval(v Interval) error { @@ -63,6 +65,10 @@ func (interval Interval) Value() (driver.Value, error) { return nil, nil } + if interval.InfinityModifier != Finite { + return interval.InfinityModifier.String(), nil + } + buf, err := IntervalCodec{}.PlanEncode(nil, 0, TextFormatCode, interval).Encode(interval, nil) if err != nil { return nil, err @@ -107,9 +113,21 @@ func (encodePlanIntervalCodecBinary) Encode(value any, buf []byte) (newBuf []byt return nil, nil } - buf = pgio.AppendInt64(buf, interval.Microseconds) - buf = pgio.AppendInt32(buf, interval.Days) - buf = pgio.AppendInt32(buf, interval.Months) + switch interval.InfinityModifier { + case Finite: + buf = pgio.AppendInt64(buf, interval.Microseconds) + buf = pgio.AppendInt32(buf, interval.Days) + buf = pgio.AppendInt32(buf, interval.Months) + case Infinity: + buf = pgio.AppendInt64(buf, math.MaxInt64) + buf = pgio.AppendInt32(buf, math.MaxInt32) + buf = pgio.AppendInt32(buf, math.MaxInt32) + case NegativeInfinity: + buf = pgio.AppendInt64(buf, math.MinInt64) + buf = pgio.AppendInt32(buf, math.MinInt32) + buf = pgio.AppendInt32(buf, math.MinInt32) + } + return buf, nil } @@ -125,32 +143,37 @@ func (encodePlanIntervalCodecText) Encode(value any, buf []byte) (newBuf []byte, return nil, nil } - if interval.Months != 0 { - buf = append(buf, strconv.FormatInt(int64(interval.Months), 10)...) - buf = append(buf, " mon "...) - } + switch interval.InfinityModifier { + case Finite: + if interval.Months != 0 { + buf = append(buf, strconv.FormatInt(int64(interval.Months), 10)...) + buf = append(buf, " mon "...) + } - if interval.Days != 0 { - buf = append(buf, strconv.FormatInt(int64(interval.Days), 10)...) - buf = append(buf, " day "...) - } + if interval.Days != 0 { + buf = append(buf, strconv.FormatInt(int64(interval.Days), 10)...) + buf = append(buf, " day "...) + } - absMicroseconds := interval.Microseconds - if absMicroseconds < 0 { - absMicroseconds = -absMicroseconds - buf = append(buf, '-') - } + absMicroseconds := interval.Microseconds + if absMicroseconds < 0 { + absMicroseconds = -absMicroseconds + buf = append(buf, '-') + } - hours := absMicroseconds / microsecondsPerHour - minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute - seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond + hours := absMicroseconds / microsecondsPerHour + minutes := (absMicroseconds % microsecondsPerHour) / microsecondsPerMinute + seconds := (absMicroseconds % microsecondsPerMinute) / microsecondsPerSecond - timeStr := fmt.Sprintf("%02d:%02d:%02d", hours, minutes, seconds) - buf = append(buf, timeStr...) + timeStr := fmt.Sprintf("%02d:%02d:%02d", hours, minutes, seconds) + buf = append(buf, timeStr...) - microseconds := absMicroseconds % microsecondsPerSecond - if microseconds != 0 { - buf = append(buf, fmt.Sprintf(".%06d", microseconds)...) + microseconds := absMicroseconds % microsecondsPerSecond + if microseconds != 0 { + buf = append(buf, fmt.Sprintf(".%06d", microseconds)...) + } + case Infinity, NegativeInfinity: + buf = append(buf, interval.InfinityModifier.String()...) } return buf, nil @@ -184,14 +207,22 @@ func (scanPlanBinaryIntervalToIntervalScanner) Scan(src []byte, dst any) error { } if len(src) != 16 { - return fmt.Errorf("Received an invalid size for an interval: %d", len(src)) + return fmt.Errorf("received an invalid size for an interval: %d", len(src)) } microseconds := int64(binary.BigEndian.Uint64(src)) days := int32(binary.BigEndian.Uint32(src[8:])) months := int32(binary.BigEndian.Uint32(src[12:])) - return scanner.ScanInterval(Interval{Microseconds: microseconds, Days: days, Months: months, Valid: true}) + interval := Interval{Microseconds: microseconds, Days: days, Months: months, Valid: true} + + if microseconds == math.MaxInt64 && days == math.MaxInt32 && months == math.MaxInt32 { + interval.InfinityModifier = Infinity + } else if microseconds == math.MinInt64 && days == math.MinInt32 && months == math.MinInt32 { + interval.InfinityModifier = NegativeInfinity + } + + return scanner.ScanInterval(interval) } type scanPlanTextAnyToIntervalScanner struct{} @@ -203,80 +234,90 @@ func (scanPlanTextAnyToIntervalScanner) Scan(src []byte, dst any) error { return scanner.ScanInterval(Interval{}) } - var microseconds int64 - var days int32 - var months int32 - - parts := strings.Split(string(src), " ") - - for i := 0; i < len(parts)-1; i += 2 { - scalar, err := strconv.ParseInt(parts[i], 10, 64) - if err != nil { - return fmt.Errorf("bad interval format") - } - - switch parts[i+1] { - case "year", "years": - months += int32(scalar * 12) - case "mon", "mons": - months += int32(scalar) - case "day", "days": - days = int32(scalar) - } - } + var interval Interval + sbuf := string(src) + switch sbuf { + case "infinity": + interval = Interval{InfinityModifier: Infinity, Valid: true} + case "-infinity": + interval = Interval{InfinityModifier: NegativeInfinity, Valid: true} + default: + var microseconds int64 + var days int32 + var months int32 + + parts := strings.Split(sbuf, " ") + + for i := 0; i < len(parts)-1; i += 2 { + scalar, err := strconv.ParseInt(parts[i], 10, 64) + if err != nil { + return fmt.Errorf("bad interval format") + } - if len(parts)%2 == 1 { - timeParts := strings.SplitN(parts[len(parts)-1], ":", 3) - if len(timeParts) != 3 { - return fmt.Errorf("bad interval format") + switch parts[i+1] { + case "year", "years": + months += int32(scalar * 12) + case "mon", "mons": + months += int32(scalar) + case "day", "days": + days = int32(scalar) + } } - var negative bool - if timeParts[0][0] == '-' { - negative = true - timeParts[0] = timeParts[0][1:] - } + if len(parts)%2 == 1 { + timeParts := strings.SplitN(parts[len(parts)-1], ":", 3) + if len(timeParts) != 3 { + return fmt.Errorf("bad interval format") + } - hours, err := strconv.ParseInt(timeParts[0], 10, 64) - if err != nil { - return fmt.Errorf("bad interval hour format: %s", timeParts[0]) - } + var negative bool + if timeParts[0][0] == '-' { + negative = true + timeParts[0] = timeParts[0][1:] + } - minutes, err := strconv.ParseInt(timeParts[1], 10, 64) - if err != nil { - return fmt.Errorf("bad interval minute format: %s", timeParts[1]) - } + hours, err := strconv.ParseInt(timeParts[0], 10, 64) + if err != nil { + return fmt.Errorf("bad interval hour format: %s", timeParts[0]) + } - sec, secFrac, secFracFound := strings.Cut(timeParts[2], ".") + minutes, err := strconv.ParseInt(timeParts[1], 10, 64) + if err != nil { + return fmt.Errorf("bad interval minute format: %s", timeParts[1]) + } - seconds, err := strconv.ParseInt(sec, 10, 64) - if err != nil { - return fmt.Errorf("bad interval second format: %s", sec) - } + sec, secFrac, secFracFound := strings.Cut(timeParts[2], ".") - var uSeconds int64 - if secFracFound { - uSeconds, err = strconv.ParseInt(secFrac, 10, 64) + seconds, err := strconv.ParseInt(sec, 10, 64) if err != nil { - return fmt.Errorf("bad interval decimal format: %s", secFrac) + return fmt.Errorf("bad interval second format: %s", sec) } - for i := 0; i < 6-len(secFrac); i++ { - uSeconds *= 10 + var uSeconds int64 + if secFracFound { + uSeconds, err = strconv.ParseInt(secFrac, 10, 64) + if err != nil { + return fmt.Errorf("bad interval decimal format: %s", secFrac) + } + + for i := 0; i < 6-len(secFrac); i++ { + uSeconds *= 10 + } } - } - microseconds = hours * microsecondsPerHour - microseconds += minutes * microsecondsPerMinute - microseconds += seconds * microsecondsPerSecond - microseconds += uSeconds + microseconds = hours * microsecondsPerHour + microseconds += minutes * microsecondsPerMinute + microseconds += seconds * microsecondsPerSecond + microseconds += uSeconds - if negative { - microseconds = -microseconds + if negative { + microseconds = -microseconds + } } + interval = Interval{Months: months, Days: days, Microseconds: microseconds, Valid: true} } - return scanner.ScanInterval(Interval{Months: months, Days: days, Microseconds: microseconds, Valid: true}) + return scanner.ScanInterval(interval) } func (c IntervalCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { diff --git a/pgtype/interval_test.go b/pgtype/interval_test.go index c06c3b2df..c628ca6f3 100644 --- a/pgtype/interval_test.go +++ b/pgtype/interval_test.go @@ -127,6 +127,16 @@ func TestIntervalCodec(t *testing.T) { new(pgtype.Interval), isExpectedEq(pgtype.Interval{Months: -13, Valid: true}), }, + { + "infinity", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{InfinityModifier: pgtype.Infinity, Valid: true}), + }, + { + "-infinity", + new(pgtype.Interval), + isExpectedEq(pgtype.Interval{InfinityModifier: pgtype.NegativeInfinity, Valid: true}), + }, {time.Hour, new(time.Duration), isExpectedEq(time.Hour)}, { pgtype.Interval{Months: 1, Days: 1, Valid: true}, @@ -149,6 +159,8 @@ func TestIntervalTextEncode(t *testing.T) { {source: pgtype.Interval{Months: 0, Days: 0, Microseconds: 0, Valid: true}, result: "00:00:00"}, {source: pgtype.Interval{Months: 0, Days: 0, Microseconds: 6 * 60 * 1000000, Valid: true}, result: "00:06:00"}, {source: pgtype.Interval{Months: 0, Days: 1, Microseconds: 6*60*1000000 + 30, Valid: true}, result: "1 day 00:06:00.000030"}, + {source: pgtype.Interval{InfinityModifier: pgtype.Infinity, Valid: true}, result: "infinity"}, + {source: pgtype.Interval{InfinityModifier: pgtype.NegativeInfinity, Valid: true}, result: "-infinity"}, } for i, tt := range successfulTests { buf, err := m.Encode(pgtype.DateOID, pgtype.TextFormatCode, tt.source, nil) diff --git a/pgtype/timestamp.go b/pgtype/timestamp.go index 677a2c6ea..b74ca2f1b 100644 --- a/pgtype/timestamp.go +++ b/pgtype/timestamp.go @@ -77,10 +77,8 @@ func (ts Timestamp) MarshalJSON() ([]byte, error) { switch ts.InfinityModifier { case Finite: s = ts.Time.Format(time.RFC3339Nano) - case Infinity: - s = "infinity" - case NegativeInfinity: - s = "-infinity" + case Infinity, NegativeInfinity: + s = ts.InfinityModifier.String() } return json.Marshal(s) @@ -205,10 +203,8 @@ func (encodePlanTimestampCodecText) Encode(value any, buf []byte) (newBuf []byte if bc { s = s + " BC" } - case Infinity: - s = "infinity" - case NegativeInfinity: - s = "-infinity" + case Infinity, NegativeInfinity: + s = ts.InfinityModifier.String() } buf = append(buf, s...) diff --git a/pgtype/timestamptz.go b/pgtype/timestamptz.go index 7efbcffd2..13eb150c6 100644 --- a/pgtype/timestamptz.go +++ b/pgtype/timestamptz.go @@ -85,10 +85,8 @@ func (tstz Timestamptz) MarshalJSON() ([]byte, error) { switch tstz.InfinityModifier { case Finite: s = tstz.Time.Format(time.RFC3339Nano) - case Infinity: - s = "infinity" - case NegativeInfinity: - s = "-infinity" + case Infinity, NegativeInfinity: + s = tstz.InfinityModifier.String() } return json.Marshal(s) @@ -213,10 +211,8 @@ func (encodePlanTimestamptzCodecText) Encode(value any, buf []byte) (newBuf []by if bc { s = s + " BC" } - case Infinity: - s = "infinity" - case NegativeInfinity: - s = "-infinity" + case Infinity, NegativeInfinity: + s = ts.InfinityModifier.String() } buf = append(buf, s...)