diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index 64782e30f4..08af776275 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -148,6 +148,20 @@ def _build_if_from_div0(args: t.List) -> exp.If: return exp.If(this=cond, true=true, false=false) +# https://docs.snowflake.com/en/sql-reference/functions/div0null +def _build_if_from_div0null(args: t.List) -> exp.If: + lhs = exp._wrap(seq_get(args, 0), exp.Binary) + rhs = exp._wrap(seq_get(args, 1), exp.Binary) + + # Returns 0 when divisor is 0 OR NULL + cond = exp.EQ(this=rhs, expression=exp.Literal.number(0)).or_( + exp.Is(this=rhs, expression=exp.null()) + ) + true = exp.Literal.number(0) + false = exp.Div(this=lhs, expression=rhs) + return exp.If(this=cond, true=true, false=false) + + # https://docs.snowflake.com/en/sql-reference/functions/zeroifnull def _build_if_from_zeroifnull(args: t.List) -> exp.If: cond = exp.Is(this=seq_get(args, 0), expression=exp.Null()) @@ -746,6 +760,7 @@ class Parser(parser.Parser): "DATEDIFF": _build_datediff, "DAYOFWEEKISO": exp.DayOfWeekIso.from_arg_list, "DIV0": _build_if_from_div0, + "DIV0NULL": _build_if_from_div0null, "EDITDISTANCE": lambda args: exp.Levenshtein( this=seq_get(args, 0), expression=seq_get(args, 1), max_dist=seq_get(args, 2) ), diff --git a/tests/dialects/test_snowflake.py b/tests/dialects/test_snowflake.py index 4c0e2d0f5f..7ced732157 100644 --- a/tests/dialects/test_snowflake.py +++ b/tests/dialects/test_snowflake.py @@ -838,6 +838,28 @@ def test_snowflake(self): "duckdb": "CASE WHEN (c - d) = 0 AND NOT (a - b) IS NULL THEN 0 ELSE (a - b) / (c - d) END", }, ) + self.validate_all( + "DIV0NULL(foo, bar)", + write={ + "snowflake": "IFF(bar = 0 OR bar IS NULL, 0, foo / bar)", + "sqlite": "IIF(bar = 0 OR bar IS NULL, 0, CAST(foo AS REAL) / bar)", + "presto": "IF(bar = 0 OR bar IS NULL, 0, CAST(foo AS DOUBLE) / bar)", + "spark": "IF(bar = 0 OR bar IS NULL, 0, foo / bar)", + "hive": "IF(bar = 0 OR bar IS NULL, 0, foo / bar)", + "duckdb": "CASE WHEN bar = 0 OR bar IS NULL THEN 0 ELSE foo / bar END", + }, + ) + self.validate_all( + "DIV0NULL(a - b, c - d)", + write={ + "snowflake": "IFF((c - d) = 0 OR (c - d) IS NULL, 0, (a - b) / (c - d))", + "sqlite": "IIF((c - d) = 0 OR (c - d) IS NULL, 0, CAST((a - b) AS REAL) / (c - d))", + "presto": "IF((c - d) = 0 OR (c - d) IS NULL, 0, CAST((a - b) AS DOUBLE) / (c - d))", + "spark": "IF((c - d) = 0 OR (c - d) IS NULL, 0, (a - b) / (c - d))", + "hive": "IF((c - d) = 0 OR (c - d) IS NULL, 0, (a - b) / (c - d))", + "duckdb": "CASE WHEN (c - d) = 0 OR (c - d) IS NULL THEN 0 ELSE (a - b) / (c - d) END", + }, + ) self.validate_all( "ZEROIFNULL(foo)", write={ diff --git a/tests/fixtures/optimizer/annotate_functions.sql b/tests/fixtures/optimizer/annotate_functions.sql index 39a688b379..2518dded89 100644 --- a/tests/fixtures/optimizer/annotate_functions.sql +++ b/tests/fixtures/optimizer/annotate_functions.sql @@ -1639,6 +1639,22 @@ BINARY; DECOMPRESS_STRING('compressed_data', 'ZSTD'); VARCHAR; +# dialect: snowflake +DIV0(10, 0); +DOUBLE; + +# dialect: snowflake +DIV0(tbl.double_col, tbl.double_col); +DOUBLE; + +# dialect: snowflake +DIV0NULL(10, 0); +DOUBLE; + +# dialect: snowflake +DIV0NULL(tbl.double_col, tbl.double_col); +DOUBLE; + # dialect: snowflake LPAD('Hello', 10, '*'); VARCHAR;