Skip to content

Is it possible to provide a demo code for bert-base-chinese-qa? #30

@WuJiunShiung

Description

@WuJiunShiung

Hi, I am new in this field. Is it possible to provide a demo code for bert-base-chinese-qa?
I tried the following code, following the book "Getting Started with Google BERT":

from transformers import BertTokenizerFast, BertForQuestionAnswering

Tokenizer = BertTokenizerFast.from_pretrained("ckiplab/bert-base-chinese")
model = BertForQuestionAnswering.from_pretrained("ckiplab/bert-base-chinese-qa")

paragraph = "李同 也 沒有 在意 , 大廈 中 , 几乎 每 天 都 有 人 搬進 搬出 , 原 不足為奇 。 \
             可是 , 當 李同 走進 大廈 時 , 卻 看見 了 那 個 老者 , 那 老者 是 倒退 著 身子 走出來 的 , \
             在 那 老者 的 面前 , 兩 個 搬運 工人 , 正 抬 著 一 只 箱子 。 那 是 一 只 木 箱子 , \
             很 殘舊 了 , 箱子 并 不 大 , 但是 兩 個 搬運 工人 抬 著 , 看來 十分 吃力 。[SEP]".strip(" ")

question = "[CLS]老者怎麼走出來的?[SEP]"

question_tokens = tokenizer.tokenize(question)
paragraph_tokens = tokenizer.tokenize(paragraph)

tokens = question_tokens + paragraph_tokens
input_ids = tokenizer.convert_tokens_to_ids(tokens)

segment_ids = [0] * len(question_tokens)
segment_ids += [1] * len(paragraph_tokens)

input_ids = torch.tensor([input_ids])
segment_ids = torch.tensor([segment_ids])

# Getting the answer

res = model(input_ids, token_type_ids=segment_ids)

start_scores, end_scores = res['start_logits'], res['end_logits']

start_index = torch.argmax(start_scores)
end_index = torch.argmax(end_scores)

print(" ".join(tokens[start_index:end_index+1]))

But, I got [CLS]. Could you provide a sample code to how how this Chinese QA model can work properly?
Thank you!

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions