RAG开源项目Qanything源码阅读3-在线推理

原文:前沿重器[47] | RAG开源项目Qanything源码阅读3-在线推理
项目:https://github.com/netease-youdao/QAnything
第一篇:RAG开源项目Qanything源码阅读1-概述+服务
第二篇:RAG开源项目Qanything源码阅读2-离线文件处理


0,推理大概流程

  • 检索&粗排
  • 精排
  • 检索文档后处理
  • prompt和请求大模型

1,外部服务

回顾一下在“前沿重器[45] RAG开源项目Qanything源码阅读1-概述+服务”中提到的服务核心文件,所有的接口都是在qanything_kernel\qanything_server\sanic_api.py里面启动的:

app.add_route(document, "/api/docs", methods=['GET'])
app.add_route(new_knowledge_base, "/api/local_doc_qa/new_knowledge_base", methods=['POST'])  # tags=["新建知识库"]
app.add_route(upload_weblink, "/api/local_doc_qa/upload_weblink", methods=['POST'])  # tags=["上传网页链接"]
app.add_route(upload_files, "/api/local_doc_qa/upload_files", methods=['POST'])  # tags=["上传文件"] 
app.add_route(local_doc_chat, "/api/local_doc_qa/local_doc_chat", methods=['POST'])  # tags=["问答接口"] 
app.add_route(list_kbs, "/api/local_doc_qa/list_knowledge_base", methods=['POST'])  # tags=["知识库列表"] 
app.add_route(list_docs, "/api/local_doc_qa/list_files", methods=['POST'])  # tags=["文件列表"]
app.add_route(get_total_status, "/api/local_doc_qa/get_total_status", methods=['POST'])  # tags=["获取所有知识库状态"]
app.add_route(clean_files_by_status, "/api/local_doc_qa/clean_files_by_status", methods=['POST'])  # tags=["清理数据库"]
app.add_route(delete_docs, "/api/local_doc_qa/delete_files", methods=['POST'])  # tags=["删除文件"] 
app.add_route(delete_knowledge_base, "/api/local_doc_qa/delete_knowledge_base", methods=['POST'])  # tags=["删除知识库"] 
app.add_route(rename_knowledge_base, "/api/local_doc_qa/rename_knowledge_base", methods=['POST'])  # tags=["重命名知识库"]

而推理,就是这里的local_doc_chat,直接看这个函数,就在qanything_kernel\qanything_server\handler.py里面。

async def local_doc_chat(req: request):
    local_doc_qa: LocalDocQA = req.app.ctx.local_doc_qa
    user_id = safe_get(req, 'user_id')
    if user_id is None:
        return sanic_json({"code": 2002, "msg": f'输入非法!request.json:{req.json},请检查!'})
    is_valid = validate_user_id(user_id)
    if not is_valid:
        return sanic_json({"code": 2005, "msg": get_invalid_user_id_msg(user_id=user_id)})
    debug_logger.info('local_doc_chat %s', user_id)
    kb_ids = safe_get(req, 'kb_ids')
    question = safe_get(req, 'question')
    rerank = safe_get(req, 'rerank', default=True)
    debug_logger.info('rerank %s', rerank)
    streaming = safe_get(req, 'streaming', False)
    history = safe_get(req, 'history', [])
    debug_logger.info("history: %s ", history)
    debug_logger.info("question: %s", question)
    debug_logger.info("kb_ids: %s", kb_ids)
    debug_logger.info("user_id: %s", user_id)
 
    not_exist_kb_ids = local_doc_qa.milvus_summary.check_kb_exist(user_id, kb_ids)
    if not_exist_kb_ids:
        return sanic_json({"code": 2003, "msg": "fail, knowledge Base {} not found".format(not_exist_kb_ids)})
 
    file_infos = []
    milvus_kb = local_doc_qa.match_milvus_kb(user_id, kb_ids)
    for kb_id in kb_ids:
        file_infos.extend(local_doc_qa.milvus_summary.get_files(user_id, kb_id))
    valid_files = [fi for fi in file_infos if fi[2] == 'green']
    if len(valid_files) == 0:
        return sanic_json({"code": 200, "msg": "当前知识库为空,请上传文件或等待文件解析完毕", "question": question,
                           "response": "All knowledge bases {} are empty or haven't green file, please upload files".format(
                               kb_ids), "history": history, "source_documents": [{}]})
    else:
        debug_logger.info("streaming: %s", streaming)
        if streaming:
            debug_logger.info("start generate answer")
 
            async def generate_answer(response):
                debug_logger.info("start generate...")
                for resp, next_history in local_doc_qa.get_knowledge_based_answer(
                        query=question, milvus_kb=milvus_kb, chat_history=history, streaming=True, rerank=rerank
                ):
                    chunk_data = resp["result"]
                    if not chunk_data:
                        continue
                    chunk_str = chunk_data[6:]
                    if chunk_str.startswith("[DONE]"):
                        source_documents = []
                        for inum, doc in enumerate(resp["source_documents"]):
                            source_info = {'file_id': doc.metadata['file_id'],
                                           'file_name': doc.metadata['file_name'],
                                           'content': doc.page_content,
                                           'retrieval_query': doc.metadata['retrieval_query'],
                                           'score': str(doc.metadata['score'])}
                            source_documents.append(source_info)
 
                        retrieval_documents = format_source_documents(resp["retrieval_documents"])
                        source_documents = format_source_documents(resp["source_documents"])
                        chat_data = {'user_info': user_id, 'kb_ids': kb_ids, 'query': question, 'history': history,
                                     'prompt': resp['prompt'], 'result': next_history[-1][1],
                                     'retrieval_documents': retrieval_documents, 'source_documents': source_documents}
                        qa_logger.info("chat_data: %s", chat_data)
                        debug_logger.info("response: %s", chat_data['result'])
                        stream_res = {
                            "code": 200,
                            "msg": "success",
                            "question": question,
                            # "response":next_history[-1][1],
                            "response": "",
                            "history": next_history,
                            "source_documents": source_documents,
                        }
                    else:
                        chunk_js = json.loads(chunk_str)
                        delta_answer = chunk_js["answer"]
                        stream_res = {
                            "code": 200,
                            "msg": "success",
                            "question": "",
                            "response": delta_answer,
                            "history": [],
                            "source_documents": [],
                        }
                    await response.write(f"data: {json.dumps(stream_res, ensure_ascii=False)}\n\n")
                    if chunk_str.startswith("[DONE]"):
                        await response.eof()
                    await asyncio.sleep(0.001)
 
            response_stream = ResponseStream(generate_answer, content_type='text/event-stream')
            return response_stream
 
        else:
            for resp, history in local_doc_qa.get_knowledge_based_answer(
                    query=question, milvus_kb=milvus_kb, chat_history=history, streaming=False, rerank=rerank
            ):
                pass
            retrieval_documents = format_source_documents(resp["retrieval_documents"])
            source_documents = format_source_documents(resp["source_documents"])
            chat_data = {'user_id': user_id, 'kb_ids': kb_ids, 'query': question, 'history': history,
                         'retrieval_documents': retrieval_documents, 'prompt': resp['prompt'], 'result': resp['result'],
                         '`': source_documents}
            qa_logger.info("chat_data: %s", chat_data)
            debug_logger.info("response: %s", chat_data['result'])
            return sanic_json({"code": 200, "msg": "success chat", "question": question, "response": resp["result"],
                               "history": history, "source_documents": source_documents})

