Skip to content

Commit 8ae14b2

Browse files
committed
Async bind callback
1 parent c4fd303 commit 8ae14b2

File tree

4 files changed

+161
-115
lines changed

4 files changed

+161
-115
lines changed

DuckDB.NET.Data/DuckDBConnection.TableFunction.cs

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
using System.Diagnostics.CodeAnalysis;
1010
using System.Runtime.CompilerServices;
1111
using System.Runtime.InteropServices;
12+
using System.Threading;
13+
using System.Threading.Tasks;
1214

1315
namespace DuckDB.NET.Data;
1416

@@ -20,55 +22,55 @@ partial class DuckDBConnection
2022
{
2123
#if NET8_0_OR_GREATER
2224
[Experimental("DuckDBNET001")]
23-
public void RegisterTableFunction<T>(string name, Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
25+
public void RegisterTableFunction<T>(string name, Func<IReadOnlyList<IDuckDBValueReader>, Task<TableFunction>> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
2426
{
2527
RegisterTableFunctionInternal(name, resultCallback, mapperCallback, typeof(T));
2628
}
2729

2830
[Experimental("DuckDBNET001")]
29-
public void RegisterTableFunction<T1, T2>(string name, Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
31+
public void RegisterTableFunction<T1, T2>(string name, Func<IReadOnlyList<IDuckDBValueReader>, Task<TableFunction>> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
3032
{
3133
RegisterTableFunctionInternal(name, resultCallback, mapperCallback, typeof(T1), typeof(T2));
3234
}
3335

3436
[Experimental("DuckDBNET001")]
35-
public void RegisterTableFunction<T1, T2, T3>(string name, Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
37+
public void RegisterTableFunction<T1, T2, T3>(string name, Func<IReadOnlyList<IDuckDBValueReader>, Task<TableFunction>> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
3638
{
3739
RegisterTableFunctionInternal(name, resultCallback, mapperCallback, typeof(T1), typeof(T2), typeof(T3));
3840
}
3941

4042
[Experimental("DuckDBNET001")]
41-
public void RegisterTableFunction<T1, T2, T3, T4>(string name, Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
43+
public void RegisterTableFunction<T1, T2, T3, T4>(string name, Func<IReadOnlyList<IDuckDBValueReader>, Task<TableFunction>> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
4244
{
4345
RegisterTableFunctionInternal(name, resultCallback, mapperCallback, typeof(T1), typeof(T2), typeof(T3), typeof(T4));
4446
}
4547

4648
[Experimental("DuckDBNET001")]
47-
public void RegisterTableFunction<T1, T2, T3, T4, T5>(string name, Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
49+
public void RegisterTableFunction<T1, T2, T3, T4, T5>(string name, Func<IReadOnlyList<IDuckDBValueReader>, Task<TableFunction>> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
4850
{
4951
RegisterTableFunctionInternal(name, resultCallback, mapperCallback, typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5));
5052
}
5153

5254
[Experimental("DuckDBNET001")]
53-
public void RegisterTableFunction<T1, T2, T3, T4, T5, T6>(string name, Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
55+
public void RegisterTableFunction<T1, T2, T3, T4, T5, T6>(string name, Func<IReadOnlyList<IDuckDBValueReader>, Task<TableFunction>> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
5456
{
5557
RegisterTableFunctionInternal(name, resultCallback, mapperCallback, typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6));
5658
}
5759

5860
[Experimental("DuckDBNET001")]
59-
public void RegisterTableFunction<T1, T2, T3, T4, T5, T6, T7>(string name, Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
61+
public void RegisterTableFunction<T1, T2, T3, T4, T5, T6, T7>(string name, Func<IReadOnlyList<IDuckDBValueReader>, Task<TableFunction>> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
6062
{
6163
RegisterTableFunctionInternal(name, resultCallback, mapperCallback, typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6), typeof(T7));
6264
}
6365

6466
[Experimental("DuckDBNET001")]
65-
public void RegisterTableFunction<T1, T2, T3, T4, T5, T6, T7, T8>(string name, Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
67+
public void RegisterTableFunction<T1, T2, T3, T4, T5, T6, T7, T8>(string name, Func<IReadOnlyList<IDuckDBValueReader>, Task<TableFunction>> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback)
6668
{
6769
RegisterTableFunctionInternal(name, resultCallback, mapperCallback, typeof(T1), typeof(T2), typeof(T3), typeof(T4), typeof(T5), typeof(T6), typeof(T7), typeof(T8));
6870
}
6971

7072
[Experimental("DuckDBNET001")]
71-
private unsafe void RegisterTableFunctionInternal(string name, Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback, params Type[] parameterTypes)
73+
private unsafe void RegisterTableFunctionInternal(string name, Func<IReadOnlyList<IDuckDBValueReader>, Task<TableFunction>> resultCallback, Action<object?, IDuckDBDataWriter[], ulong> mapperCallback, params Type[] parameterTypes)
7274
{
7375
var function = NativeMethods.TableFunction.DuckDBCreateTableFunction();
7476
using (var handle = name.ToUnmanagedString())
@@ -100,7 +102,7 @@ private unsafe void RegisterTableFunctionInternal(string name, Func<IReadOnlyLis
100102
}
101103

102104
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]
103-
public static unsafe void Bind(IntPtr info)
105+
public static async void Bind(IntPtr info)
104106
{
105107
var handle = GCHandle.FromIntPtr(NativeMethods.TableFunction.DuckDBBindGetExtraInfo(info));
106108

@@ -117,7 +119,7 @@ public static unsafe void Bind(IntPtr info)
117119
parameters[i] = value;
118120
}
119121

120-
var tableFunctionData = functionInfo.Bind(parameters);
122+
var tableFunctionData = await functionInfo.Bind(parameters);
121123

122124
foreach (var parameter in parameters)
123125
{
@@ -132,7 +134,10 @@ public static unsafe void Bind(IntPtr info)
132134

133135
var bindData = new TableFunctionBindData(tableFunctionData.Columns, tableFunctionData.Data.GetEnumerator());
134136

135-
NativeMethods.TableFunction.DuckDBBindSetBindData(info, bindData.ToHandle(), &DestroyExtraInfo);
137+
unsafe
138+
{
139+
NativeMethods.TableFunction.DuckDBBindSetBindData(info, bindData.ToHandle(), &DestroyExtraInfo);
140+
}
136141
}
137142

138143
[UnmanagedCallersOnly(CallConvs = [typeof(CallConvCdecl)])]

DuckDB.NET.Data/Internal/TableFunctionInfo.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
using System;
44
using System.Collections;
55
using System.Collections.Generic;
6+
using System.Threading.Tasks;
67

78
namespace DuckDB.NET.Data.Internal;
89

9-
class TableFunctionInfo(Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> bind, Action<object?, VectorDataWriterBase[], ulong> mapper)
10+
class TableFunctionInfo(Func<IReadOnlyList<IDuckDBValueReader>, Task<TableFunction>> bind, Action<object?, VectorDataWriterBase[], ulong> mapper)
1011
{
11-
public Func<IReadOnlyList<IDuckDBValueReader>, TableFunction> Bind { get; private set; } = bind;
12+
public Func<IReadOnlyList<IDuckDBValueReader>, Task<TableFunction>> Bind { get; private set; } = bind;
1213
public Action<object?, VectorDataWriterBase[], ulong> Mapper { get; private set; } = mapper;
1314
}
1415

DuckDB.NET.Test/TableFunctionTests.cs

Lines changed: 140 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Collections.Generic;
66
using System.Diagnostics.CodeAnalysis;
77
using System.Linq;
8+
using Octokit;
89
using Xunit;
910

1011
namespace DuckDB.NET.Test;
@@ -15,7 +16,7 @@ public class TableFunctionTests(DuckDBDatabaseFixture db) : DuckDBTestBase(db)
1516
[Fact]
1617
public void RegisterTableFunctionWithOneParameter()
1718
{
18-
Connection.RegisterTableFunction<int>("demo", (parameters) =>
19+
Connection.RegisterTableFunction<int>("demo", async (parameters) =>
1920
{
2021
var value = parameters[0].GetValue<int>();
2122

@@ -33,119 +34,157 @@ public void RegisterTableFunctionWithOneParameter()
3334
}
3435

3536
[Fact]
36-
public void RegisterTableFunctionWithTwoParameterTwoColumns()
37+
public void RegisterTableFunctionWithTwoParameters()
3738
{
38-
var count = 50;
39-
40-
Connection.RegisterTableFunction<short, string>("demo2", (parameters) =>
41-
{
42-
var start = parameters[0].GetValue<short>();
43-
var prefix = parameters[1].GetValue<string>();
44-
45-
return new TableFunction(new List<ColumnInfo>()
46-
{
47-
new ColumnInfo("foo", typeof(int)),
48-
new ColumnInfo("bar", typeof(string)),
49-
}, Enumerable.Range(start, count).Select(index => KeyValuePair.Create(index, prefix + index)));
50-
}, (item, writers, rowIndex) =>
39+
Connection.RegisterTableFunction<string, int>("github_search", async (parameters) =>
5140
{
52-
var pair = (KeyValuePair<int, string>)item;
53-
writers[0].WriteValue(pair.Key, rowIndex);
54-
writers[1].WriteValue(pair.Value, rowIndex);
55-
});
56-
57-
var data = Connection.Query<(int, string)>($"SELECT * FROM demo2(30::SmallInt, 'DuckDB');").ToList();
58-
59-
data.Select(tuple => tuple.Item1).Should().BeEquivalentTo(Enumerable.Range(30, count));
60-
data.Select(tuple => tuple.Item2).Should().BeEquivalentTo(Enumerable.Range(30, count).Select(i => $"DuckDB{i}"));
61-
}
41+
var term = parameters[0].GetValue<string>();
42+
var stars = parameters[1].GetValue<int>();
6243

63-
[Fact]
64-
public void RegisterTableFunctionWithThreeParameters()
65-
{
66-
var count = 30;
67-
var startDate = new DateTime(2024, 11, 6);
68-
var minutesParam = 10;
69-
var secondsParam = 2.5;
70-
71-
Connection.RegisterTableFunction<DateTime, long, double>("demo3", (parameters) =>
72-
{
73-
var date = parameters[0].GetValue<DateTime>();
74-
var minutes = parameters[1].GetValue<long>();
75-
var seconds = parameters[2].GetValue<double>();
44+
var client = new GitHubClient(new ProductHeaderValue("DuckDB-Table-Valued-Function"));
7645

77-
return new TableFunction(new List<ColumnInfo>()
46+
var request = new SearchRepositoriesRequest(term)
7847
{
79-
new ColumnInfo("foo", typeof(DateTime)),
80-
}, Enumerable.Range(0, count).Select(i => date.AddDays(i).AddMinutes(minutes).AddSeconds(seconds)));
81-
}, (item, writers, rowIndex) =>
82-
{
83-
writers[0].WriteValue((DateTime)item, rowIndex);
84-
});
85-
86-
var data = Connection.Query<DateTime>($"SELECT * FROM demo3('2024-11-06'::TIMESTAMP, 10, 2.5 );").ToList();
48+
Stars = new Octokit.Range(stars, SearchQualifierOperator.GreaterThan)
49+
};
8750

88-
var dateTimes = Enumerable.Range(0, count).Select(i => startDate.AddDays(i).AddMinutes(minutesParam).AddSeconds(secondsParam));
89-
data.Should().BeEquivalentTo(dateTimes);
90-
}
51+
var result = await client.Search.SearchRepo(request);
9152

92-
[Fact]
93-
public void RegisterTableFunctionWithFourParameters()
94-
{
95-
var guid = Guid.NewGuid();
96-
97-
Connection.RegisterTableFunction<bool, decimal, byte, Guid>("demo4", (parameters) =>
98-
{
99-
var param1 = parameters[0].GetValue<bool>();
100-
var param2 = parameters[1].GetValue<decimal>();
101-
var param3 = parameters[2].GetValue<byte>();
102-
var param4 = parameters[3].GetValue<Guid>();
103-
104-
var enumerable = param4.ToByteArray(param1).Append(param3);
105-
10653
return new TableFunction(new List<ColumnInfo>()
10754
{
108-
new ColumnInfo("foo", typeof(byte)),
109-
}, enumerable);
55+
new ColumnInfo("name", typeof(string)),
56+
new ColumnInfo("description", typeof(string)),
57+
new ColumnInfo("stargazers", typeof(int)),
58+
new ColumnInfo("url", typeof(string)),
59+
new ColumnInfo("owner", typeof(string)),
60+
}, result.Items);
11061
}, (item, writers, rowIndex) =>
11162
{
112-
writers[0].WriteValue((byte)item, rowIndex);
63+
var repo = (Repository)item;
64+
writers[0].WriteValue(repo.Name, rowIndex);
65+
writers[1].WriteValue(repo.Description, rowIndex);
66+
writers[2].WriteValue(repo.StargazersCount, rowIndex);
67+
writers[3].WriteValue(repo.Url, rowIndex);
68+
writers[4].WriteValue(repo.Owner.Login, rowIndex);
11369
});
11470

115-
var data = Connection.Query<byte>($"SELECT * FROM demo4(false, 10::DECIMAL(18, 3), 4::UTINYINT, '{guid}'::UUID );").ToList();
116-
117-
var bytes = guid.ToByteArray(false).Append((byte)4);
118-
data.Should().BeEquivalentTo(bytes);
71+
var data = Connection.Query<(string, string, int, string, string)>("SELECT * FROM github_search('duckdb', 400);").ToList();
11972
}
12073

121-
[Fact]
122-
public void RegisterTableFunctionWithEmptyResult()
123-
{
124-
Connection.RegisterTableFunction<sbyte, ushort, uint, ulong, float>("demo5", (parameters) =>
125-
{
126-
var param1 = parameters[0].GetValue<sbyte>();
127-
var param2 = parameters[1].GetValue<ushort>();
128-
var param3 = parameters[2].GetValue<uint>();
129-
var param4 = parameters[3].GetValue<ulong>();
130-
var param5 = parameters[4].GetValue<float>();
131-
132-
param1.Should().Be(1);
133-
param2.Should().Be(2);
134-
param3.Should().Be(3);
135-
param4.Should().Be(4);
136-
param5.Should().Be(5.6f);
137-
138-
return new TableFunction(new List<ColumnInfo>()
139-
{
140-
new ColumnInfo("foo", typeof(int)),
141-
}, Enumerable.Empty<int>());
142-
}, (item, writers, rowIndex) =>
143-
{
144-
writers[0].WriteValue((int)item, rowIndex);
145-
});
146-
147-
var data = Connection.Query<int>($"SELECT * FROM demo5(1::TINYINT, 2::USMALLINT, 3::UINTEGER, 4::UBIGINT, 5.6);").ToList();
148-
149-
data.Should().BeEquivalentTo(Enumerable.Empty<int>());
150-
}
74+
//[Fact]
75+
//public void RegisterTableFunctionWithTwoParameterTwoColumns()
76+
//{
77+
// var count = 50;
78+
79+
// Connection.RegisterTableFunction<short, string>("demo2", (parameters) =>
80+
// {
81+
// var start = parameters[0].GetValue<short>();
82+
// var prefix = parameters[1].GetValue<string>();
83+
84+
// return new TableFunction(new List<ColumnInfo>()
85+
// {
86+
// new ColumnInfo("foo", typeof(int)),
87+
// new ColumnInfo("bar", typeof(string)),
88+
// }, Enumerable.Range(start, count).Select(index => KeyValuePair.Create(index, prefix + index)));
89+
// }, (item, writers, rowIndex) =>
90+
// {
91+
// var pair = (KeyValuePair<int, string>)item;
92+
// writers[0].WriteValue(pair.Key, rowIndex);
93+
// writers[1].WriteValue(pair.Value, rowIndex);
94+
// });
95+
96+
// var data = Connection.Query<(int, string)>($"SELECT * FROM demo2(30::SmallInt, 'DuckDB');").ToList();
97+
98+
// data.Select(tuple => tuple.Item1).Should().BeEquivalentTo(Enumerable.Range(30, count));
99+
// data.Select(tuple => tuple.Item2).Should().BeEquivalentTo(Enumerable.Range(30, count).Select(i => $"DuckDB{i}"));
100+
//}
101+
102+
//[Fact]
103+
//public void RegisterTableFunctionWithThreeParameters()
104+
//{
105+
// var count = 30;
106+
// var startDate = new DateTime(2024, 11, 6);
107+
// var minutesParam = 10;
108+
// var secondsParam = 2.5;
109+
110+
// Connection.RegisterTableFunction<DateTime, long, double>("demo3", (parameters) =>
111+
// {
112+
// var date = parameters[0].GetValue<DateTime>();
113+
// var minutes = parameters[1].GetValue<long>();
114+
// var seconds = parameters[2].GetValue<double>();
115+
116+
// return new TableFunction(new List<ColumnInfo>()
117+
// {
118+
// new ColumnInfo("foo", typeof(DateTime)),
119+
// }, Enumerable.Range(0, count).Select(i => date.AddDays(i).AddMinutes(minutes).AddSeconds(seconds)));
120+
// }, (item, writers, rowIndex) =>
121+
// {
122+
// writers[0].WriteValue((DateTime)item, rowIndex);
123+
// });
124+
125+
// var data = Connection.Query<DateTime>($"SELECT * FROM demo3('2024-11-06'::TIMESTAMP, 10, 2.5 );").ToList();
126+
127+
// var dateTimes = Enumerable.Range(0, count).Select(i => startDate.AddDays(i).AddMinutes(minutesParam).AddSeconds(secondsParam));
128+
// data.Should().BeEquivalentTo(dateTimes);
129+
//}
130+
131+
//[Fact]
132+
//public void RegisterTableFunctionWithFourParameters()
133+
//{
134+
// var guid = Guid.NewGuid();
135+
136+
// Connection.RegisterTableFunction<bool, decimal, byte, Guid>("demo4", (parameters) =>
137+
// {
138+
// var param1 = parameters[0].GetValue<bool>();
139+
// var param2 = parameters[1].GetValue<decimal>();
140+
// var param3 = parameters[2].GetValue<byte>();
141+
// var param4 = parameters[3].GetValue<Guid>();
142+
143+
// var enumerable = param4.ToByteArray(param1).Append(param3);
144+
145+
// return new TableFunction(new List<ColumnInfo>()
146+
// {
147+
// new ColumnInfo("foo", typeof(byte)),
148+
// }, enumerable);
149+
// }, (item, writers, rowIndex) =>
150+
// {
151+
// writers[0].WriteValue((byte)item, rowIndex);
152+
// });
153+
154+
// var data = Connection.Query<byte>($"SELECT * FROM demo4(false, 10::DECIMAL(18, 3), 4::UTINYINT, '{guid}'::UUID );").ToList();
155+
156+
// var bytes = guid.ToByteArray(false).Append((byte)4);
157+
// data.Should().BeEquivalentTo(bytes);
158+
//}
159+
160+
//[Fact]
161+
//public void RegisterTableFunctionWithEmptyResult()
162+
//{
163+
// Connection.RegisterTableFunction<sbyte, ushort, uint, ulong, float>("demo5", (parameters) =>
164+
// {
165+
// var param1 = parameters[0].GetValue<sbyte>();
166+
// var param2 = parameters[1].GetValue<ushort>();
167+
// var param3 = parameters[2].GetValue<uint>();
168+
// var param4 = parameters[3].GetValue<ulong>();
169+
// var param5 = parameters[4].GetValue<float>();
170+
171+
// param1.Should().Be(1);
172+
// param2.Should().Be(2);
173+
// param3.Should().Be(3);
174+
// param4.Should().Be(4);
175+
// param5.Should().Be(5.6f);
176+
177+
// return new TableFunction(new List<ColumnInfo>()
178+
// {
179+
// new ColumnInfo("foo", typeof(int)),
180+
// }, Enumerable.Empty<int>());
181+
// }, (item, writers, rowIndex) =>
182+
// {
183+
// writers[0].WriteValue((int)item, rowIndex);
184+
// });
185+
186+
// var data = Connection.Query<int>($"SELECT * FROM demo5(1::TINYINT, 2::USMALLINT, 3::UINTEGER, 4::UBIGINT, 5.6);").ToList();
187+
188+
// data.Should().BeEquivalentTo(Enumerable.Empty<int>());
189+
//}
151190
}

0 commit comments

Comments
 (0)