NLP transformers - 翻译

from transformers import AutoTokenizer

#加载编码器
tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-en-ro',
                                          use_fast=True)

print(tokenizer)

#编码试算
tokenizer.batch_encode_plus(
    [['Hello, this one sentence!', 'This is another sentence.']])

PreTrainedTokenizer(name_or_path='Helsinki-NLP/opus-mt-en-ro', vocab_size=59543, model_max_len=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>'})
{'input_ids': [[125, 778, 3, 63, 141, 9191, 23, 187, 32, 716, 9191, 2, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

from datasets import load_dataset, load_from_disk

#加载数据
dataset = load_dataset(path='wmt16', name='ro-en')
# dataset = load_from_disk('datas/wmt16/ro-en')

#采样,数据量太大了跑不动
dataset['train'] = dataset['train'].shuffle(1).select(range(20000))
dataset['validation'] = dataset['validation'].shuffle(1).select(range(200))
dataset['test'] = dataset['test'].shuffle(1).select(range(200))


#数据预处理
def preprocess_function(data):
    #取出数据中的en和ro
    en = [ex['en'] for ex in data['translation']]
    ro = [ex['ro'] for ex in data['translation']]

    #源语言直接编码就行了
    data = tokenizer.batch_encode_plus(en, max_length=128, truncation=True)

    #目标语言在特殊模块中编码
    with tokenizer.as_target_tokenizer():
        data['labels'] = tokenizer.batch_encode_plus(
            ro, max_length=128, truncation=True)['input_ids']

    return data


dataset = dataset.map(function=preprocess_function,
                      batched=True,
                      batch_size=1000,
                      num_proc=4,
                      remove_columns=['translation'])

print(dataset['train'][0])

dataset

{'input_ids': [460, 354, 3794, 12, 10677, 20, 5046, 14, 4, 2546, 37, 8, 397, 5551, 30, 10113, 37, 3501, 19814, 18, 8465, 20, 4, 44690, 782, 2, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'labels': [902, 576, 2946, 76, 10815, 17, 5098, 14997, 5, 559, 1140, 43, 2434, 6624, 27, 50, 337, 19216, 46, 22174, 17, 2317, 121, 16825, 2, 0]}
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 20000
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 200
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 200
    })
})

#这个函数和下面这个工具类等价,但我也是模仿实现的,不确定有没有出入
#from transformers import DataCollatorForSeq2Seq
#DataCollatorForSeq2Seq(tokenizer, model=model)

import torch


#数据批处理函数
def collate_fn(data):
    #求最长的label
    max_length = max([len(i['labels']) for i in data])

    #把所有的label都补pad到最长
    for i in data:
        pads = [-100] * (max_length - len(i['labels']))
        i['labels'] = i['labels'] + pads

    #把多个数据整合成一个tensor
    data = tokenizer.pad(
        encoded_inputs=data,
        padding=True,
        max_length=None,
        pad_to_multiple_of=None,
        return_tensors='pt',
    )

    #定义decoder_input_ids
    data['decoder_input_ids'] = torch.full_like(data['labels'],
                                                tokenizer.get_vocab()['<pad>'],
                                                dtype=torch.long)
    data['decoder_input_ids'][:, 1:] = data['labels'][:, :-1]
    data['decoder_input_ids'][data['decoder_input_ids'] ==
                              -100] = tokenizer.get_vocab()['<pad>']

    return data


data = [{
    'input_ids': [21603, 10, 37, 3719, 13],
    'attention_mask': [1, 1, 1, 1, 1],
    'labels': [10455, 120, 80]
}, {
    'input_ids': [21603, 10, 7086, 8408, 563],
    'attention_mask': [1, 1, 1, 1, 1],
    'labels': [301, 53, 4074, 1669]
}]

collate_fn(data)['decoder_input_ids']

tensor([[59542, 10455,   120,    80],
        [59542,   301,    53,  4074]])

import torch

#数据加载器
loader = torch.utils.data.DataLoader(
    dataset=dataset['train'],
    batch_size=8,
    collate_fn=collate_fn,
    shuffle=True,
    drop_last=True,
)

for i, data in enumerate(loader):
    break

for k, v in data.items():
    print(k, v.shape, v[:2])

len(loader)

from transformers import AutoModelForSeq2SeqLM, MarianModel

#加载模型
#model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-en-ro')


#定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.pretrained = MarianModel.from_pretrained(
            'Helsinki-NLP/opus-mt-en-ro')

        self.register_buffer('final_logits_bias',
                             torch.zeros(1, tokenizer.vocab_size))

        self.fc = torch.nn.Linear(512, tokenizer.vocab_size, bias=False)

        #加载预训练模型的参数
        parameters = AutoModelForSeq2SeqLM.from_pretrained(
            'Helsinki-NLP/opus-mt-en-ro')
        self.fc.load_state_dict(parameters.lm_head.state_dict())

        self.criterion = torch.nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask, labels, decoder_input_ids):
        logits = self.pretrained(input_ids=input_ids,
                                 attention_mask=attention_mask,
                                 decoder_input_ids=decoder_input_ids)
        logits = logits.last_hidden_state

        logits = self.fc(logits) + self.final_logits_bias

        loss = self.criterion(logits.flatten(end_dim=1), labels.flatten())

        return {'loss': loss, 'logits': logits}


