From 7d789efb6c4ccc2380d2a780682bcef104b9c7af Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 8 Aug 2024 23:46:52 +0100 Subject: [PATCH] make json_get_int and json_get_float more lax --- src/common.rs | 8 +++++++- src/json_get_float.rs | 14 ++++++++------ src/json_get_int.rs | 15 ++++++++------- src/rewrite.rs | 2 +- tests/main.rs | 22 ++++++++++++++++++++++ 5 files changed, 46 insertions(+), 15 deletions(-) diff --git a/src/common.rs b/src/common.rs index 9cbee92..75ef382 100644 --- a/src/common.rs +++ b/src/common.rs @@ -4,7 +4,7 @@ use arrow::array::{Array, ArrayRef, Int64Array, LargeStringArray, StringArray, U use arrow_schema::DataType; use datafusion_common::{exec_err, plan_err, Result as DataFusionResult, ScalarValue}; use datafusion_expr::ColumnarValue; -use jiter::{Jiter, JiterError, Peek}; +use jiter::{Jiter, JiterError, JsonError, Peek}; use crate::common_union::{is_json_union, json_from_union_scalar, nested_json_array}; @@ -230,6 +230,12 @@ impl From for GetError { } } +impl From for GetError { + fn from(_: JsonError) -> Self { + GetError + } +} + impl From for GetError { fn from(_: Utf8Error) -> Self { GetError diff --git a/src/json_get_float.rs b/src/json_get_float.rs index bed8c67..6ca3b9f 100644 --- a/src/json_get_float.rs +++ b/src/json_get_float.rs @@ -65,7 +65,8 @@ impl ScalarUDFImpl for JsonGetFloat { fn jiter_json_get_float(json_data: Option<&str>, path: &[JsonPath]) -> Result { if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { - match peek { + let n = match peek { + // Peek::String => NumberAny::try_from(jiter.next_bytes()?)?, // numbers are represented by everything else in peek, hence doing it this way Peek::Null | Peek::True @@ -75,11 +76,12 @@ fn jiter_json_get_float(json_data: Option<&str>, path: &[JsonPath]) -> Result get_err!(), - _ => match jiter.known_number(peek)? { - NumberAny::Float(f) => Ok(f), - NumberAny::Int(int) => Ok(int.into()), - }, + | Peek::Object => return get_err!(), + _ => jiter.known_number(peek)?, + }; + match n { + NumberAny::Float(f) => Ok(f), + NumberAny::Int(int) => Ok(int.into()), } } else { get_err!() diff --git a/src/json_get_int.rs b/src/json_get_int.rs index 4f80256..71be4d8 100644 --- a/src/json_get_int.rs +++ b/src/json_get_int.rs @@ -65,7 +65,8 @@ impl ScalarUDFImpl for JsonGetInt { fn jiter_json_get_int(json_data: Option<&str>, path: &[JsonPath]) -> Result { if let Some((mut jiter, peek)) = jiter_json_find(json_data, path) { - match peek { + let n = match peek { + Peek::String => NumberInt::try_from(jiter.next_bytes()?)?, // numbers are represented by everything else in peek, hence doing it this way Peek::Null | Peek::True @@ -73,13 +74,13 @@ fn jiter_json_get_int(json_data: Option<&str>, path: &[JsonPath]) -> Result get_err!(), - _ => match jiter.known_int(peek)? { - NumberInt::Int(i) => Ok(i), - NumberInt::BigInt(_) => get_err!(), - }, + | Peek::Object => return get_err!(), + _ => jiter.known_int(peek)?, + }; + match n { + NumberInt::Int(i) => Ok(i), + NumberInt::BigInt(_) => get_err!(), } } else { get_err!() diff --git a/src/rewrite.rs b/src/rewrite.rs index 60403c8..814bd62 100644 --- a/src/rewrite.rs +++ b/src/rewrite.rs @@ -78,7 +78,7 @@ fn unnest_json_calls(func: &ScalarFunction) -> Option> { fn extract_scalar_function(expr: &Expr) -> Option<&ScalarFunction> { match expr { Expr::ScalarFunction(func) => Some(func), - Expr::Alias(alias) => extract_scalar_function(&*alias.expr), + Expr::Alias(alias) => extract_scalar_function(&alias.expr), _ => None, } } diff --git a/tests/main.rs b/tests/main.rs index 1bbd85c..34570d8 100644 --- a/tests/main.rs +++ b/tests/main.rs @@ -1131,8 +1131,30 @@ async fn test_long_arrow_cast() { assert_batches_eq!(expected, &batches); } +#[tokio::test] async fn test_arrow_cast_numeric() { let sql = r#"select ('{"foo": 420}'->'foo')::numeric = 420"#; let batches = run_query(sql).await.unwrap(); assert_eq!(display_val(batches).await, (DataType::Boolean, "true".to_string())); } + +#[tokio::test] +async fn test_json_get_int_string() { + let sql = r#"select json_get_int('{"foo": "420"}'->'foo')"#; + let batches = run_query(sql).await.unwrap(); + assert_eq!(display_val(batches).await, (DataType::Int64, "420".to_string())); +} + +// #[tokio::test] +// async fn test_json_get_float_string() { +// let sql = r#"select json_get_float('{"foo": "420.123"}'->'foo')"#; +// let batches = run_query(sql).await.unwrap(); +// assert_eq!(display_val(batches).await, (DataType::Int64, "420.123".to_string())); +// } +// +// #[tokio::test] +// async fn test_json_get_float_string_2() { +// let sql = r#"select json_get_float('{"foo": "420"}'->'foo')"#; +// let batches = run_query(sql).await.unwrap(); +// assert_eq!(display_val(batches).await, (DataType::Int64, "420".to_string())); +// }