diff --git a/src/db.rs b/src/db.rs index ab457b6b..5605fd19 100644 --- a/src/db.rs +++ b/src/db.rs @@ -8,6 +8,7 @@ use crate::function::char_length::CharLength; use crate::function::current_date::CurrentDate; use crate::function::lower::Lower; use crate::function::numbers::Numbers; +use crate::function::octet_length::OctetLength; use crate::function::upper::Upper; use crate::optimizer::heuristic::batch::HepBatchStrategy; use crate::optimizer::heuristic::optimizer::HepOptimizer; @@ -61,6 +62,7 @@ impl DataBaseBuilder { builder.register_scala_function(CharLength::new("character_length".to_lowercase())); builder = builder.register_scala_function(CurrentDate::new()); builder = builder.register_scala_function(Lower::new()); + builder = builder.register_scala_function(OctetLength::new()); builder = builder.register_scala_function(Upper::new()); builder = builder.register_table_function(Numbers::new()); builder diff --git a/src/function/char_length.rs b/src/function/char_length.rs index 6cf0c100..817f591a 100644 --- a/src/function/char_length.rs +++ b/src/function/char_length.rs @@ -43,7 +43,7 @@ impl ScalarFunctionImpl for CharLength { } let mut length: u64 = 0; if let DataValue::Utf8 { value, ty, unit } = &mut value { - length = value.len() as u64; + length = value.chars().count() as u64; } Ok(DataValue::UInt64(length)) } diff --git a/src/function/mod.rs b/src/function/mod.rs index 6c660c57..7930e807 100644 --- a/src/function/mod.rs +++ b/src/function/mod.rs @@ -2,4 +2,5 @@ pub(crate) mod char_length; pub(crate) mod current_date; pub(crate) mod lower; pub(crate) mod numbers; +pub(crate) mod octet_length; pub(crate) mod upper; diff --git a/src/function/octet_length.rs b/src/function/octet_length.rs new file mode 100644 index 00000000..b712ee3c --- /dev/null +++ b/src/function/octet_length.rs @@ -0,0 +1,63 @@ +use crate::catalog::ColumnRef; +use crate::errors::DatabaseError; +use crate::expression::function::scala::FuncMonotonicity; +use crate::expression::function::scala::ScalarFunctionImpl; +use crate::expression::function::FunctionSummary; +use crate::expression::ScalarExpression; +use crate::types::tuple::Tuple; +use crate::types::value::DataValue; +use crate::types::LogicalType; +use serde::Deserialize; +use serde::Serialize; +use sqlparser::ast::CharLengthUnits; +use std::sync::Arc; + +#[derive(Debug, Serialize, Deserialize)] +pub(crate) struct OctetLength { + summary: FunctionSummary, +} + +impl OctetLength { + pub(crate) fn new() -> Arc { + let function_name = "octet_length".to_lowercase(); + let arg_types = vec![LogicalType::Varchar(None, CharLengthUnits::Characters)]; + Arc::new(Self { + summary: FunctionSummary { + name: function_name, + arg_types, + }, + }) + } +} + +#[typetag::serde] +impl ScalarFunctionImpl for OctetLength { + #[allow(unused_variables, clippy::redundant_closure_call)] + fn eval( + &self, + exprs: &[ScalarExpression], + tuples: Option<(&Tuple, &[ColumnRef])>, + ) -> Result { + let mut value = exprs[0].eval(tuples)?; + if !matches!(value.logical_type(), LogicalType::Varchar(_, _)) { + value = value.cast(&LogicalType::Varchar(None, CharLengthUnits::Characters))?; + } + let mut length: u64 = 0; + if let DataValue::Utf8 { value, ty, unit } = &mut value { + length = value.len() as u64; + } + Ok(DataValue::UInt64(length)) + } + + fn monotonicity(&self) -> Option { + todo!() + } + + fn return_type(&self) -> &LogicalType { + &LogicalType::Varchar(None, CharLengthUnits::Characters) + } + + fn summary(&self) -> &FunctionSummary { + &self.summary + } +} diff --git a/tests/slt/sql_2016/E021_04.slt b/tests/slt/sql_2016/E021_04.slt index ac401405..8d4842b5 100644 --- a/tests/slt/sql_2016/E021_04.slt +++ b/tests/slt/sql_2016/E021_04.slt @@ -5,8 +5,17 @@ SELECT CHARACTER_LENGTH ( 'foo' ) ---- 3 - query I SELECT CHAR_LENGTH ( 'foo' ) ---- 3 + +query I +SELECT CHARACTER_LENGTH ( '测试' ) +---- +2 + +query I +SELECT CHAR_LENGTH ( '测试' ) +---- +2 diff --git a/tests/slt/sql_2016/E021_05.slt b/tests/slt/sql_2016/E021_05.slt index e29fffa7..a40733ff 100644 --- a/tests/slt/sql_2016/E021_05.slt +++ b/tests/slt/sql_2016/E021_05.slt @@ -1,8 +1,11 @@ -# E021-05: OCTET_LENGTH function +#E021-05: OCTET_LENGTH function -# TODO: OCTET_LENGTH() +query I +SELECT OCTET_LENGTH ( 'foo' ) +---- +3 -# query I -# SELECT OCTET_LENGTH ( 'foo' ) -# ---- -# 3 +query I +SELECT OCTET_LENGTH ( '测试' ) +---- +6