diff --git a/pgtype/vector.go b/pgtype/vector.go new file mode 100644 index 000000000..359dd82d8 --- /dev/null +++ b/pgtype/vector.go @@ -0,0 +1,271 @@ +package pgtype + +import ( + "database/sql/driver" + "encoding/binary" + "fmt" + "math" + "strconv" + "strings" + + "github.com/jackc/pgx/v5/internal/pgio" +) + +type VectorScanner interface { + ScanVector(v Vector) error +} + +type VectorValuer interface { + VectorValue() (Vector, error) +} + +type Vector struct { + Vec []float32 + Valid bool +} + +// ScanVector implements the [VectorScanner] interface. +func (v *Vector) ScanVector(val Vector) error { + *v = val + return nil +} + +// VectorValue implements the [VectorValuer] interface. +func (v Vector) VectorValue() (Vector, error) { + return v, nil +} + +// Scan implements the [database/sql.Scanner] interface. +func (dst *Vector) Scan(src any) error { + if src == nil { + *dst = Vector{} + return nil + } + + switch src := src.(type) { + case string: + return scanPlanTextAnyToVectorScanner{}.Scan([]byte(src), dst) + case []byte: + return scanPlanTextAnyToVectorScanner{}.Scan(src, dst) + } + + return fmt.Errorf("cannot scan %T", src) +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (src Vector) Value() (driver.Value, error) { + if !src.Valid { + return nil, nil + } + + buf, err := VectorCodec{}.PlanEncode(nil, 0, TextFormatCode, src).Encode(src, nil) + if err != nil { + return nil, err + } + return string(buf), err +} + +// MarshalJSON implements the [encoding/json.Marshaler] interface. +func (src Vector) MarshalJSON() ([]byte, error) { + if !src.Valid { + return []byte("null"), nil + } + + return []byte(src.String()), nil +} + +// UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. +func (dst *Vector) UnmarshalJSON(b []byte) error { + if string(b) == "null" { + *dst = Vector{} + return nil + } + + vec, err := parseVector(string(b)) + if err != nil { + return err + } + *dst = vec + return nil +} + +func (v Vector) String() string { + if !v.Valid { + return "" + } + + var b strings.Builder + b.WriteString("[") + for i, val := range v.Vec { + if i > 0 { + b.WriteString(",") + } + b.WriteString(strconv.FormatFloat(float64(val), 'g', -1, 32)) + } + b.WriteString("]") + return b.String() +} + +func parseVector(s string) (Vector, error) { + s = strings.TrimSpace(s) + if len(s) < 2 || s[0] != '[' || s[len(s)-1] != ']' { + return Vector{}, fmt.Errorf("invalid vector format") + } + + s = s[1 : len(s)-1] + if s == "" { + return Vector{Vec: []float32{}, Valid: true}, nil + } + + parts := strings.Split(s, ",") + vec := make([]float32, len(parts)) + for i, part := range parts { + f, err := strconv.ParseFloat(strings.TrimSpace(part), 32) + if err != nil { + return Vector{}, err + } + vec[i] = float32(f) + } + + return Vector{Vec: vec, Valid: true}, nil +} + +type VectorCodec struct{} + +func (VectorCodec) FormatSupported(format int16) bool { + return format == TextFormatCode || format == BinaryFormatCode +} + +func (VectorCodec) PreferredFormat() int16 { + return BinaryFormatCode +} + +func (VectorCodec) PlanEncode(m *Map, oid uint32, format int16, value any) EncodePlan { + if _, ok := value.(VectorValuer); !ok { + return nil + } + + switch format { + case BinaryFormatCode: + return encodePlanVectorCodecBinary{} + case TextFormatCode: + return encodePlanVectorCodecText{} + } + + return nil +} + +type encodePlanVectorCodecBinary struct{} + +func (encodePlanVectorCodecBinary) Encode(value any, buf []byte) (newBuf []byte, err error) { + vector, err := value.(VectorValuer).VectorValue() + if err != nil { + return nil, err + } + + if !vector.Valid { + return nil, nil + } + + dim := uint16(len(vector.Vec)) + buf = pgio.AppendUint16(buf, dim) + buf = pgio.AppendUint16(buf, 0) + for _, v := range vector.Vec { + buf = pgio.AppendUint32(buf, math.Float32bits(v)) + } + return buf, nil +} + +type encodePlanVectorCodecText struct{} + +func (encodePlanVectorCodecText) Encode(value any, buf []byte) (newBuf []byte, err error) { + vector, err := value.(VectorValuer).VectorValue() + if err != nil { + return nil, err + } + + if !vector.Valid { + return nil, nil + } + + return append(buf, vector.String()...), nil +} + +func (VectorCodec) PlanScan(m *Map, oid uint32, format int16, target any) ScanPlan { + switch format { + case BinaryFormatCode: + switch target.(type) { + case VectorScanner: + return scanPlanBinaryVectorToVectorScanner{} + } + case TextFormatCode: + switch target.(type) { + case VectorScanner: + return scanPlanTextAnyToVectorScanner{} + } + } + + return nil +} + +func (c VectorCodec) DecodeDatabaseSQLValue(m *Map, oid uint32, format int16, src []byte) (driver.Value, error) { + return codecDecodeToTextFormat(c, m, oid, format, src) +} + +func (c VectorCodec) DecodeValue(m *Map, oid uint32, format int16, src []byte) (any, error) { + if src == nil { + return nil, nil + } + + var vector Vector + err := codecScan(c, m, oid, format, src, &vector) + if err != nil { + return nil, err + } + return vector, nil +} + +type scanPlanBinaryVectorToVectorScanner struct{} + +func (scanPlanBinaryVectorToVectorScanner) Scan(src []byte, dst any) error { + scanner := (dst).(VectorScanner) + + if src == nil { + return scanner.ScanVector(Vector{}) + } + + if len(src) < 4 { + return fmt.Errorf("invalid length for vector: %v", len(src)) + } + + dim := binary.BigEndian.Uint16(src) + expectedLen := 4 + int(dim)*4 + if len(src) != expectedLen { + return fmt.Errorf("invalid length for vector: expected %d, got %d", expectedLen, len(src)) + } + + vec := make([]float32, dim) + for i := 0; i < int(dim); i++ { + bits := binary.BigEndian.Uint32(src[4+i*4:]) + vec[i] = math.Float32frombits(bits) + } + + return scanner.ScanVector(Vector{Vec: vec, Valid: true}) +} + +type scanPlanTextAnyToVectorScanner struct{} + +func (scanPlanTextAnyToVectorScanner) Scan(src []byte, dst any) error { + scanner := (dst).(VectorScanner) + + if src == nil { + return scanner.ScanVector(Vector{}) + } + + vector, err := parseVector(string(src)) + if err != nil { + return err + } + + return scanner.ScanVector(vector) +} diff --git a/pgtype/vector_test.go b/pgtype/vector_test.go new file mode 100644 index 000000000..08a2dcce1 --- /dev/null +++ b/pgtype/vector_test.go @@ -0,0 +1,379 @@ +package pgtype_test + +import ( + "encoding/binary" + "math" + "reflect" + "testing" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/require" +) + +func TestVectorMarshalJSON(t *testing.T) { + tests := []struct { + name string + vector pgtype.Vector + want string + }{ + { + name: "valid vector", + vector: pgtype.Vector{Vec: []float32{1, 2, 3}, Valid: true}, + want: "[1,2,3]", + }, + { + name: "empty vector", + vector: pgtype.Vector{Vec: []float32{}, Valid: true}, + want: "[]", + }, + { + name: "null vector", + vector: pgtype.Vector{}, + want: "null", + }, + { + name: "vector with decimals", + vector: pgtype.Vector{Vec: []float32{1.5, 2.25, 3.75}, Valid: true}, + want: "[1.5,2.25,3.75]", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.vector.MarshalJSON() + require.NoError(t, err) + require.Equal(t, tt.want, string(got)) + }) + } +} + +func TestVectorUnmarshalJSON(t *testing.T) { + tests := []struct { + name string + input string + want pgtype.Vector + wantErr bool + }{ + { + name: "valid vector", + input: "[1,2,3]", + want: pgtype.Vector{Vec: []float32{1, 2, 3}, Valid: true}, + wantErr: false, + }, + { + name: "empty vector", + input: "[]", + want: pgtype.Vector{Vec: []float32{}, Valid: true}, + wantErr: false, + }, + { + name: "null vector", + input: "null", + want: pgtype.Vector{}, + wantErr: false, + }, + { + name: "vector with decimals", + input: "[1.5,2.25,3.75]", + want: pgtype.Vector{Vec: []float32{1.5, 2.25, 3.75}, Valid: true}, + wantErr: false, + }, + { + name: "vector with spaces", + input: "[ 1.5 , 2.25 , 3.75 ]", + want: pgtype.Vector{Vec: []float32{1.5, 2.25, 3.75}, Valid: true}, + wantErr: false, + }, + { + name: "invalid format - no brackets", + input: "1,2,3", + wantErr: true, + }, + { + name: "invalid format - not a number", + input: "[1,2,abc]", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got pgtype.Vector + err := got.UnmarshalJSON([]byte(tt.input)) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.want.Valid, got.Valid) + if tt.want.Valid { + require.Equal(t, tt.want.Vec, got.Vec) + } + } + }) + } +} + +func TestVectorString(t *testing.T) { + tests := []struct { + name string + vector pgtype.Vector + want string + }{ + { + name: "valid vector", + vector: pgtype.Vector{Vec: []float32{1, 2, 3}, Valid: true}, + want: "[1,2,3]", + }, + { + name: "empty vector", + vector: pgtype.Vector{Vec: []float32{}, Valid: true}, + want: "[]", + }, + { + name: "null vector", + vector: pgtype.Vector{}, + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.vector.String() + require.Equal(t, tt.want, got) + }) + } +} + +func TestVectorCodecEncodeBinary(t *testing.T) { + tests := []struct { + name string + vector pgtype.Vector + want []byte + wantNil bool + }{ + { + name: "valid vector", + vector: pgtype.Vector{Vec: []float32{1, 2, 3}, Valid: true}, + want: func() []byte { + buf := make([]byte, 0) + buf = binary.BigEndian.AppendUint16(buf, 3) + buf = binary.BigEndian.AppendUint16(buf, 0) + buf = binary.BigEndian.AppendUint32(buf, math.Float32bits(1)) + buf = binary.BigEndian.AppendUint32(buf, math.Float32bits(2)) + buf = binary.BigEndian.AppendUint32(buf, math.Float32bits(3)) + return buf + }(), + }, + { + name: "empty vector", + vector: pgtype.Vector{Vec: []float32{}, Valid: true}, + want: func() []byte { + buf := make([]byte, 0) + buf = binary.BigEndian.AppendUint16(buf, 0) + buf = binary.BigEndian.AppendUint16(buf, 0) + return buf + }(), + }, + { + name: "null vector", + vector: pgtype.Vector{}, + wantNil: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + codec := pgtype.VectorCodec{} + plan := codec.PlanEncode(nil, 0, pgtype.BinaryFormatCode, tt.vector) + require.NotNil(t, plan) + + got, err := plan.Encode(tt.vector, nil) + require.NoError(t, err) + + if tt.wantNil { + require.Nil(t, got) + } else { + require.Equal(t, tt.want, got) + } + }) + } +} + +func TestVectorCodecDecodeTextFormat(t *testing.T) { + tests := []struct { + name string + input string + want pgtype.Vector + wantErr bool + }{ + { + name: "valid vector", + input: "[1,2,3]", + want: pgtype.Vector{Vec: []float32{1, 2, 3}, Valid: true}, + }, + { + name: "empty vector", + input: "[]", + want: pgtype.Vector{Vec: []float32{}, Valid: true}, + }, + { + name: "vector with decimals", + input: "[1.5,2.25,3.75]", + want: pgtype.Vector{Vec: []float32{1.5, 2.25, 3.75}, Valid: true}, + }, + { + name: "invalid format", + input: "1,2,3", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + codec := pgtype.VectorCodec{} + var got pgtype.Vector + plan := codec.PlanScan(nil, 0, pgtype.TextFormatCode, &got) + require.NotNil(t, plan) + + err := plan.Scan([]byte(tt.input), &got) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.want.Valid, got.Valid) + if tt.want.Valid { + require.Equal(t, tt.want.Vec, got.Vec) + } + } + }) + } +} + +func TestVectorCodecDecodeBinaryFormat(t *testing.T) { + tests := []struct { + name string + input []byte + want pgtype.Vector + wantErr bool + }{ + { + name: "valid vector", + input: func() []byte { + buf := make([]byte, 0) + buf = binary.BigEndian.AppendUint16(buf, 3) + buf = binary.BigEndian.AppendUint16(buf, 0) + buf = binary.BigEndian.AppendUint32(buf, math.Float32bits(1)) + buf = binary.BigEndian.AppendUint32(buf, math.Float32bits(2)) + buf = binary.BigEndian.AppendUint32(buf, math.Float32bits(3)) + return buf + }(), + want: pgtype.Vector{Vec: []float32{1, 2, 3}, Valid: true}, + }, + { + name: "empty vector", + input: func() []byte { + buf := make([]byte, 0) + buf = binary.BigEndian.AppendUint16(buf, 0) + buf = binary.BigEndian.AppendUint16(buf, 0) + return buf + }(), + want: pgtype.Vector{Vec: []float32{}, Valid: true}, + }, + { + name: "invalid length", + input: []byte{0, 0}, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + codec := pgtype.VectorCodec{} + var got pgtype.Vector + plan := codec.PlanScan(nil, 0, pgtype.BinaryFormatCode, &got) + require.NotNil(t, plan) + + err := plan.Scan(tt.input, &got) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.want.Valid, got.Valid) + if tt.want.Valid { + require.True(t, reflect.DeepEqual(tt.want.Vec, got.Vec)) + } + } + }) + } +} + +func TestVectorScan(t *testing.T) { + tests := []struct { + name string + input any + want pgtype.Vector + wantErr bool + }{ + { + name: "string input", + input: "[1,2,3]", + want: pgtype.Vector{Vec: []float32{1, 2, 3}, Valid: true}, + }, + { + name: "byte slice input", + input: []byte("[1,2,3]"), + want: pgtype.Vector{Vec: []float32{1, 2, 3}, Valid: true}, + }, + { + name: "nil input", + input: nil, + want: pgtype.Vector{}, + }, + { + name: "invalid type", + input: 123, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got pgtype.Vector + err := got.Scan(tt.input) + if tt.wantErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, tt.want.Valid, got.Valid) + if tt.want.Valid { + require.Equal(t, tt.want.Vec, got.Vec) + } + } + }) + } +} + +func TestVectorValue(t *testing.T) { + tests := []struct { + name string + vector pgtype.Vector + want string + wantNil bool + }{ + { + name: "valid vector", + vector: pgtype.Vector{Vec: []float32{1, 2, 3}, Valid: true}, + want: "[1,2,3]", + }, + { + name: "null vector", + vector: pgtype.Vector{}, + wantNil: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.vector.Value() + require.NoError(t, err) + + if tt.wantNil { + require.Nil(t, got) + } else { + require.Equal(t, tt.want, got) + } + }) + } +}