Skip to content

Commit f9b96f2

Browse files
authored
Mock check_hostname everywhere (#152)
* Fix for #151. * Complete refactoring.
1 parent a109d33 commit f9b96f2

File tree

2 files changed

+26
-54
lines changed

2 files changed

+26
-54
lines changed

mocket/mocket.py

Lines changed: 26 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414

1515
import decorator
1616
import urllib3
17+
from urllib3.connection import match_hostname as urllib3_match_hostname
1718
from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket
1819
from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket
1920

2021
from .compat import basestring, byte_type, decode_from_bytes, encode_to_bytes, text_type
21-
from .utils import SSL_PROTOCOL, MocketSocketCore, hexdump, hexload, wrap_ssl_socket
22+
from .utils import SSL_PROTOCOL, MocketSocketCore, hexdump, hexload
2223

2324
xxh32 = None
2425
try:
@@ -49,22 +50,32 @@
4950
true_inet_pton = socket.inet_pton
5051
true_urllib3_wrap_socket = urllib3_wrap_socket
5152
true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket
53+
true_urllib3_match_hostname = urllib3_match_hostname
5254

5355

5456
class SuperFakeSSLContext(object):
55-
""" For Python 3.6 """
57+
"""For Python 3.6"""
5658

5759
class FakeSetter(int):
5860
def __set__(self, *args):
5961
pass
6062

6163
options = FakeSetter()
62-
verify_mode = FakeSetter(ssl.CERT_OPTIONAL)
64+
verify_mode = FakeSetter(ssl.CERT_NONE)
6365

6466

6567
class FakeSSLContext(SuperFakeSSLContext):
6668
sock = None
6769
post_handshake_auth = None
70+
_check_hostname = False
71+
72+
@property
73+
def check_hostname(self):
74+
return self._check_hostname
75+
76+
@check_hostname.setter
77+
def check_hostname(self, *args):
78+
self._check_hostname = False
6879

6980
def __init__(self, sock=None, server_hostname=None, _context=None, *args, **kwargs):
7081
if isinstance(sock, MocketSocket):
@@ -141,16 +152,6 @@ def __init__(
141152
self._truesocket_recording_dir = None
142153
self.kwargs = kwargs
143154

144-
sock = kwargs.get("sock")
145-
if sock is not None:
146-
self.__dict__ = dict(sock.__dict__)
147-
148-
self.true_socket = wrap_ssl_socket(
149-
true_ssl_socket,
150-
self.true_socket,
151-
true_ssl_context(protocol=SSL_PROTOCOL),
152-
)
153-
154155
def __unicode__(self): # pragma: no cover
155156
return str(self)
156157

@@ -323,16 +324,10 @@ def true_sendall(self, data, *args, **kwargs):
323324
host = true_gethostbyname(host)
324325

325326
if isinstance(self.true_socket, true_socket) and self._secure_socket:
326-
try:
327-
self = MocketSocket(sock=self)
328-
except TypeError:
329-
ssl_context = self.kwargs.get("ssl_context")
330-
server_hostname = self.kwargs.get("server_hostname")
331-
self.true_socket = true_ssl_context.wrap_socket(
332-
self=ssl_context,
333-
sock=self.true_socket,
334-
server_hostname=server_hostname,
335-
)
327+
self.true_socket = true_urllib3_ssl_wrap_socket(
328+
self.true_socket,
329+
**self.kwargs,
330+
)
336331

337332
try:
338333
self.true_socket.connect((host, port))
@@ -388,7 +383,7 @@ def close(self):
388383
self._fd = None
389384

390385
def __getattr__(self, name):
391-
""" Do nothing catchall function, for methods like close() and shutdown() """
386+
"""Do nothing catchall function, for methods like close() and shutdown()"""
392387

393388
def do_nothing(*args, **kwargs):
394389
pass
@@ -479,6 +474,9 @@ def enable(namespace=None, truesocket_recording_dir=None):
479474
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
480475
"ssl_wrap_socket"
481476
] = FakeSSLContext.wrap_socket
477+
urllib3.connection.match_hostname = urllib3.connection.__dict__[
478+
"match_hostname"
479+
] = lambda cert, hostname: None
482480
if pyopenssl_override: # pragma: no cover
483481
# Take out the pyopenssl version - use the default implementation
484482
extract_from_urllib3()
@@ -506,6 +504,9 @@ def disable():
506504
urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[
507505
"ssl_wrap_socket"
508506
] = true_urllib3_ssl_wrap_socket
507+
urllib3.connection.match_hostname = urllib3.connection.__dict__[
508+
"match_hostname"
509+
] = true_urllib3_match_hostname
509510
Mocket.reset()
510511
if pyopenssl_override: # pragma: no cover
511512
# Put the pyopenssl version back in place
@@ -521,7 +522,7 @@ def get_truesocket_recording_dir(cls):
521522

522523
@classmethod
523524
def assert_fail_if_entries_not_served(cls):
524-
""" Mocket checks that all entries have been served at least once. """
525+
"""Mocket checks that all entries have been served at least once."""
525526
assert all(
526527
entry._served for entry in itertools.chain(*cls._entries.values())
527528
), "Some Mocket entries have not been served"

mocket/utils.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18,35 +18,6 @@ def write(self, content):
1818
os.write(Mocket.w_fd, content)
1919

2020

21-
def wrap_ssl_socket(
22-
cls,
23-
sock,
24-
context,
25-
keyfile=None,
26-
certfile=None,
27-
server_side=False,
28-
cert_reqs=ssl.CERT_NONE,
29-
ssl_version=SSL_PROTOCOL,
30-
ca_certs=None,
31-
do_handshake_on_connect=True,
32-
suppress_ragged_eofs=True,
33-
ciphers=None,
34-
):
35-
return cls(
36-
sock=sock,
37-
keyfile=keyfile,
38-
certfile=certfile,
39-
server_side=server_side,
40-
cert_reqs=cert_reqs,
41-
ssl_version=ssl_version,
42-
ca_certs=ca_certs,
43-
do_handshake_on_connect=do_handshake_on_connect,
44-
suppress_ragged_eofs=suppress_ragged_eofs,
45-
ciphers=ciphers,
46-
_context=context,
47-
)
48-
49-
5021
def hexdump(binary_string):
5122
r"""
5223
>>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F"))

0 commit comments

Comments
 (0)