上面代码的重点内容:

  • 首先因为是正式项目,在鉴权、数据库检测上都做了很多健壮性的处理,例如,对user_id的判别、对数据库及其对应用户的权限判别check_kb_exist,再者还有知识库的判空等。
  • 此处有区分是否使用了流式streaming
  • 最终结果的输出有进行结构化,结构化这事的业务代码专门弄了个函数format_source_documents
  • 这里区分了retrieval_documentssource_documents,两者有所区别,在后面展开聊关键算法流程的时候会展开讲。
  • get_knowledge_based_answer是内部获取知识点并进行生成的关键函数,就是上一条所说的关键算法流程。
# qanything_kernel\utils\general_utils.py 

def format_source_documents(ori_source_documents):
    source_documents = []
    for inum, doc in enumerate(ori_source_documents):
        # for inum, doc in enumerate(answer_source_documents):
        # doc_source = doc.metadata['source']
        file_id = doc.metadata['file_id']
        file_name = doc.metadata['file_name']
        # source_str = doc_source if isURL(doc_source) else os.path.split(doc_source)[-1]
        source_info = {'file_id': doc.metadata['file_id'],
                       'file_name': doc.metadata['file_name'],
                       'content': doc.page_content,
                       'retrieval_query': doc.metadata['retrieval_query'],
                       'kernel': doc.metadata['kernel'],
                       'score': str(doc.metadata['score']),
                       'embed_version': doc.metadata['embed_version']}
        source_documents.append(source_info)
    return source_documents

2,RAG推理流程

get_knowledge_based_answer的函数很简单,不过单独拿出来,对可读性是有挺大帮助的。
RAG说白了就是先搜后交给大模型生成,终于讲到这段代码了,流程在这里qanything_kernel\core\local_doc_qa.py

   
# qanything_kernel\core\local_doc_qa.py
@get_time
    def get_knowledge_based_answer(self, query, milvus_kb, chat_history=None, streaming: bool = STREAMING,
                                   rerank: bool = False):
        if chat_history is None:
            chat_history = []
        retrieval_queries = [query]

        source_documents = self.get_source_documents(retrieval_queries, milvus_kb)

        deduplicated_docs = self.deduplicate_documents(source_documents)
        retrieval_documents = sorted(deduplicated_docs, key=lambda x: x.metadata['score'], reverse=True)
        if rerank and len(retrieval_documents) > 1:
            debug_logger.info(f"use rerank, rerank docs num: {len(retrieval_documents)}")
            retrieval_documents = self.rerank_documents(query, retrieval_documents)

        source_documents = self.reprocess_source_documents(query=query,
                                                           source_docs=retrieval_documents,
                                                           history=chat_history,
                                                           prompt_template=PROMPT_TEMPLATE)
        prompt = self.generate_prompt(query=query,
                                      source_docs=source_documents,
                                      prompt_template=PROMPT_TEMPLATE)
        t1 = time.time()
        for answer_result in self.llm.generatorAnswer(prompt=prompt,
                                                      history=chat_history,
                                                      streaming=streaming):
            resp = answer_result.llm_output["answer"]
            prompt = answer_result.prompt
            history = answer_result.history

            # logging.info(f"[debug] get_knowledge_based_answer history = {history}")
            history[-1][0] = query
            response = {"query": query,
                        "prompt": prompt,
                        "result": resp,
                        "retrieval_documents": retrieval_documents,
                        "source_documents": source_documents}
            yield response, history
        t2 = time.time()
        debug_logger.info(f"LLM time: {t2 - t1}")