model = Model()

#统计参数量
print(sum(i.numel() for i in model.parameters()) / 10000)

#out = model(**data)
#out['loss'], out['logits'].shape

from datasets import load_metric

#加载评价函数
metric = load_metric(path='sacrebleu')

#试算
metric.compute(predictions=['hello there', 'general kenobi'],
               references=[['hello there'], ['general kenobi']])



测试

#测试
def test():
    model.eval()

    #数据加载器
    loader_test = torch.utils.data.DataLoader(
        dataset=dataset['test'],
        batch_size=8,
        collate_fn=collate_fn,
        shuffle=True,
        drop_last=True,
    )

    predictions = []
    references = []
    for i, data in enumerate(loader_test):
        #计算
        with torch.no_grad():
            out = model(**data)

        pred = tokenizer.batch_decode(out['logits'].argmax(dim=2))
        label = tokenizer.batch_decode(data['decoder_input_ids'])
        predictions.extend(pred)
        references.extend(label)

        if i % 2 == 0:
            print(i)
            input_ids = tokenizer.decode(data['input_ids'][0])

            print('input_ids=', input_ids)
            print('pred=', pred[0])
            print('label=', label[0])

        if i == 10:
            break

    references = [[j] for j in references]
    metric_out = metric.compute(predictions=predictions, references=references)
    print(metric_out)


test()



from transformers import AdamW
from transformers.optimization import get_scheduler


#训练
def train():
    optimizer = AdamW(model.parameters(), lr=2e-5)
    scheduler = get_scheduler(name='linear',
                              num_warmup_steps=0,
                              num_training_steps=len(loader),
                              optimizer=optimizer)

    model.train()
    for i, data in enumerate(loader):
        out = model(**data)
        loss = out['loss']

        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

        optimizer.step()
        scheduler.step()

        optimizer.zero_grad()
        model.zero_grad()

        if i % 50 == 0:
            out = out['logits'].argmax(dim=2)
            correct = (data['decoder_input_ids'] == out).sum().item()
            total = data['decoder_input_ids'].shape[1] * 8
            accuracy = correct / total
            del correct
            del total

            predictions = []
            references = []
            for j in range(8):
                pred = tokenizer.decode(out[j])
                label = tokenizer.decode(data['decoder_input_ids'][j])
                predictions.append(pred)
                references.append([label])

            metric_out = metric.compute(predictions=predictions,
                                        references=references)

            lr = optimizer.state_dict()['param_groups'][0]['lr']

            print(i, loss.item(), accuracy, metric_out, lr)

    torch.save(model, 'models/7.翻译.model')


train()



model = torch.load('models/7.翻译.model')
test()











本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/582562.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

OSPF大作业

一&#xff0c;拓扑 二&#xff0c;要求 1&#xff0c;r4为ISP上只配ip&#xff0c;r3与其他设备之间只使用公有IP 2&#xff0c;r3-r5/6/7为MGRE环境&#xff0c;r3为中心 3&#xff0c;整个OSPF的ip地址基于172.16.0.0/16划分 所以设备都可以访问r4的环回 4减少LSA的数量加快…

