当前位置: 首页>编程笔记>正文

Keyphrase Chunking - bert2chunk_dataloader.py分析

Keyphrase Chunking - bert2chunk_dataloader.py分析

2021SC@SDUSC

系列文章目录

BERT-KPE是最近由thunlp提出的方法,在OpenKP和KP20K上都达到了state-of-the-art和良好的鲁棒性。


文章目录

  • 系列文章目录
  • 前言
    • tdqm库(知识补充)
  • 预处理标签
    • 对答案位置进行展平、排序、筛选重叠
    • bert2chunk_preprocessor
    • 将数据转换为张量,添加[ CLS], [ LEP]
    • 训练数据加载器 && 评估数据加载器
    • 测试数据加载器


前言

JointKPE直接学习某个N-gram是否是keyphase,对此,采用交叉熵损失。
在这里插入图片描述

tdqm库(知识补充)

Python中可以使用tqdm包来显示进度条。进度条的原理其实很简单,就是不断地删除已经打印的内容,然后重新写出新的进度,从而完成在同一行中的进度条再不断增长的效果。

使用方式:

  • 直接封装可迭代对象
    可以直接使用tqdm创建一个类实例,第一个参数是一个可迭代对象,即tqdm可以直接包装一个可迭代对象,从而进行迭代时就会使用进度条了,比如range(100)一个简单的可迭代对象:
from tqdm import tqdm
import time, randomfor i in tqdm(range(10)):time.sleep(random.random())
  • 实例化一个tqdm类
    上文中一个简单的使用tqdm直接封装range可迭代对象的代码可以简单地使用trange代替,即 trange(X) = tqdm(range(X)),例如下边的例子等同于上边中的第一个例子:
from tqdm import trange
import time, randomfor i in trange(10):time.sleep(random.random())
  • 使用with语句
    一个实例化的tqdm也需要在使用完毕后通过close方法清理资源,这和打开一个文件进行处理是很类似的,因此同样可以使用with语句,让其在执行完后自动清理,就不再需要使用close方法手动关闭了:
from tqdm import tqdm
import time, randomwith tqdm(total=100) as p_bar:for i in range(50):time.sleep(random.random())p_bar.update(2)p_bar.set_description("Processing {}-th iteration".format(i+1))

预处理标签

对答案位置进行展平、排序、筛选重叠

def get_ngram_label(valid_length, start_end_pos, max_phrase_words):# flatten, rank, filter overlap for answer positionssorted_positions = loader_utils.flat_rank_pos(start_end_pos)filter_positions = loader_utils.limit_phrase_length(sorted_positions, max_phrase_words)if len(filter_positions) != len(sorted_positions):overlen_flag = Trueelse:overlen_flag = Falses_label, e_label = [], []for s, e in filter_positions:if e < valid_length:s_label.append(s)e_label.append(e)else:breakassert len(s_label) == len(e_label)return {"s_label": s_label, "e_label": e_label, "overlen_flag": overlen_flag}
  • 处理标记后的数据
if len(tokenize_output["tokens"]) < max_token:max_word = max_tokenelse:max_word = tokenize_output["tok_to_orig_index"][max_token - 1] + 1new_ex = {}new_ex["url"] = ex["url"]new_ex["tokens"] = tokenize_output["tokens"][:max_token]new_ex["valid_mask"] = tokenize_output["valid_mask"][:max_token]new_ex["doc_words"] = ex["doc_words"][:max_word]assert len(new_ex["tokens"]) == len(new_ex["valid_mask"])assert sum(new_ex["valid_mask"]) == len(new_ex["doc_words"])
  • 实现对要返回的数据的筛选处理
if mode == "train":parameter = {"valid_length": len(new_ex["doc_words"]),"start_end_pos": ex["start_end_pos"],"max_phrase_words": max_phrase_words,}# ------------------------------------------------label_dict = get_ngram_label(**parameter)if label_dict["overlen_flag"]:overlen_num += 1if not label_dict["s_label"]:continuenew_ex["s_label"] = label_dict["s_label"]new_ex["e_label"] = label_dict["e_label"]new_examples.append(new_ex)logger.info("Delete Overlen Keyphrase (length > 5): %d (overlap / total = %.2f"% (overlen_num, float(overlen_num / len(examples) * 100))+ "%)")return new_examples

bert2chunk_preprocessor

  • 传入的参数

examples,
tokenizer,
max_token,
pretrain_model,
mode,
max_phrase_words,
stem_flag=False,

  • 日志记录(使用的预训练模式和功能)
 logger.info("start preparing (%s) features for bert2chunk (%s) ..." % (mode, pretrain_model))
  • 调用预训练模型标记
