Skip to content

Commit 705a8fa

Browse files
aws_ssm - Add FileTransferManager class (#2273)
SUMMARY Add FileTransferManager class ISSUE TYPE Feature Pull Request COMPONENT NAME aws_ssm Reviewed-by: GomathiselviS <gomathiselvi@gmail.com> Reviewed-by: Mike Graves <mgraves@redhat.com> Reviewed-by: Bikouo Aubin Reviewed-by: Mandar Kulkarni <mandar242@gmail.com>
1 parent 4579add commit 705a8fa

File tree

10 files changed

+471
-215
lines changed

10 files changed

+471
-215
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
---
2+
minor_changes:
3+
- aws_ssm - Refactor connection/aws_ssm to add new ``FileTransferManager`` class and move relevant methods to the new class (https://github.com/ansible-collections/community.aws/pull/2273).

plugins/connection/aws_ssm.py

Lines changed: 118 additions & 156 deletions
Original file line numberDiff line numberDiff line change
@@ -327,15 +327,11 @@
327327
import random
328328
import re
329329
import string
330-
import time
331-
from functools import wraps
332-
from typing import Any
330+
from typing import Any, Dict
333331
from typing import Iterator
334332
from typing import List
335333
from typing import Tuple
336-
from typing import TypedDict
337334

338-
from ansible.errors import AnsibleConnectionFailure
339335
from ansible.errors import AnsibleError
340336
from ansible.errors import AnsibleFileNotFound
341337
from ansible.module_utils._text import to_bytes
@@ -353,57 +349,12 @@
353349
SSMSessionManager,
354350
)
355351

356-
display = Display()
357-
352+
from ansible_collections.community.aws.plugins.plugin_utils.ssm.filetransfermanager import FileTransferManager
353+
from ansible_collections.community.aws.plugins.plugin_utils.ssm.common import ssm_retry
354+
from ansible_collections.community.aws.plugins.plugin_utils.ssm.common import CommandResult
358355

359-
def _ssm_retry(func: Any) -> Any:
360-
"""
361-
Decorator to retry in the case of a connection failure
362-
Will retry if:
363-
* an exception is caught
364-
Will not retry if
365-
* remaining_tries is <2
366-
* retries limit reached
367-
"""
368-
369-
@wraps(func)
370-
def wrapped(self, *args: Any, **kwargs: Any) -> Any:
371-
remaining_tries = int(self.get_option("reconnection_retries")) + 1
372-
cmd_summary = f"{args[0]}..."
373-
for attempt in range(remaining_tries):
374-
try:
375-
return_tuple = func(self, *args, **kwargs)
376-
self.verbosity_display(4, f"ssm_retry: (success) {to_text(return_tuple)}")
377-
break
378-
379-
except (AnsibleConnectionFailure, Exception) as e:
380-
if attempt == remaining_tries - 1:
381-
raise
382-
pause = 2**attempt - 1
383-
pause = min(pause, 30)
384-
385-
if isinstance(e, AnsibleConnectionFailure):
386-
msg = f"ssm_retry: attempt: {attempt}, cmd ({cmd_summary}), pausing for {pause} seconds"
387-
else:
388-
msg = (
389-
f"ssm_retry: attempt: {attempt}, caught exception({e})"
390-
f"from cmd ({cmd_summary}),pausing for {pause} seconds"
391-
)
392-
393-
self.verbosity_display(2, msg)
394-
395-
time.sleep(pause)
396-
397-
# Do not attempt to reuse the existing session on retries
398-
# This will cause the SSM session to be completely restarted,
399-
# as well as reinitializing the boto3 clients
400-
self.close()
401-
402-
continue
403356

404-
return return_tuple
405-
406-
return wrapped
357+
display = Display()
407358

408359

409360
def chunks(lst: List, n: int) -> Iterator[List[Any]]:
@@ -435,14 +386,13 @@ def filter_ansi(line: str, is_windows: bool) -> str:
435386
return line
436387

437388

438-
class CommandResult(TypedDict):
389+
def escape_path(path: str) -> str:
439390
"""
440-
A dictionary that contains the executed command results.
391+
Converts a file path to a safe format by replacing backslashes with forward slashes.
392+
:param path: The file path to escape.
393+
:return: The escaped file path.
441394
"""
442-
443-
returncode: int
444-
stdout_combined: str
445-
stderr_combined: str
395+
return path.replace("\\", "/")
446396

447397

448398
class Connection(ConnectionBase, AwsConnectionPluginBase):
@@ -457,8 +407,8 @@ class Connection(ConnectionBase, AwsConnectionPluginBase):
457407
is_windows = False
458408

459409
_client = None
460-
s3_manager = None
461-
session_manager = None
410+
_s3_manager = None
411+
_session_manager = None
462412
MARK_LENGTH = 26
463413

464414
def __init__(self, *args: Any, **kwargs: Any) -> None:
@@ -470,7 +420,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
470420
self.host = self._play_context.remote_addr
471421
self._instance_id = None
472422
self.terminal_manager = TerminalManager(self)
473-
self.session_manager = None
423+
self.reconnection_retries = self.get_option("reconnection_retries")
474424

475425
if getattr(self._shell, "SHELL_FAMILY", "") == "powershell":
476426
self.delegate = None
@@ -482,6 +432,62 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
482432
self._shell_type = "powershell"
483433
self.is_windows = True
484434

435+
@property
436+
def s3_client(self) -> None:
437+
if self._s3_manager is not None:
438+
return self._s3_manager.client
439+
return None
440+
441+
@property
442+
def s3_manager(self) -> None:
443+
if self._s3_manager is None:
444+
config = {"signature_version": "s3v4", "s3": {"addressing_style": self.get_option("s3_addressing_style")}}
445+
446+
bucket_endpoint_url = self.get_option("bucket_endpoint_url")
447+
s3_endpoint_url, s3_region_name = S3ClientManager.get_bucket_endpoint(
448+
bucket_name=self.get_option("bucket_name"),
449+
bucket_endpoint_url=bucket_endpoint_url,
450+
access_key_id=self.get_option("access_key_id"),
451+
secret_key_id=self.get_option("secret_access_key"),
452+
session_token=self.get_option("session_token"),
453+
region_name=self.get_option("region"),
454+
profile_name=self.get_option("profile"),
455+
)
456+
457+
s3_client = self._get_boto_client(
458+
"s3", endpoint_url=s3_endpoint_url, region_name=s3_region_name, config=config
459+
)
460+
461+
self._s3_manager = S3ClientManager(s3_client)
462+
463+
return self._s3_manager
464+
465+
@property
466+
def session_manager(self):
467+
return self._session_manager
468+
469+
@session_manager.setter
470+
def session_manager(self, value):
471+
self._session_manager = value
472+
473+
@property
474+
def ssm_client(self):
475+
if self._client is None:
476+
config = {"signature_version": "s3v4", "s3": {"addressing_style": self.get_option("s3_addressing_style")}}
477+
478+
self._client = self._get_boto_client("ssm", region_name=self.get_option("region"), config=config)
479+
return self._client
480+
481+
@property
482+
def instance_id(self) -> str:
483+
if not self._instance_id:
484+
self._instance_id = self.host if self.get_option("instance_id") is None else self.get_option("instance_id")
485+
return self._instance_id
486+
487+
@instance_id.setter
488+
def instance_id(self, instance_id: str) -> None:
489+
self._instance_id = instance_id
490+
485491
def __del__(self) -> None:
486492
self.close()
487493

@@ -490,6 +496,7 @@ def _connect(self) -> Any:
490496
self._play_context.remote_user = getpass.getuser()
491497
if not self.session_manager:
492498
self.start_session()
499+
493500
return self
494501

495502
def _init_clients(self) -> None:
@@ -500,39 +507,22 @@ def _init_clients(self) -> None:
500507

501508
self.verbosity_display(4, "INITIALIZE BOTO3 CLIENTS")
502509

503-
# Create S3 and SSM clients
504-
config = {"signature_version": "s3v4", "s3": {"addressing_style": self.get_option("s3_addressing_style")}}
505-
506-
bucket_endpoint_url = self.get_option("bucket_endpoint_url")
507-
s3_endpoint_url, s3_region_name = S3ClientManager.get_bucket_endpoint(
508-
bucket_name=self.get_option("bucket_name"),
509-
bucket_endpoint_url=bucket_endpoint_url,
510-
access_key_id=self.get_option("access_key_id"),
511-
secret_key_id=self.get_option("secret_access_key"),
512-
session_token=self.get_option("session_token"),
513-
region_name=self.get_option("region"),
514-
profile_name=self.get_option("profile"),
515-
)
516-
517-
self.verbosity_display(4, f"BUCKET Information - Endpoint: {s3_endpoint_url} / Region: {s3_region_name}")
518-
519-
# Initialize S3ClientManager
520-
if not self.s3_manager:
521-
s3_client = self._get_boto_client(
522-
"s3", endpoint_url=s3_endpoint_url, region_name=s3_region_name, config=config
523-
)
524-
self.s3_manager = S3ClientManager(s3_client)
510+
# Initialize S3 client
511+
self.s3_manager
525512

526513
# Initialize SSM client
527-
if not self._client:
528-
self._client = self._get_boto_client("ssm", region_name=self.get_option("region"), config=config)
514+
self.ssm_client
529515

530-
@property
531-
def s3_client(self) -> None:
532-
client = None
533-
if self.s3_manager is not None:
534-
client = self.s3_manager.client
535-
return client
516+
# Initialize FileTransferManager
517+
self.file_transfer_manager = FileTransferManager(
518+
bucket_name=self.get_option("bucket_name"),
519+
instance_id=self.instance_id,
520+
s3_client=self.s3_client,
521+
reconnection_retries=self.reconnection_retries,
522+
verbosity_display=self.verbosity_display,
523+
close=self.close,
524+
exec_command=self.exec_command,
525+
)
536526

537527
def verbosity_display(self, level: int, message: str) -> None:
538528
"""
@@ -560,16 +550,6 @@ def reset(self) -> None:
560550
self.close()
561551
self.start_session()
562552

563-
@property
564-
def instance_id(self) -> str:
565-
if not self._instance_id:
566-
self._instance_id = self.host if self.get_option("instance_id") is None else self.get_option("instance_id")
567-
return self._instance_id
568-
569-
@instance_id.setter
570-
def instance_id(self, instance_id: str) -> None:
571-
self._instance_id = instance_id
572-
573553
def get_executable(self) -> str:
574554
ssm_plugin_executable = self.get_option("plugin")
575555
if ssm_plugin_executable:
@@ -597,11 +577,12 @@ def start_session(self) -> None:
597577

598578
if self.session_manager is None:
599579
self.session_manager = SSMSessionManager(
600-
self._client,
580+
self.ssm_client,
601581
self.instance_id,
602582
verbosity_display=self.verbosity_display,
603583
ssm_timeout=self.get_option("ssm_timeout"),
604584
)
585+
605586
self.session_manager.start_session(
606587
executable=executable,
607588
document_name=self.get_option("ssm_document"),
@@ -660,7 +641,7 @@ def generate_mark() -> str:
660641
mark = "".join([random.choice(string.ascii_letters) for i in range(Connection.MARK_LENGTH)])
661642
return mark
662643

663-
@_ssm_retry
644+
@ssm_retry
664645
def exec_command(self, cmd: str, in_data: bool = None, sudoable: bool = True) -> Tuple[int, str, str]:
665646
"""When running a command on the SSM host, uses generate_mark to get delimiting strings"""
666647

@@ -719,8 +700,31 @@ def _post_process(self, stdout: str, mark_begin: str) -> Tuple[str, str]:
719700

720701
return (returncode, stdout)
721702

722-
def _escape_path(self, path: str) -> str:
723-
return path.replace("\\", "/")
703+
def generate_commands(self, in_path: str, out_path: str, ssm_action: str) -> Tuple[str, str, Dict]:
704+
"""
705+
Generate S3 path and associated transport commands for file transfer.
706+
:param in_path: The local file path to transfer from.
707+
:param out_path: The remote file path to transfer to (used to build the S3 key).
708+
:param ssm_action: The SSM action to perform ("get" or "put").
709+
:return: A tuple containing:
710+
- s3_path (str): The S3 key used for the transfer.
711+
- commands (List[Dict]): A list of commands to be executed for the transfer.
712+
- put_args (Dict): Additional arguments needed for a 'put' operation.
713+
"""
714+
s3_path = escape_path(f"{self.instance_id}/{out_path}")
715+
command = ""
716+
put_args = []
717+
command, put_args = self.s3_manager.generate_host_commands(
718+
self.get_option("bucket_name"),
719+
self.get_option("bucket_sse_mode"),
720+
self.get_option("bucket_sse_kms_key_id"),
721+
s3_path,
722+
in_path,
723+
out_path,
724+
self.is_windows,
725+
ssm_action,
726+
)
727+
return s3_path, command, put_args
724728

725729
def _exec_transport_commands(self, in_path: str, out_path: str, command: dict) -> CommandResult:
726730
"""
@@ -740,52 +744,7 @@ def _exec_transport_commands(self, in_path: str, out_path: str, command: dict) -
740744

741745
return returncode, stdout, stderr
742746

743-
@_ssm_retry
744-
def _file_transport_command(
745-
self,
746-
in_path: str,
747-
out_path: str,
748-
ssm_action: str,
749-
) -> CommandResult:
750-
"""
751-
Transfer file(s) to/from host using an intermediate S3 bucket and then delete the file(s).
752-
753-
:param in_path: The input path.
754-
:param out_path: The output path.
755-
:param ssm_action: The SSM action to perform ("get" or "put").
756-
757-
:returns: The command's return code, stdout, and stderr in a tuple.
758-
"""
759-
760-
bucket_name = self.get_option("bucket_name")
761-
s3_path = self._escape_path(f"{self.instance_id}/{out_path}")
762-
763-
command, put_args = self.s3_manager.generate_host_commands(
764-
bucket_name,
765-
self.get_option("bucket_sse_mode"),
766-
self.get_option("bucket_sse_kms_key_id"),
767-
s3_path,
768-
in_path,
769-
out_path,
770-
self.is_windows,
771-
ssm_action,
772-
)
773-
774-
try:
775-
if ssm_action == "get":
776-
result = self._exec_transport_commands(in_path, out_path, command)
777-
with open(to_bytes(out_path, errors="surrogate_or_strict"), "wb") as data:
778-
self.s3_client.download_fileobj(bucket_name, s3_path, data)
779-
else:
780-
with open(to_bytes(in_path, errors="surrogate_or_strict"), "rb") as data:
781-
self.s3_client.upload_fileobj(data, bucket_name, s3_path, ExtraArgs=put_args)
782-
result = self._exec_transport_commands(in_path, out_path, command)
783-
return result
784-
finally:
785-
# Remove the files from the bucket after they've been transferred
786-
self.s3_client.delete_object(Bucket=bucket_name, Key=s3_path)
787-
788-
def put_file(self, in_path: str, out_path: str) -> Tuple[int, str, str]:
747+
def put_file(self, in_path: str, out_path: str) -> CommandResult:
789748
"""transfer a file from local to remote"""
790749

791750
super().put_file(in_path, out_path)
@@ -794,15 +753,18 @@ def put_file(self, in_path: str, out_path: str) -> Tuple[int, str, str]:
794753
if not os.path.exists(to_bytes(in_path, errors="surrogate_or_strict")):
795754
raise AnsibleFileNotFound(f"file or module does not exist: {in_path}")
796755

797-
return self._file_transport_command(in_path, out_path, "put")
756+
s3_path, command, put_args = self.generate_commands(in_path, out_path, "put")
757+
return self.file_transfer_manager._file_transport_command(in_path, out_path, "put", command, put_args, s3_path)
798758

799-
def fetch_file(self, in_path: str, out_path: str) -> Tuple[int, str, str]:
759+
def fetch_file(self, in_path: str, out_path: str) -> CommandResult:
800760
"""fetch a file from remote to local"""
801761

802762
super().fetch_file(in_path, out_path)
803763

804764
self.verbosity_display(3, f"FETCH {in_path} TO {out_path}")
805-
return self._file_transport_command(in_path, out_path, "get")
765+
766+
s3_path, command, put_args = self.generate_commands(in_path, out_path, "get")
767+
return self.file_transfer_manager._file_transport_command(in_path, out_path, "get", command, put_args, s3_path)
806768

807769
def close(self) -> None:
808770
"""terminate the connection"""

0 commit comments

Comments
 (0)