首先注意到这里有个装饰器@get_time。可以用来记录执行的时间。

def get_time(func):
    def inner(*arg, **kwargs):
        s_time = time.time()
        res = func(*arg, **kwargs)
        e_time = time.time()
        print('函数 {} 执行耗时: {} 秒'.format(func.__name__, e_time - s_time))
        return res

    return inner

2.1 检索&粗排

get_source_documents是检索的过程,即给定了retrieval_queriesmilvus_kb,即query所需要查的数据库,开始进行查询。这个的返回结果,会放在retrieval_documents里面,即**“检索到的文档”**,下面是源码。

def get_source_documents(self, queries, milvus_kb, cosine_thresh=None, top_k=None):
    milvus_kb: MilvusClient
    if not top_k:
        top_k = self.top_k
    source_documents = []
    embs = self.embeddings._get_len_safe_embeddings(queries)
    t1 = time.time()
    batch_result = milvus_kb.search_emb_async(embs=embs, top_k=top_k, queries=queries)
    t2 = time.time()
    debug_logger.info(f"milvus search time: {t2 - t1}")
    for query, query_docs in zip(queries, batch_result):
        for doc in query_docs:
            doc.metadata['retrieval_query'] = query  # 添加查询到文档的元数据中
            doc.metadata['embed_version'] = self.embeddings.embed_version
            source_documents.append(doc)
    if cosine_thresh:
        source_documents = [item for item in source_documents if float(item.metadata['score']) > cosine_thresh]
 
    return source_documents
  • _get_len_safe_embeddings给定query获取向量。在上一期RAG开源项目Qanything源码阅读2-离线文件处理有讲过,这个内部是请求一个向量模型的服务,背后的模型是需要和离线文件处理那个模型一致,所以部署同一个就会比较稳当,当然的,接口也是triton,一个grpc接口,有关GRPC,上次忘了放链接,这次放这里心法利器[6] | python grpc实践,非常建议大家详细了解并且学会。

  • search_emb_async是用于做向量检索的。这个就是pymilvus的核心功能了。

  • 此处,查询出来还要过一个阈值卡控,对相似度达不到阈值的文档,需要过滤,阈值设置在cosine_thresh

