Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions LLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,7 @@ def __init__(self, model_path :str):
super().__init__()
print("正在从本地加载模型...")
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
print("完成AutoTokenizer...")
self.model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to(torch.bfloat16).cuda()
print("完成AutoModelForCausalLM...")
self.model = self.model.eval()
print("完成本地模型的加载")

Expand Down
75 changes: 75 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
<div align="center">
<img src="https://github.com/Everfighting/SpringFestQA/blob/main/assets/logo.jpeg" width=30% />
</div>

# SpringFestQA(年关走亲访友渡劫助手)

## 介绍
SpringFestQA(年关走亲访友渡劫助手)收集了网络上中国春节的怼亲戚语录,
基于InternLM2进行微调以及检索增强链生成的模仿年轻人语气对亲戚提问作出巧妙回答的语言模型。
过年走亲访友过程中,难免遇到亲戚的热辣提问让你不知所措,还在为躺在床上才回想起来如何回怼而感到懊恼吗?
直接将棘手的提问交给大模型,让亲戚哑口无言。

在可视化网页界面中,我们提供了三种不同的回答风格:委婉认真、转换话题和阴阳怪气,
以满足不同性格的人的回答需求。通过直接点击右侧的按钮,可以生成对应风格的回答。

## OpenXLab体验地址:
```
https://openxlab.org.cn/apps/detail/SijieLyu/SpringFestQA
```
## SpringFestQA整体流程框架
待流程图完成后补充到这里

## 数据收集
数据集放在本仓库的data目录下:
### 1)MBTI
为开源的MBTI中文版本数据集,jason格式,包含四个主题:感情、收入、学业、房子这四类。
具体可参考 https://github.com/PKU-YuanGroup/Machine-Mindset/tree/main/datasets/behaviour
### 2)origin_data
用ChatGLM生成的五种风格的数据库,csv格式,分别是诙谐幽默、转换话题、委婉回答、阴阳怪气、故作神秘/深沉,
每种数据1万条,保证数据量足够,并在调试过程中优化为3种稳定且差异性输出的风格。
### 3)alpaca_data
以origin_data为原始数据转换成的json格式数据,转换代码可参考convert.py

## 基于大语言模型的数据增广方法
- 先行人工拟定对于回答的基于不同风格的少量样例数据
- 根据少量数据构造对应的prompt
- 将prompt输入LLM生成更多的数据语料
- 对语料进行人工审核构建对应的训练数据集

## 模型微调
依据MBTI数据和QA对,使用Xtuner对InternLM-Chat-7B的性格和内在知识储备进行部分参数微调,形成微调后的模型SpringFest。
性格的训练出来但回答不太有用,要多轮对话才能体现人格,但大模型多轮对话能力有限,发现效果不如预期。【可以在gradio页面上增加I/E选项,后端可调用不同模型】

## 构建知识库(RAG)
依据QA对,基于langchain框架构建进行embedding,形成外挂知识库,可以针对用户的 query 进行语义向量检索,得到与用户提问相关的知识片段。【可以在页面上增加prompt提示】
Prompt调优
当知识库内容不足时,转而激发微调后的SpringFest大模型能力,用incontext-learning的方法给模型展示正确的例子生成回答,
包括但不限于:
1.使用system prompt让模型明确任务;
2.通过委婉认真、转移话题、阴阳怪气三个风格的内置prompt,满足用户自行切换回答风格的需求。
[图片]

## ModelScope模型
### 模型权重下载(代码中已内置下载,不需要操作)
https://www.modelscope.cn/binbeing/SpringFestQA.git
SpringFestQA是InternLM2为底座模型,使用春节话题数据集和性格数据集,通过XTuner进行微调后获得的模型。可安装modelscope库后按以下命令进行下载:
```
import torch
from modelscope import snapshot_download, AutoModel, AutoTokenizer
import os
model_dir = snapshot_download('binbeing/SpringFestQA', cache_dir='./')
```
## 模型部署
- 部署到应用平台(以OpenXLab为例)
仅需要 Fork 本仓库,然后在 OpenXLab 上创建一个新的项目,将 Fork 的仓库与新建的项目关联,即可在 OpenXLab 上部署 SpringFestQA。
- 部署到本地(以InternStudio开发机为例)
```
git clone https://github.com/Lyusijie/SpringFestQA.git

python app.py

ssh -CNg -L 7860:127.0.0.1:7860 root@ssh.intern-ai.org.cn -p 33471做端口转发(
其中33471改为自己开发机端口)
```
- 本地网页打开:127.0.0.1:7860
166 changes: 1 addition & 165 deletions app.py
Original file line number Diff line number Diff line change
@@ -1,166 +1,2 @@
__import__('pysqlite3')
import sys
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3')

