diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ac7f6766..4dfc0693 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -4,7 +4,7 @@ on: push: branches: [ main ] pull_request: - branches: [ main ] + workflow_dispatch: env: CARGO_TERM_COLOR: always diff --git a/CHANGELOG.md b/CHANGELOG.md index 21f54c33..2bf5d300 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,9 +9,35 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] - ReleaseDate ### Removed + - **BREAKING** watch: `Client::watch()` API is removed ([#245]). - **BREAKING** mock: `watch()` and `watch_only_events()` are removed ([#245]). +### Changed + +- **BREAKING** query: `RowBinaryWithNamesAndTypes` is now used by default for query results. This may cause panics if + the row struct definition does not match the database schema. Use `Client::with_validation(false)` to revert to the + previous behavior which uses plain `RowBinary` format for fetching rows. ([#221]) +- **BREAKING** mock: when using `test-util` feature, it is now required to use `Client::with_mock(&mock)` to set up the + mock server, so it properly handles the response format and automatically disables parsing + `RowBinaryWithNamesAndTypes` header parsing and validation. Additionally, it is not required to call `with_url` + explicitly. See the [updated example](./examples/mock.rs). +- query: due to `RowBinaryWithNamesAndTypes` format usage, there might be an impact on fetch performance, which largely + depends on how the dataset is defined. If you notice decreased performance, consider disabling validation by using + `Client::with_validation(false)`. +- serde: it is now possible to deserialize Map ClickHouse type into `HashMap` (or `BTreeMap`, `IndexMap`, + `DashMap`, etc.). + +### Added + +- client: added `Client::with_validation` builder method. Validation is enabled by default, meaning that + `RowBinaryWithNamesAndTypes` format will be used to fetch rows from the database. If validation is disabled, + `RowBinary` format will be used, similarly to the previous versions. ([#221]). +- types: a new crate `clickhouse-types` was added to the project workspace. This crate is required for + `RowBinaryWithNamesAndTypes` struct definition validation, as it contains ClickHouse data types AST, as well as + functions and utilities to parse the types out of the ClickHouse server response. ([#221]). + +[#221]: https://github.com/ClickHouse/clickhouse-rs/pull/221 [#245]: https://github.com/ClickHouse/clickhouse-rs/pull/245 ## [0.13.3] - 2025-05-29 diff --git a/Cargo.toml b/Cargo.toml index 80324a88..def1d4c3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,10 +9,21 @@ homepage = "https://clickhouse.com" license = "MIT OR Apache-2.0" readme = "README.md" edition = "2021" -# update `derive/Cargo.toml` and CI if changed +# update `workspace.package.rust-version` below and CI if changed # TODO: after bumping to v1.80, remove `--precise` in the "msrv" CI job rust-version = "1.73.0" +[workspace] +members = ["derive", "types"] + +[workspace.package] +authors = ["ClickHouse Contributors", "Paul Loyd "] +repository = "https://github.com/ClickHouse/clickhouse-rs" +homepage = "https://clickhouse.com" +edition = "2021" +license = "MIT OR Apache-2.0" +rust-version = "1.73.0" + [lints.rust] rust_2018_idioms = { level = "warn", priority = -1 } unreachable_pub = "warn" @@ -26,16 +37,21 @@ undocumented_unsafe_blocks = "warn" all-features = true rustdoc-args = ["--cfg", "docsrs"] +[[bench]] +name = "select_nyc_taxi_data" +harness = false +required-features = ["time"] + [[bench]] name = "select_numbers" harness = false [[bench]] -name = "insert" +name = "mocked_insert" harness = false [[bench]] -name = "select" +name = "mocked_select" harness = false [[example]] @@ -97,6 +113,7 @@ rustls-tls-native-roots = [ [dependencies] clickhouse-derive = { version = "0.2.0", path = "derive" } +clickhouse-types = { version = "0.1.0", path = "types" } thiserror = "2.0" serde = "1.0.106" @@ -128,6 +145,7 @@ quanta = { version = "0.12", optional = true } replace_with = { version = "0.1.7" } [dev-dependencies] +clickhouse-derive = { version = "0.2.0", path = "derive" } criterion = "0.6" serde = { version = "1.0.106", features = ["derive"] } tokio = { version = "1.0.1", features = ["full", "test-util"] } @@ -136,6 +154,6 @@ serde_bytes = "0.11.4" serde_json = "1" serde_repr = "0.1.7" uuid = { version = "1", features = ["v4", "serde"] } -time = { version = "0.3.17", features = ["macros", "rand"] } +time = { version = "0.3.17", features = ["macros", "rand", "parsing"] } fixnum = { version = "0.9.2", features = ["serde", "i32", "i64", "i128"] } rand = { version = "0.9", features = ["small_rng"] } diff --git a/README.md b/README.md index 2d8e0d94..ea24f1a5 100644 --- a/README.md +++ b/README.md @@ -18,8 +18,10 @@ Official pure Rust typed client for ClickHouse DB. * Uses `serde` for encoding/decoding rows. * Supports `serde` attributes: `skip_serializing`, `skip_deserializing`, `rename`. -* Uses `RowBinary` encoding over HTTP transport. - * There are plans to switch to `Native` over TCP. +* Uses `RowBinaryWithNamesAndTypes` or `RowBinary` formats over HTTP transport. + * By default, `RowBinaryWithNamesAndTypes` with database schema validation is used. + * It is possible to switch to `RowBinary`, which can potentially lead to increased performance ([see below](#validation)). + * There are plans to implement `Native` format over TCP. * Supports TLS (see `native-tls` and `rustls-tls` features below). * Supports compression and decompression (LZ4 and LZ4HC). * Provides API for selecting. @@ -29,9 +31,30 @@ Official pure Rust typed client for ClickHouse DB. Note: [ch2rs](https://github.com/ClickHouse/ch2rs) is useful to generate a row type from ClickHouse. +## Validation + +Starting from 0.14.0, the crate uses `RowBinaryWithNamesAndTypes` format by default, which allows row types validation +against the ClickHouse schema. This enables clearer error messages in case of schema mismatch at the cost of +performance. Additionally, with enabled validation, the crate supports structs with correct field names and matching +types, but incorrect order of the fields, with an additional slight (5-10%) performance penalty. + +If you are looking to maximize performance, you could disable validation using `Client::with_validation(false)`. When +validation is disabled, the client switches to `RowBinary` format usage instead. + +The downside with plain `RowBinary` is that instead of clearer error messages, a mismatch between `Row` and database +schema will result in a `NotEnoughData` error without specific details. + +However, depending on the dataset, there might be x1.1 to x3 performance improvement, but that highly depends on the +shape and volume of the dataset. + +It is always recommended to measure the performance impact of validation in your specific use case. Additionally, +writing smoke tests to ensure that the row types match the ClickHouse schema is highly recommended, if you plan to +disable validation in your application. + ## Usage To use the crate, add this to your `Cargo.toml`: + ```toml [dependencies] clickhouse = "0.13.3" @@ -43,16 +66,6 @@ clickhouse = { version = "0.13.3", features = ["test-util"] }
-### Note about ClickHouse prior to v22.6 - - - -CH server older than v22.6 (2022-06-16) handles `RowBinary` [incorrectly](https://github.com/ClickHouse/ClickHouse/issues/37420) in some rare cases. Use 0.11 and enable `wa-37420` feature to solve this problem. Don't use it for newer versions. - -
-
- - ### Create a client @@ -249,7 +262,8 @@ How to choose between all these features? Here are some considerations: } ```
-* `Enum(8|16)` are supported using [serde_repr](https://docs.rs/serde_repr/latest/serde_repr/). +* `Enum(8|16)` are supported using [serde_repr](https://docs.rs/serde_repr/latest/serde_repr/). You could use + `#[repr(i8)]` for `Enum8` and `#[repr(i16)]` for `Enum16`.
Example @@ -262,7 +276,7 @@ How to choose between all these features? Here are some considerations: } #[derive(Debug, Serialize_repr, Deserialize_repr)] - #[repr(u8)] + #[repr(i8)] enum Level { Debug = 1, Info = 2, @@ -387,7 +401,7 @@ How to choose between all these features? Here are some considerations:
* `Tuple(A, B, ...)` maps to/from `(A, B, ...)` or a newtype around it. * `Array(_)` maps to/from any slice, e.g. `Vec<_>`, `&[_]`. Newtypes are also supported. -* `Map(K, V)` behaves like `Array((K, V))`. +* `Map(K, V)` can be deserialized as `HashMap` or `Vec<(K, V)>`. * `LowCardinality(_)` is supported seamlessly. * `Nullable(_)` maps to/from `Option<_>`. For `clickhouse::serde::*` helpers add `::option`.
@@ -416,7 +430,8 @@ How to choose between all these features? Here are some considerations: } ```
-* `Geo` types are supported. `Point` behaves like a tuple `(f64, f64)`, and the rest of the types are just slices of points. +* `Geo` types are supported. `Point` behaves like a tuple `(f64, f64)`, and the rest of the types are just slices of + points.
Example diff --git a/benches/README.md b/benches/README.md index d39bc8ab..a57105ba 100644 --- a/benches/README.md +++ b/benches/README.md @@ -4,31 +4,41 @@ All cases are run with `cargo bench --bench `. ## With a mocked server -These benchmarks are run against a mocked server, which is a simple HTTP server that responds with a fixed response. This is useful to measure the overhead of the client itself: -* `select` checks throughput of `Client::query()`. -* `insert` checks throughput of `Client::insert()` and `Client::inserter()` (if the `inserter` features is enabled). +These benchmarks are run against a mocked server, which is a simple HTTP server that responds with a fixed response. +This is useful to measure the overhead of the client itself. + +### Scenarios + +* [mocked_select](mocked_select.rs) checks throughput of `Client::query()`. +* [mocked_insert](mocked_insert.rs) checks throughput of `Client::insert()` and `Client::inserter()` + (requires `inserter` feature). ### How to collect perf data The crate's code runs on the thread with the name `testee`: + ```bash cargo bench --bench & perf record -p `ps -AT | grep testee | awk '{print $2}'` --call-graph dwarf,65528 --freq 5000 -g -- sleep 5 perf script > perf.script ``` -Then upload the `perf.script` file to [Firefox Profiler](https://profiler.firefox.com). +Then upload the `perf.script` file to [Firefox Profiler]. ## With a running ClickHouse server These benchmarks are run against a real ClickHouse server, so it must be started: + ```bash docker compose up -d cargo bench --bench ``` -Cases: -* `select_numbers` measures time of running a big SELECT query to the `system.numbers_mt` table. +### Scenarios + +* [select_numbers.rs](select_numbers.rs) measures time of running a big SELECT query to the `system.numbers_mt` table. +* [select_nyc_taxi_data.rs](select_nyc_taxi_data.rs) measures time of running a fairly large SELECT query (approximately + 3 million records) to the `nyc_taxi_data` table using the [NYC taxi dataset]. ### How to collect perf data @@ -38,4 +48,10 @@ perf record -p `ps -AT | grep | awk '{print $2}'` --call-graph dwarf,6552 perf script > perf.script ``` -Then upload the `perf.script` file to [Firefox Profiler](https://profiler.firefox.com). +Then upload the `perf.script` file to [Firefox Profiler]. + + + +[Firefox Profiler]: https://profiler.firefox.com + +[NYC taxi dataset]: https://clickhouse.com/docs/getting-started/example-datasets/nyc-taxi#create-the-table-trips \ No newline at end of file diff --git a/benches/common.rs b/benches/common.rs index 637447ab..7894928b 100644 --- a/benches/common.rs +++ b/benches/common.rs @@ -11,6 +11,7 @@ use std::{ }; use bytes::Bytes; +use clickhouse::error::Result; use futures::stream::StreamExt; use http_body_util::BodyExt; use hyper::{ @@ -25,35 +26,65 @@ use tokio::{ sync::{mpsc, oneshot}, }; -use clickhouse::error::Result; +pub(crate) struct ServerHandle { + handle: Option>, + shutdown_tx: Option>, +} -pub(crate) struct ServerHandle; +impl ServerHandle { + fn shutdown(&mut self) { + if let Some(tx) = self.shutdown_tx.take() { + tx.send(()).unwrap(); + } + if let Some(handle) = self.handle.take() { + handle.join().unwrap(); + } + } +} -pub(crate) fn start_server(addr: SocketAddr, serve: S) -> ServerHandle +impl Drop for ServerHandle { + fn drop(&mut self) { + self.shutdown(); + } +} + +pub(crate) async fn start_server(addr: SocketAddr, serve: S) -> ServerHandle where S: Fn(Request) -> F + Send + Sync + 'static, F: Future> + Send, B: Body + Send + 'static, { + let (shutdown_tx, mut shutdown_rx) = oneshot::channel::<()>(); + let (ready_tx, ready_rx) = oneshot::channel::<()>(); + let serving = async move { let listener = TcpListener::bind(addr).await.unwrap(); + ready_tx.send(()).unwrap(); loop { let (stream, _) = listener.accept().await.unwrap(); - - let service = - service::service_fn(|request| async { Ok::<_, Infallible>(serve(request).await) }); - - // SELECT benchmark doesn't read the whole body, so ignore possible errors. - let _ = conn::http1::Builder::new() + let server_future = conn::http1::Builder::new() .timer(TokioTimer::new()) - .serve_connection(TokioIo::new(stream), service) - .await; + .serve_connection( + TokioIo::new(stream), + service::service_fn(|request| async { + Ok::<_, Infallible>(serve(request).await) + }), + ); + tokio::select! { + _ = server_future => {} + _ = &mut shutdown_rx => { break; } + } } }; - run_on_st_runtime("server", serving); - ServerHandle + let handle = Some(run_on_st_runtime("server", serving)); + ready_rx.await.unwrap(); + + ServerHandle { + handle, + shutdown_tx: Some(shutdown_tx), + } } pub(crate) async fn skip_incoming(request: Request) { @@ -105,7 +136,7 @@ pub(crate) fn start_runner() -> RunnerHandle { RunnerHandle { tx } } -fn run_on_st_runtime(name: &str, f: impl Future + Send + 'static) { +fn run_on_st_runtime(name: &str, f: impl Future + Send + 'static) -> thread::JoinHandle<()> { let name = name.to_string(); thread::Builder::new() .name(name.clone()) @@ -121,5 +152,5 @@ fn run_on_st_runtime(name: &str, f: impl Future + Send + 'static) { .unwrap() .block_on(f); }) - .unwrap(); + .unwrap() } diff --git a/benches/common_select.rs b/benches/common_select.rs new file mode 100644 index 00000000..2b1ebaf1 --- /dev/null +++ b/benches/common_select.rs @@ -0,0 +1,136 @@ +#![allow(dead_code)] + +use clickhouse::query::RowCursor; +use clickhouse::{Client, Compression, Row}; +use serde::Deserialize; +use std::time::{Duration, Instant}; + +pub(crate) trait WithId { + fn id(&self) -> u64; +} +pub(crate) trait WithAccessType { + const ACCESS_TYPE: &'static str; +} +pub(crate) trait BenchmarkRow<'a>: Row + Deserialize<'a> + WithId + WithAccessType {} + +#[macro_export] +macro_rules! impl_benchmark_row { + ($type:ty, $id_field:ident, $access_type:literal) => { + impl WithId for $type { + fn id(&self) -> u64 { + self.$id_field as u64 + } + } + + impl WithAccessType for $type { + const ACCESS_TYPE: &'static str = $access_type; + } + + impl<'a> BenchmarkRow<'a> for $type {} + }; +} + +#[macro_export] +macro_rules! impl_benchmark_row_no_access_type { + ($type:ty, $id_field:ident) => { + impl WithId for $type { + fn id(&self) -> u64 { + self.$id_field + } + } + + impl WithAccessType for $type { + const ACCESS_TYPE: &'static str = ""; + } + + impl<'a> BenchmarkRow<'a> for $type {} + }; +} + +pub(crate) fn print_header(add: Option<&str>) { + let add = add.unwrap_or(""); + println!("compress validation elapsed throughput received{add}"); +} + +pub(crate) fn print_results<'a, T: BenchmarkRow<'a>>( + stats: &BenchmarkStats, + compression: Compression, + validation: bool, +) { + let BenchmarkStats { + throughput_mbytes_sec, + received_mbytes, + elapsed, + .. + } = stats; + let validation_mode = if validation { "enabled" } else { "disabled" }; + let compression = match compression { + Compression::None => "none", + #[cfg(feature = "lz4")] + Compression::Lz4 => "lz4", + _ => panic!("Unexpected compression mode"), + }; + let access = if T::ACCESS_TYPE.is_empty() { + "" + } else { + let access_type = T::ACCESS_TYPE; + &format!(" {access_type:>6}") + }; + println!("{compression:>8} {validation_mode:>10} {elapsed:>9.3?} {throughput_mbytes_sec:>4.0} MiB/s {received_mbytes:>4.0} MiB{access}"); +} + +pub(crate) async fn fetch_cursor<'a, T: BenchmarkRow<'a>>( + compression: Compression, + validation: bool, + query: &str, +) -> RowCursor { + let client = Client::default() + .with_compression(compression) + .with_url("http://localhost:8123") + .with_validation(validation); + client.query(query).fetch::().unwrap() +} + +pub(crate) async fn do_select_bench<'a, T: BenchmarkRow<'a>>( + query: &str, + compression: Compression, + validation: bool, +) -> BenchmarkStats { + let start = Instant::now(); + let mut cursor = fetch_cursor::(compression, validation, query).await; + + let mut sum = 0; + while let Some(row) = cursor.next().await.unwrap() { + sum += row.id(); + std::hint::black_box(&row); + } + + BenchmarkStats::new(&cursor, &start, sum) +} + +pub(crate) struct BenchmarkStats { + pub(crate) throughput_mbytes_sec: f64, + pub(crate) decoded_mbytes: f64, + pub(crate) received_mbytes: f64, + pub(crate) elapsed: Duration, + // RustRover is unhappy with pub(crate) + pub result: R, +} + +impl BenchmarkStats { + pub(crate) fn new(cursor: &RowCursor, start: &Instant, result: R) -> Self { + let elapsed = start.elapsed(); + let dec_bytes = cursor.decoded_bytes(); + let decoded_mbytes = dec_bytes as f64 / 1024.0 / 1024.0; + let recv_bytes = cursor.received_bytes(); + let received_mbytes = recv_bytes as f64 / 1024.0 / 1024.0; + let throughput_mbytes_sec = decoded_mbytes / elapsed.as_secs_f64(); + BenchmarkStats { + throughput_mbytes_sec, + decoded_mbytes, + received_mbytes, + elapsed, + result, + } + } +} diff --git a/benches/insert.rs b/benches/mocked_insert.rs similarity index 83% rename from benches/insert.rs rename to benches/mocked_insert.rs index 2c8fbb3b..cdbef62c 100644 --- a/benches/insert.rs +++ b/benches/mocked_insert.rs @@ -1,15 +1,15 @@ -use std::{ - future::Future, - mem, - time::{Duration, Instant}, -}; - use bytes::Bytes; use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use http_body_util::Empty; use hyper::{body::Incoming, Request, Response}; use serde::Serialize; use std::hint::black_box; +use std::net::SocketAddr; +use std::{ + future::Future, + mem, + time::{Duration, Instant}, +}; use clickhouse::{error::Result, Client, Compression, Row}; @@ -47,7 +47,9 @@ impl SomeRow { } } -async fn run_insert(client: Client, iters: u64) -> Result { +async fn run_insert(client: Client, addr: SocketAddr, iters: u64) -> Result { + let _server = common::start_server(addr, serve).await; + let start = Instant::now(); let mut insert = client.insert("table")?; @@ -60,7 +62,13 @@ async fn run_insert(client: Client, iters: u64) -> Result { } #[cfg(feature = "inserter")] -async fn run_inserter(client: Client, iters: u64) -> Result { +async fn run_inserter( + client: Client, + addr: SocketAddr, + iters: u64, +) -> Result { + let _server = common::start_server(addr, serve).await; + let start = Instant::now(); let mut inserter = client.inserter("table")?.with_max_rows(iters); @@ -78,12 +86,11 @@ async fn run_inserter(client: Client, iters: u64) -> Re Ok(start.elapsed()) } -fn run(c: &mut Criterion, name: &str, port: u16, f: impl Fn(Client, u64) -> F) +fn run(c: &mut Criterion, name: &str, port: u16, f: impl Fn(Client, SocketAddr, u64) -> F) where F: Future> + Send + 'static, { - let addr = format!("127.0.0.1:{port}").parse().unwrap(); - let _server = common::start_server(addr, serve); + let addr: SocketAddr = format!("127.0.0.1:{port}").parse().unwrap(); let runner = common::start_runner(); let mut group = c.benchmark_group(name); @@ -93,7 +100,7 @@ where let client = Client::default() .with_url(format!("http://{addr}")) .with_compression(Compression::None); - runner.run((f)(client, iters)) + runner.run((f)(client, addr, iters)) }) }); #[cfg(feature = "lz4")] @@ -102,7 +109,7 @@ where let client = Client::default() .with_url(format!("http://{addr}")) .with_compression(Compression::Lz4); - runner.run((f)(client, iters)) + runner.run((f)(client, addr, iters)) }) }); group.finish(); diff --git a/benches/select.rs b/benches/mocked_select.rs similarity index 54% rename from benches/select.rs rename to benches/mocked_select.rs index 378329c7..e9fcef75 100644 --- a/benches/select.rs +++ b/benches/mocked_select.rs @@ -1,10 +1,9 @@ -use std::{ - convert::Infallible, - mem, - time::{Duration, Instant}, -}; - use bytes::Bytes; +use clickhouse::{ + error::{Error, Result}, + Client, Compression, Row, +}; +use clickhouse_types::{Column, DataTypeNode}; use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use futures::stream::{self, StreamExt as _}; use http_body_util::StreamBody; @@ -13,22 +12,42 @@ use hyper::{ Request, Response, }; use serde::Deserialize; -use std::hint::black_box; - -use clickhouse::{ - error::{Error, Result}, - Client, Compression, Row, -}; +use std::convert::Infallible; +use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; +use std::time::{Duration, Instant}; mod common; async fn serve( request: Request, - chunk: Bytes, + compression: Compression, ) -> Response> { common::skip_incoming(request).await; - let stream = stream::repeat(chunk).map(|chunk| Ok(Frame::data(chunk))); + let write_schema = async move { + let schema = vec![ + Column::new("a".to_string(), DataTypeNode::UInt64), + Column::new("b".to_string(), DataTypeNode::Int64), + Column::new("c".to_string(), DataTypeNode::Int32), + Column::new("d".to_string(), DataTypeNode::UInt32), + ]; + + let mut buffer = Vec::new(); + clickhouse_types::put_rbwnat_columns_header(&schema, &mut buffer).unwrap(); + + let buffer = match compression { + Compression::None => Bytes::from(buffer), + #[cfg(feature = "lz4")] + Compression::Lz4 => clickhouse::_priv::lz4_compress(&buffer).unwrap(), + _ => unreachable!(), + }; + + Ok(Frame::data(buffer)) + }; + + let chunk = prepare_chunk(); + let stream = + stream::once(write_schema).chain(stream::repeat(chunk).map(|chunk| Ok(Frame::data(chunk)))); Response::new(StreamBody::new(stream)) } @@ -53,10 +72,13 @@ fn prepare_chunk() -> Bytes { chunk } +const ADDR: SocketAddr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 6523)); + fn select(c: &mut Criterion) { - let addr = "127.0.0.1:6543".parse().unwrap(); - let chunk = prepare_chunk(); - let _server = common::start_server(addr, move |req| serve(req, chunk.clone())); + async fn start_server(compression: Compression) -> common::ServerHandle { + common::start_server(ADDR, move |req| serve(req, compression)).await + } + let runner = common::start_runner(); #[derive(Default, Debug, Row, Deserialize)] @@ -67,7 +89,9 @@ fn select(c: &mut Criterion) { d: u32, } - async fn select_rows(client: Client, iters: u64) -> Result { + async fn select_rows(client: Client, iters: u64, compression: Compression) -> Result { + let _server = start_server(compression).await; + let mut sum = SomeRow::default(); let start = Instant::now(); let mut cursor = client @@ -84,11 +108,19 @@ fn select(c: &mut Criterion) { sum.d = sum.d.wrapping_add(row.d); } - black_box(sum); - Ok(start.elapsed()) + std::hint::black_box(sum); + + let elapsed = start.elapsed(); + Ok(elapsed) } - async fn select_bytes(client: Client, min_size: u64) -> Result { + async fn select_bytes( + client: Client, + min_size: u64, + compression: Compression, + ) -> Result { + let _server = start_server(compression).await; + let start = Instant::now(); let mut cursor = client .query("SELECT value FROM some") @@ -96,7 +128,7 @@ fn select(c: &mut Criterion) { let mut size = 0; while size < min_size { - let buf = black_box(cursor.next().await?); + let buf = std::hint::black_box(cursor.next().await?); size += buf.unwrap().len() as u64; } @@ -104,22 +136,24 @@ fn select(c: &mut Criterion) { } let mut group = c.benchmark_group("rows"); - group.throughput(Throughput::Bytes(mem::size_of::() as u64)); + group.throughput(Throughput::Bytes(size_of::() as u64)); group.bench_function("uncompressed", |b| { b.iter_custom(|iters| { + let compression = Compression::None; let client = Client::default() - .with_url(format!("http://{addr}")) - .with_compression(Compression::None); - runner.run(select_rows(client, iters)) + .with_url(format!("http://{ADDR}")) + .with_compression(compression); + runner.run(select_rows(client, iters, compression)) }) }); #[cfg(feature = "lz4")] group.bench_function("lz4", |b| { b.iter_custom(|iters| { + let compression = Compression::Lz4; let client = Client::default() - .with_url(format!("http://{addr}")) - .with_compression(Compression::Lz4); - runner.run(select_rows(client, iters)) + .with_url(format!("http://{ADDR}")) + .with_compression(compression); + runner.run(select_rows(client, iters, compression)) }) }); group.finish(); @@ -129,19 +163,21 @@ fn select(c: &mut Criterion) { group.throughput(Throughput::Bytes(MIB)); group.bench_function("uncompressed", |b| { b.iter_custom(|iters| { + let compression = Compression::None; let client = Client::default() - .with_url(format!("http://{addr}")) - .with_compression(Compression::None); - runner.run(select_bytes(client, iters * MIB)) + .with_url(format!("http://{ADDR}")) + .with_compression(compression); + runner.run(select_bytes(client, iters * MIB, compression)) }) }); #[cfg(feature = "lz4")] group.bench_function("lz4", |b| { b.iter_custom(|iters| { + let compression = Compression::None; let client = Client::default() - .with_url(format!("http://{addr}")) - .with_compression(Compression::Lz4); - runner.run(select_bytes(client, iters * MIB)) + .with_url(format!("http://{ADDR}")) + .with_compression(compression); + runner.run(select_bytes(client, iters * MIB, compression)) }) }); group.finish(); diff --git a/benches/select_numbers.rs b/benches/select_numbers.rs index 869d6ba5..2adfde86 100644 --- a/benches/select_numbers.rs +++ b/benches/select_numbers.rs @@ -1,47 +1,38 @@ use serde::Deserialize; -use clickhouse::{Client, Compression, Row}; +use crate::common_select::{ + do_select_bench, print_header, print_results, BenchmarkRow, WithAccessType, WithId, +}; +use clickhouse::{Compression, Row}; + +mod common_select; #[derive(Row, Deserialize)] struct Data { - no: u64, -} - -async fn bench(name: &str, compression: Compression) { - let start = std::time::Instant::now(); - let (sum, dec_mbytes, rec_mbytes) = tokio::spawn(do_bench(compression)).await.unwrap(); - assert_eq!(sum, 124999999750000000); - let elapsed = start.elapsed(); - let throughput = dec_mbytes / elapsed.as_secs_f64(); - println!("{name:>8} {elapsed:>7.3?} {throughput:>4.0} MiB/s {rec_mbytes:>4.0} MiB"); + number: u64, } -async fn do_bench(compression: Compression) -> (u64, f64, f64) { - let client = Client::default() - .with_compression(compression) - .with_url("http://localhost:8123"); - - let mut cursor = client - .query("SELECT number FROM system.numbers_mt LIMIT 500000000") - .fetch::() - .unwrap(); - - let mut sum = 0; - while let Some(row) = cursor.next().await.unwrap() { - sum += row.no; - } - - let dec_bytes = cursor.decoded_bytes(); - let dec_mbytes = dec_bytes as f64 / 1024.0 / 1024.0; - let recv_bytes = cursor.received_bytes(); - let recv_mbytes = recv_bytes as f64 / 1024.0 / 1024.0; - (sum, dec_mbytes, recv_mbytes) +impl_benchmark_row_no_access_type!(Data, number); + +async fn bench(compression: Compression, validation: bool) { + let stats = do_select_bench::( + "SELECT number FROM system.numbers_mt LIMIT 500000000", + compression, + validation, + ) + .await; + assert_eq!(stats.result, 124999999750000000); + print_results::(&stats, compression, validation); } #[tokio::main] async fn main() { - println!("compress elapsed throughput received"); - bench("none", Compression::None).await; + print_header(None); + bench(Compression::None, false).await; + bench(Compression::None, true).await; #[cfg(feature = "lz4")] - bench("lz4", Compression::Lz4).await; + { + bench(Compression::Lz4, false).await; + bench(Compression::Lz4, true).await; + } } diff --git a/benches/select_nyc_taxi_data.rs b/benches/select_nyc_taxi_data.rs new file mode 100644 index 00000000..618ea8c5 --- /dev/null +++ b/benches/select_nyc_taxi_data.rs @@ -0,0 +1,100 @@ +#![cfg(feature = "time")] + +use crate::common_select::{ + do_select_bench, print_header, print_results, BenchmarkRow, WithAccessType, WithId, +}; +use clickhouse::{Compression, Row}; +use serde::Deserialize; +use serde_repr::Deserialize_repr; +use time::OffsetDateTime; + +mod common_select; + +#[derive(Debug, Clone, Deserialize_repr)] +#[repr(i8)] +pub enum PaymentType { + CSH = 1, + CRE = 2, + NOC = 3, + DIS = 4, + UNK = 5, +} + +/// Uses just `visit_seq` since the order of the fields matches the database schema. +#[derive(Row, Deserialize)] +#[allow(dead_code)] +struct TripSmallSeqAccess { + trip_id: u32, + #[serde(with = "clickhouse::serde::time::datetime")] + pickup_datetime: OffsetDateTime, + #[serde(with = "clickhouse::serde::time::datetime")] + dropoff_datetime: OffsetDateTime, + pickup_longitude: Option, + pickup_latitude: Option, + dropoff_longitude: Option, + dropoff_latitude: Option, + passenger_count: u8, + trip_distance: f32, + fare_amount: f32, + extra: f32, + tip_amount: f32, + tolls_amount: f32, + total_amount: f32, + payment_type: PaymentType, + pickup_ntaname: String, + dropoff_ntaname: String, +} + +/// Uses `visit_map` to deserialize instead of `visit_seq`, +/// since the fields definition is correct, but the order is wrong. +#[derive(Row, Deserialize)] +#[allow(dead_code)] +struct TripSmallMapAccess { + pickup_ntaname: String, + dropoff_ntaname: String, + trip_id: u32, + passenger_count: u8, + trip_distance: f32, + fare_amount: f32, + extra: f32, + tip_amount: f32, + tolls_amount: f32, + total_amount: f32, + payment_type: PaymentType, + #[serde(with = "clickhouse::serde::time::datetime")] + pickup_datetime: OffsetDateTime, + #[serde(with = "clickhouse::serde::time::datetime")] + dropoff_datetime: OffsetDateTime, + pickup_longitude: Option, + pickup_latitude: Option, + dropoff_longitude: Option, + dropoff_latitude: Option, +} + +impl_benchmark_row!(TripSmallSeqAccess, trip_id, "seq"); +impl_benchmark_row!(TripSmallMapAccess, trip_id, "map"); + +async fn bench<'a, T: BenchmarkRow<'a>>(compression: Compression, validation: bool) { + let stats = do_select_bench::( + "SELECT * FROM nyc_taxi.trips_small ORDER BY trip_id DESC", + compression, + validation, + ) + .await; + assert_eq!(stats.result, 3630387815532582); + print_results::(&stats, compression, validation); +} + +#[tokio::main] +async fn main() { + print_header(Some(" access")); + bench::(Compression::None, false).await; + bench::(Compression::None, true).await; + bench::(Compression::None, true).await; + #[cfg(feature = "lz4")] + { + bench::(Compression::Lz4, false).await; + bench::(Compression::Lz4, true).await; + bench::(Compression::Lz4, true).await; + } +} diff --git a/derive/Cargo.toml b/derive/Cargo.toml index 56cb3220..4c17fc85 100644 --- a/derive/Cargo.toml +++ b/derive/Cargo.toml @@ -1,14 +1,14 @@ [package] name = "clickhouse-derive" -version = "0.2.0" description = "A macro for deriving clickhouse::Row" -authors = ["ClickHouse Contributors", "Paul Loyd "] -repository = "https://github.com/ClickHouse/clickhouse-rs" -homepage = "https://clickhouse.com" -edition = "2021" -license = "MIT OR Apache-2.0" -# update `Cargo.toml` and CI if changed -rust-version = "1.73.0" +version = "0.2.0" + +authors.workspace = true +repository.workspace = true +homepage.workspace = true +edition.workspace = true +license.workspace = true +rust-version.workspace = true [lib] proc-macro = true diff --git a/derive/src/lib.rs b/derive/src/lib.rs index bd5675a8..5e539250 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -50,7 +50,7 @@ pub fn row(input: proc_macro::TokenStream) -> proc_macro::TokenStream { }; // TODO: do something more clever? - let _ = cx.check().expect("derive context error"); + cx.check().expect("derive context error"); let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl(); @@ -58,7 +58,10 @@ pub fn row(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let expanded = quote! { #[automatically_derived] impl #impl_generics clickhouse::Row for #name #ty_generics #where_clause { + const NAME: &'static str = stringify!(#name); const COLUMN_NAMES: &'static [&'static str] = #column_names; + const COLUMN_COUNT: usize = ::COLUMN_NAMES.len(); + const KIND: clickhouse::RowKind = clickhouse::RowKind::Struct; } }; diff --git a/docker-compose.yml b/docker-compose.yml index e08344d4..cc309127 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,3 +1,4 @@ +name: clickhouse-rs services: clickhouse: image: 'clickhouse/clickhouse-server:${CLICKHOUSE_VERSION-latest-alpine}' diff --git a/examples/async_insert.rs b/examples/async_insert.rs index 7a9c18f6..8b567266 100644 --- a/examples/async_insert.rs +++ b/examples/async_insert.rs @@ -10,7 +10,7 @@ use clickhouse::{error::Result, Client, Row}; #[derive(Debug, Serialize, Deserialize, Row)] struct Event { - timestamp: u64, + timestamp: i64, message: String, } @@ -70,9 +70,9 @@ async fn main() -> Result<()> { Ok(()) } -fn now() -> u64 { +fn now() -> i64 { UNIX_EPOCH .elapsed() .expect("invalid system time") - .as_nanos() as u64 + .as_nanos() as i64 } diff --git a/examples/clickhouse_cloud.rs b/examples/clickhouse_cloud.rs index 7002160d..5c84d84b 100644 --- a/examples/clickhouse_cloud.rs +++ b/examples/clickhouse_cloud.rs @@ -66,7 +66,7 @@ async fn main() -> clickhouse::error::Result<()> { #[derive(Debug, Serialize, Deserialize, Row)] struct Data { - id: u32, + id: i32, name: String, } diff --git a/examples/data_types_derive_simple.rs b/examples/data_types_derive_simple.rs index cab63808..82d2d2f8 100644 --- a/examples/data_types_derive_simple.rs +++ b/examples/data_types_derive_simple.rs @@ -53,15 +53,26 @@ async fn main() -> Result<()> { decimal64_18_8 Decimal(18, 8), decimal128_38_12 Decimal(38, 12), -- decimal256_76_20 Decimal(76, 20), - date Date, - date32 Date32, - datetime DateTime, - datetime_tz DateTime('UTC'), - datetime64_0 DateTime64(0), - datetime64_3 DateTime64(3), - datetime64_6 DateTime64(6), - datetime64_9 DateTime64(9), - datetime64_9_tz DateTime64(9, 'UTC') + + time_date Date, + time_date32 Date32, + time_datetime DateTime, + time_datetime_tz DateTime('UTC'), + time_datetime64_0 DateTime64(0), + time_datetime64_3 DateTime64(3), + time_datetime64_6 DateTime64(6), + time_datetime64_9 DateTime64(9), + time_datetime64_9_tz DateTime64(9, 'UTC'), + + chrono_date Date, + chrono_date32 Date32, + chrono_datetime DateTime, + chrono_datetime_tz DateTime('UTC'), + chrono_datetime64_0 DateTime64(0), + chrono_datetime64_3 DateTime64(3), + chrono_datetime64_6 DateTime64(6), + chrono_datetime64_9 DateTime64(9), + chrono_datetime64_9_tz DateTime64(9, 'UTC'), ) ENGINE MergeTree ORDER BY (); ", ) @@ -166,7 +177,7 @@ type Decimal128 = FixedPoint; // Decimal(38, 12) = Decimal128(12) #[derive(Clone, Debug, PartialEq)] #[derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr)] -#[repr(u8)] +#[repr(i8)] pub enum Enum8 { Foo = 1, Bar = 2, @@ -174,7 +185,7 @@ pub enum Enum8 { #[derive(Clone, Debug, PartialEq)] #[derive(serde_repr::Serialize_repr, serde_repr::Deserialize_repr)] -#[repr(u16)] +#[repr(i16)] pub enum Enum16 { Qaz = 42, Qux = 255, diff --git a/examples/data_types_variant.rs b/examples/data_types_variant.rs index e575464b..35aa7568 100644 --- a/examples/data_types_variant.rs +++ b/examples/data_types_variant.rs @@ -140,7 +140,7 @@ fn get_rows() -> Vec { // This enum represents Variant(Array(UInt16), Bool, Date, FixedString(6), Float32, Float64, Int128, Int16, Int32, Int64, Int8, String, UInt128, UInt16, UInt32, UInt64, UInt8) #[derive(Debug, PartialEq, Serialize, Deserialize)] enum MyRowVariant { - Array(Vec), + Array(Vec), Boolean(bool), // attributes should work in this case, too #[serde(with = "clickhouse::serde::time::date")] diff --git a/examples/enums.rs b/examples/enums.rs index 851ca3fb..d20e5dc6 100644 --- a/examples/enums.rs +++ b/examples/enums.rs @@ -35,14 +35,14 @@ async fn main() -> Result<()> { #[derive(Debug, Serialize, Deserialize, Row)] struct Event { - timestamp: u64, + timestamp: i64, message: String, level: Level, } // How to define enums that map to `Enum8`/`Enum16`. #[derive(Debug, Serialize_repr, Deserialize_repr)] - #[repr(u8)] + #[repr(i8)] enum Level { Debug = 1, Info = 2, @@ -69,9 +69,9 @@ async fn main() -> Result<()> { Ok(()) } -fn now() -> u64 { +fn now() -> i64 { UNIX_EPOCH .elapsed() .expect("invalid system time") - .as_nanos() as u64 + .as_nanos() as i64 } diff --git a/examples/mock.rs b/examples/mock.rs index d6893d20..af5342f8 100644 --- a/examples/mock.rs +++ b/examples/mock.rs @@ -29,7 +29,9 @@ async fn make_insert(client: &Client, data: &[SomeRow]) -> Result<()> { #[tokio::main] async fn main() { let mock = test::Mock::new(); - let client = Client::default().with_url(mock.url()); + // Note that an explicit `with_url` call is not required, + // it will be set automatically to the mock server URL. + let client = Client::default().with_mock(&mock); let list = vec![SomeRow { no: 1 }, SomeRow { no: 2 }]; // How to test DDL. diff --git a/rustfmt.toml b/rustfmt.toml index 5f62e976..ef4162c2 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1,6 +1,2 @@ edition = "2021" merge_derives = false -imports_granularity = "Crate" -normalize_comments = true -reorder_impl_items = true -wrap_comments = true diff --git a/src/cursors/row.rs b/src/cursors/row.rs index 6f17cfcc..9538e17a 100644 --- a/src/cursors/row.rs +++ b/src/cursors/row.rs @@ -1,10 +1,13 @@ +use crate::row_metadata::RowMetadata; use crate::{ bytes_ext::BytesExt, cursors::RawCursor, error::{Error, Result}, response::Response, - rowbinary, + rowbinary, Row, }; +use clickhouse_types::error::TypesError; +use clickhouse_types::parse_rbwnat_columns_header; use serde::Deserialize; use std::marker::PhantomData; @@ -13,15 +16,64 @@ use std::marker::PhantomData; pub struct RowCursor { raw: RawCursor, bytes: BytesExt, + validation: bool, + /// [`None`] until the first call to [`RowCursor::next()`], + /// as [`RowCursor::new`] is not `async`, so it loads lazily. + row_metadata: Option, _marker: PhantomData, } impl RowCursor { - pub(crate) fn new(response: Response) -> Self { + pub(crate) fn new(response: Response, validation: bool) -> Self { Self { + _marker: PhantomData, raw: RawCursor::new(response), bytes: BytesExt::default(), - _marker: PhantomData, + row_metadata: None, + validation, + } + } + + #[cold] + #[inline(never)] + async fn read_columns(&mut self) -> Result<()> + where + T: Row, + { + loop { + if self.bytes.remaining() > 0 { + let mut slice = self.bytes.slice(); + match parse_rbwnat_columns_header(&mut slice) { + Ok(columns) if !columns.is_empty() => { + self.bytes.set_remaining(slice.len()); + self.row_metadata = Some(RowMetadata::new::(columns)); + return Ok(()); + } + Ok(_) => { + // This does not panic, as it could be a network issue + // or a malformed response from the server or LB, + // and a simple retry might help in certain cases. + return Err(Error::BadResponse( + "Expected at least one column in the header".to_string(), + )); + } + Err(TypesError::NotEnoughData(_)) => {} + Err(err) => { + return Err(Error::InvalidColumnsHeader(err.into())); + } + } + } + match self.raw.next().await? { + Some(chunk) => self.bytes.extend(chunk), + None if self.row_metadata.is_none() => { + // Similar to the other BadResponse branch above + return Err(Error::BadResponse( + "Could not read columns header".to_string(), + )); + } + // if the result set is empty, there is only the columns header + None => return Ok(()), + } } } @@ -32,20 +84,38 @@ impl RowCursor { /// # Cancel safety /// /// This method is cancellation safe. - pub async fn next<'a, 'b: 'a>(&'a mut self) -> Result> + pub async fn next<'cursor, 'data: 'cursor>(&'cursor mut self) -> Result> where - T: Deserialize<'b>, + T: Deserialize<'data> + Row, { loop { - let mut slice = super::workaround_51132(self.bytes.slice()); - - match rowbinary::deserialize_from(&mut slice) { - Ok(value) => { - self.bytes.set_remaining(slice.len()); - return Ok(Some(value)); + if self.bytes.remaining() > 0 { + let mut slice: &[u8]; + let result = if self.validation { + if self.row_metadata.is_none() { + self.read_columns().await?; + if self.bytes.remaining() == 0 { + continue; + } + } + slice = super::workaround_51132(self.bytes.slice()); + rowbinary::deserialize_row_with_validation::( + &mut slice, + // handled above + self.row_metadata.as_ref().unwrap(), + ) + } else { + slice = super::workaround_51132(self.bytes.slice()); + rowbinary::deserialize_row::(&mut slice) + }; + match result { + Err(Error::NotEnoughData) => {} + Ok(value) => { + self.bytes.set_remaining(slice.len()); + return Ok(Some(value)); + } + Err(err) => return Err(err), } - Err(Error::NotEnoughData) => {} - Err(err) => return Err(err), } match self.raw.next().await? { @@ -70,8 +140,7 @@ impl RowCursor { self.raw.received_bytes() } - /// Returns the total size in bytes decompressed since the cursor was - /// created. + /// Returns the total size in bytes decompressed since the cursor was created. #[inline] pub fn decoded_bytes(&self) -> u64 { self.raw.decoded_bytes() diff --git a/src/error.rs b/src/error.rs index eedc52c3..438a6b87 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,8 +1,7 @@ //! Contains [`Error`] and corresponding [`Result`]. -use std::{error::Error as StdError, fmt, io, result, str::Utf8Error}; - use serde::{de, ser}; +use std::{error::Error as StdError, fmt, io, result, str::Utf8Error}; /// A result with a specified [`Error`] type. pub type Result = result::Result; @@ -42,6 +41,8 @@ pub enum Error { BadResponse(String), #[error("timeout expired")] TimedOut, + #[error("error while parsing columns header from the response: {0}")] + InvalidColumnsHeader(#[source] BoxedError), #[error("unsupported: {0}")] Unsupported(String), #[error("{0}")] @@ -50,6 +51,12 @@ pub enum Error { assert_impl_all!(Error: StdError, Send, Sync); +impl From for Error { + fn from(err: clickhouse_types::error::TypesError) -> Self { + Self::InvalidColumnsHeader(Box::new(err)) + } +} + impl From for Error { fn from(error: hyper::Error) -> Self { Self::Network(Box::new(error)) diff --git a/src/lib.rs b/src/lib.rs index b9b9ed41..d2830e50 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,11 +5,10 @@ #[macro_use] extern crate static_assertions; +pub use self::{compression::Compression, row::Row, row::RowKind}; use self::{error::Result, http_client::HttpClient}; -use std::{collections::HashMap, fmt::Display, sync::Arc}; - -pub use self::{compression::Compression, row::Row}; pub use clickhouse_derive::Row; +use std::{collections::HashMap, fmt::Display, sync::Arc}; pub mod error; pub mod insert; @@ -29,6 +28,7 @@ mod http_client; mod request_body; mod response; mod row; +mod row_metadata; mod rowbinary; #[cfg(feature = "inserter")] mod ticks; @@ -45,6 +45,10 @@ pub struct Client { options: HashMap, headers: HashMap, products_info: Vec, + validation: bool, + + #[cfg(feature = "test-util")] + mocked: bool, } #[derive(Clone)] @@ -99,6 +103,9 @@ impl Client { options: HashMap::new(), headers: HashMap::new(), products_info: Vec::default(), + validation: true, + #[cfg(feature = "test-util")] + mocked: false, } } @@ -312,11 +319,59 @@ impl Client { query::Query::new(self, query) } + /// Enables or disables [`Row`] data types validation against the database schema + /// at the cost of performance. Validation is enabled by default, and in this mode, + /// the client will use `RowBinaryWithNamesAndTypes` format. + /// + /// If you are looking to maximize performance, you could disable validation using this method. + /// When validation is disabled, the client switches to `RowBinary` format usage instead. + /// + /// The downside with plain `RowBinary` is that instead of clearer error messages, + /// a mismatch between [`Row`] and database schema will result + /// in a [`error::Error::NotEnoughData`] error without specific details. + /// + /// However, depending on the dataset, there might be x1.1 to x3 performance improvement, + /// but that highly depends on the shape and volume of the dataset. + /// + /// It is always recommended to measure the performance impact of validation + /// in your specific use case. Additionally, writing smoke tests to ensure that + /// the row types match the ClickHouse schema is highly recommended, + /// if you plan to disable validation in your application. + pub fn with_validation(mut self, enabled: bool) -> Self { + self.validation = enabled; + self + } + + /// Used internally to check if the validation mode is enabled, + /// as it takes into account the `test-util` feature flag. + #[inline] + pub(crate) fn get_validation(&self) -> bool { + #[cfg(feature = "test-util")] + if self.mocked { + return false; + } + self.validation + } + /// Used internally to modify the options map of an _already cloned_ /// [`Client`] instance. pub(crate) fn add_option(&mut self, name: impl Into, value: impl Into) { self.options.insert(name.into(), value.into()); } + + /// Use a mock server for testing purposes. + /// + /// # Note + /// + /// The client will always use `RowBinary` format instead of `RowBinaryWithNamesAndTypes`, + /// as otherwise it'd be required to provide RBWNAT header in the mocks, + /// which is pointless in that kind of tests. + #[cfg(feature = "test-util")] + pub fn with_mock(mut self, mock: &test::Mock) -> Self { + self.url = mock.url().to_string(); + self.mocked = true; + self + } } /// This is a private API exported only for internal purposes. @@ -448,4 +503,14 @@ mod client_tests { .with_access_token("my_jwt") .with_password("secret"); } + + #[test] + fn it_sets_validation_mode() { + let client = Client::default(); + assert!(client.validation); + let client = client.with_validation(false); + assert!(!client.validation); + let client = client.with_validation(true); + assert!(client.validation); + } } diff --git a/src/query.rs b/src/query.rs index 374eebb9..346836c6 100644 --- a/src/query.rs +++ b/src/query.rs @@ -44,7 +44,7 @@ impl Query { /// [`Identifier`], will be appropriately escaped. /// /// All possible errors will be returned as [`Error::InvalidParams`] - /// during query execution (`execute()`, `fetch()` etc). + /// during query execution (`execute()`, `fetch()`, etc.). /// /// WARNING: This means that the query must not have any extra `?`, even if /// they are in a string literal! Use `??` to have plain `?` in query. @@ -85,10 +85,16 @@ impl Query { /// ``` pub fn fetch(mut self) -> Result> { self.sql.bind_fields::(); - self.sql.set_output_format("RowBinary"); + + let validation = self.client.get_validation(); + if validation { + self.sql.set_output_format("RowBinaryWithNamesAndTypes"); + } else { + self.sql.set_output_format("RowBinary"); + } let response = self.do_execute(true)?; - Ok(RowCursor::new(response)) + Ok(RowCursor::new(response, validation)) } /// Executes the query and returns just a single row. diff --git a/src/row.rs b/src/row.rs index c5ca6808..d591ebf9 100644 --- a/src/row.rs +++ b/src/row.rs @@ -1,7 +1,18 @@ use crate::sql; +#[derive(Debug, Clone, PartialEq)] +pub enum RowKind { + Primitive, + Struct, + Tuple, + Vec, +} + pub trait Row { + const NAME: &'static str; const COLUMN_NAMES: &'static [&'static str]; + const COLUMN_COUNT: usize; + const KIND: RowKind; // TODO: count // TODO: different list for SELECT/INSERT (de/ser) @@ -23,16 +34,24 @@ impl_primitive_for![ bool, String, u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize, f32, f64, ]; +macro_rules! count_tokens { + () => { 0 }; + ($head:tt $($tail:tt)*) => { 1 + count_tokens!($($tail)*) }; +} + +/// Two forms are supported: +/// * (P1, P2, ...) +/// * (SomeRow, P1, P2, ...) +/// +/// The second one is useful for queries like +/// `SELECT ?fields, count() FROM ... GROUP BY ?fields`. macro_rules! impl_row_for_tuple { ($i:ident $($other:ident)+) => { - /// Two forms are supported: - /// * (P1, P2, ...) - /// * (SomeRow, P1, P2, ...) - /// - /// The second one is useful for queries like - /// `SELECT ?fields, count() FROM .. GROUP BY ?fields`. impl<$i: Row, $($other: Primitive),+> Row for ($i, $($other),+) { + const NAME: &'static str = $i::NAME; const COLUMN_NAMES: &'static [&'static str] = $i::COLUMN_NAMES; + const COLUMN_COUNT: usize = $i::COLUMN_COUNT + count_tokens!($($other)*); + const KIND: RowKind = RowKind::Tuple; } impl_row_for_tuple!($($other)+); @@ -44,13 +63,19 @@ macro_rules! impl_row_for_tuple { impl Primitive for () {} impl Row for P { + const NAME: &'static str = stringify!(P); const COLUMN_NAMES: &'static [&'static str] = &[]; + const COLUMN_COUNT: usize = 1; + const KIND: RowKind = RowKind::Primitive; } impl_row_for_tuple!(T0 T1 T2 T3 T4 T5 T6 T7 T8); impl Row for Vec { + const NAME: &'static str = "Vec"; const COLUMN_NAMES: &'static [&'static str] = &[]; + const COLUMN_COUNT: usize = 1; + const KIND: RowKind = RowKind::Vec; } /// Collects all field names in depth and joins them with comma. diff --git a/src/row_metadata.rs b/src/row_metadata.rs new file mode 100644 index 00000000..dbf5dbad --- /dev/null +++ b/src/row_metadata.rs @@ -0,0 +1,195 @@ +// FIXME: this is allowed only temporarily, +// before the insert RBWNAT implementation is ready, +// cause otherwise the caches are never used. +#![allow(dead_code)] +#![allow(unreachable_pub)] + +use crate::row::RowKind; +use crate::sql::Identifier; +use crate::Result; +use crate::Row; +use clickhouse_types::{parse_rbwnat_columns_header, Column}; +use std::collections::HashMap; +use std::fmt::Display; +use std::sync::Arc; +use tokio::sync::{OnceCell, RwLock}; + +/// Cache for [`RowMetadata`] to avoid allocating it for the same struct more than once +/// during the application lifecycle. Key: fully qualified table name (e.g. `database.table`). +type LockedRowMetadataCache = RwLock>>; +static ROW_METADATA_CACHE: OnceCell = OnceCell::const_new(); + +#[derive(Debug, PartialEq)] +pub(crate) enum AccessType { + WithSeqAccess, + WithMapAccess(Vec), +} + +/// [`RowMetadata`] should be owned outside the (de)serializer, +/// as it is calculated only once per struct. It does not have lifetimes, +/// so it does not introduce a breaking change to [`crate::cursors::RowCursor`]. +pub(crate) struct RowMetadata { + /// Database schema, or columns, are parsed before the first call to (de)serializer. + pub(crate) columns: Vec, + /// This determines whether we can just use [`crate::rowbinary::de::RowBinarySeqAccess`] + /// or a more sophisticated approach with [`crate::rowbinary::de::RowBinaryStructAsMapAccess`] + /// to support structs defined with different fields order than in the schema. + /// (De)serializing a struct as a map will be approximately 40% slower than as a sequence. + access_type: AccessType, +} + +impl RowMetadata { + pub(crate) fn new(columns: Vec) -> Self { + let access_type = match T::KIND { + RowKind::Primitive => { + if columns.len() != 1 { + panic!( + "While processing a primitive row: \ + expected only 1 column in the database schema, \ + but got {} instead.\n#### All schema columns:\n{}", + columns.len(), + join_panic_schema_hint(&columns), + ); + } + AccessType::WithSeqAccess // ignored + } + RowKind::Tuple => { + if T::COLUMN_COUNT != columns.len() { + panic!( + "While processing a tuple row: database schema has {} columns, \ + but the tuple definition has {} fields in total.\ + \n#### All schema columns:\n{}", + columns.len(), + T::COLUMN_COUNT, + join_panic_schema_hint(&columns), + ); + } + AccessType::WithSeqAccess // ignored + } + RowKind::Vec => { + if columns.len() != 1 { + panic!( + "While processing a row defined as a vector: \ + expected only 1 column in the database schema, \ + but got {} instead.\n#### All schema columns:\n{}", + columns.len(), + join_panic_schema_hint(&columns), + ); + } + AccessType::WithSeqAccess // ignored + } + RowKind::Struct => { + if columns.len() != T::COLUMN_NAMES.len() { + panic!( + "While processing struct {}: database schema has {} columns, \ + but the struct definition has {} fields.\ + \n#### All struct fields:\n{}\n#### All schema columns:\n{}", + T::NAME, + columns.len(), + T::COLUMN_NAMES.len(), + join_panic_schema_hint(T::COLUMN_NAMES), + join_panic_schema_hint(&columns), + ); + } + let mut mapping = Vec::with_capacity(T::COLUMN_NAMES.len()); + let mut expected_index = 0; + let mut should_use_map = false; + for col in &columns { + if let Some(index) = T::COLUMN_NAMES.iter().position(|field| col.name == *field) + { + if index != expected_index { + should_use_map = true + } + expected_index += 1; + mapping.push(index); + } else { + panic!( + "While processing struct {}: database schema has a column {} \ + that was not found in the struct definition.\ + \n#### All struct fields:\n{}\n#### All schema columns:\n{}", + T::NAME, + col, + join_panic_schema_hint(T::COLUMN_NAMES), + join_panic_schema_hint(&columns), + ); + } + } + if should_use_map { + AccessType::WithMapAccess(mapping) + } else { + AccessType::WithSeqAccess + } + } + }; + Self { + columns, + access_type, + } + } + + #[inline] + pub(crate) fn get_schema_index(&self, struct_idx: usize) -> usize { + match &self.access_type { + AccessType::WithMapAccess(mapping) => { + if struct_idx < mapping.len() { + mapping[struct_idx] + } else { + // unreachable + panic!("Struct has more fields than columns in the database schema",) + } + } + AccessType::WithSeqAccess => struct_idx, // should be unreachable + } + } + + #[inline] + pub(crate) fn is_field_order_wrong(&self) -> bool { + matches!(self.access_type, AccessType::WithMapAccess(_)) + } +} + +pub(crate) async fn get_row_metadata( + client: &crate::Client, + table_name: &str, +) -> Result> { + let locked_cache = ROW_METADATA_CACHE + .get_or_init(|| async { RwLock::new(HashMap::new()) }) + .await; + let cache_guard = locked_cache.read().await; + match cache_guard.get(table_name) { + Some(metadata) => Ok(metadata.clone()), + None => cache_row_metadata::(client, table_name, locked_cache).await, + } +} + +/// Used internally to introspect and cache the table structure to allow validation +/// of serialized rows before submitting the first [`insert::Insert::write`]. +async fn cache_row_metadata( + client: &crate::Client, + table_name: &str, + locked_cache: &LockedRowMetadataCache, +) -> Result> { + let mut bytes_cursor = client + .query("SELECT * FROM ? LIMIT 0") + .bind(Identifier(table_name)) + .fetch_bytes("RowBinaryWithNamesAndTypes")?; + let mut buffer = Vec::::new(); + while let Some(chunk) = bytes_cursor.next().await? { + buffer.extend_from_slice(&chunk); + } + let columns = parse_rbwnat_columns_header(&mut buffer.as_slice())?; + let mut cache = locked_cache.write().await; + let metadata = Arc::new(RowMetadata::new::(columns)); + cache.insert(table_name.to_string(), metadata.clone()); + Ok(metadata) +} + +fn join_panic_schema_hint(col: &[T]) -> String { + if col.is_empty() { + return String::default(); + } + col.iter() + .map(|c| format!("- {}", c)) + .collect::>() + .join("\n") +} diff --git a/src/rowbinary/de.rs b/src/rowbinary/de.rs index c7c41392..a1e97854 100644 --- a/src/rowbinary/de.rs +++ b/src/rowbinary/de.rs @@ -1,30 +1,83 @@ -use std::{convert::TryFrom, mem, str}; - use crate::error::{Error, Result}; +use crate::row_metadata::RowMetadata; +use crate::rowbinary::utils::{ensure_size, get_unsigned_leb128}; +use crate::rowbinary::validation::SerdeType; +use crate::rowbinary::validation::{DataTypeValidator, SchemaValidator}; +use crate::Row; use bytes::Buf; +use core::mem::size_of; +use serde::de::MapAccess; use serde::{ de::{DeserializeSeed, Deserializer, EnumAccess, SeqAccess, VariantAccess, Visitor}, Deserialize, }; +use std::marker::PhantomData; +use std::{convert::TryFrom, str}; -/// Deserializes a value from `input` with a row encoded in `RowBinary`. +/// Deserializes a value from `input` with a row encoded in `RowBinary`, +/// i.e. only when [`crate::Row`] validation is disabled in the client. /// /// It accepts _a reference to_ a byte slice because it somehow leads to a more /// performant generated code than `(&[u8]) -> Result<(T, usize)>` and even /// `(&[u8], &mut Option) -> Result`. -pub(crate) fn deserialize_from<'data, T: Deserialize<'data>>(input: &mut &'data [u8]) -> Result { - let mut deserializer = RowBinaryDeserializer { input }; +pub(crate) fn deserialize_row<'data, 'cursor, T: Deserialize<'data> + Row>( + input: &mut &'data [u8], +) -> Result { + let mut deserializer = RowBinaryDeserializer::::new(input, ()); T::deserialize(&mut deserializer) } -/// A deserializer for the RowBinary format. +/// Similar to [`deserialize_row`], but uses [`RowMetadata`] +/// parsed from `RowBinaryWithNamesAndTypes` header to validate the data types. +/// This is used when [`crate::Row`] validation is enabled in the client (default). +/// +/// It expects a slice of [`Column`] objects parsed from the beginning +/// of `RowBinaryWithNamesAndTypes` data stream. After the header, +/// the rows format is the same as `RowBinary`. +pub(crate) fn deserialize_row_with_validation<'data, 'cursor, T: Deserialize<'data> + Row>( + input: &mut &'data [u8], + metadata: &'cursor RowMetadata, +) -> Result { + let validator = DataTypeValidator::new(metadata); + let mut deserializer = RowBinaryDeserializer::::new(input, validator); + T::deserialize(&mut deserializer) +} + +/// A deserializer for the `RowBinary(WithNamesAndTypes)` format. /// /// See https://clickhouse.com/docs/en/interfaces/formats#rowbinary for details. -struct RowBinaryDeserializer<'cursor, 'data> { +struct RowBinaryDeserializer<'cursor, 'data, R: Row, V = ()> +where + V: SchemaValidator, +{ input: &'cursor mut &'data [u8], + validator: V, + _marker: PhantomData, } -impl<'data> RowBinaryDeserializer<'_, 'data> { +impl<'cursor, 'data, R: Row, V> RowBinaryDeserializer<'cursor, 'data, R, V> +where + V: SchemaValidator, +{ + fn new(input: &'cursor mut &'data [u8], validator: V) -> Self { + Self { + input, + validator, + _marker: PhantomData, + } + } + + fn inner( + &mut self, + serde_type: SerdeType, + ) -> RowBinaryDeserializer<'_, 'data, R, V::Inner<'_>> { + RowBinaryDeserializer { + input: self.input, + validator: self.validator.validate(serde_type), + _marker: PhantomData, + } + } + fn read_vec(&mut self, size: usize) -> Result> { Ok(self.read_slice(size)?.to_vec()) } @@ -43,71 +96,71 @@ impl<'data> RowBinaryDeserializer<'_, 'data> { } } -#[inline] -fn ensure_size(buffer: impl Buf, size: usize) -> Result<()> { - if buffer.remaining() < size { - Err(Error::NotEnoughData) - } else { - Ok(()) - } +macro_rules! impl_num { + ($ty:ty, $deser_method:ident, $visitor_method:ident, $reader_method:ident, $serde_type:expr) => { + #[inline(always)] + fn $deser_method>(self, visitor: V) -> Result { + self.validator.validate($serde_type); + ensure_size(&mut self.input, core::mem::size_of::<$ty>())?; + let value = self.input.$reader_method(); + visitor.$visitor_method(value) + } + }; } -macro_rules! impl_num { - ($ty:ty, $deser_method:ident, $visitor_method:ident, $reader_method:ident) => { - #[inline] +macro_rules! impl_num_or_enum { + ($ty:ty, $deser_method:ident, $visitor_method:ident, $reader_method:ident, $serde_type:expr) => { + #[inline(always)] fn $deser_method>(self, visitor: V) -> Result { - ensure_size(&mut self.input, mem::size_of::<$ty>())?; + let mut maybe_enum_validator = self.validator.validate($serde_type); + ensure_size(&mut self.input, core::mem::size_of::<$ty>())?; let value = self.input.$reader_method(); + maybe_enum_validator.validate_identifier::<$ty>(value); visitor.$visitor_method(value) } }; } -impl<'data> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data> { +impl<'data, R: Row, Validator> Deserializer<'data> + for &mut RowBinaryDeserializer<'_, 'data, R, Validator> +where + Validator: SchemaValidator, +{ type Error = Error; - impl_num!(i8, deserialize_i8, visit_i8, get_i8); - - impl_num!(i16, deserialize_i16, visit_i16, get_i16_le); - - impl_num!(i32, deserialize_i32, visit_i32, get_i32_le); - - impl_num!(i64, deserialize_i64, visit_i64, get_i64_le); - - impl_num!(i128, deserialize_i128, visit_i128, get_i128_le); - - impl_num!(u8, deserialize_u8, visit_u8, get_u8); - - impl_num!(u16, deserialize_u16, visit_u16, get_u16_le); - - impl_num!(u32, deserialize_u32, visit_u32, get_u32_le); - - impl_num!(u64, deserialize_u64, visit_u64, get_u64_le); + impl_num_or_enum!(i8, deserialize_i8, visit_i8, get_i8, SerdeType::I8); + impl_num_or_enum!(i16, deserialize_i16, visit_i16, get_i16_le, SerdeType::I16); - impl_num!(u128, deserialize_u128, visit_u128, get_u128_le); + impl_num!(i32, deserialize_i32, visit_i32, get_i32_le, SerdeType::I32); + impl_num!(i64, deserialize_i64, visit_i64, get_i64_le, SerdeType::I64); + #[rustfmt::skip] + impl_num!(i128, deserialize_i128, visit_i128, get_i128_le, SerdeType::I128); - impl_num!(f32, deserialize_f32, visit_f32, get_f32_le); + impl_num!(u8, deserialize_u8, visit_u8, get_u8, SerdeType::U8); + impl_num!(u16, deserialize_u16, visit_u16, get_u16_le, SerdeType::U16); + impl_num!(u32, deserialize_u32, visit_u32, get_u32_le, SerdeType::U32); + impl_num!(u64, deserialize_u64, visit_u64, get_u64_le, SerdeType::U64); + #[rustfmt::skip] + impl_num!(u128, deserialize_u128, visit_u128, get_u128_le, SerdeType::U128); - impl_num!(f64, deserialize_f64, visit_f64, get_f64_le); + impl_num!(f32, deserialize_f32, visit_f32, get_f32_le, SerdeType::F32); + impl_num!(f64, deserialize_f64, visit_f64, get_f64_le, SerdeType::F64); - #[inline] + #[inline(always)] fn deserialize_any>(self, _: V) -> Result { Err(Error::DeserializeAnyNotSupported) } - #[inline] + #[inline(always)] fn deserialize_unit>(self, visitor: V) -> Result { // TODO: revise this. + // TODO - skip validation? visitor.visit_unit() } - #[inline] - fn deserialize_char>(self, _: V) -> Result { - panic!("character types are unsupported: `char`"); - } - - #[inline] + #[inline(always)] fn deserialize_bool>(self, visitor: V) -> Result { + self.validator.validate(SerdeType::Bool); ensure_size(&mut self.input, 1)?; match self.input.get_u8() { 0 => visitor.visit_bool(false), @@ -116,171 +169,121 @@ impl<'data> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data> { } } - #[inline] + #[inline(always)] fn deserialize_str>(self, visitor: V) -> Result { + self.validator.validate(SerdeType::Str); let size = self.read_size()?; let slice = self.read_slice(size)?; let str = str::from_utf8(slice).map_err(Error::from)?; visitor.visit_borrowed_str(str) } - #[inline] + #[inline(always)] fn deserialize_string>(self, visitor: V) -> Result { + self.validator.validate(SerdeType::String); let size = self.read_size()?; let vec = self.read_vec(size)?; let string = String::from_utf8(vec).map_err(|err| Error::from(err.utf8_error()))?; visitor.visit_string(string) } - #[inline] + #[inline(always)] fn deserialize_bytes>(self, visitor: V) -> Result { let size = self.read_size()?; + self.validator.validate(SerdeType::Bytes(size)); let slice = self.read_slice(size)?; visitor.visit_borrowed_bytes(slice) } - #[inline] + #[inline(always)] fn deserialize_byte_buf>(self, visitor: V) -> Result { let size = self.read_size()?; + self.validator.validate(SerdeType::ByteBuf(size)); visitor.visit_byte_buf(self.read_vec(size)?) } - #[inline] + /// This is used to deserialize identifiers for either: + /// - `Variant` data type + /// - [`RowBinaryStructAsMapAccess`] field. + #[inline(always)] fn deserialize_identifier>(self, visitor: V) -> Result { - self.deserialize_u8(visitor) + ensure_size(&mut self.input, size_of::())?; + let value = self.input.get_u8(); + // TODO: is there a better way to validate that the deserialized value matches the schema? + // TODO: theoretically, we can track if we are currently processing a struct field id, + // and don't call the validator in that case, cause it will never be a `Variant`. + self.validator.validate_identifier::(value); + visitor.visit_u8(value) } - #[inline] + #[inline(always)] fn deserialize_enum>( self, _name: &'static str, _variants: &'static [&'static str], visitor: V, ) -> Result { - struct Access<'de, 'cursor, 'data> { - deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data>, - } - struct VariantDeserializer<'de, 'cursor, 'data> { - deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data>, - } - impl<'data> VariantAccess<'data> for VariantDeserializer<'_, '_, 'data> { - type Error = Error; - - fn unit_variant(self) -> Result<()> { - Err(Error::Unsupported("unit variants".to_string())) - } - - fn newtype_variant_seed(self, seed: T) -> Result - where - T: DeserializeSeed<'data>, - { - DeserializeSeed::deserialize(seed, &mut *self.deserializer) - } - - fn tuple_variant(self, len: usize, visitor: V) -> Result - where - V: Visitor<'data>, - { - self.deserializer.deserialize_tuple(len, visitor) - } - - fn struct_variant( - self, - fields: &'static [&'static str], - visitor: V, - ) -> Result - where - V: Visitor<'data>, - { - self.deserializer.deserialize_tuple(fields.len(), visitor) - } - } - - impl<'de, 'cursor, 'data> EnumAccess<'data> for Access<'de, 'cursor, 'data> { - type Error = Error; - type Variant = VariantDeserializer<'de, 'cursor, 'data>; - - fn variant_seed(self, seed: T) -> Result<(T::Value, Self::Variant), Self::Error> - where - T: DeserializeSeed<'data>, - { - let value = seed.deserialize(&mut *self.deserializer)?; - let deserializer = VariantDeserializer { - deserializer: self.deserializer, - }; - Ok((value, deserializer)) - } - } - visitor.visit_enum(Access { deserializer: self }) + let deserializer = &mut self.inner(SerdeType::Enum); + visitor.visit_enum(RowBinaryEnumAccess { deserializer }) } - #[inline] + #[inline(always)] fn deserialize_tuple>(self, len: usize, visitor: V) -> Result { - struct Access<'de, 'cursor, 'data> { - deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data>, - len: usize, - } - - impl<'data> SeqAccess<'data> for Access<'_, '_, 'data> { - type Error = Error; - - fn next_element_seed(&mut self, seed: T) -> Result> - where - T: DeserializeSeed<'data>, - { - if self.len > 0 { - self.len -= 1; - let value = DeserializeSeed::deserialize(seed, &mut *self.deserializer)?; - Ok(Some(value)) - } else { - Ok(None) - } - } - - fn size_hint(&self) -> Option { - Some(self.len) - } - } - - visitor.visit_seq(Access { - deserializer: self, - len, - }) + let deserializer = &mut self.inner(SerdeType::Tuple(len)); + visitor.visit_seq(RowBinarySeqAccess { deserializer, len }) } - #[inline] + #[inline(always)] fn deserialize_option>(self, visitor: V) -> Result { ensure_size(&mut self.input, 1)?; - - match self.input.get_u8() { - 0 => visitor.visit_some(&mut *self), + let is_null = self.input.get_u8(); + let deserializer = &mut self.inner(SerdeType::Option); + match is_null { + 0 => visitor.visit_some(deserializer), 1 => visitor.visit_none(), v => Err(Error::InvalidTagEncoding(v as usize)), } } - #[inline] + #[inline(always)] fn deserialize_seq>(self, visitor: V) -> Result { let len = self.read_size()?; - self.deserialize_tuple(len, visitor) + let deserializer = &mut self.inner(SerdeType::Seq(len)); + visitor.visit_seq(RowBinarySeqAccess { deserializer, len }) } - #[inline] - fn deserialize_map>(self, _visitor: V) -> Result { - panic!("maps are unsupported, use `Vec<(A, B)>` instead"); + #[inline(always)] + fn deserialize_map>(self, visitor: V) -> Result { + let len = self.read_size()?; + let deserializer = &mut self.inner(SerdeType::Map(len)); + visitor.visit_map(RowBinaryMapAccess { + deserializer, + remaining: len, + }) } - #[inline] + #[inline(always)] fn deserialize_struct>( self, - _name: &str, + _name: &'static str, fields: &'static [&'static str], visitor: V, ) -> Result { - self.deserialize_tuple(fields.len(), visitor) + if !self.validator.is_field_order_wrong() { + visitor.visit_seq(RowBinarySeqAccess { + deserializer: self, + len: fields.len(), + }) + } else { + visitor.visit_map(RowBinaryStructAsMapAccess { + deserializer: self, + current_field_idx: 0, + fields, + }) + } } - #[inline] + #[inline(always)] fn deserialize_newtype_struct>( self, _name: &str, @@ -289,7 +292,12 @@ impl<'data> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data> { visitor.visit_newtype_struct(self) } - #[inline] + #[inline(always)] + fn deserialize_char>(self, _: V) -> Result { + panic!("character types are unsupported: `char`"); + } + + #[inline(always)] fn deserialize_unit_struct>( self, name: &'static str, @@ -298,7 +306,7 @@ impl<'data> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data> { panic!("unit types are unsupported: `{name}`"); } - #[inline] + #[inline(always)] fn deserialize_tuple_struct>( self, name: &'static str, @@ -308,43 +316,244 @@ impl<'data> Deserializer<'data> for &mut RowBinaryDeserializer<'_, 'data> { panic!("tuple struct types are unsupported: `{name}`"); } - #[inline] + #[inline(always)] fn deserialize_ignored_any>(self, _visitor: V) -> Result { panic!("ignored types are unsupported"); } - #[inline] + #[inline(always)] fn is_human_readable(&self) -> bool { false } } -fn get_unsigned_leb128(mut buffer: impl Buf) -> Result { - let mut value = 0u64; - let mut shift = 0; +/// Used in [`Deserializer::deserialize_seq`], [`Deserializer::deserialize_tuple`], +/// and it could be used in [`Deserializer::deserialize_struct`], +/// if we detect that the field order matches the database schema. +struct RowBinarySeqAccess<'de, 'cursor, 'data, R: Row, Validator> +where + Validator: SchemaValidator, +{ + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, R, Validator>, + len: usize, +} - loop { - ensure_size(&mut buffer, 1)?; +impl<'data, R: Row, Validator> SeqAccess<'data> for RowBinarySeqAccess<'_, '_, 'data, R, Validator> +where + Validator: SchemaValidator, +{ + type Error = Error; - let byte = buffer.get_u8(); - value |= (byte as u64 & 0x7f) << shift; + fn next_element_seed(&mut self, seed: T) -> Result> + where + T: DeserializeSeed<'data>, + { + if self.len > 0 { + self.len -= 1; + let value = DeserializeSeed::deserialize(seed, &mut *self.deserializer)?; + Ok(Some(value)) + } else { + Ok(None) + } + } + + fn size_hint(&self) -> Option { + Some(self.len) + } +} + +/// Used in [`Deserializer::deserialize_map`]. +struct RowBinaryMapAccess<'de, 'cursor, 'data, R: Row, Validator> +where + Validator: SchemaValidator, +{ + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, R, Validator>, + remaining: usize, +} - if byte & 0x80 == 0 { - break; +impl<'data, R: Row, Validator> MapAccess<'data> for RowBinaryMapAccess<'_, '_, 'data, R, Validator> +where + Validator: SchemaValidator, +{ + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: DeserializeSeed<'data>, + { + if self.remaining == 0 { + return Ok(None); } + self.remaining -= 1; + seed.deserialize(&mut *self.deserializer).map(Some) + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'data>, + { + seed.deserialize(&mut *self.deserializer) + } + + fn size_hint(&self) -> Option { + Some(self.remaining) + } +} + +/// Used in [`Deserializer::deserialize_struct`] to support wrong struct field order +/// as long as the data types and field names are exactly matching the database schema. +struct RowBinaryStructAsMapAccess<'de, 'cursor, 'data, R: Row, Validator> +where + Validator: SchemaValidator, +{ + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, R, Validator>, + current_field_idx: usize, + fields: &'static [&'static str], +} + +struct StructFieldIdentifier(&'static str); + +impl<'de> Deserializer<'de> for StructFieldIdentifier { + type Error = Error; + + fn deserialize_identifier(self, visitor: V) -> Result + where + V: Visitor<'de>, + { + visitor.visit_borrowed_str(self.0) + } + + fn deserialize_any(self, _visitor: V) -> Result + where + V: Visitor<'de>, + { + panic!("StructFieldIdentifier is supposed to use `deserialize_identifier` only"); + } + + serde::forward_to_deserialize_any! { + bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string + bytes byte_buf option unit unit_struct newtype_struct seq tuple + tuple_struct map struct enum ignored_any + } +} - shift += 7; - if shift > 57 { - // TODO: what about another error? - return Err(Error::NotEnoughData); +/// Without schema order "restoration", the following query: +/// +/// ```sql +/// SELECT 'foo' :: String AS a, +/// 'bar' :: String AS c +/// ``` +/// +/// Will produce a wrong result, if the struct is defined as: +/// +/// ```rs +/// struct Data { +/// c: String, +/// a: String, +/// } +/// ``` +/// +/// If we just use [`RowBinarySeqAccess`] here, `c` will be deserialized into the `a` field, +/// and `a` will be deserialized into the `c` field, which is a classic case of data corruption. +impl<'data, R: Row, Validator> MapAccess<'data> + for RowBinaryStructAsMapAccess<'_, '_, 'data, R, Validator> +where + Validator: SchemaValidator, +{ + type Error = Error; + + fn next_key_seed(&mut self, seed: K) -> Result> + where + K: DeserializeSeed<'data>, + { + if self.current_field_idx >= self.fields.len() { + return Ok(None); } + let schema_index = self + .deserializer + .validator + .get_schema_index(self.current_field_idx); + let field_id = StructFieldIdentifier(self.fields[schema_index]); + self.current_field_idx += 1; + seed.deserialize(field_id).map(Some) + } + + fn next_value_seed(&mut self, seed: V) -> Result + where + V: DeserializeSeed<'data>, + { + seed.deserialize(&mut *self.deserializer) + } + + fn size_hint(&self) -> Option { + Some(self.fields.len()) + } +} + +/// Used in [`Deserializer::deserialize_enum`]. +struct RowBinaryEnumAccess<'de, 'cursor, 'data, R: Row, Validator> +where + Validator: SchemaValidator, +{ + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, R, Validator>, +} + +struct VariantDeserializer<'de, 'cursor, 'data, R: Row, Validator> +where + Validator: SchemaValidator, +{ + deserializer: &'de mut RowBinaryDeserializer<'cursor, 'data, R, Validator>, +} + +impl<'data, R: Row, Validator> VariantAccess<'data> + for VariantDeserializer<'_, '_, 'data, R, Validator> +where + Validator: SchemaValidator, +{ + type Error = Error; + + fn unit_variant(self) -> Result<()> { + panic!("unit variants are unsupported"); } - Ok(value) + fn newtype_variant_seed(self, seed: T) -> Result + where + T: DeserializeSeed<'data>, + { + DeserializeSeed::deserialize(seed, &mut *self.deserializer) + } + + fn tuple_variant(self, len: usize, visitor: V) -> Result + where + V: Visitor<'data>, + { + self.deserializer.deserialize_tuple(len, visitor) + } + + fn struct_variant(self, fields: &'static [&'static str], visitor: V) -> Result + where + V: Visitor<'data>, + { + self.deserializer.deserialize_tuple(fields.len(), visitor) + } } -#[test] -fn it_deserializes_unsigned_leb128() { - let buf = &[0xe5, 0x8e, 0x26][..]; - assert_eq!(get_unsigned_leb128(buf).unwrap(), 624_485); +impl<'de, 'cursor, 'data, R: Row, Validator> EnumAccess<'data> + for RowBinaryEnumAccess<'de, 'cursor, 'data, R, Validator> +where + Validator: SchemaValidator, +{ + type Error = Error; + type Variant = VariantDeserializer<'de, 'cursor, 'data, R, Validator>; + + fn variant_seed(self, seed: T) -> Result<(T::Value, Self::Variant), Self::Error> + where + T: DeserializeSeed<'data>, + { + let value = seed.deserialize(&mut *self.deserializer)?; + let deserializer = VariantDeserializer { + deserializer: self.deserializer, + }; + Ok((value, deserializer)) + } } diff --git a/src/rowbinary/mod.rs b/src/rowbinary/mod.rs index dbdb672e..9c96bca9 100644 --- a/src/rowbinary/mod.rs +++ b/src/rowbinary/mod.rs @@ -1,7 +1,11 @@ -pub(crate) use de::deserialize_from; +pub(crate) use de::deserialize_row; +pub(crate) use de::deserialize_row_with_validation; pub(crate) use ser::serialize_into; +pub(crate) mod validation; + mod de; mod ser; #[cfg(test)] mod tests; +mod utils; diff --git a/src/rowbinary/ser.rs b/src/rowbinary/ser.rs index 68fec881..c644b118 100644 --- a/src/rowbinary/ser.rs +++ b/src/rowbinary/ser.rs @@ -1,4 +1,5 @@ use bytes::BufMut; +use clickhouse_types::put_leb128; use serde::{ ser::{Impossible, SerializeSeq, SerializeStruct, SerializeTuple, Serializer}, Serialize, @@ -42,27 +43,16 @@ impl Serializer for &'_ mut RowBinarySerializer { type SerializeTupleVariant = Impossible<(), Error>; impl_num!(i8, serialize_i8, put_i8); - impl_num!(i16, serialize_i16, put_i16_le); - impl_num!(i32, serialize_i32, put_i32_le); - impl_num!(i64, serialize_i64, put_i64_le); - impl_num!(i128, serialize_i128, put_i128_le); - impl_num!(u8, serialize_u8, put_u8); - impl_num!(u16, serialize_u16, put_u16_le); - impl_num!(u32, serialize_u32, put_u32_le); - impl_num!(u64, serialize_u64, put_u64_le); - impl_num!(u128, serialize_u128, put_u128_le); - impl_num!(f32, serialize_f32, put_f32_le); - impl_num!(f64, serialize_f64, put_f64_le); #[inline] @@ -78,14 +68,14 @@ impl Serializer for &'_ mut RowBinarySerializer { #[inline] fn serialize_str(self, v: &str) -> Result<()> { - put_unsigned_leb128(&mut self.buffer, v.len() as u64); + put_leb128(&mut self.buffer, v.len() as u64); self.buffer.put_slice(v.as_bytes()); Ok(()) } #[inline] fn serialize_bytes(self, v: &[u8]) -> Result<()> { - put_unsigned_leb128(&mut self.buffer, v.len() as u64); + put_leb128(&mut self.buffer, v.len() as u64); self.buffer.put_slice(v); Ok(()) } @@ -148,9 +138,7 @@ impl Serializer for &'_ mut RowBinarySerializer { // Max number of types in the Variant data type is 255 // See also: https://github.com/ClickHouse/ClickHouse/issues/54864 if variant_index > 255 { - return Err(Error::VariantDiscriminatorIsOutOfBound( - variant_index as usize, - )); + panic!("max number of types in the Variant data type is 255, got {variant_index}") } self.buffer.put_u8(variant_index as u8); value.serialize(self) @@ -159,7 +147,7 @@ impl Serializer for &'_ mut RowBinarySerializer { #[inline] fn serialize_seq(self, len: Option) -> Result { let len = len.ok_or(Error::SequenceMustHaveLength)?; - put_unsigned_leb128(&mut self.buffer, len as u64); + put_leb128(&mut self.buffer, len as u64); Ok(self) } @@ -260,27 +248,3 @@ impl SerializeTuple for &'_ mut RowBinarySerializer { Ok(()) } } - -fn put_unsigned_leb128(mut buffer: impl BufMut, mut value: u64) { - while { - let mut byte = value as u8 & 0x7f; - value >>= 7; - - if value != 0 { - byte |= 0x80; - } - - buffer.put_u8(byte); - - value != 0 - } {} -} - -#[test] -fn it_serializes_unsigned_leb128() { - let mut vec = Vec::new(); - - put_unsigned_leb128(&mut vec, 624_485); - - assert_eq!(vec, [0xe5, 0x8e, 0x26]); -} diff --git a/src/rowbinary/tests.rs b/src/rowbinary/tests.rs index 2865cbef..ac097d6c 100644 --- a/src/rowbinary/tests.rs +++ b/src/rowbinary/tests.rs @@ -1,3 +1,4 @@ +use crate::Row; use serde::{Deserialize, Serialize}; #[derive(Debug, PartialEq, Serialize, Deserialize)] @@ -36,6 +37,34 @@ struct Sample<'a> { boolean: bool, } +// clickhouse_derive is not working here +impl Row for Sample<'_> { + const NAME: &'static str = "Sample"; + const COLUMN_NAMES: &'static [&'static str] = &[ + "int8", + "int32", + "int64", + "uint8", + "uint32", + "uint64", + "float32", + "float64", + "datetime", + "datetime64", + "decimal64", + "decimal128", + "string", + "blob", + "optional_decimal64", + "optional_datetime", + "fixed_string", + "array", + "boolean", + ]; + const COLUMN_COUNT: usize = 19; + const KIND: crate::RowKind = crate::RowKind::Struct; +} + fn sample() -> Sample<'static> { Sample { int8: -42, @@ -122,10 +151,10 @@ fn it_deserializes() { let (mut left, mut right) = input.split_at(i); // It shouldn't panic. - let _: Result, _> = super::deserialize_from(&mut left); - let _: Result, _> = super::deserialize_from(&mut right); + let _: Result, _> = super::deserialize_row(&mut left); + let _: Result, _> = super::deserialize_row(&mut right); - let actual: Sample<'_> = super::deserialize_from(&mut input.as_slice()).unwrap(); + let actual: Sample<'_> = super::deserialize_row(&mut input.as_slice()).unwrap(); assert_eq!(actual, sample()); } } diff --git a/src/rowbinary/utils.rs b/src/rowbinary/utils.rs new file mode 100644 index 00000000..e1dc1d6e --- /dev/null +++ b/src/rowbinary/utils.rs @@ -0,0 +1,44 @@ +use crate::error::Error; +use bytes::Buf; + +/// TODO: it is theoretically possible to ensure size in chunks, +/// at least for some types, given that we have the database schema. +#[inline] +pub(crate) fn ensure_size(buffer: impl Buf, size: usize) -> crate::error::Result<()> { + if buffer.remaining() < size { + Err(Error::NotEnoughData) + } else { + Ok(()) + } +} + +#[inline] +pub(crate) fn get_unsigned_leb128(mut buffer: impl Buf) -> crate::error::Result { + let mut value = 0u64; + let mut shift = 0; + + loop { + ensure_size(&mut buffer, 1)?; + + let byte = buffer.get_u8(); + value |= (byte as u64 & 0x7f) << shift; + + if byte & 0x80 == 0 { + break; + } + + shift += 7; + if shift > 57 { + // TODO: what about another error? + return Err(Error::NotEnoughData); + } + } + + Ok(value) +} + +#[test] +fn it_deserializes_unsigned_leb128() { + let buf = &[0xe5, 0x8e, 0x26][..]; + assert_eq!(get_unsigned_leb128(buf).unwrap(), 624_485); +} diff --git a/src/rowbinary/validation.rs b/src/rowbinary/validation.rs new file mode 100644 index 00000000..e8b8ad0c --- /dev/null +++ b/src/rowbinary/validation.rs @@ -0,0 +1,778 @@ +use crate::row_metadata::RowMetadata; +use crate::{Row, RowKind}; +use clickhouse_types::data_types::{Column, DataTypeNode, DecimalType, EnumType}; +use std::collections::HashMap; +use std::fmt::Display; +use std::marker::PhantomData; + +/// This trait is used to validate the schema of a [`crate::Row`] against the parsed RBWNAT schema. +/// Note that [`SchemaValidator`] is also implemented for `()`, +/// which is used to skip validation if the user disabled it. +pub(crate) trait SchemaValidator: Sized { + type Inner<'de>: SchemaValidator + where + Self: 'de; + /// The main entry point. The validation flow based on the [`crate::Row::KIND`]. + /// For container types (nullable, array, map, tuple, variant, etc.), + /// it will return an [`InnerDataTypeValidator`] instance (see [`InnerDataTypeValidatorKind`]), + /// which has its own implementation of this method, allowing recursive validation. + fn validate(&mut self, serde_type: SerdeType) -> Self::Inner<'_>; + /// Validates that an identifier exists in the values map for enums, + /// or stores the variant identifier for the next serde call. + fn validate_identifier(&mut self, value: T); + /// Having the database schema from RBWNAT, the crate can detect that + /// while the field names and the types are correct, the field order in the struct + /// does not match the column order in the database schema, and we should use + /// `MapAccess` instead of `SeqAccess` to seamlessly deserialize the struct. + fn is_field_order_wrong(&self) -> bool; + /// Returns the "restored" index of the schema column for the given struct field index. + /// It is used only if the crate detects that while the field names and the types are correct, + /// the field order in the struct does not match the column order in the database schema. + fn get_schema_index(&self, struct_idx: usize) -> usize; +} + +pub(crate) struct DataTypeValidator<'cursor, R: Row> { + metadata: &'cursor RowMetadata, + current_column_idx: usize, + _marker: PhantomData, +} + +impl<'cursor, R: Row> DataTypeValidator<'cursor, R> { + pub(crate) fn new(metadata: &'cursor RowMetadata) -> Self { + Self { + metadata, + current_column_idx: 0, + _marker: PhantomData::, + } + } + + fn get_current_column(&self) -> Option<&Column> { + if self.current_column_idx > 0 && self.current_column_idx <= self.metadata.columns.len() { + // index is immediately moved to the next column after the root validator is called + let schema_index = self.get_schema_index(self.current_column_idx - 1); + Some(&self.metadata.columns[schema_index]) + } else { + None + } + } + + fn get_current_column_name_and_type(&self) -> (String, &DataTypeNode) { + self.get_current_column() + .map(|c| (format!("{}.{}", R::NAME, c.name), &c.data_type)) + // both should be defined at this point + .unwrap_or(("Struct".to_string(), &DataTypeNode::Bool)) + } + + fn panic_on_schema_mismatch<'de>( + &'de self, + data_type: &DataTypeNode, + serde_type: &SerdeType, + is_inner: bool, + ) -> Option> { + match R::KIND { + RowKind::Primitive => { + panic!( + "While processing row as a primitive: attempting to deserialize \ + ClickHouse type {} as {} which is not compatible", + data_type, serde_type + ) + } + RowKind::Vec => { + panic!( + "While processing row as a vector: attempting to deserialize \ + ClickHouse type {} as {} which is not compatible", + data_type, serde_type + ) + } + RowKind::Tuple => { + panic!( + "While processing row as a tuple: attempting to deserialize \ + ClickHouse type {} as {} which is not compatible", + data_type, serde_type + ) + } + RowKind::Struct => { + if is_inner { + let (full_name, full_data_type) = self.get_current_column_name_and_type(); + panic!( + "While processing column {} defined as {}: attempting to deserialize \ + nested ClickHouse type {} as {} which is not compatible", + full_name, full_data_type, data_type, serde_type + ) + } else { + panic!( + "While processing column {}: attempting to deserialize \ + ClickHouse type {} as {} which is not compatible", + self.get_current_column_name_and_type().0, + data_type, + serde_type + ) + } + } + } + } +} + +impl<'cursor, R: Row> SchemaValidator for DataTypeValidator<'cursor, R> { + type Inner<'de> + = Option> + where + Self: 'de; + + #[inline] + fn validate(&'_ mut self, serde_type: SerdeType) -> Self::Inner<'_> { + match R::KIND { + // `fetch::` for a "primitive row" type + RowKind::Primitive => { + if self.current_column_idx == 0 && self.metadata.columns.len() == 1 { + let data_type = &self.metadata.columns[0].data_type; + validate_impl(self, data_type, &serde_type, false) + } else { + panic!( + "Primitive row is expected to be a single value, got columns: {:?}", + self.metadata.columns + ); + } + } + // `fetch::<(i16, i32)>` or `fetch::<(T, u64)>` for a "tuple row" type + RowKind::Tuple => { + match serde_type { + SerdeType::Tuple(_) => Some(InnerDataTypeValidator { + root: self, + kind: InnerDataTypeValidatorKind::RootTuple(&self.metadata.columns, 0), + }), + _ => { + // should be unreachable + panic!( + "While processing tuple row: expected serde type Tuple(N), got {}", + serde_type + ); + } + } + } + // `fetch::>` for a "vector row" type + RowKind::Vec => { + let data_type = &self.metadata.columns[0].data_type; + let kind = match data_type { + DataTypeNode::Array(inner_type) => { + InnerDataTypeValidatorKind::RootArray(inner_type) + } + _ => panic!( + "Expected Array type when validating root level sequence, but got {}", + self.metadata.columns[0].data_type + ), + }; + Some(InnerDataTypeValidator { root: self, kind }) + } + // `fetch::` for a "struct row" type, which is supposed to be the default flow + RowKind::Struct => { + if self.current_column_idx < self.metadata.columns.len() { + let current_column = &self.metadata.columns[self.current_column_idx]; + self.current_column_idx += 1; + validate_impl(self, ¤t_column.data_type, &serde_type, false) + } else { + panic!( + "Struct {} has more fields than columns in the database schema", + R::NAME + ) + } + } + } + } + + #[inline] + fn is_field_order_wrong(&self) -> bool { + self.metadata.is_field_order_wrong() + } + + #[inline] + fn get_schema_index(&self, struct_idx: usize) -> usize { + self.metadata.get_schema_index(struct_idx) + } + + #[cold] + fn validate_identifier(&mut self, _value: T) { + unreachable!() + } +} + +/// Having a ClickHouse `Map` defined as a `HashMap` in Rust, Serde will call: +/// - `deserialize_map` for `Vec<(K, V)>` +/// - `deserialize_` suitable for `K` +/// - `deserialize_` suitable for `V` +#[derive(Debug)] +pub(crate) enum MapValidatorState { + Key, + Value, +} + +/// Having a ClickHouse `Map` defined as `Vec<(K, V)>` in Rust, Serde will call: +/// - `deserialize_seq` for `Vec<(K, V)>` +/// - `deserialize_tuple` for `(K, V)` +/// - `deserialize_` suitable for `K` +/// - `deserialize_` suitable for `V` +#[derive(Debug)] +pub(crate) enum MapAsSequenceValidatorState { + Tuple, + Key, + Value, +} + +pub(crate) struct InnerDataTypeValidator<'de, 'cursor, R: Row> { + root: &'de DataTypeValidator<'cursor, R>, + kind: InnerDataTypeValidatorKind<'cursor>, +} + +#[derive(Debug)] +pub(crate) enum InnerDataTypeValidatorKind<'cursor> { + Array(&'cursor DataTypeNode), + FixedString(usize), + Map(&'cursor [Box; 2], MapValidatorState), + /// Allows supporting ClickHouse `Map` defined as `Vec<(K, V)>` in Rust + MapAsSequence(&'cursor [Box; 2], MapAsSequenceValidatorState), + Tuple(&'cursor [DataTypeNode]), + /// This is a hack to support deserializing tuples/arrays (and not structs) from fetch calls + RootTuple(&'cursor [Column], usize), + RootArray(&'cursor DataTypeNode), + Enum(&'cursor HashMap), + Variant(&'cursor [DataTypeNode], VariantValidationState), + Nullable(&'cursor DataTypeNode), +} + +#[derive(Debug)] +pub(crate) enum VariantValidationState { + Pending, + Identifier(u8), +} + +impl<'cursor, R: Row> SchemaValidator for Option> { + type Inner<'de> + = Self + where + Self: 'de; + + #[inline] + fn validate(&mut self, serde_type: SerdeType) -> Self { + let inner = self.as_mut()?; + match &mut inner.kind { + InnerDataTypeValidatorKind::Map(kv, state) => match state { + MapValidatorState::Key => { + let result = validate_impl(inner.root, &kv[0], &serde_type, true); + *state = MapValidatorState::Value; + result + } + MapValidatorState::Value => { + let result = validate_impl(inner.root, &kv[1], &serde_type, true); + *state = MapValidatorState::Key; + result + } + }, + InnerDataTypeValidatorKind::MapAsSequence(kv, state) => { + match state { + // the first state is simply skipped, as the same validator + // will be called again for the Key and then the Value types + MapAsSequenceValidatorState::Tuple => { + *state = MapAsSequenceValidatorState::Key; + self.take() + } + MapAsSequenceValidatorState::Key => { + let result = validate_impl(inner.root, &kv[0], &serde_type, true); + *state = MapAsSequenceValidatorState::Value; + result + } + MapAsSequenceValidatorState::Value => { + let result = validate_impl(inner.root, &kv[1], &serde_type, true); + *state = MapAsSequenceValidatorState::Tuple; + result + } + } + } + InnerDataTypeValidatorKind::Array(inner_type) => { + validate_impl(inner.root, inner_type, &serde_type, true) + } + InnerDataTypeValidatorKind::Nullable(inner_type) => { + validate_impl(inner.root, inner_type, &serde_type, true) + } + InnerDataTypeValidatorKind::Tuple(elements_types) => { + match elements_types.split_first() { + Some((first, rest)) => { + *elements_types = rest; + validate_impl(inner.root, first, &serde_type, true) + } + None => { + let (full_name, full_data_type) = + inner.root.get_current_column_name_and_type(); + panic!( + "While processing column {} defined as {}: \ + attempting to deserialize {} while no more elements are allowed", + full_name, full_data_type, serde_type + ) + } + } + } + InnerDataTypeValidatorKind::FixedString(_len) => { + None // actually unreachable + } + InnerDataTypeValidatorKind::RootTuple(columns, current_index) => { + if *current_index < columns.len() { + let data_type = &columns[*current_index].data_type; + *current_index += 1; + validate_impl(inner.root, data_type, &serde_type, true) + } else { + let (full_name, full_data_type) = inner.root.get_current_column_name_and_type(); + panic!( + "While processing root tuple element {} defined as {}: \ + attempting to deserialize {} while no more elements are allowed", + full_name, full_data_type, serde_type + ) + } + } + InnerDataTypeValidatorKind::RootArray(inner_data_type) => { + validate_impl(inner.root, inner_data_type, &serde_type, true) + } + InnerDataTypeValidatorKind::Variant(possible_types, state) => match state { + VariantValidationState::Pending => { + unreachable!() + } + VariantValidationState::Identifier(value) => { + if *value as usize >= possible_types.len() { + let (full_name, full_data_type) = + inner.root.get_current_column_name_and_type(); + panic!( + "While processing column {full_name} defined as {full_data_type}: \ + Variant identifier {value} is out of bounds, max allowed index is {}", + possible_types.len() - 1 + ); + } + let data_type = &possible_types[*value as usize]; + validate_impl(inner.root, data_type, &serde_type, true) + } + }, + // TODO - check enum string value correctness in the hashmap? + // is this even possible? + InnerDataTypeValidatorKind::Enum(_values_map) => { + unreachable!() + } + } + } + + fn validate_identifier(&mut self, value: T) { + use InnerDataTypeValidatorKind::{Enum, Variant}; + if let Some(inner) = self { + match T::IDENTIFIER_TYPE { + IdentifierType::Enum8 | IdentifierType::Enum16 => { + if let Enum(values_map) = &inner.kind { + if !values_map.contains_key(&(value.into_i16())) { + let (full_name, full_data_type) = + inner.root.get_current_column_name_and_type(); + panic!( + "While processing column {full_name} defined as {full_data_type}: \ + Enum8 value {value} is not present in the database schema" + ); + } + } + } + IdentifierType::Variant => { + if let Variant(possible_types, state) = &mut inner.kind { + // ClickHouse guarantees max 255 variants, i.e. the same max value as u8 + if value.into_u8() < (possible_types.len() as u8) { + *state = VariantValidationState::Identifier(value.into_u8()); + } else { + let (full_name, full_data_type) = + inner.root.get_current_column_name_and_type(); + panic!( + "While processing column {full_name} defined as {full_data_type}: \ + Variant identifier {value} is out of bounds, max allowed index is {}", + possible_types.len() - 1 + ); + } + } + } + } + } + } + + #[inline(always)] + fn is_field_order_wrong(&self) -> bool { + false + } + + #[cold] + fn get_schema_index(&self, _struct_idx: usize) -> usize { + unreachable!() + } +} + +impl Drop for InnerDataTypeValidator<'_, '_, R> { + fn drop(&mut self) { + if let InnerDataTypeValidatorKind::Tuple(elements_types) = self.kind { + if !elements_types.is_empty() { + let (column_name, column_type) = self.root.get_current_column_name_and_type(); + panic!( + "While processing column {} defined as {}: tuple was not fully deserialized; \ + remaining elements: {}; likely, the field definition is incomplete", + column_name, + column_type, + elements_types + .iter() + .map(|c| c.to_string()) + .collect::>() + .join(", ") + ) + } + } + } +} + +// TODO: is there a way to eliminate multiple branches with similar patterns? +// static/const dispatch? +// separate smaller inline functions? +#[inline] +fn validate_impl<'de, 'cursor, R: Row>( + root: &'de DataTypeValidator<'cursor, R>, + column_data_type: &'cursor DataTypeNode, + serde_type: &SerdeType, + is_inner: bool, +) -> Option> { + let data_type = column_data_type.remove_low_cardinality(); + match serde_type { + SerdeType::Bool + if data_type == &DataTypeNode::Bool || data_type == &DataTypeNode::UInt8 => + { + None + } + SerdeType::I8 => match data_type { + DataTypeNode::Int8 => None, + DataTypeNode::Enum(EnumType::Enum8, values_map) => Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Enum(values_map), + }), + _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), + }, + SerdeType::I16 => match data_type { + DataTypeNode::Int16 => None, + DataTypeNode::Enum(EnumType::Enum16, values_map) => Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Enum(values_map), + }), + _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), + }, + SerdeType::I32 + if data_type == &DataTypeNode::Int32 + || data_type == &DataTypeNode::Date32 + || matches!( + data_type, + DataTypeNode::Decimal(_, _, DecimalType::Decimal32) + ) => + { + None + } + SerdeType::I64 + if data_type == &DataTypeNode::Int64 + || matches!(data_type, DataTypeNode::DateTime64(_, _)) + || matches!( + data_type, + DataTypeNode::Decimal(_, _, DecimalType::Decimal64) + ) => + { + None + } + SerdeType::I128 + if data_type == &DataTypeNode::Int128 + || matches!( + data_type, + DataTypeNode::Decimal(_, _, DecimalType::Decimal128) + ) => + { + None + } + SerdeType::U8 if data_type == &DataTypeNode::UInt8 => None, + SerdeType::U16 + if data_type == &DataTypeNode::UInt16 || data_type == &DataTypeNode::Date => + { + None + } + SerdeType::U32 + if data_type == &DataTypeNode::UInt32 + || matches!(data_type, DataTypeNode::DateTime(_)) + || data_type == &DataTypeNode::IPv4 => + { + None + } + SerdeType::U64 if data_type == &DataTypeNode::UInt64 => None, + SerdeType::U128 if data_type == &DataTypeNode::UInt128 => None, + SerdeType::F32 if data_type == &DataTypeNode::Float32 => None, + SerdeType::F64 if data_type == &DataTypeNode::Float64 => None, + SerdeType::Str | SerdeType::String + if data_type == &DataTypeNode::String || data_type == &DataTypeNode::JSON => + { + None + } + // allows to work with BLOB strings as well + SerdeType::Bytes(_) | SerdeType::ByteBuf(_) if data_type == &DataTypeNode::String => None, + SerdeType::Option => { + if let DataTypeNode::Nullable(inner_type) = data_type { + Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Nullable(inner_type), + }) + } else { + root.panic_on_schema_mismatch(data_type, serde_type, is_inner) + } + } + SerdeType::Seq(_) => match data_type { + DataTypeNode::Array(inner_type) => Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Array(inner_type), + }), + // A map can be defined as `Vec<(K, V)>` in the struct + DataTypeNode::Map(kv) => Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::MapAsSequence( + kv, + MapAsSequenceValidatorState::Tuple, + ), + }), + DataTypeNode::Ring => Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Array(&DataTypeNode::Point), + }), + DataTypeNode::Polygon => Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Array(&DataTypeNode::Ring), + }), + DataTypeNode::MultiPolygon => Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Array(&DataTypeNode::Polygon), + }), + DataTypeNode::LineString => Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Array(&DataTypeNode::Point), + }), + DataTypeNode::MultiLineString => Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Array(&DataTypeNode::LineString), + }), + _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), + }, + SerdeType::Tuple(len) => match data_type { + DataTypeNode::FixedString(n) => { + if n == len { + Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::FixedString(*n), + }) + } else { + let (full_name, full_data_type) = root.get_current_column_name_and_type(); + panic!( + "While processing column {} defined as {}: attempting to deserialize \ + nested ClickHouse type {} as {}", + full_name, full_data_type, data_type, serde_type, + ) + } + } + DataTypeNode::Tuple(elements) => Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Tuple(elements), + }), + DataTypeNode::Array(inner_type) => Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Array(inner_type), + }), + DataTypeNode::IPv6 => Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Array(&DataTypeNode::UInt8), + }), + DataTypeNode::UUID => Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Tuple(UUID_TUPLE_ELEMENTS), + }), + DataTypeNode::Point => Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Tuple(POINT_TUPLE_ELEMENTS), + }), + _ => root.panic_on_schema_mismatch(data_type, serde_type, is_inner), + }, + SerdeType::Map(_) => { + if let DataTypeNode::Map(kv) = data_type { + Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Map(kv, MapValidatorState::Key), + }) + } else { + panic!( + "Expected Map for {} call, but got {}", + serde_type, data_type + ) + } + } + SerdeType::Enum => { + if let DataTypeNode::Variant(possible_types) = data_type { + Some(InnerDataTypeValidator { + root, + kind: InnerDataTypeValidatorKind::Variant( + possible_types, + VariantValidationState::Pending, + ), + }) + } else { + panic!( + "Expected Variant for {} call, but got {}", + serde_type, data_type + ) + } + } + + _ => root.panic_on_schema_mismatch( + data_type, + serde_type, + is_inner || matches!(column_data_type, DataTypeNode::LowCardinality { .. }), + ), + } +} + +impl SchemaValidator for () { + type Inner<'de> = (); + + #[inline(always)] + fn validate(&mut self, _serde_type: SerdeType) {} + + #[inline(always)] + fn is_field_order_wrong(&self) -> bool { + // We can't detect incorrect field order with just plain `RowBinary` format + false + } + + #[inline(always)] + fn validate_identifier(&mut self, _value: T) {} + + #[cold] + fn get_schema_index(&self, _struct_idx: usize) -> usize { + unreachable!() + } +} + +/// Which Serde data type (De)serializer used for the given type. +/// Displays into certain Rust types for convenience in errors reporting. +/// See also: available methods in [`serde::Serializer`] and [`serde::Deserializer`]. +#[derive(Clone, Debug, PartialEq)] +pub(crate) enum SerdeType { + Bool, + I8, + I16, + I32, + I64, + I128, + U8, + U16, + U32, + U64, + U128, + F32, + F64, + Str, + String, + Option, + Enum, + Bytes(usize), + ByteBuf(usize), + Tuple(usize), + Seq(usize), + Map(usize), + // Identifier, + // Char, + // Unit, + // Struct, + // NewtypeStruct, + // TupleStruct, + // UnitStruct, + // IgnoredAny, +} + +impl Display for SerdeType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + SerdeType::Bool => write!(f, "bool"), + SerdeType::I8 => write!(f, "i8"), + SerdeType::I16 => write!(f, "i16"), + SerdeType::I32 => write!(f, "i32"), + SerdeType::I64 => write!(f, "i64"), + SerdeType::I128 => write!(f, "i128"), + SerdeType::U8 => write!(f, "u8"), + SerdeType::U16 => write!(f, "u16"), + SerdeType::U32 => write!(f, "u32"), + SerdeType::U64 => write!(f, "u64"), + SerdeType::U128 => write!(f, "u128"), + SerdeType::F32 => write!(f, "f32"), + SerdeType::F64 => write!(f, "f64"), + SerdeType::Str => write!(f, "&str"), + SerdeType::String => write!(f, "String"), + SerdeType::Bytes(len) => write!(f, "&[u8; {len}]"), + SerdeType::ByteBuf(_len) => write!(f, "Vec"), + SerdeType::Option => write!(f, "Option"), + SerdeType::Enum => write!(f, "enum"), + SerdeType::Seq(_len) => write!(f, "Vec"), + SerdeType::Tuple(len) => write!(f, "a tuple or sequence with length {len}"), + SerdeType::Map(_len) => write!(f, "Map"), + // SerdeType::Identifier => "identifier", + // SerdeType::Char => "char", + // SerdeType::Unit => "()", + // SerdeType::Struct => "struct", + // SerdeType::NewtypeStruct => "newtype struct", + // SerdeType::TupleStruct => "tuple struct", + // SerdeType::UnitStruct => "unit struct", + // SerdeType::IgnoredAny => "ignored any", + } + } +} + +#[derive(Debug)] +pub(crate) enum IdentifierType { + Enum8, + Enum16, + Variant, +} +pub(crate) trait EnumOrVariantIdentifier: Display + Copy { + const IDENTIFIER_TYPE: IdentifierType; + fn into_u8(self) -> u8; + fn into_i16(self) -> i16; +} +impl EnumOrVariantIdentifier for u8 { + const IDENTIFIER_TYPE: IdentifierType = IdentifierType::Variant; + // none of these should be ever called + #[inline(always)] + fn into_u8(self) -> u8 { + self + } + #[inline(always)] + fn into_i16(self) -> i16 { + self as i16 + } +} +impl EnumOrVariantIdentifier for i8 { + const IDENTIFIER_TYPE: IdentifierType = IdentifierType::Enum8; + #[inline(always)] + fn into_i16(self) -> i16 { + self as i16 + } + // we need only i16 for enum values HashMap + #[inline(always)] + fn into_u8(self) -> u8 { + self as u8 + } +} +impl EnumOrVariantIdentifier for i16 { + const IDENTIFIER_TYPE: IdentifierType = IdentifierType::Enum16; + #[inline(always)] + fn into_i16(self) -> i16 { + self + } + // should not be ever called + #[inline(always)] + fn into_u8(self) -> u8 { + self as u8 + } +} + +const UUID_TUPLE_ELEMENTS: &[DataTypeNode; 2] = &[DataTypeNode::UInt64, DataTypeNode::UInt64]; +const POINT_TUPLE_ELEMENTS: &[DataTypeNode; 2] = &[DataTypeNode::Float64, DataTypeNode::Float64]; diff --git a/src/test/handlers.rs b/src/test/handlers.rs index f074adbd..63fec687 100644 --- a/src/test/handlers.rs +++ b/src/test/handlers.rs @@ -7,7 +7,7 @@ use sealed::sealed; use serde::{Deserialize, Serialize}; use super::{Handler, HandlerFn}; -use crate::rowbinary; +use crate::{rowbinary, Row}; const BUFFER_INITIAL_CAPACITY: usize = 1024; @@ -82,7 +82,7 @@ pub struct RecordControl { impl RecordControl where - T: for<'a> Deserialize<'a>, + T: for<'a> Deserialize<'a> + Row, { pub async fn collect(self) -> C where @@ -93,7 +93,8 @@ where let mut result = C::default(); while !slice.is_empty() { - let row: T = rowbinary::deserialize_from(slice).expect("failed to deserialize"); + let res = rowbinary::deserialize_row(slice); + let row: T = res.expect("failed to deserialize"); result.extend(std::iter::once(row)); } diff --git a/src/test/mock.rs b/src/test/mock.rs index 41636d45..18739e24 100644 --- a/src/test/mock.rs +++ b/src/test/mock.rs @@ -52,9 +52,9 @@ impl Mock { Self { url: format!("http://{addr}"), - shared, non_exhaustive: false, server_handle: server_handle.abort_handle(), + shared, } } diff --git a/tests/it/chrono.rs b/tests/it/chrono.rs index 536c24bb..18dcb737 100644 --- a/tests/it/chrono.rs +++ b/tests/it/chrono.rs @@ -101,11 +101,11 @@ async fn datetime() { let row_str = client .query( " - SELECT toString(dt), - toString(dt64s), - toString(dt64ms), - toString(dt64us), - toString(dt64ns) + SELECT toString(dt) AS dt, + toString(dt64s) AS dt64s, + toString(dt64ms) AS dt64ms, + toString(dt64us) AS dt64us, + toString(dt64ns) AS dt64ns FROM test ", ) diff --git a/tests/it/cursor_error.rs b/tests/it/cursor_error.rs index e4894dc4..afad60a6 100644 --- a/tests/it/cursor_error.rs +++ b/tests/it/cursor_error.rs @@ -1,20 +1,24 @@ -use serde::Deserialize; - -use clickhouse::{error::Error, Client, Compression, Row}; - -#[tokio::test] -async fn deferred() { - let client = prepare_database!(); - max_execution_time(client, false).await; -} +use clickhouse::{Client, Compression}; #[tokio::test] async fn wait_end_of_query() { let client = prepare_database!(); - max_execution_time(client, true).await; + let scenarios = vec![ + // wait_end_of_query=?, expected_rows + (false, 3), // server returns some rows before throwing an error + (true, 0), // server throws an error immediately + ]; + for (wait_end_of_query, expected_rows) in scenarios { + let result = max_execution_time(client.clone(), wait_end_of_query).await; + assert_eq!( + result, expected_rows, + "wait_end_of_query: {}, expected_rows: {}", + wait_end_of_query, expected_rows + ); + } } -async fn max_execution_time(mut client: Client, wait_end_of_query: bool) { +async fn max_execution_time(mut client: Client, wait_end_of_query: bool) -> u8 { if wait_end_of_query { client = client.with_option("wait_end_of_query", "1") } @@ -22,27 +26,24 @@ async fn max_execution_time(mut client: Client, wait_end_of_query: bool) { // TODO: check different `timeout_overflow_mode` let mut cursor = client .with_compression(Compression::None) + // fails on the 4th row .with_option("max_execution_time", "0.1") - .query("SELECT toUInt8(65 + number % 5) FROM system.numbers LIMIT 100000000") + // force streaming one row in a chunk + .with_option("max_block_size", "1") + .query("SELECT sleepEachRow(0.03) AS s FROM system.numbers LIMIT 5") .fetch::() .unwrap(); - let mut i = 0u64; - + let mut i = 0; let err = loop { match cursor.next().await { - Ok(Some(no)) => { - // Check that we haven't parsed something extra. - assert_eq!(no, (65 + i % 5) as u8); - i += 1; - } + Ok(Some(_)) => i += 1, Ok(None) => panic!("DB exception hasn't been found"), Err(err) => break err, } }; - - assert!(wait_end_of_query ^ (i != 0)); assert!(err.to_string().contains("TIMEOUT_EXCEEDED")); + i } #[cfg(feature = "lz4")] @@ -98,40 +99,3 @@ async fn deferred_lz4() { assert_ne!(i, 0); // we're interested only in errors during processing assert!(err.to_string().contains("TIMEOUT_EXCEEDED")); } - -// See #185. -#[tokio::test] -async fn invalid_schema() { - #[derive(Debug, Row, Deserialize)] - #[allow(dead_code)] - struct MyRow { - no: u32, - dec: Option, // valid schema: u64-based types - } - - let client = prepare_database!(); - - client - .query( - "CREATE TABLE test(no UInt32, dec Nullable(Decimal64(4))) - ENGINE = MergeTree - ORDER BY no", - ) - .execute() - .await - .unwrap(); - - client - .query("INSERT INTO test VALUES (1, 1.1), (2, 2.2), (3, 3.3)") - .execute() - .await - .unwrap(); - - let err = client - .query("SELECT ?fields FROM test") - .fetch_all::() - .await - .unwrap_err(); - - assert!(matches!(err, Error::NotEnoughData)); -} diff --git a/tests/it/cursor_stats.rs b/tests/it/cursor_stats.rs index 7ae43bdf..503885ad 100644 --- a/tests/it/cursor_stats.rs +++ b/tests/it/cursor_stats.rs @@ -28,7 +28,7 @@ async fn check(client: Client, expected_ratio: f64) { decoded = cursor.decoded_bytes(); } - assert_eq!(decoded, 15000); + assert_eq!(decoded, 15000 + 23); // 23 extra bytes for the RBWNAT header. assert_eq!(cursor.received_bytes(), dbg!(received)); assert_eq!(cursor.decoded_bytes(), dbg!(decoded)); assert_eq!( diff --git a/tests/it/insert.rs b/tests/it/insert.rs index 5e7a77e1..952314a1 100644 --- a/tests/it/insert.rs +++ b/tests/it/insert.rs @@ -1,26 +1,7 @@ use crate::{create_simple_table, fetch_rows, flush_query_log, SimpleRow}; -use clickhouse::{sql::Identifier, Client, Row}; +use clickhouse::{sql::Identifier, Row}; use serde::{Deserialize, Serialize}; -#[derive(Debug, Row, Serialize, Deserialize, PartialEq)] -#[serde(rename_all = "camelCase")] -struct RenameRow { - #[serde(rename = "fix_id")] - pub(crate) fix_id: i64, - #[serde(rename = "extComplexId")] - pub(crate) complex_id: String, - pub(crate) ext_float: f64, -} - -async fn create_rename_table(client: &Client, table_name: &str) { - client - .query("CREATE TABLE ?(fixId UInt64, extComplexId String, extFloat Float64) ENGINE = MergeTree ORDER BY fixId") - .bind(Identifier(table_name)) - .execute() - .await - .unwrap(); -} - #[tokio::test] async fn keeps_client_options() { let table_name = "insert_keeps_client_options"; @@ -144,11 +125,36 @@ async fn empty_insert() { #[tokio::test] async fn rename_insert() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + #[serde(rename_all = "camelCase")] + struct RenameRow { + #[serde(rename = "fixId")] + pub(crate) fix_id: u64, + #[serde(rename = "extComplexId")] + pub(crate) complex_id: String, + pub(crate) ext_float: f64, + } + let table_name = "insert_rename"; let query_id = uuid::Uuid::new_v4().to_string(); let client = prepare_database!(); - create_rename_table(&client, table_name).await; + client + .query( + " + CREATE TABLE ?( + fixId UInt64, + extComplexId String, + extFloat Float64 + ) + ENGINE = MergeTree + ORDER BY fixId + ", + ) + .bind(Identifier(table_name)) + .execute() + .await + .unwrap(); let row = RenameRow { fix_id: 42, diff --git a/tests/it/main.rs b/tests/it/main.rs index 1c1e5d1f..5bc4f5e0 100644 --- a/tests/it/main.rs +++ b/tests/it/main.rs @@ -27,6 +27,45 @@ use clickhouse::{sql::Identifier, Client, Row}; use serde::{Deserialize, Serialize}; +macro_rules! assert_panic_on_fetch_with_client { + ($client:ident, $msg_parts:expr, $query:expr) => { + use futures::FutureExt; + let async_panic = + std::panic::AssertUnwindSafe(async { $client.query($query).fetch_all::().await }); + let result = async_panic.catch_unwind().await; + assert!(result.is_err()); + let panic_msg = *result.unwrap_err().downcast::().unwrap(); + for &msg in $msg_parts { + assert!( + panic_msg.contains(msg), + "panic message:\n{panic_msg}\ndid not contain the expected part:\n{msg}" + ); + } + }; +} + +macro_rules! assert_panic_on_fetch { + ($msg_parts:expr, $query:expr) => { + use futures::FutureExt; + let client = get_client(); + let async_panic = + std::panic::AssertUnwindSafe(async { client.query($query).fetch_all::().await }); + let result = async_panic.catch_unwind().await; + assert!( + result.is_err(), + "expected a panic, but got a result instead: {:?}", + result.unwrap() + ); + let panic_msg = *result.unwrap_err().downcast::().unwrap(); + for &msg in $msg_parts { + assert!( + panic_msg.contains(msg), + "panic message:\n{panic_msg}\ndid not contain the expected part:\n{msg}" + ); + } + }; +} + macro_rules! prepare_database { () => { crate::_priv::prepare_database({ @@ -87,7 +126,7 @@ impl SimpleRow { } } -async fn create_simple_table(client: &Client, table_name: &str) { +pub(crate) async fn create_simple_table(client: &Client, table_name: &str) { client .query("CREATE TABLE ?(id UInt64, data String) ENGINE = MergeTree ORDER BY id") .with_option("wait_end_of_query", "1") @@ -97,7 +136,7 @@ async fn create_simple_table(client: &Client, table_name: &str) { .unwrap(); } -async fn fetch_rows(client: &Client, table_name: &str) -> Vec +pub(crate) async fn fetch_rows(client: &Client, table_name: &str) -> Vec where T: Row + for<'b> Deserialize<'b>, { @@ -109,10 +148,21 @@ where .unwrap() } -async fn flush_query_log(client: &Client) { +pub(crate) async fn flush_query_log(client: &Client) { client.query("SYSTEM FLUSH LOGS").execute().await.unwrap(); } +pub(crate) async fn execute_statements(client: &Client, statements: &[&str]) { + for statement in statements { + client + .query(statement) + .with_option("wait_end_of_query", "1") + .execute() + .await + .unwrap_or_else(|err| panic!("cannot execute statement '{statement}', cause: {err}")); + } +} + mod chrono; mod cloud_jwt; mod compression; @@ -127,6 +177,7 @@ mod ip; mod mock; mod nested; mod query; +mod rbwnat; mod time; mod user_agent; mod uuid; diff --git a/tests/it/mock.rs b/tests/it/mock.rs index 2db04537..3cc92481 100644 --- a/tests/it/mock.rs +++ b/tests/it/mock.rs @@ -1,15 +1,14 @@ #![cfg(feature = "test-util")] -use std::time::Duration; - -use clickhouse::{test, Client}; - use crate::SimpleRow; +use clickhouse::{test, Client}; +use std::time::Duration; async fn test_provide() { let mock = test::Mock::new(); - let client = Client::default().with_url(mock.url()); + let client = Client::default().with_mock(&mock); let expected = vec![SimpleRow::new(1, "one"), SimpleRow::new(2, "two")]; + mock.add(test::handlers::provide(&expected)); let actual = crate::fetch_rows::(&client, "doesn't matter").await; diff --git a/tests/it/query.rs b/tests/it/query.rs index 195297ed..7b783e92 100644 --- a/tests/it/query.rs +++ b/tests/it/query.rs @@ -93,9 +93,9 @@ async fn server_side_param() { .query("SELECT plus({val1: Int32}, {val2: Int32}) AS result") .param("val1", 42) .param("val2", 144) - .fetch_one::() + .fetch_one::() .await - .expect("failed to fetch u64"); + .expect("failed to fetch Int64"); assert_eq!(result, 186); let result = client diff --git a/tests/it/rbwnat.rs b/tests/it/rbwnat.rs new file mode 100644 index 00000000..73b5fc1f --- /dev/null +++ b/tests/it/rbwnat.rs @@ -0,0 +1,1721 @@ +use crate::{execute_statements, get_client}; +use clickhouse::sql::Identifier; +use clickhouse_derive::Row; +use clickhouse_types::data_types::{Column, DataTypeNode}; +use clickhouse_types::parse_rbwnat_columns_header; +use fixnum::typenum::{U12, U4, U8}; +use fixnum::FixedPoint; +use serde::{Deserialize, Serialize}; +use serde_repr::{Deserialize_repr, Serialize_repr}; +use std::collections::HashMap; +use std::str::FromStr; + +#[tokio::test] +async fn header_parsing() { + let client = prepare_database!(); + client + .query( + " + CREATE OR REPLACE TABLE visits + ( + CounterID UInt32, + StartDate Date, + Sign Int8, + IsNew UInt8, + VisitID UInt64, + UserID UInt64, + Goals Nested + ( + ID UInt32, + Serial UInt32, + EventTime DateTime, + Price Int64, + OrderID String, + CurrencyID UInt32 + ) + ) ENGINE = MergeTree ORDER BY () + ", + ) + .execute() + .await + .unwrap(); + + let mut cursor = client + .query("SELECT * FROM visits LIMIT 0") + .fetch_bytes("RowBinaryWithNamesAndTypes") + .unwrap(); + + let data = cursor.collect().await.unwrap(); + let result = parse_rbwnat_columns_header(&mut &data[..]).unwrap(); + assert_eq!( + result, + vec![ + Column { + name: "CounterID".to_string(), + data_type: DataTypeNode::UInt32, + }, + Column { + name: "StartDate".to_string(), + data_type: DataTypeNode::Date, + }, + Column { + name: "Sign".to_string(), + data_type: DataTypeNode::Int8, + }, + Column { + name: "IsNew".to_string(), + data_type: DataTypeNode::UInt8, + }, + Column { + name: "VisitID".to_string(), + data_type: DataTypeNode::UInt64, + }, + Column { + name: "UserID".to_string(), + data_type: DataTypeNode::UInt64, + }, + Column { + name: "Goals.ID".to_string(), + data_type: DataTypeNode::Array(Box::new(DataTypeNode::UInt32)), + }, + Column { + name: "Goals.Serial".to_string(), + data_type: DataTypeNode::Array(Box::new(DataTypeNode::UInt32)), + }, + Column { + name: "Goals.EventTime".to_string(), + data_type: DataTypeNode::Array(Box::new(DataTypeNode::DateTime(None))), + }, + Column { + name: "Goals.Price".to_string(), + data_type: DataTypeNode::Array(Box::new(DataTypeNode::Int64)), + }, + Column { + name: "Goals.OrderID".to_string(), + data_type: DataTypeNode::Array(Box::new(DataTypeNode::String)), + }, + Column { + name: "Goals.CurrencyID".to_string(), + data_type: DataTypeNode::Array(Box::new(DataTypeNode::UInt32)), + } + ] + ); +} + +#[tokio::test] +async fn fetch_primitive_row() { + let client = get_client(); + let result = client + .query("SELECT count() FROM (SELECT * FROM system.numbers LIMIT 3)") + .fetch_one::() + .await; + assert_eq!(result.unwrap(), 3); +} + +#[tokio::test] +async fn fetch_primitive_row_schema_mismatch() { + type Data = i32; // expected type is UInt64 + assert_panic_on_fetch!( + &["primitive", "UInt64", "i32"], + "SELECT count() FROM (SELECT * FROM system.numbers LIMIT 3)" + ); +} + +#[tokio::test] +async fn fetch_vector_row() { + let client = get_client(); + let result = client + .query("SELECT [1, 2, 3] :: Array(UInt32)") + .fetch_one::>() + .await; + assert_eq!(result.unwrap(), vec![1, 2, 3]); +} + +#[tokio::test] +async fn fetch_vector_row_schema_mismatch_nested_type() { + type Data = Vec; // expected type for Array(UInt32) is Vec + assert_panic_on_fetch!( + &["vector", "UInt32", "i128"], + "SELECT [1, 2, 3] :: Array(UInt32)" + ); +} + +#[tokio::test] +async fn fetch_tuple_row() { + let client = get_client(); + let result = client + .query("SELECT 42 :: UInt32 AS a, 'foo' :: String AS b") + .fetch_one::<(u32, String)>() + .await; + assert_eq!(result.unwrap(), (42, "foo".to_string())); +} + +#[tokio::test] +async fn fetch_tuple_row_schema_mismatch_first_element() { + type Data = (i128, String); // expected u32 instead of i128 + assert_panic_on_fetch!( + &["tuple", "UInt32", "i128"], + "SELECT 42 :: UInt32 AS a, 'foo' :: String AS b" + ); +} + +#[tokio::test] +async fn fetch_tuple_row_schema_mismatch_second_element() { + type Data = (u32, i64); // expected String instead of i64 + assert_panic_on_fetch!( + &["tuple", "String", "i64"], + "SELECT 42 :: UInt32 AS a, 'foo' :: String AS b" + ); +} + +#[tokio::test] +async fn fetch_tuple_row_schema_mismatch_missing_element() { + type Data = (u32, String); // expected to have the third element as i64 + assert_panic_on_fetch!( + &[ + "database schema has 3 columns", + "tuple definition has 2 fields" + ], + "SELECT 42 :: UInt32 AS a, 'foo' :: String AS b, 144 :: Int64 AS c" + ); +} + +#[tokio::test] +async fn fetch_tuple_row_schema_mismatch_too_many_elements() { + type Data = (u32, String, i128); // i128 should not be there + assert_panic_on_fetch!( + &[ + "database schema has 2 columns", + "tuple definition has 3 fields" + ], + "SELECT 42 :: UInt32 AS a, 'foo' :: String AS b" + ); +} + +#[tokio::test] +async fn fetch_tuple_row_with_struct() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: u32, + b: String, + } + + let client = get_client(); + let result = client + .query("SELECT 42 :: UInt32 AS a, 'foo' :: String AS b, 144 :: UInt64 AS c") + .fetch_one::<(Data, u64)>() + .await; + assert_eq!( + result.unwrap(), + ( + Data { + a: 42, + b: "foo".to_string() + }, + 144 + ) + ); +} + +#[tokio::test] +async fn fetch_tuple_row_with_struct_schema_mismatch() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct _Data { + a: u64, // expected type is u32 + b: String, + } + type Data = (_Data, u64); + assert_panic_on_fetch!( + &["tuple", "UInt32", "u64"], + "SELECT 42 :: UInt32 AS a, 'foo' :: String AS b, 144 :: UInt64 AS c" + ); +} + +#[tokio::test] +async fn fetch_tuple_row_with_struct_schema_mismatch_too_many_struct_fields() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct _Data { + a: u32, + b: String, + c: u64, // this field should not be here + } + type Data = (_Data, u64); + assert_panic_on_fetch!( + &["3 columns", "4 fields"], + "SELECT 42 :: UInt32 AS a, 'foo' :: String AS b, 144 :: UInt64 AS c" + ); +} + +#[tokio::test] +async fn fetch_tuple_row_with_struct_schema_mismatch_too_many_fields() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct _Data { + a: u32, + b: String, + } + type Data = (_Data, u64, u64); // one too many u64 + assert_panic_on_fetch!( + &["3 columns", "4 fields"], + "SELECT 42 :: UInt32 AS a, 'foo' :: String AS b, 144 :: UInt64 AS c" + ); +} + +#[tokio::test] +async fn fetch_tuple_row_with_struct_schema_mismatch_too_few_struct_fields() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct _Data { + a: u32, // the second field is missing now + } + type Data = (_Data, u64); + assert_panic_on_fetch!( + &["3 columns", "2 fields"], + "SELECT 42 :: UInt32 AS a, 'foo' :: String AS b, 144 :: UInt64 AS c" + ); +} + +#[tokio::test] +async fn fetch_tuple_row_with_struct_schema_mismatch_too_few_fields() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct _Data { + a: u32, + b: String, + } + type Data = (_Data, u64); // another u64 is missing here + assert_panic_on_fetch!( + &["4 columns", "3 fields"], + "SELECT 42 :: UInt32 AS a, 'foo' :: String AS b, 144 :: UInt64 AS c, 255 :: UInt64 AS d" + ); +} + +#[tokio::test] +async fn basic_types() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + uint8_val: u8, + uint16_val: u16, + uint32_val: u32, + uint64_val: u64, + uint128_val: u128, + int8_val: i8, + int16_val: i16, + int32_val: i32, + int64_val: i64, + int128_val: i128, + float32_val: f32, + float64_val: f64, + string_val: String, + } + + let client = get_client(); + let result = client + .query( + " + SELECT + 255 :: UInt8 AS uint8_val, + 65535 :: UInt16 AS uint16_val, + 4294967295 :: UInt32 AS uint32_val, + 18446744073709551615 :: UInt64 AS uint64_val, + 340282366920938463463374607431768211455 :: UInt128 AS uint128_val, + -128 :: Int8 AS int8_val, + -32768 :: Int16 AS int16_val, + -2147483648 :: Int32 AS int32_val, + -9223372036854775808 :: Int64 AS int64_val, + -170141183460469231731687303715884105728 :: Int128 AS int128_val, + 42.0 :: Float32 AS float32_val, + 144.0 :: Float64 AS float64_val, + 'test' :: String AS string_val + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + uint8_val: 255, + uint16_val: 65535, + uint32_val: 4294967295, + uint64_val: 18446744073709551615, + uint128_val: 340282366920938463463374607431768211455, + int8_val: -128, + int16_val: -32768, + int32_val: -2147483648, + int64_val: -9223372036854775808, + int128_val: -170141183460469231731687303715884105728, + float32_val: 42.0, + float64_val: 144.0, + string_val: "test".to_string(), + } + ); +} + +// FIXME: somehow this test breaks `cargo test`, but works from RustRover +#[ignore] +#[tokio::test] +async fn borrowed_data() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data<'a> { + str: &'a str, + array: Vec<&'a str>, + tuple: (&'a str, &'a str), + str_opt: Option<&'a str>, + vec_map_str: Vec<(&'a str, &'a str)>, + vec_map_f32: Vec<(&'a str, f32)>, + vec_map_nested: Vec<(&'a str, Vec<(&'a str, &'a str)>)>, + hash_map_str: HashMap<&'a str, &'a str>, + hash_map_f32: HashMap<&'a str, f32>, + hash_map_nested: HashMap<&'a str, HashMap<&'a str, &'a str>>, + } + + let client = get_client(); + let mut cursor = client + .query( + " + SELECT * FROM + ( + SELECT + 'a' :: String AS str, + ['b', 'c'] :: Array(String) AS array, + ('d', 'e') :: Tuple(String, String) AS tuple, + NULL :: Nullable(String) AS str_opt, + map('key1', 'value1', 'key2', 'value2') :: Map(String, String) AS hash_map_str, + map('key3', 100, 'key4', 200) :: Map(String, Float32) AS hash_map_f32, + map('n1', hash_map_str) :: Map(String, Map(String, String)) AS hash_map_nested, + hash_map_str AS vec_map_str, + hash_map_f32 AS vec_map_f32, + hash_map_nested AS vec_map_nested + UNION ALL + SELECT + 'f' :: String AS str, + ['g', 'h'] :: Array(String) AS array, + ('i', 'j') :: Tuple(String, String) AS tuple, + 'k' :: Nullable(String) AS str_opt, + map('key4', 'value4', 'key5', 'value5') :: Map(String, String) AS hash_map_str, + map('key6', 300, 'key7', 400) :: Map(String, Float32) AS hash_map_f32, + map('n2', hash_map_str) :: Map(String, Map(String, String)) AS hash_map_nested, + hash_map_str AS vec_map_str, + hash_map_f32 AS vec_map_f32, + hash_map_nested AS vec_map_nested + ) + ORDER BY str + ", + ) + .fetch::>() + .unwrap(); + + let mut result = Vec::new(); + while let Some(row) = cursor.next().await.unwrap() { + result.push(row); + } + + assert_eq!( + result, + vec![ + Data { + str: "a", + array: vec!["b", "c"], + tuple: ("d", "e"), + str_opt: None, + vec_map_str: vec![("key1", "value1"), ("key2", "value2")], + vec_map_f32: vec![("key3", 100.0), ("key4", 200.0)], + vec_map_nested: vec![("n1", vec![("key1", "value1"), ("key2", "value2")])], + hash_map_str: HashMap::from([("key1", "value1"), ("key2", "value2"),]), + hash_map_f32: HashMap::from([("key3", 100.0), ("key4", 200.0),]), + hash_map_nested: HashMap::from([( + "n1", + HashMap::from([("key1", "value1"), ("key2", "value2"),]), + )]), + }, + Data { + str: "f", + array: vec!["g", "h"], + tuple: ("i", "j"), + str_opt: Some("k"), + vec_map_str: vec![("key4", "value4"), ("key5", "value5")], + vec_map_f32: vec![("key6", 300.0), ("key7", 400.0)], + vec_map_nested: vec![("n2", vec![("key4", "value4"), ("key5", "value5")])], + hash_map_str: HashMap::from([("key4", "value4"), ("key5", "value5"),]), + hash_map_f32: HashMap::from([("key6", 300.0), ("key7", 400.0),]), + hash_map_nested: HashMap::from([( + "n2", + HashMap::from([("key4", "value4"), ("key5", "value5"),]), + )]), + }, + ] + ); +} + +#[tokio::test] +async fn several_simple_rows() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + num: u64, + str: String, + } + + let client = get_client(); + let result = client + .query("SELECT number AS num, toString(number) AS str FROM system.numbers LIMIT 3") + .fetch_all::() + .await; + + assert_eq!( + result.unwrap(), + vec![ + Data { + num: 0, + str: "0".to_string(), + }, + Data { + num: 1, + str: "1".to_string(), + }, + Data { + num: 2, + str: "2".to_string(), + }, + ] + ); +} + +#[tokio::test] +async fn many_numbers() { + #[derive(Row, Deserialize)] + struct Data { + number: u64, + } + + let client = get_client(); + let mut cursor = client + .query("SELECT number FROM system.numbers_mt LIMIT 2000") + .fetch::() + .unwrap(); + + let mut sum = 0; + while let Some(row) = cursor.next().await.unwrap() { + sum += row.number; + } + assert_eq!(sum, (0..2000).sum::()); +} + +#[tokio::test] +async fn blob_string_with_serde_bytes() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + #[serde(with = "serde_bytes")] + blob: Vec, + } + + let client = get_client(); + let result = client + .query("SELECT 'foo' :: String AS blob") + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + blob: "foo".as_bytes().to_vec(), + } + ); +} + +#[tokio::test] +async fn arrays() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + id: u16, + one_dim_array: Vec, + two_dim_array: Vec>, + three_dim_array: Vec>>, + description: String, + } + + let client = get_client(); + let result = client + .query( + " + SELECT + 42 :: UInt16 AS id, + [1, 2] :: Array(UInt32) AS one_dim_array, + [[1, 2], [3, 4]] :: Array(Array(Int64)) AS two_dim_array, + [[[1.1, 2.2], [3.3, 4.4]], [], [[5.5, 6.6], [7.7, 8.8]]] :: Array(Array(Array(Float64))) AS three_dim_array, + 'foobar' :: String AS description + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + id: 42, + one_dim_array: vec![1, 2], + two_dim_array: vec![vec![1, 2], vec![3, 4]], + three_dim_array: vec![ + vec![vec![1.1, 2.2], vec![3.3, 4.4]], + vec![], + vec![vec![5.5, 6.6], vec![7.7, 8.8]] + ], + description: "foobar".to_string(), + } + ); +} + +#[tokio::test] +async fn maps() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + m1: HashMap, + m2: HashMap>, + } + + let client = get_client(); + let result = client + .query( + " + SELECT + map('key1', 'value1', 'key2', 'value2') :: Map(String, String) AS m1, + map(42, map('foo', 100, 'bar', 200), + 144, map('qaz', 300, 'qux', 400)) :: Map(UInt16, Map(String, Int32)) AS m2 + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + m1: vec![ + ("key1".to_string(), "value1".to_string()), + ("key2".to_string(), "value2".to_string()), + ] + .into_iter() + .collect(), + m2: vec![ + ( + 42, + vec![("foo".to_string(), 100), ("bar".to_string(), 200)] + .into_iter() + .collect() + ), + ( + 144, + vec![("qaz".to_string(), 300), ("qux".to_string(), 400)] + .into_iter() + .collect() + ) + ] + .into_iter() + .collect::>>(), + } + ); +} + +#[tokio::test] +async fn map_as_vec_of_tuples() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + m1: Vec<(i128, String)>, + m2: Vec<(u16, Vec<(String, i32)>)>, + } + + let client = get_client(); + let result = client + .query( + " + SELECT + map(100, 'value1', 200, 'value2') :: Map(Int128, String) AS m1, + map(42, map('foo', 100, 'bar', 200), + 144, map('qaz', 300, 'qux', 400)) :: Map(UInt16, Map(String, Int32)) AS m2 + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + m1: vec![(100, "value1".to_string()), (200, "value2".to_string()),], + m2: vec![ + ( + 42, + vec![("foo".to_string(), 100), ("bar".to_string(), 200)] + .into_iter() + .collect() + ), + ( + 144, + vec![("qaz".to_string(), 300), ("qux".to_string(), 400)] + .into_iter() + .collect() + ) + ], + } + ) +} + +#[tokio::test] +async fn map_as_vec_of_tuples_schema_mismatch() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + m: Vec<(u16, Vec<(String, i32)>)>, + } + + assert_panic_on_fetch!( + &["Data.m", "Map(Int64, String)", "Int64", "u16"], + "SELECT map(100, 'value1', 200, 'value2') :: Map(Int64, String) AS m" + ); +} + +#[tokio::test] +async fn map_as_vec_of_tuples_schema_mismatch_nested() { + type Inner = Vec<(i32, i64)>; // the value should be i128 instead of i64 + + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + m: Vec<(u16, Vec<(String, Inner)>)>, + } + + assert_panic_on_fetch!( + &[ + "Data.m", + "Map(UInt16, Map(String, Map(Int32, Int128)))", + "Int128", + "i64" + ], + "SELECT map(42, map('foo', map(144, 255))) + :: Map(UInt16, Map(String, Map(Int32, Int128))) AS m" + ); +} + +#[tokio::test] +async fn enums() { + #[derive(Debug, PartialEq, Serialize_repr, Deserialize_repr)] + #[repr(i8)] + enum MyEnum8 { + Winter = -128, + Spring = 0, + Summer = 100, + Autumn = 127, + } + + #[derive(Debug, PartialEq, Serialize_repr, Deserialize_repr)] + #[repr(i16)] + enum MyEnum16 { + North = -32768, + East = 0, + South = 144, + West = 32767, + } + + #[derive(Debug, PartialEq, Row, Serialize, Deserialize)] + struct Data { + id: u16, + enum8: MyEnum8, + enum16: MyEnum16, + } + + let table_name = "test_rbwnat_enum"; + + let client = prepare_database!(); + client + .query( + " + CREATE OR REPLACE TABLE ? + ( + id UInt16, + enum8 Enum8 ('Winter' = -128, 'Spring' = 0, 'Summer' = 100, 'Autumn' = 127), + enum16 Enum16('North' = -32768, 'East' = 0, 'South' = 144, 'West' = 32767) + ) ENGINE MergeTree ORDER BY id + ", + ) + .bind(Identifier(table_name)) + .execute() + .await + .unwrap(); + + let expected = vec![ + Data { + id: 1, + enum8: MyEnum8::Spring, + enum16: MyEnum16::East, + }, + Data { + id: 2, + enum8: MyEnum8::Autumn, + enum16: MyEnum16::North, + }, + Data { + id: 3, + enum8: MyEnum8::Winter, + enum16: MyEnum16::South, + }, + Data { + id: 4, + enum8: MyEnum8::Summer, + enum16: MyEnum16::West, + }, + ]; + + let mut insert = client.insert(table_name).unwrap(); + for row in &expected { + insert.write(row).await.unwrap() + } + insert.end().await.unwrap(); + + let result = client + .query("SELECT * FROM ? ORDER BY id ASC") + .bind(Identifier(table_name)) + .fetch_all::() + .await + .unwrap(); + + assert_eq!(result, expected); +} + +#[tokio::test] +async fn nullable() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: u32, + b: Option, + } + + let client = get_client(); + let result = client + .query( + " + SELECT * FROM ( + SELECT 1 :: UInt32 AS a, 2 :: Nullable(Int64) AS b + UNION ALL + SELECT 3 :: UInt32 AS a, NULL :: Nullable(Int64) AS b + UNION ALL + SELECT 4 :: UInt32 AS a, 5 :: Nullable(Int64) AS b + ) + ORDER BY a ASC + ", + ) + .fetch_all::() + .await; + + assert_eq!( + result.unwrap(), + vec![ + Data { a: 1, b: Some(2) }, + Data { a: 3, b: None }, + Data { a: 4, b: Some(5) }, + ] + ); +} + +#[tokio::test] +async fn invalid_nullable() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + n: Option, + } + assert_panic_on_fetch!( + &["Data.n", "Array(UInt32)", "Option"], + "SELECT array(42) :: Array(UInt32) AS n" + ); +} + +#[tokio::test] +async fn low_cardinality() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: u32, + b: Option, + } + + let client = get_client(); + let result = client + .query( + " + SELECT * FROM ( + SELECT 1 :: LowCardinality(UInt32) AS a, 2 :: LowCardinality(Nullable(Int64)) AS b + UNION ALL + SELECT 3 :: LowCardinality(UInt32) AS a, NULL :: LowCardinality(Nullable(Int64)) AS b + UNION ALL + SELECT 4 :: LowCardinality(UInt32) AS a, 5 :: LowCardinality(Nullable(Int64)) AS b + ) + ORDER BY a ASC + ", + ) + .with_option("allow_suspicious_low_cardinality_types", "1") + .fetch_all::() + .await; + + assert_eq!( + result.unwrap(), + vec![ + Data { a: 1, b: Some(2) }, + Data { a: 3, b: None }, + Data { a: 4, b: Some(5) }, + ] + ); +} + +#[tokio::test] +async fn invalid_low_cardinality() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: u32, + } + let client = get_client().with_option("allow_suspicious_low_cardinality_types", "1"); + assert_panic_on_fetch_with_client!( + client, + &["Data.a", "LowCardinality(Int32)", "u32"], + "SELECT 144 :: LowCardinality(Int32) AS a" + ); +} + +#[tokio::test] +async fn invalid_nullable_low_cardinality() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: Option, + } + let client = get_client().with_option("allow_suspicious_low_cardinality_types", "1"); + assert_panic_on_fetch_with_client!( + client, + &["Data.a", "LowCardinality(Nullable(Int32))", "u32"], + "SELECT 144 :: LowCardinality(Nullable(Int32)) AS a" + ); +} + +#[tokio::test] +#[cfg(feature = "time")] +async fn invalid_serde_with() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + #[serde(with = "clickhouse::serde::time::datetime64::millis")] + n1: time::OffsetDateTime, // underlying is still Int64; should not compose it from two (U)Int32 + } + assert_panic_on_fetch!(&["Data.n1", "UInt32", "i64"], "SELECT 42 :: UInt32 AS n1"); +} + +#[tokio::test] +async fn too_many_struct_fields() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: u32, + b: u32, + c: u32, + } + assert_panic_on_fetch!( + &["2 columns", "3 fields"], + "SELECT 42 :: UInt32 AS a, 144 :: UInt32 AS b" + ); +} + +#[tokio::test] +async fn serde_skip_deserializing() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: u32, + #[serde(skip_deserializing)] + b: u32, + c: u32, + } + + let client = get_client(); + let result = client + .query("SELECT 42 :: UInt32 AS a, 144 :: UInt32 AS c") + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + a: 42, + b: 0, // default value + c: 144, + } + ); +} + +#[tokio::test] +#[cfg(feature = "time")] +async fn date_and_time() { + use time::format_description::well_known::Iso8601; + use time::Month::{February, January}; + use time::OffsetDateTime; + + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + #[serde(with = "clickhouse::serde::time::date")] + date: time::Date, + #[serde(with = "clickhouse::serde::time::date32")] + date32: time::Date, + #[serde(with = "clickhouse::serde::time::datetime")] + date_time: OffsetDateTime, + #[serde(with = "clickhouse::serde::time::datetime64::secs")] + date_time64_0: OffsetDateTime, + #[serde(with = "clickhouse::serde::time::datetime64::millis")] + date_time64_3: OffsetDateTime, + #[serde(with = "clickhouse::serde::time::datetime64::micros")] + date_time64_6: OffsetDateTime, + #[serde(with = "clickhouse::serde::time::datetime64::nanos")] + date_time64_9: OffsetDateTime, + } + + let client = get_client(); + let result = client + .query( + " + SELECT + '2023-01-01' :: Date AS date, + '2023-02-02' :: Date32 AS date32, + '2023-01-03 12:00:00' :: DateTime AS date_time, + '2023-01-04 13:00:00' :: DateTime64(0) AS date_time64_0, + '2023-01-05 14:00:00.123' :: DateTime64(3) AS date_time64_3, + '2023-01-06 15:00:00.123456' :: DateTime64(6) AS date_time64_6, + '2023-01-07 16:00:00.123456789' :: DateTime64(9) AS date_time64_9 + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + date: time::Date::from_calendar_date(2023, January, 1).unwrap(), + date32: time::Date::from_calendar_date(2023, February, 2).unwrap(), + date_time: OffsetDateTime::parse("2023-01-03T12:00:00Z", &Iso8601::DEFAULT).unwrap(), + date_time64_0: OffsetDateTime::parse("2023-01-04T13:00:00Z", &Iso8601::DEFAULT) + .unwrap(), + date_time64_3: OffsetDateTime::parse("2023-01-05T14:00:00.123Z", &Iso8601::DEFAULT) + .unwrap(), + date_time64_6: OffsetDateTime::parse("2023-01-06T15:00:00.123456Z", &Iso8601::DEFAULT) + .unwrap(), + date_time64_9: OffsetDateTime::parse( + "2023-01-07T16:00:00.123456789Z", + &Iso8601::DEFAULT + ) + .unwrap(), + } + ); +} + +#[tokio::test] +#[cfg(feature = "uuid")] +async fn uuid() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + id: u16, + #[serde(with = "clickhouse::serde::uuid")] + uuid: uuid::Uuid, + } + + let client = get_client(); + let result = client + .query( + " + SELECT + 42 :: UInt16 AS id, + '550e8400-e29b-41d4-a716-446655440000' :: UUID AS uuid + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + id: 42, + uuid: uuid::Uuid::from_str("550e8400-e29b-41d4-a716-446655440000").unwrap(), + } + ); +} + +#[tokio::test] +async fn ipv4_ipv6() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + id: u16, + #[serde(with = "clickhouse::serde::ipv4")] + ipv4: std::net::Ipv4Addr, + ipv6: std::net::Ipv6Addr, + } + + let client = get_client(); + let result = client + .query( + " + SELECT + 42 :: UInt16 AS id, + '192.168.0.1' :: IPv4 AS ipv4, + '2001:db8:3333:4444:5555:6666:7777:8888' :: IPv6 AS ipv6 + ", + ) + .fetch_all::() + .await; + + assert_eq!( + result.unwrap(), + vec![Data { + id: 42, + ipv4: std::net::Ipv4Addr::new(192, 168, 0, 1), + ipv6: std::net::Ipv6Addr::from_str("2001:db8:3333:4444:5555:6666:7777:8888").unwrap(), + }] + ) +} + +#[tokio::test] +async fn fixed_str() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: [u8; 4], + b: [u8; 3], + } + + let client = get_client(); + let result = client + .query("SELECT '1234' :: FixedString(4) AS a, '777' :: FixedString(3) AS b") + .fetch_one::() + .await; + + let data = result.unwrap(); + assert_eq!(String::from_utf8_lossy(&data.a), "1234"); + assert_eq!(String::from_utf8_lossy(&data.b), "777"); +} + +#[tokio::test] +async fn fixed_str_too_long() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: [u8; 4], + b: [u8; 3], + } + assert_panic_on_fetch!( + &["Data.a", "FixedString(5)", "with length 4"], + "SELECT '12345' :: FixedString(5) AS a, '777' :: FixedString(3) AS b" + ); +} + +#[tokio::test] +async fn tuple() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: (u32, String), + b: (i128, HashMap), + } + + let client = get_client(); + let result = client + .query( + " + SELECT + (42, 'foo') :: Tuple(UInt32, String) AS a, + (144, map(255, 'bar')) :: Tuple(Int128, Map(UInt16, String)) AS b + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + a: (42, "foo".to_string()), + b: (144, vec![(255, "bar".to_string())].into_iter().collect()), + } + ); +} + +#[tokio::test] +async fn tuple_invalid_definition() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: (u32, String), + b: (i128, HashMap), + } + // Map key is UInt64 instead of UInt16 requested in the struct + assert_panic_on_fetch!( + &[ + "Data.b", + "Tuple(Int128, Map(UInt64, String))", + "UInt64 as u16" + ], + " + SELECT + (42, 'foo') :: Tuple(UInt32, String) AS a, + (144, map(255, 'bar')) :: Tuple(Int128, Map(UInt64, String)) AS b + " + ); +} + +#[tokio::test] +async fn tuple_too_many_elements_in_the_schema() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: (u32, String), + b: (i128, HashMap), + } + // too many elements in the db type definition + assert_panic_on_fetch!( + &[ + "Data.a", + "Tuple(UInt32, String, Bool)", + "remaining elements: Bool" + ], + " + SELECT + (42, 'foo', true) :: Tuple(UInt32, String, Bool) AS a, + (144, map(255, 'bar')) :: Tuple(Int128, Map(UInt16, String)) AS b + " + ); +} + +#[tokio::test] +async fn tuple_too_many_elements_in_the_struct() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: (u32, String, bool), + b: (i128, HashMap), + } + // too many elements in the struct enum + assert_panic_on_fetch!( + &["Data.a", "Tuple(UInt32, String)", "deserialize bool"], + " + SELECT + (42, 'foo') :: Tuple(UInt32, String) AS a, + (144, map(255, 'bar')) :: Tuple(Int128, Map(UInt16, String)) AS b + " + ); +} + +#[tokio::test] +async fn deeply_nested_validation_incorrect_fixed_string() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + id: u32, + col: Vec>>>, + } + // Struct has FixedString(2) instead of FixedString(1) + assert_panic_on_fetch!( + &["Data.col", "FixedString(1)", "with length 2"], + " + SELECT + 42 :: UInt32 AS id, + array(array(map(42, array('1', '2')))) :: Array(Array(Map(UInt32, Array(FixedString(1))))) AS col + " + ); +} + +#[tokio::test] +async fn geo() { + #[derive(Clone, Debug, PartialEq)] + #[derive(Row, serde::Serialize, serde::Deserialize)] + struct Data { + id: u32, + point: Point, + ring: Ring, + polygon: Polygon, + multi_polygon: MultiPolygon, + line_string: LineString, + multi_line_string: MultiLineString, + } + + let client = get_client(); + let result = client + .query( + " + SELECT + 42 :: UInt32 AS id, + (1.0, 2.0) :: Point AS point, + [(3.0, 4.0), (5.0, 6.0)] :: Ring AS ring, + [[(7.0, 8.0), (9.0, 10.0)], [(11.0, 12.0)]] :: Polygon AS polygon, + [[[(13.0, 14.0), (15.0, 16.0)], [(17.0, 18.0)]]] :: MultiPolygon AS multi_polygon, + [(19.0, 20.0), (21.0, 22.0)] :: LineString AS line_string, + [[(23.0, 24.0), (25.0, 26.0)], [(27.0, 28.0)]] :: MultiLineString AS multi_line_string + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + id: 42, + point: (1.0, 2.0), + ring: vec![(3.0, 4.0), (5.0, 6.0)], + polygon: vec![vec![(7.0, 8.0), (9.0, 10.0)], vec![(11.0, 12.0)]], + multi_polygon: vec![vec![vec![(13.0, 14.0), (15.0, 16.0)], vec![(17.0, 18.0)]]], + line_string: vec![(19.0, 20.0), (21.0, 22.0)], + multi_line_string: vec![vec![(23.0, 24.0), (25.0, 26.0)], vec![(27.0, 28.0)]], + } + ); +} + +// TODO: there are two panics; one about schema mismatch, +// another about not all Tuple elements being deserialized +// not easy to assert, same applies to the other Geo types +#[ignore] +#[tokio::test] +async fn geo_invalid_point() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + id: u32, + pt: (i32, i32), + } + assert_panic_on_fetch!( + &["Data.pt", "Point", "Float64 as i32"], + " + SELECT + 42 :: UInt32 AS id, + (1.0, 2.0) :: Point AS pt + " + ); +} + +#[tokio::test] +/// See https://github.com/ClickHouse/clickhouse-rs/issues/100 +async fn issue_100() { + { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + n: i8, + } + assert_panic_on_fetch!( + &["Data.n", "Nullable(Bool)", "i8"], + "SELECT NULL :: Nullable(Bool) AS n" + ); + } + + { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + n: u8, + } + assert_panic_on_fetch!( + &["Data.n", "Nullable(Bool)", "u8"], + "SELECT NULL :: Nullable(Bool) AS n" + ); + } + + { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + n: bool, + } + assert_panic_on_fetch!( + &["Data.n", "Nullable(Bool)", "bool"], + "SELECT NULL :: Nullable(Bool) AS n" + ); + } +} + +// TODO: unignore after insert implementation uses RBWNAT, too +#[ignore] +#[tokio::test] +/// See https://github.com/ClickHouse/clickhouse-rs/issues/109#issuecomment-2243197221 +async fn issue_109_1() { + #[derive(Debug, Serialize, Deserialize, Row)] + struct Data { + #[serde(skip_deserializing)] + en_id: String, + journey: u32, + drone_id: String, + call_sign: String, + } + let client = prepare_database!(); + execute_statements( + &client, + &[ + " + CREATE TABLE issue_109 ( + drone_id String, + call_sign String, + journey UInt32, + en_id String, + ) + ENGINE = MergeTree + ORDER BY (drone_id) + ", + " + INSERT INTO issue_109 VALUES + ('drone_1', 'call_sign_1', 1, 'en_id_1'), + ('drone_2', 'call_sign_2', 2, 'en_id_2'), + ('drone_3', 'call_sign_3', 3, 'en_id_3') + ", + ], + ) + .await; + + let data = client + .query("SELECT journey, drone_id, call_sign FROM issue_109") + .fetch_all::() + .await + .unwrap(); + let mut insert = client.insert("issue_109").unwrap(); + for (id, elem) in data.iter().enumerate() { + let elem = Data { + en_id: format!("ABC-{}", id), + journey: elem.journey, + drone_id: elem.drone_id.clone(), + call_sign: elem.call_sign.clone(), + }; + insert.write(&elem).await.unwrap(); + } + insert.end().await.unwrap(); +} + +#[tokio::test] +async fn issue_112() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: bool, + b: bool, + } + + assert_panic_on_fetch!( + &["Data.a", "Nullable(Bool)", "bool"], + "WITH (SELECT true) AS a, (SELECT true) AS b SELECT ?fields" + ); +} + +#[tokio::test] +/// See https://github.com/ClickHouse/clickhouse-rs/issues/113 +async fn issue_113() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + a: u64, + b: f64, + c: f64, + } + let client = prepare_database!(); + execute_statements(&client, &[ + " + CREATE TABLE issue_113_1( + id UInt32 + ) + ENGINE MergeTree + ORDER BY id + ", + " + CREATE TABLE issue_113_2( + id UInt32, + pos Float64 + ) + ENGINE MergeTree + ORDER BY id + ", + "INSERT INTO issue_113_1 VALUES (1), (2), (3), (4), (5), (6), (7), (8), (9), (10)", + "INSERT INTO issue_113_2 VALUES (1, 100.5), (2, 200.2), (3, 300.3), (4, 444.4), (5, 555.5)", + ]).await; + + // Struct should have had Option instead of f64 + assert_panic_on_fetch_with_client!( + client, + &["Data.b", "Nullable(Float64)", "f64"], + " + SELECT + COUNT(*) AS a, + (COUNT(*) / (SELECT COUNT(*) FROM issue_113_1)) * 100.0 AS b, + AVG(pos) AS c + FROM issue_113_2 + " + ); +} + +#[tokio::test] +#[cfg(feature = "time")] +/// See https://github.com/ClickHouse/clickhouse-rs/issues/114 +async fn issue_114() { + #[derive(Row, Deserialize, Debug, PartialEq)] + struct Data { + #[serde(with = "clickhouse::serde::time::date")] + date: time::Date, + arr: Vec>, + } + + let client = get_client(); + let result = client + .query( + " + SELECT + '2023-05-01' :: Date AS date, + array(map('k1', 'v1'), map('k2', 'v2')) :: Array(Map(String, String)) AS arr + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + date: time::Date::from_calendar_date(2023, time::Month::May, 1).unwrap(), + arr: vec![ + HashMap::from([("k1".to_owned(), "v1".to_owned())]), + HashMap::from([("k2".to_owned(), "v2".to_owned())]), + ], + } + ); +} + +#[tokio::test] +#[cfg(feature = "time")] +/// See https://github.com/ClickHouse/clickhouse-rs/issues/173 +async fn issue_173() { + #[derive(Debug, Serialize, Deserialize, Row)] + struct Data { + log_id: String, + #[serde(with = "clickhouse::serde::time::datetime")] + ts: time::OffsetDateTime, + } + + let client = prepare_database!().with_option("date_time_input_format", "best_effort"); + + execute_statements(&client, &[ + " + CREATE OR REPLACE TABLE logs ( + log_id String, + timestamp DateTime('Europe/Berlin') + ) + ENGINE = MergeTree() + PRIMARY KEY (log_id, timestamp) + ", + "INSERT INTO logs VALUES ('56cde52f-5f34-45e0-9f08-79d6f582e913', '2024-11-05T11:52:52+01:00')", + "INSERT INTO logs VALUES ('0e967129-6271-44f2-967b-0c8d11a60fdc', '2024-11-05T11:59:21+01:00')", + ]).await; + + // panics as we fetch `ts` two times: one from `?fields` macro, and the second time explicitly + // the resulting dataset will, in fact, contain 3 columns instead of 2: + assert_panic_on_fetch_with_client!( + client, + &["3 columns", "2 fields"], + "SELECT ?fields, toUnixTimestamp(timestamp) AS ts FROM logs ORDER by ts DESC" + ); +} + +#[tokio::test] +/// See https://github.com/ClickHouse/clickhouse-rs/issues/185 +async fn issue_185() { + #[derive(Row, Deserialize, Debug, PartialEq)] + struct Data { + pk: u32, + decimal_col: Option, + } + + let client = prepare_database!(); + execute_statements( + &client, + &[ + " + CREATE TABLE issue_185( + pk UInt32, + decimal_col Nullable(Decimal(10, 4))) + ENGINE MergeTree + ORDER BY pk + ", + "INSERT INTO issue_185 VALUES (1, 1.1), (2, 2.2), (3, 3.3)", + ], + ) + .await; + + assert_panic_on_fetch_with_client!( + client, + &["Data.decimal_col", "Decimal(10, 4)", "String"], + "SELECT ?fields FROM issue_185" + ); +} + +#[tokio::test] +#[cfg(feature = "chrono")] +async fn issue_218() { + #[derive(Row, Serialize, Deserialize, Debug)] + struct Data { + max_time: chrono::DateTime, + } + + let client = prepare_database!(); + execute_statements( + &client, + &[" + CREATE TABLE IF NOT EXISTS issue_218 ( + my_time DateTime64(3, 'UTC') CODEC(Delta, ZSTD), + ) ENGINE = MergeTree + ORDER BY my_time + "], + ) + .await; + + // FIXME: It is not a super clear panic as it hints about `&str`, + // and not about the missing attribute for `chrono::DateTime`. + // Still better than a `premature end of input` error, though. + assert_panic_on_fetch_with_client!( + client, + &["Data.max_time", "DateTime64(3, 'UTC')", "&str"], + "SELECT max(my_time) AS max_time FROM issue_218" + ); +} + +#[tokio::test] +async fn variant_wrong_definition() { + #[derive(Debug, Deserialize, PartialEq)] + enum MyVariant { + Str(String), + U32(u32), + } + + #[derive(Debug, Row, Deserialize, PartialEq)] + struct Data { + id: u8, + var: MyVariant, + } + + let client = get_client().with_option("allow_experimental_variant_type", "1"); + + assert_panic_on_fetch_with_client!( + client, + &["Data.var", "Variant(String, UInt16)", "u32"], + " + SELECT * FROM ( + SELECT 0 :: UInt8 AS id, 'foo' :: Variant(String, UInt16) AS var + UNION ALL + SELECT 1 :: UInt8 AS id, 144 :: Variant(String, UInt16) AS var + ) ORDER BY id ASC + " + ); +} + +#[tokio::test] +async fn decimals() { + #[derive(Row, Deserialize, Debug, PartialEq)] + struct Data { + decimal32_9_4: Decimal32, + decimal64_18_8: Decimal64, + decimal128_38_12: Decimal128, + } + + let client = get_client(); + let result = client + .query( + " + SELECT + 42.1234 :: Decimal32(4) AS decimal32_9_4, + 144.56789012 :: Decimal64(8) AS decimal64_18_8, + -17014118346046923173168730.37158841057 :: Decimal128(12) AS decimal128_38_12 + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + decimal32_9_4: Decimal32::from_str("42.1234").unwrap(), + decimal64_18_8: Decimal64::from_str("144.56789012").unwrap(), + decimal128_38_12: Decimal128::from_str("-17014118346046923173168730.37158841057") + .unwrap(), + } + ); +} + +#[tokio::test] +async fn decimal32_wrong_size() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + decimal32: i16, + } + + assert_panic_on_fetch!( + &["Data.decimal32", "Decimal(9, 4)", "i16"], + "SELECT 42 :: Decimal32(4) AS decimal32" + ); +} + +#[tokio::test] +async fn decimal64_wrong_size() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + decimal64: i32, + } + + assert_panic_on_fetch!( + &["Data.decimal64", "Decimal(18, 8)", "i32"], + "SELECT 144 :: Decimal64(8) AS decimal64" + ); +} + +#[tokio::test] +async fn decimal128_wrong_size() { + #[derive(Debug, Row, Serialize, Deserialize, PartialEq)] + struct Data { + decimal128: i64, + } + + assert_panic_on_fetch!( + &["Data.decimal128", "Decimal(38, 12)", "i64"], + "SELECT -17014118346046923173168730.37158841057 :: Decimal128(12) AS decimal128" + ); +} + +#[tokio::test] +async fn different_struct_field_order_same_types() { + #[derive(Debug, Row, Deserialize, PartialEq)] + struct Data { + c: String, + a: String, + } + + let client = get_client(); + let result = client + .query("SELECT 'foo' AS a, 'bar' :: String AS c") + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + c: "bar".to_string(), + a: "foo".to_string(), + } + ); +} + +#[tokio::test] +async fn different_struct_field_order_different_types() { + #[derive(Debug, Row, Deserialize, PartialEq)] + struct Data { + b: u32, + a: String, + c: Vec, + } + + let client = get_client(); + let result = client + .query( + " + SELECT array(true, false, true) AS c, + 42 :: UInt32 AS b, + 'foo' AS a + ", + ) + .fetch_one::() + .await; + + assert_eq!( + result.unwrap(), + Data { + c: vec![true, false, true], + b: 42, + a: "foo".to_string(), + } + ); +} + +// See https://clickhouse.com/docs/en/sql-reference/data-types/geo +type Point = (f64, f64); +type Ring = Vec; +type Polygon = Vec; +type MultiPolygon = Vec; +type LineString = Vec; +type MultiLineString = Vec; + +// See ClickHouse decimal sizes: https://clickhouse.com/docs/en/sql-reference/data-types/decimal +type Decimal32 = FixedPoint; // Decimal(9, 4) = Decimal32(4) +type Decimal64 = FixedPoint; // Decimal(18, 8) = Decimal64(8) +type Decimal128 = FixedPoint; // Decimal(38, 12) = Decimal128(12) diff --git a/tests/it/time.rs b/tests/it/time.rs index 4feaef0d..ef8d8153 100644 --- a/tests/it/time.rs +++ b/tests/it/time.rs @@ -93,11 +93,11 @@ async fn datetime() { let row_str = client .query( " - SELECT toString(dt), - toString(dt64s), - toString(dt64ms), - toString(dt64us), - toString(dt64ns) + SELECT toString(dt) AS dt, + toString(dt64s) AS dt64s, + toString(dt64ms) AS dt64ms, + toString(dt64us) AS dt64us, + toString(dt64ns) AS dt64ns FROM test ", ) diff --git a/tests/it/variant.rs b/tests/it/variant.rs index 14e81901..1343905e 100644 --- a/tests/it/variant.rs +++ b/tests/it/variant.rs @@ -4,7 +4,6 @@ use serde::{Deserialize, Serialize}; use time::Month::January; use clickhouse::Row; - // See also: https://clickhouse.com/docs/en/sql-reference/data-types/variant #[tokio::test] @@ -30,10 +29,10 @@ async fn variant_data_type() { Int8(i8), String(String), UInt128(u128), - UInt16(i16), + UInt16(u16), UInt32(u32), UInt64(u64), - UInt8(i8), + UInt8(u8), } #[derive(Debug, PartialEq, Row, Serialize, Deserialize)] @@ -42,14 +41,14 @@ async fn variant_data_type() { } // No matter the order of the definition on the Variant types, it will always be sorted as follows: - // Variant(Array(UInt16), Bool, FixedString(6), Float32, Float64, Int128, Int16, Int32, Int64, Int8, String, UInt128, UInt16, UInt32, UInt64, UInt8) + // Variant(Array(Int16), Bool, FixedString(6), Float32, Float64, Int128, Int16, Int32, Int64, Int8, String, UInt128, UInt16, UInt32, UInt64, UInt8) client .query( " CREATE OR REPLACE TABLE test_var ( `var` Variant( - Array(UInt16), + Array(Int16), Bool, Date, FixedString(6), diff --git a/types/Cargo.toml b/types/Cargo.toml new file mode 100644 index 00000000..a169b67c --- /dev/null +++ b/types/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "clickhouse-types" +description = "Data types utils to use with Native and RowBinary(WithNamesAndTypes) formats in ClickHouse" +version = "0.1.0" +authors.workspace = true +repository.workspace = true +homepage.workspace = true +edition.workspace = true +license.workspace = true +rust-version.workspace = true + +[lints.rust] +missing_docs = "warn" + +[dependencies] +thiserror = "1.0.16" +bytes = "1.10.1" diff --git a/types/src/data_types.rs b/types/src/data_types.rs new file mode 100644 index 00000000..d61e1979 --- /dev/null +++ b/types/src/data_types.rs @@ -0,0 +1,1410 @@ +use crate::error::TypesError; +use std::collections::HashMap; +use std::fmt::{Display, Formatter}; + +/// A definition of a column in the result set, +/// taken out of the `RowBinaryWithNamesAndTypes` header. +#[derive(Debug, Clone, PartialEq)] +pub struct Column { + /// The name of the column. + pub name: String, + /// The data type of the column. + pub data_type: DataTypeNode, +} + +impl Column { + #[allow(missing_docs)] + pub fn new(name: String, data_type: DataTypeNode) -> Self { + Self { name, data_type } + } +} + +impl Display for Column { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}: {}", self.name, self.data_type) + } +} + +/// Represents a data type in ClickHouse. +/// See +#[derive(Debug, Clone, PartialEq)] +#[non_exhaustive] +#[allow(missing_docs)] +pub enum DataTypeNode { + Bool, + + UInt8, + UInt16, + UInt32, + UInt64, + UInt128, + UInt256, + + Int8, + Int16, + Int32, + Int64, + Int128, + Int256, + + Float32, + Float64, + BFloat16, + + /// Scale, Precision, 32 | 64 | 128 | 256 + Decimal(u8, u8, DecimalType), + + String, + FixedString(usize), + UUID, + + Date, + Date32, + + /// Optional timezone + DateTime(Option), + /// Precision and optional timezone + DateTime64(DateTimePrecision, Option), + + IPv4, + IPv6, + + Nullable(Box), + LowCardinality(Box), + + Array(Box), + Tuple(Vec), + Enum(EnumType, HashMap), + + /// Key-Value pairs are defined as an array, so it can be used as a slice + Map([Box; 2]), + + /// Function name and its arguments + AggregateFunction(String, Vec), + + /// Contains all possible types for this variant + Variant(Vec), + + Dynamic, + JSON, + + Point, + Ring, + LineString, + MultiLineString, + Polygon, + MultiPolygon, +} + +impl DataTypeNode { + /// Parses a data type from a string that is received + /// in the `RowBinaryWithNamesAndTypes` and `Native` formats headers. + /// See also: + pub fn new(name: &str) -> Result { + match name { + "UInt8" => Ok(Self::UInt8), + "UInt16" => Ok(Self::UInt16), + "UInt32" => Ok(Self::UInt32), + "UInt64" => Ok(Self::UInt64), + "UInt128" => Ok(Self::UInt128), + "UInt256" => Ok(Self::UInt256), + "Int8" => Ok(Self::Int8), + "Int16" => Ok(Self::Int16), + "Int32" => Ok(Self::Int32), + "Int64" => Ok(Self::Int64), + "Int128" => Ok(Self::Int128), + "Int256" => Ok(Self::Int256), + "Float32" => Ok(Self::Float32), + "Float64" => Ok(Self::Float64), + "BFloat16" => Ok(Self::BFloat16), + "String" => Ok(Self::String), + "UUID" => Ok(Self::UUID), + "Date" => Ok(Self::Date), + "Date32" => Ok(Self::Date32), + "IPv4" => Ok(Self::IPv4), + "IPv6" => Ok(Self::IPv6), + "Bool" => Ok(Self::Bool), + "Dynamic" => Ok(Self::Dynamic), + "JSON" => Ok(Self::JSON), + "Point" => Ok(Self::Point), + "Ring" => Ok(Self::Ring), + "LineString" => Ok(Self::LineString), + "MultiLineString" => Ok(Self::MultiLineString), + "Polygon" => Ok(Self::Polygon), + "MultiPolygon" => Ok(Self::MultiPolygon), + + str if str.starts_with("Decimal") => parse_decimal(str), + str if str.starts_with("DateTime64") => parse_datetime64(str), + str if str.starts_with("DateTime") => parse_datetime(str), + + str if str.starts_with("Nullable") => parse_nullable(str), + str if str.starts_with("LowCardinality") => parse_low_cardinality(str), + str if str.starts_with("FixedString") => parse_fixed_string(str), + + str if str.starts_with("Array") => parse_array(str), + str if str.starts_with("Enum") => parse_enum(str), + str if str.starts_with("Map") => parse_map(str), + str if str.starts_with("Tuple") => parse_tuple(str), + str if str.starts_with("Variant") => parse_variant(str), + + // ... + str => Err(TypesError::TypeParsingError(format!( + "Unknown data type: {}", + str + ))), + } + } + + /// LowCardinality(T) -> T + pub fn remove_low_cardinality(&self) -> &DataTypeNode { + match self { + DataTypeNode::LowCardinality(inner) => inner, + _ => self, + } + } +} + +impl From for String { + fn from(value: DataTypeNode) -> Self { + value.to_string() + } +} + +impl Display for DataTypeNode { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + use DataTypeNode::*; + let str = match self { + UInt8 => "UInt8".to_string(), + UInt16 => "UInt16".to_string(), + UInt32 => "UInt32".to_string(), + UInt64 => "UInt64".to_string(), + UInt128 => "UInt128".to_string(), + UInt256 => "UInt256".to_string(), + Int8 => "Int8".to_string(), + Int16 => "Int16".to_string(), + Int32 => "Int32".to_string(), + Int64 => "Int64".to_string(), + Int128 => "Int128".to_string(), + Int256 => "Int256".to_string(), + Float32 => "Float32".to_string(), + Float64 => "Float64".to_string(), + BFloat16 => "BFloat16".to_string(), + Decimal(precision, scale, _) => { + format!("Decimal({}, {})", precision, scale) + } + String => "String".to_string(), + UUID => "UUID".to_string(), + Date => "Date".to_string(), + Date32 => "Date32".to_string(), + DateTime(None) => "DateTime".to_string(), + DateTime(Some(tz)) => format!("DateTime('{}')", tz), + DateTime64(precision, None) => format!("DateTime64({})", precision), + DateTime64(precision, Some(tz)) => format!("DateTime64({}, '{}')", precision, tz), + IPv4 => "IPv4".to_string(), + IPv6 => "IPv6".to_string(), + Bool => "Bool".to_string(), + Nullable(inner) => format!("Nullable({})", inner), + Array(inner) => format!("Array({})", inner), + Tuple(elements) => { + let elements_str = data_types_to_string(elements); + format!("Tuple({})", elements_str) + } + Map([key, value]) => { + format!("Map({}, {})", key, value) + } + LowCardinality(inner) => { + format!("LowCardinality({})", inner) + } + Enum(enum_type, values) => { + let mut values_vec = values.iter().collect::>(); + values_vec.sort_by(|(i1, _), (i2, _)| (*i1).cmp(*i2)); + let values_str = values_vec + .iter() + .map(|(index, name)| format!("'{}' = {}", name, index)) + .collect::>() + .join(", "); + format!("{}({})", enum_type, values_str) + } + AggregateFunction(func_name, args) => { + let args_str = data_types_to_string(args); + format!("AggregateFunction({}, {})", func_name, args_str) + } + FixedString(size) => { + format!("FixedString({})", size) + } + Variant(types) => { + let types_str = data_types_to_string(types); + format!("Variant({})", types_str) + } + JSON => "JSON".to_string(), + Dynamic => "Dynamic".to_string(), + Point => "Point".to_string(), + Ring => "Ring".to_string(), + LineString => "LineString".to_string(), + MultiLineString => "MultiLineString".to_string(), + Polygon => "Polygon".to_string(), + MultiPolygon => "MultiPolygon".to_string(), + }; + write!(f, "{}", str) + } +} + +/// Represents the underlying integer size of an Enum type. +#[derive(Debug, Clone, PartialEq)] +pub enum EnumType { + /// Stored as an `Int8` + Enum8, + /// Stored as an `Int16` + Enum16, +} + +impl Display for EnumType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + EnumType::Enum8 => write!(f, "Enum8"), + EnumType::Enum16 => write!(f, "Enum16"), + } + } +} + +/// DateTime64 precision. +/// Defined as an enum, as it is valid only in the range from 0 to 9. +/// See also: +#[derive(Debug, Clone, PartialEq)] +#[allow(missing_docs)] +pub enum DateTimePrecision { + Precision0, + Precision1, + Precision2, + Precision3, + Precision4, + Precision5, + Precision6, + Precision7, + Precision8, + Precision9, +} + +impl DateTimePrecision { + pub(crate) fn new(char: char) -> Result { + match char { + '0' => Ok(DateTimePrecision::Precision0), + '1' => Ok(DateTimePrecision::Precision1), + '2' => Ok(DateTimePrecision::Precision2), + '3' => Ok(DateTimePrecision::Precision3), + '4' => Ok(DateTimePrecision::Precision4), + '5' => Ok(DateTimePrecision::Precision5), + '6' => Ok(DateTimePrecision::Precision6), + '7' => Ok(DateTimePrecision::Precision7), + '8' => Ok(DateTimePrecision::Precision8), + '9' => Ok(DateTimePrecision::Precision9), + _ => Err(TypesError::TypeParsingError(format!( + "Invalid DateTime64 precision, expected to be within [0, 9] interval, got {}", + char + ))), + } + } +} + +/// Represents the underlying integer type for a Decimal. +/// See also: +#[derive(Debug, Clone, PartialEq)] +pub enum DecimalType { + /// Stored as an `Int32` + Decimal32, + /// Stored as an `Int64` + Decimal64, + /// Stored as an `Int128` + Decimal128, + /// Stored as an `Int256` + Decimal256, +} + +impl Display for DecimalType { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + DecimalType::Decimal32 => write!(f, "Decimal32"), + DecimalType::Decimal64 => write!(f, "Decimal64"), + DecimalType::Decimal128 => write!(f, "Decimal128"), + DecimalType::Decimal256 => write!(f, "Decimal256"), + } + } +} + +impl DecimalType { + pub(crate) fn new(precision: u8) -> Result { + if precision <= 9 { + Ok(DecimalType::Decimal32) + } else if precision <= 18 { + Ok(DecimalType::Decimal64) + } else if precision <= 38 { + Ok(DecimalType::Decimal128) + } else if precision <= 76 { + Ok(DecimalType::Decimal256) + } else { + return Err(TypesError::TypeParsingError(format!( + "Invalid Decimal precision: {}", + precision + ))); + } + } +} + +impl Display for DateTimePrecision { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + DateTimePrecision::Precision0 => write!(f, "0"), + DateTimePrecision::Precision1 => write!(f, "1"), + DateTimePrecision::Precision2 => write!(f, "2"), + DateTimePrecision::Precision3 => write!(f, "3"), + DateTimePrecision::Precision4 => write!(f, "4"), + DateTimePrecision::Precision5 => write!(f, "5"), + DateTimePrecision::Precision6 => write!(f, "6"), + DateTimePrecision::Precision7 => write!(f, "7"), + DateTimePrecision::Precision8 => write!(f, "8"), + DateTimePrecision::Precision9 => write!(f, "9"), + } + } +} + +fn data_types_to_string(elements: &[DataTypeNode]) -> String { + elements + .iter() + .map(|a| a.to_string()) + .collect::>() + .join(", ") +} + +fn parse_fixed_string(input: &str) -> Result { + if input.len() >= 14 { + let size_str = &input[12..input.len() - 1]; + let size = size_str.parse::().map_err(|err| { + TypesError::TypeParsingError(format!( + "Invalid FixedString size, expected a valid number. Underlying error: {}, input: {}, size_str: {}", + err, input, size_str + )) + })?; + if size == 0 { + return Err(TypesError::TypeParsingError(format!( + "Invalid FixedString size, expected a positive number, got zero. Input: {}", + input + ))); + } + return Ok(DataTypeNode::FixedString(size)); + } + Err(TypesError::TypeParsingError(format!( + "Invalid FixedString format, expected FixedString(N), got {}", + input + ))) +} + +fn parse_array(input: &str) -> Result { + if input.len() >= 8 { + let inner_type_str = &input[6..input.len() - 1]; + let inner_type = DataTypeNode::new(inner_type_str)?; + return Ok(DataTypeNode::Array(Box::new(inner_type))); + } + Err(TypesError::TypeParsingError(format!( + "Invalid Array format, expected Array(InnerType), got {}", + input + ))) +} + +fn parse_enum(input: &str) -> Result { + if input.len() >= 9 { + let (enum_type, prefix_len) = if input.starts_with("Enum8") { + (EnumType::Enum8, 6) + } else if input.starts_with("Enum16") { + (EnumType::Enum16, 7) + } else { + return Err(TypesError::TypeParsingError(format!( + "Invalid Enum type, expected Enum8 or Enum16, got {}", + input + ))); + }; + let enum_values_map_str = &input[prefix_len..input.len() - 1]; + let enum_values_map = parse_enum_values_map(enum_values_map_str)?; + return Ok(DataTypeNode::Enum(enum_type, enum_values_map)); + } + Err(TypesError::TypeParsingError(format!( + "Invalid Enum format, expected Enum8('name' = value), got {}", + input + ))) +} + +fn parse_datetime(input: &str) -> Result { + if input == "DateTime" { + return Ok(DataTypeNode::DateTime(None)); + } + if input.len() >= 12 { + let timezone = input[10..input.len() - 2].to_string(); + return Ok(DataTypeNode::DateTime(Some(timezone))); + } + Err(TypesError::TypeParsingError(format!( + "Invalid DateTime format, expected DateTime('timezone'), got {}", + input + ))) +} + +fn parse_decimal(input: &str) -> Result { + if input.len() >= 10 { + let precision_and_scale_str = input[8..input.len() - 1].split(", ").collect::>(); + if precision_and_scale_str.len() != 2 { + return Err(TypesError::TypeParsingError(format!( + "Invalid Decimal format, expected Decimal(P, S), got {}", + input + ))); + } + let parsed = precision_and_scale_str + .iter() + .map(|s| s.parse::()) + .collect::, _>>() + .map_err(|err| { + TypesError::TypeParsingError(format!( + "Invalid Decimal format, expected Decimal(P, S), got {}. Underlying error: {}", + input, err + )) + })?; + let precision = parsed[0]; + let scale = parsed[1]; + if scale < 1 || precision < 1 { + return Err(TypesError::TypeParsingError(format!( + "Invalid Decimal format, expected Decimal(P, S) with P > 0 and S > 0, got {}", + input + ))); + } + if precision < scale { + return Err(TypesError::TypeParsingError(format!( + "Invalid Decimal format, expected Decimal(P, S) with P >= S, got {}", + input + ))); + } + let size = DecimalType::new(parsed[0])?; + return Ok(DataTypeNode::Decimal(precision, scale, size)); + } + Err(TypesError::TypeParsingError(format!( + "Invalid Decimal format, expected Decimal(P), got {}", + input + ))) +} + +fn parse_datetime64(input: &str) -> Result { + if input.len() >= 13 { + let mut chars = input[11..input.len() - 1].chars(); + let precision_char = chars.next().ok_or(TypesError::TypeParsingError(format!( + "Invalid DateTime64 precision, expected a positive number. Input: {}", + input + )))?; + let precision = DateTimePrecision::new(precision_char)?; + let maybe_tz = match chars.as_str() { + str if str.len() > 2 => Some(str[3..str.len() - 1].to_string()), + _ => None, + }; + return Ok(DataTypeNode::DateTime64(precision, maybe_tz)); + } + Err(TypesError::TypeParsingError(format!( + "Invalid DateTime format, expected DateTime('timezone'), got {}", + input + ))) +} + +fn parse_low_cardinality(input: &str) -> Result { + if input.len() >= 16 { + let inner_type_str = &input[15..input.len() - 1]; + let inner_type = DataTypeNode::new(inner_type_str)?; + return Ok(DataTypeNode::LowCardinality(Box::new(inner_type))); + } + Err(TypesError::TypeParsingError(format!( + "Invalid LowCardinality format, expected LowCardinality(InnerType), got {}", + input + ))) +} + +fn parse_nullable(input: &str) -> Result { + if input.len() >= 10 { + let inner_type_str = &input[9..input.len() - 1]; + let inner_type = DataTypeNode::new(inner_type_str)?; + return Ok(DataTypeNode::Nullable(Box::new(inner_type))); + } + Err(TypesError::TypeParsingError(format!( + "Invalid Nullable format, expected Nullable(InnerType), got {}", + input + ))) +} + +fn parse_map(input: &str) -> Result { + if input.len() >= 5 { + let inner_types_str = &input[4..input.len() - 1]; + let inner_types = parse_inner_types(inner_types_str)?; + if inner_types.len() != 2 { + return Err(TypesError::TypeParsingError(format!( + "Expected two inner elements in a Map from input {}", + input + ))); + } + return Ok(DataTypeNode::Map([ + Box::new(inner_types[0].clone()), + Box::new(inner_types[1].clone()), + ])); + } + Err(TypesError::TypeParsingError(format!( + "Invalid Map format, expected Map(KeyType, ValueType), got {}", + input + ))) +} + +fn parse_tuple(input: &str) -> Result { + if input.len() > 7 { + let inner_types_str = &input[6..input.len() - 1]; + let inner_types = parse_inner_types(inner_types_str)?; + if inner_types.is_empty() { + return Err(TypesError::TypeParsingError(format!( + "Expected at least one inner element in a Tuple from input {}", + input + ))); + } + return Ok(DataTypeNode::Tuple(inner_types)); + } + Err(TypesError::TypeParsingError(format!( + "Invalid Tuple format, expected Tuple(Type1, Type2, ...), got {}", + input + ))) +} + +fn parse_variant(input: &str) -> Result { + if input.len() >= 9 { + let inner_types_str = &input[8..input.len() - 1]; + let inner_types = parse_inner_types(inner_types_str)?; + return Ok(DataTypeNode::Variant(inner_types)); + } + Err(TypesError::TypeParsingError(format!( + "Invalid Variant format, expected Variant(Type1, Type2, ...), got {}", + input + ))) +} + +/// Considers the element type parsed once we reach a comma outside of parens AND after an unescaped tick. +/// The most complicated cases are values names in the self-defined Enum types: +/// ``` +/// let input1 = "Tuple(Enum8('f\'()' = 1))"; // the result is `f\'()` +/// let input2 = "Tuple(Enum8('(' = 1))"; // the result is `(` +/// ``` +fn parse_inner_types(input: &str) -> Result, TypesError> { + let mut inner_types: Vec = Vec::new(); + + let input_bytes = input.as_bytes(); + + let mut open_parens = 0; + let mut quote_open = false; + let mut char_escaped = false; + let mut last_element_index = 0; + + let mut i = 0; + while i < input_bytes.len() { + if char_escaped { + char_escaped = false; + } else if input_bytes[i] == b'\\' { + char_escaped = true; + } else if input_bytes[i] == b'\'' { + quote_open = !quote_open; // unescaped quote + } else if !quote_open { + if input_bytes[i] == b'(' { + open_parens += 1; + } else if input_bytes[i] == b')' { + open_parens -= 1; + } else if input_bytes[i] == b',' && open_parens == 0 { + let data_type_str = String::from_utf8(input_bytes[last_element_index..i].to_vec()) + .map_err(|_| { + TypesError::TypeParsingError(format!( + "Invalid UTF-8 sequence in input for the inner data type: {}", + &input[last_element_index..] + )) + })?; + let data_type = DataTypeNode::new(&data_type_str)?; + inner_types.push(data_type); + // Skip ', ' (comma and space) + if i + 2 <= input_bytes.len() && input_bytes[i + 1] == b' ' { + i += 2; + } else { + i += 1; + } + last_element_index = i; + continue; // Skip the normal increment at the end of the loop + } + } + i += 1; + } + + // Push the remaining part of the type if it seems to be valid (at least all parentheses are closed) + if open_parens == 0 && last_element_index < input_bytes.len() { + let data_type_str = + String::from_utf8(input_bytes[last_element_index..].to_vec()).map_err(|_| { + TypesError::TypeParsingError(format!( + "Invalid UTF-8 sequence in input for the inner data type: {}", + &input[last_element_index..] + )) + })?; + let data_type = DataTypeNode::new(&data_type_str)?; + inner_types.push(data_type); + } + + Ok(inner_types) +} + +#[inline] +fn parse_enum_index(input_bytes: &[u8], input: &str) -> Result { + String::from_utf8(input_bytes.to_vec()) + .map_err(|_| { + TypesError::TypeParsingError(format!( + "Invalid UTF-8 sequence in input for the enum index: {}", + &input + )) + })? + .parse::() + .map_err(|_| { + TypesError::TypeParsingError(format!( + "Invalid Enum index, expected a valid number. Input: {}", + input + )) + }) +} + +fn parse_enum_values_map(input: &str) -> Result, TypesError> { + let mut names: Vec = Vec::new(); + let mut indices: Vec = Vec::new(); + let mut parsing_name = true; // false when parsing the index + let mut char_escaped = false; // we should ignore escaped ticks + let mut start_index = 1; // Skip the first ' + + let mut i = 1; + let input_bytes = input.as_bytes(); + while i < input_bytes.len() { + if parsing_name { + if char_escaped { + char_escaped = false; + } else if input_bytes[i] == b'\\' { + char_escaped = true; + } else if input_bytes[i] == b'\'' { + // non-escaped closing tick - push the name + let name_bytes = &input_bytes[start_index..i]; + let name = String::from_utf8(name_bytes.to_vec()).map_err(|_| { + TypesError::TypeParsingError(format!( + "Invalid UTF-8 sequence in input for the enum name: {}", + &input[start_index..i] + )) + })?; + names.push(name); + + // Skip ` = ` and the first digit, as it will always have at least one + if i + 4 >= input_bytes.len() { + return Err(TypesError::TypeParsingError(format!( + "Invalid Enum format - expected ` = ` after name, input: {}", + input, + ))); + } + i += 4; + start_index = i; + parsing_name = false; + } + } + // Parsing the index, skipping next iterations until the first non-digit one + else if input_bytes[i] < b'0' || input_bytes[i] > b'9' { + let index = parse_enum_index(&input_bytes[start_index..i], input)?; + indices.push(index); + + // the char at this index should be comma + // Skip `, '`, but not the first char - ClickHouse allows something like Enum8('foo' = 0, '' = 42) + if i + 2 >= input_bytes.len() { + break; // At the end of the enum, no more entries + } + i += 2; + start_index = i + 1; + parsing_name = true; + char_escaped = false; + } + + i += 1; + } + + let index = parse_enum_index(&input_bytes[start_index..i], input)?; + indices.push(index); + + if names.len() != indices.len() { + return Err(TypesError::TypeParsingError(format!( + "Invalid Enum format - expected the same number of names and indices, got names: {}, indices: {}", + names.join(", "), + indices.iter().map(|index| index.to_string()).collect::>().join(", "), + ))); + } + + Ok(indices + .into_iter() + .zip(names) + .collect::>()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_data_type_new_simple() { + assert_eq!(DataTypeNode::new("UInt8").unwrap(), DataTypeNode::UInt8); + assert_eq!(DataTypeNode::new("UInt16").unwrap(), DataTypeNode::UInt16); + assert_eq!(DataTypeNode::new("UInt32").unwrap(), DataTypeNode::UInt32); + assert_eq!(DataTypeNode::new("UInt64").unwrap(), DataTypeNode::UInt64); + assert_eq!(DataTypeNode::new("UInt128").unwrap(), DataTypeNode::UInt128); + assert_eq!(DataTypeNode::new("UInt256").unwrap(), DataTypeNode::UInt256); + assert_eq!(DataTypeNode::new("Int8").unwrap(), DataTypeNode::Int8); + assert_eq!(DataTypeNode::new("Int16").unwrap(), DataTypeNode::Int16); + assert_eq!(DataTypeNode::new("Int32").unwrap(), DataTypeNode::Int32); + assert_eq!(DataTypeNode::new("Int64").unwrap(), DataTypeNode::Int64); + assert_eq!(DataTypeNode::new("Int128").unwrap(), DataTypeNode::Int128); + assert_eq!(DataTypeNode::new("Int256").unwrap(), DataTypeNode::Int256); + assert_eq!(DataTypeNode::new("Float32").unwrap(), DataTypeNode::Float32); + assert_eq!(DataTypeNode::new("Float64").unwrap(), DataTypeNode::Float64); + assert_eq!( + DataTypeNode::new("BFloat16").unwrap(), + DataTypeNode::BFloat16 + ); + assert_eq!(DataTypeNode::new("String").unwrap(), DataTypeNode::String); + assert_eq!(DataTypeNode::new("UUID").unwrap(), DataTypeNode::UUID); + assert_eq!(DataTypeNode::new("Date").unwrap(), DataTypeNode::Date); + assert_eq!(DataTypeNode::new("Date32").unwrap(), DataTypeNode::Date32); + assert_eq!(DataTypeNode::new("IPv4").unwrap(), DataTypeNode::IPv4); + assert_eq!(DataTypeNode::new("IPv6").unwrap(), DataTypeNode::IPv6); + assert_eq!(DataTypeNode::new("Bool").unwrap(), DataTypeNode::Bool); + assert_eq!(DataTypeNode::new("Dynamic").unwrap(), DataTypeNode::Dynamic); + assert_eq!(DataTypeNode::new("JSON").unwrap(), DataTypeNode::JSON); + assert!(DataTypeNode::new("SomeUnknownType").is_err()); + } + + #[test] + fn test_data_type_new_fixed_string() { + assert_eq!( + DataTypeNode::new("FixedString(1)").unwrap(), + DataTypeNode::FixedString(1) + ); + assert_eq!( + DataTypeNode::new("FixedString(16)").unwrap(), + DataTypeNode::FixedString(16) + ); + assert_eq!( + DataTypeNode::new("FixedString(255)").unwrap(), + DataTypeNode::FixedString(255) + ); + assert_eq!( + DataTypeNode::new("FixedString(65535)").unwrap(), + DataTypeNode::FixedString(65_535) + ); + assert!(DataTypeNode::new("FixedString()").is_err()); + assert!(DataTypeNode::new("FixedString(0)").is_err()); + assert!(DataTypeNode::new("FixedString(-1)").is_err()); + assert!(DataTypeNode::new("FixedString(abc)").is_err()); + } + + #[test] + fn test_data_type_new_array() { + assert_eq!( + DataTypeNode::new("Array(UInt8)").unwrap(), + DataTypeNode::Array(Box::new(DataTypeNode::UInt8)) + ); + assert_eq!( + DataTypeNode::new("Array(String)").unwrap(), + DataTypeNode::Array(Box::new(DataTypeNode::String)) + ); + assert_eq!( + DataTypeNode::new("Array(FixedString(16))").unwrap(), + DataTypeNode::Array(Box::new(DataTypeNode::FixedString(16))) + ); + assert_eq!( + DataTypeNode::new("Array(Nullable(Int32))").unwrap(), + DataTypeNode::Array(Box::new(DataTypeNode::Nullable(Box::new( + DataTypeNode::Int32 + )))) + ); + assert!(DataTypeNode::new("Array()").is_err()); + assert!(DataTypeNode::new("Array(abc)").is_err()); + } + + #[test] + fn test_data_type_new_decimal() { + assert_eq!( + DataTypeNode::new("Decimal(7, 2)").unwrap(), + DataTypeNode::Decimal(7, 2, DecimalType::Decimal32) + ); + assert_eq!( + DataTypeNode::new("Decimal(12, 4)").unwrap(), + DataTypeNode::Decimal(12, 4, DecimalType::Decimal64) + ); + assert_eq!( + DataTypeNode::new("Decimal(27, 6)").unwrap(), + DataTypeNode::Decimal(27, 6, DecimalType::Decimal128) + ); + assert_eq!( + DataTypeNode::new("Decimal(42, 8)").unwrap(), + DataTypeNode::Decimal(42, 8, DecimalType::Decimal256) + ); + assert!(DataTypeNode::new("Decimal").is_err()); + assert!(DataTypeNode::new("Decimal(").is_err()); + assert!(DataTypeNode::new("Decimal()").is_err()); + assert!(DataTypeNode::new("Decimal(1)").is_err()); + assert!(DataTypeNode::new("Decimal(1,)").is_err()); + assert!(DataTypeNode::new("Decimal(1, )").is_err()); + assert!(DataTypeNode::new("Decimal(0, 0)").is_err()); // Precision must be > 0 + assert!(DataTypeNode::new("Decimal(x, 0)").is_err()); // Non-numeric precision + assert!(DataTypeNode::new("Decimal(', ')").is_err()); + assert!(DataTypeNode::new("Decimal(77, 1)").is_err()); // Max precision is 76 + assert!(DataTypeNode::new("Decimal(1, 2)").is_err()); // Scale must be less than precision + assert!(DataTypeNode::new("Decimal(1, x)").is_err()); // Non-numeric scale + assert!(DataTypeNode::new("Decimal(42, ,)").is_err()); + assert!(DataTypeNode::new("Decimal(42, ')").is_err()); + assert!(DataTypeNode::new("Decimal(foobar)").is_err()); + } + + #[test] + fn test_data_type_new_datetime() { + assert_eq!( + DataTypeNode::new("DateTime").unwrap(), + DataTypeNode::DateTime(None) + ); + assert_eq!( + DataTypeNode::new("DateTime('UTC')").unwrap(), + DataTypeNode::DateTime(Some("UTC".to_string())) + ); + assert_eq!( + DataTypeNode::new("DateTime('America/New_York')").unwrap(), + DataTypeNode::DateTime(Some("America/New_York".to_string())) + ); + assert!(DataTypeNode::new("DateTime()").is_err()); + } + + #[test] + fn test_data_type_new_datetime64() { + assert_eq!( + DataTypeNode::new("DateTime64(0)").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision0, None) + ); + assert_eq!( + DataTypeNode::new("DateTime64(1)").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision1, None) + ); + assert_eq!( + DataTypeNode::new("DateTime64(2)").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision2, None) + ); + assert_eq!( + DataTypeNode::new("DateTime64(3)").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision3, None) + ); + assert_eq!( + DataTypeNode::new("DateTime64(4)").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision4, None) + ); + assert_eq!( + DataTypeNode::new("DateTime64(5)").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision5, None) + ); + assert_eq!( + DataTypeNode::new("DateTime64(6)").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision6, None) + ); + assert_eq!( + DataTypeNode::new("DateTime64(7)").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision7, None) + ); + assert_eq!( + DataTypeNode::new("DateTime64(8)").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision8, None) + ); + assert_eq!( + DataTypeNode::new("DateTime64(9)").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision9, None) + ); + assert_eq!( + DataTypeNode::new("DateTime64(0, 'UTC')").unwrap(), + DataTypeNode::DateTime64(DateTimePrecision::Precision0, Some("UTC".to_string())) + ); + assert_eq!( + DataTypeNode::new("DateTime64(3, 'America/New_York')").unwrap(), + DataTypeNode::DateTime64( + DateTimePrecision::Precision3, + Some("America/New_York".to_string()) + ) + ); + assert_eq!( + DataTypeNode::new("DateTime64(6, 'America/New_York')").unwrap(), + DataTypeNode::DateTime64( + DateTimePrecision::Precision6, + Some("America/New_York".to_string()) + ) + ); + assert_eq!( + DataTypeNode::new("DateTime64(9, 'Europe/Amsterdam')").unwrap(), + DataTypeNode::DateTime64( + DateTimePrecision::Precision9, + Some("Europe/Amsterdam".to_string()) + ) + ); + assert!(DataTypeNode::new("DateTime64()").is_err()); + assert!(DataTypeNode::new("DateTime64(x)").is_err()); + } + + #[test] + fn test_data_type_new_low_cardinality() { + assert_eq!( + DataTypeNode::new("LowCardinality(UInt8)").unwrap(), + DataTypeNode::LowCardinality(Box::new(DataTypeNode::UInt8)) + ); + assert_eq!( + DataTypeNode::new("LowCardinality(String)").unwrap(), + DataTypeNode::LowCardinality(Box::new(DataTypeNode::String)) + ); + assert_eq!( + DataTypeNode::new("LowCardinality(Array(Int32))").unwrap(), + DataTypeNode::LowCardinality(Box::new(DataTypeNode::Array(Box::new( + DataTypeNode::Int32 + )))) + ); + assert_eq!( + DataTypeNode::new("LowCardinality(Nullable(Int32))").unwrap(), + DataTypeNode::LowCardinality(Box::new(DataTypeNode::Nullable(Box::new( + DataTypeNode::Int32 + )))) + ); + assert!(DataTypeNode::new("LowCardinality").is_err()); + assert!(DataTypeNode::new("LowCardinality()").is_err()); + assert!(DataTypeNode::new("LowCardinality(X)").is_err()); + } + + #[test] + fn test_data_type_new_nullable() { + assert_eq!( + DataTypeNode::new("Nullable(UInt8)").unwrap(), + DataTypeNode::Nullable(Box::new(DataTypeNode::UInt8)) + ); + assert_eq!( + DataTypeNode::new("Nullable(String)").unwrap(), + DataTypeNode::Nullable(Box::new(DataTypeNode::String)) + ); + assert!(DataTypeNode::new("Nullable").is_err()); + assert!(DataTypeNode::new("Nullable()").is_err()); + assert!(DataTypeNode::new("Nullable(X)").is_err()); + } + + #[test] + fn test_data_type_new_map() { + assert_eq!( + DataTypeNode::new("Map(UInt8, String)").unwrap(), + DataTypeNode::Map([ + Box::new(DataTypeNode::UInt8), + Box::new(DataTypeNode::String) + ]) + ); + assert_eq!( + DataTypeNode::new("Map(String, Int32)").unwrap(), + DataTypeNode::Map([ + Box::new(DataTypeNode::String), + Box::new(DataTypeNode::Int32) + ]) + ); + assert_eq!( + DataTypeNode::new("Map(String, Map(Int32, Array(Nullable(String))))").unwrap(), + DataTypeNode::Map([ + Box::new(DataTypeNode::String), + Box::new(DataTypeNode::Map([ + Box::new(DataTypeNode::Int32), + Box::new(DataTypeNode::Array(Box::new(DataTypeNode::Nullable( + Box::new(DataTypeNode::String) + )))) + ])) + ]) + ); + assert!(DataTypeNode::new("Map()").is_err()); + assert!(DataTypeNode::new("Map").is_err()); + assert!(DataTypeNode::new("Map(K)").is_err()); + assert!(DataTypeNode::new("Map(K, V)").is_err()); + assert!(DataTypeNode::new("Map(Int32, V)").is_err()); + assert!(DataTypeNode::new("Map(K, Int32)").is_err()); + assert!(DataTypeNode::new("Map(String, Int32").is_err()); + } + + #[test] + fn test_data_type_new_variant() { + assert_eq!( + DataTypeNode::new("Variant(UInt8, String)").unwrap(), + DataTypeNode::Variant(vec![DataTypeNode::UInt8, DataTypeNode::String]) + ); + assert_eq!( + DataTypeNode::new("Variant(String, Int32)").unwrap(), + DataTypeNode::Variant(vec![DataTypeNode::String, DataTypeNode::Int32]) + ); + assert_eq!( + DataTypeNode::new("Variant(Int32, Array(Nullable(String)), Map(Int32, String))") + .unwrap(), + DataTypeNode::Variant(vec![ + DataTypeNode::Int32, + DataTypeNode::Array(Box::new(DataTypeNode::Nullable(Box::new( + DataTypeNode::String + )))), + DataTypeNode::Map([ + Box::new(DataTypeNode::Int32), + Box::new(DataTypeNode::String) + ]) + ]) + ); + assert!(DataTypeNode::new("Variant").is_err()); + } + + #[test] + fn test_data_type_new_tuple() { + assert_eq!( + DataTypeNode::new("Tuple(UInt8, String)").unwrap(), + DataTypeNode::Tuple(vec![DataTypeNode::UInt8, DataTypeNode::String]) + ); + assert_eq!( + DataTypeNode::new("Tuple(String, Int32)").unwrap(), + DataTypeNode::Tuple(vec![DataTypeNode::String, DataTypeNode::Int32]) + ); + assert_eq!( + DataTypeNode::new("Tuple(Bool,Int32)").unwrap(), + DataTypeNode::Tuple(vec![DataTypeNode::Bool, DataTypeNode::Int32]) + ); + assert_eq!( + DataTypeNode::new( + "Tuple(Int32, Array(Nullable(String)), Map(Int32, Tuple(String, Array(UInt8))))" + ) + .unwrap(), + DataTypeNode::Tuple(vec![ + DataTypeNode::Int32, + DataTypeNode::Array(Box::new(DataTypeNode::Nullable(Box::new( + DataTypeNode::String + )))), + DataTypeNode::Map([ + Box::new(DataTypeNode::Int32), + Box::new(DataTypeNode::Tuple(vec![ + DataTypeNode::String, + DataTypeNode::Array(Box::new(DataTypeNode::UInt8)) + ])) + ]) + ]) + ); + assert_eq!( + DataTypeNode::new(&format!("Tuple(String, {})", ENUM_WITH_ESCAPING_STR)).unwrap(), + DataTypeNode::Tuple(vec![DataTypeNode::String, enum_with_escaping()]) + ); + assert!(DataTypeNode::new("Tuple").is_err()); + assert!(DataTypeNode::new("Tuple(").is_err()); + assert!(DataTypeNode::new("Tuple()").is_err()); + assert!(DataTypeNode::new("Tuple(,)").is_err()); + assert!(DataTypeNode::new("Tuple(X)").is_err()); + assert!(DataTypeNode::new("Tuple(Int32, X)").is_err()); + assert!(DataTypeNode::new("Tuple(Int32, String, X)").is_err()); + } + + #[test] + fn test_data_type_new_enum() { + assert_eq!( + DataTypeNode::new("Enum8('A' = -42)").unwrap(), + DataTypeNode::Enum(EnumType::Enum8, HashMap::from([(-42, "A".to_string())])) + ); + assert_eq!( + DataTypeNode::new("Enum16('A' = -144)").unwrap(), + DataTypeNode::Enum(EnumType::Enum16, HashMap::from([(-144, "A".to_string())])) + ); + assert_eq!( + DataTypeNode::new("Enum8('A' = 1, 'B' = 2)").unwrap(), + DataTypeNode::Enum( + EnumType::Enum8, + HashMap::from([(1, "A".to_string()), (2, "B".to_string())]) + ) + ); + assert_eq!( + DataTypeNode::new("Enum16('A' = 1, 'B' = 2)").unwrap(), + DataTypeNode::Enum( + EnumType::Enum16, + HashMap::from([(1, "A".to_string()), (2, "B".to_string())]) + ) + ); + assert_eq!( + DataTypeNode::new(ENUM_WITH_ESCAPING_STR).unwrap(), + enum_with_escaping() + ); + assert_eq!( + DataTypeNode::new("Enum8('foo' = 0, '' = 42)").unwrap(), + DataTypeNode::Enum( + EnumType::Enum8, + HashMap::from([(0, "foo".to_string()), (42, "".to_string())]) + ) + ); + + assert!(DataTypeNode::new("Enum()").is_err()); + assert!(DataTypeNode::new("Enum8()").is_err()); + assert!(DataTypeNode::new("Enum16()").is_err()); + assert!(DataTypeNode::new("Enum32('A' = 1, 'B' = 2)").is_err()); + assert!(DataTypeNode::new("Enum32('A','B')").is_err()); + assert!(DataTypeNode::new("Enum32('A' = 1, 'B')").is_err()); + assert!(DataTypeNode::new("Enum32('A' = 1, 'B' =)").is_err()); + assert!(DataTypeNode::new("Enum32('A' = 1, 'B' = )").is_err()); + assert!(DataTypeNode::new("Enum32('A'= 1,'B' =)").is_err()); + } + + #[test] + fn test_data_type_new_geo() { + assert_eq!(DataTypeNode::new("Point").unwrap(), DataTypeNode::Point); + assert_eq!(DataTypeNode::new("Ring").unwrap(), DataTypeNode::Ring); + assert_eq!( + DataTypeNode::new("LineString").unwrap(), + DataTypeNode::LineString + ); + assert_eq!(DataTypeNode::new("Polygon").unwrap(), DataTypeNode::Polygon); + assert_eq!( + DataTypeNode::new("MultiLineString").unwrap(), + DataTypeNode::MultiLineString + ); + assert_eq!( + DataTypeNode::new("MultiPolygon").unwrap(), + DataTypeNode::MultiPolygon + ); + } + + #[test] + fn test_data_type_to_string_simple() { + // Simple types + assert_eq!(DataTypeNode::UInt8.to_string(), "UInt8"); + assert_eq!(DataTypeNode::UInt16.to_string(), "UInt16"); + assert_eq!(DataTypeNode::UInt32.to_string(), "UInt32"); + assert_eq!(DataTypeNode::UInt64.to_string(), "UInt64"); + assert_eq!(DataTypeNode::UInt128.to_string(), "UInt128"); + assert_eq!(DataTypeNode::UInt256.to_string(), "UInt256"); + assert_eq!(DataTypeNode::Int8.to_string(), "Int8"); + assert_eq!(DataTypeNode::Int16.to_string(), "Int16"); + assert_eq!(DataTypeNode::Int32.to_string(), "Int32"); + assert_eq!(DataTypeNode::Int64.to_string(), "Int64"); + assert_eq!(DataTypeNode::Int128.to_string(), "Int128"); + assert_eq!(DataTypeNode::Int256.to_string(), "Int256"); + assert_eq!(DataTypeNode::Float32.to_string(), "Float32"); + assert_eq!(DataTypeNode::Float64.to_string(), "Float64"); + assert_eq!(DataTypeNode::BFloat16.to_string(), "BFloat16"); + assert_eq!(DataTypeNode::UUID.to_string(), "UUID"); + assert_eq!(DataTypeNode::Date.to_string(), "Date"); + assert_eq!(DataTypeNode::Date32.to_string(), "Date32"); + assert_eq!(DataTypeNode::IPv4.to_string(), "IPv4"); + assert_eq!(DataTypeNode::IPv6.to_string(), "IPv6"); + assert_eq!(DataTypeNode::Bool.to_string(), "Bool"); + assert_eq!(DataTypeNode::Dynamic.to_string(), "Dynamic"); + assert_eq!(DataTypeNode::JSON.to_string(), "JSON"); + assert_eq!(DataTypeNode::String.to_string(), "String"); + } + + #[test] + fn test_data_types_to_string_complex() { + assert_eq!(DataTypeNode::DateTime(None).to_string(), "DateTime"); + assert_eq!( + DataTypeNode::DateTime(Some("UTC".to_string())).to_string(), + "DateTime('UTC')" + ); + assert_eq!( + DataTypeNode::DateTime(Some("America/New_York".to_string())).to_string(), + "DateTime('America/New_York')" + ); + + assert_eq!( + DataTypeNode::Nullable(Box::new(DataTypeNode::UInt64)).to_string(), + "Nullable(UInt64)" + ); + assert_eq!( + DataTypeNode::LowCardinality(Box::new(DataTypeNode::String)).to_string(), + "LowCardinality(String)" + ); + assert_eq!( + DataTypeNode::Array(Box::new(DataTypeNode::String)).to_string(), + "Array(String)" + ); + assert_eq!( + DataTypeNode::Array(Box::new(DataTypeNode::Nullable(Box::new( + DataTypeNode::String + )))) + .to_string(), + "Array(Nullable(String))" + ); + assert_eq!( + DataTypeNode::Tuple(vec![ + DataTypeNode::String, + DataTypeNode::UInt32, + DataTypeNode::Float64 + ]) + .to_string(), + "Tuple(String, UInt32, Float64)" + ); + assert_eq!( + DataTypeNode::Map([ + Box::new(DataTypeNode::String), + Box::new(DataTypeNode::UInt32) + ]) + .to_string(), + "Map(String, UInt32)" + ); + assert_eq!( + DataTypeNode::Decimal(10, 2, DecimalType::Decimal32).to_string(), + "Decimal(10, 2)" + ); + assert_eq!( + DataTypeNode::Enum( + EnumType::Enum8, + HashMap::from([(1, "A".to_string()), (2, "B".to_string())]), + ) + .to_string(), + "Enum8('A' = 1, 'B' = 2)" + ); + assert_eq!( + DataTypeNode::Enum( + EnumType::Enum16, + HashMap::from([(42, "foo".to_string()), (144, "bar".to_string())]), + ) + .to_string(), + "Enum16('foo' = 42, 'bar' = 144)" + ); + assert_eq!(enum_with_escaping().to_string(), ENUM_WITH_ESCAPING_STR); + assert_eq!( + DataTypeNode::AggregateFunction("sum".to_string(), vec![DataTypeNode::UInt64]) + .to_string(), + "AggregateFunction(sum, UInt64)" + ); + assert_eq!(DataTypeNode::FixedString(16).to_string(), "FixedString(16)"); + assert_eq!( + DataTypeNode::Variant(vec![DataTypeNode::UInt8, DataTypeNode::Bool]).to_string(), + "Variant(UInt8, Bool)" + ); + } + + #[test] + fn test_datetime64_to_string() { + let test_cases = [ + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision0, None), + "DateTime64(0)", + ), + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision1, None), + "DateTime64(1)", + ), + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision2, None), + "DateTime64(2)", + ), + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision3, None), + "DateTime64(3)", + ), + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision4, None), + "DateTime64(4)", + ), + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision5, None), + "DateTime64(5)", + ), + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision6, None), + "DateTime64(6)", + ), + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision7, None), + "DateTime64(7)", + ), + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision8, None), + "DateTime64(8)", + ), + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision9, None), + "DateTime64(9)", + ), + ( + DataTypeNode::DateTime64(DateTimePrecision::Precision0, Some("UTC".to_string())), + "DateTime64(0, 'UTC')", + ), + ( + DataTypeNode::DateTime64( + DateTimePrecision::Precision3, + Some("America/New_York".to_string()), + ), + "DateTime64(3, 'America/New_York')", + ), + ( + DataTypeNode::DateTime64( + DateTimePrecision::Precision6, + Some("Europe/Amsterdam".to_string()), + ), + "DateTime64(6, 'Europe/Amsterdam')", + ), + ( + DataTypeNode::DateTime64( + DateTimePrecision::Precision9, + Some("Asia/Tokyo".to_string()), + ), + "DateTime64(9, 'Asia/Tokyo')", + ), + ]; + for (data_type, expected_str) in test_cases.iter() { + assert_eq!( + &data_type.to_string(), + expected_str, + "Expected data type {} to be formatted as {}", + data_type, + expected_str + ); + } + } + + #[test] + fn test_data_type_node_into_string() { + let data_type = DataTypeNode::new("Array(Int32)").unwrap(); + let data_type_string: String = data_type.into(); + assert_eq!(data_type_string, "Array(Int32)"); + } + + #[test] + fn test_data_type_to_string_geo() { + assert_eq!(DataTypeNode::Point.to_string(), "Point"); + assert_eq!(DataTypeNode::Ring.to_string(), "Ring"); + assert_eq!(DataTypeNode::LineString.to_string(), "LineString"); + assert_eq!(DataTypeNode::Polygon.to_string(), "Polygon"); + assert_eq!(DataTypeNode::MultiLineString.to_string(), "MultiLineString"); + assert_eq!(DataTypeNode::MultiPolygon.to_string(), "MultiPolygon"); + } + + #[test] + fn test_display_column() { + let column = Column::new( + "col".to_string(), + DataTypeNode::new("Array(Int32)").unwrap(), + ); + assert_eq!(column.to_string(), "col: Array(Int32)"); + } + + #[test] + fn test_display_decimal_size() { + assert_eq!(DecimalType::Decimal32.to_string(), "Decimal32"); + assert_eq!(DecimalType::Decimal64.to_string(), "Decimal64"); + assert_eq!(DecimalType::Decimal128.to_string(), "Decimal128"); + assert_eq!(DecimalType::Decimal256.to_string(), "Decimal256"); + } + + const ENUM_WITH_ESCAPING_STR: &'static str = + "Enum8('f\\'' = 1, 'x =' = 2, 'b\\'\\'' = 3, '\\'c=4=' = 42, '4' = 100)"; + + fn enum_with_escaping() -> DataTypeNode { + DataTypeNode::Enum( + EnumType::Enum8, + HashMap::from([ + (1, "f\\'".to_string()), + (2, "x =".to_string()), + (3, "b\\'\\'".to_string()), + (42, "\\'c=4=".to_string()), + (100, "4".to_string()), + ]), + ) + } +} diff --git a/types/src/decoders.rs b/types/src/decoders.rs new file mode 100644 index 00000000..be4355ea --- /dev/null +++ b/types/src/decoders.rs @@ -0,0 +1,27 @@ +use crate::error::TypesError; +use crate::leb128::read_leb128; +use bytes::Buf; + +#[inline] +pub(crate) fn read_string(mut buffer: impl Buf) -> Result { + let length = read_leb128(&mut buffer)? as usize; + if length == 0 { + return Ok("".to_string()); + } + ensure_size(&mut buffer, length)?; + let result = String::from_utf8_lossy(&buffer.copy_to_bytes(length)).to_string(); + Ok(result) +} + +#[inline] +pub(crate) fn ensure_size(buffer: impl Buf, size: usize) -> Result<(), TypesError> { + if buffer.remaining() < size { + Err(TypesError::NotEnoughData(format!( + "expected at least {} bytes, but only {} bytes remaining", + size, + buffer.remaining() + ))) + } else { + Ok(()) + } +} diff --git a/types/src/error.rs b/types/src/error.rs new file mode 100644 index 00000000..8418d10e --- /dev/null +++ b/types/src/error.rs @@ -0,0 +1,11 @@ +#[derive(Debug, thiserror::Error)] +#[non_exhaustive] +#[doc(hidden)] +pub enum TypesError { + #[error("not enough data: {0}")] + NotEnoughData(String), + #[error("type parsing error: {0}")] + TypeParsingError(String), + #[error("unexpected empty list of columns")] + EmptyColumns, +} diff --git a/types/src/leb128.rs b/types/src/leb128.rs new file mode 100644 index 00000000..af7acc0a --- /dev/null +++ b/types/src/leb128.rs @@ -0,0 +1,132 @@ +use crate::error::TypesError; +use crate::error::TypesError::{NotEnoughData, TypeParsingError}; +use bytes::{Buf, BufMut}; + +#[inline] +#[doc(hidden)] +pub fn read_leb128(mut buffer: impl Buf) -> Result { + let mut value = 0u64; + let mut shift = 0; + loop { + if buffer.remaining() < 1 { + return Err(NotEnoughData( + "decoding LEB128, 0 bytes remaining".to_string(), + )); + } + let byte = buffer.get_u8(); + value |= (byte as u64 & 0x7f) << shift; + if byte & 0x80 == 0 { + break; + } + shift += 7; + if shift > 57 { + return Err(TypeParsingError( + "decoding LEB128, unexpected shift value".to_string(), + )); + } + } + Ok(value) +} + +#[inline] +#[doc(hidden)] +pub fn put_leb128(mut buffer: impl BufMut, mut value: u64) { + while { + let mut byte = value as u8 & 0x7f; + value >>= 7; + + if value != 0 { + byte |= 0x80; + } + + buffer.put_u8(byte); + + value != 0 + } {} +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn read() { + let test_cases = vec![ + // (input bytes, expected value) + (vec![0], 0), + (vec![1], 1), + (vec![127], 127), + (vec![128, 1], 128), + (vec![255, 1], 255), + (vec![0x85, 0x91, 0x26], 624773), + (vec![0xE5, 0x8E, 0x26], 624485), + ]; + + for (input, expected) in test_cases { + let result = read_leb128(&mut input.as_slice()).unwrap(); + assert_eq!(result, expected, "Failed decoding {:?}", input); + } + } + + #[test] + fn read_errors() { + let test_cases = vec![ + // (input bytes, expected error message) + (vec![], "decoding LEB128, 0 bytes remaining"), + ( + vec![0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x01], + "decoding LEB128, unexpected shift value", + ), + ]; + + for (input, expected_error) in test_cases { + let result = read_leb128(&mut input.as_slice()); + assert!(result.is_err(), "Expected error for input {:?}", input); + if let Err(e) = result { + assert!( + e.to_string().contains(expected_error), + "Error message mismatch for `{:?}`; error was: `{}`, should contain: `{}`", + input, + e, + expected_error + ); + } + } + } + + #[test] + fn put_and_read() { + let test_cases: Vec<(u64, Vec)> = vec![ + // (value, expected encoding) + (0u64, vec![0x00]), + (1, vec![0x01]), + (127, vec![0x7F]), + (128, vec![0x80, 0x01]), + (255, vec![0xFF, 0x01]), + (300_000, vec![0xE0, 0xA7, 0x12]), + (624_773, vec![0x85, 0x91, 0x26]), + (624_485, vec![0xE5, 0x8E, 0x26]), + (10_000_000, vec![0x80, 0xAD, 0xE2, 0x04]), + (u32::MAX as u64, vec![0xFF, 0xFF, 0xFF, 0xFF, 0x0F]), + ]; + + for (value, expected_encoding) in test_cases { + // Test encoding + let mut encoded = Vec::new(); + put_leb128(&mut encoded, value); + assert_eq!( + encoded, expected_encoding, + "Incorrect encoding for {}", + value + ); + + // Test round-trip + let decoded = read_leb128(&mut encoded.as_slice()).unwrap(); + assert_eq!( + decoded, value, + "Failed round trip for {}: encoded as {:?}, decoded as {}", + value, encoded, decoded + ); + } + } +} diff --git a/types/src/lib.rs b/types/src/lib.rs new file mode 100644 index 00000000..9271783b --- /dev/null +++ b/types/src/lib.rs @@ -0,0 +1,96 @@ +//! # clickhouse-types +//! +//! This crate is required for `RowBinaryWithNamesAndTypes` struct definition validation, +//! as it contains ClickHouse data types AST, as well as functions and utilities +//! to parse the types out of the ClickHouse server response. +//! +//! Note that this crate is not intended for public usage, +//! as it might introduce internal breaking changes not following semver. + +pub use crate::data_types::{Column, DataTypeNode}; +use crate::decoders::read_string; +use crate::error::TypesError; +use bytes::{Buf, BufMut}; + +/// Exported for internal usage only. +/// Do not use it directly in your code. +pub use crate::leb128::put_leb128; +pub use crate::leb128::read_leb128; + +/// ClickHouse data types AST and utilities to parse it from strings. +pub mod data_types; +/// Required decoders to parse the columns definitions from the header of the response. +pub mod decoders; +/// Error types for this crate. +pub mod error; +/// Utils for working with LEB128 encoding and decoding. +pub mod leb128; + +/// Parses the columns definitions from the response in `RowBinaryWithNamesAndTypes` format. +/// This is a mandatory step for this format, as it enables client-side data types validation. +#[doc(hidden)] +pub fn parse_rbwnat_columns_header(mut buffer: impl Buf) -> Result, TypesError> { + let num_columns = read_leb128(&mut buffer)?; + if num_columns == 0 { + return Err(TypesError::EmptyColumns); + } + let mut columns_names: Vec = Vec::with_capacity(num_columns as usize); + for _ in 0..num_columns { + let column_name = read_string(&mut buffer)?; + columns_names.push(column_name); + } + let mut column_data_types: Vec = Vec::with_capacity(num_columns as usize); + for _ in 0..num_columns { + let column_type = read_string(&mut buffer)?; + let data_type = DataTypeNode::new(&column_type)?; + column_data_types.push(data_type); + } + let columns = columns_names + .into_iter() + .zip(column_data_types) + .map(|(name, data_type)| Column::new(name, data_type)) + .collect(); + Ok(columns) +} + +/// Having a table definition as a slice of [`Column`], +/// encodes it into the `RowBinary` format, and puts it into the provided buffer. +/// This is required to insert the data in `RowBinaryWithNamesAndTypes` format. +#[doc(hidden)] +pub fn put_rbwnat_columns_header( + columns: &[Column], + mut buffer: impl BufMut, +) -> Result<(), TypesError> { + if columns.is_empty() { + return Err(TypesError::EmptyColumns); + } + put_leb128(&mut buffer, columns.len() as u64); + for column in columns { + put_leb128(&mut buffer, column.name.len() as u64); + buffer.put_slice(column.name.as_bytes()); + } + for column in columns.iter() { + put_leb128(&mut buffer, column.data_type.to_string().len() as u64); + buffer.put_slice(column.data_type.to_string().as_bytes()); + } + Ok(()) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::data_types::DataTypeNode; + use bytes::BytesMut; + + #[test] + fn test_rbwnat_header_round_trip() { + let mut buffer = BytesMut::new(); + let columns = vec![ + Column::new("id".to_string(), DataTypeNode::Int32), + Column::new("name".to_string(), DataTypeNode::String), + ]; + put_rbwnat_columns_header(&columns, &mut buffer).unwrap(); + let parsed_columns = parse_rbwnat_columns_header(&mut buffer).unwrap(); + assert_eq!(parsed_columns, columns); + } +}