Skip to content

Commit 64be07b

Browse files
remove un-necessary changes in thrift field if tests
Signed-off-by: varun-edachali-dbx <varun.edachali@databricks.com>
1 parent 7ea7b75 commit 64be07b

File tree

1 file changed

+21
-26
lines changed

1 file changed

+21
-26
lines changed

tests/unit/test_thrift_field_ids.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -16,77 +16,72 @@ class TestThriftFieldIds:
1616

1717
# Known exceptions that exceed the field ID limit
1818
KNOWN_EXCEPTIONS = {
19-
("TExecuteStatementReq", "enforceEmbeddedSchemaCorrectness"): 3353,
20-
("TSessionHandle", "serverProtocolVersion"): 3329,
19+
('TExecuteStatementReq', 'enforceEmbeddedSchemaCorrectness'): 3353,
20+
('TSessionHandle', 'serverProtocolVersion'): 3329,
2121
}
2222

2323
def test_all_thrift_field_ids_are_within_allowed_range(self):
2424
"""
2525
Validates that all field IDs in Thrift-generated classes are within the allowed range.
26-
26+
2727
This test prevents field ID conflicts and ensures compatibility with different
2828
Thrift implementations and protocols.
2929
"""
3030
violations = []
31-
31+
3232
# Get all classes from the ttypes module
3333
for name, obj in inspect.getmembers(ttypes):
34-
if (
35-
inspect.isclass(obj)
36-
and hasattr(obj, "thrift_spec")
37-
and obj.thrift_spec is not None
38-
):
39-
34+
if (inspect.isclass(obj) and
35+
hasattr(obj, 'thrift_spec') and
36+
obj.thrift_spec is not None):
37+
4038
self._check_class_field_ids(obj, name, violations)
41-
39+
4240
if violations:
4341
error_message = self._build_error_message(violations)
4442
pytest.fail(error_message)
4543

4644
def _check_class_field_ids(self, cls, class_name, violations):
4745
"""
4846
Checks all field IDs in a Thrift class and reports violations.
49-
47+
5048
Args:
5149
cls: The Thrift class to check
5250
class_name: Name of the class for error reporting
5351
violations: List to append violation messages to
5452
"""
5553
thrift_spec = cls.thrift_spec
56-
54+
5755
if not isinstance(thrift_spec, (tuple, list)):
5856
return
59-
57+
6058
for spec_entry in thrift_spec:
6159
if spec_entry is None:
6260
continue
63-
61+
6462
# Thrift spec format: (field_id, field_type, field_name, ...)
6563
if isinstance(spec_entry, (tuple, list)) and len(spec_entry) >= 3:
6664
field_id = spec_entry[0]
6765
field_name = spec_entry[2]
68-
66+
6967
# Skip known exceptions
7068
if (class_name, field_name) in self.KNOWN_EXCEPTIONS:
7169
continue
72-
70+
7371
if isinstance(field_id, int) and field_id >= self.MAX_ALLOWED_FIELD_ID:
7472
violations.append(
7573
"{} field '{}' has field ID {} (exceeds maximum of {})".format(
76-
class_name,
77-
field_name,
78-
field_id,
79-
self.MAX_ALLOWED_FIELD_ID - 1,
74+
class_name, field_name, field_id, self.MAX_ALLOWED_FIELD_ID - 1
8075
)
8176
)
8277

8378
def _build_error_message(self, violations):
8479
"""
8580
Builds a comprehensive error message for field ID violations.
86-
81+
8782
Args:
8883
violations: List of violation messages
89-
84+
9085
Returns:
9186
Formatted error message
9287
"""
@@ -95,8 +90,8 @@ def _build_error_message(self, violations):
9590
"This can cause compatibility issues and conflicts with reserved ID ranges.\n"
9691
"Violations found:\n".format(self.MAX_ALLOWED_FIELD_ID - 1)
9792
)
98-
93+
9994
for violation in violations:
10095
error_message += " - {}\n".format(violation)
101-
102-
return error_message
96+
97+
return error_message

0 commit comments

Comments
 (0)