overlen_num = 0new_examples = []for idx, ex in enumerate(tqdm(examples)):# tokenizetokenize_output = loader_utils.tokenize_for_bert(doc_words=ex["doc_words"], tokenizer=tokenizer)

将数据转换为张量,添加[ CLS], [ LEP]

 src_tokens = [BOS_WORD] + ex["tokens"] + [EOS_WORD]valid_ids = [0] + ex["valid_mask"] + [0]src_tensor = torch.LongTensor(tokenizer.convert_tokens_to_ids(src_tokens))valid_mask = torch.LongTensor(valid_ids)orig_doc_len = sum(valid_ids)if mode == "train":s_label = ex["s_label"]e_label = ex["e_label"]return (index,src_tensor,valid_mask,s_label,e_label,orig_doc_len,max_phrase_words,)else:return index, src_tensor, valid_mask, orig_doc_len, max_phrase_words

训练数据加载器 && 评估数据加载器

def batchify_bert2chunk_features_for_train(batch):

  • 准备操作
ids = [ex[0] for ex in batch]docs = [ex[1] for ex in batch]valid_mask = [ex[2] for ex in batch]s_label_list = [ex[3] for ex in batch]e_label_list = [ex[4] for ex in batch]doc_word_lens = [ex[5] for ex in batch]max_phrase_words = [ex[6] for ex in batch][0]bert_output_dim = 768max_word_len = max([word_len for word_len in doc_word_lens])  # word-level
  • 标记张量
# [1] [2] src tokens tensordoc_max_length = max([d.size(0) for d in docs])input_ids = torch.LongTensor(len(docs), doc_max_length).zero_()input_mask = torch.LongTensor(len(docs), doc_max_length).zero_()for i, d in enumerate(docs):input_ids[i, : d.size(0)].copy_(d)input_mask[i, : d.size(0)].fill_(1)
  • 有效掩模张量
valid_max_length = max([v.size(0) for v in valid_mask])valid_ids = torch.LongTensor(len(valid_mask), valid_max_length).zero_()for i, v in enumerate(valid_mask):valid_ids[i, : v.size(0)].copy_(v)
  • 主动掩码:n-gram
 # [4] active mask : for n-grammax_ngram_length = sum([max_word_len - n for n in range(max_phrase_words)])active_mask = torch.LongTensor(len(docs), max_ngram_length).zero_()for batch_i, word_len in enumerate(doc_word_lens):pad_len = max_word_len - word_lenbatch_mask = []for n in range(max_phrase_words):ngram_len = word_len - nif ngram_len > 0:gram_list = [1 for _ in range(ngram_len)] + [0 for _ in range(pad_len)]else:gram_list = [0 for _ in range(max_word_len - n)]batch_mask.extend(gram_list)active_mask[batch_i].copy_(torch.LongTensor(batch_mask))
  • 对于n-gram的标签
    • 清空标签列表 - empty label list
label_list = []for _ in range(len(docs)):batch_label = []for n in range(max_phrase_words):batch_label.append(torch.LongTensor([0 for _ in range(max_word_len - n)]))label_list.append(batch_label)
  • 有效标签列表 - valid label list
for batch_i in range(len(docs)):for s, e in zip(s_label_list[batch_i], e_label_list[batch_i]):gram = e - slabel_list[batch_i][gram][s] = 1
  • 标签张量 - label tensor
ngram_label = torch.LongTensor(len(docs), max_ngram_length).zero_()for batch_i, label in enumerate(label_list):ngram_label[batch_i].copy_(torch.cat(label))
  • 有效输出 - valid output
 valid_output = torch.zeros(len(docs), max_word_len, bert_output_dim)return input_ids, input_mask, valid_ids, active_mask, valid_output, ngram_label, ids

测试数据加载器

def batchify_bert2chunk_features_for_test(batch):
“”" test dataloader for Dev & Public_Valid."""

  • 准备操作
 ids = [ex[0] for ex in batch]docs = [ex[1] for ex in batch]valid_mask = [ex[2] for ex in batch]doc_word_lens = [ex[3] for ex in batch]max_phrase_words = [ex[4] for ex in batch][0]bert_output_dim = 768max_word_len = max([word_len for word_len in doc_word_lens])  # word-level
  • 标记张量