# 导入必要的库
import gradio as gr
from langchain.vectorstores import Chroma
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
from LLM import InternLM_LLM
from langchain.prompts import PromptTemplate
import torch
from modelscope import snapshot_download, AutoModel, AutoTokenizer
import os
from langchain.chains import RetrievalQA
from langchain import PromptTemplate, LLMChain
from langchain.chains import ConversationalRetrievalChain
from langchain.chains.question_answering import load_qa_chain
from langchain.chains.conversational_retrieval.prompts import CONDENSE_QUESTION_PROMPT

prompt_template_qa="""直接回答问题,除非问题涉及结婚、生孩、工作、恋爱、买房这些内容时才参考使用内容后内容做出回答。

问题: {question}
内容:{context}
回答:"""
def change_prompt1(*args):
global prompt_template_qa
prompt_template_qa = """委婉认真地回答问题,除非问题涉及结婚、生孩、工作、恋爱、买房这些内容时才参考使用内容后内容做出回答。
问题: {question}
内容:{context}
回答:"""

def change_prompt2(*args):
global prompt_template_qa
prompt_template_qa = """转移话题地回答问题,除非问题涉及结婚、生孩、工作、恋爱、买房这些内容时才参考使用内容后内容做出回答。
问题: {question}
内容:{context}
回答:"""

def change_prompt3(*args):
global prompt_template_qa
prompt_template_qa = """阴阳怪气地回答问题,除非问题涉及结婚、生孩、工作、恋爱、买房这些内容时才参考使用内容后内容做出回答。
问题: {question}
内容:{context}
回答:"""

def init():
model_dir = snapshot_download('Shanghai_AI_Laboratory/internlm-chat-7b'
, cache_dir='./')
os.environ['HF_ENDPOINT'] = 'https://hf-mirror.com'
# 下载模型
os.system('huggingface-cli download --resume-download sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2 --local-dir sentence-transformer')


def load_chain():
# 加载问答链
# 定义 Embeddings
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformer")

# 向量数据库持久化路径
persist_directory = 'data_base/vector_db/chroma'

# 加载数据库
vectordb = Chroma(
persist_directory=persist_directory, # 允许我们将persist_directory目录保存到磁盘上
embedding_function=embeddings
)

llm = InternLM_LLM(model_path = "Shanghai_AI_Laboratory/internlm-chat-7b")

