移动网站虚拟主机,简洁中文网站模板,女性健康网站源码,网站建设及管理基本要求背景
LlamaFactory 的 LoRA 微调功能非常便捷#xff0c;微调后的模型#xff0c;没有直接支持 vllm 推理#xff0c;故导致推理速度不够快。
LlamaFactory 目前支持通过 VLLM API 进行部署#xff0c;调用 API 时的响应速度#xff0c;仍然没有vllm批量推理的速度快。 …背景
LlamaFactory 的 LoRA 微调功能非常便捷微调后的模型没有直接支持 vllm 推理故导致推理速度不够快。
LlamaFactory 目前支持通过 VLLM API 进行部署调用 API 时的响应速度仍然没有vllm批量推理的速度快。
如果模型是通过 LlamaFactory 微调的为了确保数据集的一致性建议在推理时也使用 LlamaFactory 提供的封装数据集。
简介
在上述的背景下我们使用 LlamaFactory 原生数据集支持 lora的 vllm 批量推理。 完整代码如下
import json
import os
from typing import Listfrom vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequestfrom llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizerdef vllm_infer():model_args, data_args, training_args, finetuning_args, generating_args (get_train_args())tokenizer load_tokenizer(model_args)[tokenizer]template get_template_and_fix_tokenizer(tokenizer, data_args)eval_dataset get_dataset(template, model_args, data_args, training_args, finetuning_args.stage, tokenizer)[eval_dataset]prompts [item[input_ids] for item in eval_dataset]prompts tokenizer.batch_decode(prompts, skip_special_tokensFalse)labels [list(filter(lambda x: x ! IGNORE_INDEX, item[labels]))for item in eval_dataset]labels tokenizer.batch_decode(labels, skip_special_tokensTrue)sampling_params SamplingParams(temperaturegenerating_args.temperature,top_kgenerating_args.top_k,top_pgenerating_args.top_p,max_tokens2048,)if model_args.adapter_name_or_path:if isinstance(model_args.adapter_name_or_path, list):lora_requests []for i, _lora_path in enumerate(model_args.adapter_name_or_path):lora_requests.append(LoRARequest(flora_adapter_{i}, i, lora_path_lora_path))else:lora_requests LoRARequest(lora_adapter_0, 0, lora_pathmodel_args.adapter_name_or_path)enable_lora Trueelse:lora_requests Noneenable_lora Falsellm LLM(modelmodel_args.model_name_or_path,trust_remote_codeTrue,tokenizermodel_args.model_name_or_path,enable_loraenable_lora,)outputs llm.generate(prompts, sampling_params, lora_requestlora_requests)if not os.path.exists(training_args.output_dir):os.makedirs(training_args.output_dir, exist_okTrue)output_prediction_file os.path.join(training_args.output_dir, generated_predictions.jsonl)with open(output_prediction_file, w, encodingutf-8) as writer:res: List[str] []for text, pred, label in zip(prompts, outputs, labels):res.append(json.dumps({prompt: text, predict: pred.outputs[0].text, label: label},ensure_asciiFalse,))writer.write(\n.join(res))vllm.yaml 示例:
## model
model_name_or_path: qwen/Qwen2.5-7B-Instruct
# adapter_name_or_path: lora模型### method
stage: sft
do_predict: true
finetuning_type: lora### dataset
dataset_dir: 数据集路径
eval_dataset: 数据集
template: qwen
cutoff_len: 1024
max_samples: 1000
overwrite_cache: true
preprocessing_num_workers: 16### output
output_dir: output/
overwrite_output_dir: true### eval
predict_with_generate: true程序调用:
python vllm_infer.py vllm.yaml程序运行速度
Processed prompts: 100%|█| 1000/1000 [01:5600:00, 8.60it/s, est. speed input: 5169.35 toks/s, output: 811.57总结
本方案在原生 LlamaFactory 数据集的基础上支持 LoRA 的 vllm 批量推理能提升了推理效率。
进一步阅读
如果微调模型后发现使用vllm模型批量效果不太好可以参考下述文章
基于 LLamafactory 的异步API高效调用实现与速度对比.https://blog.csdn.net/sjxgghg/article/details/144176645
亲测LLamafactory 部署 模型然后使用 Async API 调用后评估效果会好一些。