|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": {}, |
| 6 | + "source": [ |
| 7 | + "# Optimize semantic cache threshold with RedisVL\n", |
| 8 | + "\n", |
| 9 | + "> **Note:** Threshold optimization in redisvl relies on `python > 3.9.`\n", |
| 10 | + "\n", |
| 11 | + "<a href=\"https://colab.research.google.com/github/redis-developer/redis-ai-resources/blob/main/python-recipes/semantic-cache/02_semantic_cache_optimization.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" |
| 12 | + ] |
| 13 | + }, |
| 14 | + { |
| 15 | + "cell_type": "markdown", |
| 16 | + "metadata": {}, |
| 17 | + "source": [ |
| 18 | + "# CacheThresholdOptimizer\n", |
| 19 | + "\n", |
| 20 | + "Let's say you setup the following semantic cache with a distance_threshold of `X` and store the entries:\n", |
| 21 | + "\n", |
| 22 | + "- prompt: `what is the capital of france?` response: `paris`\n", |
| 23 | + "- prompt: `what is the capital of morocco?` response: `rabat`" |
| 24 | + ] |
| 25 | + }, |
| 26 | + { |
| 27 | + "cell_type": "code", |
| 28 | + "execution_count": 1, |
| 29 | + "metadata": {}, |
| 30 | + "outputs": [ |
| 31 | + { |
| 32 | + "name": "stderr", |
| 33 | + "output_type": "stream", |
| 34 | + "text": [ |
| 35 | + "/Users/robert.shelton/.pyenv/versions/3.11.9/lib/python3.11/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", |
| 36 | + " warnings.warn(\n", |
| 37 | + "/Users/robert.shelton/.pyenv/versions/3.11.9/lib/python3.11/site-packages/huggingface_hub/file_download.py:1142: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", |
| 38 | + " warnings.warn(\n" |
| 39 | + ] |
| 40 | + } |
| 41 | + ], |
| 42 | + "source": [ |
| 43 | + "from redisvl.extensions.llmcache import SemanticCache\n", |
| 44 | + "\n", |
| 45 | + "sem_cache = SemanticCache(\n", |
| 46 | + " name=\"sem_cache\", # underlying search index name\n", |
| 47 | + " redis_url=\"redis://localhost:6379\", # redis connection url string\n", |
| 48 | + " distance_threshold=0.5 # semantic cache distance threshold\n", |
| 49 | + ")\n", |
| 50 | + "\n", |
| 51 | + "paris_key = sem_cache.store(prompt=\"what is the capital of france?\", response=\"paris\")\n", |
| 52 | + "rabat_key = sem_cache.store(prompt=\"what is the capital of morocco?\", response=\"rabat\")\n" |
| 53 | + ] |
| 54 | + }, |
| 55 | + { |
| 56 | + "cell_type": "markdown", |
| 57 | + "metadata": {}, |
| 58 | + "source": [ |
| 59 | + "This works well but we want to make sure the cache only applies for the appropriate questions. If we test the cache with a question we don't want a response to we see that the current distance_threshold is too high. " |
| 60 | + ] |
| 61 | + }, |
| 62 | + { |
| 63 | + "cell_type": "code", |
| 64 | + "execution_count": 2, |
| 65 | + "metadata": {}, |
| 66 | + "outputs": [ |
| 67 | + { |
| 68 | + "data": { |
| 69 | + "text/plain": [ |
| 70 | + "[{'entry_id': 'c990cc06e5e77570e5f03360426d2b7f947cbb5a67daa8af8164bfe0b3e24fe3',\n", |
| 71 | + " 'prompt': 'what is the capital of france?',\n", |
| 72 | + " 'response': 'paris',\n", |
| 73 | + " 'vector_distance': 0.421104669571,\n", |
| 74 | + " 'inserted_at': 1741039231.99,\n", |
| 75 | + " 'updated_at': 1741039231.99,\n", |
| 76 | + " 'key': 'sem_cache:c990cc06e5e77570e5f03360426d2b7f947cbb5a67daa8af8164bfe0b3e24fe3'}]" |
| 77 | + ] |
| 78 | + }, |
| 79 | + "execution_count": 2, |
| 80 | + "metadata": {}, |
| 81 | + "output_type": "execute_result" |
| 82 | + } |
| 83 | + ], |
| 84 | + "source": [ |
| 85 | + "sem_cache.check(\"what's the capital of britain?\")" |
| 86 | + ] |
| 87 | + }, |
| 88 | + { |
| 89 | + "cell_type": "markdown", |
| 90 | + "metadata": {}, |
| 91 | + "source": [ |
| 92 | + "### Define test_data and optimize\n", |
| 93 | + "\n", |
| 94 | + "With the `CacheThresholdOptimizer` you can quickly tune the distance threshold by providing some test data in the form:\n", |
| 95 | + "\n", |
| 96 | + "```json\n", |
| 97 | + "[\n", |
| 98 | + " {\n", |
| 99 | + " \"query\": \"What's the capital of Britain?\",\n", |
| 100 | + " \"query_match\": \"\"\n", |
| 101 | + " },\n", |
| 102 | + " {\n", |
| 103 | + " \"query\": \"What's the capital of France??\",\n", |
| 104 | + " \"query_match\": paris_key\n", |
| 105 | + " },\n", |
| 106 | + " {\n", |
| 107 | + " \"query\": \"What's the capital city of Morocco?\",\n", |
| 108 | + " \"query_match\": rabat_key\n", |
| 109 | + " },\n", |
| 110 | + "]\n", |
| 111 | + "```\n", |
| 112 | + "\n", |
| 113 | + "The threshold optimizer will then efficiently execute and score different threshold against the what is currently populated in your cache and automatically update the threshold of the cache to the best setting" |
| 114 | + ] |
| 115 | + }, |
| 116 | + { |
| 117 | + "cell_type": "code", |
| 118 | + "execution_count": 3, |
| 119 | + "metadata": {}, |
| 120 | + "outputs": [ |
| 121 | + { |
| 122 | + "name": "stdout", |
| 123 | + "output_type": "stream", |
| 124 | + "text": [ |
| 125 | + "Distance threshold before: 0.5 \n", |
| 126 | + "\n", |
| 127 | + "Distance threshold after: 0.13050847457627118 \n", |
| 128 | + "\n" |
| 129 | + ] |
| 130 | + } |
| 131 | + ], |
| 132 | + "source": [ |
| 133 | + "from redisvl.utils.optimize import CacheThresholdOptimizer\n", |
| 134 | + "\n", |
| 135 | + "test_data = [\n", |
| 136 | + " {\n", |
| 137 | + " \"query\": \"What's the capital of Britain?\",\n", |
| 138 | + " \"query_match\": \"\"\n", |
| 139 | + " },\n", |
| 140 | + " {\n", |
| 141 | + " \"query\": \"What's the capital of France??\",\n", |
| 142 | + " \"query_match\": paris_key\n", |
| 143 | + " },\n", |
| 144 | + " {\n", |
| 145 | + " \"query\": \"What's the capital city of Morocco?\",\n", |
| 146 | + " \"query_match\": rabat_key\n", |
| 147 | + " },\n", |
| 148 | + "]\n", |
| 149 | + "\n", |
| 150 | + "print(f\"Distance threshold before: {sem_cache.distance_threshold} \\n\")\n", |
| 151 | + "optimizer = CacheThresholdOptimizer(sem_cache, test_data)\n", |
| 152 | + "optimizer.optimize()\n", |
| 153 | + "print(f\"Distance threshold after: {sem_cache.distance_threshold} \\n\")" |
| 154 | + ] |
| 155 | + }, |
| 156 | + { |
| 157 | + "cell_type": "markdown", |
| 158 | + "metadata": {}, |
| 159 | + "source": [ |
| 160 | + "We can also see that we no longer match on the incorrect example:" |
| 161 | + ] |
| 162 | + }, |
| 163 | + { |
| 164 | + "cell_type": "code", |
| 165 | + "execution_count": 4, |
| 166 | + "metadata": {}, |
| 167 | + "outputs": [ |
| 168 | + { |
| 169 | + "data": { |
| 170 | + "text/plain": [ |
| 171 | + "[]" |
| 172 | + ] |
| 173 | + }, |
| 174 | + "execution_count": 4, |
| 175 | + "metadata": {}, |
| 176 | + "output_type": "execute_result" |
| 177 | + } |
| 178 | + ], |
| 179 | + "source": [ |
| 180 | + "sem_cache.check(\"what's the capital of britain?\")" |
| 181 | + ] |
| 182 | + }, |
| 183 | + { |
| 184 | + "cell_type": "markdown", |
| 185 | + "metadata": {}, |
| 186 | + "source": [ |
| 187 | + "But still match on highly relevant prompts:" |
| 188 | + ] |
| 189 | + }, |
| 190 | + { |
| 191 | + "cell_type": "code", |
| 192 | + "execution_count": 5, |
| 193 | + "metadata": {}, |
| 194 | + "outputs": [ |
| 195 | + { |
| 196 | + "data": { |
| 197 | + "text/plain": [ |
| 198 | + "[{'entry_id': 'c990cc06e5e77570e5f03360426d2b7f947cbb5a67daa8af8164bfe0b3e24fe3',\n", |
| 199 | + " 'prompt': 'what is the capital of france?',\n", |
| 200 | + " 'response': 'paris',\n", |
| 201 | + " 'vector_distance': 0.0835866332054,\n", |
| 202 | + " 'inserted_at': 1741039231.99,\n", |
| 203 | + " 'updated_at': 1741039231.99,\n", |
| 204 | + " 'key': 'sem_cache:c990cc06e5e77570e5f03360426d2b7f947cbb5a67daa8af8164bfe0b3e24fe3'}]" |
| 205 | + ] |
| 206 | + }, |
| 207 | + "execution_count": 5, |
| 208 | + "metadata": {}, |
| 209 | + "output_type": "execute_result" |
| 210 | + } |
| 211 | + ], |
| 212 | + "source": [ |
| 213 | + "sem_cache.check(\"what's the capital city of france?\")" |
| 214 | + ] |
| 215 | + }, |
| 216 | + { |
| 217 | + "cell_type": "markdown", |
| 218 | + "metadata": {}, |
| 219 | + "source": [ |
| 220 | + "# Additional configuration\n", |
| 221 | + "\n", |
| 222 | + "By default threshold optimization is performed based on the highest `F1` score but can also be configured to rank results based on `precision` and `recall` by specifying the `eval_metric` keyword argument. " |
| 223 | + ] |
| 224 | + }, |
| 225 | + { |
| 226 | + "cell_type": "code", |
| 227 | + "execution_count": 2, |
| 228 | + "metadata": {}, |
| 229 | + "outputs": [ |
| 230 | + { |
| 231 | + "name": "stdout", |
| 232 | + "output_type": "stream", |
| 233 | + "text": [ |
| 234 | + "Distance threshold before: 0.5 \n", |
| 235 | + "\n" |
| 236 | + ] |
| 237 | + }, |
| 238 | + { |
| 239 | + "ename": "NameError", |
| 240 | + "evalue": "name 'CacheThresholdOptimizer' is not defined", |
| 241 | + "output_type": "error", |
| 242 | + "traceback": [ |
| 243 | + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |
| 244 | + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", |
| 245 | + "Cell \u001b[0;32mIn[2], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDistance threshold before: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msem_cache\u001b[38;5;241m.\u001b[39mdistance_threshold\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m optimizer \u001b[38;5;241m=\u001b[39m \u001b[43mCacheThresholdOptimizer\u001b[49m(sem_cache, test_data, eval_metric\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mprecision\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 3\u001b[0m optimizer\u001b[38;5;241m.\u001b[39moptimize()\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mDistance threshold after: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00msem_cache\u001b[38;5;241m.\u001b[39mdistance_threshold\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m \u001b[39m\u001b[38;5;130;01m\\n\u001b[39;00m\u001b[38;5;124m\"\u001b[39m)\n", |
| 246 | + "\u001b[0;31mNameError\u001b[0m: name 'CacheThresholdOptimizer' is not defined" |
| 247 | + ] |
| 248 | + } |
| 249 | + ], |
| 250 | + "source": [ |
| 251 | + "print(f\"Distance threshold before: {sem_cache.distance_threshold} \\n\")\n", |
| 252 | + "optimizer = CacheThresholdOptimizer(sem_cache, test_data, eval_metric=\"precision\")\n", |
| 253 | + "optimizer.optimize()\n", |
| 254 | + "print(f\"Distance threshold after: {sem_cache.distance_threshold} \\n\")" |
| 255 | + ] |
| 256 | + }, |
| 257 | + { |
| 258 | + "cell_type": "code", |
| 259 | + "execution_count": null, |
| 260 | + "metadata": {}, |
| 261 | + "outputs": [], |
| 262 | + "source": [ |
| 263 | + "print(f\"Distance threshold before: {sem_cache.distance_threshold} \\n\")\n", |
| 264 | + "optimizer = CacheThresholdOptimizer(sem_cache, test_data, eval_metric=\"recall\")\n", |
| 265 | + "optimizer.optimize()\n", |
| 266 | + "print(f\"Distance threshold after: {sem_cache.distance_threshold} \\n\")" |
| 267 | + ] |
| 268 | + }, |
| 269 | + { |
| 270 | + "cell_type": "markdown", |
| 271 | + "metadata": {}, |
| 272 | + "source": [ |
| 273 | + "**Note**: the CacheThresholdOptimizer class also exposes an optional `opt_fn` which can be leveraged to define more custom logic. See implementation within [source code for reference](https://github.com/redis/redis-vl-python/blob/18ff1008c5a40353c97c176d3d30028a87ff777a/redisvl/utils/optimize/cache.py#L48-L49)." |
| 274 | + ] |
| 275 | + }, |
| 276 | + { |
| 277 | + "cell_type": "markdown", |
| 278 | + "metadata": {}, |
| 279 | + "source": [ |
| 280 | + "## Cleanup" |
| 281 | + ] |
| 282 | + }, |
| 283 | + { |
| 284 | + "cell_type": "code", |
| 285 | + "execution_count": null, |
| 286 | + "metadata": {}, |
| 287 | + "outputs": [], |
| 288 | + "source": [ |
| 289 | + "sem_cache.delete()" |
| 290 | + ] |
| 291 | + } |
| 292 | + ], |
| 293 | + "metadata": { |
| 294 | + "kernelspec": { |
| 295 | + "display_name": "Python 3", |
| 296 | + "language": "python", |
| 297 | + "name": "python3" |
| 298 | + }, |
| 299 | + "language_info": { |
| 300 | + "codemirror_mode": { |
| 301 | + "name": "ipython", |
| 302 | + "version": 3 |
| 303 | + }, |
| 304 | + "file_extension": ".py", |
| 305 | + "mimetype": "text/x-python", |
| 306 | + "name": "python", |
| 307 | + "nbconvert_exporter": "python", |
| 308 | + "pygments_lexer": "ipython3", |
| 309 | + "version": "3.11.9" |
| 310 | + }, |
| 311 | + "orig_nbformat": 4 |
| 312 | + }, |
| 313 | + "nbformat": 4, |
| 314 | + "nbformat_minor": 2 |
| 315 | +} |
0 commit comments