Keyphrase Chunking - bert2chunk_dataloader.py分析
Keyphrase Chunking - bert2chunk_dataloader.py分析
2021SC@SDUSC
系列文章目录
BERT-KPE是最近由thunlp提出的方法,在OpenKP和KP20K上都达到了state-of-the-art和良好的鲁棒性。
文章目录
- 系列文章目录
- 前言
- tdqm库(知识补充)
- 预处理标签
- 对答案位置进行展平、排序、筛选重叠
- bert2chunk_preprocessor
- 将数据转换为张量,添加[ CLS], [ LEP]
- 训练数据加载器 && 评估数据加载器
- 测试数据加载器
前言
JointKPE直接学习某个N-gram是否是keyphase,对此,采用交叉熵损失。
tdqm库(知识补充)
Python中可以使用tqdm包来显示进度条。进度条的原理其实很简单,就是不断地删除已经打印的内容,然后重新写出新的进度,从而完成在同一行中的进度条再不断增长的效果。
使用方式:
- 直接封装可迭代对象
可以直接使用tqdm创建一个类实例,第一个参数是一个可迭代对象,即tqdm可以直接包装一个可迭代对象,从而进行迭代时就会使用进度条了,比如range(100)一个简单的可迭代对象:
from tqdm import tqdm
import time, randomfor i in tqdm(range(10)):time.sleep(random.random())
- 实例化一个tqdm类
上文中一个简单的使用tqdm直接封装range可迭代对象的代码可以简单地使用trange代替,即 trange(X) = tqdm(range(X)),例如下边的例子等同于上边中的第一个例子:
from tqdm import trange
import time, randomfor i in trange(10):time.sleep(random.random())
- 使用with语句
一个实例化的tqdm也需要在使用完毕后通过close方法清理资源,这和打开一个文件进行处理是很类似的,因此同样可以使用with语句,让其在执行完后自动清理,就不再需要使用close方法手动关闭了:
from tqdm import tqdm
import time, randomwith tqdm(total=100) as p_bar:for i in range(50):time.sleep(random.random())p_bar.update(2)p_bar.set_description("Processing {}-th iteration".format(i+1))
预处理标签
对答案位置进行展平、排序、筛选重叠
def get_ngram_label(valid_length, start_end_pos, max_phrase_words):# flatten, rank, filter overlap for answer positionssorted_positions = loader_utils.flat_rank_pos(start_end_pos)filter_positions = loader_utils.limit_phrase_length(sorted_positions, max_phrase_words)if len(filter_positions) != len(sorted_positions):overlen_flag = Trueelse:overlen_flag = Falses_label, e_label = [], []for s, e in filter_positions:if e < valid_length:s_label.append(s)e_label.append(e)else:breakassert len(s_label) == len(e_label)return {"s_label": s_label, "e_label": e_label, "overlen_flag": overlen_flag}
- 处理标记后的数据
if len(tokenize_output["tokens"]) < max_token:max_word = max_tokenelse:max_word = tokenize_output["tok_to_orig_index"][max_token - 1] + 1new_ex = {}new_ex["url"] = ex["url"]new_ex["tokens"] = tokenize_output["tokens"][:max_token]new_ex["valid_mask"] = tokenize_output["valid_mask"][:max_token]new_ex["doc_words"] = ex["doc_words"][:max_word]assert len(new_ex["tokens"]) == len(new_ex["valid_mask"])assert sum(new_ex["valid_mask"]) == len(new_ex["doc_words"])
- 实现对要返回的数据的筛选处理
if mode == "train":parameter = {"valid_length": len(new_ex["doc_words"]),"start_end_pos": ex["start_end_pos"],"max_phrase_words": max_phrase_words,}# ------------------------------------------------label_dict = get_ngram_label(**parameter)if label_dict["overlen_flag"]:overlen_num += 1if not label_dict["s_label"]:continuenew_ex["s_label"] = label_dict["s_label"]new_ex["e_label"] = label_dict["e_label"]new_examples.append(new_ex)logger.info("Delete Overlen Keyphrase (length > 5): %d (overlap / total = %.2f"% (overlen_num, float(overlen_num / len(examples) * 100))+ "%)")return new_examples
bert2chunk_preprocessor
- 传入的参数
examples,
tokenizer,
max_token,
pretrain_model,
mode,
max_phrase_words,
stem_flag=False,
- 日志记录(使用的预训练模式和功能)
logger.info("start preparing (%s) features for bert2chunk (%s) ..." % (mode, pretrain_model))
- 调用预训练模型标记
overlen_num = 0new_examples = []for idx, ex in enumerate(tqdm(examples)):# tokenizetokenize_output = loader_utils.tokenize_for_bert(doc_words=ex["doc_words"], tokenizer=tokenizer)
将数据转换为张量,添加[ CLS], [ LEP]
src_tokens = [BOS_WORD] + ex["tokens"] + [EOS_WORD]valid_ids = [0] + ex["valid_mask"] + [0]src_tensor = torch.LongTensor(tokenizer.convert_tokens_to_ids(src_tokens))valid_mask = torch.LongTensor(valid_ids)orig_doc_len = sum(valid_ids)if mode == "train":s_label = ex["s_label"]e_label = ex["e_label"]return (index,src_tensor,valid_mask,s_label,e_label,orig_doc_len,max_phrase_words,)else:return index, src_tensor, valid_mask, orig_doc_len, max_phrase_words
训练数据加载器 && 评估数据加载器
def batchify_bert2chunk_features_for_train(batch):
- 准备操作
ids = [ex[0] for ex in batch]docs = [ex[1] for ex in batch]valid_mask = [ex[2] for ex in batch]s_label_list = [ex[3] for ex in batch]e_label_list = [ex[4] for ex in batch]doc_word_lens = [ex[5] for ex in batch]max_phrase_words = [ex[6] for ex in batch][0]bert_output_dim = 768max_word_len = max([word_len for word_len in doc_word_lens]) # word-level
- 标记张量
# [1] [2] src tokens tensordoc_max_length = max([d.size(0) for d in docs])input_ids = torch.LongTensor(len(docs), doc_max_length).zero_()input_mask = torch.LongTensor(len(docs), doc_max_length).zero_()for i, d in enumerate(docs):input_ids[i, : d.size(0)].copy_(d)input_mask[i, : d.size(0)].fill_(1)
- 有效掩模张量
valid_max_length = max([v.size(0) for v in valid_mask])valid_ids = torch.LongTensor(len(valid_mask), valid_max_length).zero_()for i, v in enumerate(valid_mask):valid_ids[i, : v.size(0)].copy_(v)
- 主动掩码:n-gram
# [4] active mask : for n-grammax_ngram_length = sum([max_word_len - n for n in range(max_phrase_words)])active_mask = torch.LongTensor(len(docs), max_ngram_length).zero_()for batch_i, word_len in enumerate(doc_word_lens):pad_len = max_word_len - word_lenbatch_mask = []for n in range(max_phrase_words):ngram_len = word_len - nif ngram_len > 0:gram_list = [1 for _ in range(ngram_len)] + [0 for _ in range(pad_len)]else:gram_list = [0 for _ in range(max_word_len - n)]batch_mask.extend(gram_list)active_mask[batch_i].copy_(torch.LongTensor(batch_mask))
- 对于n-gram的标签
- 清空标签列表 - empty label list
label_list = []for _ in range(len(docs)):batch_label = []for n in range(max_phrase_words):batch_label.append(torch.LongTensor([0 for _ in range(max_word_len - n)]))label_list.append(batch_label)
- 有效标签列表 - valid label list
for batch_i in range(len(docs)):for s, e in zip(s_label_list[batch_i], e_label_list[batch_i]):gram = e - slabel_list[batch_i][gram][s] = 1
- 标签张量 - label tensor
ngram_label = torch.LongTensor(len(docs), max_ngram_length).zero_()for batch_i, label in enumerate(label_list):ngram_label[batch_i].copy_(torch.cat(label))
- 有效输出 - valid output
valid_output = torch.zeros(len(docs), max_word_len, bert_output_dim)return input_ids, input_mask, valid_ids, active_mask, valid_output, ngram_label, ids
测试数据加载器
def batchify_bert2chunk_features_for_test(batch):
“”" test dataloader for Dev & Public_Valid."""
- 准备操作
ids = [ex[0] for ex in batch]docs = [ex[1] for ex in batch]valid_mask = [ex[2] for ex in batch]doc_word_lens = [ex[3] for ex in batch]max_phrase_words = [ex[4] for ex in batch][0]bert_output_dim = 768max_word_len = max([word_len for word_len in doc_word_lens]) # word-level
- 标记张量
# [1] [2] src tokens tensordoc_max_length = max([d.size(0) for d in docs])input_ids = torch.LongTensor(len(docs), doc_max_length).zero_()input_mask = torch.LongTensor(len(docs), doc_max_length).zero_()for i, d in enumerate(docs):input_ids[i, : d.size(0)].copy_(d)input_mask[i, : d.size(0)].fill_(1)
- 有效掩码
# [3] valid mask tensorvalid_max_length = max([v.size(0) for v in valid_mask])valid_ids = torch.LongTensor(len(valid_mask), valid_max_length).zero_()for i, v in enumerate(valid_mask):valid_ids[i, : v.size(0)].copy_(v)
- 激活掩码 - for n-gram
# [4] active mask : for n-grammax_ngram_length = sum([max_word_len - n for n in range(max_phrase_words)])active_mask = torch.LongTensor(len(docs), max_ngram_length).zero_()for batch_i, word_len in enumerate(doc_word_lens):pad_len = max_word_len - word_lenbatch_mask = []for n in range(max_phrase_words):ngram_len = word_len - nif ngram_len > 0:gram_list = [1 for _ in range(ngram_len)] + [0 for _ in range(pad_len)]else:gram_list = [0 for _ in range(max_word_len - n)]batch_mask.extend(gram_list)active_mask[batch_i].copy_(torch.LongTensor(batch_mask))
- 有效输出
# [5] valid outputvalid_output = torch.zeros(len(docs), max_word_len, bert_output_dim)return (input_ids,input_mask,valid_ids,active_mask,valid_output,doc_word_lens,ids,)