Skip to content

Commit 75376ff

Browse files
Fix incorrect quoting of identifiers with _ as initial character. (#569)
1 parent b9b26e5 commit 75376ff

File tree

3 files changed

+34
-2
lines changed

3 files changed

+34
-2
lines changed

DESCRIPTION.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ Source code is also available at:
99

1010
# Release Notes
1111
- (Unreleased)
12-
- Fix return value of snowflake get_table_names
12+
- Fix return value of snowflake get_table_names.
13+
- Fix incorrect quoting of identifiers with `_` as initial character.
1314
- Added `force_div_is_floordiv` flag to override `div_is_floordiv` new default value `False` in `SnowflakeDialect`.
1415
- With the flag in `False`, the `/` division operator will be treated as a float division and `//` as a floor division.
1516
- This flag is added to maintain backward compatibility with the previous behavior of Snowflake Dialect division.

src/snowflake/sqlalchemy/base.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,11 @@
117117
r"\s*(?:UPDATE|INSERT|DELETE|MERGE|COPY)", re.I | re.UNICODE
118118
)
119119
# used for quoting identifiers ie. table names, column names, etc.
120-
ILLEGAL_INITIAL_CHARACTERS = frozenset({d for d in string.digits}.union({"_", "$"}))
120+
ILLEGAL_INITIAL_CHARACTERS = frozenset({d for d in string.digits}.union({"$"}))
121+
122+
123+
# used for quoting identifiers ie. table names, column names, etc.
124+
ILLEGAL_IDENTIFIERS = frozenset({d for d in string.digits}.union({"_"}))
121125

122126
"""
123127
Overwrite methods to handle Snowflake BCR change:
@@ -443,6 +447,7 @@ def _join_left_to_right(
443447
class SnowflakeIdentifierPreparer(compiler.IdentifierPreparer):
444448
reserved_words = {x.lower() for x in RESERVED_WORDS}
445449
illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
450+
illegal_identifiers = ILLEGAL_IDENTIFIERS
446451

447452
def __init__(self, dialect, **kw):
448453
quote = '"'
@@ -471,6 +476,17 @@ def format_label(self, label, name=None):
471476

472477
return self.quote_identifier(s) if n.quote else s
473478

479+
def _requires_quotes(self, value: str) -> bool:
480+
"""Return True if the given identifier requires quoting."""
481+
lc_value = value.lower()
482+
return (
483+
lc_value in self.reserved_words
484+
or lc_value in self.illegal_identifiers
485+
or value[0] in self.illegal_initial_characters
486+
or not self.legal_characters.match(str(value))
487+
or (lc_value != value)
488+
)
489+
474490
def _split_schema_by_dot(self, schema):
475491
ret = []
476492
idx = 0

tests/test_compiler.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,21 @@ def test_underscore_as_valid_identifier(self):
5050
dialect="snowflake",
5151
)
5252

53+
def test_underscore_as_initial_character_as_non_quoted_identifier(self):
54+
_table = table(
55+
"table_1745924",
56+
column("ca", Integer),
57+
column("cb", String),
58+
column("_identifier", String),
59+
)
60+
61+
stmt = insert(_table).values(ca=1, cb="test", _identifier="test_")
62+
self.assert_compile(
63+
stmt,
64+
"INSERT INTO table_1745924 (ca, cb, _identifier) VALUES (%(ca)s, %(cb)s, %(_identifier)s)",
65+
dialect="snowflake",
66+
)
67+
5368
def test_multi_table_delete(self):
5469
statement = table1.delete().where(table1.c.id == table2.c.id)
5570
self.assert_compile(

0 commit comments

Comments
 (0)