11"""Client-side gRPC interceptors."""
22
33import abc
4+ from collections import namedtuple
45import logging
56
6- from grpc import StatusCode , UnaryUnaryClientInterceptor
7+ from grpc import (
8+ ClientCallDetails ,
9+ StatusCode ,
10+ UnaryStreamClientInterceptor ,
11+ UnaryUnaryClientInterceptor ,
12+ )
713
814from ansys .edb .core .inner .exceptions import EDBSessionException , ErrorCode , InvalidArgumentException
15+ from ansys .edb .core .utility .cache import get_cache
916
1017
11- class Interceptor (UnaryUnaryClientInterceptor , metaclass = abc .ABCMeta ):
18+ class Interceptor (UnaryUnaryClientInterceptor , UnaryStreamClientInterceptor , metaclass = abc .ABCMeta ):
1219 """Provides the base interceptor class."""
1320
1421 def __init__ (self , logger ):
@@ -20,14 +27,21 @@ def __init__(self, logger):
2027 def _post_process (self , response ):
2128 pass
2229
30+ def _continue_unary_unary (self , continuation , client_call_details , request ):
31+ return continuation (client_call_details , request )
32+
2333 def intercept_unary_unary (self , continuation , client_call_details , request ):
2434 """Intercept a gRPC call."""
25- response = continuation ( client_call_details , request )
35+ response = self . _continue_unary_unary ( continuation , client_call_details , request )
2636
2737 self ._post_process (response )
2838
2939 return response
3040
41+ def intercept_unary_stream (self , continuation , client_call_details , request ):
42+ """Intercept a gRPC streaming call."""
43+ return continuation (client_call_details , request )
44+
3145
3246class LoggingInterceptor (Interceptor ):
3347 """Logs EDB errors on each request."""
@@ -76,3 +90,78 @@ def _post_process(self, response):
7690
7791 if exception is not None :
7892 raise exception
93+
94+
95+ class CachingInterceptor (Interceptor ):
96+ """Returns cached values if a given request has already been made and caching is enabled."""
97+
98+ def __init__ (self , logger , rpc_counter ):
99+ """Initialize a caching interceptor with a logger and rpc counter."""
100+ super ().__init__ (logger )
101+ self ._rpc_counter = rpc_counter
102+ self ._reset_cache_entry_data ()
103+
104+ def _reset_cache_entry_data (self ):
105+ self ._current_rpc_method = ""
106+ self ._current_cache_key_details = None
107+
108+ def _should_log_traffic (self ):
109+ return self ._rpc_counter is not None
110+
111+ class _ClientCallDetails (
112+ namedtuple ("_ClientCallDetails" , ("method" , "timeout" , "metadata" , "credentials" )),
113+ ClientCallDetails ,
114+ ):
115+ pass
116+
117+ @classmethod
118+ def _get_client_call_details_with_caching_options (cls , client_call_details ):
119+ if get_cache () is None :
120+ return client_call_details
121+ metadata = []
122+ if client_call_details .metadata is not None :
123+ metadata = list (client_call_details .metadata )
124+ metadata .append (("enable-caching" , "1" ))
125+ return cls ._ClientCallDetails (
126+ client_call_details .method ,
127+ client_call_details .timeout ,
128+ metadata ,
129+ client_call_details .credentials ,
130+ )
131+
132+ def _continue_unary_unary (self , continuation , client_call_details , request ):
133+ if self ._should_log_traffic ():
134+ self ._current_rpc_method = client_call_details .method
135+ cache = get_cache ()
136+ if cache is not None :
137+ method_tokens = client_call_details .method .strip ("/" ).split ("/" )
138+ cache_key_details = method_tokens [0 ], method_tokens [1 ], request
139+ cached_response = cache .get (* cache_key_details )
140+ if cached_response is not None :
141+ return cached_response
142+ else :
143+ self ._current_cache_key_details = cache_key_details
144+ return super ()._continue_unary_unary (
145+ continuation ,
146+ self ._get_client_call_details_with_caching_options (client_call_details ),
147+ request ,
148+ )
149+
150+ def _cache_missed (self ):
151+ return self ._current_cache_key_details is not None
152+
153+ def _post_process (self , response ):
154+ cache = get_cache ()
155+ if cache is not None and self ._cache_missed ():
156+ cache .add (* self ._current_cache_key_details , response .result ())
157+ if self ._should_log_traffic () and (cache is None or self ._cache_missed ()):
158+ self ._rpc_counter [self ._current_rpc_method ] += 1
159+ self ._reset_cache_entry_data ()
160+
161+ def intercept_unary_stream (self , continuation , client_call_details , request ):
162+ """Intercept a gRPC streaming call."""
163+ return super ().intercept_unary_stream (
164+ continuation ,
165+ self ._get_client_call_details_with_caching_options (client_call_details ),
166+ request ,
167+ )
0 commit comments