【转载】C++代码中将函数返回类型后置有啥好处吗

C代码中将函数返回类型后置有啥好处吗&#xff1f; 内容如下&#xff1a; C代码中将函数返回类型后置有啥好处吗&#xff1f; 这种语法是 C11 新增的&#xff0c;学名叫 trailing return type[1]。翻译过来是后置返回类型&#xff0c;trailing 是后面的、拖尾的意思。书写 int …

质量管理,怎能不知道SPC?

SPC&#xff08;统计过程控制&#xff09;是质量管理的一个重要组成部分&#xff0c;它可以帮助企业更好地控制生产过程、提高产品质量、降低生产成本、增强企业的竞争力。 SPC是一种应用统计技术对过程中的各个阶段进行评估和监控&#xff0c;建立并保持过程处于可接受的并且稳…

深入理解Linux调试工具eBPF和strace、内存泄漏处理、Kubernetes容器调试以及C++协程的崩溃信息收集

在软件开发领域&#xff0c;无论是初级开发者还是资深工程师&#xff0c;都需要面对复杂的调试工作。本文将介绍几个重要的调试工具和技术&#xff0c;并提供实际调试方法的指导&#xff0c;包括Linux环境下的eBPF和strace&#xff0c;内存泄漏问题的处理&#xff0c;Kubernete…

020Node.js的FS模块使用fs.mkdir创建目录

Node.js的FS模块使用fs.mkdir创建目录 //fs.mkdir 创建目录 /*path 将创建的目录路径mode 目录权限&#xff08;读写权限&#xff09;&#xff0c;默认777callback 回调&#xff0c;传递异常参数err*/ const fsrequire(fs);fs.mkdir(./css,(err)>{if(err){console.log(err)…

时间序列模型(含python程序实现)

常用按时间顺序排列的一组随机变量来表示一个随机事件的时间序列&#xff0c;简记为 用表示该随机序列的n个有序观察值&#xff0c;称之为序列长度为n的观察值序列。 常用的时间序列模型 时间序列的预处理 拿到一个观察值序列后&#xff0c;首先要对它的纯随机性和平稳性进行…

PC-3000 Mobile Pro: 智能手机及平板设备数据提取及取证工具

天津鸿萌科贸发展有限公司从事数据安全业务20余年&#xff0c;在数据恢复、数据取证、数据备份等领域有丰富的案例经验、前沿专业技术及良好的行业口碑。同时&#xff0c;公司面向取证机构及数据恢复公司&#xff0c;提供数据恢复实验室建设方案&#xff0c;包含数据恢复硬件设…

书生·浦语 大模型(学习笔记-9)大模型微调的相关概念 预训练 与 微调 全量微调FFT 与 PEFT的区别

目录 一、什么是大模型微调 二、指令微调 三、微调的目的 三、微调的方式 四、微调的步骤 五、微调数据准备 六、微调的数据质量 一、什么是大模型微调 预训练和微调的区别&#xff0c;这个很关键 二、指令微调 这个地方主要是微调的角度不同&#xff0c;简单理解&#…

linux jmeter ant下载并安装【2024-亲测】

环境 centos7 一、下载jmeter 在这里插入代码片wget https://dlcdn.apache.org//jmeter/binaries/apache-jmeter-5.6.3.tgz --no-check-certificate解压 tar -zxvf apache-jmeter-5.6.3.tgz复制到安装目录、设置环境变量 vim /etc/profile添加环境变量&#xff0c;路径改成…

YOLOv8: 快速而准确的对象检测

YOLOv8: 快速而准确的对象检测 背景 对象检测是计算机视觉中的一个关键任务,它可以帮助我们在图像或视频中识别和定位感兴趣的物体。其中,YOLO(You Only Look Once)系列是一类非常出色的实时对象检测算法,以其快速和准确的特点而闻名。YOLOv8是YOLO系列的最新版本,由Ultralyti…

【ESP32S3】使用 Flash 下载工具完成 Flash 加密功能

