|
7 | 7 |
|
8 | 8 | from __future__ import annotations
|
9 | 9 |
|
| 10 | +import contextlib |
10 | 11 | import datetime
|
11 | 12 | import gc
|
12 | 13 | import os
|
|
31 | 32 | AF_INET6,
|
32 | 33 | MSG_PEEK,
|
33 | 34 | SHUT_RDWR,
|
| 35 | + SO_RCVBUF, |
| 36 | + SO_SNDBUF, |
| 37 | + SOL_SOCKET, |
34 | 38 | gaierror,
|
35 | 39 | socket,
|
36 | 40 | )
|
@@ -413,6 +417,88 @@ def handshake_in_memory(
|
413 | 417 | interact_in_memory(client_conn, server_conn)
|
414 | 418 |
|
415 | 419 |
|
| 420 | +def get_ssl_error_reason(ssl_error: SSL.Error) -> str | None: |
| 421 | + """ |
| 422 | + Extracts the reason string from the first error tuple in an SSL.Error. |
| 423 | + Returns None if the expected error structure is not found. |
| 424 | + """ |
| 425 | + if ( |
| 426 | + ssl_error.args |
| 427 | + and isinstance(ssl_error.args, tuple) |
| 428 | + and len(ssl_error.args) > 0 |
| 429 | + ): |
| 430 | + error_details = ssl_error.args[0] # list of error tuples |
| 431 | + if isinstance(error_details, list) and len(error_details) > 0: |
| 432 | + first_error_tuple = error_details[0] |
| 433 | + if ( |
| 434 | + isinstance(first_error_tuple, tuple) |
| 435 | + and len(first_error_tuple) >= 3 |
| 436 | + ): |
| 437 | + reason = first_error_tuple[2] |
| 438 | + if isinstance(reason, str): |
| 439 | + return reason |
| 440 | + return None |
| 441 | + |
| 442 | + |
| 443 | +def create_ssl_nonblocking_connection( |
| 444 | + mode: int | None, request_send_buffer_size: int |
| 445 | +) -> tuple[Connection, Connection, int, int]: |
| 446 | + """ |
| 447 | + Create a pair of sockets and set up an SSL connection between them. |
| 448 | + mode: The mode to set if not None. |
| 449 | + request_send_buffer_size: requested size of the send buffer |
| 450 | + Returns the SSL Connection objects |
| 451 | + and the actual send/receive buffer sizes. |
| 452 | + """ |
| 453 | + |
| 454 | + client_socket, server_socket = socket_pair() |
| 455 | + |
| 456 | + # Set up client context |
| 457 | + client_ctx = Context(SSLv23_METHOD) |
| 458 | + |
| 459 | + # SSL_MODE_ENABLE_PARTIAL_WRITE and |
| 460 | + # SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER modes |
| 461 | + # are set by default when ctx is initialized. |
| 462 | + # Clear them if requested so tests can |
| 463 | + # be run without them if so desired. |
| 464 | + if mode is not None: |
| 465 | + client_ctx.clear_mode( |
| 466 | + _lib.SSL_MODE_ENABLE_PARTIAL_WRITE |
| 467 | + | _lib.SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER |
| 468 | + ) |
| 469 | + # Set the new mode to the requested value |
| 470 | + client_ctx.set_mode(mode) |
| 471 | + |
| 472 | + # create the SSL connections |
| 473 | + client = Connection(client_ctx, client_socket) |
| 474 | + server = loopback_server_factory(server_socket) |
| 475 | + |
| 476 | + # Allow caller to request small buffer sizes so they can be easily filled. |
| 477 | + # Note the OS may not respect the requested values. |
| 478 | + # Make the receive buffer smaller than the send buffer. |
| 479 | + requested_receive_buffer_size = request_send_buffer_size // 2 |
| 480 | + client_socket.setsockopt(SOL_SOCKET, SO_SNDBUF, request_send_buffer_size) |
| 481 | + actual_sndbuf = client_socket.getsockopt(SOL_SOCKET, SO_SNDBUF) |
| 482 | + |
| 483 | + server_socket.setsockopt( |
| 484 | + SOL_SOCKET, SO_RCVBUF, requested_receive_buffer_size |
| 485 | + ) |
| 486 | + actual_rcvbuf = server_socket.getsockopt(SOL_SOCKET, SO_RCVBUF) |
| 487 | + |
| 488 | + # set the connection state |
| 489 | + client.set_connect_state() |
| 490 | + # loopback_server_factory already sets the accept state on the server |
| 491 | + |
| 492 | + handshake(client, server) |
| 493 | + |
| 494 | + return ( |
| 495 | + client, |
| 496 | + server, |
| 497 | + actual_sndbuf, |
| 498 | + actual_rcvbuf, |
| 499 | + ) |
| 500 | + |
| 501 | + |
416 | 502 | class TestVersion:
|
417 | 503 | """
|
418 | 504 | Tests for version information exposed by `OpenSSL.SSL.SSLeay_version` and
|
@@ -3011,6 +3097,185 @@ def test_wantWriteError(self) -> None:
|
3011 | 3097 |
|
3012 | 3098 | # XXX want_read
|
3013 | 3099 |
|
| 3100 | + def _attempt_want_write_error( |
| 3101 | + self, client: Connection, buffer_size: int |
| 3102 | + ) -> bytes: |
| 3103 | + """ |
| 3104 | + Deliberately attempts to send application data |
| 3105 | + over SSL to trigger WantWriteError. The send may need |
| 3106 | + to be repeated many times depending on the socket and |
| 3107 | + network buffer sizes allocated by the environment. |
| 3108 | + Returns the message that triggered the error so that |
| 3109 | + the buffer for the message is not immediately reclaimed. |
| 3110 | + """ |
| 3111 | + initial_want_write_triggered = False |
| 3112 | + max_num_of_attempts = 100000 |
| 3113 | + |
| 3114 | + for i in range(max_num_of_attempts): |
| 3115 | + msg = b"Y" * buffer_size |
| 3116 | + try: |
| 3117 | + client.send(msg) |
| 3118 | + except SSL.WantWriteError: |
| 3119 | + initial_want_write_triggered = True |
| 3120 | + break # Exit loop as desired error was triggered |
| 3121 | + |
| 3122 | + assert initial_want_write_triggered, ( |
| 3123 | + f"Could not induce WantWriteError within {i + 1} attempts" |
| 3124 | + ) |
| 3125 | + return msg |
| 3126 | + |
| 3127 | + def _drain_server_buffers(self, server: Connection) -> None: |
| 3128 | + """Reads from server SSL and raw sockets to drain any pending data.""" |
| 3129 | + total_ssl_read = 0 |
| 3130 | + consecutive_empty_ssl_reads = 0 |
| 3131 | + |
| 3132 | + while total_ssl_read < 1024 * 1024: |
| 3133 | + try: |
| 3134 | + data = server.recv(65536) |
| 3135 | + # if serverbuffer is empty the call should |
| 3136 | + # raise WantReadError not return None |
| 3137 | + assert data is not None, "SSL peer closed or empty data" |
| 3138 | + total_ssl_read += len(data) |
| 3139 | + # Reset counter on successful read |
| 3140 | + consecutive_empty_ssl_reads = 0 |
| 3141 | + except SSL.WantReadError: |
| 3142 | + consecutive_empty_ssl_reads += 1 |
| 3143 | + if consecutive_empty_ssl_reads >= 10: |
| 3144 | + # "No more SSL application data available after |
| 3145 | + # consecutive_empty_ssl_readss |
| 3146 | + return |
| 3147 | + # Small delay to allow time for clearing buffers |
| 3148 | + time.sleep(0.01) |
| 3149 | + |
| 3150 | + def _perform_moving_buffer_test( |
| 3151 | + self, client: Connection, buffer_size: int, want_bad_retry: bool |
| 3152 | + ) -> bool: |
| 3153 | + """ |
| 3154 | + Attempts a retry write with a moving buffer and checks for |
| 3155 | + 'bad write retry' error. |
| 3156 | + Returns True if 'bad write retry' occurs, False otherwise. |
| 3157 | + """ |
| 3158 | + # Attempt retry with different buffer but same size |
| 3159 | + msg2 = b"Z" * buffer_size |
| 3160 | + try: |
| 3161 | + bytes_written = client.send(msg2) |
| 3162 | + assert not want_bad_retry, ( |
| 3163 | + "_perform_moving_buffer_test() failed as retry succeeded " |
| 3164 | + f"unexpectedly with {bytes_written} bytes written." |
| 3165 | + ) |
| 3166 | + return False # Retry succeeded |
| 3167 | + except SSL.Error as e: |
| 3168 | + reason = get_ssl_error_reason(e) |
| 3169 | + assert reason == "bad write retry", ( |
| 3170 | + f"Retry failed with unexpected SSL error: {e!r}({reason})." |
| 3171 | + ) |
| 3172 | + return True # Bad write retry |
| 3173 | + |
| 3174 | + def _shutdown_connections( |
| 3175 | + self, |
| 3176 | + client: Connection, |
| 3177 | + server: Connection, |
| 3178 | + ) -> None: |
| 3179 | + """Helper to safely shut down SSL connections and close sockets.""" |
| 3180 | + if client: |
| 3181 | + with contextlib.suppress(SSL.Error): |
| 3182 | + # When closing connections in the test teardown stage, |
| 3183 | + # we don't care about possible TLS-level problems as the test |
| 3184 | + # was specifically emulating corner case situations |
| 3185 | + # pre-shutdown. We just attempt releasing resources |
| 3186 | + # if possible and disregard any possibly related |
| 3187 | + # problems that may occur at this point. |
| 3188 | + client.shutdown() |
| 3189 | + if server: |
| 3190 | + with contextlib.suppress(SSL.Error): |
| 3191 | + server.shutdown() |
| 3192 | + |
| 3193 | + @pytest.fixture |
| 3194 | + def ssl_connection_setup( |
| 3195 | + self, request: pytest.FixtureRequest |
| 3196 | + ) -> typing.Generator[ |
| 3197 | + tuple[Connection, Connection, int, bool], |
| 3198 | + None, |
| 3199 | + None, |
| 3200 | + ]: |
| 3201 | + """ |
| 3202 | + Sets up a non-blocking SSL connection for testing |
| 3203 | + bad_write_retry errors. |
| 3204 | + Modeflag allows the caller to turn off |
| 3205 | + SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER which is normally |
| 3206 | + on by default. |
| 3207 | + """ |
| 3208 | + want_bad_retry = request.param.get("want_bad_retry") |
| 3209 | + request_buffer_size = request.param.get("request_buffer_size") |
| 3210 | + modeflag = request.param.get("modeflag") |
| 3211 | + |
| 3212 | + client, server, sndbuf, rcvbuf = create_ssl_nonblocking_connection( |
| 3213 | + modeflag, request_buffer_size |
| 3214 | + ) |
| 3215 | + # Use a buffer size that is half the size |
| 3216 | + # of the allocated socket buffers |
| 3217 | + buffer_size = min(sndbuf, rcvbuf) // 2 |
| 3218 | + |
| 3219 | + # Yield the resources needed by the test |
| 3220 | + yield ( |
| 3221 | + client, |
| 3222 | + server, |
| 3223 | + buffer_size, |
| 3224 | + want_bad_retry, |
| 3225 | + ) |
| 3226 | + |
| 3227 | + # Teardown: Clean up the connections after the test finishes |
| 3228 | + self._shutdown_connections(client, server) |
| 3229 | + |
| 3230 | + @pytest.mark.parametrize( |
| 3231 | + "ssl_connection_setup", |
| 3232 | + [ |
| 3233 | + { |
| 3234 | + "request_buffer_size": 65536, |
| 3235 | + "modeflag": _lib.SSL_MODE_ENABLE_PARTIAL_WRITE, |
| 3236 | + "want_bad_retry": True, |
| 3237 | + }, |
| 3238 | + { |
| 3239 | + "request_buffer_size": 65536, |
| 3240 | + "modeflag": None, |
| 3241 | + "want_bad_retry": False, |
| 3242 | + }, |
| 3243 | + ], |
| 3244 | + indirect=True, |
| 3245 | + ) |
| 3246 | + def test_moving_buffer_behavior( |
| 3247 | + self, |
| 3248 | + ssl_connection_setup: tuple[Connection, Connection, int, bool], |
| 3249 | + ) -> None: |
| 3250 | + """Tests for possible "bad write retry" errors over an SSL connection. |
| 3251 | + If an SSL connection partially processes some data, |
| 3252 | + and then hits an `OpenSSL.SSL.WantWriteError`, |
| 3253 | + the connection may expect a retry. When PyOpenSSL creates |
| 3254 | + a new connection object, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER is |
| 3255 | + applied by default. This mode allows for data to be sent from a |
| 3256 | + different buffer location, something that may happen if Python moves a |
| 3257 | + mutable object such as a bytearray as part of its memory management. |
| 3258 | + If the mode is turned off, OpenSSL will reject the resend with |
| 3259 | + "bad_write_retry" error. |
| 3260 | + """ |
| 3261 | + ( |
| 3262 | + client, |
| 3263 | + server, |
| 3264 | + buffer_size, |
| 3265 | + want_bad_retry, |
| 3266 | + ) = ssl_connection_setup |
| 3267 | + |
| 3268 | + _ = self._attempt_want_write_error(client, buffer_size) |
| 3269 | + self._drain_server_buffers(server) |
| 3270 | + |
| 3271 | + # Perform the test and get the result |
| 3272 | + result = self._perform_moving_buffer_test( |
| 3273 | + client, buffer_size, want_bad_retry |
| 3274 | + ) |
| 3275 | + |
| 3276 | + # Assert that the result matches the expected outcome from the fixture |
| 3277 | + assert result == want_bad_retry |
| 3278 | + |
3014 | 3279 | def test_get_finished_before_connect(self) -> None:
|
3015 | 3280 | """
|
3016 | 3281 | `Connection.get_finished` returns `None` before TLS handshake
|
|
0 commit comments