11from enum import Enum
2- from typing import Optional , Any , Union
2+ from typing import Dict , Optional , Any , Union
33import uuid
44import logging
55
@@ -43,6 +43,7 @@ def __init__(
4343 backend_type : BackendType ,
4444 guid : Any ,
4545 secret : Optional [Any ] = None ,
46+ info : Optional [Dict [str , Any ]] = None ,
4647 ):
4748 """
4849 Initialize a SessionId.
@@ -51,13 +52,15 @@ def __init__(
5152 backend_type: The type of backend (THRIFT or SEA)
5253 guid: The primary identifier for the session
5354 secret: The secret part of the identifier (only used for Thrift)
55+ info: Additional information about the session
5456 """
5557 self .backend_type = backend_type
5658 self .guid = guid
5759 self .secret = secret
60+ self .info = info or {}
5861
5962 @classmethod
60- def from_thrift_handle (cls , session_handle ):
63+ def from_thrift_handle (cls , session_handle , info : Optional [ Dict [ str , Any ]] = None ):
6164 """
6265 Create a SessionId from a Thrift session handle.
6366
@@ -67,16 +70,23 @@ def from_thrift_handle(cls, session_handle):
6770 Returns:
6871 A SessionId instance
6972 """
70- if session_handle is None or session_handle . sessionId is None :
73+ if session_handle is None :
7174 return None
7275
7376 guid_bytes = session_handle .sessionId .guid
7477 secret_bytes = session_handle .sessionId .secret
7578
76- return cls (BackendType .THRIFT , guid_bytes , secret_bytes )
79+ if session_handle .serverProtocolVersion is not None :
80+ if info is None :
81+ info = {}
82+ info ["serverProtocolVersion" ] = session_handle .serverProtocolVersion
83+
84+ return cls (BackendType .THRIFT , guid_bytes , secret_bytes , info )
7785
7886 @classmethod
79- def from_sea_session_id (cls , session_id : str ):
87+ def from_sea_session_id (
88+ cls , session_id : str , info : Optional [Dict [str , Any ]] = None
89+ ):
8090 """
8191 Create a SessionId from a SEA session ID.
8292
@@ -86,7 +96,7 @@ def from_sea_session_id(cls, session_id: str):
8696 Returns:
8797 A SessionId instance
8898 """
89- return cls (BackendType .SEA , session_id )
99+ return cls (BackendType .SEA , session_id , info = info )
90100
91101 def to_thrift_handle (self ):
92102 """
@@ -101,7 +111,10 @@ def to_thrift_handle(self):
101111 from databricks .sql .thrift_api .TCLIService import ttypes
102112
103113 handle_identifier = ttypes .THandleIdentifier (guid = self .guid , secret = self .secret )
104- return ttypes .TSessionHandle (sessionId = handle_identifier )
114+ server_protocol_version = self .info .get ("serverProtocolVersion" )
115+ return ttypes .TSessionHandle (
116+ sessionId = handle_identifier , serverProtocolVersion = server_protocol_version
117+ )
105118
106119 def to_sea_session_id (self ):
107120 """
@@ -129,19 +142,12 @@ def to_hex_id(self) -> str:
129142
130143 def get_protocol_version (self ):
131144 """
132- Since the sessionHandle will sometimes have a serverProtocolVersion, it takes
133- precedence over the serverProtocolVersion defined in the OpenSessionResponse.
145+ Get the server protocol version for this session.
146+
147+ Returns:
148+ The server protocol version or None if this is not a Thrift session ID
134149 """
135- if self .backend_type != BackendType .THRIFT :
136- return None
137- session_handle = self .to_thrift_handle ()
138- if (
139- session_handle
140- and hasattr (session_handle , "serverProtocolVersion" )
141- and session_handle .serverProtocolVersion
142- ):
143- return session_handle .serverProtocolVersion
144- return None
150+ return self .info .get ("serverProtocolVersion" )
145151
146152
147153class CommandId :
@@ -190,7 +196,7 @@ def from_thrift_handle(cls, operation_handle):
190196 Returns:
191197 A CommandId instance
192198 """
193- if operation_handle is None or operation_handle . operationId is None :
199+ if operation_handle is None :
194200 return None
195201
196202 guid_bytes = operation_handle .operationId .guid
0 commit comments