_get_len_safe_embeddings 使用的embedding 代码(可跳过,继续回到 get_knowledge_based_answer
# qanything_kernel\connector\embedding\embedding_for_local.py
"""Wrapper around YouDao embedding models."""
from typing import List

from qanything_kernel.connector.embedding.embedding_client import EmbeddingClient
from qanything_kernel.configs.model_config import LOCAL_EMBED_SERVICE_URL, LOCAL_EMBED_MODEL_NAME, LOCAL_EMBED_MAX_LENGTH, LOCAL_EMBED_BATCH
from qanything_kernel.utils.custom_log import debug_logger
import concurrent.futures
from tqdm import tqdm 

embedding_client = EmbeddingClient(
    server_url=LOCAL_EMBED_SERVICE_URL,
    model_name=LOCAL_EMBED_MODEL_NAME,
    model_version='1',
    resp_wait_s=120,
    tokenizer_path='qanything_kernel/connector/embedding/embedding_model_0630')


class YouDaoLocalEmbeddings:
    def __init__(self):
        pass

    def _get_embedding(self, queries):
        embeddings = embedding_client.get_embedding(queries, max_length=LOCAL_EMBED_MAX_LENGTH)
        return embeddings

    def _get_len_safe_embeddings(self, texts: List[str]) -> List[List[float]]:
        all_embeddings = []
        batch_size = LOCAL_EMBED_BATCH

        with concurrent.futures.ThreadPoolExecutor() as executor:
            futures = []
            for i in range(0, len(texts), batch_size):
                batch = texts[i:i + batch_size]
                future = executor.submit(self._get_embedding, batch)
                futures.append(future)
            debug_logger.info(f'embedding number: {len(futures)}')
            for future in tqdm(futures):
                embeddings = future.result()
                all_embeddings += embeddings
        return all_embeddings

    @property
    def embed_version(self):
        return embedding_client.getModelVersion()

回到 get_knowledge_based_answer
留意到 qanything_kernel\core\local_doc_qa.py 文件里的 get_knowledge_based_answer 里这一串代码:

retrieval_documents = sorted(deduplicated_docs, key=lambda x: x.metadata['score'], reverse=True)
if rerank and len(retrieval_documents) > 1:
    debug_logger.info(f"use rerank, rerank docs num: {len(retrieval_documents)}")
    retrieval_documents = self.rerank_documents(query, retrieval_documents)
  • 此处注意,这里的检索还涉及一个过程“粗排”(上面第一行代码),这个粗排是指查询数据库的时候,需要根据相似度进行排序,只取TOPN,毕竟如果不进行这个TOP的卡控,那数据库里所有的数据都会被查出来,这没什么意义了。这里之所以叫粗排,是因为这种相似度的对比是比较粗略的,只能过滤掉“肯定不是”的那些无关结果。具体“哪个好”,用额外的、更精准的模型来做会更好,达到“优中取优”的目的。

2.2 检索&粗排

继续关注这里的 qanything_kernel\core\local_doc_qa.py 的 get_knowledge_based_answer里调用的 rerank_documents,这个就是精排,或者像这里说的重排。

def rerank_documents(self, query, source_documents):
    return self.rerank_documents_for_local(query, source_documents)
 
def rerank_documents_for_local(self, query, source_documents):
    if len(query) > 300:  # tokens数量超过300时不使用local rerank
        return source_documents
 
    source_documents_reranked = []
    try:
        response = requests.post(f"{self.local_rerank_service_url}/rerank",
                                    json={"passages": [doc.page_content for doc in source_documents], "query": query})
        scores = response.json()
        for idx, score in enumerate(scores):
            source_documents[idx].metadata['score'] = score
            if score < 0.35 and len(source_documents_reranked) > 0:
                continue
            source_documents_reranked.append(source_documents[idx])
 
        source_documents_reranked = sorted(source_documents_reranked, key=lambda x: x.metadata['score'], reverse=True)
    except Exception as e:
        debug_logger.error("rerank error: %s", traceback.format_exc())
        debug_logger.warning("rerank error, use origin retrieval docs")
        source_documents_reranked = sorted(source_documents, key=lambda x: x.metadata['score'], reverse=True)
 
    return source_documents_reranked

简单地,这里就是把所有召回回来的文章请求到重排服务来算分,根据算分来进行过滤和排序,筛选出最优的文章。和向量模型类似,一样是用triton部署的,看模型名像是QAEnsemble_embed_rerank

2.3 检索文档后处理

更进一步,需要对文档进行后处理,即reprocess_source_documents函数。qanything_kernel\core\local_doc_qa.py

#source_documents = self.reprocess_source_documents(query=query,
#                                                           source_docs=retrieval_documents,
#                                                           history=chat_history,
#                                                           prompt_template=PROMPT_TEMPLATE)

def reprocess_source_documents(self, query: str,
                                source_docs: List[Document],
                                history: List[str],
                                prompt_template: str) -> List[Document]:
    # 组装prompt,根据max_token
    query_token_num = self.llm.num_tokens_from_messages([query])
    history_token_num = self.llm.num_tokens_from_messages([x for sublist in history for x in sublist])
    template_token_num = self.llm.num_tokens_from_messages([prompt_template])
 
    # logging.info(f"<self.llm.token_window, self.llm.max_token, self.llm.offcut_token, query_token_num, history_token_num, template_token_num>, types = {type(self.llm.token_window), type(self.llm.max_token), type(self.llm.offcut_token), type(query_token_num), type(history_token_num), type(template_token_num)}, values = {query_token_num, history_token_num, template_token_num}")
    limited_token_nums = self.llm.token_window - self.llm.max_token - self.llm.offcut_token - query_token_num - history_token_num - template_token_num
    new_source_docs = []
    total_token_num = 0
    for doc in source_docs:
        doc_token_num = self.llm.num_tokens_from_docs([doc])
        if total_token_num + doc_token_num <= limited_token_nums:
            new_source_docs.append(doc)
            total_token_num += doc_token_num
        else:
            remaining_token_num = limited_token_nums - total_token_num
            doc_content = doc.page_content
            doc_content_token_num = self.llm.num_tokens_from_messages([doc_content])
            while doc_content_token_num > remaining_token_num:
                # Truncate the doc content to fit the remaining tokens
                if len(doc_content) > 2 * self.llm.truncate_len:
                    doc_content = doc_content[self.llm.truncate_len: -self.llm.truncate_len]
                else:  # 如果最后不够truncate_len长度的2倍,说明不够切了,直接赋值为空
                    doc_content = ""
                    break
                doc_content_token_num = self.llm.num_tokens_from_messages([doc_content])
            doc.page_content = doc_content
            new_source_docs.append(doc)
            break
 
    debug_logger.info(f"limited token nums: {limited_token_nums}")
    debug_logger.info(f"template token nums: {template_token_num}")
    debug_logger.info(f"query token nums: {query_token_num}")
    debug_logger.info(f"history token nums: {history_token_num}")
    debug_logger.info(f"new_source_docs token nums: {self.llm.num_tokens_from_docs(new_source_docs)}")
    return new_source_docs
  • 这里的llm,是一个自己封装好的大模型工具,具体是在qanything_kernel\connector\llm\llm_for_fastchat.py这个位置。里面支持计算token请求大模型等通用功能。这个工具可以结合自己场景的需求搬过去直接使用。

  • 计算limited_token_nums主要是方便组装prompt,避免某些文字被吃掉

  • 这里是需要对文档进行新的拼接和调整,如果查询的文档太多太长,则需要截断,且截断的时候需要注意,要保证截断的位置必须是完整地句子,如果不够长直接不切了。

2.4 prompt和请求大模型

然后就是开始生成promptgenerate_prompt。说白了就是一个简单的拼接。另外,这里的prompt拼接,更多使用replace来完成,之前有看过别的模式,例如用字符串的format应该也可以,不过replace的适用范围会更广一些。

def generate_prompt(self, query, source_docs, prompt_template):
    context = "\n".join([doc.page_content for doc in source_docs])
    prompt = prompt_template.replace("{question}", query).replace("{context}", context)
    return prompt

顺带就看看他们的prompt吧,实际上并不复杂。

PROMPT_TEMPLATE = """参考信息:
{context}
---
我的问题或指令:
{question}
---
请根据上述参考信息回答我的问题或回复我的指令。前面的参考信息可能有用,也可能没用,你需要从我给出的参考信息中选出与我的问题最相关的那些,来为你的回答提供依据。回答一定要忠于原文,简洁但不丢信息,不要胡乱编造。我的问题或指令是什么语种,你就用什么语种回复,
你的回复:"""

最后一步就是开始请求大模型了。即generatorAnswer函数。

def generatorAnswer(self, prompt: str,
                    history: List[List[str]] = [],
                    streaming: bool = False) -> AnswerResult:
 
    if history is None or len(history) == 0:
        history = [[]]
    logging.info(f"history_len: {self.history_len}")
    logging.info(f"prompt: {prompt}")
    logging.info(f"prompt tokens: {self.num_tokens_from_messages([{'content': prompt}])}")
    logging.info(f"streaming: {streaming}")
            
    response = self._call(prompt, history[:-1], streaming)
    complete_answer = ""
    for response_text in response:
 
        if response_text:
            chunk_str = response_text[6:]
            if not chunk_str.startswith("[DONE]"):
                chunk_js = json.loads(chunk_str)
                complete_answer += chunk_js["answer"]
                
        history[-1] = [prompt, complete_answer]
        answer_result = AnswerResult()
        answer_result.history = history
        answer_result.llm_output = {"answer": response_text}
        answer_result.prompt = prompt
        yield answer_result

这里就是请求大模型的基本话术了,相对还是比较简单的,一方面是请求大模型,另一方面是解析大模型内的结果。有留意到,这里有对内容做一些校验:

if response_text:
    chunk_str = response_text[6:]
    if not chunk_str.startswith("[DONE]"):
        chunk_js = json.loads(chunk_str)
        complete_answer += chunk_js["answer"]

可以看出应该是有一些泛用性处理,能解决更多复杂的问题吧。

小结

本文把QAnything项目内的重要的推理部分穿讲了一遍,可以看出这个项目已经非常完成,基本具备上线所需的关键部分,同时也有很严格的校验逻辑,严格程度很高也比较稳定,经过这个学习,自己对工程代码和具体实施的理解有了很大的提升,希望大家也有收获。当然有空再复习一遍应该有更大收获。

QAnything在服务的完整性、健壮性,以及文档处理上都有了很多的更新,但都不要指望用上就能达到很高的水准,需要进一步提升还需要更多内里的修炼:

  • query理解辅助更好地提升检索的准确性
  • 联合训练提升大模型和检索结果的协同
  • 更深入定制的文档处理提升内容的可读性等

补充

qanything_kernel\connector\llm\llm_for_fastchat.py
from abc import ABC
import tiktoken
import os
from dotenv import load_dotenv
from openai import OpenAI
from typing import Optional, List
import sys
import json
import requests
import logging
sys.path.append("../../../")
from qanything_kernel.connector.llm.base import (BaseAnswer, AnswerResult)
from qanything_kernel.configs.model_config import LOCAL_LLM_SERVICE_URL, LOCAL_LLM_MODEL_NAME, LOCAL_LLM_MAX_LENGTH

load_dotenv()

logging.basicConfig(level=logging.INFO)

class OpenAICustomLLM(BaseAnswer, ABC):
    model: str = LOCAL_LLM_MODEL_NAME
    token_window: int = LOCAL_LLM_MAX_LENGTH
    max_token: int = 512
    offcut_token: int = 50
    truncate_len: int = 50
    temperature: float = 0
    stop_words: str = None
    history: List[List[str]] = []
    history_len: int = 2

    def __init__(self):
        super().__init__()
        # self.client = OpenAI(base_url="http://localhost:7802/v1", api_key="EMPTY")
        if LOCAL_LLM_SERVICE_URL.startswith("http://"):
            base_url = f"{LOCAL_LLM_SERVICE_URL}/v1" 
        else:
            base_url = f"http://{LOCAL_LLM_SERVICE_URL}/v1" 
        self.client = OpenAI(base_url=base_url, api_key="EMPTY")

    @property
    def _llm_type(self) -> str:
        return "CustomLLM using FastChat w/ huggingface transformers or vllm backend"

    @property
    def _history_len(self) -> int:
        return self.history_len

    def set_history_len(self, history_len: int = 10) -> None:
        self.history_len = history_len

    def token_check(self, query: str) -> int:
        
        if LOCAL_LLM_SERVICE_URL.startswith("http://"):
            base_url = f"{LOCAL_LLM_SERVICE_URL}/api/v1/token_check" 
        else:
            base_url = f"http://{LOCAL_LLM_SERVICE_URL}/api/v1/token_check" 

        headers = {"Content-Type": "application/json"}
        
        response = requests.post(
            base_url, 
            data=json.dumps(
                {'prompts': [{'model': self.model, 'prompt': query, 'max_tokens': self.max_token}]}
            ),
            headers=headers, timeout=60)

        # {'prompts': [{'fits': True, 'tokenCount': 317, 'contextLength': 8192}]}
        result = response.json()
        token_num = 0
        try:
            token_num = result['prompts'][0]['tokenCount']
            return token_num
        except Exception as e:
            logging.error(f"token_check Exception {base_url} w/ {e}")
            return token_num

    def num_tokens_from_messages(self, message_texts):
        num_tokens = 0
        for message in message_texts:
            num_tokens += self.token_check(message)
        return num_tokens

    def num_tokens_from_docs(self, docs):
        num_tokens = 0
        for doc in docs:
            num_tokens += self.token_check(doc.page_content)
        return num_tokens

    def _call(self, prompt: str, history: List[List[str]], streaming: bool=False) -> str:
        messages = []
        for pair in history:
            question, answer = pair
            messages.append({"role": "user", "content": question})
            messages.append({"role": "assistant", "content": answer})
        messages.append({"role": "user", "content": prompt})
        logging.info(messages)

        try:

            if streaming:
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=messages,
                    stream=True,
                    max_tokens=self.max_token,
                    # temperature=self.temperature,
                    stop=[self.stop_words] if self.stop_words is not None else None,
                )

                for event in response:
                    if not isinstance(event, dict):
                        event = event.model_dump()

                    if event["choices"] is None:
                        event_text = event["text"] + " error_code:" + str(event["error_code"])
                    else:
                        event_text = event["choices"][0]['delta']['content']
                    if isinstance(event_text, str) and event_text != "":
                        # logging.info(f"[debug] event_text = [{event_text}]")
                        delta = {'answer': event_text}
                        yield "data: " + json.dumps(delta, ensure_ascii=False)

            else:
                response = self.client.chat.completions.create(
                    model=self.model,
                    messages=messages,
                    stream=False,
                    max_tokens=self.max_token,
                    # temperature=self.temperature,
                    stop=[self.stop_words] if self.stop_words is not None else None,
                )
                
                # logging.info(f"[debug] response.choices = [{response.choices}]")
                event_text = response.choices[0].message.content if response.choices else ""
                delta = {'answer': event_text}
                yield "data: " + json.dumps(delta, ensure_ascii=False)

        except Exception as e:
            logging.info(f"Error calling API: {e}")
            delta = {'answer': f"{e}"}
            yield "data: " + json.dumps(delta, ensure_ascii=False)

        finally:
            # logging.info("[debug] try-finally")
            yield f"data: [DONE]\n\n"

    def generatorAnswer(self, prompt: str,
                        history: List[List[str]] = [],
                        streaming: bool = False) -> AnswerResult:

        if history is None or len(history) == 0:
            history = [[]]
        logging.info(f"history_len: {self.history_len}")
        logging.info(f"prompt: {prompt}")
        logging.info(f"prompt tokens: {self.num_tokens_from_messages([prompt])}")
        logging.info(f"streaming: {streaming}")
                
        response = self._call(prompt, history[:-1], streaming)
        complete_answer = ""
        for response_text in response:

            if response_text:
                chunk_str = response_text[6:]
                if not chunk_str.startswith("[DONE]"):
                    chunk_js = json.loads(chunk_str)
                    complete_answer += chunk_js["answer"]
                    
            history[-1] = [prompt, complete_answer]
            answer_result = AnswerResult()
            answer_result.history = history
            if streaming:
                answer_result.llm_output = {"answer": response_text}
            else:
                answer_result.llm_output = {"answer": complete_answer}
            answer_result.prompt = prompt
            yield answer_result


