diff --git a/src/meta/app/src/principal/mod.rs b/src/meta/app/src/principal/mod.rs index f5e814dac5fcb..8572fd34e9524 100644 --- a/src/meta/app/src/principal/mod.rs +++ b/src/meta/app/src/principal/mod.rs @@ -123,6 +123,7 @@ pub use user_defined_function::UDAFScript; pub use user_defined_function::UDFDefinition; pub use user_defined_function::UDFScript; pub use user_defined_function::UDFServer; +pub use user_defined_function::UDTFServer; pub use user_defined_function::UserDefinedFunction; pub use user_defined_function::UDTF; pub use user_grant::GrantEntry; diff --git a/src/meta/app/src/principal/user_defined_function.rs b/src/meta/app/src/principal/user_defined_function.rs index 37a9cf30ea35b..6612d7b00c201 100644 --- a/src/meta/app/src/principal/user_defined_function.rs +++ b/src/meta/app/src/principal/user_defined_function.rs @@ -52,6 +52,18 @@ pub struct UDFScript { pub immutable: Option, } +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +pub struct UDTFServer { + pub address: String, + pub handler: String, + pub headers: BTreeMap, + pub language: String, + pub arg_names: Vec, + pub arg_types: Vec, + pub return_types: Vec<(String, DataType)>, + pub immutable: Option, +} + /// User Defined Table Function (UDTF) /// /// # Fields @@ -98,6 +110,7 @@ pub enum UDFDefinition { UDFServer(UDFServer), UDFScript(UDFScript), UDAFScript(UDAFScript), + UDTFServer(UDTFServer), UDTF(UDTF), ScalarUDF(ScalarUDF), } @@ -110,7 +123,8 @@ impl UDFDefinition { Self::UDFScript(_) => "UDFScript", Self::UDAFScript(_) => "UDAFScript", Self::UDTF(_) => "UDTF", - UDFDefinition::ScalarUDF(_) => "ScalarUDF", + Self::UDTFServer(_) => "UDTFServer", + Self::ScalarUDF(_) => "ScalarUDF", } } @@ -120,6 +134,7 @@ impl UDFDefinition { Self::UDFServer(_) => false, Self::UDFScript(_) => false, Self::UDTF(_) => false, + Self::UDTFServer(_) => false, Self::ScalarUDF(_) => false, Self::UDAFScript(_) => true, } @@ -130,6 +145,7 @@ impl UDFDefinition { Self::LambdaUDF(_) => "SQL", Self::UDTF(_) => "SQL", Self::ScalarUDF(_) => "SQL", + Self::UDTFServer(x) => x.language.as_str(), Self::UDFServer(x) => x.language.as_str(), Self::UDFScript(x) => x.language.as_str(), Self::UDAFScript(x) => x.language.as_str(), @@ -220,6 +236,13 @@ impl UserDefinedFunction { created_on: Utc::now(), } } + + pub fn as_udtf_server(self) -> Option { + if let UDFDefinition::UDTFServer(udtf_server) = self.definition { + return Some(udtf_server); + } + None + } } impl Display for UDFDefinition { @@ -353,6 +376,50 @@ impl Display for UDFDefinition { } write!(f, ") AS $${sql}$$")?; } + UDFDefinition::UDTFServer(UDTFServer { + address, + handler, + headers, + language, + arg_names, + arg_types, + return_types, + immutable, + }) => { + for (i, (name, ty)) in arg_names.iter().zip(arg_types.iter()).enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{name} {ty}")?; + } + write!(f, ") RETURNS (")?; + for (i, (name, ty)) in return_types.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{name} {ty}")?; + } + write!(f, ") LANGUAGE {language}")?; + if let Some(immutable) = immutable { + if *immutable { + write!(f, " IMMUTABLE")?; + } else { + write!(f, " VOLATILE")?; + } + } + write!(f, " HANDLER = {handler}")?; + if !headers.is_empty() { + write!(f, " HEADERS = (")?; + for (i, (key, value)) in headers.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{key} = {value}")?; + } + write!(f, ")")?; + } + write!(f, " ADDRESS = {address}")?; + } UDFDefinition::ScalarUDF(ScalarUDF { arg_types, return_type, diff --git a/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs b/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs index 5ce96fc9254b0..1ec19e2ef8f04 100644 --- a/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs +++ b/src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs @@ -359,6 +359,88 @@ impl FromToProto for mt::UDTF { } } +impl FromToProto for mt::UDTFServer { + type PB = pb::UdtfServer; + + fn get_pb_ver(p: &Self::PB) -> u64 { + p.ver + } + + fn from_pb(p: Self::PB) -> Result + where Self: Sized { + reader_check_msg(p.ver, p.min_reader_ver)?; + + let mut arg_types = Vec::with_capacity(p.arg_types.len()); + for arg_type in p.arg_types { + let arg_type = DataType::from(&TableDataType::from_pb(arg_type)?); + arg_types.push(arg_type); + } + let mut return_types = Vec::new(); + for return_ty in p.return_types { + let ty_pb = return_ty.ty.ok_or_else(|| { + Incompatible::new("UDTF.arg_types.ty can not be None".to_string()) + })?; + let ty = TableDataType::from_pb(ty_pb)?; + + return_types.push((return_ty.name, (&ty).into())); + } + + Ok(mt::UDTFServer { + address: p.address, + arg_types, + return_types, + handler: p.handler, + headers: p.headers, + language: p.language, + immutable: p.immutable, + arg_names: p.arg_names, + }) + } + + fn to_pb(&self) -> Result { + let mut arg_types = Vec::with_capacity(self.arg_types.len()); + for arg_type in self.arg_types.iter() { + let arg_type = infer_schema_type(arg_type) + .map_err(|e| { + Incompatible::new(format!( + "Convert DataType to TableDataType failed: {}", + e.message() + )) + })? + .to_pb()?; + arg_types.push(arg_type); + } + let mut return_types = Vec::with_capacity(self.return_types.len()); + for (return_name, return_type) in self.return_types.iter() { + let return_type = infer_schema_type(return_type) + .map_err(|e| { + Incompatible::new(format!( + "Convert DataType to TableDataType failed: {}", + e.message() + )) + })? + .to_pb()?; + return_types.push(UdtfArg { + name: return_name.clone(), + ty: Some(return_type), + }); + } + + Ok(pb::UdtfServer { + ver: VER, + min_reader_ver: MIN_READER_VER, + address: self.address.clone(), + handler: self.handler.clone(), + headers: self.headers.clone(), + language: self.language.clone(), + arg_types, + return_types, + immutable: self.immutable, + arg_names: self.arg_names.clone(), + }) + } +} + impl FromToProto for mt::ScalarUDF { type PB = pb::ScalarUdf; @@ -454,6 +536,9 @@ impl FromToProto for mt::UserDefinedFunction { Some(pb::user_defined_function::Definition::ScalarUdf(scalar_udf)) => { mt::UDFDefinition::ScalarUDF(mt::ScalarUDF::from_pb(scalar_udf)?) } + Some(pb::user_defined_function::Definition::UdtfServer(udtf_server)) => { + mt::UDFDefinition::UDTFServer(mt::UDTFServer::from_pb(udtf_server)?) + } None => { return Err(Incompatible::new( "UserDefinedFunction.definition cannot be None".to_string(), @@ -492,6 +577,9 @@ impl FromToProto for mt::UserDefinedFunction { mt::UDFDefinition::ScalarUDF(scalar_udf) => { pb::user_defined_function::Definition::ScalarUdf(scalar_udf.to_pb()?) } + mt::UDFDefinition::UDTFServer(udtf_server) => { + pb::user_defined_function::Definition::UdtfServer(udtf_server.to_pb()?) + } }; Ok(pb::UserDefinedFunction { diff --git a/src/meta/proto-conv/src/util.rs b/src/meta/proto-conv/src/util.rs index 22200f00b210d..f43acf20a64cf 100644 --- a/src/meta/proto-conv/src/util.rs +++ b/src/meta/proto-conv/src/util.rs @@ -187,6 +187,7 @@ const META_CHANGE_LOG: &[(u64, &str)] = &[ (155, "2025-10-24: Add: RowAccessPolicyMeta::RowAccessPolicyArg"), (156, "2025-10-22: Add: DataMaskMeta add DataMaskArg"), (157, "2025-10-22: Add: TableDataType TimestampTz"), + (158, "2025-10-22: Add: Server UDTF"), // Dear developer: // If you're gonna add a new metadata version, you'll have to add a test for it. // You could just copy an existing test file(e.g., `../tests/it/v024_table_meta.rs`) diff --git a/src/meta/proto-conv/tests/it/main.rs b/src/meta/proto-conv/tests/it/main.rs index 8e6ab544da646..dbd4618aa3ea4 100644 --- a/src/meta/proto-conv/tests/it/main.rs +++ b/src/meta/proto-conv/tests/it/main.rs @@ -149,3 +149,4 @@ mod v154_vacuum_watermark; mod v155_row_access_policy_args; mod v156_data_mask_args; mod v157_type_timestamp_tz; +mod v158_udtf_server; diff --git a/src/meta/proto-conv/tests/it/v158_udtf_server.rs b/src/meta/proto-conv/tests/it/v158_udtf_server.rs new file mode 100644 index 0000000000000..7ee196e9e811e --- /dev/null +++ b/src/meta/proto-conv/tests/it/v158_udtf_server.rs @@ -0,0 +1,79 @@ +// Copyright 2023 Datafuse Labs. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use chrono::DateTime; +use chrono::Utc; +use databend_common_expression::types::DataType; +use databend_common_expression::types::NumberDataType; +use databend_common_meta_app::principal::UDFDefinition; +use databend_common_meta_app::principal::UDTFServer; +use databend_common_meta_app::principal::UserDefinedFunction; +use fastrace::func_name; + +use crate::common; + +// These bytes are built when a new version in introduced, +// and are kept for backward compatibility test. +// +// ************************************************************* +// * These messages should never be updated, * +// * only be added when a new version is added, * +// * or be removed when an old version is no longer supported. * +// ************************************************************* +// +// The message bytes are built from the output of `test_pb_from_to()` +#[test] +fn test_decode_v158_server_udtf() -> anyhow::Result<()> { + let bytes = vec![ + 10, 15, 116, 101, 115, 116, 95, 115, 99, 97, 108, 97, 114, 95, 117, 100, 102, 18, 21, 84, + 104, 105, 115, 32, 105, 115, 32, 97, 32, 100, 101, 115, 99, 114, 105, 112, 116, 105, 111, + 110, 82, 144, 1, 10, 21, 104, 116, 116, 112, 58, 47, 47, 108, 111, 99, 97, 108, 104, 111, + 115, 116, 58, 56, 56, 56, 56, 18, 11, 112, 108, 117, 115, 95, 105, 110, 116, 95, 112, 121, + 26, 6, 112, 121, 116, 104, 111, 110, 34, 10, 146, 2, 0, 160, 6, 158, 1, 168, 6, 24, 34, 10, + 138, 2, 0, 160, 6, 158, 1, 168, 6, 24, 42, 16, 10, 2, 99, 49, 18, 10, 146, 2, 0, 160, 6, + 158, 1, 168, 6, 24, 42, 25, 10, 2, 99, 50, 18, 19, 154, 2, 9, 42, 0, 160, 6, 158, 1, 168, + 6, 24, 160, 6, 158, 1, 168, 6, 24, 50, 14, 10, 4, 107, 101, 121, 49, 18, 6, 118, 97, 108, + 117, 101, 49, 66, 2, 99, 49, 66, 2, 99, 50, 160, 6, 158, 1, 168, 6, 24, 42, 23, 50, 48, 50, + 51, 45, 49, 50, 45, 49, 53, 32, 48, 49, 58, 50, 54, 58, 48, 57, 32, 85, 84, 67, 160, 6, + 158, 1, 168, 6, 24, + ]; + + let want = || UserDefinedFunction { + name: "test_scalar_udf".to_string(), + description: "This is a description".to_string(), + definition: UDFDefinition::UDTFServer(UDTFServer { + address: "http://localhost:8888".to_string(), + handler: "plus_int_py".to_string(), + headers: vec![("key1".to_string(), "value1".to_string())] + .into_iter() + .collect(), + language: "python".to_string(), + arg_names: vec![s("c1"), s("c2")], + arg_types: vec![DataType::String, DataType::Boolean], + return_types: vec![ + (s("c1"), DataType::String), + (s("c2"), DataType::Number(NumberDataType::Int8)), + ], + immutable: None, + }), + created_on: DateTime::::from_timestamp(1702603569, 0).unwrap(), + }; + + common::test_pb_from_to(func_name!(), want())?; + common::test_load_old(func_name!(), bytes.as_slice(), 158, want()) +} + +fn s(ss: impl ToString) -> String { + ss.to_string() +} diff --git a/src/meta/protos/proto/udf.proto b/src/meta/protos/proto/udf.proto index a0efd94134051..3c0ba6e5448f8 100644 --- a/src/meta/protos/proto/udf.proto +++ b/src/meta/protos/proto/udf.proto @@ -86,6 +86,21 @@ message UDTF { string sql = 3; } +message UDTFServer { + uint64 ver = 100; + uint64 min_reader_ver = 101; + + string address = 1; + string handler = 2; + string language = 3; + repeated DataType arg_types = 4; + // return column name with data type + repeated UDTFArg return_types = 5; + map headers = 6; + optional bool immutable = 7; + repeated string arg_names = 8; +} + message ScalarUDF { uint64 ver = 100; uint64 min_reader_ver = 101; @@ -112,6 +127,7 @@ message UserDefinedFunction { UDAFScript udaf_script = 7; UDTF udtf = 8; ScalarUDF scalar_udf = 9; + UDTFServer udtf_server = 10; } // The time udf created. optional string created_on = 5; diff --git a/src/query/ast/src/ast/statements/udf.rs b/src/query/ast/src/ast/statements/udf.rs index e268394c0fca2..e7208c15d9a4e 100644 --- a/src/query/ast/src/ast/statements/udf.rs +++ b/src/query/ast/src/ast/statements/udf.rs @@ -88,6 +88,15 @@ pub enum UDFDefinition { return_types: Vec<(Identifier, TypeName)>, sql: String, }, + UDTFServer { + arg_types: Vec<(Identifier, TypeName)>, + return_types: Vec<(Identifier, TypeName)>, + address: String, + handler: String, + headers: BTreeMap, + language: String, + immutable: Option, + }, ScalarUDF { arg_types: Vec<(Identifier, TypeName)>, definition: String, @@ -282,6 +291,46 @@ impl Display for UDFDefinition { )?; write!(f, ") AS $$\n{sql}\n$$")?; } + UDFDefinition::UDTFServer { + arg_types, + return_types, + address, + handler, + headers, + language, + immutable, + } => { + write!(f, "(")?; + write_comma_separated_list( + f, + arg_types.iter().map(|(name, ty)| format!("{name} {ty}")), + )?; + write!(f, ") RETURNS TABLE (")?; + write_comma_separated_list( + f, + return_types.iter().map(|(name, ty)| format!("{name} {ty}")), + )?; + write!(f, ") LANGUAGE {language}")?; + if let Some(immutable) = immutable { + if *immutable { + write!(f, " IMMUTABLE")?; + } else { + write!(f, " VOLATILE")?; + } + } + write!(f, " HANDLER = '{handler}'")?; + if !headers.is_empty() { + write!(f, " HEADERS = (")?; + for (i, (key, value)) in headers.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "'{key}' = '{value}'")?; + } + write!(f, ")")?; + } + write!(f, " ADDRESS = '{address}'")?; + } UDFDefinition::ScalarUDF { arg_types, definition, diff --git a/src/query/ast/src/parser/statement.rs b/src/query/ast/src/parser/statement.rs index 9749ea6ccc284..abb0ed475c9d7 100644 --- a/src/query/ast/src/parser/statement.rs +++ b/src/query/ast/src/parser/statement.rs @@ -5377,6 +5377,50 @@ pub fn udf_definition(i: Input) -> IResult { .parse(i) } + enum FuncBody { + Sql(String), + Server { + address: String, + handler: String, + headers: BTreeMap, + language: String, + immutable: Option, + }, + } + + fn func_body(i: Input) -> IResult { + let sql = map( + rule! { + AS ~ ^#code_string + }, + |(_, sql)| FuncBody::Sql(sql), + ); + let server = map( + rule! { + LANGUAGE ~ #ident + ~ (#udf_immutable)? + ~ HANDLER ~ ^"=" ~ ^#literal_string + ~ ( HEADERS ~ ^"=" ~ "(" ~ #comma_separated_list0(udf_header) ~ ")" )? + ~ ADDRESS ~ ^"=" ~ ^#literal_string + }, + |(_, language, immutable, _, _, handler, headers, _, _, address)| FuncBody::Server { + address, + handler, + headers: headers + .map(|(_, _, _, headers, _)| BTreeMap::from_iter(headers)) + .unwrap_or_default(), + language: language.to_string(), + immutable, + }, + ); + + rule!( + #sql: "AS " + | #server: "LANGUAGE HANDLER= ADDRESS=" + ) + .parse(i) + } + let lambda_udf = map( rule! { AS ~ #lambda_udf_params @@ -5449,23 +5493,49 @@ pub fn udf_definition(i: Input) -> IResult { }, ); - let scalar_udf_or_udtf = map( + let scalar_udf_or_udtf = map_res( rule! { "(" ~ #comma_separated_list0(udtf_arg) ~ ")" ~ RETURNS ~ ^#return_body - ~ AS ~ ^#code_string + ~ #func_body }, - |(_, arg_types, _, _, return_body, _, sql)| match return_body { - ReturnBody::Scalar(return_type) => UDFDefinition::ScalarUDF { - arg_types, - definition: sql, - return_type, - }, - ReturnBody::Table(return_types) => UDFDefinition::UDTFSql { - arg_types, - return_types, - sql, - }, + |(_, arg_types, _, _, return_body, func_body)| { + let definition = match (return_body, func_body) { + (ReturnBody::Scalar(return_type), FuncBody::Sql(sql)) => UDFDefinition::ScalarUDF { + arg_types, + definition: sql, + return_type, + }, + (ReturnBody::Scalar(_), FuncBody::Server { .. }) => { + return Err(nom::Err::Failure(ErrorKind::Other( + "ScalarUDF unsupported external Server", + ))) + } + (ReturnBody::Table(return_types), FuncBody::Sql(sql)) => UDFDefinition::UDTFSql { + arg_types, + return_types, + sql, + }, + ( + ReturnBody::Table(return_types), + FuncBody::Server { + address, + handler, + headers, + language, + immutable, + }, + ) => UDFDefinition::UDTFServer { + arg_types, + return_types, + address, + handler, + headers, + language, + immutable, + }, + }; + Ok(definition) }, ); @@ -5531,7 +5601,7 @@ pub fn udf_definition(i: Input) -> IResult { #lambda_udf: "AS (, ...) -> " | #udaf: "(<[arg_name] arg_type>, ...) STATE {, ...} RETURNS LANGUAGE { ADDRESS= | AS } " | #udf: "(<[arg_name] arg_type>, ...) RETURNS LANGUAGE HANDLER= { ADDRESS= | AS } " - | #scalar_udf_or_udtf: "(, ...) RETURNS AS }" + | #scalar_udf_or_udtf: "(, ...) RETURNS { AS | LANGUAGE HANDLER= ADDRESS= } }" ).parse(i) } diff --git a/src/query/catalog/src/catalog/interface.rs b/src/query/catalog/src/catalog/interface.rs index 4cb409811fff8..64a5c9c11a31a 100644 --- a/src/query/catalog/src/catalog/interface.rs +++ b/src/query/catalog/src/catalog/interface.rs @@ -21,6 +21,7 @@ use databend_common_ast::ast::Engine; use databend_common_exception::ErrorCode; use databend_common_exception::ErrorCodeResultExt; use databend_common_exception::Result; +use databend_common_meta_app::principal::UDTFServer; use databend_common_meta_app::schema::database_name_ident::DatabaseNameIdent; use databend_common_meta_app::schema::dictionary_name_ident::DictionaryNameIdent; use databend_common_meta_app::schema::least_visible_time_ident::LeastVisibleTimeIdent; @@ -123,6 +124,7 @@ use log::info; use crate::database::Database; use crate::table::Table; use crate::table_args::TableArgs; +use crate::table_context::TableContext; use crate::table_function::TableFunction; #[derive(Default, Clone)] @@ -618,4 +620,12 @@ pub trait Catalog: DynClone + Send + Sync + Debug { } async fn rename_dictionary(&self, req: RenameDictionaryReq) -> Result<()>; + + fn transform_udtf_as_table_function( + &self, + ctx: &dyn TableContext, + table_args: &TableArgs, + udtf: UDTFServer, + func_name: &str, + ) -> Result>; } diff --git a/src/query/expression/src/utils/udf_client.rs b/src/query/expression/src/utils/udf_client.rs index 6471e9963eaca..e5847e14c7bc5 100644 --- a/src/query/expression/src/utils/udf_client.rs +++ b/src/query/expression/src/utils/udf_client.rs @@ -44,6 +44,7 @@ use futures::stream; use futures::StreamExt; use futures::TryStreamExt; use hyper_util::client::legacy::connect::HttpConnector; +use itertools::Itertools; use tonic::metadata::KeyAndValueRef; use tonic::metadata::MetadataKey; use tonic::metadata::MetadataMap; @@ -301,15 +302,14 @@ impl UDFFlightClient { &mut self, name: &str, func_name: &str, - num_rows: usize, + num_rows: Option, args: Vec, return_type: &DataType, - ) -> Result { + ) -> Result { let instant = Instant::now(); Profile::record_usize_profile(ProfileStatisticsName::ExternalServerRequestCount, 1); record_running_requests_external_start(name, 1); - record_request_external_batch_rows(func_name, num_rows); let args = args .into_iter() @@ -332,7 +332,9 @@ impl UDFFlightClient { .collect::>(); let data_schema = DataSchema::new(fields); - let input_batch = DataBlock::new(args, num_rows) + // at least 1 for `UDFFlightClient::batch_rows` + let input_num_rows = args.first().map(|entry| entry.len()).unwrap_or(1); + let input_batch = DataBlock::new(args, input_num_rows) .to_record_batch_with_dataschema(&data_schema) .map_err(|err| ErrorCode::from_string(format!("{err}")))?; @@ -344,11 +346,12 @@ impl UDFFlightClient { let result_batch = result_batch?; let schema = DataSchema::try_from(&(*result_batch.schema()))?; - let result_block = DataBlock::from_record_batch(&schema, &result_batch).map_err(|err| { - ErrorCode::UDFDataError(format!( - "Cannot convert arrow record batch to data block: {err}" - )) - })?; + let mut result_block = + DataBlock::from_record_batch(&schema, &result_batch).map_err(|err| { + ErrorCode::UDFDataError(format!( + "Cannot convert arrow record batch to data block: {err}" + )) + })?; let result_fields = schema.fields(); if result_fields.is_empty() || result_block.is_empty() { @@ -357,26 +360,53 @@ impl UDFFlightClient { )); } - if result_fields[0].data_type() != return_type { + if let Some(expected_rows) = num_rows { + if result_block.num_rows() != expected_rows { + return Err(ErrorCode::UDFDataError(format!( + "UDF server should return {} rows, but it returned {} rows", + expected_rows, + result_block.num_rows() + ))); + } + } + record_request_external_batch_rows(func_name, result_block.num_rows()); + + if return_type.remove_nullable().is_tuple() && result_fields.len() > 1 { + if let DataType::Tuple(tys) = return_type.remove_nullable() { + if tys + .iter() + .zip(result_fields.iter().map(|f| f.data_type())) + .any(|(ty_a, ty_b)| ty_a != ty_b) + { + return Err(ErrorCode::UDFSchemaMismatch(format!( + "UDF server return incorrect type, expected: {}, but got: {}", + tys.iter().map(|ty| ty.to_string()).join(","), + result_fields + .iter() + .map(|f| f.data_type()) + .map(|ty| ty.to_string()) + .join(", ") + ))); + } + } + } else if result_fields[0].data_type() != return_type { return Err(ErrorCode::UDFSchemaMismatch(format!( - "The user-defined function \"{func_name}\" returned an unexpected schema. Expected result type {return_type}, but got {}.", + "UDF server return incorrect type, expected: {}, but got: {}", + return_type, result_fields[0].data_type() ))); } - if result_block.num_rows() != num_rows { - return Err(ErrorCode::UDFDataError(format!( - "UDF server should return {} rows, but it returned {} rows", - num_rows, - result_block.num_rows() - ))); - } - - if contains_variant(return_type) { - let value = transform_variant(&result_block.get_by_offset(0).value(), false)?; - Ok(BlockEntry::Column(value.as_column().unwrap().clone())) - } else { - Ok(result_block.get_by_offset(0).clone()) + for (entry, field) in result_block + .columns_mut() + .iter_mut() + .zip(result_fields.iter()) + { + if contains_variant(field.data_type()) { + let value = transform_variant(&entry.value(), false)?; + *entry = BlockEntry::Column(value.as_column().unwrap().clone()); + } } + Ok(result_block) } #[async_backtrace::framed] diff --git a/src/query/service/src/catalogs/default/database_catalog.rs b/src/query/service/src/catalogs/default/database_catalog.rs index 5d33440bec848..ce53102d85d54 100644 --- a/src/query/service/src/catalogs/default/database_catalog.rs +++ b/src/query/service/src/catalogs/default/database_catalog.rs @@ -22,11 +22,13 @@ use databend_common_catalog::catalog::Catalog; use databend_common_catalog::catalog::StorageDescription; use databend_common_catalog::database::Database; use databend_common_catalog::table_args::TableArgs; +use databend_common_catalog::table_context::TableContext; use databend_common_catalog::table_function::TableFunction; use databend_common_config::InnerConfig; use databend_common_exception::ErrorCode; use databend_common_exception::ErrorCodeResultExt; use databend_common_exception::Result; +use databend_common_meta_app::principal::UDTFServer; use databend_common_meta_app::schema::database_name_ident::DatabaseNameIdent; use databend_common_meta_app::schema::dictionary_name_ident::DictionaryNameIdent; use databend_common_meta_app::schema::least_visible_time_ident::LeastVisibleTimeIdent; @@ -126,6 +128,7 @@ use crate::catalogs::default::MutableCatalog; use crate::catalogs::default::SessionCatalog; use crate::storages::Table; use crate::table_functions::TableFunctionFactory; +use crate::table_functions::UDTFTable; /// Combine two catalogs together /// - read/search like operations are always performed at @@ -928,4 +931,14 @@ impl Catalog for DatabaseCatalog { ) -> Result { self.mutable_catalog.get_autoincrement_next_value(req).await } + + fn transform_udtf_as_table_function( + &self, + ctx: &dyn TableContext, + table_args: &TableArgs, + udtf: UDTFServer, + func_name: &str, + ) -> Result> { + UDTFTable::create(ctx, "default", func_name, table_args, udtf) + } } diff --git a/src/query/service/src/catalogs/default/immutable_catalog.rs b/src/query/service/src/catalogs/default/immutable_catalog.rs index 394e2bf22ed2d..61d91c5165a5e 100644 --- a/src/query/service/src/catalogs/default/immutable_catalog.rs +++ b/src/query/service/src/catalogs/default/immutable_catalog.rs @@ -18,9 +18,13 @@ use std::fmt::Formatter; use std::sync::Arc; use databend_common_catalog::catalog::Catalog; +use databend_common_catalog::table_args::TableArgs; +use databend_common_catalog::table_context::TableContext; +use databend_common_catalog::table_function::TableFunction; use databend_common_config::InnerConfig; use databend_common_exception::ErrorCode; use databend_common_exception::Result; +use databend_common_meta_app::principal::UDTFServer; use databend_common_meta_app::schema::database_name_ident::DatabaseNameIdent; use databend_common_meta_app::schema::dictionary_name_ident::DictionaryNameIdent; use databend_common_meta_app::schema::CatalogInfo; @@ -611,4 +615,14 @@ impl Catalog for ImmutableCatalog { ) -> Result { unimplemented!() } + + fn transform_udtf_as_table_function( + &self, + _ctx: &dyn TableContext, + _table_args: &TableArgs, + _udtf: UDTFServer, + _func_name: &str, + ) -> Result> { + unimplemented!() + } } diff --git a/src/query/service/src/catalogs/default/mutable_catalog.rs b/src/query/service/src/catalogs/default/mutable_catalog.rs index 00111a22bb1a0..79b62c0c81e34 100644 --- a/src/query/service/src/catalogs/default/mutable_catalog.rs +++ b/src/query/service/src/catalogs/default/mutable_catalog.rs @@ -20,6 +20,9 @@ use std::time::Instant; use databend_common_base::base::BuildInfoRef; use databend_common_catalog::catalog::Catalog; +use databend_common_catalog::table_args::TableArgs; +use databend_common_catalog::table_context::TableContext; +use databend_common_catalog::table_function::TableFunction; use databend_common_config::InnerConfig; use databend_common_exception::ErrorCode; use databend_common_exception::Result; @@ -35,6 +38,7 @@ use databend_common_meta_api::SecurityApi; use databend_common_meta_api::SequenceApi; use databend_common_meta_api::TableApi; use databend_common_meta_app::app_error::AppError; +use databend_common_meta_app::principal::UDTFServer; use databend_common_meta_app::schema::database_name_ident::DatabaseNameIdent; use databend_common_meta_app::schema::dictionary_name_ident::DictionaryNameIdent; use databend_common_meta_app::schema::index_id_ident::IndexId; @@ -148,6 +152,7 @@ use crate::databases::DatabaseFactory; use crate::storages::StorageDescription; use crate::storages::StorageFactory; use crate::storages::Table; +use crate::table_functions::UDTFTable; /// Catalog based on MetaStore /// - System Database NOT included @@ -972,6 +977,16 @@ impl Catalog for MutableCatalog { Ok(res) } + fn transform_udtf_as_table_function( + &self, + ctx: &dyn TableContext, + table_args: &TableArgs, + udtf: UDTFServer, + func_name: &str, + ) -> Result> { + UDTFTable::create(ctx, "default", func_name, table_args, udtf) + } + #[async_backtrace::framed] async fn get_autoincrement_next_value( &self, diff --git a/src/query/service/src/catalogs/default/session_catalog.rs b/src/query/service/src/catalogs/default/session_catalog.rs index 5fe072ca987c0..01c487b5a6efa 100644 --- a/src/query/service/src/catalogs/default/session_catalog.rs +++ b/src/query/service/src/catalogs/default/session_catalog.rs @@ -20,9 +20,11 @@ use databend_common_catalog::catalog::StorageDescription; use databend_common_catalog::database::Database; use databend_common_catalog::table::Table; use databend_common_catalog::table_args::TableArgs; +use databend_common_catalog::table_context::TableContext; use databend_common_catalog::table_function::TableFunction; use databend_common_exception::ErrorCode; use databend_common_exception::Result; +use databend_common_meta_app::principal::UDTFServer; use databend_common_meta_app::schema::database_name_ident::DatabaseNameIdent; use databend_common_meta_app::schema::dictionary_name_ident::DictionaryNameIdent; use databend_common_meta_app::schema::least_visible_time_ident::LeastVisibleTimeIdent; @@ -799,6 +801,17 @@ impl Catalog for SessionCatalog { ) -> Result { self.inner.get_autoincrement_next_value(req).await } + + fn transform_udtf_as_table_function( + &self, + ctx: &dyn TableContext, + table_args: &TableArgs, + udtf: UDTFServer, + func_name: &str, + ) -> Result> { + self.inner + .transform_udtf_as_table_function(ctx, table_args, udtf, func_name) + } } impl SessionCatalog { diff --git a/src/query/service/src/catalogs/iceberg/iceberg_catalog.rs b/src/query/service/src/catalogs/iceberg/iceberg_catalog.rs index 2aa28e92ca70a..0625b47a4f4ec 100644 --- a/src/query/service/src/catalogs/iceberg/iceberg_catalog.rs +++ b/src/query/service/src/catalogs/iceberg/iceberg_catalog.rs @@ -22,9 +22,13 @@ use databend_common_catalog::catalog::Catalog; use databend_common_catalog::catalog::CatalogCreator; use databend_common_catalog::catalog::StorageDescription; use databend_common_catalog::database::Database; +use databend_common_catalog::table_args::TableArgs; +use databend_common_catalog::table_context::TableContext; +use databend_common_catalog::table_function::TableFunction; use databend_common_exception::ErrorCode; use databend_common_exception::ErrorCodeResultExt; use databend_common_exception::Result; +use databend_common_meta_app::principal::UDTFServer; use databend_common_meta_app::schema::database_name_ident::DatabaseNameIdent; use databend_common_meta_app::schema::dictionary_name_ident::DictionaryNameIdent; use databend_common_meta_app::schema::CatalogInfo; @@ -600,4 +604,14 @@ impl Catalog for IcebergCatalog { ) -> Result { unimplemented!() } + + fn transform_udtf_as_table_function( + &self, + _ctx: &dyn TableContext, + _table_args: &TableArgs, + _udtf: UDTFServer, + _func_name: &str, + ) -> Result> { + unimplemented!() + } } diff --git a/src/query/service/src/pipelines/builders/builder_udtf.rs b/src/query/service/src/pipelines/builders/builder_udtf.rs new file mode 100644 index 0000000000000..b00d5eb1593d7 --- /dev/null +++ b/src/query/service/src/pipelines/builders/builder_udtf.rs @@ -0,0 +1,201 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::BTreeMap; +use std::sync::Arc; +use std::time::Duration; + +use backon::ExponentialBuilder; +use backon::Retryable; +use databend_common_base::runtime::profile::Profile; +use databend_common_base::runtime::profile::ProfileStatisticsName; +use databend_common_catalog::table_context::TableContext; +use databend_common_exception::ErrorCode; +use databend_common_exception::Result; +use databend_common_expression::types::DataType; +use databend_common_expression::udf_client::error_kind; +use databend_common_expression::udf_client::UDFFlightClient; +use databend_common_expression::BlockEntry; +use databend_common_expression::DataBlock; +use databend_common_expression::Scalar; +use databend_common_metrics::external_server::record_retry_external; +use databend_common_pipeline::sources::AsyncSource; +use tokio::sync::Semaphore; +use tonic::transport::Endpoint; + +pub struct UdtfServerSource { + ctx: Arc, + + func: UdtfFunctionDesc, + connect_timeout: u64, + semaphore: Arc, + endpoint: Arc, + retry_times: usize, + + done: bool, +} + +impl UdtfServerSource { + pub fn init_semaphore(ctx: Arc) -> Result> { + let settings = ctx.get_settings(); + let request_max_threads = settings.get_external_server_request_max_threads()? as usize; + let semaphore = Arc::new(Semaphore::new(request_max_threads)); + Ok(semaphore) + } + + pub fn init_endpoints( + ctx: Arc, + func: &UdtfFunctionDesc, + ) -> Result> { + let settings = ctx.get_settings(); + let connect_timeout = settings.get_external_server_connect_timeout_secs()?; + let request_timeout = settings.get_external_server_request_timeout_secs()?; + + let endpoint = UDFFlightClient::build_endpoint( + &func.server, + connect_timeout, + request_timeout, + &ctx.get_version().udf_client_user_agent(), + )?; + + Ok(endpoint) + } + + pub fn new( + ctx: Arc, + func: UdtfFunctionDesc, + semaphore: Arc, + endpoint: Arc, + ) -> Result { + let settings = ctx.get_settings(); + let connect_timeout = settings.get_external_server_connect_timeout_secs()?; + let retry_times = settings.get_external_server_request_retry_times()? as usize; + + Ok(Self { + ctx, + func, + connect_timeout, + semaphore, + endpoint, + retry_times, + done: false, + }) + } + + async fn source_inner( + ctx: Arc, + endpoint: Arc, + semaphore: Arc, + connect_timeout: u64, + func: UdtfFunctionDesc, + ) -> Result { + // Must obtain the permit to execute, prevent too many connections being executed concurrently + let permit = semaphore.acquire_owned().await.map_err(|e| { + ErrorCode::Internal(format!("Udtf transformer acquire permit failure. {}", e)) + })?; + // construct input record_batch + let mut client = + UDFFlightClient::connect(&func.func_name, endpoint, connect_timeout, 65536) + .await? + .with_tenant(ctx.get_tenant().tenant_name())? + .with_func_name(&func.name)? + .with_handler_name(&func.func_name)? + .with_query_id(&ctx.get_id())? + .with_headers(func.headers)?; + + let args = func + .args + .into_iter() + .filter(|(_, ty)| ty.remove_nullable() != DataType::StageLocation) + .map(|(scalar, ty)| BlockEntry::Const(scalar, ty, 1)) + .collect(); + + debug_assert!(func.return_ty.is_tuple()); + let result = client + .do_exchange(&func.name, &func.func_name, None, args, &func.return_ty) + .await?; + + drop(permit); + Ok(result) + } +} + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct UdtfFunctionDesc { + pub name: String, + pub func_name: String, + pub return_ty: DataType, + pub args: Vec<(Scalar, DataType)>, + pub headers: BTreeMap, + pub server: String, +} + +fn retry_on(err: &databend_common_exception::ErrorCode) -> bool { + if err.code() == ErrorCode::U_D_F_DATA_ERROR { + let message = err.message(); + // this means the server can't handle the request in 60s + if message.contains("h2 protocol error") { + return false; + } + } + true +} + +#[async_trait::async_trait] +impl AsyncSource for UdtfServerSource { + const NAME: &'static str = "UdtfServerSource"; + const SKIP_EMPTY_DATA_BLOCK: bool = true; + + async fn generate(&mut self) -> Result> { + if self.done { + return Ok(None); + } + let ctx = self.ctx.clone(); + let endpoint = self.endpoint.clone(); + let connect_timeout = self.connect_timeout; + let semaphore = self.semaphore.clone(); + let func = self.func.clone(); + let name = func.name.clone(); + + let f = { + move || { + Self::source_inner( + ctx.clone(), + endpoint.clone(), + semaphore.clone(), + connect_timeout, + func.clone(), + ) + } + }; + let backoff = ExponentialBuilder::default() + .with_min_delay(Duration::from_millis(50)) + .with_factor(2.0) + .with_max_delay(Duration::from_secs(30)) + .with_max_times(self.retry_times); + + let data_block = f + .retry(backoff) + .when(retry_on) + .notify(move |err, dur| { + Profile::record_usize_profile(ProfileStatisticsName::ExternalServerRetryCount, 1); + record_retry_external(name.clone(), error_kind(&err.message())); + log::warn!("Retry udtf error: {:?} after {:?}", err.message(), dur); + }) + .await?; + + self.done = true; + Ok(Some(data_block)) + } +} diff --git a/src/query/service/src/pipelines/builders/mod.rs b/src/query/service/src/pipelines/builders/mod.rs index 229465b2ad902..61c2836588b3a 100644 --- a/src/query/service/src/pipelines/builders/mod.rs +++ b/src/query/service/src/pipelines/builders/mod.rs @@ -24,9 +24,12 @@ mod builder_on_finished; mod builder_project; mod builder_replace_into; mod builder_sort; +mod builder_udtf; mod merge_into_join_optimizations; mod transform_builder; pub use builder_replace_into::RawValueSource; pub use builder_replace_into::ValueSource; pub use builder_sort::SortPipelineBuilder; +pub use builder_udtf::UdtfFunctionDesc; +pub use builder_udtf::UdtfServerSource; diff --git a/src/query/service/src/pipelines/processors/transforms/transform_udf_server.rs b/src/query/service/src/pipelines/processors/transforms/transform_udf_server.rs index a1e7b4344443a..ea7f8dfbd4903 100644 --- a/src/query/service/src/pipelines/processors/transforms/transform_udf_server.rs +++ b/src/query/service/src/pipelines/processors/transforms/transform_udf_server.rs @@ -140,13 +140,13 @@ impl TransformUdfServer { .do_exchange( &func.name, &func.func_name, - num_rows, + Some(num_rows), block_entries, &func.data_type, ) .await?; - data_block.add_entry(result); + data_block.add_entry(result.take_columns().pop().unwrap()); drop(permit); Ok(data_block) diff --git a/src/query/service/src/sessions/query_ctx.rs b/src/query/service/src/sessions/query_ctx.rs index 560a63f27326a..46198d9b71061 100644 --- a/src/query/service/src/sessions/query_ctx.rs +++ b/src/query/service/src/sessions/query_ctx.rs @@ -228,8 +228,26 @@ impl QueryContext { .shared .catalog_manager .get_default_catalog(self.session_state()?)?; - let table_function = - default_catalog.get_table_function(&table_info.name, table_args)?; + let udtf_result = databend_common_base::runtime::block_on(async { + if let Some(udtf) = UserApiProvider::instance() + .get_udf(&self.get_tenant(), &table_info.name) + .await? + .and_then(|func| func.as_udtf_server()) + { + return default_catalog + .transform_udtf_as_table_function( + self, + &table_args, + udtf, + &table_info.name, + ) + .map(Some); + } + Ok(None) + }); + let table_function = udtf_result.transpose().unwrap_or_else(|| { + default_catalog.get_table_function(&table_info.name, table_args) + })?; Ok(table_function.as_table()) } (Some(_), false) => Err(ErrorCode::InvalidArgument( diff --git a/src/query/service/src/table_functions/mod.rs b/src/query/service/src/table_functions/mod.rs index 79eec1c4cdb81..dbf09caea6666 100644 --- a/src/query/service/src/table_functions/mod.rs +++ b/src/query/service/src/table_functions/mod.rs @@ -32,6 +32,7 @@ mod system; mod table_function; mod table_function_factory; mod temporary_tables_table; +mod udf_table; pub use copy_history::CopyHistoryTable; pub use numbers::generate_numbers_parts; @@ -46,3 +47,4 @@ pub use system::TableStatisticsFunc; pub use table_function::TableFunction; pub use table_function_factory::TableFunctionFactory; pub use temporary_tables_table::TemporaryTablesTable; +pub use udf_table::UDTFTable; diff --git a/src/query/service/src/table_functions/others/udf.rs b/src/query/service/src/table_functions/others/udf.rs index c7475efacf604..6412b30dce85e 100644 --- a/src/query/service/src/table_functions/others/udf.rs +++ b/src/query/service/src/table_functions/others/udf.rs @@ -163,10 +163,10 @@ impl Table for UdfEchoTable { let return_type = DataType::Nullable(Box::new(DataType::String)); let result = client - .do_exchange(name, name, num_rows, block_entries, &return_type) + .do_exchange(name, name, Some(num_rows), block_entries, &return_type) .await?; - let scalar = unsafe { result.index_unchecked(0) }; + let scalar = unsafe { result.get_by_offset(0).index_unchecked(0) }; let value = scalar.as_string().unwrap(); let parts = vec![Arc::new(Box::new(StringPart { value: value.to_string(), diff --git a/src/query/service/src/table_functions/udf_table.rs b/src/query/service/src/table_functions/udf_table.rs new file mode 100644 index 0000000000000..854720f1aec08 --- /dev/null +++ b/src/query/service/src/table_functions/udf_table.rs @@ -0,0 +1,228 @@ +// Copyright 2021 Datafuse Labs +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::any::Any; +use std::sync::Arc; + +use databend_common_ast::Span; +use databend_common_catalog::plan::DataSourcePlan; +use databend_common_catalog::plan::PartStatistics; +use databend_common_catalog::plan::Partitions; +use databend_common_catalog::plan::PushDownInfo; +use databend_common_catalog::table::Table; +use databend_common_catalog::table_args::TableArgs; +use databend_common_catalog::table_context::TableContext; +use databend_common_catalog::table_function::TableFunction; +use databend_common_exception::ErrorCode; +use databend_common_exception::Result; +use databend_common_expression::cast_scalar; +use databend_common_expression::infer_schema_type; +use databend_common_expression::types::DataType; +use databend_common_expression::TableField; +use databend_common_expression::TableSchema; +use databend_common_expression::TableSchemaRefExt; +use databend_common_functions::BUILTIN_FUNCTIONS; +use databend_common_meta_app::principal::StageType; +use databend_common_meta_app::principal::UDTFServer; +use databend_common_meta_app::schema::TableInfo; +use databend_common_meta_app::storage::StorageParams; +use databend_common_pipeline::core::Pipeline; +use databend_common_pipeline::sources::AsyncSourcer; +use databend_common_sql::binder::resolve_stage_location; +use databend_common_sql::StageLocationParam; + +use crate::pipelines::builders::UdtfFunctionDesc; +use crate::pipelines::builders::UdtfServerSource; + +pub struct UDTFTable { + desc: UdtfFunctionDesc, + table_info: TableInfo, +} + +impl UDTFTable { + pub fn create( + ctx: &dyn TableContext, + database_name: &str, + table_func_name: &str, + table_args: &TableArgs, + mut udtf: UDTFServer, + ) -> Result> { + let schema = Self::schema(&udtf)?; + + let table_info = TableInfo { + ident: databend_common_meta_app::schema::TableIdent::new(0, 0), + desc: format!("'{}'.'{}'", database_name, table_func_name), + name: table_func_name.to_string(), + meta: databend_common_meta_app::schema::TableMeta { + schema: schema.clone(), + ..Default::default() + }, + ..Default::default() + }; + let mut stage_locations = Vec::new(); + for (i, (argument, dest_type)) in table_args + .positioned + .iter() + .zip(udtf.arg_types.iter()) + .enumerate() + { + if dest_type.remove_nullable() == DataType::StageLocation { + let Some(location) = argument.as_string() else { + return Err(ErrorCode::SemanticError(format!( + "invalid parameter {argument} for udf function, expected constant string", + ))); + }; + let (stage_info, relative_path) = + databend_common_base::runtime::block_on(resolve_stage_location(ctx, location))?; + + if !matches!(stage_info.stage_type, StageType::External) { + return Err(ErrorCode::SemanticError(format!( + "stage {} type is {}, UDF only support External Stage", + stage_info.stage_name, stage_info.stage_type, + ))); + } + if let StorageParams::S3(config) = &stage_info.stage_params.storage { + if !config.security_token.is_empty() || !config.role_arn.is_empty() { + return Err(ErrorCode::SemanticError(format!( + "StageLocation: @{} must use a separate credential", + location + ))); + } + } + + stage_locations.push(StageLocationParam { + param_name: udtf.arg_names[i].clone(), + relative_path, + stage_info, + }); + } + } + if !stage_locations.is_empty() { + let stage_location_value = serde_json::to_string(&stage_locations)?; + udtf.headers + .insert("databend-stage-mapping".to_string(), stage_location_value); + } + + if table_args.positioned.len() != udtf.arg_types.len() { + return Err(ErrorCode::SyntaxException(format!( + "UDTF '{}' argument types length {} does not match input arguments length {}", + table_func_name, + udtf.arg_types.len(), + table_args.positioned.len() + ))); + } + let args = table_args + .positioned + .iter() + .cloned() + .zip(udtf.arg_types) + .map(|(scalar, ty)| { + cast_scalar(Span::None, scalar, &ty, &BUILTIN_FUNCTIONS).map(|scalar| (scalar, ty)) + }) + .collect::>>()?; + + Ok(Arc::new(Self { + desc: UdtfFunctionDesc { + name: table_func_name.to_string(), + func_name: udtf.handler, + return_ty: DataType::Tuple( + udtf.return_types.into_iter().map(|(_, ty)| ty).collect(), + ), + args, + headers: udtf.headers, + server: udtf.address, + }, + table_info, + })) + } + + fn schema(udtf: &UDTFServer) -> Result> { + let fields = udtf + .return_types + .iter() + .map(|(name, ty)| infer_schema_type(ty).map(|ty| TableField::new(name.as_str(), ty))) + .collect::>>()?; + + Ok(TableSchemaRefExt::create(fields)) + } +} + +impl TableFunction for UDTFTable { + fn function_name(&self) -> &str { + self.desc.name.as_str() + } + + fn as_table<'a>(self: Arc) -> Arc + where Self: 'a { + self + } +} + +#[async_trait::async_trait] +impl Table for UDTFTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn get_table_info(&self) -> &TableInfo { + &self.table_info + } + + #[async_backtrace::framed] + async fn read_partitions( + &self, + _ctx: Arc, + _push_downs: Option, + _dry_run: bool, + ) -> Result<(PartStatistics, Partitions)> { + Ok((PartStatistics::default(), Partitions::default())) + } + + fn table_args(&self) -> Option { + let scalars = self + .desc + .args + .iter() + .map(|(scalar, _)| scalar) + .cloned() + .collect(); + + Some(TableArgs::new_positioned(scalars)) + } + + fn read_data( + &self, + ctx: Arc, + _plan: &DataSourcePlan, + pipeline: &mut Pipeline, + _put_cache: bool, + ) -> Result<()> { + let semaphore = UdtfServerSource::init_semaphore(ctx.clone())?; + let endpoints = UdtfServerSource::init_endpoints(ctx.clone(), &self.desc)?; + pipeline.add_source( + |output| { + let inner = UdtfServerSource::new( + ctx.clone(), + self.desc.clone(), + semaphore.clone(), + endpoints.clone(), + )?; + AsyncSourcer::create(ctx.get_scan_progress(), output, inner) + }, + 1, + )?; + + Ok(()) + } +} diff --git a/src/query/service/tests/it/pipelines/udf_transport.rs b/src/query/service/tests/it/pipelines/udf_transport.rs index 3868ab4fc6684..312f871b09b7d 100644 --- a/src/query/service/tests/it/pipelines/udf_transport.rs +++ b/src/query/service/tests/it/pipelines/udf_transport.rs @@ -246,7 +246,7 @@ async fn malformed_data_returns_parse_hint() -> Result<()> { async fn schema_mismatch_returns_schema_hint() -> Result<()> { let message = run_mock_exchange(MockMode::SchemaMismatch).await?; assert!( - message.contains("returned an unexpected schema"), + message.contains("return incorrect type"), "unexpected schema mismatch message: {message}" ); Ok(()) @@ -279,7 +279,13 @@ async fn run_mock_exchange(mode: MockMode) -> Result { let return_type = DataType::Null; let result = timeout( std::time::Duration::from_secs(5), - client.do_exchange("mock_udf", "mock_handler", num_rows, args, &return_type), + client.do_exchange( + "mock_udf", + "mock_handler", + Some(num_rows), + args, + &return_type, + ), ) .await .expect("do_exchange future timed out"); diff --git a/src/query/service/tests/it/sql/exec/get_table_bind_test.rs b/src/query/service/tests/it/sql/exec/get_table_bind_test.rs index 32030373b1d53..c0cbfe1f87754 100644 --- a/src/query/service/tests/it/sql/exec/get_table_bind_test.rs +++ b/src/query/service/tests/it/sql/exec/get_table_bind_test.rs @@ -39,11 +39,13 @@ use databend_common_catalog::runtime_filter_info::RuntimeFilterReport; use databend_common_catalog::session_type::SessionType; use databend_common_catalog::statistics::data_cache_statistics::DataCacheMetrics; use databend_common_catalog::table::Table; +use databend_common_catalog::table_args::TableArgs; use databend_common_catalog::table_context::ContextError; use databend_common_catalog::table_context::FilteredCopyFiles; use databend_common_catalog::table_context::ProcessInfo; use databend_common_catalog::table_context::StageAttachment; use databend_common_catalog::table_context::TableContext; +use databend_common_catalog::table_function::TableFunction; use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::BlockThresholds; @@ -55,6 +57,7 @@ use databend_common_meta_app::principal::FileFormatParams; use databend_common_meta_app::principal::GrantObject; use databend_common_meta_app::principal::OnErrorMode; use databend_common_meta_app::principal::RoleInfo; +use databend_common_meta_app::principal::UDTFServer; use databend_common_meta_app::principal::UserDefinedConnection; use databend_common_meta_app::principal::UserInfo; use databend_common_meta_app::principal::UserPrivilegeType; @@ -488,6 +491,16 @@ impl Catalog for FakedCatalog { ) -> Result { todo!() } + + fn transform_udtf_as_table_function( + &self, + _ctx: &dyn TableContext, + _table_args: &TableArgs, + _udtf: UDTFServer, + _func_name: &str, + ) -> Result> { + todo!() + } } struct CtxDelegation { diff --git a/src/query/service/tests/it/storages/fuse/operations/commit.rs b/src/query/service/tests/it/storages/fuse/operations/commit.rs index b9ef74130d3ec..0fec83392569e 100644 --- a/src/query/service/tests/it/storages/fuse/operations/commit.rs +++ b/src/query/service/tests/it/storages/fuse/operations/commit.rs @@ -37,11 +37,13 @@ use databend_common_catalog::runtime_filter_info::RuntimeFilterReady; use databend_common_catalog::runtime_filter_info::RuntimeFilterReport; use databend_common_catalog::statistics::data_cache_statistics::DataCacheMetrics; use databend_common_catalog::table::Table; +use databend_common_catalog::table_args::TableArgs; use databend_common_catalog::table_context::ContextError; use databend_common_catalog::table_context::FilteredCopyFiles; use databend_common_catalog::table_context::ProcessInfo; use databend_common_catalog::table_context::StageAttachment; use databend_common_catalog::table_context::TableContext; +use databend_common_catalog::table_function::TableFunction; use databend_common_exception::ErrorCode; use databend_common_exception::Result; use databend_common_expression::BlockThresholds; @@ -54,6 +56,7 @@ use databend_common_meta_app::principal::FileFormatParams; use databend_common_meta_app::principal::GrantObject; use databend_common_meta_app::principal::OnErrorMode; use databend_common_meta_app::principal::RoleInfo; +use databend_common_meta_app::principal::UDTFServer; use databend_common_meta_app::principal::UserDefinedConnection; use databend_common_meta_app::principal::UserInfo; use databend_common_meta_app::principal::UserPrivilegeType; @@ -1238,4 +1241,14 @@ impl Catalog for FakedCatalog { ) -> Result { todo!() } + + fn transform_udtf_as_table_function( + &self, + _ctx: &dyn TableContext, + _table_args: &TableArgs, + _udtf: UDTFServer, + _func_name: &str, + ) -> Result> { + todo!() + } } diff --git a/src/query/sql/src/planner/binder/bind_table_reference/bind_table_function.rs b/src/query/sql/src/planner/binder/bind_table_reference/bind_table_function.rs index 31632e3773c35..e595074d60fdf 100644 --- a/src/query/sql/src/planner/binder/bind_table_reference/bind_table_function.rs +++ b/src/query/sql/src/planner/binder/bind_table_reference/bind_table_function.rs @@ -140,111 +140,128 @@ impl Binder { let tenant = self.ctx.get_tenant(); let udtf_result = databend_common_base::runtime::block_on(async { - if let Some(UDFDefinition::UDTF(udtf)) = UserApiProvider::instance() + match UserApiProvider::instance() .get_udf(&tenant, &func_name.name) .await? .map(|udf| udf.definition) { - let mut stmt = Planner::new(self.ctx.clone()) - .parse_sql(&udtf.sql)? - .statement; - - if udtf.arg_types.len() != table_args.positioned.len() { - return Err(ErrorCode::SyntaxException(format!( - "UDTF '{}' argument types length {} does not match input arguments length {}", - func_name, - udtf.arg_types.len(), - table_args.positioned.len() - ))); - } - - let args_expr = table_args - .positioned - .iter() - .map(|scalar| Expr::Literal { - span: None, - value: Literal::String(scalar_ref_to_string(&scalar.as_ref())), - }) - .collect::>(); - let mut visitor = UDFArgVisitor::new(&udtf.arg_types, &args_expr); - stmt.drive_mut(&mut visitor); - - let binder = Binder::new( - self.ctx.clone(), - CatalogManager::instance(), - self.name_resolution_ctx.clone(), - self.metadata.clone(), - ) - .with_subquery_executor(self.subquery_executor.clone()); - let plan = binder.bind(&stmt).await?; - - let Plan::Query { - s_expr, - mut bind_context, - .. - } = plan - else { - return Err(ErrorCode::UDFRuntimeError( - "Query in UDTF returned no result set", - )); - }; - let mut output_bindings = Vec::with_capacity(bind_context.columns.len()); - let mut output_items = Vec::with_capacity(bind_context.columns.len()); - - if udtf.return_types.len() != bind_context.columns.len() { - return Err(ErrorCode::UDFSchemaMismatch(format!( - "UDTF '{}' return types length {} does not match output columns length {}", - func_name, - udtf.return_types.len(), - bind_context.columns.len() - ))); - } + Some(UDFDefinition::UDTF(udtf)) => { + let mut stmt = Planner::new(self.ctx.clone()) + .parse_sql(&udtf.sql)? + .statement; + + if udtf.arg_types.len() != table_args.positioned.len() { + return Err(ErrorCode::SyntaxException(format!( + "UDTF '{}' argument types length {} does not match input arguments length {}", + func_name, + udtf.arg_types.len(), + table_args.positioned.len() + ))); + } - for ((return_name, return_type), output_binding) in udtf - .return_types - .into_iter() - .zip(bind_context.columns.iter()) - { - let input_expr = ScalarExpr::BoundColumnRef(BoundColumnRef { - span: None, - column: output_binding.clone(), - }); - let cast_expr = ScalarExpr::CastExpr(CastExpr { - span: None, - is_try: false, - argument: Box::new(input_expr), - target_type: Box::new(return_type.clone()), - }); - let index = self - .metadata - .write() - .add_derived_column(return_name.clone(), return_type.clone()); - let output_binding = ColumnBindingBuilder::new( - return_name, - index, - Box::new(return_type), - Visibility::Visible, + let args_expr = table_args + .positioned + .iter() + .map(|scalar| Expr::Literal { + span: None, + value: Literal::String(scalar_ref_to_string(&scalar.as_ref())), + }) + .collect::>(); + let mut visitor = UDFArgVisitor::new(&udtf.arg_types, &args_expr); + stmt.drive_mut(&mut visitor); + + let binder = Binder::new( + self.ctx.clone(), + CatalogManager::instance(), + self.name_resolution_ctx.clone(), + self.metadata.clone(), ) - .build(); + .with_subquery_executor(self.subquery_executor.clone()); + let plan = binder.bind(&stmt).await?; + + let Plan::Query { + s_expr, + mut bind_context, + .. + } = plan + else { + return Err(ErrorCode::UDFRuntimeError( + "Query in UDTF returned no result set", + )); + }; + let mut output_bindings = Vec::with_capacity(bind_context.columns.len()); + let mut output_items = Vec::with_capacity(bind_context.columns.len()); + + if udtf.return_types.len() != bind_context.columns.len() { + return Err(ErrorCode::UDFSchemaMismatch(format!( + "UDTF '{}' return types length {} does not match output columns length {}", + func_name, + udtf.return_types.len(), + bind_context.columns.len() + ))); + } - output_items.push(ScalarItem { - scalar: cast_expr, - index: output_binding.index, - }); - output_bindings.push(output_binding); + for ((return_name, return_type), output_binding) in udtf + .return_types + .into_iter() + .zip(bind_context.columns.iter()) + { + let input_expr = ScalarExpr::BoundColumnRef(BoundColumnRef { + span: None, + column: output_binding.clone(), + }); + let cast_expr = ScalarExpr::CastExpr(CastExpr { + span: None, + is_try: false, + argument: Box::new(input_expr), + target_type: Box::new(return_type.clone()), + }); + let index = self + .metadata + .write() + .add_derived_column(return_name.clone(), return_type.clone()); + let output_binding = ColumnBindingBuilder::new( + return_name, + index, + Box::new(return_type), + Visibility::Visible, + ) + .build(); + + output_items.push(ScalarItem { + scalar: cast_expr, + index: output_binding.index, + }); + output_bindings.push(output_binding); + } + bind_context.columns = output_bindings; + let s_expr = SExpr::create_unary( + Arc::new( + EvalScalar { + items: output_items, + } + .into(), + ), + s_expr, + ); + + return Ok(Some((s_expr, *bind_context))); + } + Some(UDFDefinition::UDTFServer(udtf)) => { + let table = self + .catalogs + .get_default_catalog(self.ctx.session_state()?)? + .transform_udtf_as_table_function( + self.ctx.as_ref(), + &table_args, + udtf, + &func_name.name, + )?; + let (s_expr, bind_context) = + self.bind_base_table_inner(bind_context, alias, sample, table)?; + return Ok(Some((s_expr, bind_context))); } - bind_context.columns = output_bindings; - let s_expr = SExpr::create_unary( - Arc::new( - EvalScalar { - items: output_items, - } - .into(), - ), - s_expr, - ); - - return Ok(Some((s_expr, *bind_context))); + _ => (), } Ok(None) }); @@ -260,31 +277,39 @@ impl Binder { .catalogs .get_default_catalog(self.ctx.session_state()?)? .get_table_function(&func_name.name, table_args)?; - let table = table_meta.as_table(); - let table_alias_name = if let Some(table_alias) = alias { - Some(normalize_identifier(&table_alias.name, &self.name_resolution_ctx).name) - } else { - None - }; - let table_index = self.metadata.write().add_table( - CATALOG_DEFAULT.to_string(), - "system".to_string(), - table.clone(), - table_alias_name, - false, - false, - false, - None, - false, - ); + self.bind_base_table_inner(bind_context, alias, sample, table_meta) + } + } - let (s_expr, mut bind_context) = - self.bind_base_table(bind_context, "system", table_index, None, sample)?; - if let Some(alias) = alias { - bind_context.apply_table_alias(alias, &self.name_resolution_ctx)?; - } - Ok((s_expr, bind_context)) + fn bind_base_table_inner( + &mut self, + bind_context: &mut BindContext, + alias: &Option, + sample: &Option, + table: Arc, + ) -> Result<(SExpr, BindContext)> { + let table_alias_name = if let Some(table_alias) = alias { + Some(normalize_identifier(&table_alias.name, &self.name_resolution_ctx).name) + } else { + None + }; + let table_index = self.metadata.write().add_table( + CATALOG_DEFAULT.to_string(), + "system".to_string(), + table.as_table(), + table_alias_name, + false, + false, + false, + None, + false, + ); + let (s_expr, mut bind_context) = + self.bind_base_table(bind_context, "system", table_index, None, sample)?; + if let Some(alias) = alias { + bind_context.apply_table_alias(alias, &self.name_resolution_ctx)?; } + Ok((s_expr, bind_context)) } fn bind_result_scan( diff --git a/src/query/sql/src/planner/binder/udf.rs b/src/query/sql/src/planner/binder/udf.rs index 13cf84b6d3de7..22eb1502a671e 100644 --- a/src/query/sql/src/planner/binder/udf.rs +++ b/src/query/sql/src/planner/binder/udf.rs @@ -33,6 +33,7 @@ use databend_common_meta_app::principal::UDAFScript; use databend_common_meta_app::principal::UDFDefinition as PlanUDFDefinition; use databend_common_meta_app::principal::UDFScript; use databend_common_meta_app::principal::UDFServer; +use databend_common_meta_app::principal::UDTFServer; use databend_common_meta_app::principal::UserDefinedFunction; use databend_common_meta_app::principal::UDTF; @@ -257,6 +258,50 @@ impl Binder { created_on: Utc::now(), }) } + UDFDefinition::UDTFServer { + arg_types, + return_types, + address, + handler, + headers, + language, + immutable, + } => { + UDFValidator::is_udf_server_allowed(address.as_str())?; + + let mut arg_datatypes = Vec::with_capacity(arg_types.len()); + let mut arg_names = Vec::with_capacity(arg_types.len()); + + for (arg_name, arg_type) in arg_types { + arg_names.push(normalize_identifier(arg_name, &self.name_resolution_ctx).name); + arg_datatypes.push(DataType::from(&resolve_type_name_udf(arg_type)?)); + } + + let return_types = return_types + .iter() + .map(|(name, arg_type)| { + let column = normalize_identifier(name, &self.name_resolution_ctx).name; + let ty = DataType::from(&resolve_type_name_udf(arg_type)?); + Ok((column, ty)) + }) + .collect::>>()?; + + Ok(UserDefinedFunction { + name, + description, + definition: PlanUDFDefinition::UDTFServer(UDTFServer { + address: address.clone(), + handler: handler.clone(), + headers: headers.clone(), + language: language.clone(), + arg_names, + arg_types: arg_datatypes, + return_types, + immutable: *immutable, + }), + created_on: Utc::now(), + }) + } UDFDefinition::ScalarUDF { arg_types, definition, diff --git a/src/query/sql/src/planner/plans/scan.rs b/src/query/sql/src/planner/plans/scan.rs index 64e19bf2de701..d47b83b367be7 100644 --- a/src/query/sql/src/planner/plans/scan.rs +++ b/src/query/sql/src/planner/plans/scan.rs @@ -50,7 +50,7 @@ use crate::plans::SortItem; use crate::ColumnSet; use crate::IndexType; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq, Eq, Hash)] pub struct Prewhere { // columns needed to be output after prewhere scan pub output_columns: ColumnSet, diff --git a/src/query/sql/src/planner/semantic/mod.rs b/src/query/sql/src/planner/semantic/mod.rs index 4164d8f66d802..0f22a8bc6c742 100644 --- a/src/query/sql/src/planner/semantic/mod.rs +++ b/src/query/sql/src/planner/semantic/mod.rs @@ -47,6 +47,7 @@ pub use type_check::resolve_type_name; pub use type_check::resolve_type_name_by_str; pub use type_check::resolve_type_name_udf; pub use type_check::validate_function_arg; +pub use type_check::StageLocationParam; pub use type_check::TypeChecker; pub use udf_rewriter::UDFArgVisitor; pub(crate) use udf_rewriter::UdfRewriter; diff --git a/src/query/sql/src/planner/semantic/type_check.rs b/src/query/sql/src/planner/semantic/type_check.rs index c84452c55fa4b..b6c78425c52cc 100644 --- a/src/query/sql/src/planner/semantic/type_check.rs +++ b/src/query/sql/src/planner/semantic/type_check.rs @@ -207,6 +207,13 @@ use crate::UDFArgVisitor; const DEFAULT_DECIMAL_PRECISION: i64 = 38; const DEFAULT_DECIMAL_SCALE: i64 = 0; +#[derive(serde::Serialize, serde::Deserialize)] +pub struct StageLocationParam { + pub param_name: String, + pub relative_path: String, + pub stage_info: StageInfo, +} + /// A helper for type checking. /// /// `TypeChecker::resolve` will resolve types of `Expr` and transform `Expr` into @@ -5152,6 +5159,7 @@ impl<'a> TypeChecker<'a> { self.resolve_udaf_script(span, name, arguments, udf_def)?, )), UDFDefinition::UDTF(_) => unreachable!(), + UDFDefinition::UDTFServer(_) => unreachable!(), UDFDefinition::ScalarUDF(udf_def) => Ok(Some( self.resolve_scalar_udf(span, name, arguments, udf_def)?, )), @@ -5165,13 +5173,6 @@ impl<'a> TypeChecker<'a> { arguments: &[Expr], mut udf_definition: UDFServer, ) -> Result> { - #[derive(serde::Serialize, serde::Deserialize)] - struct StageLocationParam { - param_name: String, - relative_path: String, - stage_info: StageInfo, - } - UDFValidator::is_udf_server_allowed(&udf_definition.address)?; if arguments.len() != udf_definition.arg_types.len() { return Err(ErrorCode::InvalidArgument(format!( @@ -5346,13 +5347,13 @@ impl<'a> TypeChecker<'a> { .do_exchange( name, &udf_definition.handler, - num_rows, + Some(num_rows), block_entries, &udf_definition.return_type, ) .await?; - let value = unsafe { result.index_unchecked(0) }; + let value = unsafe { result.get_by_offset(0).index_unchecked(0) }; Ok(value.to_owned()) } diff --git a/src/query/storages/hive/hive/src/hive_catalog.rs b/src/query/storages/hive/hive/src/hive_catalog.rs index a90840dfd9534..dfd729bde8f5f 100644 --- a/src/query/storages/hive/hive/src/hive_catalog.rs +++ b/src/query/storages/hive/hive/src/hive_catalog.rs @@ -24,9 +24,11 @@ use databend_common_catalog::catalog::StorageDescription; use databend_common_catalog::database::Database; use databend_common_catalog::table::Table; use databend_common_catalog::table_args::TableArgs; +use databend_common_catalog::table_context::TableContext; use databend_common_catalog::table_function::TableFunction; use databend_common_exception::ErrorCode; use databend_common_exception::Result; +use databend_common_meta_app::principal::UDTFServer; use databend_common_meta_app::schema::database_name_ident::DatabaseNameIdent; use databend_common_meta_app::schema::dictionary_name_ident::DictionaryNameIdent; use databend_common_meta_app::schema::CatalogInfo; @@ -765,4 +767,14 @@ impl Catalog for HiveCatalog { ) -> Result { unimplemented!() } + + fn transform_udtf_as_table_function( + &self, + _ctx: &dyn TableContext, + _table_args: &TableArgs, + _udtf: UDTFServer, + _func_name: &str, + ) -> Result> { + unimplemented!() + } } diff --git a/src/query/storages/iceberg/src/catalog.rs b/src/query/storages/iceberg/src/catalog.rs index 3c1d6cffb2277..5c3013c854d31 100644 --- a/src/query/storages/iceberg/src/catalog.rs +++ b/src/query/storages/iceberg/src/catalog.rs @@ -24,9 +24,11 @@ use databend_common_catalog::catalog::StorageDescription; use databend_common_catalog::database::Database; use databend_common_catalog::table::Table; use databend_common_catalog::table_args::TableArgs; +use databend_common_catalog::table_context::TableContext; use databend_common_catalog::table_function::TableFunction; use databend_common_exception::ErrorCode; use databend_common_exception::Result; +use databend_common_meta_app::principal::UDTFServer; use databend_common_meta_app::schema::database_name_ident::DatabaseNameIdent; use databend_common_meta_app::schema::dictionary_name_ident::DictionaryNameIdent; use databend_common_meta_app::schema::CatalogInfo; @@ -761,4 +763,14 @@ impl Catalog for IcebergMutableCatalog { ) -> Result { unimplemented!() } + + fn transform_udtf_as_table_function( + &self, + _ctx: &dyn TableContext, + _table_args: &TableArgs, + _udtf: UDTFServer, + _func_name: &str, + ) -> Result> { + unimplemented!() + } } diff --git a/src/query/storages/system/src/user_functions_table.rs b/src/query/storages/system/src/user_functions_table.rs index 374a84efb92ab..4bce99f8feedf 100644 --- a/src/query/storages/system/src/user_functions_table.rs +++ b/src/query/storages/system/src/user_functions_table.rs @@ -228,6 +228,24 @@ impl UserFunctionsTable { states: BTreeMap::new(), immutable: None, }, + UDFDefinition::UDTFServer(x) => UserFunctionArguments { + arg_types: x + .arg_names + .iter() + .zip(x.arg_types.iter()) + .map(|(name, ty)| format!("{name} {ty}")) + .collect(), + return_type: Some( + x.return_types + .iter() + .map(|(name, ty)| format!("{name} {ty}")) + .collect(), + ), + server: Some(x.address.to_string()), + parameters: vec![], + states: BTreeMap::new(), + immutable: x.immutable, + }, UDFDefinition::ScalarUDF(x) => UserFunctionArguments { arg_types: x .arg_types diff --git a/tests/sqllogictests/suites/udf_server/udf_server_test.test b/tests/sqllogictests/suites/udf_server/udf_server_test.test index 4506b8edf2bd1..162e9df4d249e 100644 --- a/tests/sqllogictests/suites/udf_server/udf_server_test.test +++ b/tests/sqllogictests/suites/udf_server/udf_server_test.test @@ -689,6 +689,12 @@ CREATE OR REPLACE FUNCTION multi_stage_process(input_stage STAGE_LOCATION, outpu statement ok CREATE OR REPLACE FUNCTION immutable_multi_stage_process(input_stage STAGE_LOCATION, output_stage STAGE_LOCATION, value VARCHAR) RETURNS INT LANGUAGE python IMMUTABLE HANDLER = 'immutable_multi_stage_process' ADDRESS = 'http://0.0.0.0:8815'; +statement ok +CREATE OR REPLACE FUNCTION stage_summary_udtf(data_stage STAGE_LOCATION, arg int) RETURNS TABLE (stage_name varchar, stage_type varchar, bucket varchar,relative_path varchar, value int, summary varchar) LANGUAGE python HANDLER = 'stage_summary_udtf' HEADERS = ('X-Authorization' = '123') ADDRESS = 'http://0.0.0.0:8815'; + +statement ok +CREATE OR REPLACE FUNCTION multi_stage_process_udtf(input_stage STAGE_LOCATION, output_stage STAGE_LOCATION, arg int) RETURNS TABLE (input_stage varchar, output_stage varchar, input_bucket varchar, output_bucket varchar, input_relative_path varchar, output_relative_path varchar, result int) LANGUAGE python HANDLER = 'multi_stage_process_udtf' HEADERS = ('X-Authorization' = '123') ADDRESS = 'http://0.0.0.0:8815'; + statement ok CREATE OR REPLACE STAGE s3_stage URL = 's3://test/' CONNECTION = ( AWS_KEY_ID = 'minioadmin' AWS_SECRET_KEY = 'minioadmin' ENDPOINT_URL = 'http://127.0.0.1:9900') FILE_FORMAT = (TYPE = CSV); @@ -706,3 +712,15 @@ query I SELECT immutable_multi_stage_process(@s3_stage/input/2024/, @s3_stage/output/2024, 'hello') ---- 13 + +query I +SELECT * from stage_summary_udtf(@s3_stage/output/2024, 21) +---- +s3_stage External test output/2024 21 s3_stage:test:output/2024:21 +s3_stage External test output/2024 22 s3_stage:test:output/2024:22 + +query I +SELECT * from multi_stage_process_udtf(@s3_stage/input/2024/, @s3_stage/output/2024, 21) +---- +s3_stage s3_stage test test input/2024/ output/2024 29 +s3_stage s3_stage test test input/2024/ output/2024 30 diff --git a/tests/udf/udf_server.py b/tests/udf/udf_server.py index 4d5eb9ab32b3a..741c0f870276e 100644 --- a/tests/udf/udf_server.py +++ b/tests/udf/udf_server.py @@ -425,14 +425,54 @@ def embedding_4(s: str): return [1.1, 1.2, 1.3, 1.4] +def _stage_bucket(stage: StageLocation) -> str: + storage = stage.storage or {} + return storage.get("bucket") or storage.get("container") or "" + + @udf(stage_refs=["data_stage"], input_types=["INT"], result_type="VARCHAR") def stage_summary(data_stage: StageLocation, value: int) -> str: assert data_stage.stage_type.lower() == "external" assert data_stage.storage - bucket = data_stage.storage.get("bucket", data_stage.storage.get("container", "")) + bucket = _stage_bucket(data_stage) return f"{data_stage.stage_name}:{bucket}:{data_stage.relative_path}:{value}" +@udf( + stage_refs=["data_stage"], + input_types=["INT"], + result_type=[ + ("stage_name", "VARCHAR"), + ("stage_type", "VARCHAR"), + ("bucket", "VARCHAR"), + ("relative_path", "VARCHAR"), + ("value", "INT"), + ("summary", "VARCHAR"), + ], +) +def stage_summary_udtf(data_stage: StageLocation, value: int): + assert data_stage.stage_type.lower() == "external" + assert data_stage.storage + bucket = _stage_bucket(data_stage) + rows = [] + for offset in (0, 1): + current_value = value + offset + summary = ( + f"{data_stage.stage_name}:{bucket}:{data_stage.relative_path}:{current_value}" + ) + rows.append( + { + "stage_name": data_stage.stage_name or "", + "stage_type": data_stage.stage_type or "", + "bucket": bucket, + "relative_path": data_stage.relative_path or "", + "value": current_value, + "summary": summary, + } + ) + return rows + + @udf( stage_refs=["input_stage", "output_stage"], input_types=["INT"], @@ -452,6 +492,45 @@ def multi_stage_process( ) +@udf( + stage_refs=["input_stage", "output_stage"], + input_types=["INT"], + result_type=[ + ("input_stage", "VARCHAR"), + ("output_stage", "VARCHAR"), + ("input_bucket", "VARCHAR"), + ("output_bucket", "VARCHAR"), + ("input_relative_path", "VARCHAR"), + ("output_relative_path", "VARCHAR"), + ("result", "INT"), + ], +) +def multi_stage_process_udtf( + input_stage: StageLocation, output_stage: StageLocation, value: int +): + assert input_stage.storage and output_stage.storage + assert input_stage.stage_type.lower() == "external" + assert output_stage.stage_type.lower() == "external" + input_bucket = _stage_bucket(input_stage) + output_bucket = _stage_bucket(output_stage) + rows = [] + for offset in (0, 1): + current_value = value + offset + result = current_value + len(input_bucket) + len(output_bucket) + rows.append( + { + "input_stage": input_stage.stage_name or "", + "output_stage": output_stage.stage_name or "", + "input_bucket": input_bucket, + "output_bucket": output_bucket, + "input_relative_path": input_stage.relative_path or "", + "output_relative_path": output_stage.relative_path or "", + "result": result, + } + ) + return rows + + @udf( stage_refs=["input_stage", "output_stage"], input_types=["VARCHAR"], @@ -503,7 +582,9 @@ def immutable_multi_stage_process( udf_server.add_function(check_headers) udf_server.add_function(embedding_4) udf_server.add_function(stage_summary) + udf_server.add_function(stage_summary_udtf) udf_server.add_function(multi_stage_process) + udf_server.add_function(multi_stage_process_udtf) udf_server.add_function(immutable_multi_stage_process) # Built-in function