3
3
import errno
4
4
import hashlib
5
5
import io
6
+ import itertools
6
7
import json
7
8
import os
8
9
import select
16
17
from urllib3 .util .ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket
17
18
from urllib3 .util .ssl_ import wrap_socket as urllib3_wrap_socket
18
19
19
- from .compat import (
20
- basestring ,
21
- byte_type ,
22
- decode_from_bytes ,
23
- encode_to_bytes ,
24
- text_type ,
25
- )
20
+ from .compat import basestring , byte_type , decode_from_bytes , encode_to_bytes , text_type
26
21
from .utils import SSL_PROTOCOL , MocketSocketCore , hexdump , hexload , wrap_ssl_socket
27
22
28
23
xxh32 = None
@@ -517,6 +512,13 @@ def get_namespace(cls):
517
512
def get_truesocket_recording_dir (cls ):
518
513
return cls ._truesocket_recording_dir
519
514
515
+ @classmethod
516
+ def assert_fail_if_entries_not_served (cls ):
517
+ """ Mocket checks that all entries have been served at least once. """
518
+ assert all (
519
+ entry ._served for entry in itertools .chain (* cls ._entries .values ())
520
+ ), "Some Mocket entries have not been served"
521
+
520
522
521
523
class MocketEntry (object ):
522
524
class Response (byte_type ):
@@ -526,8 +528,11 @@ def data(self):
526
528
527
529
request_cls = str
528
530
response_cls = Response
531
+ responses = None
532
+ _served = None
529
533
530
534
def __init__ (self , location , responses ):
535
+ self ._served = False
531
536
self .location = location
532
537
self .response_index = 0
533
538
@@ -536,19 +541,18 @@ def __init__(self, location, responses):
536
541
):
537
542
responses = [responses ]
538
543
539
- lresponses = []
544
+ self . responses = []
540
545
for r in responses :
541
546
if isinstance (r , BaseException ):
542
547
pass
543
548
elif not getattr (r , "data" , False ):
544
549
if isinstance (r , text_type ):
545
550
r = encode_to_bytes (r )
546
551
r = self .response_cls (r )
547
- lresponses .append (r )
552
+ self . responses .append (r )
548
553
else :
549
554
if not responses :
550
- lresponses = [self .response_cls (encode_to_bytes ("" ))]
551
- self .responses = lresponses
555
+ self .responses = [self .response_cls (encode_to_bytes ("" ))]
552
556
553
557
def can_handle (self , data ):
554
558
return True
@@ -562,6 +566,8 @@ def get_response(self):
562
566
if self .response_index < len (self .responses ) - 1 :
563
567
self .response_index += 1
564
568
569
+ self ._served = True
570
+
565
571
if isinstance (response , BaseException ):
566
572
raise response
567
573
0 commit comments