if __name__ == "__main__":

    base_url = f"http://{LOCAL_LLM_SERVICE_URL}/api/v1/token_check" 
    headers = {"Content-Type": "application/json"}
    query = "hello"
    response = requests.post(
        base_url, 
        data=json.dumps(
            {'prompts': [{'model': LOCAL_LLM_MODEL_NAME, 'prompt': query, 'max_tokens': 512}]}
        ),
        headers=headers, timeout=60)

    # {'prompts': [{'fits': True, 'tokenCount': 317, 'contextLength': 8192}]}
    result = response.json()
    logging.info(f"[debug] result = {result}")


    llm = OpenAICustomLLM()
    streaming = True
    chat_history = []
    prompt = "你是谁"
    prompt = """参考信息:
中央纪委国家监委网站讯 据山西省纪委监委消息:山西转型综合改革示范区党工委副书记、管委会副主任董良涉嫌严重违纪违法,目前正接受山西省纪委监委纪律审查和监察调查。\\u3000\\u3000董良简历\\u3000\\u3000董良,男,汉族,1964年8月生,河南鹿邑人,在职研究生学历,邮箱random@xxx.com,联系电话131xxxxx909,1984年3月加入中国共产党,1984年8月参加工作\\u3000\\u3000历任太原经济技术开发区管委会副主任、太原武宿综合保税区专职副主任,山西转型综合改革示范区党工委委员、管委会副主任。2021年8月,任山西转型综合改革示范区党工委副书记、管委会副主任。(山西省纪委监委)
---
我的问题或指令:
帮我提取上述人物的中文名,英文名,性别,国籍,现任职位,最高学历,毕业院校,邮箱,电话
---
请根据上述参考信息回答我的问题或回复我的指令。前面的参考信息可能有用,也可能没用,你需要从我给出的参考信息中选出与我的问题最相关的那些,来为你的回答提供依据。回答一定要忠于原文,简洁但不丢信息,不要胡乱编造。我的问题或指令是什么语种,你就用什么语种回复,
你的回复:"""
    final_result = ""
    for answer_result in llm.generatorAnswer(prompt=prompt,
                                                      history=chat_history,
                                                      streaming=streaming):
        resp = answer_result.llm_output["answer"]
        if "DONE" not in resp:
            final_result += json.loads(resp[6:])["answer"]
        # logging.info(resp)

    logging.info(f"final_result = {final_result}")

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/767063.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

