Skip to content

Commit 140fec8

Browse files
Fix BOOLEAN not found in snowdialect (#551)
* Fix import BOOLEAN error * Update imports * Add test for explicit imports
1 parent 9b2c6d1 commit 140fec8

File tree

3 files changed

+67
-0
lines changed

3 files changed

+67
-0
lines changed

DESCRIPTION.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ Source code is also available at:
1111

1212
- (Unreleased)
1313
- Add support for partition by to copy into <location>
14+
- Fix BOOLEAN type not found in snowdialect
1415

1516
- v1.7.0(November 22, 2024)
1617

src/snowflake/sqlalchemy/snowdialect.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@
3939
_CUSTOM_Float,
4040
_CUSTOM_Time,
4141
)
42+
from .parser.custom_type_parser import * # noqa
43+
from .parser.custom_type_parser import _CUSTOM_DECIMAL # noqa
4244
from .parser.custom_type_parser import ischema_names, parse_type
4345
from .sql.custom_schema.custom_table_prefix import CustomTablePrefix
4446
from .util import (

tests/test_imports.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#
2+
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
3+
#
4+
5+
import importlib
6+
import inspect
7+
8+
import pytest
9+
10+
11+
def get_classes_from_module(module_name):
12+
"""Returns a set of class names from a given module."""
13+
try:
14+
module = importlib.import_module(module_name)
15+
members = inspect.getmembers(module)
16+
return {name for name, obj in members if inspect.isclass(obj)}
17+
18+
except ImportError:
19+
print(f"Module '{module_name}' could not be imported.")
20+
return set()
21+
22+
23+
def test_types_in_snowdialect():
24+
classes_a = get_classes_from_module(
25+
"snowflake.sqlalchemy.parser.custom_type_parser"
26+
)
27+
classes_b = get_classes_from_module("snowflake.sqlalchemy.snowdialect")
28+
assert classes_a.issubset(classes_b), str(classes_a - classes_b)
29+
30+
31+
@pytest.mark.parametrize(
32+
"type_class_name",
33+
[
34+
"BIGINT",
35+
"BINARY",
36+
"BOOLEAN",
37+
"CHAR",
38+
"DATE",
39+
"DATETIME",
40+
"DECIMAL",
41+
"FLOAT",
42+
"INTEGER",
43+
"REAL",
44+
"SMALLINT",
45+
"TIME",
46+
"TIMESTAMP",
47+
"VARCHAR",
48+
"NullType",
49+
"_CUSTOM_DECIMAL",
50+
"ARRAY",
51+
"DOUBLE",
52+
"GEOGRAPHY",
53+
"GEOMETRY",
54+
"MAP",
55+
"OBJECT",
56+
"TIMESTAMP_LTZ",
57+
"TIMESTAMP_NTZ",
58+
"TIMESTAMP_TZ",
59+
"VARIANT",
60+
],
61+
)
62+
def test_snowflake_data_types_instance(type_class_name):
63+
classes_b = get_classes_from_module("snowflake.sqlalchemy.snowdialect")
64+
assert type_class_name in classes_b, type_class_name

0 commit comments

Comments
 (0)