此篇文档记录通过 Flash 下载工具 完成 Flash 加密 功能的实现&#xff0c;此文档不启用 Flash 加密方案的 NVS 加密。 Flash 加密启动的验证代码&#xff1a;esp-idf/components/bootloader_support/src/flash_encrypt.c Flash 加密测试例程&#xff1a;esp-idf/examples/sec…

【yolov8目标检测部署】TensorRT int8量化

原作者github&#xff1a;https://github.com/xuanandsix/Tensorrt-int8-quantization-pipline/tree/main 改进&#xff1a; 源代码支持的TensorRT版本为7.许多属性已经弃用&#xff1b; 在原有的代码上将支持的TensorRT版本从7改到8. &#xff01;&#xff01;不知道如何安装T…

酷得电子机器狗玩具 MCU方案介绍

机器狗是一种多功能、互动性强的机器人&#xff0c;适合家庭和学校环境。它不仅可以陪伴孩子们玩耍&#xff0c;还能帮助他们学习和成长。功能如下&#xff1a; 关节可动&#xff1a;机器狗的关节设计灵活&#xff0c;可以执行各种动作&#xff0c;如“坐下”、“俯卧撑”、“…

【Ant-Desgin-React 步骤条】步骤条配合组件使用

步骤条配合组件使用 基础使用多分组进度 基础使用 /* eslint-disable no-unused-vars */ import React, { useState } from react import { Button, message, Steps, theme } from antd import After from ./components/after import Now from ./components/now const steps …

纯js对比excel小工具

如何使用JavaScript和xlsx.js实现Excel文件对比&#xff1a;实战指南 在日常办公或数据分析工作中&#xff0c;我们经常需要比较两个Excel文件中的数据差异。手动对比不仅耗时费力&#xff0c;还容易出错。本文将带你通过一个简单的网页应用&#xff0c;利用JavaScript和开源库…

【产品经理】如果人人都是产品经理,那么如何提升自己的不可替代性?

任何职业都需要有危机感&#xff0c;不只是产品经理。 人有生老病死&#xff0c;相应的职场上也有升降变离。当乔布斯站在宇宙之巅望着芸芸众生说“活着就是为了改变世界”的时候&#xff0c;这话着实燃烧了我们一把。随之&#xff0c;马化腾、周鸿祎、张小龙、王小川等汹涌而入…

前端canvas项目实战——在线图文编辑器(九):逻辑画布

目录 前言一、 效果展示二、 实现步骤1. 调整布局&#xff0c;最大化利用屏幕空间2. 添加逻辑画布3. 添加遮罩4. 居中显示逻辑画布5. 一个容易被忽视的bug点 三、Show u the code后记 前言 上一篇博文中&#xff0c;我们实现了一组通用的功能按钮&#xff1a;复制、删除、锁定…

Eclipse内存分析器 Java内存分析工具MAT(Memory Analyzer Tool)的介绍与使用

1.visualvm实时监测 2.Memory Analyzer Tool打开 3.工具的使用可以参考 Java内存分析工具MAT(Memory Analyzer Tool)的介绍与使用 ------------------------ 1.我远程发现是其中一个客户端A请求服务器页面响应&#xff0c;一直得不到响应&#xff0c;然后客户端A一直请求&am…

js 字符串 第一个斜杠前最后一次出现英文字母的位置并添加自定义值,返回新值

要找到字符串中第一个斜杠&#xff08;/&#xff09;前最后一次英文字母出现的位置&#xff0c;可以使用正则表达式配合lastIndexOf方法。以下是实现这一功能的示例代码&#xff1a; 如果是匹配第一个数字前的字母加值可以看这里 function findLastLetterIndexBeforeSlash(str…

外贸旺季外贸人如何做好时间管理

第1步 记住这些原则 50-30-20原则 你工作日里50%的时间应该花在有益于你长期发展目标的事情上&#xff0c;30%的时间应该用于你完成中期(两年左右)目标的事情&#xff0c;20%的时间用于完成未来90天以内需要完成的任务。 “一个篮子”原则 One Bucket 尽可能减少自己接收新任务…