【Elasticsearch】一、概述,安装

文章目录 概述全文搜索引擎概述ES&#xff08;7.x&#xff09; 安装ES&#xff08;Docker&#xff09;测试&#xff0c;是否启动成功 可视化工具配置中文 客户端Postman下载 概述 ES是开源的高扩展的分布式全文搜索引擎&#xff0c;实时的存储、检索数据&#xff1b;本身扩展性…

function-calling初体验

课程地址&#xff1a;https://learn.deeplearning.ai/courses/function-calling-and-data-extraction-with-llms/lesson/1/introduction github notebook地址&#xff1a;https://github.com/kingglory/LLMs-function-calling/tree/main Function-Calling 介绍 函数调用(Funct…

Linux Centos7部署Zookeeper

目录 一、下载zookeeper 二、单机部署 1、创建目录 2、解压 3、修改配置文件名 ​4、创建保存数据的文件夹 ​5、修改配置文件保存数据的地址 ​6、启动服务 7、api创建节点 一、下载zookeeper 地址&#xff1a;Index of /dist/zookeeper/zookeeper-3.5.7 (apache.org…

Python23 使用Tensorflow实现线性回归

TensorFlow 是一个开源的软件库&#xff0c;用于数值计算&#xff0c;特别适用于大规模的机器学习。它由 Google 的研究人员和工程师在 Google Brain 团队内部开发&#xff0c;并在 2015 年首次发布。TensorFlow 的核心是使用数据流图来组织计算&#xff0c;使得它可以轻松地利…

