Skip to content

Commit 5ce8254

Browse files
committed
feat(optimizer)!: Annotate type for snowflake STRTOK function. (#5991)
* feat(optimizer)!: Annotate type for snowflake STROK function * feat(optimizer)!: Map to split_part expression and annotate, update tests * fix: applied formatting * refactor: Modify functions and update tests as per comments * Remove redundant `TRANSFORMS` entry * Use `rename_func` directly * Make sure to set default delimiter as well, if missing --------- Co-authored-by: Jo <46752250+georgesittas@users.noreply.github.com>
1 parent 587196c commit 5ce8254

File tree

4 files changed

+47
-2
lines changed

4 files changed

+47
-2
lines changed

sqlglot/dialects/snowflake.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,18 @@
4141
from sqlglot._typing import E, B
4242

4343

44-
# from https://docs.snowflake.com/en/sql-reference/functions/to_timestamp.html
44+
def _build_strtok(args: t.List) -> exp.SplitPart:
45+
# Add default delimiter (space) if missing - per Snowflake docs
46+
if len(args) == 1:
47+
args.append(exp.Literal.string(" "))
48+
49+
# Add default part_index (1) if missing
50+
if len(args) == 2:
51+
args.append(exp.Literal.number(1))
52+
53+
return exp.SplitPart.from_arg_list(args)
54+
55+
4556
def _build_datetime(
4657
name: str, kind: exp.DataType.Type, safe: bool = False
4758
) -> t.Callable[[t.List], exp.Func]:
@@ -773,6 +784,7 @@ class Parser(parser.Parser):
773784
"SHA2_BINARY": exp.SHA2Digest.from_arg_list,
774785
"SHA2_HEX": exp.SHA2.from_arg_list,
775786
"SQUARE": lambda args: exp.Pow(this=seq_get(args, 0), expression=exp.Literal.number(2)),
787+
"STRTOK": _build_strtok,
776788
"TABLE": lambda args: exp.TableFromRows(this=seq_get(args, 0)),
777789
"TIMEADD": _build_date_time_add(exp.TimeAdd),
778790
"TIMEDIFF": _build_datediff,
@@ -1902,3 +1914,13 @@ def format_sql(self, expression: exp.Format) -> str:
19021914
return self.func("TO_CHAR", expression.expressions[0])
19031915

19041916
return self.function_fallback_sql(expression)
1917+
1918+
def splitpart_sql(self, expression: exp.SplitPart) -> str:
1919+
# Set part_index to 1 if missing
1920+
if not expression.args.get("delimiter"):
1921+
expression.set("delimiter", exp.Literal.string(" "))
1922+
1923+
if not expression.args.get("part_index"):
1924+
expression.set("part_index", exp.Literal.number(1))
1925+
1926+
return rename_func("SPLIT_PART")(self, expression)

sqlglot/expressions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7390,8 +7390,10 @@ class Split(Func):
73907390

73917391

73927392
# https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.split_part.html
7393+
# https://docs.snowflake.com/en/sql-reference/functions/split_part
7394+
# https://docs.snowflake.com/en/sql-reference/functions/strtok
73937395
class SplitPart(Func):
7394-
arg_types = {"this": True, "delimiter": True, "part_index": True}
7396+
arg_types = {"this": True, "delimiter": False, "part_index": False}
73957397

73967398

73977399
# Start may be omitted in the case of postgres

tests/dialects/test_snowflake.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,15 @@ def test_snowflake(self):
5858
self.validate_identity("SELECT {* EXCLUDE (col1)} FROM my_table")
5959
self.validate_identity("SELECT {* EXCLUDE (col1, col2)} FROM my_table")
6060
self.validate_identity("SELECT a, b, COUNT(*) FROM x GROUP BY ALL LIMIT 100")
61+
self.validate_identity(
62+
"SELECT STRTOK('hello world')", "SELECT SPLIT_PART('hello world', ' ', 1)"
63+
)
64+
self.validate_identity(
65+
"SELECT STRTOK('hello world', ' ')", "SELECT SPLIT_PART('hello world', ' ', 1)"
66+
)
67+
self.validate_identity(
68+
"SELECT STRTOK('hello world', ' ', 2)", "SELECT SPLIT_PART('hello world', ' ', 2)"
69+
)
6170
self.validate_identity("STRTOK_TO_ARRAY('a b c')")
6271
self.validate_identity("STRTOK_TO_ARRAY('a.b.c', '.')")
6372
self.validate_identity("GET(a, b)")

tests/fixtures/optimizer/annotate_functions.sql

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2171,6 +2171,18 @@ ARRAY;
21712171
SPLIT_PART('11.22.33', '.', 1);
21722172
VARCHAR;
21732173

2174+
# dialect: snowflake
2175+
STRTOK('hello world');
2176+
VARCHAR;
2177+
2178+
# dialect: snowflake
2179+
STRTOK('hello world', ' ');
2180+
VARCHAR;
2181+
2182+
# dialect: snowflake
2183+
STRTOK('a.b.c', '.', 1);
2184+
VARCHAR;
2185+
21742186
# dialect: snowflake
21752187
STARTSWITH('hello world', 'hello');
21762188
BOOLEAN;

0 commit comments

Comments
 (0)