代码之家  ›  专栏  ›  技术社区  ›  xerxes01

在Flask API中启用Huggingface LLM langchain流式响应

  •  0
  • xerxes01  · 技术社区  · 2 年前

    我正在创建一个能够流式传输LLM(封装在langchain管道中)响应的flask API。我已经能够使用openai llm做到这一点,但它似乎不适用于拥抱脸模型。

    如有任何帮助,我们将不胜感激!

    model.py的代码段:

    def load_llama2_13b():
        model_str = "meta-llama/Llama-2-13b-chat-hf"
        access_token = "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"
        tokenizer = AutoTokenizer.from_pretrained(model_str, use_auth_token=access_token)
    
        model = AutoModelForCausalLM.from_pretrained(
            model_str,
            device_map="auto",
            # quantization_config=bnb_config,
            trust_remote_code=True,
            use_auth_token=access_token)
    
        streamer = TextStreamer(tokenizer)
        llm_pipeline = pipeline(
            "text-generation",  # task
            model=model,
            tokenizer=tokenizer,
            trust_remote_code=True,
            device_map="auto",
            do_sample=True,
            max_new_tokens=300,
            streamer=streamer,
            eos_token_id=tokenizer.eos_token_id,
            model_kwargs={"temperature": 0.01, "repetition_penalty": 2.5}
        )
    
        llm = HuggingFacePipeline(pipeline=llm_pipeline)
        return llm
    
    def query_llm(llm, query, task, prompt_template=None, vectordb=None):
        valid_tasks = ["instructive", "answer"]
    
        if task not in valid_tasks:
            raise ValueError(f"Invalid task '{task}'. Allowed values are {valid_tasks}")
    
        if task == "instructive":
            template = """
                        You are an intelligent chatbot. Answer the question posed by user.
                        Question: {question}
                        Answer:"""
    
    
        if task == "answer":
            if prompt_template:
                template = prompt_template
            else:
                template = """
                        You are an intelligent chatbot. Given the context below, answer the question given at the end:
                        Context: {context}
                        QUESTION: {question} 
                        Answer:"""
    
            PROMPT = PromptTemplate(template=template, input_variables=["context", "question"])
            llm_chain = RetrievalQA.from_chain_type(llm,
                                                    chain_type="stuff",
                                                    retriever=vectordb.as_retriever(),
                                                    return_source_documents=True, chain_type_kwargs={"prompt": PROMPT})
            res = llm_chain({'query': query})
            # import pdb;pdb.set_trace()
            source_docs = [t.__dict__ for t in res["source_documents"]]
    
            return {"result": res["result"], "source_documents": source_docs}
    
        PROMPT = PromptTemplate(template=template, input_variables=["question"])
    
        llm_chain = LLMChain(llm=llm, prompt=PROMPT, verbose=True)
        return {"result": llm_chain.run(question=query)}
    

    这就是它在api.py中的消耗方式:

    @app.route('/llmgeneration_stream', methods=['POST'])
    def llm_generation_stream():
        json_data = {"status": 0, "message": "Failed"}
        start = time.time()
    
        data = json.loads(flask.request.data)
    
        model = data.get("model", "falcon-7b")
        query = data.get("query")
        task = data.get("task", "instructive")
        prompt_template = data.get("prompt_template", None)
    
        valid_models = ["falcon-7b", "openai", "falcon-40b", "llama2-13b", "llama2-7b", "llama2-70b"]
        if model not in valid_models:
            raise ValueError(f"Invalid model '{model}'. Allowed values are {valid_models}")
    
        try:
            if model == "falcon-7b":
                llm = falcon_7b_llm
           
    
            return Response(stream_with_context(query_llm(llm, query, task, prompt_template, vectordb=vectordb)), mimetype='application/json')
    
            json_data["data"] = result
            json_data["message"] = "Passed"
            json_data["status"] = 1
            end = time.time()
            logger.info("llm API: " + "Time: " + str(end - start) + "Model: " + model + "Query" + query + "text_generated"
                        + str(result))
    
        except Exception as e:
            logger.error(e, model, query)
        return flask.jsonify(json_data)
    
    
    
    0 回复  |  直到 2 年前
    推荐文章