327
327
import random
328
328
import re
329
329
import string
330
- import time
331
- from functools import wraps
332
- from typing import Any
330
+ from typing import Any , Dict
333
331
from typing import Iterator
334
332
from typing import List
335
333
from typing import Tuple
336
- from typing import TypedDict
337
334
338
- from ansible .errors import AnsibleConnectionFailure
339
335
from ansible .errors import AnsibleError
340
336
from ansible .errors import AnsibleFileNotFound
341
337
from ansible .module_utils ._text import to_bytes
353
349
SSMSessionManager ,
354
350
)
355
351
356
- display = Display ()
357
-
352
+ from ansible_collections .community .aws .plugins .plugin_utils .ssm .filetransfermanager import FileTransferManager
353
+ from ansible_collections .community .aws .plugins .plugin_utils .ssm .common import ssm_retry
354
+ from ansible_collections .community .aws .plugins .plugin_utils .ssm .common import CommandResult
358
355
359
- def _ssm_retry (func : Any ) -> Any :
360
- """
361
- Decorator to retry in the case of a connection failure
362
- Will retry if:
363
- * an exception is caught
364
- Will not retry if
365
- * remaining_tries is <2
366
- * retries limit reached
367
- """
368
-
369
- @wraps (func )
370
- def wrapped (self , * args : Any , ** kwargs : Any ) -> Any :
371
- remaining_tries = int (self .get_option ("reconnection_retries" )) + 1
372
- cmd_summary = f"{ args [0 ]} ..."
373
- for attempt in range (remaining_tries ):
374
- try :
375
- return_tuple = func (self , * args , ** kwargs )
376
- self .verbosity_display (4 , f"ssm_retry: (success) { to_text (return_tuple )} " )
377
- break
378
-
379
- except (AnsibleConnectionFailure , Exception ) as e :
380
- if attempt == remaining_tries - 1 :
381
- raise
382
- pause = 2 ** attempt - 1
383
- pause = min (pause , 30 )
384
-
385
- if isinstance (e , AnsibleConnectionFailure ):
386
- msg = f"ssm_retry: attempt: { attempt } , cmd ({ cmd_summary } ), pausing for { pause } seconds"
387
- else :
388
- msg = (
389
- f"ssm_retry: attempt: { attempt } , caught exception({ e } )"
390
- f"from cmd ({ cmd_summary } ),pausing for { pause } seconds"
391
- )
392
-
393
- self .verbosity_display (2 , msg )
394
-
395
- time .sleep (pause )
396
-
397
- # Do not attempt to reuse the existing session on retries
398
- # This will cause the SSM session to be completely restarted,
399
- # as well as reinitializing the boto3 clients
400
- self .close ()
401
-
402
- continue
403
356
404
- return return_tuple
405
-
406
- return wrapped
357
+ display = Display ()
407
358
408
359
409
360
def chunks (lst : List , n : int ) -> Iterator [List [Any ]]:
@@ -435,14 +386,13 @@ def filter_ansi(line: str, is_windows: bool) -> str:
435
386
return line
436
387
437
388
438
- class CommandResult ( TypedDict ) :
389
+ def escape_path ( path : str ) -> str :
439
390
"""
440
- A dictionary that contains the executed command results.
391
+ Converts a file path to a safe format by replacing backslashes with forward slashes.
392
+ :param path: The file path to escape.
393
+ :return: The escaped file path.
441
394
"""
442
-
443
- returncode : int
444
- stdout_combined : str
445
- stderr_combined : str
395
+ return path .replace ("\\ " , "/" )
446
396
447
397
448
398
class Connection (ConnectionBase , AwsConnectionPluginBase ):
@@ -457,8 +407,8 @@ class Connection(ConnectionBase, AwsConnectionPluginBase):
457
407
is_windows = False
458
408
459
409
_client = None
460
- s3_manager = None
461
- session_manager = None
410
+ _s3_manager = None
411
+ _session_manager = None
462
412
MARK_LENGTH = 26
463
413
464
414
def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
@@ -470,7 +420,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
470
420
self .host = self ._play_context .remote_addr
471
421
self ._instance_id = None
472
422
self .terminal_manager = TerminalManager (self )
473
- self .session_manager = None
423
+ self .reconnection_retries = self . get_option ( "reconnection_retries" )
474
424
475
425
if getattr (self ._shell , "SHELL_FAMILY" , "" ) == "powershell" :
476
426
self .delegate = None
@@ -482,6 +432,62 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
482
432
self ._shell_type = "powershell"
483
433
self .is_windows = True
484
434
435
+ @property
436
+ def s3_client (self ) -> None :
437
+ if self ._s3_manager is not None :
438
+ return self ._s3_manager .client
439
+ return None
440
+
441
+ @property
442
+ def s3_manager (self ) -> None :
443
+ if self ._s3_manager is None :
444
+ config = {"signature_version" : "s3v4" , "s3" : {"addressing_style" : self .get_option ("s3_addressing_style" )}}
445
+
446
+ bucket_endpoint_url = self .get_option ("bucket_endpoint_url" )
447
+ s3_endpoint_url , s3_region_name = S3ClientManager .get_bucket_endpoint (
448
+ bucket_name = self .get_option ("bucket_name" ),
449
+ bucket_endpoint_url = bucket_endpoint_url ,
450
+ access_key_id = self .get_option ("access_key_id" ),
451
+ secret_key_id = self .get_option ("secret_access_key" ),
452
+ session_token = self .get_option ("session_token" ),
453
+ region_name = self .get_option ("region" ),
454
+ profile_name = self .get_option ("profile" ),
455
+ )
456
+
457
+ s3_client = self ._get_boto_client (
458
+ "s3" , endpoint_url = s3_endpoint_url , region_name = s3_region_name , config = config
459
+ )
460
+
461
+ self ._s3_manager = S3ClientManager (s3_client )
462
+
463
+ return self ._s3_manager
464
+
465
+ @property
466
+ def session_manager (self ):
467
+ return self ._session_manager
468
+
469
+ @session_manager .setter
470
+ def session_manager (self , value ):
471
+ self ._session_manager = value
472
+
473
+ @property
474
+ def ssm_client (self ):
475
+ if self ._client is None :
476
+ config = {"signature_version" : "s3v4" , "s3" : {"addressing_style" : self .get_option ("s3_addressing_style" )}}
477
+
478
+ self ._client = self ._get_boto_client ("ssm" , region_name = self .get_option ("region" ), config = config )
479
+ return self ._client
480
+
481
+ @property
482
+ def instance_id (self ) -> str :
483
+ if not self ._instance_id :
484
+ self ._instance_id = self .host if self .get_option ("instance_id" ) is None else self .get_option ("instance_id" )
485
+ return self ._instance_id
486
+
487
+ @instance_id .setter
488
+ def instance_id (self , instance_id : str ) -> None :
489
+ self ._instance_id = instance_id
490
+
485
491
def __del__ (self ) -> None :
486
492
self .close ()
487
493
@@ -490,6 +496,7 @@ def _connect(self) -> Any:
490
496
self ._play_context .remote_user = getpass .getuser ()
491
497
if not self .session_manager :
492
498
self .start_session ()
499
+
493
500
return self
494
501
495
502
def _init_clients (self ) -> None :
@@ -500,39 +507,22 @@ def _init_clients(self) -> None:
500
507
501
508
self .verbosity_display (4 , "INITIALIZE BOTO3 CLIENTS" )
502
509
503
- # Create S3 and SSM clients
504
- config = {"signature_version" : "s3v4" , "s3" : {"addressing_style" : self .get_option ("s3_addressing_style" )}}
505
-
506
- bucket_endpoint_url = self .get_option ("bucket_endpoint_url" )
507
- s3_endpoint_url , s3_region_name = S3ClientManager .get_bucket_endpoint (
508
- bucket_name = self .get_option ("bucket_name" ),
509
- bucket_endpoint_url = bucket_endpoint_url ,
510
- access_key_id = self .get_option ("access_key_id" ),
511
- secret_key_id = self .get_option ("secret_access_key" ),
512
- session_token = self .get_option ("session_token" ),
513
- region_name = self .get_option ("region" ),
514
- profile_name = self .get_option ("profile" ),
515
- )
516
-
517
- self .verbosity_display (4 , f"BUCKET Information - Endpoint: { s3_endpoint_url } / Region: { s3_region_name } " )
518
-
519
- # Initialize S3ClientManager
520
- if not self .s3_manager :
521
- s3_client = self ._get_boto_client (
522
- "s3" , endpoint_url = s3_endpoint_url , region_name = s3_region_name , config = config
523
- )
524
- self .s3_manager = S3ClientManager (s3_client )
510
+ # Initialize S3 client
511
+ self .s3_manager
525
512
526
513
# Initialize SSM client
527
- if not self ._client :
528
- self ._client = self ._get_boto_client ("ssm" , region_name = self .get_option ("region" ), config = config )
514
+ self .ssm_client
529
515
530
- @property
531
- def s3_client (self ) -> None :
532
- client = None
533
- if self .s3_manager is not None :
534
- client = self .s3_manager .client
535
- return client
516
+ # Initialize FileTransferManager
517
+ self .file_transfer_manager = FileTransferManager (
518
+ bucket_name = self .get_option ("bucket_name" ),
519
+ instance_id = self .instance_id ,
520
+ s3_client = self .s3_client ,
521
+ reconnection_retries = self .reconnection_retries ,
522
+ verbosity_display = self .verbosity_display ,
523
+ close = self .close ,
524
+ exec_command = self .exec_command ,
525
+ )
536
526
537
527
def verbosity_display (self , level : int , message : str ) -> None :
538
528
"""
@@ -560,16 +550,6 @@ def reset(self) -> None:
560
550
self .close ()
561
551
self .start_session ()
562
552
563
- @property
564
- def instance_id (self ) -> str :
565
- if not self ._instance_id :
566
- self ._instance_id = self .host if self .get_option ("instance_id" ) is None else self .get_option ("instance_id" )
567
- return self ._instance_id
568
-
569
- @instance_id .setter
570
- def instance_id (self , instance_id : str ) -> None :
571
- self ._instance_id = instance_id
572
-
573
553
def get_executable (self ) -> str :
574
554
ssm_plugin_executable = self .get_option ("plugin" )
575
555
if ssm_plugin_executable :
@@ -597,11 +577,12 @@ def start_session(self) -> None:
597
577
598
578
if self .session_manager is None :
599
579
self .session_manager = SSMSessionManager (
600
- self ._client ,
580
+ self .ssm_client ,
601
581
self .instance_id ,
602
582
verbosity_display = self .verbosity_display ,
603
583
ssm_timeout = self .get_option ("ssm_timeout" ),
604
584
)
585
+
605
586
self .session_manager .start_session (
606
587
executable = executable ,
607
588
document_name = self .get_option ("ssm_document" ),
@@ -660,7 +641,7 @@ def generate_mark() -> str:
660
641
mark = "" .join ([random .choice (string .ascii_letters ) for i in range (Connection .MARK_LENGTH )])
661
642
return mark
662
643
663
- @_ssm_retry
644
+ @ssm_retry
664
645
def exec_command (self , cmd : str , in_data : bool = None , sudoable : bool = True ) -> Tuple [int , str , str ]:
665
646
"""When running a command on the SSM host, uses generate_mark to get delimiting strings"""
666
647
@@ -719,8 +700,31 @@ def _post_process(self, stdout: str, mark_begin: str) -> Tuple[str, str]:
719
700
720
701
return (returncode , stdout )
721
702
722
- def _escape_path (self , path : str ) -> str :
723
- return path .replace ("\\ " , "/" )
703
+ def generate_commands (self , in_path : str , out_path : str , ssm_action : str ) -> Tuple [str , str , Dict ]:
704
+ """
705
+ Generate S3 path and associated transport commands for file transfer.
706
+ :param in_path: The local file path to transfer from.
707
+ :param out_path: The remote file path to transfer to (used to build the S3 key).
708
+ :param ssm_action: The SSM action to perform ("get" or "put").
709
+ :return: A tuple containing:
710
+ - s3_path (str): The S3 key used for the transfer.
711
+ - commands (List[Dict]): A list of commands to be executed for the transfer.
712
+ - put_args (Dict): Additional arguments needed for a 'put' operation.
713
+ """
714
+ s3_path = escape_path (f"{ self .instance_id } /{ out_path } " )
715
+ command = ""
716
+ put_args = []
717
+ command , put_args = self .s3_manager .generate_host_commands (
718
+ self .get_option ("bucket_name" ),
719
+ self .get_option ("bucket_sse_mode" ),
720
+ self .get_option ("bucket_sse_kms_key_id" ),
721
+ s3_path ,
722
+ in_path ,
723
+ out_path ,
724
+ self .is_windows ,
725
+ ssm_action ,
726
+ )
727
+ return s3_path , command , put_args
724
728
725
729
def _exec_transport_commands (self , in_path : str , out_path : str , command : dict ) -> CommandResult :
726
730
"""
@@ -740,52 +744,7 @@ def _exec_transport_commands(self, in_path: str, out_path: str, command: dict) -
740
744
741
745
return returncode , stdout , stderr
742
746
743
- @_ssm_retry
744
- def _file_transport_command (
745
- self ,
746
- in_path : str ,
747
- out_path : str ,
748
- ssm_action : str ,
749
- ) -> CommandResult :
750
- """
751
- Transfer file(s) to/from host using an intermediate S3 bucket and then delete the file(s).
752
-
753
- :param in_path: The input path.
754
- :param out_path: The output path.
755
- :param ssm_action: The SSM action to perform ("get" or "put").
756
-
757
- :returns: The command's return code, stdout, and stderr in a tuple.
758
- """
759
-
760
- bucket_name = self .get_option ("bucket_name" )
761
- s3_path = self ._escape_path (f"{ self .instance_id } /{ out_path } " )
762
-
763
- command , put_args = self .s3_manager .generate_host_commands (
764
- bucket_name ,
765
- self .get_option ("bucket_sse_mode" ),
766
- self .get_option ("bucket_sse_kms_key_id" ),
767
- s3_path ,
768
- in_path ,
769
- out_path ,
770
- self .is_windows ,
771
- ssm_action ,
772
- )
773
-
774
- try :
775
- if ssm_action == "get" :
776
- result = self ._exec_transport_commands (in_path , out_path , command )
777
- with open (to_bytes (out_path , errors = "surrogate_or_strict" ), "wb" ) as data :
778
- self .s3_client .download_fileobj (bucket_name , s3_path , data )
779
- else :
780
- with open (to_bytes (in_path , errors = "surrogate_or_strict" ), "rb" ) as data :
781
- self .s3_client .upload_fileobj (data , bucket_name , s3_path , ExtraArgs = put_args )
782
- result = self ._exec_transport_commands (in_path , out_path , command )
783
- return result
784
- finally :
785
- # Remove the files from the bucket after they've been transferred
786
- self .s3_client .delete_object (Bucket = bucket_name , Key = s3_path )
787
-
788
- def put_file (self , in_path : str , out_path : str ) -> Tuple [int , str , str ]:
747
+ def put_file (self , in_path : str , out_path : str ) -> CommandResult :
789
748
"""transfer a file from local to remote"""
790
749
791
750
super ().put_file (in_path , out_path )
@@ -794,15 +753,18 @@ def put_file(self, in_path: str, out_path: str) -> Tuple[int, str, str]:
794
753
if not os .path .exists (to_bytes (in_path , errors = "surrogate_or_strict" )):
795
754
raise AnsibleFileNotFound (f"file or module does not exist: { in_path } " )
796
755
797
- return self ._file_transport_command (in_path , out_path , "put" )
756
+ s3_path , command , put_args = self .generate_commands (in_path , out_path , "put" )
757
+ return self .file_transfer_manager ._file_transport_command (in_path , out_path , "put" , command , put_args , s3_path )
798
758
799
- def fetch_file (self , in_path : str , out_path : str ) -> Tuple [ int , str , str ] :
759
+ def fetch_file (self , in_path : str , out_path : str ) -> CommandResult :
800
760
"""fetch a file from remote to local"""
801
761
802
762
super ().fetch_file (in_path , out_path )
803
763
804
764
self .verbosity_display (3 , f"FETCH { in_path } TO { out_path } " )
805
- return self ._file_transport_command (in_path , out_path , "get" )
765
+
766
+ s3_path , command , put_args = self .generate_commands (in_path , out_path , "get" )
767
+ return self .file_transfer_manager ._file_transport_command (in_path , out_path , "get" , command , put_args , s3_path )
806
768
807
769
def close (self ) -> None :
808
770
"""terminate the connection"""
0 commit comments