11import time
2- from typing import Dict , Tuple , List , Optional , Any , Union , Sequence
2+ from typing import Dict , Tuple , List , Optional , Any , Union , Sequence , BinaryIO
33import pandas
44
55try :
6767)
6868from databricks .sql .telemetry .latency_logger import log_latency
6969from databricks .sql .telemetry .models .enums import StatementType
70+ from databricks .sql .common .http import DatabricksHttpClient , HttpMethod
7071
7172logger = logging .getLogger (__name__ )
7273
@@ -647,8 +648,34 @@ def _check_not_closed(self):
647648 session_id_hex = self .connection .get_session_id_hex (),
648649 )
649650
651+ def _validate_staging_http_response (
652+ self , response : requests .Response , operation_name : str = "staging operation"
653+ ) -> None :
654+
655+ # Check response codes
656+ OK = requests .codes .ok # 200
657+ CREATED = requests .codes .created # 201
658+ ACCEPTED = requests .codes .accepted # 202
659+ NO_CONTENT = requests .codes .no_content # 204
660+
661+ if response .status_code not in [OK , CREATED , NO_CONTENT , ACCEPTED ]:
662+ raise OperationalError (
663+ f"{ operation_name } over HTTP was unsuccessful: { response .status_code } -{ response .text } " ,
664+ session_id_hex = self .connection .get_session_id_hex (),
665+ )
666+
667+ if response .status_code == ACCEPTED :
668+ logger .debug (
669+ "Response code %s from server indicates %s was accepted "
670+ "but not yet applied on the server. It's possible this command may fail later." ,
671+ ACCEPTED ,
672+ operation_name ,
673+ )
674+
650675 def _handle_staging_operation (
651- self , staging_allowed_local_path : Union [None , str , List [str ]]
676+ self ,
677+ staging_allowed_local_path : Union [None , str , List [str ]],
678+ input_stream : Optional [BinaryIO ] = None ,
652679 ):
653680 """Fetch the HTTP request instruction from a staging ingestion command
654681 and call the designated handler.
@@ -657,6 +684,28 @@ def _handle_staging_operation(
657684 is not descended from staging_allowed_local_path.
658685 """
659686
687+ assert self .active_result_set is not None
688+ row = self .active_result_set .fetchone ()
689+ assert row is not None
690+
691+ # Parse headers
692+ headers = (
693+ json .loads (row .headers ) if isinstance (row .headers , str ) else row .headers
694+ )
695+ headers = dict (headers ) if headers else {}
696+
697+ # Handle __input_stream__ token for PUT operations
698+ if (
699+ row .operation == "PUT"
700+ and getattr (row , "localFile" , None ) == "__input_stream__"
701+ ):
702+ return self ._handle_staging_put_stream (
703+ presigned_url = row .presignedUrl ,
704+ stream = input_stream ,
705+ headers = headers ,
706+ )
707+
708+ # For non-streaming operations, validate staging_allowed_local_path
660709 if isinstance (staging_allowed_local_path , type (str ())):
661710 _staging_allowed_local_paths = [staging_allowed_local_path ]
662711 elif isinstance (staging_allowed_local_path , type (list ())):
@@ -671,10 +720,6 @@ def _handle_staging_operation(
671720 os .path .abspath (i ) for i in _staging_allowed_local_paths
672721 ]
673722
674- assert self .active_result_set is not None
675- row = self .active_result_set .fetchone ()
676- assert row is not None
677-
678723 # Must set to None in cases where server response does not include localFile
679724 abs_localFile = None
680725
@@ -697,19 +742,16 @@ def _handle_staging_operation(
697742 session_id_hex = self .connection .get_session_id_hex (),
698743 )
699744
700- # May be real headers, or could be json string
701- headers = (
702- json .loads (row .headers ) if isinstance (row .headers , str ) else row .headers
703- )
704-
705745 handler_args = {
706746 "presigned_url" : row .presignedUrl ,
707747 "local_file" : abs_localFile ,
708- "headers" : dict ( headers ) or {} ,
748+ "headers" : headers ,
709749 }
710750
711751 logger .debug (
712- f"Attempting staging operation indicated by server: { row .operation } - { getattr (row , 'localFile' , '' )} "
752+ "Attempting staging operation indicated by server: %s - %s" ,
753+ row .operation ,
754+ getattr (row , "localFile" , "" ),
713755 )
714756
715757 # TODO: Create a retry loop here to re-attempt if the request times out or fails
@@ -728,6 +770,43 @@ def _handle_staging_operation(
728770 session_id_hex = self .connection .get_session_id_hex (),
729771 )
730772
773+ @log_latency (StatementType .SQL )
774+ def _handle_staging_put_stream (
775+ self ,
776+ presigned_url : str ,
777+ stream : BinaryIO ,
778+ headers : dict = {},
779+ ) -> None :
780+ """Handle PUT operation with streaming data.
781+
782+ Args:
783+ presigned_url: The presigned URL for upload
784+ stream: Binary stream to upload
785+ headers: HTTP headers
786+
787+ Raises:
788+ ProgrammingError: If no input stream is provided
789+ OperationalError: If the upload fails
790+ """
791+
792+ if not stream :
793+ raise ProgrammingError (
794+ "No input stream provided for streaming operation" ,
795+ session_id_hex = self .connection .get_session_id_hex (),
796+ )
797+
798+ http_client = DatabricksHttpClient .get_instance ()
799+
800+ # Stream directly to presigned URL
801+ with http_client .execute (
802+ method = HttpMethod .PUT ,
803+ url = presigned_url ,
804+ data = stream ,
805+ headers = headers ,
806+ timeout = 300 , # 5 minute timeout
807+ ) as response :
808+ self ._validate_staging_http_response (response , "stream upload" )
809+
731810 @log_latency (StatementType .SQL )
732811 def _handle_staging_put (
733812 self , presigned_url : str , local_file : str , headers : Optional [dict ] = None
@@ -746,27 +825,7 @@ def _handle_staging_put(
746825 with open (local_file , "rb" ) as fh :
747826 r = requests .put (url = presigned_url , data = fh , headers = headers )
748827
749- # fmt: off
750- # Design borrowed from: https://stackoverflow.com/a/2342589/5093960
751-
752- OK = requests .codes .ok # 200
753- CREATED = requests .codes .created # 201
754- ACCEPTED = requests .codes .accepted # 202
755- NO_CONTENT = requests .codes .no_content # 204
756-
757- # fmt: on
758-
759- if r .status_code not in [OK , CREATED , NO_CONTENT , ACCEPTED ]:
760- raise OperationalError (
761- f"Staging operation over HTTP was unsuccessful: { r .status_code } -{ r .text } " ,
762- session_id_hex = self .connection .get_session_id_hex (),
763- )
764-
765- if r .status_code == ACCEPTED :
766- logger .debug (
767- f"Response code { ACCEPTED } from server indicates ingestion command was accepted "
768- + "but not yet applied on the server. It's possible this command may fail later."
769- )
828+ self ._validate_staging_http_response (r , "file upload" )
770829
771830 @log_latency (StatementType .SQL )
772831 def _handle_staging_get (
@@ -816,6 +875,7 @@ def execute(
816875 operation : str ,
817876 parameters : Optional [TParameterCollection ] = None ,
818877 enforce_embedded_schema_correctness = False ,
878+ input_stream : Optional [BinaryIO ] = None ,
819879 ) -> "Cursor" :
820880 """
821881 Execute a query and wait for execution to complete.
@@ -852,7 +912,6 @@ def execute(
852912 logger .debug (
853913 "Cursor.execute(operation=%s, parameters=%s)" , operation , parameters
854914 )
855-
856915 param_approach = self ._determine_parameter_approach (parameters )
857916 if param_approach == ParameterApproach .NONE :
858917 prepared_params = NO_NATIVE_PARAMS
@@ -890,7 +949,8 @@ def execute(
890949
891950 if self .active_result_set and self .active_result_set .is_staging_operation :
892951 self ._handle_staging_operation (
893- staging_allowed_local_path = self .connection .staging_allowed_local_path
952+ staging_allowed_local_path = self .connection .staging_allowed_local_path ,
953+ input_stream = input_stream ,
894954 )
895955
896956 return self
0 commit comments