diff --git a/hazelcast/asyncio/__init__.py b/hazelcast/asyncio/__init__.py new file mode 100644 index 0000000000..6137aac760 --- /dev/null +++ b/hazelcast/asyncio/__init__.py @@ -0,0 +1,2 @@ +from hazelcast.asyncio.client import HazelcastClient +from hazelcast.internal.asyncio_proxy.map import Map diff --git a/hazelcast/asyncio/client.py b/hazelcast/asyncio/client.py new file mode 100644 index 0000000000..0f6db252fe --- /dev/null +++ b/hazelcast/asyncio/client.py @@ -0,0 +1,367 @@ +import asyncio +import logging +import sys +import typing + +from hazelcast.internal.asyncio_cluster import ClusterService, _InternalClusterService +from hazelcast.internal.asyncio_compact import CompactSchemaService +from hazelcast.config import Config +from hazelcast.internal.asyncio_connection import ConnectionManager, DefaultAddressProvider +from hazelcast.core import DistributedObjectEvent, DistributedObjectInfo +from hazelcast.cp import CPSubsystem, ProxySessionManager +from hazelcast.discovery import HazelcastCloudAddressProvider +from hazelcast.errors import IllegalStateError, InvalidConfigurationError +from hazelcast.internal.asyncio_invocation import InvocationService, Invocation +from hazelcast.lifecycle import LifecycleService, LifecycleState, _InternalLifecycleService +from hazelcast.internal.asyncio_listener import ClusterViewListenerService, ListenerService +from hazelcast.near_cache import NearCacheManager +from hazelcast.partition import PartitionService, _InternalPartitionService +from hazelcast.protocol.codec import ( + client_add_distributed_object_listener_codec, + client_get_distributed_objects_codec, + client_remove_distributed_object_listener_codec, +) +from hazelcast.internal.asyncio_proxy.manager import ( + MAP_SERVICE, + ProxyManager, +) +from hazelcast.internal.asyncio_proxy.base import Proxy +from hazelcast.internal.asyncio_proxy.map import Map +from hazelcast.internal.asyncio_reactor import AsyncioReactor +from hazelcast.serialization import SerializationServiceV1 +from hazelcast.sql import SqlService, _InternalSqlService +from hazelcast.statistics import Statistics +from hazelcast.types import KeyType, ValueType, ItemType, MessageType +from hazelcast.util import AtomicInteger, RoundRobinLB + +__all__ = ("HazelcastClient",) + +_logger = logging.getLogger(__name__) + + +class HazelcastClient: + + _CLIENT_ID = AtomicInteger() + + @classmethod + async def create_and_start(cls, config: Config | None = None, **kwargs) -> "HazelcastClient": + client = HazelcastClient(config, **kwargs) + await client._start() + return client + + def __init__(self, config: Config | None = None, **kwargs): + if config: + if kwargs: + raise InvalidConfigurationError( + "Ambiguous client configuration is found. Either provide " + "the config object as the only parameter, or do not " + "pass it and use keyword arguments to configure the " + "client." + ) + else: + config = Config.from_dict(kwargs) + + self._config = config + self._context = _ClientContext() + client_id = HazelcastClient._CLIENT_ID.get_and_increment() + self._name = self._create_client_name(client_id) + self._reactor = AsyncioReactor() + self._serialization_service = SerializationServiceV1(config) + self._near_cache_manager = NearCacheManager(config, self._serialization_service) + self._internal_lifecycle_service = _InternalLifecycleService(config) + self._lifecycle_service = LifecycleService(self._internal_lifecycle_service) + self._internal_cluster_service = _InternalClusterService(self, config) + self._cluster_service = ClusterService(self._internal_cluster_service) + self._invocation_service = InvocationService(self, config, self._reactor) + self._compact_schema_service = CompactSchemaService( + self._serialization_service.compact_stream_serializer, + self._invocation_service, + self._cluster_service, + self._reactor, + self._config, + ) + self._address_provider = self._create_address_provider() + self._internal_partition_service = _InternalPartitionService(self) + self._partition_service = PartitionService( + self._internal_partition_service, + self._serialization_service, + self._compact_schema_service.send_schema_and_retry, + ) + self._connection_manager = ConnectionManager( + self, + config, + self._reactor, + self._address_provider, + self._internal_lifecycle_service, + self._internal_partition_service, + self._internal_cluster_service, + self._invocation_service, + self._near_cache_manager, + self._send_state_to_cluster, + ) + self._load_balancer = self._init_load_balancer(config) + self._listener_service = ListenerService( + self, + config, + self._connection_manager, + self._invocation_service, + self._compact_schema_service, + ) + self._proxy_manager = ProxyManager(self._context) + self._cp_subsystem = CPSubsystem(self._context) + self._proxy_session_manager = ProxySessionManager(self._context) + self._lock_reference_id_generator = AtomicInteger(1) + self._statistics = Statistics( + self, + config, + self._reactor, + self._connection_manager, + self._invocation_service, + self._near_cache_manager, + ) + self._cluster_view_listener = ClusterViewListenerService( + self, + self._connection_manager, + self._internal_partition_service, + self._internal_cluster_service, + self._invocation_service, + ) + self._shutdown_lock = asyncio.Lock() + self._invocation_service.init( + self._internal_partition_service, + self._connection_manager, + self._listener_service, + self._compact_schema_service, + ) + self._internal_sql_service = _InternalSqlService( + self._connection_manager, + self._serialization_service, + self._invocation_service, + self._compact_schema_service.send_schema_and_retry, + ) + self._sql_service = SqlService(self._internal_sql_service) + self._init_context() + + def _init_context(self): + self._context.init_context( + self, + self._config, + self._invocation_service, + self._internal_partition_service, + self._internal_cluster_service, + self._connection_manager, + self._serialization_service, + self._listener_service, + self._proxy_manager, + self._near_cache_manager, + self._lock_reference_id_generator, + self._name, + self._proxy_session_manager, + self._reactor, + self._compact_schema_service, + ) + + async def _start(self): + self._reactor.start() + try: + self._internal_lifecycle_service.start() + self._invocation_service.start() + membership_listeners = self._config.membership_listeners + self._internal_cluster_service.start(self._connection_manager, membership_listeners) + self._cluster_view_listener.start() + await self._connection_manager.start(self._load_balancer) + sync_start = not self._config.async_start + if sync_start: + await self._internal_cluster_service.wait_initial_member_list_fetched() + await self._connection_manager.connect_to_all_cluster_members(sync_start) + self._listener_service.start() + await self._invocation_service.add_backup_listener() + self._load_balancer.init(self._cluster_service) + self._statistics.start() + except Exception: + await self.shutdown() + raise + _logger.info("Client started") + + async def get_map(self, name: str) -> Map[KeyType, ValueType]: + return await self._proxy_manager.get_or_create(MAP_SERVICE, name) + + async def add_distributed_object_listener( + self, listener_func: typing.Callable[[DistributedObjectEvent], None] + ) -> str: + is_smart = self._config.smart_routing + codec = client_add_distributed_object_listener_codec + request = codec.encode_request(is_smart) + + def handle_distributed_object_event(name, service_name, event_type, source): + event = DistributedObjectEvent(name, service_name, event_type, source) + listener_func(event) + + def event_handler(client_message): + return codec.handle(client_message, handle_distributed_object_event) + + return await self._listener_service.register_listener( + request, + codec.decode_response, + client_remove_distributed_object_listener_codec.encode_request, + event_handler, + ) + + async def remove_distributed_object_listener(self, registration_id: str) -> bool: + return await self._listener_service.deregister_listener(registration_id) + + async def get_distributed_objects(self) -> typing.List[Proxy]: + request = client_get_distributed_objects_codec.encode_request() + invocation = Invocation(request, response_handler=lambda m: m) + await self._invocation_service.ainvoke(invocation) + + local_distributed_object_infos = { + DistributedObjectInfo(dist_obj.service_name, dist_obj.name) + for dist_obj in self._proxy_manager.get_distributed_objects() + } + + response = client_get_distributed_objects_codec.decode_response(invocation.future.result()) + async with asyncio.TaskGroup() as tg: # type: ignore[attr-defined] + for dist_obj_info in response: + local_distributed_object_infos.discard(dist_obj_info) + tg.create_task( + self._proxy_manager.get_or_create( + dist_obj_info.service_name, dist_obj_info.name, create_on_remote=False + ) + ) + + async with asyncio.TaskGroup() as tg: # type: ignore[attr-defined] + for dist_obj_info in local_distributed_object_infos: + tg.create_task( + self._proxy_manager.destroy_proxy( + dist_obj_info.service_name, dist_obj_info.name, destroy_on_remote=False + ) + ) + + return self._proxy_manager.get_distributed_objects() + + async def shutdown(self) -> None: + async with self._shutdown_lock: + if self._internal_lifecycle_service.running: + self._internal_lifecycle_service.fire_lifecycle_event(LifecycleState.SHUTTING_DOWN) + self._internal_lifecycle_service.shutdown() + self._proxy_session_manager.shutdown().result() + self._near_cache_manager.destroy_near_caches() + await self._connection_manager.shutdown() + self._invocation_service.shutdown() + self._statistics.shutdown() + self._reactor.shutdown() + self._internal_lifecycle_service.fire_lifecycle_event(LifecycleState.SHUTDOWN) + + @property + def name(self) -> str: + return self._name + + @property + def lifecycle_service(self) -> LifecycleService: + return self._lifecycle_service + + @property + def partition_service(self) -> PartitionService: + return self._partition_service + + @property + def cluster_service(self) -> ClusterService: + return self._cluster_service + + @property + def cp_subsystem(self) -> CPSubsystem: + return self._cp_subsystem + + def _create_address_provider(self): + config = self._config + cluster_members = config.cluster_members + address_list_provided = len(cluster_members) > 0 + cloud_discovery_token = config.cloud_discovery_token + cloud_enabled = cloud_discovery_token is not None + if address_list_provided and cloud_enabled: + raise IllegalStateError( + "Only one discovery method can be enabled at a time. " + "Cluster members given explicitly: %s, Hazelcast Cloud enabled: %s" + % (address_list_provided, cloud_enabled) + ) + + if cloud_enabled: + connection_timeout = self._get_connection_timeout(config) + return HazelcastCloudAddressProvider(cloud_discovery_token, connection_timeout) + + return DefaultAddressProvider(cluster_members) + + def _create_client_name(self, client_id): + client_name = self._config.client_name + if client_name: + return client_name + return "hz.client_%s" % client_id + + async def _send_state_to_cluster(self): + return await self._compact_schema_service.send_all_schemas() + + @staticmethod + def _get_connection_timeout(config): + timeout = config.connection_timeout + return sys.maxsize if timeout == 0 else timeout + + @staticmethod + def _init_load_balancer(config): + load_balancer = config.load_balancer + if not load_balancer: + load_balancer = RoundRobinLB() + return load_balancer + + +class _ClientContext: + def __init__(self): + self.client = None + self.config = None + self.invocation_service = None + self.partition_service = None + self.cluster_service = None + self.connection_manager = None + self.serialization_service = None + self.listener_service = None + self.proxy_manager = None + self.near_cache_manager = None + self.lock_reference_id_generator = None + self.name = None + self.proxy_session_manager = None + self.reactor = None + self.compact_schema_service = None + + def init_context( + self, + client, + config, + invocation_service, + partition_service, + cluster_service, + connection_manager, + serialization_service, + listener_service, + proxy_manager, + near_cache_manager, + lock_reference_id_generator, + name, + proxy_session_manager, + reactor, + compact_schema_service, + ): + self.client = client + self.config = config + self.invocation_service = invocation_service + self.partition_service = partition_service + self.cluster_service = cluster_service + self.connection_manager = connection_manager + self.serialization_service = serialization_service + self.listener_service = listener_service + self.proxy_manager = proxy_manager + self.near_cache_manager = near_cache_manager + self.lock_reference_id_generator = lock_reference_id_generator + self.name = name + self.proxy_session_manager = proxy_session_manager + self.reactor = reactor + self.compact_schema_service = compact_schema_service diff --git a/hazelcast/internal/__init__.py b/hazelcast/internal/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/hazelcast/internal/asyncio_cluster.py b/hazelcast/internal/asyncio_cluster.py new file mode 100644 index 0000000000..4ce2491b65 --- /dev/null +++ b/hazelcast/internal/asyncio_cluster.py @@ -0,0 +1,392 @@ +import asyncio +import logging +import threading +import typing +import uuid +from collections import OrderedDict + +from hazelcast.core import EndpointQualifier, ProtocolType, MemberInfo, Address +from hazelcast.errors import TargetDisconnectedError, IllegalStateError +from hazelcast.util import check_not_none + +_logger = logging.getLogger(__name__) + + +class _MemberListSnapshot: + __slots__ = ("version", "members") + + def __init__(self, version, members): + self.version = version + self.members = members + + +class ClientInfo: + """Local information of the client. + + Attributes: + uuid: Unique id of this client instance. + address: Local address that is used to communicate with cluster. + name: Name of the client. + labels: Read-only set of all labels of this client. + """ + + __slots__ = ("uuid", "address", "name", "labels") + + def __init__( + self, client_uuid: uuid.UUID, address: Address, name: str, labels: typing.Set[str] + ): + self.uuid = client_uuid + self.address = address + self.name = name + self.labels = labels + + def __repr__(self): + return "ClientInfo(uuid=%s, address=%s, name=%s, labels=%s)" % ( + self.uuid, + self.address, + self.name, + self.labels, + ) + + +_EMPTY_SNAPSHOT = _MemberListSnapshot(-1, OrderedDict()) +_INITIAL_MEMBERS_TIMEOUT_SECONDS = 120 +_CLIENT_ENDPOINT_QUALIFIER = EndpointQualifier(ProtocolType.CLIENT, None) +_MEMBER_ENDPOINT_QUALIFIER = EndpointQualifier(ProtocolType.MEMBER, None) + + +class ClusterService: + """ + Cluster service for Hazelcast clients. + + It provides access to the members in the cluster + and one can register for changes in the cluster members. + """ + + def __init__(self, internal_cluster_service): + self._service = internal_cluster_service + + def add_listener( + self, + member_added: typing.Callable[[MemberInfo], None] = None, + member_removed: typing.Callable[[MemberInfo], None] = None, + fire_for_existing=False, + ) -> str: + """ + Adds a membership listener to listen for membership updates. + + It will be notified when a member is added to the cluster or removed + from the cluster. There is no check for duplicate registrations, + so if you register the listener twice, it will get events twice. + + Args: + member_added: Function to be called when a member is added to the + cluster. + member_removed: Function to be called when a member is removed + from the cluster. + fire_for_existing: Whether or not fire member_added for existing + members. + + Returns: + Registration id of the listener which will be used for removing + this listener. + """ + return self._service.add_listener(member_added, member_removed, fire_for_existing) + + def remove_listener(self, registration_id: str) -> bool: + """ + Removes the specified membership listener. + + Args: + registration_id: Registration id of the listener to be removed. + + Returns: + ``True`` if the registration is removed, ``False`` otherwise. + """ + return self._service.remove_listener(registration_id) + + def get_members( + self, member_selector: typing.Callable[[MemberInfo], bool] = None + ) -> typing.List[MemberInfo]: + """ + Lists the current members in the cluster. + + Every member in the cluster returns the members in the same order. + To obtain the oldest member in the cluster, you can retrieve the first + item in the list. + + Args: + member_selector: Function to filter members to return. If not + provided, the returned list will contain all the available + cluster members. + + Returns: + Current members in the cluster + """ + return self._service.get_members(member_selector) + + +class _InternalClusterService: + def __init__(self, client, config): + self._client = client + self._connection_manager = None + self._labels = frozenset(config.labels) + self._listeners = {} + self._member_list_snapshot = _EMPTY_SNAPSHOT + self._initial_list_fetched = asyncio.Event() + + def start(self, connection_manager, membership_listeners): + self._connection_manager = connection_manager + for listener in membership_listeners: + self.add_listener(*listener) + + def get_member(self, member_uuid): + check_not_none(uuid, "UUID must not be null") + snapshot = self._member_list_snapshot + return snapshot.members.get(member_uuid, None) + + def get_members(self, member_selector=None): + snapshot = self._member_list_snapshot + if not member_selector: + return list(snapshot.members.values()) + + members = [] + for member in snapshot.members.values(): + if member_selector(member): + members.append(member) + return members + + def size(self): + """ + Returns: + int: Size of the cluster. + """ + snapshot = self._member_list_snapshot + return len(snapshot.members) + + def get_local_client(self): + """ + Returns: + hazelcast.cluster.ClientInfo: The client info. + """ + connection_manager = self._connection_manager + connection = connection_manager.get_random_connection() + local_address = None if not connection else connection.local_address + return ClientInfo( + connection_manager.client_uuid, local_address, self._client.name, self._labels + ) + + def add_listener(self, member_added=None, member_removed=None, fire_for_existing=False): + registration_id = str(uuid.uuid4()) + self._listeners[registration_id] = (member_added, member_removed) + + if fire_for_existing and member_added: + snapshot = self._member_list_snapshot + for member in snapshot.members.values(): + member_added(member) + + return registration_id + + def remove_listener(self, registration_id): + try: + self._listeners.pop(registration_id) + return True + except KeyError: + return False + + async def wait_initial_member_list_fetched(self): + """Blocks until the initial member list is fetched from the cluster. + + If it is not received within the timeout, an error is raised. + + Raises: + IllegalStateError: If the member list could not be fetched + """ + try: + await asyncio.wait_for( + self._initial_list_fetched.wait(), _INITIAL_MEMBERS_TIMEOUT_SECONDS + ) + except TimeoutError: + raise IllegalStateError("Could not get initial member list from cluster!") + + def clear_member_list_version(self): + _logger.debug("Resetting the member list version") + + current = self._member_list_snapshot + if current is not _EMPTY_SNAPSHOT: + self._member_list_snapshot = _MemberListSnapshot(0, current.members) + + def clear_member_list(self): + _logger.debug("Resetting the member list") + + current = self._member_list_snapshot + if current is not _EMPTY_SNAPSHOT: + previous_members = current.members + snapshot = _MemberListSnapshot(0, {}) + self._member_list_snapshot = snapshot + dead_members, new_members = self._detect_membership_events( + previous_members, snapshot.members + ) + self._fire_membership_events(dead_members, new_members) + + def handle_members_view_event(self, version, member_infos): + snapshot = self._create_snapshot(version, member_infos) + if _logger.isEnabledFor(logging.DEBUG): + _logger.debug( + "Handling new snapshot with membership version: %s, member string: %s", + version, + self._members_string(snapshot.members), + ) + + current = self._member_list_snapshot + if version > current.version: + self._apply_new_state_and_fire_events(current, snapshot) + + if current is _EMPTY_SNAPSHOT: + self._initial_list_fetched.set() + + def _apply_new_state_and_fire_events(self, current, snapshot): + self._member_list_snapshot = snapshot + dead_members, new_members = self._detect_membership_events( + current.members, snapshot.members + ) + self._fire_membership_events(dead_members, new_members) + + def _fire_membership_events(self, dead_members, new_members): + # Removal events should be fired first + for dead_member in dead_members: + for _, handler in self._listeners.values(): + if handler: + try: + handler(dead_member) + except: + _logger.exception("Exception in membership listener") + + for new_member in new_members: + for handler, _ in self._listeners.values(): + if handler: + try: + handler(new_member) + except: + _logger.exception("Exception in membership listener") + + def _detect_membership_events(self, previous_members, current_members): + new_members = [] + dead_members = set(previous_members.values()) + for member in current_members.values(): + try: + dead_members.remove(member) + except KeyError: + new_members.append(member) + + for dead_member in dead_members: + connection = self._connection_manager.get_connection(dead_member.uuid) + if connection: + connection.close_connection( + None, + TargetDisconnectedError( + "The client has closed the connection to this member, " + "after receiving a member left event from the cluster. " + "%s" % connection + ), + ) + + if (len(new_members) + len(dead_members)) > 0: + if len(current_members) > 0: + _logger.info(self._members_string(current_members)) + + return dead_members, new_members + + @staticmethod + def _members_string(members): + n = len(members) + return "\n\nMembers [%s] {\n\t%s\n}\n" % (n, "\n\t".join(map(str, members.values()))) + + @staticmethod + def _create_snapshot(version, member_infos): + new_members = OrderedDict() + for member_info in member_infos: + address_map = member_info.address_map + if address_map: + address = address_map.get( + _CLIENT_ENDPOINT_QUALIFIER, + address_map.get(_MEMBER_ENDPOINT_QUALIFIER, None), + ) + member_info.address = address + else: + # It might be None on 4.0 servers. + member_info.address_map = { + _MEMBER_ENDPOINT_QUALIFIER: member_info.address, + } + + new_members[member_info.uuid] = member_info + return _MemberListSnapshot(version, new_members) + + +class VectorClock: + """Vector clock consisting of distinct replica logical clocks. + + The vector clock may be read from different thread but concurrent + updates must be synchronized externally. There is no guarantee for + concurrent updates. + + See Also: + https://en.wikipedia.org/wiki/Vector_clock + """ + + def __init__(self): + self._replica_timestamps = {} + + def is_after(self, other: "VectorClock") -> bool: + """Returns ``True`` if this vector clock is causally strictly after the + provided vector clock. This means that it the provided clock is neither + equal to, greater than or concurrent to this vector clock. + + Args: + other: Vector clock to be compared + + Returns: + ``True`` if this vector clock is strictly after the other vector + clock, ``False`` otherwise. + """ + any_timestamp_greater = False + for replica_id, other_timestamp in other.entry_set(): + local_timestamp = self._replica_timestamps.get(replica_id) + + if local_timestamp is None or local_timestamp < other_timestamp: + return False + elif local_timestamp > other_timestamp: + any_timestamp_greater = True + + # there is at least one local timestamp greater or local vector clock has additional timestamps + return any_timestamp_greater or other.size() < self.size() + + def set_replica_timestamp(self, replica_id: str, timestamp: int) -> None: + """Sets the logical timestamp for the given replica ID. + + Args: + replica_id: Replica ID. + timestamp: Timestamp for the given replica ID. + """ + self._replica_timestamps[replica_id] = timestamp + + def entry_set(self) -> typing.List[typing.Tuple[str, int]]: + """Returns the entry set of the replica timestamps in a format of list + of tuples. + + Each tuple contains the replica ID and the timestamp associated with + it. + + Returns: + List of tuples. + """ + return list(self._replica_timestamps.items()) + + def size(self) -> int: + """Returns the number of timestamps that are in the replica timestamps + dictionary. + + Returns: + Number of timestamps in the replica timestamps. + """ + return len(self._replica_timestamps) diff --git a/hazelcast/internal/asyncio_compact.py b/hazelcast/internal/asyncio_compact.py new file mode 100644 index 0000000000..06f53ab97c --- /dev/null +++ b/hazelcast/internal/asyncio_compact.py @@ -0,0 +1,162 @@ +import asyncio +import logging +import typing + +from hazelcast.errors import HazelcastSerializationError, IllegalStateError +from hazelcast.internal.asyncio_invocation import Invocation +from hazelcast.protocol.codec import ( + client_fetch_schema_codec, + client_send_schema_codec, + client_send_all_schemas_codec, +) + +if typing.TYPE_CHECKING: + from hazelcast.config import Config + from hazelcast.protocol.client_message import OutboundMessage + from hazelcast.internal.asyncio_cluster import ClusterService + from hazelcast.internal.asyncio_invocation import InvocationService + from hazelcast.internal.asyncio_reactor import AsyncioReactor + from hazelcast.serialization.compact import ( + CompactStreamSerializer, + Schema, + SchemaNotReplicatedError, + ) + +_logger = logging.getLogger(__name__) + + +class CompactSchemaService: + _SEND_SCHEMA_RETRY_COUNT = 100 + + def __init__( + self, + compact_serializer: "CompactStreamSerializer", + invocation_service: "InvocationService", + cluster_service: "ClusterService", + reactor: "AsyncioReactor", + config: "Config", + ): + self._compact_serializer = compact_serializer + self._invocation_service = invocation_service + self._cluster_service = cluster_service + self._reactor = reactor + self._invocation_retry_pause = config.invocation_retry_pause + self._has_replicated_schemas = False + + def fetch_schema(self, schema_id: int) -> asyncio.Future: + _logger.debug( + "Could not find schema with the id %s locally. It will be fetched from the cluster.", + schema_id, + ) + + request = client_fetch_schema_codec.encode_request(schema_id) + fetch_schema_invocation = Invocation( + request, + response_handler=client_fetch_schema_codec.decode_response, + ) + self._invocation_service.invoke(fetch_schema_invocation) + return fetch_schema_invocation.future + + def send_schema_and_retry( + self, + error: "SchemaNotReplicatedError", + func: typing.Callable[..., asyncio.Future], + *args: typing.Any, + **kwargs: typing.Any, + ) -> asyncio.Future: + schema = error.schema + clazz = error.clazz + request = client_send_schema_codec.encode_request(schema) + + def callback(): + self._has_replicated_schemas = True + self._compact_serializer.register_schema_to_type(schema, clazz) + return func(*args, **kwargs) + + return self._replicate_schema( + schema, request, CompactSchemaService._SEND_SCHEMA_RETRY_COUNT, callback + ) + + def _replicate_schema( + self, + schema: "Schema", + request: "OutboundMessage", + remaining_retries: int, + callback: typing.Callable[..., asyncio.Future], + ) -> asyncio.Future: + def continuation(future: asyncio.Future): + replicated_members = future.result() + members = self._cluster_service.get_members() + for member in members: + if member.uuid not in replicated_members: + break + else: + # Loop completed normally. + # All members in our member list all known to have the schema + return callback() + + # There is a member in our member list that the schema + # is not known to be replicated yet. We should retry + # sending it in a random member. + if remaining_retries <= 1: + # We tried to send it a couple of times, but the member list + # in our local and the member list returned by the initiator + # nodes did not match. + raise IllegalStateError( + f"The schema {schema} cannot be replicated in the cluster, " + f"after {CompactSchemaService._SEND_SCHEMA_RETRY_COUNT} retries. " + f"It might be the case that the client is connected to the two " + f"halves of the cluster that is experiencing a split-brain, " + f"and continue putting the data associated with that schema " + f"might result in data loss. It might be possible to replicate " + f"the schema after some time, when the cluster is healed." + ) + + delayed_future: asyncio.Future = asyncio.get_running_loop().create_future() + self._reactor.add_timer( + self._invocation_retry_pause, + lambda: delayed_future.set_result(None), + ) + + def retry(_): + return self._replicate_schema( + schema, request.copy(), remaining_retries - 1, callback + ) + + return delayed_future.add_done_callback(retry) + + fut = self._send_schema_replication_request(request) + fut.add_done_callback(continuation) + return fut + + def _send_schema_replication_request(self, request: "OutboundMessage") -> asyncio.Future: + invocation = Invocation(request, response_handler=client_send_schema_codec.decode_response) + self._invocation_service.invoke(invocation) + return invocation.future + + async def send_all_schemas(self) -> None: + schemas = self._compact_serializer.get_schemas() + if not schemas: + _logger.debug("There is no schema to send to the cluster.") + return None + + _logger.debug("Sending the following schemas to the cluster: %s", schemas) + request = client_send_all_schemas_codec.encode_request(schemas) + invocation = Invocation(request, urgent=True) + self._invocation_service.invoke(invocation) + return await invocation.future + + def register_fetched_schema(self, schema_id: int, schema: typing.Optional["Schema"]) -> None: + if not schema: + raise HazelcastSerializationError( + f"The schema with the id {schema_id} can not be found in the cluster." + ) + + self._compact_serializer.register_schema_to_id(schema) + + def has_replicated_schemas(self): + """ + Returns ``True`` is the client has replicated + any Compact schemas to the cluster. + """ + return self._has_replicated_schemas diff --git a/hazelcast/internal/asyncio_connection.py b/hazelcast/internal/asyncio_connection.py new file mode 100644 index 0000000000..8db67caf81 --- /dev/null +++ b/hazelcast/internal/asyncio_connection.py @@ -0,0 +1,1052 @@ +import asyncio +import io +import logging +import random +import struct +import time +import uuid +from typing import Coroutine + +from hazelcast import __version__ +from hazelcast.config import ReconnectMode +from hazelcast.core import ( + AddressHelper, + CLIENT_TYPE, + SERIALIZATION_VERSION, + EndpointQualifier, + ProtocolType, +) +from hazelcast.errors import ( + AuthenticationError, + TargetDisconnectedError, + HazelcastClientNotActiveError, + InvalidConfigurationError, + ClientNotAllowedInClusterError, + IllegalStateError, + ClientOfflineError, +) +from hazelcast.internal.asyncio_invocation import Invocation +from hazelcast.lifecycle import LifecycleState +from hazelcast.protocol.client_message import ( + SIZE_OF_FRAME_LENGTH_AND_FLAGS, + Frame, + InboundMessage, + ClientMessageBuilder, +) +from hazelcast.protocol.codec import ( + client_authentication_codec, + client_authentication_custom_codec, + client_ping_codec, +) +from hazelcast.util import ( + AtomicInteger, + calculate_version, + UNKNOWN_VERSION, + member_of_larger_same_version_group, +) + +_logger = logging.getLogger(__name__) + +_INF = float("inf") +_SQL_CONNECTION_RANDOM_ATTEMPTS = 10 +_CLIENT_PUBLIC_ENDPOINT_QUALIFIER = EndpointQualifier(ProtocolType.CLIENT, "public") + + +class WaitStrategy: + def __init__(self, initial_backoff, max_backoff, multiplier, cluster_connect_timeout, jitter): + self._initial_backoff = initial_backoff + self._max_backoff = max_backoff + self._multiplier = multiplier + self._cluster_connect_timeout = cluster_connect_timeout + self._jitter = jitter + self._attempt = None + self._cluster_connect_attempt_begin = None + self._current_backoff = None + + if cluster_connect_timeout == _INF: + self._cluster_connect_timeout_text = "INFINITE" + else: + self._cluster_connect_timeout_text = "%.2fs" % self._cluster_connect_timeout + + def reset(self): + self._attempt = 0 + self._cluster_connect_attempt_begin = time.time() + self._current_backoff = min(self._max_backoff, self._initial_backoff) + + def sleep(self): + self._attempt += 1 + time_passed = time.time() - self._cluster_connect_attempt_begin + if time_passed > self._cluster_connect_timeout: + _logger.warning( + "Unable to get live cluster connection, cluster connect timeout (%s) is reached. " + "Attempt %d.", + self._cluster_connect_timeout_text, + self._attempt, + ) + return False + + # random between (-jitter * current_backoff, jitter * current_backoff) + sleep_time = self._current_backoff + self._current_backoff * self._jitter * ( + 2 * random.random() - 1 + ) + sleep_time = min(sleep_time, self._cluster_connect_timeout - time_passed) + _logger.warning( + "Unable to get live cluster connection, retry in %.2fs, attempt: %d, " + "cluster connect timeout: %s, max backoff: %.2fs", + sleep_time, + self._attempt, + self._cluster_connect_timeout_text, + self._max_backoff, + ) + time.sleep(sleep_time) + self._current_backoff = min(self._current_backoff * self._multiplier, self._max_backoff) + return True + + +class AuthenticationStatus: + AUTHENTICATED = 0 + CREDENTIALS_FAILED = 1 + SERIALIZATION_VERSION_MISMATCH = 2 + NOT_ALLOWED_IN_CLUSTER = 3 + + +class ClientState: + INITIAL = 0 + """ + Clients start with this state. + Once a client connects to a cluster, it directly switches to + `INITIALIZED_ON_CLUSTER` instead of `CONNECTED_TO_CLUSTER` because on + startup a client has no local state to send to the cluster. + """ + + CONNECTED_TO_CLUSTER = 1 + """ + When a client switches to a new cluster, it moves to this state. It means + that the client has connected to a new cluster but not sent its local + state to the new cluster yet. + """ + + INITIALIZED_ON_CLUSTER = 2 + """ + When a client sends its local state to the cluster it has connected, it + switches to this state. + Invocations are allowed in this state. + """ + + +class ConnectionManager: + """ConnectionManager is responsible for managing ``Connection`` objects.""" + + def __init__( + self, + client, + config, + reactor, + address_provider, + lifecycle_service, + partition_service, + cluster_service, + invocation_service, + near_cache_manager, + send_state_to_cluster_fn, + ): + self.live = False + self.active_connections = {} # uuid to connection, must be modified under the _lock + self.client_uuid = uuid.uuid4() + + self._client = client + self._config = config + self._reactor = reactor + self._address_provider = address_provider + self._lifecycle_service = lifecycle_service + self._partition_service = partition_service + self._cluster_service = cluster_service + self._invocation_service = invocation_service + self._near_cache_manager = near_cache_manager + self._send_state_to_cluster_fn = send_state_to_cluster_fn + self._client_state = ClientState.INITIAL # must be modified under the _lock + self._established_initial_cluster_connection = False # must be modified under the _lock + self._smart_routing_enabled = config.smart_routing + self._wait_strategy = self._init_wait_strategy(config) + self._reconnect_mode = config.reconnect_mode + self._heartbeat_manager = HeartbeatManager( + self, self._client, config, reactor, invocation_service + ) + self._connection_listeners = [] + self._connect_all_members_task: asyncio.Task | None = None + self._async_start = config.async_start + self._connect_to_cluster_thread_running = False + self._shuffle_member_list = config.shuffle_member_list + self._lock = asyncio.Lock() + self._connection_id_generator = AtomicInteger() + self._labels = frozenset(config.labels) + self._cluster_id = None + self._load_balancer = None + self._use_public_ip = ( + isinstance(address_provider, DefaultAddressProvider) and config.use_public_ip + ) + + def add_listener(self, on_connection_opened=None, on_connection_closed=None): + """Registers a ConnectionListener. + + If the same listener is registered multiple times, it will be notified multiple times. + + Args: + on_connection_opened (function): Function to be called when a connection is opened. (Default value = None) + on_connection_closed (function): Function to be called when a connection is removed. (Default value = None) + """ + self._connection_listeners.append((on_connection_opened, on_connection_closed)) + + def get_connection(self, member_uuid): + return self.active_connections.get(member_uuid, None) + + def get_random_connection(self): + # Try getting the connection from the load balancer, if smart routing is enabled + if self._smart_routing_enabled: + member = self._load_balancer.next() + if member: + connection = self.get_connection(member.uuid) + if connection: + return connection + + # Otherwise iterate over connections and return the first one + for connection in list(self.active_connections.values()): + return connection + + # Failed to get a connection + return None + + def get_random_connection_for_sql(self): + """Returns a random connection for SQL. + + The connection is tried to be selected in the following order. + + - Random connection to a data member from the larger same-version + group. + - Random connection to a data member. + - Any random connection + - ``None``, if there is no connection. + + Returns: + Connection: A random connection for SQL. + """ + if self._smart_routing_enabled: + # There might be a race - the chosen member might be just connected or disconnected. + # Try a couple of times, the member_of_larger_same_version_group returns a random + # connection, we might be lucky... + for _ in range(_SQL_CONNECTION_RANDOM_ATTEMPTS): + members = self._cluster_service.get_members() + member = member_of_larger_same_version_group(members) + if not member: + break + + connection = self.get_connection(member.uuid) + if connection: + return connection + + # Otherwise iterate over connections and return the first one + # that's not to a lite member. + first_connection = None + for member_uuid, connection in list(self.active_connections.items()): + if not first_connection: + first_connection = connection + + member = self._cluster_service.get_member(member_uuid) + if not member or member.lite_member: + continue + + return connection + + # Failed to get a connection to a data member. + return first_connection + + async def start(self, load_balancer): + if self.live: + return + + self.live = True + self._load_balancer = load_balancer + self._heartbeat_manager.start() + await self._connect_to_cluster() + + async def shutdown(self): + if not self.live: + return + + self.live = False + if self._connect_all_members_task: + self._connect_all_members_task.cancel() + + self._heartbeat_manager.shutdown() + + # Need to create copy of connection values to avoid modification errors on runtime + async with asyncio.TaskGroup() as tg: + for connection in list(self.active_connections.values()): + tg.create_task( + connection.close_connection("Hazelcast client is shutting down", None) + ) + + self.active_connections.clear() + del self._connection_listeners[:] + + async def connect_to_all_cluster_members(self, sync_start): + if not self._smart_routing_enabled: + return + + if sync_start: + async with asyncio.TaskGroup() as tg: + for member in self._cluster_service.get_members(): + tg.create_task(self._get_or_connect_to_member(member)) + + self._start_connect_all_members_timer() + + async def on_connection_close(self, closed_connection): + remote_uuid = closed_connection.remote_uuid + remote_address = closed_connection.remote_address + + if not remote_address: + _logger.debug( + "Destroying %s, but it has no remote address, hence nothing is " + "removed from the connection dictionary", + closed_connection, + ) + return + + disconnected = False + removed = False + trigger_reconnection = False + async with self._lock: + connection = self.active_connections.get(remote_uuid, None) + if connection == closed_connection: + self.active_connections.pop(remote_uuid, None) + removed = True + _logger.info( + "Removed connection to %s:%s, connection: %s", + remote_address, + remote_uuid, + connection, + ) + + if not self.active_connections: + trigger_reconnection = True + if self._client_state == ClientState.INITIALIZED_ON_CLUSTER: + disconnected = True + + if disconnected: + self._lifecycle_service.fire_lifecycle_event(LifecycleState.DISCONNECTED) + + if trigger_reconnection: + await self._trigger_cluster_reconnection() + + if removed: + async with asyncio.TaskGroup() as tg: + # TODO: see on_connection_open + for _, on_connection_closed in self._connection_listeners: + if on_connection_closed: + try: + maybe_coro = on_connection_closed(closed_connection) + if isinstance(maybe_coro, Coroutine): + tg.create_task(maybe_coro) + except Exception: + _logger.exception("Exception in connection listener") + else: + _logger.debug( + "Destroying %s, but there is no mapping for %s in the connection dictionary", + closed_connection, + remote_uuid, + ) + + def check_invocation_allowed(self): + state = self._client_state + if state == ClientState.INITIALIZED_ON_CLUSTER and self.active_connections: + return + + if state == ClientState.INITIAL: + if self._async_start: + raise ClientOfflineError() + else: + raise IOError("No connection found to cluster since the client is starting.") + elif self._reconnect_mode == ReconnectMode.ASYNC: + raise ClientOfflineError() + else: + raise IOError("No connection found to cluster") + + def initialized_on_cluster(self) -> bool: + """ + Returns ``True`` if the client is initialized on the cluster, by + sending its local state, if necessary. + """ + return self._client_state == ClientState.INITIALIZED_ON_CLUSTER + + async def _get_or_connect_to_address(self, address): + for connection in list(self.active_connections.values()): + if connection.remote_address == address: + return connection + translated = self._translate(address) + connection = await self._create_connection(translated) + response = await self._authenticate(connection) + await self._on_auth(response, connection) + return connection + + async def _get_or_connect_to_member(self, member): + connection = self.active_connections.get(member.uuid, None) + if connection: + return connection + + translated = self._translate_member_address(member) + connection = await self._create_connection(translated) + response = await self._authenticate(connection) + await self._on_auth(response, connection) + return connection + + async def _create_connection(self, address): + return await self._reactor.connection_factory( + self, + self._connection_id_generator.get_and_increment(), + address, + self._config, + self._invocation_service.handle_client_message, + ) + + def _translate(self, address): + translated = self._address_provider.translate(address) + if not translated: + raise ValueError( + "Address provider %s could not translate address %s" + % (self._address_provider.__class__.__name__, address) + ) + + return translated + + def _translate_member_address(self, member): + if self._use_public_ip: + public_address = member.address_map.get(_CLIENT_PUBLIC_ENDPOINT_QUALIFIER, None) + if public_address: + return public_address + + return member.address + + return self._translate(member.address) + + async def _trigger_cluster_reconnection(self): + if self._reconnect_mode == ReconnectMode.OFF: + _logger.info("Reconnect mode is OFF. Shutting down the client") + await self._shutdown_client() + return + + if self._lifecycle_service.running: + await self._start_connect_to_cluster_thread() + + def _init_wait_strategy(self, config): + cluster_connect_timeout = config.cluster_connect_timeout + if cluster_connect_timeout == -1: + # If the no timeout is specified by the + # user, or set to -1 explicitly, set + # the timeout to infinite. + cluster_connect_timeout = _INF + + return WaitStrategy( + config.retry_initial_backoff, + config.retry_max_backoff, + config.retry_multiplier, + cluster_connect_timeout, + config.retry_jitter, + ) + + def _start_connect_all_members_timer(self): + connecting_uuids = set() + + async def run(): + await asyncio.sleep(1) + if not self._lifecycle_service.running: + return + + async with asyncio.TaskGroup() as tg: + member_uuids = [] + for member in self._cluster_service.get_members(): + member_uuid = member.uuid + if self.active_connections.get(member_uuid, None): + continue + if member_uuid in connecting_uuids: + continue + connecting_uuids.add(member_uuid) + if not self._lifecycle_service.running: + break + # TODO: ERROR:asyncio:Task was destroyed but it is pending! + tg.create_task(self._get_or_connect_to_member(member)) + member_uuids.append(member_uuid) + + for item in member_uuids: + connecting_uuids.discard(item) + + self._connect_all_members_task = asyncio.create_task(run()) + + self._connect_all_members_task = asyncio.create_task(run()) + + async def _connect_to_cluster(self): + await self._sync_connect_to_cluster() + + async def _start_connect_to_cluster_thread(self): + async with self._lock: + if self._connect_to_cluster_thread_running: + return + + self._connect_to_cluster_thread_running = True + + try: + while True: + await self._sync_connect_to_cluster() + async with self._lock: + if self.active_connections: + self._connect_to_cluster_thread_running = False + return + except Exception: + _logger.exception("Could not connect to any cluster, shutting down the client") + await self._shutdown_client() + + async def _shutdown_client(self): + try: + await self._client.shutdown() + except Exception: + _logger.exception("Exception during client shutdown") + + async def _sync_connect_to_cluster(self): + tried_addresses = set() + self._wait_strategy.reset() + try: + while True: + tried_addresses_per_attempt = set() + members = self._cluster_service.get_members() + if self._shuffle_member_list: + random.shuffle(members) + + for member in members: + self._check_client_active() + tried_addresses_per_attempt.add(member.address) + connection = await self._connect(member, self._get_or_connect_to_member) + if connection: + return + + for address in self._get_possible_addresses(): + self._check_client_active() + if address in tried_addresses_per_attempt: + # We already tried this address on from the member list + continue + + tried_addresses_per_attempt.add(address) + connection = await self._connect(address, self._get_or_connect_to_address) + if connection: + return + + tried_addresses.update(tried_addresses_per_attempt) + + # If the address providers load no addresses (which seems to be possible), + # then the above loop is not entered and the lifecycle check is missing, + # hence we need to repeat the same check at this point. + if not tried_addresses_per_attempt: + self._check_client_active() + + if not self._wait_strategy.sleep(): + break + except (ClientNotAllowedInClusterError, InvalidConfigurationError): + cluster_name = self._config.cluster_name + _logger.exception("Stopped trying on cluster %s", cluster_name) + + cluster_name = self._config.cluster_name + _logger.info( + "Unable to connect to any address from the cluster with name: %s. " + "The following addresses were tried: %s", + cluster_name, + tried_addresses, + ) + if self._lifecycle_service.running: + msg = "Unable to connect to any cluster" + else: + msg = "Client is being shutdown" + raise IllegalStateError(msg) + + async def _connect(self, target, get_or_connect_func): + _logger.info("Trying to connect to %s", target) + try: + return await get_or_connect_func(target) + except (ClientNotAllowedInClusterError, InvalidConfigurationError) as e: + _logger.warning("Error during initial connection to %s", target, exc_info=True) + raise e + except Exception: + _logger.warning("Error during initial connection to %s", target, exc_info=True) + return None + + def _authenticate(self, connection) -> asyncio.Future: + client = self._client + cluster_name = self._config.cluster_name + client_name = client.name + if self._config.token_provider: + token = self._config.token_provider.token(connection.connected_address) + request = client_authentication_custom_codec.encode_request( + cluster_name, + token, + self.client_uuid, + CLIENT_TYPE, + SERIALIZATION_VERSION, + __version__, + client_name, + self._labels, + ) + else: + request = client_authentication_codec.encode_request( + cluster_name, + self._config.creds_username, + self._config.creds_password, + self.client_uuid, + CLIENT_TYPE, + SERIALIZATION_VERSION, + __version__, + client_name, + self._labels, + ) + invocation = Invocation( + request, connection=connection, urgent=True, response_handler=lambda m: m + ) + self._invocation_service.invoke(invocation) + return invocation.future + + async def _on_auth(self, response, connection): + try: + response = client_authentication_codec.decode_response(response) + except Exception as e: + await connection.close_connection("Failed to authenticate connection", e) + raise e + + status = response["status"] + if status == AuthenticationStatus.AUTHENTICATED: + return await self._handle_successful_auth(response, connection) + + if status == AuthenticationStatus.CREDENTIALS_FAILED: + err = AuthenticationError("Authentication failed. Check cluster name and credentials.") + elif status == AuthenticationStatus.NOT_ALLOWED_IN_CLUSTER: + err = ClientNotAllowedInClusterError("Client is not allowed in the cluster") + elif status == AuthenticationStatus.SERIALIZATION_VERSION_MISMATCH: + err = IllegalStateError("Server serialization version does not match to client") + else: + err = AuthenticationError( + "Authentication status code not supported. status: %s" % status + ) + + await connection.close_connection("Failed to authenticate connection", err) + raise err + + async def _handle_successful_auth(self, response, connection): + async with self._lock: + self._check_partition_count(response["partition_count"]) + + server_version_str = response["server_hazelcast_version"] + remote_address = response["address"] + remote_uuid = response["member_uuid"] + + connection.remote_address = remote_address + connection.server_version = calculate_version(server_version_str) + connection.remote_uuid = remote_uuid + + existing = self.active_connections.get(remote_uuid, None) + if existing: + await connection.close_connection( + "Duplicate connection to same member with UUID: %s" % remote_uuid, None + ) + return existing + + new_cluster_id = response["cluster_id"] + changed_cluster = self._cluster_id is not None and self._cluster_id != new_cluster_id + if changed_cluster: + await self._check_client_state_on_cluster_change(connection) + _logger.warning( + "Switching from current cluster: %s to new cluster: %s", + self._cluster_id, + new_cluster_id, + ) + self._on_cluster_restart() + + is_initial_connection = not self.active_connections + self.active_connections[remote_uuid] = connection + fire_connected_lifecycle_event = False + if is_initial_connection: + self._cluster_id = new_cluster_id + # In split brain, the client might connect to the one half + # of the cluster, and then later might reconnect to the + # other half, after the half it was connected to is + # completely dead. Since the cluster id is preserved in + # split brain scenarios, it is impossible to distinguish + # reconnection to the same cluster vs reconnection to the + # other half of the split brain. However, in the latter, + # we might need to send some state to the other half of + # the split brain (like Compact schemas). That forces us + # to send the client state to the cluster after the first + # cluster connection, regardless the cluster id is + # changed or not. + if self._established_initial_cluster_connection: + self._client_state = ClientState.CONNECTED_TO_CLUSTER + await self._initialize_on_cluster(new_cluster_id) + else: + fire_connected_lifecycle_event = True + self._established_initial_cluster_connection = True + self._client_state = ClientState.INITIALIZED_ON_CLUSTER + + if fire_connected_lifecycle_event: + self._lifecycle_service.fire_lifecycle_event(LifecycleState.CONNECTED) + + _logger.info( + "Authenticated with server %s:%s, server version: %s, local address: %s", + remote_address, + remote_uuid, + server_version_str, + connection.local_address, + ) + + async with asyncio.TaskGroup() as tg: + for on_connection_opened, _ in self._connection_listeners: + if on_connection_opened: + try: + # TODO: creating the task may not throw the exception + # TODO: protect the loop against exceptions, so all handlers run + maybe_coro = on_connection_opened(connection) + if isinstance(maybe_coro, Coroutine): + tg.create_task(maybe_coro) + except Exception: + _logger.exception("Exception in connection listener") + + if not connection.live: + await self.on_connection_close(connection) + + return connection + + async def _initialize_on_cluster(self, cluster_id) -> None: + # This method is only called in the reactor thread + if cluster_id != self._cluster_id: + _logger.warning( + f"Client won't send the state to the cluster: {cluster_id}" + f"because it switched to a new cluster: {self._cluster_id}" + ) + return + + async def callback(): + try: + if cluster_id == self._cluster_id: + _logger.debug("The client state is sent to the cluster %s", cluster_id) + self._client_state = ClientState.INITIALIZED_ON_CLUSTER + self._lifecycle_service.fire_lifecycle_event(LifecycleState.CONNECTED) + elif _logger.isEnabledFor(logging.DEBUG): + _logger.warning( + "Cannot set client state to 'INITIALIZED_ON_CLUSTER'" + f"because current cluster id: {self._cluster_id}" + f"is different than the expected cluster id: {cluster_id}" + ) + except Exception: + await retry_on_error() + + async def retry_on_error(): + _logger.exception(f"Failure during sending client state to the cluster {cluster_id}") + + if cluster_id != self._cluster_id: + return + + if _logger.isEnabledFor(logging.DEBUG): + _logger.warning(f"Retrying sending client state to the cluster: {cluster_id}") + + await self._initialize_on_cluster(cluster_id) + + try: + await self._send_state_to_cluster_fn() + await callback() + except Exception: + await retry_on_error() + + async def _check_client_state_on_cluster_change(self, connection): + if self.active_connections: + # If there are other connections, we must be connected to the wrong cluster. + # We should not stay connected to this new connection. + # Note that, in some racy scenarios, we might close a connection that + # we can operate on. In those scenarios, we rely on the fact that we will + # reopen the connections. + reason = "Connection does not belong to the cluster %s" % self._cluster_id + await connection.close_connection(reason, None) + raise ValueError(reason) + + def _on_cluster_restart(self): + self._near_cache_manager.clear_near_caches() + self._cluster_service.clear_member_list() + + def _check_partition_count(self, partition_count): + if not self._partition_service.check_and_set_partition_count(partition_count): + raise ClientNotAllowedInClusterError( + "Client can not work with this cluster because it has a " + "different partition count. Expected partition count: %d, " + "Member partition count: %d" + % (self._partition_service.partition_count, partition_count) + ) + + def _check_client_active(self): + if not self._lifecycle_service.running: + raise HazelcastClientNotActiveError() + + def _get_possible_addresses(self): + primaries, secondaries = self._address_provider.load_addresses() + if self._shuffle_member_list: + # The relative order between primary and secondary addresses should + # not be changed. So we shuffle the lists separately and then add + # them to the final list so that secondary addresses are not tried + # before all primary addresses have been tried. Otherwise we can get + # startup delays + random.shuffle(primaries) + random.shuffle(secondaries) + + addresses = [] + addresses.extend(primaries) + addresses.extend(secondaries) + return addresses + + +class HeartbeatManager: + def __init__(self, connection_manager, client, config, reactor, invocation_service): + self._connection_manager = connection_manager + self._client = client + self._reactor = reactor + self._invocation_service = invocation_service + self._heartbeat_timeout = config.heartbeat_timeout + self._heartbeat_interval = config.heartbeat_interval + self._heartbeat_task: asyncio.Task | None = None + + def start(self): + """Starts sending periodic HeartBeat operations.""" + + async def _heartbeat(): + await asyncio.sleep(self._heartbeat_interval) + _logger.debug("heartbeat") + conn_manager = self._connection_manager + if not conn_manager.live: + return + + now = time.time() + async with asyncio.TaskGroup() as tg: + for connection in list(conn_manager.active_connections.values()): + tg.create_task(self._check_connection(now, connection)) + self._heartbeat_task = asyncio.create_task(_heartbeat()) + + self._heartbeat_task = asyncio.create_task(_heartbeat()) + + def shutdown(self): + """Stops HeartBeat operations.""" + if self._heartbeat_task: + self._heartbeat_task.cancel() + + async def _check_connection(self, now, connection): + if not connection.live: + return + + if (now - connection.last_read_time) > self._heartbeat_timeout: + _logger.warning("Heartbeat failed over the connection: %s", connection) + await connection.close_connection( + "Heartbeat timed out", + TargetDisconnectedError("Heartbeat timed out to connection %s" % connection), + ) + return + + if (now - connection.last_write_time) > self._heartbeat_interval: + request = client_ping_codec.encode_request() + invocation = Invocation(request, connection=connection, urgent=True) + asyncio.create_task(self._invocation_service.ainvoke(invocation)) + + +_frame_header = struct.Struct(" None: + if not invocation.timeout: + invocation.timeout = self._invocation_timeout + time.time() + + correlation_id = self._next_correlation_id.get_and_increment() + request = invocation.request + request.set_correlation_id(correlation_id) + request.set_partition_id(invocation.partition_id) + self._do_invoke(invocation) + + async def ainvoke(self, invocation: Invocation): + self.invoke(invocation) + return await invocation.future + + def shutdown(self): + if self._shutdown: + return + + self._shutdown = True + if self._clean_resources_timer: + self._clean_resources_timer.cancel() + for invocation in list(self._pending.values()): + self._notify_error(invocation, HazelcastClientNotActiveError()) + + def _invoke_on_partition_owner(self, invocation, partition_id): + owner_uuid = self._partition_service.get_partition_owner(partition_id) + if not owner_uuid: + _logger.debug("Partition owner is not assigned yet") + return False + return self._invoke_on_target(invocation, owner_uuid) + + def _invoke_on_target(self, invocation, owner_uuid): + connection = self._connection_manager.get_connection(owner_uuid) + if not connection: + _logger.debug("Client is not connected to target: %s", owner_uuid) + return False + return self._send(invocation, connection) + + def _invoke_on_random_connection(self, invocation): + connection = self._connection_manager.get_random_connection() + if not connection: + _logger.debug("No connection found to invoke") + return False + return self._send(invocation, connection) + + def _invoke_smart(self, invocation): + try: + if invocation.urgent: + self._check_urgent_invocation_allowed(invocation) + else: + self._check_invocation_allowed_fn() + + connection = invocation.connection + if connection: + invoked = self._send(invocation, connection) + if not invoked: + self._notify_error( + invocation, IOError("Could not invoke on connection %s" % connection) + ) + return + + if invocation.partition_id != -1: + invoked = self._invoke_on_partition_owner(invocation, invocation.partition_id) + elif invocation.uuid: + invoked = self._invoke_on_target(invocation, invocation.uuid) + else: + invoked = self._invoke_on_random_connection(invocation) + + if not invoked: + invoked = self._invoke_on_random_connection(invocation) + + if not invoked: + self._notify_error(invocation, IOError("No connection found to invoke")) + except Exception as e: + self._notify_error(invocation, e) + + def _invoke_non_smart(self, invocation): + try: + if invocation.urgent: + self._check_urgent_invocation_allowed(invocation) + else: + self._check_invocation_allowed_fn() + + connection = invocation.connection + if connection: + invoked = self._send(invocation, connection) + if not invoked: + self._notify_error( + invocation, IOError("Could not invoke on connection %s" % connection) + ) + return + + if not self._invoke_on_random_connection(invocation): + self._notify_error(invocation, IOError("No connection found to invoke")) + except Exception as e: + self._notify_error(invocation, e) + + def _send(self, invocation, connection): + if self._shutdown: + raise HazelcastClientNotActiveError() + + if self._backup_ack_to_client_enabled: + invocation.request.set_backup_aware_flag() + + message = invocation.request + correlation_id = message.get_correlation_id() + self._pending[correlation_id] = invocation + + if invocation.event_handler: + self._listener_service.add_event_handler(correlation_id, invocation.event_handler) + + if not connection.send_message(message): + if invocation.event_handler: + self._listener_service.remove_event_handler(correlation_id) + return False + + invocation.sent_connection = connection + return True + + def _complete(self, invocation: Invocation, client_message: InboundMessage) -> None: + try: + result = invocation.response_handler(client_message) + invocation.future.set_result(result) + except SchemaNotFoundError as e: + self._fetch_schema_and_complete_again(e, invocation, client_message) + return + except Exception as e: + invocation.future.set_exception(e) + + correlation_id = invocation.request.get_correlation_id() + self._pending.pop(correlation_id, None) + + def _complete_with_error(self, invocation, error): + invocation.future.set_exception(error) + correlation_id = invocation.request.get_correlation_id() + self._pending.pop(correlation_id, None) + + def _fetch_schema_and_complete_again( + self, error: SchemaNotFoundError, invocation: Invocation, message: InboundMessage + ) -> None: + schema_id = error.schema_id + + def callback(future): + try: + schema = future.result() + self._compact_schema_service.register_fetched_schema(schema_id, schema) + except Exception as e: + self._complete_with_error(invocation, e) + return + + message.reset_next_frame() + self._complete(invocation, message) + + fetch_schema_future = self._compact_schema_service.fetch_schema(schema_id) + fetch_schema_future.add_done_callback(callback) + + def _notify_error(self, invocation, error): + _logger.debug("Got exception for request %s, error: %s", invocation.request, error) + + if not self._client.lifecycle_service.is_running(): + self._complete_with_error(invocation, HazelcastClientNotActiveError()) + return + + if not self._should_retry(invocation, error): + self._complete_with_error(invocation, error) + return + + if invocation.timeout < time.time(): + _logger.debug("Error will not be retried because invocation timed out: %s", error) + error = OperationTimeoutError( + "Request timed out because an error occurred " + "after invocation timeout: %s" % error + ) + self._complete_with_error(invocation, error) + return + + invocation.sent_connection = None + invoke_func = functools.partial(self._retry_if_not_done, invocation) + self._reactor.add_timer(self._invocation_retry_pause, invoke_func) + + def _retry_if_not_done(self, invocation): + if not invocation.future.done(): + self._do_invoke(invocation) + + def _should_retry(self, invocation, error): + if isinstance(error, InvocationMightContainCompactDataError): + return True + + if invocation.connection and isinstance(error, (IOError, TargetDisconnectedError)): + return False + + if invocation.uuid and isinstance(error, TargetNotMemberError): + return False + + if isinstance(error, (IOError, HazelcastInstanceNotActiveError)) or is_retryable_error( + error + ): + return True + + if isinstance(error, TargetDisconnectedError): + return invocation.request.retryable or self._is_redo_operation + + return False + + def _check_urgent_invocation_allowed(self, invocation: Invocation): + if self._connection_manager.initialized_on_cluster(): + # If the client is initialized on the cluster, that means we + # have sent all the schemas to the cluster, even if we are + # reconnected to it + return + + if not self._compact_schema_service.has_replicated_schemas(): + # If there were no Compact schemas to begin with, we don't need + # to perform the check below. If the client didn't send a Compact + # schema up until this point, the retries or listener registrations + # could not send a schema, because if they were, we wouldn't hit + # this line. + return + + # We are not yet initialized on cluster, so the Compact schemas might + # not be sent yet. This message contains some serialized data, + # and it is possible that it can also contain Compact serialized data. + # In that case, allowing this invocation to go through now could + # violate the invariant that the schema must come to cluster before + # the data. We will retry this invocation and wait until the client + # is initialized on the cluster, which means schemas are replicated + # in the cluster. + if invocation.request.contains_data: + raise InvocationMightContainCompactDataError() + + async def _register_backup_listener(self): + codec = client_local_backup_listener_codec + request = codec.encode_request() + await self._listener_service.register_listener( + request, + codec.decode_response, + lambda reg_id: None, + lambda m: codec.handle(m, self._backup_event_handler), + ) + + def _backup_event_handler(self, correlation_id): + invocation = self._pending.get(correlation_id, None) + if not invocation: + _logger.debug("Invocation not found for backup event, invocation id %s", correlation_id) + return + self._notify_backup_complete(invocation) + + def _notify(self, invocation, client_message): + expected_backups = client_message.get_number_of_backup_acks() + if expected_backups > invocation.backup_acks_received: + invocation.pending_response_received_time = time.time() + invocation.backup_acks_expected = expected_backups + invocation.pending_response = client_message + return + + self._complete(invocation, client_message) + + def _notify_backup_complete(self, invocation): + invocation.backup_acks_received += 1 + if not invocation.pending_response: + return + + if invocation.backup_acks_expected != invocation.backup_acks_received: + return + + self._complete(invocation, invocation.pending_response) + + def _start_clean_resources_timer(self): + def run(): + if self._shutdown: + return + + now = time.time() + for invocation in list(self._pending.values()): + connection = invocation.sent_connection + if not connection: + continue + + if not connection.live: + error = TargetDisconnectedError(connection.close_reason) + self._notify_error(invocation, error) + continue + + if self._backup_ack_to_client_enabled: + self._detect_and_handle_backup_timeout(invocation, now) + + self._clean_resources_timer = self._reactor.add_timer(self._CLEAN_RESOURCES_PERIOD, run) + + self._clean_resources_timer = self._reactor.add_timer(self._CLEAN_RESOURCES_PERIOD, run) + + def _detect_and_handle_backup_timeout(self, invocation, now): + if not invocation.pending_response: + return + + if invocation.backup_acks_expected == invocation.backup_acks_received: + return + + expiration_time = invocation.pending_response_received_time + self._backup_timeout + timeout_reached = 0 < expiration_time < now + if not timeout_reached: + return + + if self._fail_on_indeterminate_state: + error = IndeterminateOperationStateError( + "Invocation failed because the backup acks are missed" + ) + self._complete_with_error(invocation, error) + return + + self._complete(invocation, invocation.pending_response) diff --git a/hazelcast/internal/asyncio_listener.py b/hazelcast/internal/asyncio_listener.py new file mode 100644 index 0000000000..dbd04956ea --- /dev/null +++ b/hazelcast/internal/asyncio_listener.py @@ -0,0 +1,300 @@ +import asyncio +import logging +import sys +import typing +from uuid import uuid4 + +from hazelcast.internal.asyncio_compact import CompactSchemaService +from hazelcast.errors import HazelcastError, HazelcastClientNotActiveError, TargetDisconnectedError +from hazelcast.internal.asyncio_invocation import Invocation +from hazelcast.internal.asyncio_reactor import AsyncioConnection +from hazelcast.protocol.client_message import InboundMessage +from hazelcast.protocol.codec import client_add_cluster_view_listener_codec +from hazelcast.serialization.compact import SchemaNotFoundError +from hazelcast.util import check_not_none + +_logger = logging.getLogger(__name__) + + +class _ListenerRegistration: + __slots__ = ( + "registration_request", + "decode_register_response", + "encode_deregister_request", + "handler", + "connection_registrations", + ) + + def __init__( + self, registration_request, decode_register_response, encode_deregister_request, handler + ): + self.registration_request = registration_request + self.decode_register_response = decode_register_response + self.encode_deregister_request = encode_deregister_request + self.handler = handler + self.connection_registrations = {} # Dict of Connection, EventRegistration + + +class _EventRegistration: + __slots__ = ("server_registration_id", "correlation_id") + + def __init__(self, server_registration_id, correlation_id): + self.server_registration_id = server_registration_id + self.correlation_id = correlation_id + + +class ListenerService: + def __init__( + self, + client, + config, + connection_manager, + invocation_service, + compact_schema_service: CompactSchemaService, + ): + self._client = client + self._connection_manager = connection_manager + self._invocation_service = invocation_service + self._compact_schema_service = compact_schema_service + self._is_smart = config.smart_routing + self._active_registrations: typing.Dict[str, _ListenerRegistration] = {} + self._registration_lock = asyncio.Lock() + self._event_handlers: typing.Dict[int, typing.Callable] = {} + + def start(self): + self._connection_manager.add_listener(self._connection_added, self._connection_removed) + + async def register_listener( + self, registration_request, decode_register_response, encode_deregister_request, handler + ): + async with self._registration_lock: + registration_id = str(uuid4()) + registration = _ListenerRegistration( + registration_request, decode_register_response, encode_deregister_request, handler + ) + self._active_registrations[registration_id] = registration + try: + async with asyncio.TaskGroup() as tg: + for connection in list(self._connection_manager.active_connections.values()): + task = self._register_on_connection( + registration_id, registration, connection + ) + tg.create_task(task) + return registration_id + except Exception: + await self.deregister_listener(registration_id) + raise HazelcastError("Listener cannot be added") + + async def deregister_listener(self, user_registration_id): + check_not_none(user_registration_id, "None user_registration_id is not allowed!") + async with self._registration_lock: + listener_registration = self._active_registrations.pop(user_registration_id, None) + if not listener_registration: + return False + + async def handle(inv: Invocation, conn: AsyncioConnection): + try: + await inv.future + except Exception as e: + if not isinstance( + e, (HazelcastClientNotActiveError, IOError, TargetDisconnectedError) + ): + _logger.warning( + "Deregistration of listener with ID %s has failed for address %s: %s", + user_registration_id, + conn.remote_address, + e, + ) + + async with asyncio.TaskGroup() as tg: + items = listener_registration.connection_registrations.items() + for connection, event_registration in items: + # Remove local handler + self.remove_event_handler(event_registration.correlation_id) + # The rest is for deleting the remote registration + server_registration_id = event_registration.server_registration_id + deregister_request = listener_registration.encode_deregister_request( + server_registration_id + ) + if deregister_request is None: + # None means no remote registration (e.g. for backup acks) + continue + invocation = Invocation( + deregister_request, connection=connection, timeout=sys.maxsize, urgent=True + ) + self._invocation_service.invoke(invocation) + tg.create_task(handle(invocation, connection)) + + listener_registration.connection_registrations.clear() + return True + + def handle_client_message(self, message: InboundMessage, correlation_id: int): + handler = self._event_handlers.get(correlation_id, None) + if handler: + try: + handler(message) + except SchemaNotFoundError as e: + self._fetch_schema_and_handle_again(e, handler, message) + else: + _logger.debug("Got event message with unknown correlation id: %s", message) + + def _fetch_schema_and_handle_again( + self, + error: SchemaNotFoundError, + handler: typing.Callable[[InboundMessage], None], + message: InboundMessage, + ) -> None: + schema_id = error.schema_id + + def callback(future): + try: + schema = future.result() + self._compact_schema_service.register_fetched_schema(schema_id, schema) + except Exception: + _logger.exception( + f"Failed to call event handler: {handler} with message: {message}" + ) + return + + message.reset_next_frame() + try: + handler(message) + except SchemaNotFoundError as e: + self._fetch_schema_and_handle_again(e, handler, message) + + fetch_schema_future = self._compact_schema_service.fetch_schema(schema_id) + fetch_schema_future.add_done_callback(callback) + + def add_event_handler(self, correlation_id, event_handler): + self._event_handlers[correlation_id] = event_handler + + def remove_event_handler(self, correlation_id): + self._event_handlers.pop(correlation_id, None) + + async def _register_on_connection( + self, user_registration_id, listener_registration, connection + ): + registration_map = listener_registration.connection_registrations + + if connection in registration_map: + return + + registration_request = listener_registration.registration_request.copy() + invocation = Invocation( + registration_request, + connection=connection, + event_handler=listener_registration.handler, + response_handler=lambda m: m, + urgent=True, + ) + self._invocation_service.invoke(invocation) + + def callback(f): + try: + response = f.result() + server_registration_id = listener_registration.decode_register_response(response) + correlation_id = registration_request.get_correlation_id() + registration = _EventRegistration(server_registration_id, correlation_id) + registration_map[connection] = registration + except Exception as e: + if connection.live: + _logger.exception( + "Listener %s can not be added to a new connection: %s", + user_registration_id, + connection, + ) + raise e + + invocation.future.add_done_callback(callback) + return await invocation.future + + async def _connection_added(self, connection): + async with self._registration_lock: + async with asyncio.TaskGroup() as tg: + for user_reg_id, listener_registration in self._active_registrations.items(): + task = self._register_on_connection( + user_reg_id, listener_registration, connection + ) + tg.create_task(task) + + async def _connection_removed(self, connection): + async with self._registration_lock: + for listener_registration in self._active_registrations.values(): + event_registration = listener_registration.connection_registrations.pop( + connection, None + ) + if event_registration: + self.remove_event_handler(event_registration.correlation_id) + + +class ClusterViewListenerService: + def __init__( + self, + client, + connection_manager, + partition_service, + cluster_service, + invocation_service, + ): + self._client = client + self._partition_service = partition_service + self._connection_manager = connection_manager + self._cluster_service = cluster_service + self._invocation_service = invocation_service + self._listener_added_connection = None + + def start(self): + self._connection_manager.add_listener(self._connection_added, self._connection_removed) + + def _connection_added(self, connection): + self._try_register(connection) + + def _connection_removed(self, connection): + self._try_register_to_random_connection(connection) + + def _try_register_to_random_connection(self, old_connection): + if self._listener_added_connection is not old_connection: + return + self._listener_added_connection = None + new_connection = self._connection_manager.get_random_connection() + if new_connection: + self._try_register(new_connection) + + def _try_register(self, connection): + if not self._connection_manager.live: + # There is no point on trying the register a backup listener + # if the client is about to shutdown. + return + + if self._listener_added_connection: + return + + self._listener_added_connection = connection + request = client_add_cluster_view_listener_codec.encode_request() + invocation = Invocation( + request, connection=connection, event_handler=self._handler(connection), urgent=True + ) + self._cluster_service.clear_member_list_version() + self._invocation_service.invoke(invocation) + + def callback(f): + try: + f.result() + except Exception: + self._try_register_to_random_connection(connection) + + invocation.future.add_done_callback(callback) + + def _handler(self, connection): + def handle_partitions_view_event(version, partitions): + self._partition_service.handle_partitions_view_event(connection, partitions, version) + + def handle_members_view_event(member_list_version, member_infos): + self._cluster_service.handle_members_view_event(member_list_version, member_infos) + + def inner(message): + client_add_cluster_view_listener_codec.handle( + message, handle_members_view_event, handle_partitions_view_event + ) + + return inner diff --git a/hazelcast/internal/asyncio_proxy/__init__.py b/hazelcast/internal/asyncio_proxy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/hazelcast/internal/asyncio_proxy/base.py b/hazelcast/internal/asyncio_proxy/base.py new file mode 100644 index 0000000000..60fb8de4ac --- /dev/null +++ b/hazelcast/internal/asyncio_proxy/base.py @@ -0,0 +1,297 @@ +import abc +import asyncio +import typing +import uuid + +from hazelcast.core import MemberInfo +from hazelcast.types import KeyType, ValueType, ItemType, MessageType, BlockingProxyType +from hazelcast.internal.asyncio_invocation import Invocation +from hazelcast.partition import string_partition_strategy +from hazelcast.util import get_attr_name + +MAX_SIZE = float("inf") + + +def _no_op_response_handler(_): + return None + + +class Proxy(typing.Generic[BlockingProxyType], abc.ABC): + """Provides basic functionality for Hazelcast Proxies.""" + + def __init__(self, service_name: str, name: str, context): + self.service_name = service_name + self.name = name + self._context = context + self._invocation_service = context.invocation_service + self._partition_service = context.partition_service + serialization_service = context.serialization_service + self._to_object = serialization_service.to_object + self._to_data = serialization_service.to_data + listener_service = context.listener_service + self._register_listener = listener_service.register_listener + self._deregister_listener = listener_service.deregister_listener + self._is_smart = context.config.smart_routing + self._send_schema_and_retry = context.compact_schema_service.send_schema_and_retry + + async def destroy(self) -> bool: + """Destroys this proxy. + + Returns: + ``True`` if this proxy is destroyed successfully, ``False`` + otherwise. + """ + self._on_destroy() + return await self._context.proxy_manager.destroy_proxy(self.service_name, self.name) + + def _on_destroy(self): + pass + + def __repr__(self) -> str: + return '%s(name="%s")' % (type(self).__name__, self.name) + + def _invoke(self, request, response_handler=_no_op_response_handler) -> asyncio.Future: + invocation = Invocation(request, response_handler=response_handler) + self._invocation_service.invoke(invocation) + return invocation.future + + def _invoke_on_target( + self, request, uuid, response_handler=_no_op_response_handler + ) -> asyncio.Future: + invocation = Invocation(request, uuid=uuid, response_handler=response_handler) + self._invocation_service.invoke(invocation) + return invocation.future + + def _invoke_on_key( + self, request, key_data, response_handler=_no_op_response_handler + ) -> asyncio.Future: + partition_id = self._partition_service.get_partition_id(key_data) + invocation = Invocation( + request, partition_id=partition_id, response_handler=response_handler + ) + self._invocation_service.invoke(invocation) + return invocation.future + + def _invoke_on_partition( + self, request, partition_id, response_handler=_no_op_response_handler + ) -> asyncio.Future: + invocation = Invocation( + request, partition_id=partition_id, response_handler=response_handler + ) + self._invocation_service.invoke(invocation) + return invocation.future + + async def _ainvoke_on_partition( + self, request, partition_id, response_handler=_no_op_response_handler + ) -> typing.Any: + fut = self._invoke_on_partition(request, partition_id, response_handler) + return await fut + + +class PartitionSpecificProxy(Proxy[BlockingProxyType], abc.ABC): + """Provides basic functionality for Partition Specific Proxies.""" + + def __init__(self, service_name, name, context): + super(PartitionSpecificProxy, self).__init__(service_name, name, context) + partition_key = context.serialization_service.to_data(string_partition_strategy(self.name)) + self._partition_id = context.partition_service.get_partition_id(partition_key) + + def _invoke(self, request, response_handler=_no_op_response_handler): + invocation = Invocation( + request, partition_id=self._partition_id, response_handler=response_handler + ) + self._invocation_service.invoke(invocation) + return invocation.future + + +class TransactionalProxy: + """Provides an interface for all transactional distributed objects.""" + + def __init__(self, name, transaction, context): + self.name = name + self.transaction = transaction + self._invocation_service = context.invocation_service + serialization_service = context.serialization_service + self._to_object = serialization_service.to_object + self._to_data = serialization_service.to_data + self._send_schema_and_retry = context.compact_schema_service.send_schema_and_retry + + def _send_schema(self, error): + return self._send_schema_and_retry(error, lambda: None).result() + + def _invoke(self, request, response_handler=_no_op_response_handler): + invocation = Invocation( + request, connection=self.transaction.connection, response_handler=response_handler + ) + self._invocation_service.invoke(invocation) + return invocation.future.result() + + def __repr__(self): + return '%s(name="%s")' % (type(self).__name__, self.name) + + +class ItemEventType: + """Type of item events.""" + + ADDED = 1 + """ + Fired when an item is added. + """ + + REMOVED = 2 + """ + Fired when an item is removed. + """ + + +class EntryEventType: + """Type of entry event.""" + + ADDED = 1 + """ + Fired if an entry is added. + """ + + REMOVED = 2 + """ + Fired if an entry is removed. + """ + + UPDATED = 4 + """ + Fired if an entry is updated. + """ + + EVICTED = 8 + """ + Fired if an entry is evicted. + """ + + EXPIRED = 16 + """ + Fired if an entry is expired. + """ + + EVICT_ALL = 32 + """ + Fired if all entries are evicted. + """ + + CLEAR_ALL = 64 + """ + Fired if all entries are cleared. + """ + + MERGED = 128 + """ + Fired if an entry is merged after a network partition. + """ + + INVALIDATION = 256 + """ + Fired if an entry is invalidated. + """ + + LOADED = 512 + """ + Fired if an entry is loaded. + """ + + +class ItemEvent(typing.Generic[ItemType]): + """Map Item event. + + Attributes: + name: Name of the proxy that fired the event. + item: The item related to the event. + event_type: Type of the event. + member: Member that fired the event. + """ + + def __init__(self, name: str, item: ItemEventType, event_type: int, member: MemberInfo): + self.name = name + self.item = item + self.event_type = event_type + self.member = member + + +class EntryEvent(typing.Generic[KeyType, ValueType]): + """Map Entry event. + + Attributes: + event_type: Type of the event. + uuid: UUID of the member that fired the event. + number_of_affected_entries: Number of affected entries by this event. + key: The key of this entry event. + value: The value of the entry event. + old_value: The old value of the entry event. + merging_value: The incoming merging value of the entry event. + """ + + def __init__( + self, + key: KeyType, + value: ValueType, + old_value: ValueType, + merging_value: ValueType, + event_type: int, + member_uuid: uuid.UUID, + number_of_affected_entries: int, + ): + self.key = key + self.value = value + self.old_value = old_value + self.merging_value = merging_value + self.event_type = event_type + self.uuid = member_uuid + self.number_of_affected_entries = number_of_affected_entries + + def __repr__(self): + return ( + "EntryEvent(key=%s, value=%s, old_value=%s, merging_value=%s, event_type=%s, uuid=%s, " + "number_of_affected_entries=%s)" + % ( + self.key, + self.value, + self.old_value, + self.merging_value, + get_attr_name(EntryEventType, self.event_type), + self.uuid, + self.number_of_affected_entries, + ) + ) + + +class TopicMessage(typing.Generic[MessageType]): + """Topic message. + + Attributes: + name: Name of the proxy that fired the event. + message: The message sent to Topic. + publish_time: UNIX time that the event is published as seconds. + member: Member that fired the event. + """ + + __slots__ = ("name", "message", "publish_time", "member") + + def __init__(self, name: str, message: MessageType, publish_time: int, member: MemberInfo): + self.name = name + self.message = message + self.publish_time = publish_time + self.member = member + + def __repr__(self): + return "TopicMessage(message=%s, publish_time=%s, topic_name=%s, publishing_member=%s)" % ( + self.message, + self.publish_time, + self.name, + self.member, + ) + + +def get_entry_listener_flags(**kwargs): + flags = 0 + for key, value in kwargs.items(): + if value: + flags |= getattr(EntryEventType, key) + return flags diff --git a/hazelcast/internal/asyncio_proxy/manager.py b/hazelcast/internal/asyncio_proxy/manager.py new file mode 100644 index 0000000000..6bf635bcfc --- /dev/null +++ b/hazelcast/internal/asyncio_proxy/manager.py @@ -0,0 +1,53 @@ +import typing + +from hazelcast.protocol.codec import client_create_proxy_codec, client_destroy_proxy_codec +from hazelcast.internal.asyncio_invocation import Invocation +from hazelcast.internal.asyncio_proxy.base import Proxy +from hazelcast.internal.asyncio_proxy.map import create_map_proxy +from hazelcast.util import to_list + +MAP_SERVICE = "hz:impl:mapService" + +_proxy_init: typing.Dict[str, typing.Callable[[str, str, typing.Any], Proxy]] = { + MAP_SERVICE: create_map_proxy, +} + + +class ProxyManager: + def __init__(self, context): + self._context = context + self._proxies = {} + + async def get_or_create(self, service_name, name, create_on_remote=True): + ns = (service_name, name) + if ns in self._proxies: + return self._proxies[ns] + + proxy = await self._create_proxy(service_name, name, create_on_remote) + self._proxies[ns] = proxy + return proxy + + async def _create_proxy(self, service_name, name, create_on_remote) -> Proxy: + if create_on_remote: + request = client_create_proxy_codec.encode_request(name, service_name) + invocation = Invocation(request) + invocation_service = self._context.invocation_service + await invocation_service.ainvoke(invocation) + + return _proxy_init[service_name](service_name, name, self._context) + + async def destroy_proxy(self, service_name, name, destroy_on_remote=True): + ns = (service_name, name) + try: + self._proxies.pop(ns) + if destroy_on_remote: + request = client_destroy_proxy_codec.encode_request(name, service_name) + invocation = Invocation(request) + invocation_service = self._context.invocation_service + await invocation_service.ainvoke(invocation) + return True + except KeyError: + return False + + def get_distributed_objects(self): + return to_list(self._proxies.values()) diff --git a/hazelcast/internal/asyncio_proxy/map.py b/hazelcast/internal/asyncio_proxy/map.py new file mode 100644 index 0000000000..9f2f765ec1 --- /dev/null +++ b/hazelcast/internal/asyncio_proxy/map.py @@ -0,0 +1,978 @@ +import asyncio +import itertools +import typing + +from hazelcast.aggregator import Aggregator +from hazelcast.config import IndexUtil, IndexType, IndexConfig +from hazelcast.core import SimpleEntryView +from hazelcast.errors import InvalidConfigurationError +from hazelcast.projection import Projection +from hazelcast.protocol import PagingPredicateHolder +from hazelcast.protocol.codec import ( + map_add_entry_listener_codec, + map_add_entry_listener_to_key_codec, + map_add_entry_listener_with_predicate_codec, + map_add_entry_listener_to_key_with_predicate_codec, + map_clear_codec, + map_contains_key_codec, + map_contains_value_codec, + map_delete_codec, + map_entry_set_codec, + map_entries_with_predicate_codec, + map_evict_codec, + map_evict_all_codec, + map_flush_codec, + map_get_codec, + map_get_all_codec, + map_get_entry_view_codec, + map_is_empty_codec, + map_key_set_codec, + map_key_set_with_predicate_codec, + map_load_all_codec, + map_load_given_keys_codec, + map_put_codec, + map_put_all_codec, + map_put_if_absent_codec, + map_put_transient_codec, + map_size_codec, + map_remove_codec, + map_remove_if_same_codec, + map_remove_entry_listener_codec, + map_replace_codec, + map_replace_if_same_codec, + map_set_codec, + map_try_put_codec, + map_try_remove_codec, + map_values_codec, + map_values_with_predicate_codec, + map_add_interceptor_codec, + map_aggregate_codec, + map_aggregate_with_predicate_codec, + map_project_codec, + map_project_with_predicate_codec, + map_execute_on_all_keys_codec, + map_execute_on_key_codec, + map_execute_on_keys_codec, + map_execute_with_predicate_codec, + map_add_index_codec, + map_set_ttl_codec, + map_entries_with_paging_predicate_codec, + map_key_set_with_paging_predicate_codec, + map_values_with_paging_predicate_codec, + map_put_with_max_idle_codec, + map_put_if_absent_with_max_idle_codec, + map_put_transient_with_max_idle_codec, + map_set_with_max_idle_codec, + map_remove_interceptor_codec, + map_remove_all_codec, +) +from hazelcast.internal.asyncio_proxy.base import ( + Proxy, + EntryEvent, + EntryEventType, + get_entry_listener_flags, +) +from hazelcast.predicate import Predicate, _PagingPredicate +from hazelcast.serialization.data import Data +from hazelcast.types import AggregatorResultType, KeyType, ValueType, ProjectionType +from hazelcast.serialization.compact import SchemaNotReplicatedError +from hazelcast.util import ( + check_not_none, + thread_id, + to_millis, + IterationType, + deserialize_entry_list_in_place, + deserialize_list_in_place, +) + + +EntryEventCallable = typing.Callable[[EntryEvent[KeyType, ValueType]], None] + + +class Map(Proxy, typing.Generic[KeyType, ValueType]): + def __init__(self, service_name, name, context): + super(Map, self).__init__(service_name, name, context) + self._reference_id_generator = context.lock_reference_id_generator + + async def add_entry_listener( + self, + include_value: bool = False, + key: KeyType = None, + predicate: Predicate = None, + added_func: EntryEventCallable = None, + removed_func: EntryEventCallable = None, + updated_func: EntryEventCallable = None, + evicted_func: EntryEventCallable = None, + evict_all_func: EntryEventCallable = None, + clear_all_func: EntryEventCallable = None, + merged_func: EntryEventCallable = None, + expired_func: EntryEventCallable = None, + loaded_func: EntryEventCallable = None, + ) -> str: + flags = get_entry_listener_flags( + ADDED=added_func, + REMOVED=removed_func, + UPDATED=updated_func, + EVICTED=evicted_func, + EXPIRED=expired_func, + EVICT_ALL=evict_all_func, + CLEAR_ALL=clear_all_func, + MERGED=merged_func, + LOADED=loaded_func, + ) + if key is not None and predicate is not None: + try: + key_data = self._to_data(key) + predicate_data = self._to_data(predicate) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry( + e, + self.add_entry_listener, + include_value, + key, + predicate, + added_func, + removed_func, + updated_func, + evicted_func, + evict_all_func, + clear_all_func, + merged_func, + expired_func, + loaded_func, + ) + with_key_and_predicate_codec = map_add_entry_listener_to_key_with_predicate_codec + request = with_key_and_predicate_codec.encode_request( + self.name, key_data, predicate_data, include_value, flags, self._is_smart + ) + response_decoder = with_key_and_predicate_codec.decode_response + event_message_handler = with_key_and_predicate_codec.handle + elif key is not None and predicate is None: + try: + key_data = self._to_data(key) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry( + e, + self.add_entry_listener, + include_value, + key, + predicate, + added_func, + removed_func, + updated_func, + evicted_func, + evict_all_func, + clear_all_func, + merged_func, + expired_func, + loaded_func, + ) + + with_key_codec = map_add_entry_listener_to_key_codec + request = with_key_codec.encode_request( + self.name, key_data, include_value, flags, self._is_smart + ) + response_decoder = with_key_codec.decode_response + event_message_handler = with_key_codec.handle + elif key is None and predicate is not None: + try: + predicate = self._to_data(predicate) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry( + e, + self.add_entry_listener, + include_value, + key, + predicate, + added_func, + removed_func, + updated_func, + evicted_func, + evict_all_func, + clear_all_func, + merged_func, + expired_func, + loaded_func, + ) + with_predicate_codec = map_add_entry_listener_with_predicate_codec + request = with_predicate_codec.encode_request( + self.name, predicate, include_value, flags, self._is_smart + ) + response_decoder = with_predicate_codec.decode_response + event_message_handler = with_predicate_codec.handle + else: + codec = map_add_entry_listener_codec + request = codec.encode_request(self.name, include_value, flags, self._is_smart) + response_decoder = codec.decode_response + event_message_handler = codec.handle + + def handle_event_entry( + key_data, + value_data, + old_value_data, + merging_value_data, + event_type, + uuid, + number_of_affected_entries, + ): + event = EntryEvent( + self._to_object(key_data), + self._to_object(value_data), + self._to_object(old_value_data), + self._to_object(merging_value_data), + event_type, + uuid, + number_of_affected_entries, + ) + if event.event_type == EntryEventType.ADDED: + added_func(event) + elif event.event_type == EntryEventType.REMOVED: + removed_func(event) + elif event.event_type == EntryEventType.UPDATED: + updated_func(event) + elif event.event_type == EntryEventType.EVICTED: + evicted_func(event) + elif event.event_type == EntryEventType.EVICT_ALL: + evict_all_func(event) + elif event.event_type == EntryEventType.CLEAR_ALL: + clear_all_func(event) + elif event.event_type == EntryEventType.MERGED: + merged_func(event) + elif event.event_type == EntryEventType.EXPIRED: + expired_func(event) + elif event.event_type == EntryEventType.LOADED: + loaded_func(event) + + return await self._register_listener( + request, + lambda r: response_decoder(r), + lambda reg_id: map_remove_entry_listener_codec.encode_request(self.name, reg_id), + lambda m: event_message_handler(m, handle_event_entry), + ) + + async def add_index( + self, + attributes: typing.Sequence[str] = None, + index_type: typing.Union[int, str] = IndexType.SORTED, + name: str = None, + bitmap_index_options: typing.Dict[str, typing.Any] = None, + ) -> None: + d = { + "name": name, + "type": index_type, + "attributes": attributes, + "bitmap_index_options": bitmap_index_options, + } + config = IndexConfig.from_dict(d) + validated = IndexUtil.validate_and_normalize(self.name, config) + request = map_add_index_codec.encode_request(self.name, validated) + return await self._invoke(request) + + async def add_interceptor(self, interceptor: typing.Any) -> str: + try: + interceptor_data = self._to_data(interceptor) + except SchemaNotReplicatedError as e: + return self._send_schema_and_retry(e, self.add_interceptor, interceptor) + + request = map_add_interceptor_codec.encode_request(self.name, interceptor_data) + return await self._invoke(request, map_add_interceptor_codec.decode_response) + + async def aggregate( + self, aggregator: Aggregator[AggregatorResultType], predicate: Predicate = None + ) -> AggregatorResultType: + check_not_none(aggregator, "aggregator can't be none") + if predicate: + if isinstance(predicate, _PagingPredicate): + raise AssertionError("Paging predicate is not supported.") + + try: + aggregator_data = self._to_data(aggregator) + predicate_data = self._to_data(predicate) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.aggregate, aggregator, predicate) + + def handler(message): + return self._to_object(map_aggregate_with_predicate_codec.decode_response(message)) + + request = map_aggregate_with_predicate_codec.encode_request( + self.name, aggregator_data, predicate_data + ) + else: + try: + aggregator_data = self._to_data(aggregator) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.aggregate, aggregator, predicate) + + def handler(message): + return self._to_object(map_aggregate_codec.decode_response(message)) + + request = map_aggregate_codec.encode_request(self.name, aggregator_data) + + return await self._invoke(request, handler) + + async def clear(self) -> None: + request = map_clear_codec.encode_request(self.name) + return await self._invoke(request) + + async def contains_key(self, key: KeyType) -> bool: + check_not_none(key, "key can't be None") + try: + key_data = self._to_data(key) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.contains_key, key) + + return await self._contains_key_internal(key_data) + + async def contains_value(self, value: ValueType) -> bool: + check_not_none(value, "value can't be None") + try: + value_data = self._to_data(value) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.contains_value, value) + request = map_contains_value_codec.encode_request(self.name, value_data) + return await self._invoke(request, map_contains_value_codec.decode_response) + + async def delete(self, key: KeyType) -> None: + check_not_none(key, "key can't be None") + try: + key_data = self._to_data(key) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.delete, key) + return await self._delete_internal(key_data) + + async def entry_set( + self, predicate: Predicate = None + ) -> typing.List[typing.Tuple[KeyType, ValueType]]: + if predicate: + if isinstance(predicate, _PagingPredicate): + predicate.iteration_type = IterationType.ENTRY + try: + holder = PagingPredicateHolder.of(predicate, self._to_data) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.entry_set, predicate) + + def handler(message): + response = map_entries_with_paging_predicate_codec.decode_response(message) + predicate.anchor_list = response["anchor_data_list"].as_anchor_list( + self._to_object + ) + entry_data_list = response["response"] + return deserialize_entry_list_in_place(entry_data_list, self._to_object) + + request = map_entries_with_paging_predicate_codec.encode_request(self.name, holder) + else: + try: + predicate_data = self._to_data(predicate) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.entry_set, predicate) + + def handler(message): + entry_data_list = map_entries_with_predicate_codec.decode_response(message) + return deserialize_entry_list_in_place(entry_data_list, self._to_object) + + request = map_entries_with_predicate_codec.encode_request(self.name, predicate_data) + else: + + def handler(message): + entry_data_list = map_entry_set_codec.decode_response(message) + return deserialize_entry_list_in_place(entry_data_list, self._to_object) + + request = map_entry_set_codec.encode_request(self.name) + + return await self._invoke(request, handler) + + async def evict(self, key: KeyType) -> bool: + check_not_none(key, "key can't be None") + try: + key_data = self._to_data(key) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.evict, key) + + return await self._evict_internal(key_data) + + async def evict_all(self) -> None: + request = map_evict_all_codec.encode_request(self.name) + return await self._invoke(request) + + async def execute_on_entries( + self, entry_processor: typing.Any, predicate: Predicate | None = None + ) -> typing.List[typing.Any]: + if predicate: + try: + entry_processor_data = self._to_data(entry_processor) + predicate_data = self._to_data(predicate) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry( + e, self.execute_on_entries, entry_processor, predicate + ) + + def handler(message): + entry_data_list = map_execute_with_predicate_codec.decode_response(message) + return deserialize_entry_list_in_place(entry_data_list, self._to_object) + + request = map_execute_with_predicate_codec.encode_request( + self.name, entry_processor_data, predicate_data + ) + else: + try: + entry_processor_data = self._to_data(entry_processor) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry( + e, self.execute_on_entries, entry_processor, predicate + ) + + def handler(message): + entry_data_list = map_execute_on_all_keys_codec.decode_response(message) + return deserialize_entry_list_in_place(entry_data_list, self._to_object) + + request = map_execute_on_all_keys_codec.encode_request(self.name, entry_processor_data) + + return await self._invoke(request, handler) + + async def execute_on_key(self, key: KeyType, entry_processor: typing.Any) -> typing.Any: + check_not_none(key, "key can't be None") + try: + key_data = self._to_data(key) + entry_processor_data = self._to_data(entry_processor) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.execute_on_key, key, entry_processor) + + return await self._execute_on_key_internal(key_data, entry_processor_data) + + async def execute_on_keys( + self, keys: typing.Sequence[KeyType], entry_processor: typing.Any + ) -> typing.List[typing.Any]: + if len(keys) == 0: + return [] + try: + key_list = [] + for key in keys: + check_not_none(key, "key can't be None") + key_list.append(self._to_data(key)) + + entry_processor_data = self._to_data(entry_processor) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.execute_on_keys, keys, entry_processor) + + def handler(message): + entry_data_list = map_execute_on_keys_codec.decode_response(message) + return deserialize_entry_list_in_place(entry_data_list, self._to_object) + + request = map_execute_on_keys_codec.encode_request( + self.name, entry_processor_data, key_list + ) + return await self._invoke(request, handler) + + async def flush(self) -> None: + request = map_flush_codec.encode_request(self.name) + return await self._invoke(request) + + async def get(self, key: KeyType) -> typing.Optional[ValueType]: + check_not_none(key, "key can't be None") + try: + key_data = self._to_data(key) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.get, key) + return await self._get_internal(key_data) + + async def get_all(self, keys: typing.Sequence[KeyType]) -> typing.Dict[KeyType, ValueType]: + check_not_none(keys, "keys can't be None") + if not keys: + return {} + partition_service = self._context.partition_service + partition_to_keys: typing.Dict[int, typing.Dict[KeyType, Data]] = {} + for key in keys: + check_not_none(key, "key can't be None") + try: + key_data = self._to_data(key) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.get_all, keys) + partition_id = partition_service.get_partition_id(key_data) + try: + partition_to_keys[partition_id][key] = key_data + except KeyError: + partition_to_keys[partition_id] = {key: key_data} + + return await self._get_all_internal(partition_to_keys) + + async def get_entry_view(self, key: KeyType) -> SimpleEntryView[KeyType, ValueType]: + check_not_none(key, "key can't be None") + try: + key_data = self._to_data(key) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.get_entry_view, key) + + def handler(message): + response = map_get_entry_view_codec.decode_response(message) + entry_view = response["response"] + if not entry_view: + return None + entry_view.key = self._to_object(entry_view.key) + entry_view.value = self._to_object(entry_view.value) + return entry_view + + request = map_get_entry_view_codec.encode_request(self.name, key_data, thread_id()) + return await self._invoke_on_key(request, key_data, handler) + + async def is_empty(self) -> bool: + request = map_is_empty_codec.encode_request(self.name) + return await self._invoke(request, map_is_empty_codec.decode_response) + + async def key_set(self, predicate: Predicate | None = None) -> typing.List[ValueType]: + if predicate: + if isinstance(predicate, _PagingPredicate): + predicate.iteration_type = IterationType.KEY + + try: + holder = PagingPredicateHolder.of(predicate, self._to_data) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.key_set, predicate) + + def handler(message): + response = map_key_set_with_paging_predicate_codec.decode_response(message) + predicate.anchor_list = response["anchor_data_list"].as_anchor_list( + self._to_object + ) + data_list = response["response"] + return deserialize_list_in_place(data_list, self._to_object) + + request = map_key_set_with_paging_predicate_codec.encode_request(self.name, holder) + else: + try: + predicate_data = self._to_data(predicate) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.key_set, predicate) + + def handler(message): + data_list = map_key_set_with_predicate_codec.decode_response(message) + return deserialize_list_in_place(data_list, self._to_object) + + request = map_key_set_with_predicate_codec.encode_request(self.name, predicate_data) + else: + + def handler(message): + data_list = map_key_set_codec.decode_response(message) + return deserialize_list_in_place(data_list, self._to_object) + + request = map_key_set_codec.encode_request(self.name) + + return await self._invoke(request, handler) + + async def load_all( + self, keys: typing.Sequence[KeyType] = None, replace_existing_values: bool = True + ) -> None: + if keys: + try: + key_data_list = [self._to_data(key) for key in keys] + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry( + e, self.load_all, keys, replace_existing_values + ) + + return await self._load_all_internal(key_data_list, replace_existing_values) + + request = map_load_all_codec.encode_request(self.name, replace_existing_values) + return await self._invoke(request) + + async def project( + self, projection: Projection[ProjectionType], predicate: Predicate = None + ) -> ProjectionType: + check_not_none(projection, "Projection can't be none") + if predicate: + if isinstance(predicate, _PagingPredicate): + raise AssertionError("Paging predicate is not supported.") + try: + projection_data = self._to_data(projection) + predicate_data = self._to_data(predicate) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.project, projection, predicate) + + def handler(message): + data_list = map_project_with_predicate_codec.decode_response(message) + return deserialize_list_in_place(data_list, self._to_object) + + request = map_project_with_predicate_codec.encode_request( + self.name, projection_data, predicate_data + ) + else: + try: + projection_data = self._to_data(projection) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.project, projection, predicate) + + def handler(message): + data_list = map_project_codec.decode_response(message) + return deserialize_list_in_place(data_list, self._to_object) + + request = map_project_codec.encode_request(self.name, projection_data) + + return await self._invoke(request, handler) + + async def put( + self, key: KeyType, value: ValueType, ttl: float = None, max_idle: float = None + ) -> typing.Optional[ValueType]: + check_not_none(key, "key can't be None") + check_not_none(value, "value can't be None") + try: + key_data = self._to_data(key) + value_data = self._to_data(value) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.put, key, value, ttl, max_idle) + + return await self._put_internal(key_data, value_data, ttl, max_idle) + + async def put_all(self, map: typing.Dict[KeyType, ValueType]) -> None: + check_not_none(map, "map can't be None") + if not map: + return None + partition_service = self._context.partition_service + partition_map: typing.Dict[int, typing.List[typing.Tuple[Data, Data]]] = {} + for key, value in map.items(): + check_not_none(key, "key can't be None") + check_not_none(value, "value can't be None") + try: + entry = (self._to_data(key), self._to_data(value)) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.put_all, map) + partition_id = partition_service.get_partition_id(entry[0]) + try: + partition_map[partition_id].append(entry) + except KeyError: + partition_map[partition_id] = [entry] + + async with asyncio.TaskGroup() as tg: # type: ignore[attr-defined] + for partition_id, entry_list in partition_map.items(): + request = map_put_all_codec.encode_request( + self.name, entry_list, False + ) # TODO trigger map loader + tg.create_task(self._ainvoke_on_partition(request, partition_id)) + return None + + async def put_if_absent( + self, key: KeyType, value: ValueType, ttl: float = None, max_idle: float = None + ) -> typing.Optional[ValueType]: + check_not_none(key, "key can't be None") + check_not_none(value, "value can't be None") + try: + key_data = self._to_data(key) + value_data = self._to_data(value) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry( + e, self.put_if_absent, key, value, ttl, max_idle + ) + + return await self._put_if_absent_internal(key_data, value_data, ttl, max_idle) + + async def put_transient( + self, key: KeyType, value: ValueType, ttl: float = None, max_idle: float = None + ) -> None: + check_not_none(key, "key can't be None") + check_not_none(value, "value can't be None") + try: + key_data = self._to_data(key) + value_data = self._to_data(value) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry( + e, self.put_transient, key, value, ttl, max_idle + ) + + return await self._put_transient_internal(key_data, value_data, ttl, max_idle) + + async def remove(self, key: KeyType) -> typing.Optional[ValueType]: + check_not_none(key, "key can't be None") + try: + key_data = self._to_data(key) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.remove, key) + + return await self._remove_internal(key_data) + + async def remove_all(self, predicate: Predicate) -> None: + check_not_none(predicate, "predicate can't be None") + try: + predicate_data = self._to_data(predicate) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.remove_all, predicate) + + return await self._remove_all_internal(predicate_data) + + async def remove_if_same(self, key: KeyType, value: ValueType) -> bool: + check_not_none(key, "key can't be None") + check_not_none(value, "value can't be None") + try: + key_data = self._to_data(key) + value_data = self._to_data(value) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.remove_if_same, key, value) + return await self._remove_if_same_internal_(key_data, value_data) + + async def remove_entry_listener(self, registration_id: str) -> bool: + return await self._deregister_listener(registration_id) + + async def remove_interceptor(self, registration_id: str) -> bool: + check_not_none(registration_id, "Interceptor registration id should not be None") + request = map_remove_interceptor_codec.encode_request(self.name, registration_id) + return await self._invoke(request, map_remove_interceptor_codec.decode_response) + + async def replace(self, key: KeyType, value: ValueType) -> typing.Optional[ValueType]: + check_not_none(key, "key can't be None") + check_not_none(value, "value can't be None") + try: + key_data = self._to_data(key) + value_data = self._to_data(value) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.replace, key, value) + return await self._replace_internal(key_data, value_data) + + async def replace_if_same( + self, key: ValueType, old_value: ValueType, new_value: ValueType + ) -> bool: + check_not_none(key, "key can't be None") + check_not_none(old_value, "old_value can't be None") + check_not_none(new_value, "new_value can't be None") + try: + key_data = self._to_data(key) + old_value_data = self._to_data(old_value) + new_value_data = self._to_data(new_value) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry( + e, self.replace_if_same, key, old_value, new_value + ) + + return await self._replace_if_same_internal(key_data, old_value_data, new_value_data) + + async def set( + self, key: KeyType, value: ValueType, ttl: float = None, max_idle: float = None + ) -> None: + check_not_none(key, "key can't be None") + check_not_none(value, "value can't be None") + try: + key_data = self._to_data(key) + value_data = self._to_data(value) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.set, key, value, ttl, max_idle) + return await self._set_internal(key_data, value_data, ttl, max_idle) + + async def set_ttl(self, key: KeyType, ttl: float) -> None: + check_not_none(key, "key can't be None") + check_not_none(ttl, "ttl can't be None") + try: + key_data = self._to_data(key) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.set_ttl, key, ttl) + return await self._set_ttl_internal(key_data, ttl) + + async def size(self) -> int: + request = map_size_codec.encode_request(self.name) + return await self._invoke(request, map_size_codec.decode_response) + + async def try_put(self, key: KeyType, value: ValueType, timeout: float = 0) -> bool: + check_not_none(key, "key can't be None") + check_not_none(value, "value can't be None") + try: + key_data = self._to_data(key) + value_data = self._to_data(value) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.try_put, key, value, timeout) + return await self._try_put_internal(key_data, value_data, timeout) + + async def try_remove(self, key: KeyType, timeout: float = 0) -> bool: + check_not_none(key, "key can't be None") + try: + key_data = self._to_data(key) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.try_remove, key, timeout) + return await self._try_remove_internal(key_data, timeout) + + async def values(self, predicate: Predicate = None) -> typing.List[ValueType]: + if predicate: + if isinstance(predicate, _PagingPredicate): + predicate.iteration_type = IterationType.VALUE + + try: + holder = PagingPredicateHolder.of(predicate, self._to_data) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.values, predicate) + + def handler(message): + response = map_values_with_paging_predicate_codec.decode_response(message) + predicate.anchor_list = response["anchor_data_list"].as_anchor_list( + self._to_object + ) + data_list = response["response"] + return deserialize_list_in_place(data_list, self._to_object) + + request = map_values_with_paging_predicate_codec.encode_request(self.name, holder) + else: + try: + predicate_data = self._to_data(predicate) + except SchemaNotReplicatedError as e: + return await self._send_schema_and_retry(e, self.values, predicate) + + def handler(message): + data_list = map_values_with_predicate_codec.decode_response(message) + return deserialize_list_in_place(data_list, self._to_object) + + request = map_values_with_predicate_codec.encode_request(self.name, predicate_data) + else: + + def handler(message): + data_list = map_values_codec.decode_response(message) + return deserialize_list_in_place(data_list, self._to_object) + + request = map_values_codec.encode_request(self.name) + + return await self._invoke(request, handler) + + def _contains_key_internal(self, key_data): + request = map_contains_key_codec.encode_request(self.name, key_data, thread_id()) + return self._invoke_on_key(request, key_data, map_contains_key_codec.decode_response) + + def _get_internal(self, key_data): + def handler(message): + return self._to_object(map_get_codec.decode_response(message)) + + request = map_get_codec.encode_request(self.name, key_data, thread_id()) + return self._invoke_on_key(request, key_data, handler) + + async def _get_all_internal(self, partition_to_keys, tasks=None): + def handler(message): + entry_data_list = map_get_all_codec.decode_response(message) + return deserialize_entry_list_in_place(entry_data_list, self._to_object) + + tasks = tasks or [] + async with asyncio.TaskGroup() as tg: + for partition_id, key_dict in partition_to_keys.items(): + request = map_get_all_codec.encode_request(self.name, key_dict.values()) + task = tg.create_task(self._ainvoke_on_partition(request, partition_id, handler)) + tasks.append(task) + kvs = itertools.chain.from_iterable(task.result() for task in tasks) + return dict(kvs) + + def _remove_internal(self, key_data): + def handler(message): + return self._to_object(map_remove_codec.decode_response(message)) + + request = map_remove_codec.encode_request(self.name, key_data, thread_id()) + return self._invoke_on_key(request, key_data, handler) + + def _remove_all_internal(self, predicate_data): + request = map_remove_all_codec.encode_request(self.name, predicate_data) + return self._invoke(request) + + def _remove_if_same_internal_(self, key_data, value_data): + request = map_remove_if_same_codec.encode_request( + self.name, key_data, value_data, thread_id() + ) + return self._invoke_on_key( + request, key_data, response_handler=map_remove_if_same_codec.decode_response + ) + + def _delete_internal(self, key_data): + request = map_delete_codec.encode_request(self.name, key_data, thread_id()) + return self._invoke_on_key(request, key_data) + + def _put_internal(self, key_data, value_data, ttl, max_idle): + def handler(message): + return self._to_object(map_put_codec.decode_response(message)) + + if max_idle is not None: + request = map_put_with_max_idle_codec.encode_request( + self.name, key_data, value_data, thread_id(), to_millis(ttl), to_millis(max_idle) + ) + else: + request = map_put_codec.encode_request( + self.name, key_data, value_data, thread_id(), to_millis(ttl) + ) + return self._invoke_on_key(request, key_data, handler) + + def _set_internal(self, key_data, value_data, ttl, max_idle): + if max_idle is not None: + request = map_set_with_max_idle_codec.encode_request( + self.name, key_data, value_data, thread_id(), to_millis(ttl), to_millis(max_idle) + ) + else: + request = map_set_codec.encode_request( + self.name, key_data, value_data, thread_id(), to_millis(ttl) + ) + return self._invoke_on_key(request, key_data) + + def _set_ttl_internal(self, key_data, ttl): + request = map_set_ttl_codec.encode_request(self.name, key_data, to_millis(ttl)) + return self._invoke_on_key(request, key_data, map_set_ttl_codec.decode_response) + + def _try_remove_internal(self, key_data, timeout): + request = map_try_remove_codec.encode_request( + self.name, key_data, thread_id(), to_millis(timeout) + ) + return self._invoke_on_key(request, key_data, map_try_remove_codec.decode_response) + + def _try_put_internal(self, key_data, value_data, timeout): + request = map_try_put_codec.encode_request( + self.name, key_data, value_data, thread_id(), to_millis(timeout) + ) + return self._invoke_on_key(request, key_data, map_try_put_codec.decode_response) + + def _put_transient_internal(self, key_data, value_data, ttl, max_idle): + if max_idle is not None: + request = map_put_transient_with_max_idle_codec.encode_request( + self.name, key_data, value_data, thread_id(), to_millis(ttl), to_millis(max_idle) + ) + else: + request = map_put_transient_codec.encode_request( + self.name, key_data, value_data, thread_id(), to_millis(ttl) + ) + return self._invoke_on_key(request, key_data) + + def _put_if_absent_internal(self, key_data, value_data, ttl, max_idle): + def handler(message): + return self._to_object(map_put_if_absent_codec.decode_response(message)) + + if max_idle is not None: + request = map_put_if_absent_with_max_idle_codec.encode_request( + self.name, key_data, value_data, thread_id(), to_millis(ttl), to_millis(max_idle) + ) + else: + request = map_put_if_absent_codec.encode_request( + self.name, key_data, value_data, thread_id(), to_millis(ttl) + ) + return self._invoke_on_key(request, key_data, handler) + + def _replace_if_same_internal(self, key_data, old_value_data, new_value_data): + request = map_replace_if_same_codec.encode_request( + self.name, key_data, old_value_data, new_value_data, thread_id() + ) + return self._invoke_on_key(request, key_data, map_replace_if_same_codec.decode_response) + + def _replace_internal(self, key_data, value_data): + def handler(message): + return self._to_object(map_replace_codec.decode_response(message)) + + request = map_replace_codec.encode_request(self.name, key_data, value_data, thread_id()) + return self._invoke_on_key(request, key_data, handler) + + def _evict_internal(self, key_data): + request = map_evict_codec.encode_request(self.name, key_data, thread_id()) + return self._invoke_on_key(request, key_data, map_evict_codec.decode_response) + + def _load_all_internal(self, key_data_list, replace_existing_values): + request = map_load_given_keys_codec.encode_request( + self.name, key_data_list, replace_existing_values + ) + return self._invoke(request) + + def _execute_on_key_internal(self, key_data, entry_processor_data): + def handler(message): + return self._to_object(map_execute_on_key_codec.decode_response(message)) + + request = map_execute_on_key_codec.encode_request( + self.name, entry_processor_data, key_data, thread_id() + ) + return self._invoke_on_key(request, key_data, handler) + + +def create_map_proxy(service_name, name, context): + near_cache_config = context.config.near_caches.get(name, None) + if near_cache_config is None: + return Map(service_name, name, context) + raise InvalidConfigurationError("near cache is not supported") diff --git a/hazelcast/internal/asyncio_reactor.py b/hazelcast/internal/asyncio_reactor.py new file mode 100644 index 0000000000..a44d656449 --- /dev/null +++ b/hazelcast/internal/asyncio_reactor.py @@ -0,0 +1,226 @@ +import asyncio +import io +import logging +import ssl +import time +from asyncio import AbstractEventLoop, transports + +from hazelcast.config import Config, SSLProtocol +from hazelcast.internal.asyncio_connection import Connection +from hazelcast.core import Address + +_BUFFER_SIZE = 128000 + + +_logger = logging.getLogger(__name__) + + +class AsyncioReactor: + def __init__(self, loop: AbstractEventLoop | None = None): + self._is_live = False + self._loop = loop or asyncio.get_running_loop() + self._bytes_sent = 0 + self._bytes_received = 0 + + def add_timer(self, delay, callback): + return self._loop.call_later(delay, callback) + + def start(self): + self._is_live = True + + def shutdown(self): + if not self._is_live: + return + # TODO: cancel tasks + + async def connection_factory( + self, connection_manager, connection_id, address: Address, network_config, message_callback + ): + return await AsyncioConnection.create_and_connect( + self._loop, + self, + connection_manager, + connection_id, + address, + network_config, + message_callback, + ) + + def update_bytes_sent(self, sent: int): + self._bytes_sent += sent + + def update_bytes_received(self, received: int): + self._bytes_received += received + + +class AsyncioConnection(Connection): + def __init__( + self, + loop, + reactor: AsyncioReactor, + connection_manager, + connection_id, + address, + config, + message_callback, + ): + super().__init__(connection_manager, connection_id, message_callback) + self._loop = loop + self._reactor = reactor + self._address = address + self._config = config + self._proto = None + + @classmethod + async def create_and_connect( + cls, + loop, + reactor: AsyncioReactor, + connection_manager, + connection_id, + address, + config, + message_callback, + ): + this = cls( + loop, reactor, connection_manager, connection_id, address, config, message_callback + ) + await this._create_connection(config, address) + return this + + def _create_protocol(self): + return HazelcastProtocol(self) + + async def _create_connection(self, config, address): + ssl_context = None + if config.ssl_enabled: + ssl_context = self._create_ssl_context(config) + server_hostname = None + if config.ssl_check_hostname: + server_hostname = address.host + res = await self._loop.create_connection( + self._create_protocol, + host=self._address.host, + port=self._address.port, + ssl=ssl_context, + server_hostname=server_hostname, + ) + _sock, self._proto = res + + def _write(self, buf): + self._proto.write(buf) + + def _inner_close(self): + self._proto.close() + + def _update_read_time(self, time): + self.last_read_time = time + + def _update_write_time(self, time): + self.last_write_time = time + + def _update_sent(self, sent): + self._reactor.update_bytes_sent(sent) + + def _update_received(self, received): + self._reactor.update_bytes_received(received) + + def _create_ssl_context(self, config: Config): + ssl_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23) + protocol = config.ssl_protocol + # Use only the configured protocol + try: + if protocol != SSLProtocol.SSLv2: + ssl_context.options |= ssl.OP_NO_SSLv2 + if protocol != SSLProtocol.SSLv3: + ssl_context.options |= ssl.OP_NO_SSLv3 + if protocol != SSLProtocol.TLSv1: + ssl_context.options |= ssl.OP_NO_TLSv1 + if protocol != SSLProtocol.TLSv1_1: + ssl_context.options |= ssl.OP_NO_TLSv1_1 + if protocol != SSLProtocol.TLSv1_2: + ssl_context.options |= ssl.OP_NO_TLSv1_2 + if protocol != SSLProtocol.TLSv1_3: + ssl_context.options |= ssl.OP_NO_TLSv1_3 + except AttributeError: + pass + + ssl_context.verify_mode = ssl.CERT_REQUIRED + if config.ssl_cafile: + ssl_context.load_verify_locations(config.ssl_cafile) + else: + ssl_context.load_default_certs() + if config.ssl_certfile: + ssl_context.load_cert_chain( + config.ssl_certfile, config.ssl_keyfile, config.ssl_password + ) + if config.ssl_ciphers: + ssl_context.set_ciphers(config.ssl_ciphers) + if config.ssl_check_hostname: + ssl_context.check_hostname = True + + return ssl_context + + +class HazelcastProtocol(asyncio.BufferedProtocol): + + PROTOCOL_STARTER = b"CP2" + + def __init__(self, conn: AsyncioConnection): + self._conn = conn + self._transport: transports.BaseTransport | None = None + self.start_time: float | None = None + self._write_buf = io.BytesIO() + self._write_buf_size = 0 + self._recv_buf = None + self._alive = True + + def connection_made(self, transport: transports.BaseTransport): + self._transport = transport + self.start_time = time.time() + self.write(self.PROTOCOL_STARTER) + _logger.debug("Connected to %s", self._conn._address) + self._conn._loop.call_soon(self._write_loop) + + def connection_lost(self, exc): + self._alive = False + self._conn._loop.create_task(self._conn.close_connection(str(exc), None)) + return False + + def close(self): + self._transport.close() + + def write(self, buf): + self._write_buf.write(buf) + self._write_buf_size += len(buf) + + def get_buffer(self, sizehint): + if self._recv_buf is None: + buf_size = max(sizehint, _BUFFER_SIZE) + self._recv_buf = memoryview(bytearray(buf_size)) + return self._recv_buf + + def buffer_updated(self, nbytes): + recv_bytes = self._recv_buf[:nbytes] + self._conn._update_read_time(time.time()) + self._conn._update_received(nbytes) + self._conn._reader.read(recv_bytes) + if self._conn._reader.length: + self._conn._reader.process() + + def eof_received(self): + self._alive = False + + def _do_write(self): + if not self._write_buf_size: + return + buf_bytes = self._write_buf.getvalue() + self._transport.write(buf_bytes[: self._write_buf_size]) + self._conn._update_write_time(time.time()) + self._conn._update_sent(self._write_buf_size) + self._write_buf.seek(0) + self._write_buf_size = 0 + + def _write_loop(self): + self._do_write() + return self._conn._loop.call_later(0.01, self._write_loop) diff --git a/tests/integration/asyncio/__init__.py b/tests/integration/asyncio/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/asyncio/authentication_tests/__init__.py b/tests/integration/asyncio/authentication_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/asyncio/authentication_tests/authentication_test.py b/tests/integration/asyncio/authentication_tests/authentication_test.py new file mode 100644 index 0000000000..97c29b346e --- /dev/null +++ b/tests/integration/asyncio/authentication_tests/authentication_test.py @@ -0,0 +1,70 @@ +import os +import unittest + +import pytest + +from hazelcast.errors import HazelcastError +from tests.integration.asyncio.base import HazelcastTestCase +from tests.util import get_abs_path, compare_client_version +from hazelcast.asyncio.client import HazelcastClient + +try: + from hazelcast.security import BasicTokenProvider +except ImportError: + pass + + +@pytest.mark.enterprise +@unittest.skipIf( + compare_client_version("4.2.1") < 0, "Tests the features added in 4.2.1 version of the client" +) +class AuthenticationTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + current_directory = os.path.dirname(__file__) + rc = None + hazelcast_token_xml = get_abs_path( + current_directory, "../../backward_compatible/authentication_tests/hazelcast-token.xml" + ) + hazelcast_userpass_xml = get_abs_path( + current_directory, "../../backward_compatible/authentication_tests/hazelcast-user-pass.xml" + ) + + def setUp(self): + self.rc = self.create_rc() + + def tearDown(self): + self.rc.exit() + + async def test_no_auth(self): + cluster = self.create_cluster(self.rc, self.configure_cluster(self.hazelcast_userpass_xml)) + cluster.start_member() + + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start( + cluster_name=cluster.id, cluster_connect_timeout=2 + ) + + async def test_token_auth(self): + cluster = self.create_cluster(self.rc, self.configure_cluster(self.hazelcast_token_xml)) + cluster.start_member() + + token_provider = BasicTokenProvider("Hazelcast") + client = await HazelcastClient.create_and_start( + cluster_name=cluster.id, token_provider=token_provider + ) + self.assertTrue(client.lifecycle_service.is_running()) + await client.shutdown() + + async def test_username_password_auth(self): + cluster = self.create_cluster(self.rc, self.configure_cluster(self.hazelcast_userpass_xml)) + cluster.start_member() + + client = await HazelcastClient.create_and_start( + cluster_name=cluster.id, creds_username="member1", creds_password="s3crEt" + ) + self.assertTrue(client.lifecycle_service.is_running()) + await client.shutdown() + + @classmethod + def configure_cluster(cls, filename): + with open(filename, "r") as f: + return f.read() diff --git a/tests/integration/asyncio/backup_acks_test.py b/tests/integration/asyncio/backup_acks_test.py new file mode 100644 index 0000000000..44d5662ae4 --- /dev/null +++ b/tests/integration/asyncio/backup_acks_test.py @@ -0,0 +1,94 @@ +import unittest + +from mock import MagicMock + +from hazelcast.asyncio.client import HazelcastClient +from hazelcast.errors import IndeterminateOperationStateError +from tests.base import HazelcastTestCase + + +class BackupAcksTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + @classmethod + def setUpClass(cls): + cls.rc = cls.create_rc() + cls.cluster = cls.rc.createCluster(None, None) + cls.rc.startMember(cls.cluster.id) + cls.rc.startMember(cls.cluster.id) + + @classmethod + def tearDownClass(cls): + cls.rc.terminateCluster(cls.cluster.id) + cls.rc.exit() + + def setUp(self): + self.client = None + + async def asyncTearDown(self): + if self.client: + await self.client.shutdown() + + async def test_smart_mode(self): + self.client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, + fail_on_indeterminate_operation_state=True, + ) + m = await self.client.get_map("test") + # TODO: Remove the next line once + # https://github.com/hazelcast/hazelcast/issues/9398 is fixed + await m.get(1) + # it's enough for this operation to succeed + await m.set(1, 2) + + async def test_lost_backups_on_smart_mode_with_fail_on_indeterminate_operation_state(self): + self.client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, + operation_backup_timeout=0.3, + fail_on_indeterminate_operation_state=True, + ) + client = self.client + # replace backup ack handler with a mock to emulate backup acks loss + client._invocation_service._backup_event_handler = MagicMock() + m = await client.get_map("test") + with self.assertRaises(IndeterminateOperationStateError): + await m.set(1, 2) + + async def test_lost_backups_on_smart_mode_without_fail_on_indeterminate_operation_state(self): + self.client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, + operation_backup_timeout=0.3, + fail_on_indeterminate_operation_state=False, + ) + client = self.client + # replace backup ack handler with a mock to emulate backup acks loss + client._invocation_service._backup_event_handler = MagicMock() + m = await client.get_map("test") + # it's enough for this operation to succeed + await m.set(1, 2) + + async def test_backup_acks_disabled(self): + self.client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, + backup_ack_to_client_enabled=False, + ) + m = await self.client.get_map("test") + # it's enough for this operation to succeed + await m.set(1, 2) + + async def test_unisocket_mode(self): + self.client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, + smart_routing=False, + ) + m = await self.client.get_map("test") + # it's enough for this operation to succeed + await m.set(1, 2) + + async def test_unisocket_mode_with_disabled_backup_acks(self): + self.client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, + smart_routing=False, + backup_ack_to_client_enabled=False, + ) + m = await self.client.get_map("test") + # it's enough for this operation to succeed + await m.set(1, 2) diff --git a/tests/integration/asyncio/base.py b/tests/integration/asyncio/base.py new file mode 100644 index 0000000000..d860f89929 --- /dev/null +++ b/tests/integration/asyncio/base.py @@ -0,0 +1,123 @@ +import asyncio +import logging +import unittest +from typing import Awaitable + +from hazelcast.asyncio.client import HazelcastClient + +from tests.base import _Cluster +from tests.hzrc.client import HzRemoteController +from tests.util import get_current_timestamp + + +class HazelcastTestCase(unittest.TestCase): + clients = [] + + def __init__(self, methodName): + unittest.TestCase.__init__(self, methodName) + self.logger = logging.getLogger(methodName) + + @staticmethod + def create_rc(): + return HzRemoteController("127.0.0.1", 9701) + + @classmethod + def create_cluster(cls, rc, config=None): + return _Cluster(rc, rc.createCluster(None, config)) + + @classmethod + def create_cluster_keep_cluster_name(cls, rc, config=None): + return _Cluster(rc, rc.createClusterKeepClusterName(None, config)) + + async def create_client(self, config=None): + client = await HazelcastClient.create_and_start(**config) + self.clients.append(client) + return client + + async def shutdown_all_clients(self): + async with asyncio.TaskGroup() as tg: + for c in self.clients: + tg.create_task(c.shutdown()) + self.clients = [] + + async def assertTrueEventually(self, assertion, timeout=30): + timeout_time = get_current_timestamp() + timeout + last_exception = None + while get_current_timestamp() < timeout_time: + try: + maybe_awaitable = assertion() + if isinstance(maybe_awaitable, Awaitable): + await maybe_awaitable + return + except AssertionError as e: + last_exception = e + await asyncio.sleep(0.1) + if last_exception is None: + raise Exception("Could not enter the assertion loop!") + raise last_exception + + async def assertSetEventually(self, event: asyncio.Event, timeout=5): + is_set = asyncio.wait_for(event.wait(), timeout=timeout) + self.assertTrue(is_set, "Event was not set within %d seconds" % timeout) + + def assertEntryEvent( + self, + event, + event_type, + key=None, + value=None, + old_value=None, + merging_value=None, + number_of_affected_entries=1, + ): + + self.assertEqual(event.key, key) + self.assertEqual(event.event_type, event_type) + self.assertEqual(event.value, value) + self.assertEqual(event.merging_value, merging_value) + self.assertEqual(event.old_value, old_value) + self.assertEqual(event.number_of_affected_entries, number_of_affected_entries) + + def assertDistributedObjectEvent(self, event, name, service_name, event_type): + self.assertEqual(name, event.name) + self.assertEqual(service_name, event.service_name) + self.assertEqual(event_type, event.event_type) + + def set_logging_level(self, level): + logging.getLogger().setLevel(level) + + +class SingleMemberTestCase(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + """ + Test cases where a single member - client combination is needed + """ + + rc = None + cluster = None + member = None + client = None + + @classmethod + def setUpClass(cls): + cls.rc = cls.create_rc() + cls.cluster = cls.create_cluster(cls.rc, cls.configure_cluster()) + cls.member = cls.cluster.start_member() + + @classmethod + def tearDownClass(cls): + cls.rc.terminateCluster(cls.cluster.id) + cls.rc.exit() + + @classmethod + def configure_client(cls, config): + return config + + @classmethod + def configure_cluster(cls): + return None + + async def asyncSetUp(self): + self.client = await HazelcastClient.create_and_start(**self.configure_client({})) + + async def asyncTearDown(self): + await self.client.shutdown() diff --git a/tests/integration/asyncio/client_test.py b/tests/integration/asyncio/client_test.py new file mode 100644 index 0000000000..02ef4d1ea7 --- /dev/null +++ b/tests/integration/asyncio/client_test.py @@ -0,0 +1,141 @@ +import unittest + +from tests.integration.asyncio.base import HazelcastTestCase, SingleMemberTestCase +from hazelcast.asyncio.client import HazelcastClient +from tests.hzrc.ttypes import Lang +from tests.util import compare_client_version, random_string + +try: + from hazelcast.config import Config + from hazelcast.errors import InvalidConfigurationError +except ImportError: + # For backward compatibility tests + pass + + +class ClientLabelsTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + @classmethod + def setUpClass(cls): + cls.rc = cls.create_rc() + cls.cluster = cls.create_cluster(cls.rc) + cls.cluster.start_member() + + @classmethod + def tearDownClass(cls): + cls.rc.terminateCluster(cls.cluster.id) + cls.rc.exit() + + async def asyncTearDown(self): + await self.shutdown_all_clients() + + async def test_default_config(self): + client = await self.create_client({"cluster_name": self.cluster.id}) + self.assertIsNone(self.get_labels_from_member(client._connection_manager.client_uuid)) + + async def test_provided_labels_are_received(self): + client = await self.create_client( + { + "cluster_name": self.cluster.id, + "labels": [ + "test-label", + ], + } + ) + self.assertEqual( + b"test-label", self.get_labels_from_member(client._connection_manager.client_uuid) + ) + + def get_labels_from_member(self, client_uuid): + script = """ + var clients = instance_0.getClientService().getConnectedClients().toArray(); + for (i=0; i < clients.length; i++) { + var client = clients[i]; + if ("%s".equals(client.getUuid().toString())) { + result = client.getLabels().iterator().next(); + break; + } + }""" % str( + client_uuid + ) + return self.rc.executeOnController(self.cluster.id, script, Lang.JAVASCRIPT).result + + +@unittest.skipIf( + compare_client_version("4.2.2") < 0 or compare_client_version("5.0") == 0, + "Tests the features added in 5.1 version of the client, " + "which are backported into 4.2.2 and 5.0.1", +) +class ClientTcpMetricsTest(SingleMemberTestCase): + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + return config + + async def test_bytes_received(self): + reactor = self.client._reactor + bytes_received = reactor._bytes_received + self.assertGreater(bytes_received, 0) + m = await self.client.get_map(random_string()) + await m.get(random_string()) + self.assertGreater(reactor._bytes_received, bytes_received) + + async def test_bytes_sent(self): + reactor = self.client._reactor + bytes_sent = reactor._bytes_sent + self.assertGreater(bytes_sent, 0) + m = await self.client.get_map(random_string()) + await m.set(random_string(), random_string()) + self.assertGreater(reactor._bytes_sent, bytes_sent) + + +@unittest.skipIf( + compare_client_version("5.2") < 0, + "Tests the features added in 5.2 version of the client", +) +class ClientConfigurationTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + rc = None + cluster = None + + @classmethod + def setUpClass(cls): + cls.rc = cls.create_rc() + cls.cluster = cls.create_cluster(cls.rc, None) + cls.cluster.start_member() + + def setUp(self): + self.client = None + + async def asyncTearDown(self): + if self.client: + await self.client.shutdown() + + @classmethod + def tearDownClass(cls): + cls.rc.terminateCluster(cls.cluster.id) + cls.rc.exit() + + async def test_keyword_args_configuration(self): + self.client = await HazelcastClient.create_and_start( + cluster_name=self.cluster.id, + ) + self.assertTrue(self.client.lifecycle_service.is_running()) + + async def test_configuration_object(self): + config = Config() + config.cluster_name = self.cluster.id + self.client = await HazelcastClient.create_and_start(config) + self.assertTrue(self.client.lifecycle_service.is_running()) + + async def test_configuration_object_as_keyword_argument(self): + config = Config() + config.cluster_name = self.cluster.id + self.client = await HazelcastClient.create_and_start(config=config) + self.assertTrue(self.client.lifecycle_service.is_running()) + + async def test_ambiguous_configuration(self): + config = Config() + with self.assertRaisesRegex( + InvalidConfigurationError, + "Ambiguous client configuration is found", + ): + self.client = await HazelcastClient.create_and_start(config, cluster_name="a-cluster") diff --git a/tests/integration/asyncio/proxy/__init__.py b/tests/integration/asyncio/proxy/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/asyncio/proxy/map_test.py b/tests/integration/asyncio/proxy/map_test.py new file mode 100644 index 0000000000..b63ae9e0fe --- /dev/null +++ b/tests/integration/asyncio/proxy/map_test.py @@ -0,0 +1,1075 @@ +import asyncio +import os +import time +import unittest + + +try: + from hazelcast.aggregator import ( + count, + distinct, + double_avg, + double_sum, + fixed_point_sum, + floating_point_sum, + int_avg, + int_sum, + long_avg, + long_sum, + max_, + max_by, + min_, + min_by, + number_avg, + ) +except ImportError: + # If the import of those fail, we won't use + # them in the tests thanks to client version check. + pass + +try: + from hazelcast.projection import ( + single_attribute, + multi_attribute, + identity, + ) +except ImportError: + # If the import of those fail, we won't use + # them in the tests thanks to client version check. + pass + +from hazelcast.core import HazelcastJsonValue +from hazelcast.config import IndexType, IntType +from hazelcast.errors import HazelcastError +from hazelcast.predicate import greater_or_equal, less_or_equal, sql, paging, true +from hazelcast.internal.asyncio_proxy.map import EntryEventType +from hazelcast.serialization.api import IdentifiedDataSerializable +from tests.integration.asyncio.base import SingleMemberTestCase +from tests.integration.backward_compatible.util import ( + read_string_from_input, + write_string_to_output, +) +from tests.util import ( + event_collector, + get_current_timestamp, + compare_client_version, + compare_server_version, + skip_if_client_version_older_than, + random_string, +) + +from tests.integration.asyncio.util import fill_map + + +class EntryProcessor(IdentifiedDataSerializable): + FACTORY_ID = 66 + CLASS_ID = 1 + + def __init__(self, value=None): + self.value = value + + def write_data(self, object_data_output): + write_string_to_output(object_data_output, self.value) + + def read_data(self, object_data_input): + self.value = read_string_from_input(object_data_input) + + def get_factory_id(self): + return self.FACTORY_ID + + def get_class_id(self): + return self.CLASS_ID + + +class MapGetInterceptor(IdentifiedDataSerializable): + + FACTORY_ID = 666 + CLASS_ID = 6 + + def __init__(self, prefix): + self.prefix = prefix + + def write_data(self, object_data_output): + write_string_to_output(object_data_output, self.prefix) + + def read_data(self, object_data_input): + pass + + def get_factory_id(self): + return self.FACTORY_ID + + def get_class_id(self): + return self.CLASS_ID + + +class MapTest(SingleMemberTestCase): + @classmethod + def configure_cluster(cls): + path = os.path.abspath(__file__) + dir_path = os.path.dirname(path) + with open(os.path.join(dir_path, "../../backward_compatible/proxy/hazelcast.xml")) as f: + return f.read() + + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + config["data_serializable_factories"] = { + EntryProcessor.FACTORY_ID: {EntryProcessor.CLASS_ID: EntryProcessor} + } + return config + + async def asyncSetUp(self): + await super().asyncSetUp() + self.map = await self.client.get_map(random_string()) + + async def asyncTearDown(self): + await self.map.destroy() + await super().asyncTearDown() + + async def test_add_entry_listener_item_added(self): + collector = event_collector() + await self.map.add_entry_listener(include_value=True, added_func=collector) + await self.map.put("key", "value") + + def assert_event(): + self.assertEqual(len(collector.events), 1) + event = collector.events[0] + self.assertEntryEvent(event, key="key", event_type=EntryEventType.ADDED, value="value") + + await self.assertTrueEventually(assert_event, 5) + + async def test_add_entry_listener_item_removed(self): + collector = event_collector() + await self.map.add_entry_listener(include_value=True, removed_func=collector) + await self.map.put("key", "value") + await self.map.remove("key") + + def assert_event(): + self.assertEqual(len(collector.events), 1) + event = collector.events[0] + self.assertEntryEvent( + event, key="key", event_type=EntryEventType.REMOVED, old_value="value" + ) + + await self.assertTrueEventually(assert_event, 5) + + async def test_add_entry_listener_item_updated(self): + collector = event_collector() + await self.map.add_entry_listener(include_value=True, updated_func=collector) + await self.map.put("key", "value") + await self.map.put("key", "new_value") + + def assert_event(): + self.assertEqual(len(collector.events), 1) + event = collector.events[0] + self.assertEntryEvent( + event, + key="key", + event_type=EntryEventType.UPDATED, + old_value="value", + value="new_value", + ) + + await self.assertTrueEventually(assert_event, 5) + + async def test_add_entry_listener_item_expired(self): + collector = event_collector() + await self.map.add_entry_listener(include_value=True, expired_func=collector) + await self.map.put("key", "value", ttl=0.1) + + def assert_event(): + self.assertEqual(len(collector.events), 1) + event = collector.events[0] + self.assertEntryEvent( + event, key="key", event_type=EntryEventType.EXPIRED, old_value="value" + ) + + await self.assertTrueEventually(assert_event, 10) + + async def test_add_entry_listener_with_key(self): + collector = event_collector() + await self.map.add_entry_listener(key="key1", include_value=True, added_func=collector) + await self.map.put("key2", "value2") + await self.map.put("key1", "value1") + + def assert_event(): + self.assertEqual(len(collector.events), 1) + event = collector.events[0] + self.assertEntryEvent( + event, key="key1", event_type=EntryEventType.ADDED, value="value1" + ) + + await self.assertTrueEventually(assert_event, 5) + + async def test_add_entry_listener_with_predicate(self): + collector = event_collector() + await self.map.add_entry_listener( + predicate=sql("this == value1"), include_value=True, added_func=collector + ) + await self.map.put("key2", "value2") + await self.map.put("key1", "value1") + + def assert_event(): + self.assertEqual(len(collector.events), 1) + event = collector.events[0] + self.assertEntryEvent( + event, key="key1", event_type=EntryEventType.ADDED, value="value1" + ) + + await self.assertTrueEventually(assert_event, 5) + + async def test_add_entry_listener_with_key_and_predicate(self): + collector = event_collector() + await self.map.add_entry_listener( + key="key1", predicate=sql("this == value3"), include_value=True, added_func=collector + ) + await self.map.put("key2", "value2") + await self.map.put("key1", "value1") + await self.map.remove("key1") + await self.map.put("key1", "value3") + + def assert_event(): + self.assertEqual(len(collector.events), 1) + event = collector.events[0] + self.assertEntryEvent( + event, key="key1", event_type=EntryEventType.ADDED, value="value3" + ) + + await self.assertTrueEventually(assert_event, 5) + + async def test_add_index(self): + await self.map.add_index(attributes=["this"]) + await self.map.add_index(attributes=["this"], index_type=IndexType.HASH) + await self.map.add_index( + attributes=["this"], + index_type=IndexType.BITMAP, + bitmap_index_options={ + "unique_key": "this", + }, + ) + + async def test_add_index_duplicate_fields(self): + with self.assertRaises(ValueError): + await self.map.add_index(attributes=["this", "this"]) + + async def test_add_index_invalid_attribute(self): + with self.assertRaises(ValueError): + await self.map.add_index(attributes=["this.x."]) + + async def test_clear(self): + await self.fill_map() + await self.map.clear() + self.assertEqual(await self.map.size(), 0) + + async def test_contains_key(self): + await self.fill_map() + self.assertTrue(await self.map.contains_key("key-1")) + self.assertFalse(await self.map.contains_key("key-10")) + + async def test_contains_value(self): + await self.fill_map() + self.assertTrue(await self.map.contains_value("value-1")) + self.assertFalse(await self.map.contains_value("value-10")) + + async def test_delete(self): + await self.fill_map() + await self.map.delete("key-1") + self.assertEqual(await self.map.size(), 9) + self.assertFalse(await self.map.contains_key("key-1")) + + async def test_entry_set(self): + entries = await self.fill_map() + self.assertCountEqual(await self.map.entry_set(), list(entries.items())) + + async def test_entry_set_with_predicate(self): + await self.fill_map() + self.assertEqual(await self.map.entry_set(sql("this == 'value-1'")), [("key-1", "value-1")]) + + async def test_evict(self): + await self.fill_map() + await self.map.evict("key-1") + self.assertEqual(await self.map.size(), 9) + self.assertFalse(await self.map.contains_key("key-1")) + + async def test_evict_all(self): + await self.fill_map() + await self.map.evict_all() + self.assertEqual(await self.map.size(), 0) + + async def test_execute_on_entries(self): + m = await self.fill_map() + expected_entry_set = [(key, "processed") for key in m] + values = await self.map.execute_on_entries(EntryProcessor("processed")) + + self.assertCountEqual(expected_entry_set, await self.map.entry_set()) + self.assertCountEqual(expected_entry_set, values) + + async def test_execute_on_entries_with_predicate(self): + m = await self.fill_map() + expected_entry_set = [(key, "processed") if key < "key-5" else (key, m[key]) for key in m] + expected_values = [(key, "processed") for key in m if key < "key-5"] + values = await self.map.execute_on_entries( + EntryProcessor("processed"), sql("__key < 'key-5'") + ) + self.assertCountEqual(expected_entry_set, await self.map.entry_set()) + self.assertCountEqual(expected_values, values) + + async def test_execute_on_key(self): + await self.map.put("test-key", "test-value") + value = await self.map.execute_on_key("test-key", EntryProcessor("processed")) + self.assertEqual("processed", await self.map.get("test-key")) + self.assertEqual("processed", value) + + async def test_execute_on_keys(self): + m = await self.fill_map() + expected_entry_set = [(key, "processed") for key in m] + values = await self.map.execute_on_keys(list(m.keys()), EntryProcessor("processed")) + self.assertCountEqual(expected_entry_set, await self.map.entry_set()) + self.assertCountEqual(expected_entry_set, values) + + async def test_execute_on_keys_with_empty_key_list(self): + m = await self.fill_map() + expected_entry_set = [(key, m[key]) for key in m] + values = await self.map.execute_on_keys([], EntryProcessor("processed")) + self.assertEqual([], values) + self.assertCountEqual(expected_entry_set, await self.map.entry_set()) + + async def test_flush(self): + await self.fill_map() + await self.map.flush() + + async def test_get_all(self): + expected = await self.fill_map(1000) + actual = await self.map.get_all(list(expected.keys())) + self.assertCountEqual(expected, actual) + + async def test_get_all_when_no_keys(self): + self.assertEqual(await self.map.get_all([]), {}) + + async def test_get_entry_view(self): + await self.map.put("key", "value") + await self.map.get("key") + await self.map.put("key", "new_value") + + entry_view = await self.map.get_entry_view("key") + + self.assertEqual(entry_view.key, "key") + self.assertEqual(entry_view.value, "new_value") + self.assertIsNotNone(entry_view.cost) + self.assertIsNotNone(entry_view.creation_time) + self.assertIsNotNone(entry_view.expiration_time) + if compare_server_version(self.client, "4.2") < 0: + self.assertEqual(entry_view.hits, 2) + else: + # 4.2+ servers do not collect per entry stats by default + self.assertIsNotNone(entry_view.hits) + self.assertIsNotNone(entry_view.last_access_time) + self.assertIsNotNone(entry_view.last_stored_time) + self.assertIsNotNone(entry_view.last_update_time) + self.assertEqual(entry_view.version, 1) + self.assertIsNotNone(entry_view.ttl) + self.assertIsNotNone(entry_view.max_idle) + + async def test_is_empty(self): + await self.map.put("key", "value") + self.assertFalse(await self.map.is_empty()) + await self.map.clear() + self.assertTrue(await self.map.is_empty()) + + async def test_key_set(self): + keys = list((await self.fill_map()).keys()) + self.assertCountEqual(await self.map.key_set(), keys) + + async def test_key_set_with_predicate(self): + await self.fill_map() + self.assertEqual(await self.map.key_set(sql("this == 'value-1'")), ["key-1"]) + + async def test_put_all(self): + m = {"key-%d" % x: "value-%d" % x for x in range(0, 1000)} + await self.map.put_all(m) + + entries = await self.map.entry_set() + + self.assertCountEqual(entries, m.items()) + + async def test_put_all_when_no_keys(self): + self.assertIsNone(await self.map.put_all({})) + + async def test_put_if_absent_when_missing_value(self): + returned_value = await self.map.put_if_absent("key", "new_value") + + self.assertIsNone(returned_value) + self.assertEqual(await self.map.get("key"), "new_value") + + async def test_put_if_absent_when_existing_value(self): + await self.map.put("key", "value") + returned_value = await self.map.put_if_absent("key", "new_value") + self.assertEqual(returned_value, "value") + self.assertEqual(await self.map.get("key"), "value") + + async def test_put_get(self): + self.assertIsNone(await self.map.put("key", "value")) + self.assertEqual(await self.map.get("key"), "value") + + async def test_put_get_large_payload(self): + # The fix for reading large payloads is introduced in 4.2.1 + # See https://github.com/hazelcast/hazelcast-python-client/pull/436 + skip_if_client_version_older_than(self, "4.2.1") + + payload = bytearray(os.urandom(16 * 1024 * 1024)) + start = get_current_timestamp() + self.assertIsNone(await self.map.put("key", payload)) + self.assertEqual(await self.map.get("key"), payload) + self.assertLessEqual(get_current_timestamp() - start, 5) + + async def test_put_get2(self): + val = "x" * 5000 + self.assertIsNone(await self.map.put("key-x", val)) + self.assertEqual(await self.map.get("key-x"), val) + + async def test_put_when_existing(self): + await self.map.put("key", "value") + self.assertEqual(await self.map.put("key", "new_value"), "value") + self.assertEqual(await self.map.get("key"), "new_value") + + async def test_put_transient(self): + await self.map.put_transient("key", "value") + self.assertEqual(await self.map.get("key"), "value") + + async def test_remove(self): + await self.map.put("key", "value") + removed = await self.map.remove("key") + self.assertEqual(removed, "value") + self.assertEqual(0, await self.map.size()) + self.assertFalse(await self.map.contains_key("key")) + + async def test_remove_all_with_none_predicate(self): + skip_if_client_version_older_than(self, "5.2.0") + + with self.assertRaises(AssertionError): + await self.map.remove_all(None) + + async def test_remove_all(self): + skip_if_client_version_older_than(self, "5.2.0") + + await self.fill_map() + await self.map.remove_all(predicate=sql("__key > 'key-7'")) + self.assertEqual(await self.map.size(), 8) + + async def test_remove_if_same_when_same(self): + await self.map.put("key", "value") + self.assertTrue(await self.map.remove_if_same("key", "value")) + self.assertFalse(await self.map.contains_key("key")) + + async def test_remove_if_same_when_different(self): + await self.map.put("key", "value") + self.assertFalse(await self.map.remove_if_same("key", "another_value")) + self.assertTrue(await self.map.contains_key("key")) + + async def test_remove_entry_listener(self): + collector = event_collector() + reg_id = await self.map.add_entry_listener(added_func=collector) + + await self.map.put("key", "value") + await self.assertTrueEventually(lambda: self.assertEqual(len(collector.events), 1)) + await self.map.remove_entry_listener(reg_id) + await self.map.put("key2", "value") + + await asyncio.sleep(1) + self.assertEqual(len(collector.events), 1) + + async def test_remove_entry_listener_with_none_id(self): + with self.assertRaises(AssertionError) as cm: + await self.map.remove_entry_listener(None) + e = cm.exception + self.assertEqual(e.args[0], "None user_registration_id is not allowed!") + + async def test_replace(self): + await self.map.put("key", "value") + replaced = await self.map.replace("key", "new_value") + self.assertEqual(replaced, "value") + self.assertEqual(await self.map.get("key"), "new_value") + + async def test_replace_if_same_when_same(self): + await self.map.put("key", "value") + self.assertTrue(await self.map.replace_if_same("key", "value", "new_value")) + self.assertEqual(await self.map.get("key"), "new_value") + + async def test_replace_if_same_when_different(self): + await self.map.put("key", "value") + self.assertFalse(await self.map.replace_if_same("key", "another_value", "new_value")) + self.assertEqual(await self.map.get("key"), "value") + + async def test_set(self): + await self.map.set("key", "value") + + self.assertEqual(await self.map.get("key"), "value") + + async def test_set_ttl(self): + await self.map.put("key", "value") + await self.map.set_ttl("key", 0.1) + + async def evicted(): + self.assertFalse(await self.map.contains_key("key")) + + await self.assertTrueEventually(evicted, 5) + + async def test_size(self): + await self.fill_map() + + self.assertEqual(10, await self.map.size()) + + async def test_values(self): + values = list((await self.fill_map()).values()) + + self.assertCountEqual(list(await self.map.values()), values) + + async def test_values_with_predicate(self): + await self.fill_map() + self.assertEqual(await self.map.values(sql("this == 'value-1'")), ["value-1"]) + + def test_str(self): + self.assertTrue(str(self.map).startswith("Map")) + + async def test_add_interceptor(self): + interceptor = MapGetInterceptor(":") + registration_id = await self.map.add_interceptor(interceptor) + self.assertIsNotNone(registration_id) + + await self.map.set(1, ")") + value = await self.map.get(1) + self.assertEqual(":)", value) + + async def test_remove_interceptor(self): + skip_if_client_version_older_than(self, "5.0") + + interceptor = MapGetInterceptor(":") + registration_id = await self.map.add_interceptor(interceptor) + self.assertIsNotNone(registration_id) + self.assertTrue(await self.map.remove_interceptor(registration_id)) + + # Unknown registration id should return False + self.assertFalse(await self.map.remove_interceptor(registration_id)) + + # Make sure that the interceptor is indeed removed + await self.map.set(1, ")") + value = await self.map.get(1) + self.assertEqual(")", value) + + async def fill_map(self, count=10): + m = {"key-%d" % x: "value-%d" % x for x in range(0, count)} + await self.map.put_all(m) + return m + + +class MapStoreTest(SingleMemberTestCase): + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + return config + + @classmethod + def configure_cluster(cls): + path = os.path.abspath(__file__) + dir_path = os.path.dirname(path) + with open( + os.path.join(dir_path, "../../backward_compatible/proxy/hazelcast_mapstore.xml") + ) as f: + return f.read() + + async def asyncSetUp(self): + await super().asyncSetUp() + self.map = await self.client.get_map("mapstore-test") + self.entries = await fill_map(self.map, size=10, key_prefix="key", value_prefix="val") + + async def asyncTearDown(self): + await self.map.destroy() + await super().asyncTearDown() + + async def test_load_all_with_no_args_loads_all_keys(self): + await self.map.evict_all() + await self.map.load_all() + entry_set = await self.map.get_all(self.entries.keys()) + self.assertCountEqual(entry_set, self.entries) + + async def test_load_all_with_key_set_loads_given_keys(self): + await self.map.evict_all() + await self.map.load_all(["key0", "key1"]) + entry_set = await self.map.get_all(["key0", "key1"]) + self.assertCountEqual(entry_set, {"key0": "val0", "key1": "val1"}) + + async def test_load_all_overrides_entries_in_memory_by_default(self): + await self.map.evict_all() + await self.map.put_transient("key0", "new0") + await self.map.put_transient("key1", "new1") + await self.map.load_all(["key0", "key1"]) + entry_set = await self.map.get_all(["key0", "key1"]) + self.assertCountEqual(entry_set, {"key0": "val0", "key1": "val1"}) + + async def test_load_all_with_replace_existing_false_does_not_override(self): + await self.map.evict_all() + await self.map.put_transient("key0", "new0") + await self.map.put_transient("key1", "new1") + await self.map.load_all(["key0", "key1"], replace_existing_values=False) + entry_set = await self.map.get_all(["key0", "key1"]) + self.assertCountEqual(entry_set, {"key0": "new0", "key1": "new1"}) + + async def test_evict(self): + await self.map.evict("key0") + self.assertEqual(await self.map.size(), 9) + + async def test_evict_non_existing_key(self): + await self.map.evict("non_existing_key") + self.assertEqual(await self.map.size(), 10) + + async def test_evict_all(self): + await self.map.evict_all() + self.assertEqual(await self.map.size(), 0) + + async def test_add_entry_listener_item_loaded(self): + collector = event_collector() + await self.map.add_entry_listener(include_value=True, loaded_func=collector) + await self.map.put("key", "value", ttl=0.1) + time.sleep(2) + await self.map.get("key") + + def assert_event(): + self.assertEqual(len(collector.events), 1) + event = collector.events[0] + self.assertEntryEvent(event, key="key", value="value", event_type=EntryEventType.LOADED) + + await self.assertTrueEventually(assert_event, 10) + + +class MapTTLTest(SingleMemberTestCase): + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + return config + + async def asyncSetUp(self): + await super().asyncSetUp() + self.map = await self.client.get_map(random_string()) + + async def asyncTearDown(self): + await self.map.destroy() + await super().asyncTearDown() + + async def test_put_default_ttl(self): + await self.map.put("key", "value") + await asyncio.sleep(1.0) + self.assertTrue(await self.map.contains_key("key")) + + async def test_put(self): + async def assert_map_not_contains(): + self.assertFalse(await self.map.contains_key("key")) + + await self.map.put("key", "value", 0.1) + await self.assertTrueEventually(lambda: assert_map_not_contains()) + + async def test_put_transient_default_ttl(self): + await self.map.put_transient("key", "value") + await asyncio.sleep(1.0) + self.assertTrue(await self.map.contains_key("key")) + + async def test_put_transient(self): + async def assert_map_not_contains(): + self.assertFalse(await self.map.contains_key("key")) + + await self.map.put_transient("key", "value", 0.1) + await self.assertTrueEventually(lambda: assert_map_not_contains()) + + async def test_put_if_absent_ttl(self): + await self.map.put_if_absent("key", "value") + await asyncio.sleep(1.0) + self.assertTrue(await self.map.contains_key("key")) + + async def test_put_if_absent(self): + async def assert_map_not_contains(): + self.assertFalse(await self.map.contains_key("key")) + + await self.map.put_if_absent("key", "value", 0.1) + await self.assertTrueEventually(lambda: assert_map_not_contains()) + + async def test_set_default_ttl(self): + await self.map.set("key", "value") + await asyncio.sleep(1.0) + self.assertTrue(await self.map.contains_key("key")) + + async def test_set(self): + async def assert_map_not_contains(): + self.assertFalse(await self.map.contains_key("key")) + + await self.map.set("key", "value", 0.1) + await self.assertTrueEventually(lambda: assert_map_not_contains()) + + +class MapMaxIdleTest(SingleMemberTestCase): + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + return config + + async def asyncSetUp(self): + await super().asyncSetUp() + self.map = await self.client.get_map(random_string()) + + async def asyncTearDown(self): + await self.map.destroy() + await super().asyncTearDown() + + async def test_put_default_max_idle(self): + await self.map.put("key", "value") + await asyncio.sleep(1.0) + self.assertTrue(await self.map.contains_key("key")) + + async def test_put(self): + await self.map.put("key", "value", max_idle=0.1) + await asyncio.sleep(1.0) + self.assertFalse(await self.map.contains_key("key")) + + async def test_put_transient_default_max_idle(self): + await self.map.put_transient("key", "value") + await asyncio.sleep(1.0) + self.assertTrue(await self.map.contains_key("key")) + + async def test_put_transient(self): + await self.map.put_transient("key", "value", max_idle=0.1) + await asyncio.sleep(1.0) + self.assertFalse(await self.map.contains_key("key")) + + async def test_put_if_absent_max_idle(self): + await self.map.put_if_absent("key", "value") + await asyncio.sleep(1.0) + self.assertTrue(await self.map.contains_key("key")) + + async def test_put_if_absent(self): + await self.map.put_if_absent("key", "value", max_idle=0.1) + await asyncio.sleep(1.0) + self.assertFalse(await self.map.contains_key("key")) + + async def test_set_default_ttl(self): + await self.map.set("key", "value") + await asyncio.sleep(1.0) + self.assertTrue(await self.map.contains_key("key")) + + async def test_set(self): + await self.map.set("key", "value", max_idle=0.1) + await asyncio.sleep(1.0) + self.assertFalse(await self.map.contains_key("key")) + + +@unittest.skipIf( + compare_client_version("4.2.1") < 0, "Tests the features added in 4.2.1 version of the client" +) +class MapAggregatorsIntTest(SingleMemberTestCase): + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + config["default_int_type"] = IntType.INT + return config + + async def asyncSetUp(self): + await super().asyncSetUp() + self.map = await self.client.get_map(random_string()) + await self.map.put_all({"key-%d" % i: i for i in range(50)}) + + async def asyncTearDown(self): + await self.map.destroy() + await super().asyncTearDown() + + async def test_aggregate_with_none_aggregator(self): + with self.assertRaises(AssertionError): + await self.map.aggregate(None) + + async def test_aggregate_with_paging_predicate(self): + with self.assertRaises(AssertionError): + await self.map.aggregate(int_avg("foo"), paging(true(), 10)) + + async def test_int_average(self): + average = await self.map.aggregate(int_avg()) + self.assertEqual(24.5, average) + + async def test_int_average_with_attribute_path(self): + average = await self.map.aggregate(int_avg("this")) + self.assertEqual(24.5, average) + + async def test_int_average_with_predicate(self): + average = await self.map.aggregate(int_avg(), greater_or_equal("this", 47)) + self.assertEqual(48, average) + + async def test_int_sum(self): + sum_ = await self.map.aggregate(int_sum()) + self.assertEqual(1225, sum_) + + async def test_int_sum_with_attribute_path(self): + sum_ = await self.map.aggregate(int_sum("this")) + self.assertEqual(1225, sum_) + + async def test_int_sum_with_predicate(self): + sum_ = await self.map.aggregate(int_sum(), greater_or_equal("this", 47)) + self.assertEqual(144, sum_) + + async def test_fixed_point_sum(self): + sum_ = await self.map.aggregate(fixed_point_sum()) + self.assertEqual(1225, sum_) + + async def test_fixed_point_sum_with_attribute_path(self): + sum_ = await self.map.aggregate(fixed_point_sum("this")) + self.assertEqual(1225, sum_) + + async def test_fixed_point_sum_with_predicate(self): + sum_ = await self.map.aggregate(fixed_point_sum(), greater_or_equal("this", 47)) + self.assertEqual(144, sum_) + + async def test_distinct(self): + await self._fill_with_duplicate_values() + distinct_values = await self.map.aggregate(distinct()) + self.assertEqual(set(range(50)), distinct_values) + + async def test_distinct_with_attribute_path(self): + await self._fill_with_duplicate_values() + distinct_values = await self.map.aggregate(distinct("this")) + self.assertEqual(set(range(50)), distinct_values) + + async def test_distinct_with_predicate(self): + await self._fill_with_duplicate_values() + distinct_values = await self.map.aggregate(distinct(), greater_or_equal("this", 10)) + self.assertEqual(set(range(10, 50)), distinct_values) + + async def test_max_by(self): + max_item = await self.map.aggregate(max_by("this")) + self.assertEqual("key-49", max_item.key) + self.assertEqual(49, max_item.value) + + async def test_max_by_with_predicate(self): + max_item = await self.map.aggregate(max_by("this"), less_or_equal("this", 10)) + self.assertEqual("key-10", max_item.key) + self.assertEqual(10, max_item.value) + + async def test_min_by(self): + min_item = await self.map.aggregate(min_by("this")) + self.assertEqual("key-0", min_item.key) + self.assertEqual(0, min_item.value) + + async def test_min_by_with_predicate(self): + min_item = await self.map.aggregate(min_by("this"), greater_or_equal("this", 10)) + self.assertEqual("key-10", min_item.key) + self.assertEqual(10, min_item.value) + + async def _fill_with_duplicate_values(self): + # Map is initially filled with key-i: i mappings from [0, 50). + # Add more values with different keys but the same values to + # test the behaviour of the distinct aggregator. + await self.map.put_all({"different-key-%d" % i: i for i in range(50)}) + + +@unittest.skipIf( + compare_client_version("4.2.1") < 0, "Tests the features added in 4.2.1 version of the client" +) +class MapAggregatorsLongTest(SingleMemberTestCase): + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + config["default_int_type"] = IntType.LONG + return config + + async def asyncSetUp(self): + await super().asyncSetUp() + self.map = await self.client.get_map(random_string()) + await self.map.put_all({"key-%d" % i: i for i in range(50)}) + + async def asyncTearDown(self): + await self.map.destroy() + await super().asyncTearDown() + + async def test_long_average(self): + average = await self.map.aggregate(long_avg()) + self.assertEqual(24.5, average) + + async def test_long_average_with_attribute_path(self): + average = await self.map.aggregate(long_avg("this")) + self.assertEqual(24.5, average) + + async def test_long_average_with_predicate(self): + average = await self.map.aggregate(long_avg(), greater_or_equal("this", 47)) + self.assertEqual(48, average) + + async def test_long_sum(self): + sum_ = await self.map.aggregate(long_sum()) + self.assertEqual(1225, sum_) + + async def test_long_sum_with_attribute_path(self): + sum_ = await self.map.aggregate(long_sum("this")) + self.assertEqual(1225, sum_) + + async def test_long_sum_with_predicate(self): + sum_ = await self.map.aggregate(long_sum(), greater_or_equal("this", 47)) + self.assertEqual(144, sum_) + + +@unittest.skipIf( + compare_client_version("4.2.1") < 0, "Tests the features added in 4.2.1 version of the client" +) +class MapAggregatorsDoubleTest(SingleMemberTestCase): + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + return config + + async def asyncSetUp(self): + await super().asyncSetUp() + self.map = await self.client.get_map(random_string()) + await self.map.put_all({"key-%d" % i: float(i) for i in range(50)}) + + async def asyncTearDown(self): + await self.map.destroy() + await super().asyncTearDown() + + async def test_count(self): + count_ = await self.map.aggregate(count()) + self.assertEqual(50, count_) + + async def test_count_with_attribute_path(self): + count_ = await self.map.aggregate(count("this")) + self.assertEqual(50, count_) + + async def test_count_with_predicate(self): + count_ = await self.map.aggregate(count(), greater_or_equal("this", 1)) + self.assertEqual(49, count_) + + async def test_double_average(self): + average = await self.map.aggregate(double_avg()) + self.assertEqual(24.5, average) + + async def test_double_average_with_attribute_path(self): + average = await self.map.aggregate(double_avg("this")) + self.assertEqual(24.5, average) + + async def test_double_average_with_predicate(self): + average = await self.map.aggregate(double_avg(), greater_or_equal("this", 47)) + self.assertEqual(48, average) + + async def test_double_sum(self): + sum_ = await self.map.aggregate(double_sum()) + self.assertEqual(1225, sum_) + + async def test_double_sum_with_attribute_path(self): + sum_ = await self.map.aggregate(double_sum("this")) + self.assertEqual(1225, sum_) + + async def test_double_sum_with_predicate(self): + sum_ = await self.map.aggregate(double_sum(), greater_or_equal("this", 47)) + self.assertEqual(144, sum_) + + async def test_floating_point_sum(self): + sum_ = await self.map.aggregate(floating_point_sum()) + self.assertEqual(1225, sum_) + + async def test_floating_point_sum_with_attribute_path(self): + sum_ = await self.map.aggregate(floating_point_sum("this")) + self.assertEqual(1225, sum_) + + async def test_floating_point_sum_with_predicate(self): + sum_ = await self.map.aggregate(floating_point_sum(), greater_or_equal("this", 47)) + self.assertEqual(144, sum_) + + async def test_number_avg(self): + average = await self.map.aggregate(number_avg()) + self.assertEqual(24.5, average) + + async def test_number_avg_with_attribute_path(self): + average = await self.map.aggregate(number_avg("this")) + self.assertEqual(24.5, average) + + async def test_number_avg_with_predicate(self): + average = await self.map.aggregate(number_avg(), greater_or_equal("this", 47)) + self.assertEqual(48, average) + + async def test_max(self): + average = await self.map.aggregate(max_()) + self.assertEqual(49, average) + + async def test_max_with_attribute_path(self): + average = await self.map.aggregate(max_("this")) + self.assertEqual(49, average) + + async def test_max_with_predicate(self): + average = await self.map.aggregate(max_(), less_or_equal("this", 3)) + self.assertEqual(3, average) + + async def test_min(self): + average = await self.map.aggregate(min_()) + self.assertEqual(0, average) + + async def test_min_with_attribute_path(self): + average = await self.map.aggregate(min_("this")) + self.assertEqual(0, average) + + async def test_min_with_predicate(self): + average = await self.map.aggregate(min_(), greater_or_equal("this", 3)) + self.assertEqual(3, average) + + +@unittest.skipIf( + compare_client_version("4.2.1") < 0, "Tests the features added in 4.2.1 version of the client" +) +class MapProjectionsTest(SingleMemberTestCase): + @classmethod + def configure_client(cls, config): + config["cluster_name"] = cls.cluster.id + return config + + async def asyncSetUp(self): + await super().asyncSetUp() + self.map = await self.client.get_map(random_string()) + await self.map.put(1, HazelcastJsonValue('{"attr1": 1, "attr2": 2, "attr3": 3}')) + await self.map.put(2, HazelcastJsonValue('{"attr1": 4, "attr2": 5, "attr3": 6}')) + + async def asyncTearDown(self): + await self.map.destroy() + await super().asyncTearDown() + + async def test_project_with_none_projection(self): + with self.assertRaises(AssertionError): + await self.map.project(None) + + async def test_project_with_paging_predicate(self): + with self.assertRaises(AssertionError): + await self.map.project(single_attribute("foo"), paging(true(), 10)) + + async def test_single_attribute(self): + attributes = await self.map.project(single_attribute("attr1")) + self.assertCountEqual([1, 4], attributes) + + async def test_single_attribute_with_predicate(self): + attributes = await self.map.project(single_attribute("attr1"), greater_or_equal("attr1", 4)) + self.assertCountEqual([4], attributes) + + async def test_multi_attribute(self): + attributes = await self.map.project(multi_attribute("attr1", "attr2")) + self.assertCountEqual([[1, 2], [4, 5]], attributes) + + async def test_multi_attribute_with_predicate(self): + attributes = await self.map.project( + multi_attribute("attr1", "attr2"), + greater_or_equal("attr2", 3), + ) + self.assertCountEqual([[4, 5]], attributes) + + async def test_identity(self): + attributes = await self.map.project(identity()) + self.assertCountEqual( + [ + HazelcastJsonValue('{"attr1": 4, "attr2": 5, "attr3": 6}'), + HazelcastJsonValue('{"attr1": 1, "attr2": 2, "attr3": 3}'), + ], + [attribute.value for attribute in attributes], + ) + + async def test_identity_with_predicate(self): + attributes = await self.map.project(identity(), greater_or_equal("attr2", 3)) + self.assertCountEqual( + [HazelcastJsonValue('{"attr1": 4, "attr2": 5, "attr3": 6}')], + [attribute.value for attribute in attributes], + ) diff --git a/tests/integration/asyncio/ssl_tests/__init__.py b/tests/integration/asyncio/ssl_tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/asyncio/ssl_tests/hostname_verification/__init__.py b/tests/integration/asyncio/ssl_tests/hostname_verification/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integration/asyncio/ssl_tests/hostname_verification/ssl_hostname_verification_test.py b/tests/integration/asyncio/ssl_tests/hostname_verification/ssl_hostname_verification_test.py new file mode 100644 index 0000000000..87fff817ec --- /dev/null +++ b/tests/integration/asyncio/ssl_tests/hostname_verification/ssl_hostname_verification_test.py @@ -0,0 +1,135 @@ +import os +import sys +import unittest + +import pytest + +from hazelcast.asyncio.client import HazelcastClient +from hazelcast.config import SSLProtocol +from hazelcast.errors import IllegalStateError +from tests.integration.asyncio.base import HazelcastTestCase +from tests.util import compare_client_version, get_abs_path + +current_directory = os.path.abspath( + os.path.join( + os.path.dirname(__file__), "../../../backward_compatible/ssl_tests/hostname_verification" + ) +) + +MEMBER_CONFIG = """ + + + + + com.hazelcast.nio.ssl.BasicSSLContextFactory + + + %s + 123456 + PKCS12 + TLSv1.2 + + + + +""" + + +@unittest.skipIf( + sys.version_info < (3, 7), + "Hostname verification feature requires Python 3.7+", +) +@unittest.skipIf( + compare_client_version("5.1") < 0, + "Tests the features added in 5.1 version of the client", +) +@pytest.mark.enterprise +class SslHostnameVerificationTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + def setUp(self): + self.rc = self.create_rc() + self.cluster = None + + async def asyncTearDown(self): + await self.shutdown_all_clients() + self.rc.terminateCluster(self.cluster.id) + self.rc.exit() + + async def test_hostname_verification_with_loopback_san(self): + # SAN entry is present with different possible values + file_name = "tls-host-loopback-san" + self.start_member_with(f"{file_name}.p12") + await self.start_client_with(f"{file_name}.pem", "127.0.0.1:5701") + await self.start_client_with(f"{file_name}.pem", "localhost:5701") + + async def test_hostname_verification_with_loopback_dns_san(self): + # SAN entry is present, but only with `dns:localhost` + file_name = "tls-host-loopback-san-dns" + self.start_member_with(f"{file_name}.p12") + await self.start_client_with(f"{file_name}.pem", "localhost:5701") + with self.assertRaisesRegex(IllegalStateError, "Unable to connect to any cluster"): + await self.start_client_with(f"{file_name}.pem", "127.0.0.1:5701") + + async def test_hostname_verification_with_different_san(self): + # There is a valid entry, but it does not match with the address of the member. + file_name = "tls-host-not-our-san" + self.start_member_with(f"{file_name}.p12") + with self.assertRaisesRegex(IllegalStateError, "Unable to connect to any cluster"): + await self.start_client_with(f"{file_name}.pem", "localhost:5701") + with self.assertRaisesRegex(IllegalStateError, "Unable to connect to any cluster"): + await self.start_client_with(f"{file_name}.pem", "127.0.0.1:5701") + + async def test_hostname_verification_with_loopback_cn(self): + # No entry in SAN but an entry in CN which checked as a fallback + # when no entry in SAN is present. + file_name = "tls-host-loopback-cn" + self.start_member_with(f"{file_name}.p12") + await self.start_client_with(f"{file_name}.pem", "localhost:5701") + # See https://stackoverflow.com/a/8444863/12394291. IP addresses in CN + # are not supported. So, we don't have a test for it. + with self.assertRaisesRegex(IllegalStateError, "Unable to connect to any cluster"): + await self.start_client_with(f"{file_name}.pem", "127.0.0.1:5701") + + async def test_hostname_verification_with_no_entry(self): + # No entry either in the SAN or CN. No way to verify hostname. + file_name = "tls-host-no-entry" + self.start_member_with(f"{file_name}.p12") + with self.assertRaisesRegex(IllegalStateError, "Unable to connect to any cluster"): + await self.start_client_with(f"{file_name}.pem", "localhost:5701") + with self.assertRaisesRegex(IllegalStateError, "Unable to connect to any cluster"): + await self.start_client_with(f"{file_name}.pem", "127.0.0.1:5701") + + async def test_hostname_verification_disabled(self): + # When hostname verification is disabled, the scenarious that + # would fail in `test_hostname_verification_with_no_entry` will + # no longer fail, showing that it is working as expected. + file_name = "tls-host-no-entry" + self.start_member_with(f"{file_name}.p12") + await self.start_client_with(f"{file_name}.pem", "localhost:5701", check_hostname=False) + await self.start_client_with(f"{file_name}.pem", "127.0.0.1:5701", check_hostname=False) + + async def start_client_with( + self, + truststore_name: str, + member_address: str, + *, + check_hostname=True, + ) -> HazelcastClient: + return await self.create_client( + { + "cluster_name": self.cluster.id, + "cluster_members": [member_address], + "ssl_enabled": True, + "ssl_protocol": SSLProtocol.TLSv1_2, + "ssl_cafile": get_abs_path(current_directory, truststore_name), + "ssl_check_hostname": check_hostname, + "cluster_connect_timeout": 0, + } + ) + + def start_member_with(self, keystore_name: str) -> None: + config = MEMBER_CONFIG % get_abs_path(current_directory, keystore_name) + self.cluster = self.create_cluster(self.rc, config) + self.cluster.start_member() diff --git a/tests/integration/asyncio/ssl_tests/mutual_authentication_test.py b/tests/integration/asyncio/ssl_tests/mutual_authentication_test.py new file mode 100644 index 0000000000..2d392278d2 --- /dev/null +++ b/tests/integration/asyncio/ssl_tests/mutual_authentication_test.py @@ -0,0 +1,169 @@ +import os +import unittest + +import pytest + +from tests.integration.asyncio.base import HazelcastTestCase +from hazelcast.asyncio.client import HazelcastClient +from hazelcast.errors import HazelcastError +from tests.util import get_ssl_config, get_abs_path + + +@pytest.mark.enterprise +class MutualAuthenticationTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + current_directory = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../backward_compatible/ssl_tests") + ) + rc = None + mutual_auth = True + ma_req_xml = get_abs_path(current_directory, "hazelcast-ma-required.xml") + ma_opt_xml = get_abs_path(current_directory, "hazelcast-ma-optional.xml") + + def setUp(self): + self.rc = self.create_rc() + + def tearDown(self): + self.rc.exit() + + async def test_ma_required_client_and_server_authenticated(self): + cluster = self.create_cluster(self.rc, self.read_config(True)) + cluster.start_member() + client = await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server1-cert.pem"), + get_abs_path(self.current_directory, "client1-cert.pem"), + get_abs_path(self.current_directory, "client1-key.pem"), + ) + ) + self.assertTrue(client.lifecycle_service.is_running()) + await client.shutdown() + + async def test_ma_required_server_not_authenticated(self): + cluster = self.create_cluster(self.rc, self.read_config(True)) + cluster.start_member() + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server2-cert.pem"), + get_abs_path(self.current_directory, "client1-cert.pem"), + get_abs_path(self.current_directory, "client1-key.pem"), + ) + ) + + async def test_ma_required_client_not_authenticated(self): + cluster = self.create_cluster(self.rc, self.read_config(True)) + cluster.start_member() + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server1-cert.pem"), + get_abs_path(self.current_directory, "client2-cert.pem"), + get_abs_path(self.current_directory, "client2-key.pem"), + ) + ) + + async def test_ma_required_client_and_server_not_authenticated(self): + cluster = self.create_cluster(self.rc, self.read_config(True)) + cluster.start_member() + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server2-cert.pem"), + get_abs_path(self.current_directory, "client2-cert.pem"), + get_abs_path(self.current_directory, "client2-key.pem"), + ) + ) + + async def test_ma_optional_client_and_server_authenticated(self): + cluster = self.create_cluster(self.rc, self.read_config(False)) + cluster.start_member() + client = await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server1-cert.pem"), + get_abs_path(self.current_directory, "client1-cert.pem"), + get_abs_path(self.current_directory, "client1-key.pem"), + ) + ) + self.assertTrue(client.lifecycle_service.is_running()) + await client.shutdown() + + async def test_ma_optional_server_not_authenticated(self): + cluster = self.create_cluster(self.rc, self.read_config(False)) + cluster.start_member() + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server2-cert.pem"), + get_abs_path(self.current_directory, "client1-cert.pem"), + get_abs_path(self.current_directory, "client1-key.pem"), + ) + ) + + async def test_ma_optional_client_not_authenticated(self): + cluster = self.create_cluster(self.rc, self.read_config(False)) + cluster.start_member() + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server1-cert.pem"), + get_abs_path(self.current_directory, "client2-cert.pem"), + get_abs_path(self.current_directory, "client2-key.pem"), + ) + ) + + async def test_ma_optional_client_and_server_not_authenticated(self): + cluster = self.create_cluster(self.rc, self.read_config(False)) + cluster.start_member() + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server2-cert.pem"), + get_abs_path(self.current_directory, "client2-cert.pem"), + get_abs_path(self.current_directory, "client2-key.pem"), + ) + ) + + async def test_ma_required_with_no_cert_file(self): + cluster = self.create_cluster(self.rc, self.read_config(True)) + cluster.start_member() + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, True, get_abs_path(self.current_directory, "server1-cert.pem") + ) + ) + + async def test_ma_optional_with_no_cert_file(self): + cluster = self.create_cluster(self.rc, self.read_config(False)) + cluster.start_member() + client = await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, True, get_abs_path(self.current_directory, "server1-cert.pem") + ) + ) + self.assertTrue(client.lifecycle_service.is_running()) + await client.shutdown() + + def read_config(self, is_ma_required): + file_path = self.ma_req_xml if is_ma_required else self.ma_opt_xml + with open(file_path, "r") as f: + xml_config = f.read() + keystore_path = get_abs_path(self.current_directory, "server1.keystore") + truststore_path = get_abs_path(self.current_directory, "server1.truststore") + return xml_config % (keystore_path, truststore_path) diff --git a/tests/integration/asyncio/ssl_tests/ssl_test.py b/tests/integration/asyncio/ssl_tests/ssl_test.py new file mode 100644 index 0000000000..6190c7571a --- /dev/null +++ b/tests/integration/asyncio/ssl_tests/ssl_test.py @@ -0,0 +1,134 @@ +import os +import unittest + +import pytest + +from tests.integration.asyncio.base import HazelcastTestCase +from hazelcast.asyncio.client import HazelcastClient +from hazelcast.errors import HazelcastError +from hazelcast.config import SSLProtocol +from tests.util import get_ssl_config, get_abs_path +from tests.integration.asyncio.util import fill_map + + +@pytest.mark.enterprise +class SSLTest(unittest.IsolatedAsyncioTestCase, HazelcastTestCase): + current_directory = os.path.abspath( + os.path.join(os.path.dirname(__file__), "../../backward_compatible/ssl_tests") + ) + rc = None + hazelcast_ssl_xml = get_abs_path(current_directory, "hazelcast-ssl.xml") + default_ca_xml = get_abs_path(current_directory, "hazelcast-default-ca.xml") + + def setUp(self): + self.rc = self.create_rc() + + def tearDown(self): + self.rc.exit() + + async def test_ssl_disabled(self): + cluster = self.create_cluster(self.rc, self.read_ssl_config()) + cluster.start_member() + + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start(**get_ssl_config(cluster.id, False)) + + async def test_ssl_enabled_is_client_live(self): + cluster = self.create_cluster(self.rc, self.read_ssl_config()) + cluster.start_member() + + client = await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, True, get_abs_path(self.current_directory, "server1-cert.pem") + ) + ) + self.assertTrue(client.lifecycle_service.is_running()) + await client.shutdown() + + async def test_ssl_enabled_trust_default_certificates(self): + cluster = self.create_cluster(self.rc, self.read_default_ca_config()) + cluster.start_member() + + client = await HazelcastClient.create_and_start(**get_ssl_config(cluster.id, True)) + self.assertTrue(client.lifecycle_service.is_running()) + await client.shutdown() + + async def test_ssl_enabled_dont_trust_self_signed_certificates(self): + # Member started with self-signed certificate + cluster = self.create_cluster(self.rc, self.read_ssl_config()) + cluster.start_member() + + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start(**get_ssl_config(cluster.id, True)) + + async def test_ssl_enabled_map_size(self): + cluster = self.create_cluster(self.rc, self.read_ssl_config()) + cluster.start_member() + + client = await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, True, get_abs_path(self.current_directory, "server1-cert.pem") + ) + ) + test_map = await client.get_map("test_map") + await fill_map(test_map, 10) + self.assertEqual(await test_map.size(), 10) + await client.shutdown() + + async def test_ssl_enabled_with_custom_ciphers(self): + cluster = self.create_cluster(self.rc, self.read_ssl_config()) + cluster.start_member() + + client = await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server1-cert.pem"), + ciphers="ECDHE-RSA-AES128-SHA256:ECDHE-RSA-AES256-GCM-SHA384", + ) + ) + self.assertTrue(client.lifecycle_service.is_running()) + await client.shutdown() + + async def test_ssl_enabled_with_invalid_ciphers(self): + cluster = self.create_cluster(self.rc, self.read_ssl_config()) + cluster.start_member() + + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server1-cert.pem"), + ciphers="INVALID-CIPHER1:INVALID_CIPHER2", + ) + ) + + async def test_ssl_enabled_with_protocol_mismatch(self): + cluster = self.create_cluster(self.rc, self.read_ssl_config()) + cluster.start_member() + + # Member configured with TLSv1 + with self.assertRaises(HazelcastError): + await HazelcastClient.create_and_start( + **get_ssl_config( + cluster.id, + True, + get_abs_path(self.current_directory, "server1-cert.pem"), + protocol=SSLProtocol.SSLv3, + ) + ) + + def read_default_ca_config(self): + with open(self.default_ca_xml, "r") as f: + xml_config = f.read() + + keystore_path = get_abs_path(self.current_directory, "keystore.jks") + return xml_config % (keystore_path, keystore_path) + + def read_ssl_config(self): + with open(self.hazelcast_ssl_xml, "r") as f: + xml_config = f.read() + + keystore_path = get_abs_path(self.current_directory, "server1.keystore") + return xml_config % keystore_path diff --git a/tests/integration/asyncio/util.py b/tests/integration/asyncio/util.py new file mode 100644 index 0000000000..e101a58103 --- /dev/null +++ b/tests/integration/asyncio/util.py @@ -0,0 +1,6 @@ +async def fill_map(map, size=10, key_prefix="key", value_prefix="val"): + entries = dict() + for i in range(size): + entries[key_prefix + str(i)] = value_prefix + str(i) + await map.put_all(entries) + return entries