|
27 | 27 | "source": [
|
28 | 28 | "The following server arguments are relevant for multi-LoRA serving:\n",
|
29 | 29 | "\n",
|
| 30 | + "* `enable_lora`: Enable LoRA support for the model. This argument is automatically set to True if `--lora-paths` is provided for backward compatibility.\n", |
| 31 | + "\n", |
30 | 32 | "* `lora_paths`: A mapping from each adaptor's name to its path, in the form of `{name}={path} {name}={path}`.\n",
|
31 | 33 | "\n",
|
32 | 34 | "* `max_loras_per_batch`: Maximum number of adaptors used by each batch. This argument can affect the amount of GPU memory reserved for multi-LoRA serving, so it should be set to a smaller value when memory is scarce. Defaults to be 8.\n",
|
|
35 | 37 | "\n",
|
36 | 38 | "* `max_lora_rank`: The maximum LoRA rank that should be supported. If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of larger LoRA rank after server startup.\n",
|
37 | 39 | "\n",
|
38 |
| - "* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup.\n", |
| 40 | + "* `lora_target_modules`: The union set of all target modules where LoRA should be applied (e.g., `q_proj`, `k_proj`, `gate_proj`). If not specified, it will be automatically inferred from the adapters provided in `--lora-paths`. This argument is needed when you expect to dynamically load adapters of different target modules after server startup. You can also set it to `all` to enable LoRA for all supported modules. However, enabling LoRA on additional modules introduces a minor performance overhead. If your application is performance-sensitive, we recommend only specifying the modules for which you plan to load adapters.\n", |
39 | 41 | "\n",
|
40 | 42 | "* `tp_size`: LoRA serving along with Tensor Parallelism is supported by SGLang. `tp_size` controls the number of GPUs for tensor parallelism. More details on the tensor sharding strategy can be found in [S-Lora](https://arxiv.org/pdf/2311.03285) paper.\n",
|
41 | 43 | "\n",
|
|
79 | 81 | "server_process, port = launch_server_cmd(\n",
|
80 | 82 | " \"\"\"\n",
|
81 | 83 | "python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
|
| 84 | + " --enable-lora \\\n", |
82 | 85 | " --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
|
83 | 86 | " --max-loras-per-batch 1 --lora-backend triton \\\n",
|
84 | 87 | " --disable-radix-cache\n",
|
|
98 | 101 | "json_data = {\n",
|
99 | 102 | " \"text\": [\n",
|
100 | 103 | " \"List 3 countries and their capitals.\",\n",
|
101 |
| - " \"AI is a field of computer science focused on\",\n", |
| 104 | + " \"List 3 countries and their capitals.\",\n", |
102 | 105 | " ],\n",
|
103 | 106 | " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
|
104 | 107 | " # The first input uses lora0, and the second input uses the base model\n",
|
|
137 | 140 | "server_process, port = launch_server_cmd(\n",
|
138 | 141 | " \"\"\"\n",
|
139 | 142 | "python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
|
| 143 | + " --enable-lora \\\n", |
140 | 144 | " --lora-paths lora0=algoprog/fact-generation-llama-3.1-8b-instruct-lora \\\n",
|
141 | 145 | " lora1=Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16 \\\n",
|
142 | 146 | " --max-loras-per-batch 2 --lora-backend triton \\\n",
|
|
157 | 161 | "json_data = {\n",
|
158 | 162 | " \"text\": [\n",
|
159 | 163 | " \"List 3 countries and their capitals.\",\n",
|
160 |
| - " \"AI is a field of computer science focused on\",\n", |
| 164 | + " \"List 3 countries and their capitals.\",\n", |
161 | 165 | " ],\n",
|
162 | 166 | " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
|
163 | 167 | " # The first input uses lora0, and the second input uses lora1\n",
|
|
191 | 195 | "cell_type": "markdown",
|
192 | 196 | "metadata": {},
|
193 | 197 | "source": [
|
194 |
| - "### Basic Usage\n", |
195 |
| - "\n", |
196 | 198 | "Instead of specifying all adapters during server startup via `--lora-paths`. You can also load & unload LoRA adapters dynamically via the `/load_lora_adapter` and `/unload_lora_adapter` API.\n",
|
197 | 199 | "\n",
|
198 |
| - "(Please note that, currently we still require you to specify at least one adapter in `--lora-paths` to enable the LoRA feature, this limitation will be lifted soon.)" |
| 200 | + "When using dynamic LoRA loading, it's recommended to explicitly specify both `--max-lora-rank` and `--lora-target-modules` at startup. For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. However, in that case, you would have to ensure that all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"." |
199 | 201 | ]
|
200 | 202 | },
|
201 | 203 | {
|
|
204 | 206 | "metadata": {},
|
205 | 207 | "outputs": [],
|
206 | 208 | "source": [
|
| 209 | + "lora0 = \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\" # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\n", |
| 210 | + "lora1 = \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\" # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n", |
| 211 | + "lora0_new = \"philschmid/code-llama-3-1-8b-text-to-sql-lora\" # rank - 256, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n", |
| 212 | + "\n", |
| 213 | + "\n", |
| 214 | + "# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\n", |
| 215 | + "# We are adding it here just to demonstrate usage.\n", |
207 | 216 | "server_process, port = launch_server_cmd(\n",
|
208 | 217 | " \"\"\"\n",
|
209 | 218 | " python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n",
|
210 |
| - " --lora-paths lora0=philschmid/code-llama-3-1-8b-text-to-sql-lora \\\n", |
| 219 | + " --enable-lora \\\n", |
211 | 220 | " --cuda-graph-max-bs 2 \\\n",
|
212 | 221 | " --max-loras-per-batch 2 --lora-backend triton \\\n",
|
213 | 222 | " --disable-radix-cache\n",
|
| 223 | + " --max-lora-rank 256\n", |
| 224 | + " --lora-target-modules all\n", |
214 | 225 | " \"\"\"\n",
|
215 | 226 | ")\n",
|
216 | 227 | "\n",
|
217 | 228 | "url = f\"http://127.0.0.1:{port}\"\n",
|
218 | 229 | "wait_for_server(url)"
|
219 | 230 | ]
|
220 | 231 | },
|
| 232 | + { |
| 233 | + "cell_type": "markdown", |
| 234 | + "metadata": {}, |
| 235 | + "source": [ |
| 236 | + "Load adapter lora0" |
| 237 | + ] |
| 238 | + }, |
221 | 239 | {
|
222 | 240 | "cell_type": "code",
|
223 | 241 | "execution_count": null,
|
|
227 | 245 | "response = requests.post(\n",
|
228 | 246 | " url + \"/load_lora_adapter\",\n",
|
229 | 247 | " json={\n",
|
230 |
| - " \"lora_name\": \"lora1\",\n", |
231 |
| - " \"lora_path\": \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\",\n", |
| 248 | + " \"lora_name\": \"lora0\",\n", |
| 249 | + " \"lora_path\": lora0,\n", |
232 | 250 | " },\n",
|
233 | 251 | ")\n",
|
234 | 252 | "\n",
|
|
239 | 257 | ]
|
240 | 258 | },
|
241 | 259 | {
|
242 |
| - "cell_type": "code", |
243 |
| - "execution_count": null, |
244 |
| - "metadata": {}, |
245 |
| - "outputs": [], |
246 |
| - "source": [ |
247 |
| - "response = requests.post(\n", |
248 |
| - " url + \"/generate\",\n", |
249 |
| - " json={\n", |
250 |
| - " \"text\": [\n", |
251 |
| - " \"List 3 countries and their capitals.\",\n", |
252 |
| - " \"List 3 countries and their capitals.\",\n", |
253 |
| - " ],\n", |
254 |
| - " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", |
255 |
| - " \"lora_path\": [\"lora0\", \"lora1\"],\n", |
256 |
| - " },\n", |
257 |
| - ")\n", |
258 |
| - "print(f\"Output from lora0: {response.json()[0]['text']}\")\n", |
259 |
| - "print(f\"Output from lora1: {response.json()[1]['text']}\")" |
260 |
| - ] |
261 |
| - }, |
262 |
| - { |
263 |
| - "cell_type": "code", |
264 |
| - "execution_count": null, |
| 260 | + "cell_type": "markdown", |
265 | 261 | "metadata": {},
|
266 |
| - "outputs": [], |
267 | 262 | "source": [
|
268 |
| - "response = requests.post(\n", |
269 |
| - " url + \"/unload_lora_adapter\",\n", |
270 |
| - " json={\n", |
271 |
| - " \"lora_name\": \"lora0\",\n", |
272 |
| - " },\n", |
273 |
| - ")" |
| 263 | + "Load adapter lora1:" |
274 | 264 | ]
|
275 | 265 | },
|
276 | 266 | {
|
|
282 | 272 | "response = requests.post(\n",
|
283 | 273 | " url + \"/load_lora_adapter\",\n",
|
284 | 274 | " json={\n",
|
285 |
| - " \"lora_name\": \"lora2\",\n", |
286 |
| - " \"lora_path\": \"pbevan11/llama-3.1-8b-ocr-correction\",\n", |
| 275 | + " \"lora_name\": \"lora1\",\n", |
| 276 | + " \"lora_path\": lora1,\n", |
287 | 277 | " },\n",
|
288 | 278 | ")\n",
|
289 | 279 | "\n",
|
|
294 | 284 | ]
|
295 | 285 | },
|
296 | 286 | {
|
297 |
| - "cell_type": "code", |
298 |
| - "execution_count": null, |
| 287 | + "cell_type": "markdown", |
299 | 288 | "metadata": {},
|
300 |
| - "outputs": [], |
301 | 289 | "source": [
|
302 |
| - "response = requests.post(\n", |
303 |
| - " url + \"/generate\",\n", |
304 |
| - " json={\n", |
305 |
| - " \"text\": [\n", |
306 |
| - " \"List 3 countries and their capitals.\",\n", |
307 |
| - " \"List 3 countries and their capitals.\",\n", |
308 |
| - " ],\n", |
309 |
| - " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", |
310 |
| - " \"lora_path\": [\"lora1\", \"lora2\"],\n", |
311 |
| - " },\n", |
312 |
| - ")\n", |
313 |
| - "print(f\"Output from lora1: {response.json()[0]['text']}\")\n", |
314 |
| - "print(f\"Output from lora2: {response.json()[1]['text']}\")" |
| 290 | + "Check inference output:" |
315 | 291 | ]
|
316 | 292 | },
|
317 | 293 | {
|
|
320 | 296 | "metadata": {},
|
321 | 297 | "outputs": [],
|
322 | 298 | "source": [
|
323 |
| - "terminate_process(server_process)" |
| 299 | + "url = f\"http://127.0.0.1:{port}\"\n", |
| 300 | + "json_data = {\n", |
| 301 | + " \"text\": [\n", |
| 302 | + " \"List 3 countries and their capitals.\",\n", |
| 303 | + " \"List 3 countries and their capitals.\",\n", |
| 304 | + " ],\n", |
| 305 | + " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n", |
| 306 | + " # The first input uses lora0, and the second input uses lora1\n", |
| 307 | + " \"lora_path\": [\"lora0\", \"lora1\"],\n", |
| 308 | + "}\n", |
| 309 | + "response = requests.post(\n", |
| 310 | + " url + \"/generate\",\n", |
| 311 | + " json=json_data,\n", |
| 312 | + ")\n", |
| 313 | + "print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n", |
| 314 | + "print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")" |
324 | 315 | ]
|
325 | 316 | },
|
326 | 317 | {
|
327 | 318 | "cell_type": "markdown",
|
328 | 319 | "metadata": {},
|
329 | 320 | "source": [
|
330 |
| - "### Advanced: hosting adapters of different shapes\n", |
331 |
| - "\n", |
332 |
| - "In some cases, you may want to load LoRA adapters with different ranks or target modules (e.g., `q_proj`, `k_proj`) simultaneously. To ensure the server can accommodate all expected LoRA shapes, it's recommended to explicitly specify `--max-lora-rank` and/or `--lora-target-modules` at startup.\n", |
333 |
| - "\n", |
334 |
| - "For backward compatibility, SGLang will infer these values from `--lora-paths` if they are not explicitly provided. This means it's safe to omit them **only if** all dynamically loaded adapters share the same shape (rank and target modules) as those in the initial `--lora-paths` or are strictly \"smaller\"." |
| 321 | + "Unload lora0 and replace it with a different adapter:" |
335 | 322 | ]
|
336 | 323 | },
|
337 | 324 | {
|
|
340 | 327 | "metadata": {},
|
341 | 328 | "outputs": [],
|
342 | 329 | "source": [
|
343 |
| - "lora0 = \"Nutanix/Meta-Llama-3.1-8B-Instruct_lora_4_alpha_16\" # rank - 4, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj\n", |
344 |
| - "lora1 = \"algoprog/fact-generation-llama-3.1-8b-instruct-lora\" # rank - 64, target modules - q_proj, k_proj, v_proj, o_proj, gate_proj, up_proj, down_proj\n", |
345 |
| - "\n", |
346 |
| - "\n", |
347 |
| - "# The `--target-lora-modules` param below is technically not needed, as the server will infer it from lora0 which already has all the target modules specified.\n", |
348 |
| - "# We are adding it here just to demonstrate usage.\n", |
349 |
| - "server_process, port = launch_server_cmd(\n", |
350 |
| - " f\"\"\"\n", |
351 |
| - " python3 -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", |
352 |
| - " --lora-paths lora0={lora0} \\\n", |
353 |
| - " --cuda-graph-max-bs 2 \\\n", |
354 |
| - " --max-loras-per-batch 2 --lora-backend triton \\\n", |
355 |
| - " --disable-radix-cache\n", |
356 |
| - " --max-lora-rank 64\n", |
357 |
| - " --lora-target-modules q_proj k_proj v_proj o_proj down_proj up_proj gate_proj\n", |
358 |
| - " \"\"\"\n", |
| 330 | + "response = requests.post(\n", |
| 331 | + " url + \"/unload_lora_adapter\",\n", |
| 332 | + " json={\n", |
| 333 | + " \"lora_name\": \"lora0\",\n", |
| 334 | + " },\n", |
359 | 335 | ")\n",
|
360 | 336 | "\n",
|
361 |
| - "url = f\"http://127.0.0.1:{port}\"\n", |
362 |
| - "wait_for_server(url)" |
363 |
| - ] |
364 |
| - }, |
365 |
| - { |
366 |
| - "cell_type": "code", |
367 |
| - "execution_count": null, |
368 |
| - "metadata": {}, |
369 |
| - "outputs": [], |
370 |
| - "source": [ |
371 | 337 | "response = requests.post(\n",
|
372 | 338 | " url + \"/load_lora_adapter\",\n",
|
373 | 339 | " json={\n",
|
374 |
| - " \"lora_name\": \"lora1\",\n", |
375 |
| - " \"lora_path\": lora1,\n", |
| 340 | + " \"lora_name\": \"lora0\",\n", |
| 341 | + " \"lora_path\": lora0_new,\n", |
376 | 342 | " },\n",
|
377 | 343 | ")\n",
|
378 | 344 | "\n",
|
|
382 | 348 | " print(\"Failed to load LoRA adapter.\", response.json())"
|
383 | 349 | ]
|
384 | 350 | },
|
| 351 | + { |
| 352 | + "cell_type": "markdown", |
| 353 | + "metadata": {}, |
| 354 | + "source": [ |
| 355 | + "Check output again:" |
| 356 | + ] |
| 357 | + }, |
385 | 358 | {
|
386 | 359 | "cell_type": "code",
|
387 | 360 | "execution_count": null,
|
|
392 | 365 | "json_data = {\n",
|
393 | 366 | " \"text\": [\n",
|
394 | 367 | " \"List 3 countries and their capitals.\",\n",
|
395 |
| - " \"AI is a field of computer science focused on\",\n", |
| 368 | + " \"List 3 countries and their capitals.\",\n", |
396 | 369 | " ],\n",
|
397 | 370 | " \"sampling_params\": {\"max_new_tokens\": 32, \"temperature\": 0},\n",
|
398 | 371 | " # The first input uses lora0, and the second input uses lora1\n",
|
|
402 | 375 | " url + \"/generate\",\n",
|
403 | 376 | " json=json_data,\n",
|
404 | 377 | ")\n",
|
405 |
| - "print(f\"Output from lora0: {response.json()[0]['text']}\")\n", |
406 |
| - "print(f\"Output from lora1: {response.json()[1]['text']}\")" |
| 378 | + "print(f\"Output from lora0: \\n{response.json()[0]['text']}\\n\")\n", |
| 379 | + "print(f\"Output from lora1 (updated): \\n{response.json()[1]['text']}\\n\")" |
407 | 380 | ]
|
408 | 381 | },
|
409 | 382 | {
|
|
0 commit comments