|
| 1 | +use dashmap::DashMap; |
| 2 | +use ntex::web::HttpResponse; |
| 3 | +use redis::Commands; |
| 4 | +use sonic_rs::json; |
| 5 | + |
| 6 | +use crate::{ |
| 7 | + plugins::traits::{ |
| 8 | + ControlFlow, OnExecuteEnd, OnExecuteEndPayload, OnExecuteStart, OnExecuteStartPayload, |
| 9 | + OnSchemaReload, OnSchemaReloadPayload, |
| 10 | + }, |
| 11 | + utils::consts::TYPENAME_FIELD_NAME, |
| 12 | +}; |
| 13 | + |
| 14 | +pub struct ResponseCachePlugin { |
| 15 | + redis_client: redis::Client, |
| 16 | + ttl_per_type: DashMap<String, u64>, |
| 17 | +} |
| 18 | + |
| 19 | +impl ResponseCachePlugin { |
| 20 | + pub fn try_new(redis_url: &str) -> Result<Self, redis::RedisError> { |
| 21 | + let redis_client = redis::Client::open(redis_url)?; |
| 22 | + Ok(Self { |
| 23 | + redis_client, |
| 24 | + ttl_per_type: DashMap::new(), |
| 25 | + }) |
| 26 | + } |
| 27 | +} |
| 28 | + |
| 29 | +pub struct ResponseCacheContext { |
| 30 | + key: String, |
| 31 | +} |
| 32 | + |
| 33 | +impl OnExecuteStart for ResponseCachePlugin { |
| 34 | + fn on_execute_start(&self, payload: OnExecuteStartPayload) -> ControlFlow { |
| 35 | + let key = format!( |
| 36 | + "response_cache:{}:{:?}", |
| 37 | + payload.query_plan, payload.variable_values |
| 38 | + ); |
| 39 | + payload |
| 40 | + .router_http_request |
| 41 | + .extensions_mut() |
| 42 | + .insert(ResponseCacheContext { key: key.clone() }); |
| 43 | + if let Ok(mut conn) = self.redis_client.get_connection() { |
| 44 | + let cached_response: Option<Vec<u8>> = conn.get(&key).ok(); |
| 45 | + if let Some(cached_response) = cached_response { |
| 46 | + return ControlFlow::Break( |
| 47 | + HttpResponse::Ok() |
| 48 | + .header("Content-Type", "application/json") |
| 49 | + .body(cached_response), |
| 50 | + ); |
| 51 | + } |
| 52 | + } |
| 53 | + ControlFlow::Continue |
| 54 | + } |
| 55 | +} |
| 56 | + |
| 57 | +impl OnExecuteEnd for ResponseCachePlugin { |
| 58 | + fn on_execute_end(&self, payload: OnExecuteEndPayload) -> ControlFlow { |
| 59 | + // Do not cache if there are errors |
| 60 | + if !payload.errors.is_empty() { |
| 61 | + return ControlFlow::Continue; |
| 62 | + } |
| 63 | + if let Some(key) = payload |
| 64 | + .router_http_request |
| 65 | + .extensions() |
| 66 | + .get::<ResponseCacheContext>() |
| 67 | + .map(|ctx| &ctx.key) |
| 68 | + { |
| 69 | + if let Ok(mut conn) = self.redis_client.get_connection() { |
| 70 | + if let Ok(serialized) = sonic_rs::to_vec(&payload.data) { |
| 71 | + // Decide on the ttl somehow |
| 72 | + // Get the type names |
| 73 | + let mut max_ttl = 0; |
| 74 | + |
| 75 | + // Imagine this code is traversing the response data to find type names |
| 76 | + if let Some(obj) = payload.data.as_object() { |
| 77 | + if let Some(typename) = obj |
| 78 | + .iter() |
| 79 | + .position(|(k, _)| k == &TYPENAME_FIELD_NAME) |
| 80 | + .and_then(|idx| obj[idx].1.as_str()) |
| 81 | + { |
| 82 | + if let Some(ttl) = self.ttl_per_type.get(typename).map(|v| *v) { |
| 83 | + max_ttl = max_ttl.max(ttl); |
| 84 | + } |
| 85 | + } |
| 86 | + } |
| 87 | + |
| 88 | + // If no ttl found, default to 60 seconds |
| 89 | + if max_ttl == 0 { |
| 90 | + max_ttl = 60; |
| 91 | + } |
| 92 | + |
| 93 | + // Insert the ttl into extensions for client awareness |
| 94 | + payload |
| 95 | + .extensions |
| 96 | + .insert("response_cache_ttl".to_string(), json!(max_ttl)); |
| 97 | + |
| 98 | + // Set the cache with the decided ttl |
| 99 | + let _: () = conn.set_ex(key, serialized, max_ttl).unwrap_or(()); |
| 100 | + } |
| 101 | + } |
| 102 | + } |
| 103 | + ControlFlow::Continue |
| 104 | + } |
| 105 | +} |
| 106 | + |
| 107 | +impl OnSchemaReload for ResponseCachePlugin { |
| 108 | + fn on_schema_reload(&self, payload: OnSchemaReloadPayload) { |
| 109 | + // Visit the schema and update ttl_per_type based on some directive |
| 110 | + payload |
| 111 | + .new_schema |
| 112 | + .document |
| 113 | + .definitions |
| 114 | + .iter() |
| 115 | + .for_each(|def| { |
| 116 | + if let graphql_parser::schema::Definition::TypeDefinition(type_def) = def { |
| 117 | + if let graphql_parser::schema::TypeDefinition::Object(obj_type) = type_def { |
| 118 | + for directive in &obj_type.directives { |
| 119 | + if directive.name == "cacheControl" { |
| 120 | + for arg in &directive.arguments { |
| 121 | + if arg.0 == "maxAge" { |
| 122 | + if let graphql_parser::query::Value::Int(max_age) = &arg.1 { |
| 123 | + if let Some(max_age) = max_age.as_i64() { |
| 124 | + self.ttl_per_type |
| 125 | + .insert(obj_type.name.clone(), max_age as u64); |
| 126 | + } |
| 127 | + } |
| 128 | + } |
| 129 | + } |
| 130 | + } |
| 131 | + } |
| 132 | + } |
| 133 | + } |
| 134 | + }); |
| 135 | + } |
| 136 | +} |
0 commit comments