From 5f8f2e153a50be4079cd7d6a88ca3005b755a378 Mon Sep 17 00:00:00 2001 From: Brad Paskewitz Date: Fri, 3 Oct 2025 14:33:11 -0400 Subject: [PATCH] feat(optimizer)!: annotate type for Snowflake COSH function --- sqlglot/dialects/snowflake.py | 1 + sqlglot/expressions.py | 4 ++++ tests/dialects/test_bigquery.py | 1 + tests/dialects/test_databricks.py | 1 + tests/dialects/test_duckdb.py | 1 + tests/dialects/test_postgres.py | 1 + tests/dialects/test_redshift.py | 1 + tests/dialects/test_snowflake.py | 1 + tests/fixtures/optimizer/annotate_functions.sql | 4 ++++ 9 files changed, 15 insertions(+) diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index d788210fd6..33952b01df 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -565,6 +565,7 @@ class Snowflake(Dialect): **Dialect.TYPE_TO_EXPRESSIONS, exp.DataType.Type.DOUBLE: { *Dialect.TYPE_TO_EXPRESSIONS[exp.DataType.Type.DOUBLE], + exp.Cosh, exp.Cot, exp.Sin, exp.Tan, diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 461edfda06..3ad3870cfc 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -5553,6 +5553,10 @@ class Tan(Func): pass +class Cosh(Func): + pass + + class CosineDistance(Func): arg_types = {"this": True, "expression": True} diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index a61ea17a53..5ed1830ca2 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -2108,6 +2108,7 @@ def test_ml_functions(self): self.validate_identity( "SELECT * FROM ML.PREDICT(MODEL mydataset.mymodel, (SELECT custom_label, column1, column2 FROM mydataset.mytable), STRUCT(0.55 AS threshold))" ) + self.validate_identity("SELECT COSH(1.5)") self.validate_identity( "SELECT * FROM ML.PREDICT(MODEL `my_project`.my_dataset.my_model, (SELECT * FROM input_data))" ) diff --git a/tests/dialects/test_databricks.py b/tests/dialects/test_databricks.py index 854e34ea7e..f855a3d827 100644 --- a/tests/dialects/test_databricks.py +++ b/tests/dialects/test_databricks.py @@ -7,6 +7,7 @@ class TestDatabricks(Validator): dialect = "databricks" def test_databricks(self): + self.validate_identity("SELECT COSH(1.5)") null_type = exp.DataType.build("VOID", dialect="databricks") self.assertEqual(null_type.sql(), "NULL") self.assertEqual(null_type.sql("databricks"), "VOID") diff --git a/tests/dialects/test_duckdb.py b/tests/dialects/test_duckdb.py index 590d0f4979..18e26b10a1 100644 --- a/tests/dialects/test_duckdb.py +++ b/tests/dialects/test_duckdb.py @@ -9,6 +9,7 @@ class TestDuckDB(Validator): dialect = "duckdb" def test_duckdb(self): + self.validate_identity("SELECT COSH(1.5)") with self.assertRaises(ParseError): parse_one("1 //", read="duckdb") diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 8a7b39cb51..c4aeeace1b 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -8,6 +8,7 @@ class TestPostgres(Validator): dialect = "postgres" def test_postgres(self): + self.validate_identity("SELECT COSH(1.5)") self.validate_identity( "select count() OVER(partition by a order by a range offset preceding exclude current row)", "SELECT COUNT() OVER (PARTITION BY a ORDER BY a range BETWEEN offset preceding AND CURRENT ROW EXCLUDE CURRENT ROW)", diff --git a/tests/dialects/test_redshift.py b/tests/dialects/test_redshift.py index 16c9cd2e30..194dd4f3d2 100644 --- a/tests/dialects/test_redshift.py +++ b/tests/dialects/test_redshift.py @@ -6,6 +6,7 @@ class TestRedshift(Validator): dialect = "redshift" def test_redshift(self): + self.validate_identity("SELECT COSH(1.5)") self.validate_all( "SELECT SPLIT_TO_ARRAY('12,345,6789')", write={ diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index fe4391a712..b8f12cf92c 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -49,6 +49,7 @@ def test_snowflake(self): self.validate_identity("SELECT SOUNDEX_P123(column_name)") self.validate_identity("SELECT ABS(x)") self.validate_identity("SELECT SIGN(x)") + self.validate_identity("SELECT COSH(1.5)") self.validate_identity("SELECT JAROWINKLER_SIMILARITY('hello', 'world')") self.validate_identity("SELECT TRANSLATE(column_name, 'abc', '123')") self.validate_identity("SELECT UNICODE(column_name)") diff --git a/tests/fixtures/optimizer/annotate_functions.sql b/tests/fixtures/optimizer/annotate_functions.sql index 343fd8ae45..14117011ec 100644 --- a/tests/fixtures/optimizer/annotate_functions.sql +++ b/tests/fixtures/optimizer/annotate_functions.sql @@ -1631,6 +1631,10 @@ VARCHAR; COLLATE('hello', 'utf8'); VARCHAR; +# dialect: snowflake +COSH(1.5); +DOUBLE; + # dialect: snowflake COMPRESS('Hello World', 'SNAPPY'); BINARY;