【Python画图-驯化seaborn】一文搞懂seaborn中的箱线图实践技巧

【Python画图-驯化seaborn】一文搞懂seaborn中的箱线图实践技巧 本次修炼方法请往下查看 &#x1f308; 欢迎莅临我的个人主页 &#x1f448;这里是我工作、学习、实践 IT领域、真诚分享 踩坑集合&#xff0c;智慧小天地&#xff01; &#x1f387; 免费获取相关内容文档关注&a…

05 docker 镜像

目录 1. 镜像 2. 联合文件系统 3. docker镜像加载原理 4. 镜像分层 镜像分层的优势 5. 容器层 1. 镜像 镜像是一种轻量级、可执行的独立软件包&#xff0c;它包含运行某个软件所需的所有内容&#xff0c;我们把应用程序和配置依赖打包好行程一个可交付的运行环境&#xf…

每日一题 7月1日

1 设数组data[m]作为循环队列的存储空间,front为队头指针,rear为队尾指针,则执行出队操作后其头指针front值为____ 2 采用滑动窗口机制对两个相邻结点A(发送方)和B(接收方)的通信过程进行流量控制。假定帧的序号长度为3比特,发送窗口与接收窗口的大小均为7,当A发送了…

昇思25天学习打卡营第9天|MindSpore-Vision Transformer图像分类

Vision Transformer图像分类 Vision Transformer(ViT)简介 近些年,随着基于自注意(Self-Attention)结构的模型的发展,特别是Transformer模型的提出,极大地促进了自然语言处理模型的发展。由于Transformers的计算效率和可扩展性,它已经能够训练具有超过100B参数的空前…

