import os import torch import torch.nn as nn from torch.utils.data import Dataset,DataLoader from seqeval.metrics import f1_score,precision_score,recall_score,classification_report def read_data(file): with open(file,"r",encoding="utf-8") as f: all_data = f.read().split("\n") sentences = [] labels = [] sentence = [] label = [] for data in all_data: data_s = data.split(" ") if len(data_s) != 2 : if len(sentence)>0 and len(label) > 0: sentences.append(sentence) labels.append(label) sentence = [] label = [] continue sent,l = data_s sentence.append(sent) label.append(l) return sentences,labels def build_word(train_text): word_2_index = {"PAD":0,"UNK":1} for text in train_text: for w in text: word_2_index[w] = word_2_index.get(w,len(word_2_index)) return word_2_index def build_tag(train_tag): tag_2_index = {"PAD":0,"UNK":1,"O":2} for text in train_tag: for w in text: tag_2_index[w] = tag_2_index.get(w,len(tag_2_index)) return tag_2_index class NDataset(Dataset): def __init__(self,all_text,all_tag,word_2_index,tag_2_index,max_len,is_dev=False): self.all_text = all_text self.all_tag = all_tag self.word_2_index = word_2_index self.tag_2_index = tag_2_index self.max_len = max_len self.is_dev = is_dev def __getitem__(self,index): max_len = self.max_len text = self.all_text[index] if self.is_dev: max_len = len(text) text = text[:max_len] tag = self.all_tag[index][:max_len] assert len(text) == len(tag) text_len = len(tag) text_idx = [self.word_2_index.get(i,1) for i in text] tag_idx = [self.tag_2_index.get(i,1) for i in tag] text_idx = text_idx + [0] * (max_len-len(text_idx)) tag_idx = tag_idx + [0] * (max_len-len(tag_idx)) return torch.tensor(text_idx),torch.tensor(tag_idx),text_len def __len__(self): return len(self.all_tag) class NModel(nn.Module): def __init__(self,corpus_len,embedding_num,tag_num): super().__init__() self.embedding = nn.Embedding(corpus_len,embedding_num) self.rnn = nn.GRU(embedding_num,150,batch_first=True,bidirectional=True) self.classifier = nn.Linear(300,tag_num) self.loss_fun = nn.CrossEntropyLoss() def forward(self,x,batch_label=None): batch_size,seq_len = x.shape x = self.embedding(x) x,_ = self.rnn(x) pre = self.classifier(x) # if batch_label is not None: if batch_label != None: loss = self.loss_fun(pre.reshape(batch_size*seq_len,-1),batch_label.reshape(-1)) return loss else: return torch.argmax(pre,dim=-1) if __name__ == "__main__": # f1-score , accuracy , precision , recall # B I E S O # B I E O # B I O train_text,train_label = read_data(os.path.join("..","data","ner","BIO","train.txt")) dev_text,dev_label = read_data(os.path.join("..","data","ner","BIO","dev.txt")) word_2_index = build_word(train_text) tag_2_index = build_tag(train_label) index_2_tag = list(tag_2_index) max_len = 25 batch_size = 10 epoch = 20 lr = 0.001 embedding_num = 150 tag_num = len(tag_2_index) corpus_len = len(word_2_index) device = "cuda" if torch.cuda.is_available() else "cpu" train_dataset = NDataset(train_text,train_label,word_2_index,tag_2_index,max_len) train_dataloader = DataLoader(train_dataset,batch_size,shuffle=False) dev_dataset = NDataset(dev_text, dev_label, word_2_index, tag_2_index, max_len,is_dev=True) dev_dataloader = DataLoader(dev_dataset, 1, shuffle=False) model = NModel(corpus_len,embedding_num,tag_num).to(device) opt = torch.optim.AdamW(model.parameters(),lr=lr) lr_scheduler = torch.optim.lr_scheduler.StepLR(opt,5,gamma=0.8) for e in range(epoch): model.train() for batch_text_idx,batch_tag_idx,batch_len in train_dataloader: batch_text_idx = batch_text_idx.to(device) batch_tag_idx = batch_tag_idx.to(device) loss = model.forward(batch_text_idx,batch_tag_idx ) loss.backward() opt.step() opt.zero_grad() lr_scheduler.step() # print(f"epoch:{e},loss:{loss:.2f}, lr:{opt.param_groups[0]['lr']}") model.eval() all_predict = [] for batch_text_idx, batch_tag_idx,batch_len in dev_dataloader: batch_text_idx = batch_text_idx.to(device) batch_tag_idx = batch_tag_idx.to(device) pre = model.forward(batch_text_idx).tolist() pre = [index_2_tag[i] for i in pre[0]] all_predict.append(pre) # right_num = 0 # all_num = 0 # # all_not_O_num = 0 # all_pre_not_O_num = 0 # # for batch_text_idx, batch_tag_idx,batch_len in dev_dataloader: # batch_text_idx = batch_text_idx.to(device) # batch_tag_idx = batch_tag_idx.to(device) # # pre = model.forward(batch_text_idx) # # all_not_O_num += torch.sum(batch_tag_idx!=2) # all_pre_not_O_num += torch.sum((pre == batch_tag_idx) & (pre!=2)) # # right_num += torch.sum((pre == batch_tag_idx)) # all_num += batch_tag_idx.shape[-1] # acc = right_num/all_num * 100 # rec = all_pre_not_O_num/all_not_O_num * 100 f1 = f1_score(dev_label,all_predict,average='macro') print(f"epoch:{e},f1:{f1*100:.2f}%") print(classification_report(dev_label,all_predict))
RNN NER
因篇幅问题不能全部显示,请点此查看更多更全内容