Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions go/adbc/driver/internal/driverbase/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ import (
"time"

"github.com/apache/arrow-adbc/go/adbc"
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/extensions"
"github.com/apache/arrow-go/v18/arrow/memory"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
Expand Down Expand Up @@ -79,6 +81,34 @@ var getExporterName = sync.OnceValue(func() string {
return os.Getenv(otelTracesExporter)
})

// registerExtensionTypes ensures that canonical Arrow extension types are registered.
// This is called once during driver initialization to make sure extension types like
// UUID are available for use throughout the driver.
var registerExtensionTypes = sync.OnceFunc(func() {
// The arrow/extensions package automatically registers canonical extension types
// (UUID, Bool8, JSON, Opaque, Variant) in its init() function.
// However, we explicitly ensure registration here in case the package wasn't
// imported elsewhere, and to handle any registration errors gracefully.

// List of canonical extension types to ensure are registered
canonicalTypes := []arrow.ExtensionType{
extensions.NewUUIDType(),
extensions.NewBool8Type(),
&extensions.JSONType{},
&extensions.OpaqueType{},
&extensions.VariantType{},
}

for _, extType := range canonicalTypes {
// RegisterExtensionType is idempotent - it returns an error only if
// a different type with the same name is already registered
if err := arrow.RegisterExtensionType(extType); err != nil {
// Log but don't fail - the type might already be registered
// which is fine (the extensions package init() may have done it)
}
}
})

// DatabaseImpl is an interface that drivers implement to provide
// vendor-specific functionality.
type DatabaseImpl interface {
Expand Down Expand Up @@ -115,6 +145,9 @@ type DatabaseImplBase struct {
// - driver is a DriverImplBase containing the common resources from the parent
// driver, allowing the Arrow allocator and error handler to be reused.
func NewDatabaseImplBase(ctx context.Context, driver *DriverImplBase) (DatabaseImplBase, error) {
// Ensure extension types are registered before creating the database
registerExtensionTypes()

database := DatabaseImplBase{
Alloc: driver.Alloc,
ErrorHelper: driver.ErrorHelper,
Expand Down
91 changes: 91 additions & 0 deletions go/adbc/driver/internal/driverbase/driver_test.go
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ import (
"github.com/apache/arrow-adbc/go/adbc/validation"
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/array"
"github.com/apache/arrow-go/v18/arrow/extensions"
"github.com/apache/arrow-go/v18/arrow/memory"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -791,3 +793,92 @@ func TestRequiredList(t *testing.T) {
require.NoError(t, json.Unmarshal([]byte(`["d", "e", "f"]`), &v))
assert.Equal(t, driverbase.RequiredList([]string{"d", "e", "f"}), v)
}

func TestExtensionTypesRegistered(t *testing.T) {
// Test that extension types are properly registered when creating a database
alloc := memory.NewCheckedAllocator(memory.DefaultAllocator)
defer alloc.AssertSize(t, 0)

var handler MockedHandler
handler.On("Handle", mock.Anything, mock.Anything).Return(nil)

drv := NewDriver(alloc, &handler, false)
db, err := drv.NewDatabase(nil)
require.NoError(t, err)
defer validation.CheckedClose(t, db)

// Test 1: UUID extension type
t.Run("UUID", func(t *testing.T) {
schema := arrow.NewSchema([]arrow.Field{
{Name: "uuid_col", Type: extensions.NewUUIDType(), Nullable: true},
}, nil)

bldr := array.NewRecordBuilder(alloc, schema)
defer bldr.Release()

uuidBldr := bldr.Field(0).(*extensions.UUIDBuilder)
testUUID := uuid.New()
uuidBldr.Append(testUUID)
uuidBldr.AppendNull()

rec := bldr.NewRecordBatch()
defer rec.Release()

require.Equal(t, int64(2), rec.NumRows())
uuidArr := rec.Column(0).(*extensions.UUIDArray)
require.True(t, uuidArr.IsValid(0))
require.Equal(t, testUUID, uuidArr.Value(0))
require.False(t, uuidArr.IsValid(1))
})

// Test 2: Bool8 extension type
t.Run("Bool8", func(t *testing.T) {
schema := arrow.NewSchema([]arrow.Field{
{Name: "bool8_col", Type: extensions.NewBool8Type(), Nullable: true},
}, nil)

bldr := array.NewRecordBuilder(alloc, schema)
defer bldr.Release()

bool8Bldr := bldr.Field(0).(*extensions.Bool8Builder)
bool8Bldr.Append(true)
bool8Bldr.Append(false)
bool8Bldr.AppendNull()

rec := bldr.NewRecordBatch()
defer rec.Release()

require.Equal(t, int64(3), rec.NumRows())
bool8Arr := rec.Column(0).(*extensions.Bool8Array)
require.True(t, bool8Arr.IsValid(0))
require.Equal(t, true, bool8Arr.Value(0))
require.True(t, bool8Arr.IsValid(1))
require.Equal(t, false, bool8Arr.Value(1))
require.False(t, bool8Arr.IsValid(2))
})

// Test 3: Verify all canonical types can be instantiated
t.Run("AllTypesInstantiable", func(t *testing.T) {
// Just verify we can create instances of all canonical extension types
// without errors (they're registered)
uuidType := extensions.NewUUIDType()
require.NotNil(t, uuidType)
require.Equal(t, "arrow.uuid", uuidType.ExtensionName())

bool8Type := extensions.NewBool8Type()
require.NotNil(t, bool8Type)
require.Equal(t, "arrow.bool8", bool8Type.ExtensionName())

jsonType := &extensions.JSONType{}
require.NotNil(t, jsonType)
require.Equal(t, "arrow.json", jsonType.ExtensionName())

opaqueType := extensions.NewOpaqueType(arrow.BinaryTypes.String, "test.opaque", "test_vendor")
require.NotNil(t, opaqueType)
require.Equal(t, "arrow.opaque", opaqueType.ExtensionName())

variantType := &extensions.VariantType{}
require.NotNil(t, variantType)
require.Equal(t, "parquet.variant", variantType.ExtensionName())
})
}
8 changes: 7 additions & 1 deletion go/adbc/driver/internal/shared_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -694,6 +694,7 @@ const (
XdbcDataType_XDBC_BIT XdbcDataType = -7
XdbcDataType_XDBC_WCHAR XdbcDataType = -8
XdbcDataType_XDBC_WVARCHAR XdbcDataType = -9
XdbcDataType_XDBC_GUID XdbcDataType = -11
)

func ToXdbcDataType(dt arrow.DataType) (xdbcType XdbcDataType) {
Expand All @@ -703,7 +704,12 @@ func ToXdbcDataType(dt arrow.DataType) (xdbcType XdbcDataType) {

switch dt.ID() {
case arrow.EXTENSION:
return ToXdbcDataType(dt.(arrow.ExtensionType).StorageType())
switch dt.(arrow.ExtensionType).ExtensionName() {
case "arrow.uuid":
return XdbcDataType_XDBC_GUID
default:
return ToXdbcDataType(dt.(arrow.ExtensionType).StorageType())
}
case arrow.DICTIONARY:
return ToXdbcDataType(dt.(*arrow.DictionaryType).ValueType)
case arrow.RUN_END_ENCODED:
Expand Down
147 changes: 147 additions & 0 deletions go/adbc/driver/internal/shared_utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package internal_test

import (
"testing"

"github.com/apache/arrow-adbc/go/adbc/driver/internal"
"github.com/apache/arrow-go/v18/arrow"
"github.com/apache/arrow-go/v18/arrow/extensions"
"github.com/stretchr/testify/require"
)

func TestToXdbcDataType(t *testing.T) {
tests := []struct {
name string
dataType arrow.DataType
expected internal.XdbcDataType
}{
{
name: "Nil type",
dataType: nil,
expected: internal.XdbcDataType_XDBC_UNKNOWN_TYPE,
},
{
name: "INT8",
dataType: arrow.PrimitiveTypes.Int8,
expected: internal.XdbcDataType_XDBC_TINYINT,
},
{
name: "INT16",
dataType: arrow.PrimitiveTypes.Int16,
expected: internal.XdbcDataType_XDBC_SMALLINT,
},
{
name: "INT32",
dataType: arrow.PrimitiveTypes.Int32,
expected: internal.XdbcDataType_XDBC_INTEGER,
},
{
name: "INT64",
dataType: arrow.PrimitiveTypes.Int64,
expected: internal.XdbcDataType_XDBC_BIGINT,
},
{
name: "FLOAT32",
dataType: arrow.PrimitiveTypes.Float32,
expected: internal.XdbcDataType_XDBC_FLOAT,
},
{
name: "String",
dataType: arrow.BinaryTypes.String,
expected: internal.XdbcDataType_XDBC_VARCHAR,
},
{
name: "Binary",
dataType: arrow.BinaryTypes.Binary,
expected: internal.XdbcDataType_XDBC_BINARY,
},
{
name: "Boolean",
dataType: arrow.FixedWidthTypes.Boolean,
expected: internal.XdbcDataType_XDBC_BIT,
},
{
name: "Date32",
dataType: arrow.FixedWidthTypes.Date32,
expected: internal.XdbcDataType_XDBC_DATE,
},
{
name: "Timestamp",
dataType: arrow.FixedWidthTypes.Timestamp_us,
expected: internal.XdbcDataType_XDBC_TIMESTAMP,
},
{
name: "UUID Extension Type",
dataType: extensions.NewUUIDType(),
expected: internal.XdbcDataType_XDBC_GUID,
},
{
name: "Bool8 Extension Type",
dataType: extensions.NewBool8Type(),
expected: internal.XdbcDataType_XDBC_TINYINT,
},
{
name: "Opaque Extension Type",
dataType: extensions.NewOpaqueType(arrow.BinaryTypes.String, "test.opaque", "test_vendor"),
expected: internal.XdbcDataType_XDBC_VARCHAR,
},
{
name: "Bool8 Extension Type",
dataType: extensions.NewBool8Type(),
expected: internal.XdbcDataType_XDBC_TINYINT,
},
{
name: "Bool8 Extension Type",
dataType: extensions.NewJSONType(arrow.BinaryTypes.String),
expected: internal.XdbcDataType_XDBC_TINYINT,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := internal.ToXdbcDataType(tt.dataType)
require.Equal(t, tt.expected, result, "Expected XDBC type %v for %s, got %v", tt.expected, tt.name, result)
})
}
}

func TestToXdbcDataType_ExtensionTypes(t *testing.T) {

t.Run("JSON falls back to storage type", func(t *testing.T) {
jsonType, err := extensions.NewJSONType(arrow.BinaryTypes.String)
require.NoError(t, err)
require.NotNil(t, jsonType)
require.Equal(t, "arrow.json", jsonType.ExtensionName())

// JSON storage type is String, which maps to VARCHAR
xdbcType := internal.ToXdbcDataType(jsonType)
require.Equal(t, internal.XdbcDataType_XDBC_VARCHAR, xdbcType)
})

t.Run("Opaque falls back to storage type", func(t *testing.T) {
opaqueType := extensions.NewOpaqueType(arrow.BinaryTypes.Binary, "test.opaque", "test_vendor")
require.NotNil(t, opaqueType)
require.Equal(t, "arrow.opaque", opaqueType.ExtensionName())

// Opaque storage type is Binary, which maps to BINARY
xdbcType := internal.ToXdbcDataType(opaqueType)
require.Equal(t, internal.XdbcDataType_XDBC_BINARY, xdbcType)
})
}
Loading