传输线在阻抗匹配时串联端接电阻为什么要靠近发送端

传输线在阻抗匹配时串联端接电阻为什么要靠近发送端 在进行阻抗匹配的时候我们可以在电阻源端放置一个串联端接电阻&#xff0c;但是有时候受到空间的限制可能会把电阻摆的稍微远一点&#xff0c;那么这个时候大家可能会有疑问&#xff0c;电阻离发送端远一点或者电阻放置在接…

java+mysql教师管理系统

完整源码地址 教师信息管理系统使用命令行交互的方式及数据库连接实现教师信息管理系统&#xff0c;该系统旨在实现教师信息的管理&#xff0c;并根据需要进行教师信息展示。该软件的功能有如下功能 (1)基本信息管理(教师号、姓名、性别、出生年月、职称、学历、学位、教师类型…

【Git 学习笔记】1.3 Git 的三个阶段

1.3 Git 的三个阶段 由于远程代码库后续存在新的提交&#xff0c;因此实操过程中的结果与书中并不完全一致。根据书中 HEAD 指向的 SHA-1&#xff1a;34acc370b4d6ae53f051255680feaefaf7f7850d&#xff0c;可通过以下命令切换到对应版本&#xff0c;并新建一个 newdemo 分支来…

【STM32 RTC实时时钟如何配置!超详细的解析和超简单的配置,附上寄存器操作】

STM32 里面RTC模块和时钟配置系统(RCC_BDCR寄存器)处于后备区域&#xff0c;即在系统复位或从待机模式唤醒后&#xff0c;RTC的设置和时间维持不变。因为系统对后备寄存器和RTC相关寄存器有写保护&#xff0c;所以如果想要对后备寄存器和RTC进行访问&#xff0c;则需要通过操作…

社交媒体优化的智能顾问:Kompas.ai如何提升品牌社交表现

在社交媒体盛行的数字时代&#xff0c;品牌必须在社交平台上保持活跃和互动&#xff0c;以增强品牌社交互动和提升在线可见性。社交媒体优化不仅能够扩大品牌的影响力&#xff0c;还能够加深与消费者的联系。Kompas.ai&#xff0c;作为一款智能社交媒体顾问工具&#xff0c;能够…

【前端项目笔记】7 商品管理

商品管理 效果展示&#xff1a; 在功能开发之前&#xff0c;创建商品列表的子分支 git branch 查看所有分支 git checkout -b goods_list 创建并切换到新分支goods_list git push -u origin goods_list 将新分支goods_list推送到云端仓库origin并命名为goods_list保存 通过…

LLM学习记录

概述 语言模型的发展 语言模型经历过四个阶段的发展&#xff0c;依次从统计语言模型到神经网络语言模型&#xff08;NLM&#xff09;&#xff0c;到出现以 BERT 和 Transformer 架构为代表的预训练语言模型&#xff08;PLM&#xff09;&#xff0c;最终到大型语言模型阶段&am…

竞赛选题 交通目标检测-行人车辆检测流量计数 - 竞赛选题

文章目录 0 前言1\. 目标检测概况1.1 什么是目标检测&#xff1f;1.2 发展阶段 2\. 行人检测2.1 行人检测简介2.2 行人检测技术难点2.3 行人检测实现效果2.4 关键代码-训练过程 最后 0 前言 &#x1f525; 优质竞赛项目系列&#xff0c;今天要分享的是 &#x1f6a9; 毕业设计…

【Java环境配置过程详解(包括IDEA配置Java)】

目录 一、JDK下载安装 1. 官网下载JDK 2. 本地安装JDK 3. 配置环境变量 4. 验证是否安装成功 ​编辑二、IDEA进行安装下载 1. 官网下载 IDEA 2、IDEA进行Java开发 1. 创建Java项目 2. 程序测试 一、JDK下载安装 1. 官网下载JDK 1&#xff09;官网链接: https://www.o…

PTrade如何获取技术值班?如get_RSI - 相对强弱指标;PTrade量化软件如何获取?

get_RSI - 相对强弱指标 get_RSI(close, n6) 使用场景 该函数仅在回测、交易模块可用 接口说明 获取相对强弱指标RSI指标的计算结果 PTrade是恒生公司开发的一款专业量化软件&#xff0c;部分合作券商可提供&#xff0c;↑↑↑&#xff01; 参数 close&#xff1a;价格…

C语言的数据结构:图的基本概念

前言 之前学过了其它的数据结构&#xff0c;如&#xff1a; 集合 \color{#5ecffd}集合 集合 —— 数据元素属于一个集合。 线型结构 \color{#5ecffd}线型结构 线型结构 —— 一个对一个&#xff0c;如线性表、栈、队列&#xff0c;每一个节点和其它节点之间的关系 一个对一个…

rpm包下载

内网无法下载、选择外网的一台机器下载rpm包 下载后上传rpm包 1、创建下载目录 mkdir /data/asap/test 2、下载能留存包的工具 sudo yum install yum-utils -y 报错就是环境问题没下载成功&#xff0c;我换了个环境正常的机器就可以了 3、下载rpm包到指定目录/data/asa…