Skip to main content
Version: 0.8.x

MultiLoRA 推理

源代码: vllm-project/vllm

"""
This example shows how to use the multi-LoRA functionality
for offline inference.

Requires HuggingFace credentials for access to Llama2.


该示例展示了如何在离线推理中使用多LoRA功能。
访问 Llama2 需要 HuggingFace 凭证。

"""

from typing import List, Optional, Tuple

from huggingface_hub import snapshot_download

from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
from vllm.lora.request import LoRARequest


def create_test_prompts(
lora_path: str
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
"""Create a list of test prompts with their sampling parameters.

2 requests for base model, 4 requests for the LoRA. We define 2
different LoRA adapters (using the same model for demo purposes).
Since we also set `max_loras=1`, the expectation is that the requests
with the second LoRA adapter will be ran after all requests with the
first adapter have finished.
创建一个测试提示列表及其采样参数。
2 个请求用于基础模型,4 个请求用于 LoRA。我们定义了 2 个不同的 LoRA 适配器
(为了演示,使用相同的模型)。
由于我们还设置了 `max_loras=1`,因此预期的是,第二个 LoRA 适配器的请求将在
所有使用第一个适配器的请求完成后运行。
"""
return [
("A robot may not injure a human being",
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128), None),
("To be or not to be,",
SamplingParams(temperature=0.8,
top_k=5,
presence_penalty=0.2,
max_tokens=128), None),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
SamplingParams(n=3,
best_of=3,
use_beam_search=True,
temperature=0,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501
SamplingParams(temperature=0.0,
logprobs=1,
prompt_logprobs=1,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora2", 2, lora_path)),
(
"[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501
SamplingParams(n=3,
best_of=3,
use_beam_search=True,
temperature=0,
max_tokens=128,
stop_token_ids=[32003]),
LoRARequest("sql-lora", 1, lora_path)),
]


def process_requests(engine: LLMEngine,
test_prompts: List[Tuple[str, SamplingParams,
Optional[LoRARequest]]]):
"""Continuously process a list of prompts and handle the outputs."""
"""持续处理一系列提示并处理输出"""
request_id = 0

while test_prompts or engine.has_unfinished_requests():
if test_prompts:
prompt, sampling_params, lora_request = test_prompts.pop(0)
engine.add_request(str(request_id),
prompt,
sampling_params,
lora_request=lora_request)
request_id += 1

request_outputs: List[RequestOutput] = engine.step()

for request_output in request_outputs:
if request_output.finished:
print(request_output)


def initialize_engine() -> LLMEngine:
"""Initialize the LLMEngine."""
"""初始化 LLMEngine"""
# max_loras: controls the number of LoRAs that can be used in the same
# batch. Larger numbers will cause higher memory usage, as each LoRA
# slot requires its own preallocated tensor.
# max_loras: 控制在同一批次中可以使用的 LoRA 数量。
# 较大的数字会导致更高的内存使用,因为每个 LoRA 插槽都需要自己的预分配张量。
# max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
# numbers will cause higher memory usage. If you know that all LoRAs will
# use the same rank, it is recommended to set this as low as possible.
# # max_lora_rank: 控制所有 LoRA 的最大支持秩。
# 较大的数字会导致更高的内存使用。如果你知道所有 LoRA 都会使用相同的秩,
# 建议将此值设置为尽可能低。
# max_cpu_loras: controls the size of the CPU LoRA cache.
# max_cpu_loras: 控制 CPU LoRA 缓存的大小。



engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf",
enable_lora=True,
max_loras=1,
max_lora_rank=8,
max_cpu_loras=2,
max_num_seqs=256)
return LLMEngine.from_engine_args(engine_args)


def main():
"""Main function that sets up and runs the prompt processing."""
"""主函数,用于设置和运行提示处理"""
engine = initialize_engine()
lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
test_prompts = create_test_prompts(lora_path)
process_requests(engine, test_prompts)


if __name__ == '__main__':
main()