Skip to content

Commit af1ed38

Browse files
committed
feat: improves some type mapping
1 parent 18aa9d4 commit af1ed38

File tree

1 file changed

+128
-54
lines changed

1 file changed

+128
-54
lines changed

datafusion-postgres/src/datatypes.rs

Lines changed: 128 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -21,34 +21,52 @@ pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
2121
DataType::Int16 | DataType::UInt16 => Type::INT2,
2222
DataType::Int32 | DataType::UInt32 => Type::INT4,
2323
DataType::Int64 | DataType::UInt64 => Type::INT8,
24-
DataType::Timestamp(_, _) => Type::TIMESTAMP,
24+
DataType::Timestamp(_, tz) => {
25+
if tz.is_some() {
26+
Type::TIMESTAMPTZ
27+
} else {
28+
Type::TIMESTAMP
29+
}
30+
}
2531
DataType::Time32(_) | DataType::Time64(_) => Type::TIME,
2632
DataType::Date32 | DataType::Date64 => Type::DATE,
27-
DataType::Binary => Type::BYTEA,
28-
DataType::Float32 => Type::FLOAT4,
33+
DataType::Interval(_) => Type::INTERVAL,
34+
DataType::Binary | DataType::FixedSizeBinary(_) | DataType::LargeBinary => Type::BYTEA,
35+
DataType::Float16 | DataType::Float32 => Type::FLOAT4,
2936
DataType::Float64 => Type::FLOAT8,
3037
DataType::Utf8 => Type::VARCHAR,
31-
DataType::List(field) => match field.data_type() {
32-
DataType::Boolean => Type::BOOL_ARRAY,
33-
DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY,
34-
DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY,
35-
DataType::Int32 | DataType::UInt32 => Type::INT4_ARRAY,
36-
DataType::Int64 | DataType::UInt64 => Type::INT8_ARRAY,
37-
DataType::Timestamp(_, _) => Type::TIMESTAMP_ARRAY,
38-
DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY,
39-
DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY,
40-
DataType::Binary => Type::BYTEA_ARRAY,
41-
DataType::Float32 => Type::FLOAT4_ARRAY,
42-
DataType::Float64 => Type::FLOAT8_ARRAY,
43-
DataType::Utf8 => Type::VARCHAR_ARRAY,
44-
list_type => {
45-
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
46-
"ERROR".to_owned(),
47-
"XX000".to_owned(),
48-
format!("Unsupported List Datatype {list_type}"),
49-
))));
38+
DataType::LargeUtf8 => Type::TEXT,
39+
DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => {
40+
match field.data_type() {
41+
DataType::Boolean => Type::BOOL_ARRAY,
42+
DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY,
43+
DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY,
44+
DataType::Int32 | DataType::UInt32 => Type::INT4_ARRAY,
45+
DataType::Int64 | DataType::UInt64 => Type::INT8_ARRAY,
46+
DataType::Timestamp(_, tz) => {
47+
if tz.is_some() {
48+
Type::TIMESTAMPTZ_ARRAY
49+
} else {
50+
Type::TIMESTAMP_ARRAY
51+
}
52+
}
53+
DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY,
54+
DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY,
55+
DataType::Interval(_) => Type::INTERVAL_ARRAY,
56+
DataType::FixedSizeBinary(_) | DataType::Binary => Type::BYTEA_ARRAY,
57+
DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY,
58+
DataType::Float64 => Type::FLOAT8_ARRAY,
59+
DataType::Utf8 => Type::VARCHAR_ARRAY,
60+
DataType::LargeUtf8 => Type::TEXT_ARRAY,
61+
list_type => {
62+
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
63+
"ERROR".to_owned(),
64+
"XX000".to_owned(),
65+
format!("Unsupported List Datatype {list_type}"),
66+
))));
67+
}
5068
}
51-
},
69+
}
5270
_ => {
5371
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
5472
"ERROR".to_owned(),
@@ -147,6 +165,27 @@ fn get_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> &str {
147165
.value(idx)
148166
}
149167

168+
fn get_large_utf8_value(arr: &Arc<dyn Array>, idx: usize) -> &str {
169+
arr.as_any()
170+
.downcast_ref::<LargeStringArray>()
171+
.unwrap()
172+
.value(idx)
173+
}
174+
175+
fn get_binary_value(arr: &Arc<dyn Array>, idx: usize) -> &[u8] {
176+
arr.as_any()
177+
.downcast_ref::<BinaryArray>()
178+
.unwrap()
179+
.value(idx)
180+
}
181+
182+
fn get_large_binary_value(arr: &Arc<dyn Array>, idx: usize) -> &[u8] {
183+
arr.as_any()
184+
.downcast_ref::<LargeBinaryArray>()
185+
.unwrap()
186+
.value(idx)
187+
}
188+
150189
fn get_date32_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDate> {
151190
arr.as_any()
152191
.downcast_ref::<Date32Array>()
@@ -246,6 +285,9 @@ fn encode_value(
246285
DataType::Float32 => encoder.encode_field(&get_f32_value(arr, idx))?,
247286
DataType::Float64 => encoder.encode_field(&get_f64_value(arr, idx))?,
248287
DataType::Utf8 => encoder.encode_field(&get_utf8_value(arr, idx))?,
288+
DataType::LargeUtf8 => encoder.encode_field(&get_large_utf8_value(arr, idx))?,
289+
DataType::Binary => encoder.encode_field(&get_binary_value(arr, idx))?,
290+
DataType::LargeBinary => encoder.encode_field(&get_large_binary_value(arr, idx))?,
249291
DataType::Date32 => encoder.encode_field(&get_date32_value(arr, idx))?,
250292
DataType::Date64 => encoder.encode_field(&get_date64_value(arr, idx))?,
251293
DataType::Time32(unit) => match unit {
@@ -262,45 +304,77 @@ fn encode_value(
262304
TimeUnit::Nanosecond => encoder.encode_field(&get_time64_nanosecond_value(arr, idx))?,
263305
_ => {}
264306
},
265-
DataType::Timestamp(unit, _) => match unit {
266-
TimeUnit::Second => encoder.encode_field(&get_timestamp_second_value(arr, idx))?,
307+
DataType::Timestamp(unit, timezone) => match unit {
308+
TimeUnit::Second => {
309+
let value = get_timestamp_second_value(arr, idx);
310+
if timezone.is_some() {
311+
let value_tz = value.map(|datetime| datetime.and_utc());
312+
313+
encoder.encode_field(&value_tz)?;
314+
} else {
315+
encoder.encode_field(&value)?
316+
}
317+
}
267318
TimeUnit::Millisecond => {
268-
encoder.encode_field(&get_timestamp_millisecond_value(arr, idx))?
319+
let value = get_timestamp_millisecond_value(arr, idx);
320+
if timezone.is_some() {
321+
let value_tz = value.map(|datetime| datetime.and_utc());
322+
323+
encoder.encode_field(&value_tz)?;
324+
} else {
325+
encoder.encode_field(&value)?
326+
}
269327
}
270328
TimeUnit::Microsecond => {
271-
encoder.encode_field(&get_timestamp_microsecond_value(arr, idx))?
329+
let value = get_timestamp_microsecond_value(arr, idx);
330+
if timezone.is_some() {
331+
let value_tz = value.map(|datetime| datetime.and_utc());
332+
333+
encoder.encode_field(&value_tz)?;
334+
} else {
335+
encoder.encode_field(&value)?
336+
}
272337
}
273338
TimeUnit::Nanosecond => {
274-
encoder.encode_field(&get_timestamp_nanosecond_value(arr, idx))?
339+
let value = get_timestamp_nanosecond_value(arr, idx);
340+
if timezone.is_some() {
341+
let value_tz = value.map(|datetime| datetime.and_utc());
342+
343+
encoder.encode_field(&value_tz)?;
344+
} else {
345+
encoder.encode_field(&value)?
346+
}
275347
}
276348
},
277-
DataType::List(field) => match field.data_type() {
278-
DataType::Null => encoder.encode_field(&None::<i8>)?,
279-
DataType::Boolean => encoder.encode_field(&get_bool_list_value(arr, idx))?,
280-
DataType::Int8 => encoder.encode_field(&get_i8_list_value(arr, idx))?,
281-
DataType::Int16 => encoder.encode_field(&get_i16_list_value(arr, idx))?,
282-
DataType::Int32 => encoder.encode_field(&get_i32_list_value(arr, idx))?,
283-
DataType::Int64 => encoder.encode_field(&get_i64_list_value(arr, idx))?,
284-
DataType::UInt8 => encoder.encode_field(&get_u8_list_value(arr, idx))?,
285-
DataType::UInt16 => encoder.encode_field(&get_u16_list_value(arr, idx))?,
286-
DataType::UInt32 => encoder.encode_field(&get_u32_list_value(arr, idx))?,
287-
DataType::UInt64 => encoder.encode_field(&get_u64_list_value(arr, idx))?,
288-
DataType::Float32 => encoder.encode_field(&get_f32_list_value(arr, idx))?,
289-
DataType::Float64 => encoder.encode_field(&get_f64_list_value(arr, idx))?,
290-
DataType::Utf8 => encoder.encode_field(&get_utf8_list_value(arr, idx))?,
291-
292-
// TODO: more types
293-
list_type => {
294-
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
295-
"ERROR".to_owned(),
296-
"XX000".to_owned(),
297-
format!(
298-
"Unsupported List Datatype {} and array {:?}",
299-
list_type, &arr
300-
),
301-
))))
349+
DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => {
350+
match field.data_type() {
351+
DataType::Null => encoder.encode_field(&None::<i8>)?,
352+
DataType::Boolean => encoder.encode_field(&get_bool_list_value(arr, idx))?,
353+
DataType::Int8 => encoder.encode_field(&get_i8_list_value(arr, idx))?,
354+
DataType::Int16 => encoder.encode_field(&get_i16_list_value(arr, idx))?,
355+
DataType::Int32 => encoder.encode_field(&get_i32_list_value(arr, idx))?,
356+
DataType::Int64 => encoder.encode_field(&get_i64_list_value(arr, idx))?,
357+
DataType::UInt8 => encoder.encode_field(&get_u8_list_value(arr, idx))?,
358+
DataType::UInt16 => encoder.encode_field(&get_u16_list_value(arr, idx))?,
359+
DataType::UInt32 => encoder.encode_field(&get_u32_list_value(arr, idx))?,
360+
DataType::UInt64 => encoder.encode_field(&get_u64_list_value(arr, idx))?,
361+
DataType::Float32 => encoder.encode_field(&get_f32_list_value(arr, idx))?,
362+
DataType::Float64 => encoder.encode_field(&get_f64_list_value(arr, idx))?,
363+
DataType::Utf8 => encoder.encode_field(&get_utf8_list_value(arr, idx))?,
364+
365+
// TODO: more types
366+
list_type => {
367+
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
368+
"ERROR".to_owned(),
369+
"XX000".to_owned(),
370+
format!(
371+
"Unsupported List Datatype {} and array {:?}",
372+
list_type, &arr
373+
),
374+
))))
375+
}
302376
}
303-
},
377+
}
304378
_ => {
305379
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
306380
"ERROR".to_owned(),

0 commit comments

Comments
 (0)