# [1] [2] src tokens tensordoc_max_length = max([d.size(0) for d in docs])input_ids = torch.LongTensor(len(docs), doc_max_length).zero_()input_mask = torch.LongTensor(len(docs), doc_max_length).zero_()for i, d in enumerate(docs):input_ids[i, : d.size(0)].copy_(d)input_mask[i, : d.size(0)].fill_(1)
  • 有效掩码
 # [3] valid mask tensorvalid_max_length = max([v.size(0) for v in valid_mask])valid_ids = torch.LongTensor(len(valid_mask), valid_max_length).zero_()for i, v in enumerate(valid_mask):valid_ids[i, : v.size(0)].copy_(v)
  • 激活掩码 - for n-gram
 # [4] active mask : for n-grammax_ngram_length = sum([max_word_len - n for n in range(max_phrase_words)])active_mask = torch.LongTensor(len(docs), max_ngram_length).zero_()for batch_i, word_len in enumerate(doc_word_lens):pad_len = max_word_len - word_lenbatch_mask = []for n in range(max_phrase_words):ngram_len = word_len - nif ngram_len > 0:gram_list = [1 for _ in range(ngram_len)] + [0 for _ in range(pad_len)]else:gram_list = [0 for _ in range(max_word_len - n)]batch_mask.extend(gram_list)active_mask[batch_i].copy_(torch.LongTensor(batch_mask))
  • 有效输出
# [5] valid outputvalid_output = torch.zeros(len(docs), max_word_len, bert_output_dim)return (input_ids,input_mask,valid_ids,active_mask,valid_output,doc_word_lens,ids,)

https://www.nshth.com/bcbj/552.html
>

相关文章:

  • 2019北京智源大會,智源 - 看山杯 專家發現算法大賽 2019 知乎
  • 電腦id和ip是一個嗎,【運維心得】網絡ID與網絡IP的區別你知道嗎?
  • Deep Learning for Matching in Search and Recommendation 搜索與推薦中的深度學習匹配(1 引言)
  • woocommerce好用嗎,woocommerce 分類到菜單_我如何為每個WooCommerce產品類別創建不同的菜單?
  • wordpress底部菜單插件,sysbios掛鉤函數使用_使用動作掛鉤自定義WordPress主題
  • 醫學影像成像原理,醫學成像模式~~~
  • 學生請假系統app,基于微信小程序的學生請假系統開發
  • 上課睡覺,學生上課睡覺班主任怎么處理_學生上課睡覺,你能正確處理嗎?
  • 對計算機老師的課堂教學評價,計算機課學生評價用語,關于學生上課的評語及評課用語
  • 為什很多學生上課睡覺,學生上課睡覺班主任怎么處理_學生上課睡覺怎么辦
  • win10小盾牌怎么去掉,windows10軟件圖標去除小盾牌
  • matlab常用命令,matlab基礎之變量,matlab基礎知識(4):特殊變量
  • 應用程序右下角有個盾牌,Win7文件右下角盾牌標志去除方法---UAC阻止程序運行
  • 對ui設計的理解和認識,我對于UI設計這個領域的理解
  • 視頻直播間有哪些,直播平臺必備-百度音視頻直播 LSS
  • obs可以推流到哪些直播平臺,新版RTMP推流協議視頻直播點播平臺EasyDSS在進行視頻直播/錄像回看時如何創建視頻錄像計劃?
  • 歐美國家需要輸入法嗎,Mac刪除默認美國輸入法
  • iphone12忘記鎖屏密碼怎么解鎖,APPLE
  • 簡體字比繁體字的好處,雜談對抽象事物的審美——繁體字與簡體字,孰美?
  • 互聯網醫美是什么,醫美互聯網公司:新氧
  • 英語陳述句疑問句祈使句感嘆句,Wh問句,疑問句,祈使句,感嘆句,10
  • 反卷積原理,超越 ConvNeXt、RepLKNet | 看 51×51 卷積核如何破萬卷!
  • 三星最新概念機,三星提出XFormer | 超越MobileViT、DeiT、MobileNet等模型
  • 統計學屬于哪個大類,第四章 專業統計(上)-統計實務
  • json模塊,模塊講解——time,datetime,json,os,requests
  • 微信復制別人的話中間有虛線,微信小程序——繪制折線圖
  • 流固耦合作用,2018結構、流體、熱分析、多物理場耦合、電磁仿真計算特點與硬件配置方案分析
  • 流固耦合作用,結構、流體、熱分析、多物理場耦合、電磁仿真硬件配置推薦2018
  • 倉庫主管崗位職責,LeetCode:Database 115.倉庫經理
  • ios如何卸載軟件,ios13測試版怎么卸載軟件,蘋果手機升ios13.2后怎么刪除app ios13.2卸載軟件應用方法...