From d6a5b9de6565e49ee61d5a4fd7860389510d07f1 Mon Sep 17 00:00:00 2001 From: jhrotko Date: Fri, 12 Dec 2025 15:42:22 +0000 Subject: [PATCH] register extension types in database --- .../driver/internal/driverbase/database.go | 33 ++++ .../driver/internal/driverbase/driver_test.go | 91 +++++++++++ go/adbc/driver/internal/shared_utils.go | 8 +- go/adbc/driver/internal/shared_utils_test.go | 147 ++++++++++++++++++ 4 files changed, 278 insertions(+), 1 deletion(-) mode change 100644 => 100755 go/adbc/driver/internal/driverbase/driver_test.go create mode 100644 go/adbc/driver/internal/shared_utils_test.go diff --git a/go/adbc/driver/internal/driverbase/database.go b/go/adbc/driver/internal/driverbase/database.go index 16a4edb807..122070a3da 100644 --- a/go/adbc/driver/internal/driverbase/database.go +++ b/go/adbc/driver/internal/driverbase/database.go @@ -27,6 +27,8 @@ import ( "time" "github.com/apache/arrow-adbc/go/adbc" + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/extensions" "github.com/apache/arrow-go/v18/arrow/memory" "go.opentelemetry.io/otel" "go.opentelemetry.io/otel/attribute" @@ -79,6 +81,34 @@ var getExporterName = sync.OnceValue(func() string { return os.Getenv(otelTracesExporter) }) +// registerExtensionTypes ensures that canonical Arrow extension types are registered. +// This is called once during driver initialization to make sure extension types like +// UUID are available for use throughout the driver. +var registerExtensionTypes = sync.OnceFunc(func() { + // The arrow/extensions package automatically registers canonical extension types + // (UUID, Bool8, JSON, Opaque, Variant) in its init() function. + // However, we explicitly ensure registration here in case the package wasn't + // imported elsewhere, and to handle any registration errors gracefully. + + // List of canonical extension types to ensure are registered + canonicalTypes := []arrow.ExtensionType{ + extensions.NewUUIDType(), + extensions.NewBool8Type(), + &extensions.JSONType{}, + &extensions.OpaqueType{}, + &extensions.VariantType{}, + } + + for _, extType := range canonicalTypes { + // RegisterExtensionType is idempotent - it returns an error only if + // a different type with the same name is already registered + if err := arrow.RegisterExtensionType(extType); err != nil { + // Log but don't fail - the type might already be registered + // which is fine (the extensions package init() may have done it) + } + } +}) + // DatabaseImpl is an interface that drivers implement to provide // vendor-specific functionality. type DatabaseImpl interface { @@ -115,6 +145,9 @@ type DatabaseImplBase struct { // - driver is a DriverImplBase containing the common resources from the parent // driver, allowing the Arrow allocator and error handler to be reused. func NewDatabaseImplBase(ctx context.Context, driver *DriverImplBase) (DatabaseImplBase, error) { + // Ensure extension types are registered before creating the database + registerExtensionTypes() + database := DatabaseImplBase{ Alloc: driver.Alloc, ErrorHelper: driver.ErrorHelper, diff --git a/go/adbc/driver/internal/driverbase/driver_test.go b/go/adbc/driver/internal/driverbase/driver_test.go old mode 100644 new mode 100755 index 1f1a24f5f9..495ce5a155 --- a/go/adbc/driver/internal/driverbase/driver_test.go +++ b/go/adbc/driver/internal/driverbase/driver_test.go @@ -31,7 +31,9 @@ import ( "github.com/apache/arrow-adbc/go/adbc/validation" "github.com/apache/arrow-go/v18/arrow" "github.com/apache/arrow-go/v18/arrow/array" + "github.com/apache/arrow-go/v18/arrow/extensions" "github.com/apache/arrow-go/v18/arrow/memory" + "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" @@ -791,3 +793,92 @@ func TestRequiredList(t *testing.T) { require.NoError(t, json.Unmarshal([]byte(`["d", "e", "f"]`), &v)) assert.Equal(t, driverbase.RequiredList([]string{"d", "e", "f"}), v) } + +func TestExtensionTypesRegistered(t *testing.T) { + // Test that extension types are properly registered when creating a database + alloc := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer alloc.AssertSize(t, 0) + + var handler MockedHandler + handler.On("Handle", mock.Anything, mock.Anything).Return(nil) + + drv := NewDriver(alloc, &handler, false) + db, err := drv.NewDatabase(nil) + require.NoError(t, err) + defer validation.CheckedClose(t, db) + + // Test 1: UUID extension type + t.Run("UUID", func(t *testing.T) { + schema := arrow.NewSchema([]arrow.Field{ + {Name: "uuid_col", Type: extensions.NewUUIDType(), Nullable: true}, + }, nil) + + bldr := array.NewRecordBuilder(alloc, schema) + defer bldr.Release() + + uuidBldr := bldr.Field(0).(*extensions.UUIDBuilder) + testUUID := uuid.New() + uuidBldr.Append(testUUID) + uuidBldr.AppendNull() + + rec := bldr.NewRecordBatch() + defer rec.Release() + + require.Equal(t, int64(2), rec.NumRows()) + uuidArr := rec.Column(0).(*extensions.UUIDArray) + require.True(t, uuidArr.IsValid(0)) + require.Equal(t, testUUID, uuidArr.Value(0)) + require.False(t, uuidArr.IsValid(1)) + }) + + // Test 2: Bool8 extension type + t.Run("Bool8", func(t *testing.T) { + schema := arrow.NewSchema([]arrow.Field{ + {Name: "bool8_col", Type: extensions.NewBool8Type(), Nullable: true}, + }, nil) + + bldr := array.NewRecordBuilder(alloc, schema) + defer bldr.Release() + + bool8Bldr := bldr.Field(0).(*extensions.Bool8Builder) + bool8Bldr.Append(true) + bool8Bldr.Append(false) + bool8Bldr.AppendNull() + + rec := bldr.NewRecordBatch() + defer rec.Release() + + require.Equal(t, int64(3), rec.NumRows()) + bool8Arr := rec.Column(0).(*extensions.Bool8Array) + require.True(t, bool8Arr.IsValid(0)) + require.Equal(t, true, bool8Arr.Value(0)) + require.True(t, bool8Arr.IsValid(1)) + require.Equal(t, false, bool8Arr.Value(1)) + require.False(t, bool8Arr.IsValid(2)) + }) + + // Test 3: Verify all canonical types can be instantiated + t.Run("AllTypesInstantiable", func(t *testing.T) { + // Just verify we can create instances of all canonical extension types + // without errors (they're registered) + uuidType := extensions.NewUUIDType() + require.NotNil(t, uuidType) + require.Equal(t, "arrow.uuid", uuidType.ExtensionName()) + + bool8Type := extensions.NewBool8Type() + require.NotNil(t, bool8Type) + require.Equal(t, "arrow.bool8", bool8Type.ExtensionName()) + + jsonType := &extensions.JSONType{} + require.NotNil(t, jsonType) + require.Equal(t, "arrow.json", jsonType.ExtensionName()) + + opaqueType := extensions.NewOpaqueType(arrow.BinaryTypes.String, "test.opaque", "test_vendor") + require.NotNil(t, opaqueType) + require.Equal(t, "arrow.opaque", opaqueType.ExtensionName()) + + variantType := &extensions.VariantType{} + require.NotNil(t, variantType) + require.Equal(t, "parquet.variant", variantType.ExtensionName()) + }) +} diff --git a/go/adbc/driver/internal/shared_utils.go b/go/adbc/driver/internal/shared_utils.go index 6ea07fcbed..90724362b6 100644 --- a/go/adbc/driver/internal/shared_utils.go +++ b/go/adbc/driver/internal/shared_utils.go @@ -694,6 +694,7 @@ const ( XdbcDataType_XDBC_BIT XdbcDataType = -7 XdbcDataType_XDBC_WCHAR XdbcDataType = -8 XdbcDataType_XDBC_WVARCHAR XdbcDataType = -9 + XdbcDataType_XDBC_GUID XdbcDataType = -11 ) func ToXdbcDataType(dt arrow.DataType) (xdbcType XdbcDataType) { @@ -703,7 +704,12 @@ func ToXdbcDataType(dt arrow.DataType) (xdbcType XdbcDataType) { switch dt.ID() { case arrow.EXTENSION: - return ToXdbcDataType(dt.(arrow.ExtensionType).StorageType()) + switch dt.(arrow.ExtensionType).ExtensionName() { + case "arrow.uuid": + return XdbcDataType_XDBC_GUID + default: + return ToXdbcDataType(dt.(arrow.ExtensionType).StorageType()) + } case arrow.DICTIONARY: return ToXdbcDataType(dt.(*arrow.DictionaryType).ValueType) case arrow.RUN_END_ENCODED: diff --git a/go/adbc/driver/internal/shared_utils_test.go b/go/adbc/driver/internal/shared_utils_test.go new file mode 100644 index 0000000000..4c25cafe32 --- /dev/null +++ b/go/adbc/driver/internal/shared_utils_test.go @@ -0,0 +1,147 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package internal_test + +import ( + "testing" + + "github.com/apache/arrow-adbc/go/adbc/driver/internal" + "github.com/apache/arrow-go/v18/arrow" + "github.com/apache/arrow-go/v18/arrow/extensions" + "github.com/stretchr/testify/require" +) + +func TestToXdbcDataType(t *testing.T) { + tests := []struct { + name string + dataType arrow.DataType + expected internal.XdbcDataType + }{ + { + name: "Nil type", + dataType: nil, + expected: internal.XdbcDataType_XDBC_UNKNOWN_TYPE, + }, + { + name: "INT8", + dataType: arrow.PrimitiveTypes.Int8, + expected: internal.XdbcDataType_XDBC_TINYINT, + }, + { + name: "INT16", + dataType: arrow.PrimitiveTypes.Int16, + expected: internal.XdbcDataType_XDBC_SMALLINT, + }, + { + name: "INT32", + dataType: arrow.PrimitiveTypes.Int32, + expected: internal.XdbcDataType_XDBC_INTEGER, + }, + { + name: "INT64", + dataType: arrow.PrimitiveTypes.Int64, + expected: internal.XdbcDataType_XDBC_BIGINT, + }, + { + name: "FLOAT32", + dataType: arrow.PrimitiveTypes.Float32, + expected: internal.XdbcDataType_XDBC_FLOAT, + }, + { + name: "String", + dataType: arrow.BinaryTypes.String, + expected: internal.XdbcDataType_XDBC_VARCHAR, + }, + { + name: "Binary", + dataType: arrow.BinaryTypes.Binary, + expected: internal.XdbcDataType_XDBC_BINARY, + }, + { + name: "Boolean", + dataType: arrow.FixedWidthTypes.Boolean, + expected: internal.XdbcDataType_XDBC_BIT, + }, + { + name: "Date32", + dataType: arrow.FixedWidthTypes.Date32, + expected: internal.XdbcDataType_XDBC_DATE, + }, + { + name: "Timestamp", + dataType: arrow.FixedWidthTypes.Timestamp_us, + expected: internal.XdbcDataType_XDBC_TIMESTAMP, + }, + { + name: "UUID Extension Type", + dataType: extensions.NewUUIDType(), + expected: internal.XdbcDataType_XDBC_GUID, + }, + { + name: "Bool8 Extension Type", + dataType: extensions.NewBool8Type(), + expected: internal.XdbcDataType_XDBC_TINYINT, + }, + { + name: "Opaque Extension Type", + dataType: extensions.NewOpaqueType(arrow.BinaryTypes.String, "test.opaque", "test_vendor"), + expected: internal.XdbcDataType_XDBC_VARCHAR, + }, + { + name: "Bool8 Extension Type", + dataType: extensions.NewBool8Type(), + expected: internal.XdbcDataType_XDBC_TINYINT, + }, + { + name: "Bool8 Extension Type", + dataType: extensions.NewJSONType(arrow.BinaryTypes.String), + expected: internal.XdbcDataType_XDBC_TINYINT, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := internal.ToXdbcDataType(tt.dataType) + require.Equal(t, tt.expected, result, "Expected XDBC type %v for %s, got %v", tt.expected, tt.name, result) + }) + } +} + +func TestToXdbcDataType_ExtensionTypes(t *testing.T) { + + t.Run("JSON falls back to storage type", func(t *testing.T) { + jsonType, err := extensions.NewJSONType(arrow.BinaryTypes.String) + require.NoError(t, err) + require.NotNil(t, jsonType) + require.Equal(t, "arrow.json", jsonType.ExtensionName()) + + // JSON storage type is String, which maps to VARCHAR + xdbcType := internal.ToXdbcDataType(jsonType) + require.Equal(t, internal.XdbcDataType_XDBC_VARCHAR, xdbcType) + }) + + t.Run("Opaque falls back to storage type", func(t *testing.T) { + opaqueType := extensions.NewOpaqueType(arrow.BinaryTypes.Binary, "test.opaque", "test_vendor") + require.NotNil(t, opaqueType) + require.Equal(t, "arrow.opaque", opaqueType.ExtensionName()) + + // Opaque storage type is Binary, which maps to BINARY + xdbcType := internal.ToXdbcDataType(opaqueType) + require.Equal(t, internal.XdbcDataType_XDBC_BINARY, xdbcType) + }) +}