|
15 | 15 | "id": "fcbb5d61-8a0b-47d9-a7c5-0c041c82b8bf", |
16 | 16 | "metadata": {}, |
17 | 17 | "source": [ |
18 | | - "# 🚀 Deploy `deepseek-ai/DeepSeek-R1-Distill-Llama-8B` on Amazon SageMaker" |
| 18 | + "# 🚀 Deploy `Qwen/Qwen3-4B-Instruct-2507` on Amazon SageMaker" |
19 | 19 | ] |
20 | 20 | }, |
21 | 21 | { |
22 | 22 | "cell_type": "markdown", |
23 | 23 | "id": "dd210e90-21e1-4f03-a08e-c3fba9aa6979", |
24 | 24 | "metadata": {}, |
25 | 25 | "source": [ |
| 26 | + "## Prerequisites\n", |
| 27 | + "\n", |
26 | 28 | "To start off, let's install some packages to help us through the notebooks. **Restart the kernel after packages have been installed.**" |
27 | 29 | ] |
28 | 30 | }, |
|
57 | 59 | "get_ipython().kernel.do_shutdown(True)" |
58 | 60 | ] |
59 | 61 | }, |
| 62 | + { |
| 63 | + "cell_type": "markdown", |
| 64 | + "id": "a947367a-bea3-498a-9548-d6e6e08f0d10", |
| 65 | + "metadata": {}, |
| 66 | + "source": [ |
| 67 | + "***" |
| 68 | + ] |
| 69 | + }, |
60 | 70 | { |
61 | 71 | "cell_type": "code", |
62 | 72 | "execution_count": null, |
|
66 | 76 | "source": [ |
67 | 77 | "import os\n", |
68 | 78 | "import sagemaker\n", |
69 | | - "from sagemaker.djl_inference import DJLModel\n", |
70 | | - "from ipywidgets import Dropdown\n", |
71 | | - "\n", |
| 79 | + "import boto3\n", |
| 80 | + "import shutil\n", |
| 81 | + "from sagemaker.config import load_sagemaker_config\n", |
72 | 82 | "import sys\n", |
73 | 83 | "sys.path.append(os.path.dirname(os.getcwd()))\n", |
74 | 84 | "\n", |
|
78 | 88 | " print_dialog,\n", |
79 | 89 | " format_messages,\n", |
80 | 90 | " write_eula\n", |
81 | | - ")" |
82 | | - ] |
83 | | - }, |
84 | | - { |
85 | | - "cell_type": "code", |
86 | | - "execution_count": null, |
87 | | - "id": "8b53f21c-3a65-44fc-b547-712d971cd652", |
88 | | - "metadata": {}, |
89 | | - "outputs": [], |
90 | | - "source": [ |
91 | | - "import boto3\n", |
92 | | - "import shutil\n", |
93 | | - "import sagemaker\n", |
94 | | - "from sagemaker.config import load_sagemaker_config\n", |
| 91 | + ")\n", |
95 | 92 | "\n", |
96 | 93 | "sagemaker_session = sagemaker.Session()\n", |
97 | 94 | "s3_client = boto3.client('s3')\n", |
98 | 95 | "\n", |
| 96 | + "region = sagemaker_session.boto_session.region_name\n", |
99 | 97 | "bucket_name = sagemaker_session.default_bucket()\n", |
100 | 98 | "default_prefix = sagemaker_session.default_bucket_prefix\n", |
101 | 99 | "configs = load_sagemaker_config()\n", |
102 | 100 | "\n", |
103 | 101 | "session = sagemaker.Session()\n", |
104 | 102 | "role = sagemaker.get_execution_role()\n", |
105 | 103 | "\n", |
| 104 | + "\n", |
106 | 105 | "print(f\"Execution Role: {role}\")\n", |
107 | 106 | "print(f\"Default S3 Bucket: {bucket_name}\")" |
108 | 107 | ] |
|
130 | 129 | "metadata": {}, |
131 | 130 | "outputs": [], |
132 | 131 | "source": [ |
133 | | - "inference_image_uri = sagemaker.image_uris.retrieve(\n", |
134 | | - " framework=\"djl-lmi\", \n", |
135 | | - " region=session.boto_session.region_name, \n", |
136 | | - " version=\"0.29.0\"\n", |
137 | | - ")\n", |
| 132 | + "# commenting until LMI 0.33.0 available via SageMaker SDK\n", |
| 133 | + "# inference_image_uri = sagemaker.image_uris.retrieve(\n", |
| 134 | + "# framework=\"djl-lmi\", \n", |
| 135 | + "# region=session.boto_session.region_name, \n", |
| 136 | + "# version=\"0.33.0\"\n", |
| 137 | + "# )\n", |
| 138 | + "\n", |
| 139 | + "inference_image_uri = f\"763104351884.dkr.ecr.{region}.amazonaws.com/djl-inference:0.33.0-lmi15.0.0-cu128\"\n", |
138 | 140 | "pretty_print_html(f\"using image to host: {inference_image_uri}\")" |
139 | 141 | ] |
140 | 142 | }, |
|
153 | 155 | "metadata": {}, |
154 | 156 | "outputs": [], |
155 | 157 | "source": [ |
156 | | - "model_id = \"deepseek-ai/DeepSeek-R1-Distill-Llama-8B\"\n", |
| 158 | + "model_id = \"Qwen/Qwen3-4B-Instruct-2507\"\n", |
157 | 159 | "model_id_filesafe = model_id.replace(\"/\",\"_\")\n", |
158 | 160 | "\n", |
159 | 161 | "use_local_model = True #set to false for the training job to download from HF, otherwise True will download locally" |
|
225 | 227 | "metadata": {}, |
226 | 228 | "outputs": [], |
227 | 229 | "source": [ |
228 | | - "model_name = \"DeepSeek-R1-Distill-Llama-8B\"\n", |
| 230 | + "model_name = \"Qwen3-4B-Instruct-2507\"\n", |
229 | 231 | "\n", |
230 | 232 | "lmi_model = sagemaker.Model(\n", |
231 | 233 | " image_uri=inference_image_uri,\n", |
|
242 | 244 | "metadata": {}, |
243 | 245 | "outputs": [], |
244 | 246 | "source": [ |
245 | | - "base_endpoint_name = f\"{model_name}-endpoint\"\n", |
| 247 | + "from sagemaker.utils import name_from_base\n", |
| 248 | + "\n", |
| 249 | + "endpoint_name = f\"{model_name}-endpoint\"\n", |
| 250 | + "BASE_ENDPOINT_NAME = name_from_base(endpoint_name)\n", |
246 | 251 | "\n", |
247 | 252 | "predictor = lmi_model.deploy(\n", |
248 | 253 | " initial_instance_count=1, \n", |
249 | 254 | " instance_type=\"ml.g5.2xlarge\",\n", |
250 | | - " endpoint_name=base_endpoint_name\n", |
| 255 | + " endpoint_name=BASE_ENDPOINT_NAME\n", |
251 | 256 | ")" |
252 | 257 | ] |
253 | 258 | }, |
|
258 | 263 | "metadata": {}, |
259 | 264 | "outputs": [], |
260 | 265 | "source": [ |
261 | | - "base_prompt = f\"\"\"\n", |
262 | | - "<|begin_of_text|>\n", |
263 | | - "<|start_header_id|>system<|end_header_id|>\n", |
264 | | - "You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. \n", |
| 266 | + "SYSTEM_PROMPT = f\"\"\"You are a medical expert with advanced knowledge in clinical reasoning, diagnostics, and treatment planning. \n", |
265 | 267 | "Below is an instruction that describes a task, paired with an input that provides further context. \n", |
266 | 268 | "Write a response that appropriately completes the request.\n", |
267 | | - "Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.\n", |
268 | | - "<|eot_id|><|start_header_id|>user<|end_header_id|>\n", |
269 | | - "{{question}}<|eot_id|>\n", |
270 | | - "<|start_header_id|>assistant<|end_header_id|>\"\"\"" |
271 | | - ] |
272 | | - }, |
273 | | - { |
274 | | - "cell_type": "code", |
275 | | - "execution_count": null, |
276 | | - "id": "6b37e7f1-730c-4b31-aa3b-55e2009f8f04", |
277 | | - "metadata": {}, |
278 | | - "outputs": [], |
279 | | - "source": [ |
280 | | - "prompt = base_prompt.format(\n", |
281 | | - " question=\"A 3-week-old child has been diagnosed with late onset perinatal meningitis, and the CSF culture shows gram-positive bacilli. What characteristic of this bacterium can specifically differentiate it from other bacterial agents?\"\n", |
282 | | - ")\n", |
| 269 | + "Before answering, think carefully about the question and create a step-by-step chain of thoughts to ensure a logical and accurate response.\"\"\"\n", |
| 270 | + "\n", |
| 271 | + "USER_PROMPT = \"A 3-week-old child has been diagnosed with late onset perinatal meningitis, and the CSF culture shows gram-positive bacilli. What characteristic of this bacterium can specifically differentiate it from other bacterial agents?\"\n", |
283 | 272 | "\n", |
284 | | - "print(prompt)" |
| 273 | + "messages = [\n", |
| 274 | + " {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n", |
| 275 | + " {\"role\": \"user\", \"content\": USER_PROMPT},\n", |
| 276 | + "]\n", |
| 277 | + "\n", |
| 278 | + "messages" |
285 | 279 | ] |
286 | 280 | }, |
287 | 281 | { |
|
292 | 286 | "outputs": [], |
293 | 287 | "source": [ |
294 | 288 | "predictor = sagemaker.Predictor(\n", |
295 | | - " endpoint_name=base_endpoint_name,\n", |
| 289 | + " endpoint_name=BASE_ENDPOINT_NAME,\n", |
296 | 290 | " sagemaker_session=sagemaker_session,\n", |
297 | 291 | " serializer=sagemaker.serializers.JSONSerializer(),\n", |
298 | 292 | " deserializer=sagemaker.deserializers.JSONDeserializer(),\n", |
299 | 293 | ")\n", |
300 | 294 | "\n", |
301 | 295 | "response = predictor.predict({\n", |
302 | | - "\t\"inputs\": prompt,\n", |
| 296 | + "\t\"messages\": messages,\n", |
303 | 297 | " \"parameters\": {\n", |
304 | 298 | " \"temperature\": 0.2,\n", |
305 | 299 | " \"top_p\": 0.9,\n", |
306 | 300 | " \"return_full_text\": False,\n", |
307 | | - " \"max_new_tokens\": 1024,\n", |
308 | | - " \"stop\": ['<|eot_id|>']\n", |
| 301 | + " \"max_new_tokens\": 1024\n", |
309 | 302 | " }\n", |
310 | 303 | "})\n", |
311 | 304 | "\n", |
312 | | - "response = response[\"generated_text\"].split(\"<|eot_id|>\")[0]\n", |
| 305 | + "response[\"choices\"][0][\"message\"][\"content\"]" |
| 306 | + ] |
| 307 | + }, |
| 308 | + { |
| 309 | + "cell_type": "markdown", |
| 310 | + "id": "165c8660-ee18-411f-9d8a-8032c6171d77", |
| 311 | + "metadata": {}, |
| 312 | + "source": [ |
| 313 | + "### Store variables\n", |
313 | 314 | "\n", |
314 | | - "response" |
| 315 | + "Save the endpoint name for use later" |
315 | 316 | ] |
316 | 317 | }, |
317 | 318 | { |
318 | 319 | "cell_type": "code", |
319 | 320 | "execution_count": null, |
320 | | - "id": "dbfc37bb-dc1f-4ba7-9948-6e482c1c86b0", |
| 321 | + "id": "0ed6ca9e-705c-4d01-9118-110b86caaef6", |
321 | 322 | "metadata": {}, |
322 | 323 | "outputs": [], |
323 | | - "source": [] |
| 324 | + "source": [ |
| 325 | + "%store BASE_ENDPOINT_NAME" |
| 326 | + ] |
324 | 327 | } |
325 | 328 | ], |
326 | 329 | "metadata": { |
|
0 commit comments