# 定义一个 Prompt Template
template = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question, in its original language.
Chat History: {chat_history}
Follow Up Input: {question}
Standalone question: """
prompt_qg = PromptTemplate(
template=template,
input_variables=["chat_history", "question"],
)
global prompt_template_qa

prompt_qa = PromptTemplate(
template=prompt_template_qa,
input_variables=["context", "question"]
)
question_generator = LLMChain(llm=llm, prompt=prompt_qg)
doc_chain = load_qa_chain(llm=llm, chain_type="stuff", prompt=prompt_qa)

# 运行 chain
qa_chain = ConversationalRetrievalChain(retriever=vectordb.as_retriever(),question_generator=question_generator,combine_docs_chain=doc_chain,)
return qa_chain

class Model_center():
"""
存储问答 Chain 的对象
"""
init()
def __init__(self):
self.chain = load_chain()

def qa_chain_self_answer(self, question: str, chat_history:list):
"""
调用问答链进行回答
"""
chat_history_tuples = []
#for message in chat_history:
#chat_history_tuples.append((message[0], message[1]))
chat_history_tuples = tuple(tuple(x) for x in chat_history)
if question == None or len(question) < 1:
return "", chat_history
try:
chat_history.append(
(question, self.chain({"question": question, "chat_history": chat_history_tuples})["answer"]))
# 将问答结果直接附加到问答历史中,Gradio 会将其展示出来
return "", chat_history
except Exception as e:
return e, chat_history


model_center = Model_center()

block = gr.Blocks()
with block as demo:
with gr.Row(equal_height=True):
with gr.Column(scale=15):
gr.Markdown("""<h1><center>SpringFestQA</center></h1>
<center>年关走亲访友渡劫助手</center>
""")
# gr.Image(value=LOGO_PATH, scale=1, min_width=10,show_label=False, show_download_button=False)

with gr.Row():
with gr.Column(scale=4):
chatbot = gr.Chatbot(height=450, show_copy_button=True)
# 创建一个文本框组件,用于输入 prompt。
msg = gr.Textbox(label="Prompt/问题")

with gr.Row():
# 创建提交按钮。
db_wo_his_btn = gr.Button("Chat")
with gr.Row():
# 创建一个清除按钮,用于清除聊天机器人组件的内容。
clear = gr.ClearButton(
components=[chatbot], value="Clear console")

chat_history=[]
# 设置按钮的点击事件。当点击时,调用上面定义的 qa_chain_self_answer 函数,并传入用户的消息和聊天历史记录,然后更新文本框和聊天机器人组件。
db_wo_his_btn.click(model_center.qa_chain_self_answer, inputs=[
msg, chatbot], outputs=[msg, chatbot])
# 创建一个新的gr.Column,用于放置按钮。
with gr.Column(scale=2):
# 创建三个gr.Button组件,分别设置label参数为"类型1","类型2"和"类型3",设置click参数为不同的函数,比如change_prompt1,change_prompt2和change_prompt3。
type1_btn = gr.Button("委婉认真")
type2_btn = gr.Button("转换话题")
type3_btn = gr.Button("阴阳怪气")
type1_btn.click(change_prompt1)
type2_btn.click(change_prompt2)
type3_btn.click(change_prompt3)
gr.Markdown("""提醒:<br>
1. 初始化数据库时间可能较长,请耐心等待。
2. 使用中如果出现异常,将会在文本输入框进行展示,请不要惊慌。 <br>
""")
# threads to consume the request
gr.close_all()
# 启动新的 Gradio 应用,设置分享功能为 True,并使用环境变量 PORT1 指定服务器端口。
# demo.launch(share=True, server_port=int(os.environ['PORT1']))
# 直接启动
demo.launch()
os.system('streamlit run web_internlm2.py --server.address=0.0.0.0 --server.port 7860 --server.fileWatcherType none')
Binary file added assets/logo.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/robot.jpeg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/user.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 12 additions & 7 deletions create_db.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
# 首先导入所需第三方库
from langchain.document_loaders import UnstructuredFileLoader
from langchain.document_loaders import UnstructuredMarkdownLoader
from langchain.document_loaders import PyPDFLoader # for loading the pdf
from langchain.chains import ChatVectorDBChain # for chatting with the pdf
from langchain.document_loaders.csv_loader import CSVLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
Expand All @@ -22,7 +21,7 @@ def get_files(dir_path):
file_list.append(os.path.join(filepath, filename))
elif filename.endswith(".txt"):
file_list.append(os.path.join(filepath, filename))
elif filename.endswith(".pdf"):
elif filename.endswith(".csv"):
file_list.append(os.path.join(filepath, filename))
return file_list

Expand All @@ -40,8 +39,8 @@ def get_text(dir_path):
loader = UnstructuredMarkdownLoader(one_file)
elif file_type == 'txt':
loader = UnstructuredFileLoader(one_file)
elif file_type == 'pdf':
loader = PyPDFLoader(one_file)
elif file_type == 'csv':
loader = CSVLoader(file_path=one_file)
else:
# 如果是不符合条件的文件,直接跳过
continue
Expand All @@ -50,7 +49,13 @@ def get_text(dir_path):

# 目标文件夹
tar_dir = [
"/root/data/paper_demo/graph",
"/root/data/InternLM",
"/root/data/InternLM-XComposer",
"/root/data/lagent",
"/root/data/lmdeploy",
"/root/data/opencompass",
"/root/data/xtuner",
"/root/data/newyear"
]

# 加载目标文件
Expand All @@ -61,7 +66,7 @@ def get_text(dir_path):
# 对文本进行分块
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=500, chunk_overlap=150)
split_docs = text_splitter.split_documents(docs[:10])
split_docs = text_splitter.split_documents(docs)

# 加载开源词向量模型
embeddings = HuggingFaceEmbeddings(model_name="/root/data/model/sentence-transformer")
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed data_base/vector_db/chroma/chroma.sqlite3
Binary file not shown.
Empty file added model/.gitkeep
Empty file.
80 changes: 80 additions & 0 deletions relation/ChinaRelationship.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import json
import re


class RelationshipCounter:
def __init__(self, data_file="data.json", filter_file="filter.json", reverse=False):
self.data = self.load_json(data_file)
self.filter = self.load_json(filter_file)["filter"]
self.reverse = reverse # 是否反转

def load_json(self, file_name):
with open(file_name, "r", encoding="utf-8") as f:
return json.load(f)

# 称谓转换成关联字符
def transform_title_to_key(self, string):
result = string.replace("的", ",").replace("我", "").replace("爸爸", "f").replace("父亲", "f").replace("妈妈", "m").replace(
"母亲", "m").replace("爷爷", "f,f").replace("奶奶", "f,m").replace("外公", "m,f").replace("姥爷", "m,f").replace("外婆",
"m,m").replace(
"姥姥", "m,m").replace("老公", "h").replace("丈夫", "h").replace("老婆", "w").replace("妻子", "h").replace("儿子",
"s").replace(
"女儿", "d").replace("兄弟", "xd").replace("哥哥", "ob").replace("弟弟", "lb").replace("姐妹", "xs").replace("姐姐",
"os").replace(
"妹妹", "ls").strip(",") + ","
for f in self.filter:
exp = "^" + f["exp"].replace("&", "\\&").replace(",", ".*") + "$"
result = re.sub(exp, f["str"].replace("$", "\\"), result)
if self.reverse:
result = result.replace("f", "m").replace("m", "f")
if result.endswith(","):
result = result[:-1]
return result

# 错误关系判断
def error_message(self, key):
if key.find("ob,h") != -1 or key.find("xb,h") != -1 or key.find("lb,h") != -1 or key.find("os,w") != -1 or key.find(
"ls,w") != -1 or key.find("xs,w") != -1 or key.find("f,h") != -1 or key.find("m,w") != -1 or key.find(
"d,w") != -1 or key.find("s,h") != -1:
return "根据我国法律暂不支持同性婚姻,怎么称呼你自己决定吧"
elif key.find("h,h") != -1 or key.find("w,w") != -1:
return "根据我国法律暂不支持重婚,怎么称呼你自己决定吧"
return key

# 关系链转换成称谓
def transform_key_to_title(self, string):
if not string:
return None
result = []
seen = set()
for s in string.split("#"):
if s != self.error_message(s):
return self.error_message(s)
if s in self.data and self.data[s][0] not in seen:
result.append(self.data[s][0])
seen.add(self.data[s][0])
# 如果结果为空,再使用逗号分割子字符串
if not result:
for s in string.split(','):
if s != self.error_message(s):
return self.error_message(s)
if s in self.data and self.data[s][0] not in seen:
result.append(self.data[s][0])
seen.add(self.data[s][0])

if '自己' in result:
result.remove('自己')
result_str = ','.join(result)
if self.reverse:
result_str = result_str.replace("父", "母").replace("母", "父")
return result_str


if __name__ == '__main__':
rc = RelationshipCounter()
print(rc.transform_key_to_title(rc.transform_title_to_key("我的爸爸")))
print(rc.transform_key_to_title(rc.transform_title_to_key("我的父亲的儿子")))
print(rc.transform_key_to_title(rc.transform_title_to_key("我的哥哥的丈夫")))
print(rc.transform_key_to_title(rc.transform_title_to_key("我的哥哥的弟弟")))
print(rc.transform_key_to_title(rc.transform_title_to_key("我的爸爸的爸爸")))
print(rc.transform_key_to_title(rc.transform_title_to_key("我的哥哥的姐姐的妹妹")))
Loading