LLM模型

什么是LLM(大语音模型)

概述

Large Language Model(LLM),也称为大型语言模型,是一种基于机器学习和自然语言处理技术的模型,它通过对大量的文本数据进行训练,来学习服务人类语言理解和生成的能力

LLM的核心思想是通过大规模的无监督训练来学习自然语言的模式和语言结构,这在一定程度上能够模拟人类的语言认知和生成过程

与传统的NLP模型相比,LLM能够更好地理解和生成自然文本,同时还能够表现出一定的逻辑思维和推理能力

近年来,LLM得到了广泛的应用,其中最具代表性的是谷歌的BERT和OpenAI的GPT系列。这些模型在多个自然语言处理领域已经取得了显著的成果,包括文本分类、命名实体识别、情感分析、机器翻译、自动问答等

然而,在实际应用中,LLM面临着更多的挑战

  1. 首先,LLM需要大量的计算资源和大规模的数据集来训练,这对于一般的企业和个人来说十分困难
  2. 其次,由于LLM模型的复杂性和计算量较大,对于实时的语言处理应用来说,LLM在应用效率和响应速度上还存在一定的局限性

因此,如何解决模型训练和应用过程中的计算性能和效率问题,是LLM面临的主要挑战之一

微调

LLM大模型低资源微调p tuning v2和lora区别

prefix, p-tuningv2, lora finetune该怎么选择?

让天下没有难Tuning的大模型:PEFT技术简介 2023-04

大模型高效微调综述上:Adapter Tuning、AdaMix、PET、Prefix-Tuning、Prompt Tuning、P-tuning、P-tuning v2

大模型高效微调综述下: DiffPruning、BitFit、LoRa、AdaLoRA、MAM Adapters、UniPELT

微调(Fine-tuning)是一种常用的技术,用于将预训练的语言模型适应于特定的任务或领域。微调的目的是通过在特定任务上进行有监督的训练,调整模型参数以提高其性能和适应性

以下是微调在适应语言模型中的有效性的几个原因:

  1. 迁移学习:预训练的语言模型在大规模文本数据上进行了无监督的学习,从中学习到了通用的语言表示。通过微调,我们可以将这些通用的语言表示迁移到特定任务或领域上,因此可以利用模型在预训练阶段学到的知识
  2. 少样本学习:微调通常只需要在特定任务的相对较小的标注数据集上进行训练,而不是从头开始训练一个全新的模型。这对于许多任务来说是非常有益的,因为获得大规模标注数据可能是昂贵或困难的。通过利用预训练模型的泛化能力,微调可以在少量标注样本上实现较好的性能
  3. 领域自适应:通过微调,可以将语言模型从通用领域适应到特定领域。通过在特定领域的数据上微调,模型可以学习到该领域的特定语言模式、词汇和上下文,从而提高在该领域任务上的性能
  4. 模型个性化:微调还可以用于个性化模型,以适应特定用户或特定应用场景的需求。通过微调模型,可以根据个体用户的偏好、行为或数据特点进行定制,提供更准确和个性化的预测和推荐

微调语言模型是一种有效的方法,可以通过迁移学习、少样本学习、领域自适应和模型个性化等方式,利用预训练模型的优势和泛化能力,提高模型在特定任务或领域上的性能和适应性

为什么需要微调

  1. 高效训练,减少训练成本
  2. 共享基础大模型,在上面叠加自己的新模型

发展脉络

LLM微调技术发展脉络

Adapter系列

AdapterFusion: Non-Destructive Task Composition for Transfer Learning 2021

Lexicon Enhanced Chinese Sequence Labeling Using BERT Adapter 2021

LLaMA-Adapter: Efficient Fine-tuning of Language Models with Zero-init Attention 2023

LLaMA-Adapter V2: Parameter-Efficient Visual Instruction Model 2023

github LLaMA-Adapter: Efficient Fine-tuning of LLaMA

p-tunning系列

Prefix-Tuning: Optimizing Continuous Prompts for Generation 2021

The Power of Scale for Parameter-Efficient Prompt Tuning 2021

P-Tuning - GPT Understands, Too 2021

P-Tuning v2: Prompt Tuning Can Be Comparable to Fine-tuning Universally Across Scales and Tasks 2022

lora系列

LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS 2021

AdaLoRA Adaptive Budget Allocation for Parameter-Efficient Fine-Tuning 2023

QLORA: Efficient Finetuning of Quantized LLM 2023

另外huggingface很贴心的把常见的fine-Tuning方法都做了集成,只用几行代码就可添加和修改,十分方便,还有微软提供的加速库

解密Prompt系列3. 冻结LM微调Prompt: Prefix-Tuning & Prompt-Tuning & P-Tuning

微调LM和全部冻结的prompt模板相比,微调Prompt范式最大的区别就是prompt模板都是连续型(Embedding),而非和Token对应的离散型模板

核心在于我们并不关心prompt本身是否是自然语言,只关心prompt作为探针能否引导出预训练模型在下游任务上的特定能力

固定LM微调Prompt的范式有以下几个优点

  • 性价比高: 微调参数少,冻结LM只微调prompt部分的参数
  • 无人工参与: 无需人工设计prompt模板,依赖模型微调即可
  • 多任务共享模型: 因为LM被冻结,只需训练针对不同任务的prompt即可。因此可以固定预训练模型,拔插式加入Prompt用于不同下游任务

Adapter Tuning

预训练模型微调 | 一文带你了解Adapter Tuning

Adapter Tuning

随着计算机硬件性能的提高,预训练模型参数量越来越多,在训练下游任务时进行全模型微调变得昂贵且耗时,Adapter 的出现缓解了这个问题。Adapter在预训练模型每层中插入用于下游任务的参数,在微调时将模型主体冻结,仅训练特定于任务的参数,减少训练时算力开销

Adapter模块设计方法

2019年,Houlsby N等人将Adapter引入NLP领域,作为全模型微调的一种替代方案。Adapter主体架构下图所示

img

Prefix/Prompt-Tuning

hugging face参数高效微调peft源码解析

P系列关系:

  • Prefix-Tuning(软提示/连续提示)
  • Prompt-Tuning(软提示/连续提示)(可看做是Prefix-Tuning的简化版本)
  • P-Tuning(软提示/连续提示)
  • P-Tuning V2(软提示/连续提示)(可看做是Prefix-Tuning的优化版本)

Prefix Tuning和PTuning V2在实现上基本上是一样的,其实就是一样的

下面是peft作者回复的关于Prefix Tuning和PTuning V2在实现上的关系(How to switch to P-Tuning v2)

1
2
3
Hello, those are implemented together. P-Tuning v2 introduced optional parameterization of prompt tokens which you can specify via prefix_projection of PrefixTuningConfig. The other contribution was the ability of work without verbalizers using the linear classification head for NLU tasks whereas Prefix-Tuning paper which focused on NLG didn't focus on this.

So, they are supported via the same PrefixEncoder PEFT method

另外在peft/peft_model.p的代码中有这样一段(大概1106行)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
if peft_config.peft_type == PeftType.PREFIX_TUNING:
# PREFIX_TUNING、P_TUNING_V2
past_key_values = self.get_prompt(batch_size)
return self.base_model(
input_ids=input_ids, inputs_embeds=inputs_embeds, past_key_values=past_key_values, **kwargs
)
else:
# PROMPT_TUNING、P_TUNING
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
# concat prompt labels
if labels is not None:
prefix_labels = torch.full((batch_size, peft_config.num_virtual_tokens), -100).to(labels.device)
kwargs["labels"] = torch.cat((prefix_labels, labels), dim=1)
prompts = self.get_prompt(batch_size=batch_size, task_ids=task_ids)
prompts = prompts.to(inputs_embeds.dtype)
inputs_embeds = torch.cat((prompts, inputs_embeds), dim=1)
return self.base_model(inputs_embeds=inputs_embeds, **kwargs)

可以看出PREFIX_TUNING生效是通过past_key_values传播的,下面是通过拼接到inputs_embeds上实现的

Prefix-Tuning

等待…

Prefix-Tuning可以理解是CTRL[1]模型的连续化升级版,为了生成不同领域和话题的文本,CTRL是在预训练阶段在输入文本前加入了control code,例如好评前面加’Reviews Rating:5.0’,差评前面加’Reviews Rating:1.0’, 政治评论前面加‘Politics Title:’,把语言模型的生成概率,优化成了基于文本主题的条件概率

Prefix-Tuning进一步把control code优化成了虚拟Token,每个NLP任务对应多个虚拟Token的Embedding(prefix),对于Decoder-Only的GPT,prefix只加在句首,对于Encoder-Decoder的BART,不同的prefix同时加在编码器和解码器的开头。在下游微调时,LM的参数被冻结,只有prefix部分的参数进行更新。不过这里的prefix参数不只包括embedding层而是虚拟token位置对应的每一层的activation都进行更新

Prompt-Tuning

https://github.com/google-research/prompt-tuning

等待…

Prompt-Tunning是以上prefix-Tunning的简化版本,面向NLU任务,进行了更全面的效果对比,并且在大模型上成功打平了LM微调的效果

对比Prefix-Tunning,prompt-tuning的主要差异如下:

论文使用100个prefix token作为默认参数,大于以上prefix-tuning默认的10个token,不过差异在于prompt-Tunning只对输入层(Embedding)进行微调,而Prefix是对虚拟Token对应的上游layer全部进行微调。因此Prompt-Tunning的微调参数量级要更小,且不需要修改原始模型结构,这是“简化”的来源。相同的prefix长度,Prompt-Tunning(<0.01%)微调的参数量级要比Prefix-Tunning(0.1%~1%)小10倍以上

P-Tuning

P-Tuning V1

github THUDM/P-tuning

手动尝试最优的提示无异于大海捞针,于是便有了自动离散提示搜索的方法(左图),但提示是离散的,神经网络是连续的,所以寻找的最优提示可能是次优的。p-tuning依然是固定LLM参数,利用多层感知机和LSTM对prompt进行编码,编码之后与其他向量进行拼接之后正常输入LLM。注意,训练之后只保留prompt编码之后的向量即可,无需保留编码器

p_tunning架构v1 vs 离散型promote

动机

  • 一个刻板印象是GPT不适合理解类任务,这篇就是去思考这种刻板印象是否正确
  • GPT-3采用人工构造的模版来做in context learning,人工设计的模版的变化特别敏感,加一个词或者少一个词,或者变动位置啥的都会造成比较大的变化(这里作者做了一个简单的验证实验,具体看论文)。近来的自动化搜索模版工作成本也比较高,同时以前这种离散化的token的搜索出来的结果可能并不是最优的

和prefix-tuning差不多,反正是基于这两点去设计了一种连续可微的模版

相比prefix-tuning,这里加了可微的virtual token,但是仅限于输入,没有在每层加;另外virtual token的位置也不一定是前缀,插入的位置是可选的。这里的出发点实际是把传统人工设计模版中的真实token替换成可微的virtual token

P-Tuning V2

github THUDM/P-tuning-v2

P-tuning V2论文和代码实现详解

chatGLM的浅薄解析 P-tuning V2

大模型参数高效微调技术原理综述(三)-P-Tuning、P-Tuning v2

可以简单的将P-Tuning认为是针对Prompt Tuning的改进,P-Tuning v2认为是针对Prefix Tuning的改进

概述

p_tunning架构v1 vs v2

代码示例

PrefixEncoder类,为了获得连续prompt,设计的模块

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch


class PrefixEncoder(torch.nn.Module):
r'''
The torch.nn model to encode the prefix

Input shape: (batch-size, prefix-length)

Output shape: (batch-size, prefix-length, 2*layers*hidden)
'''
def __init__(self, config):
super().__init__()
self.prefix_projection = config.prefix_projection
if self.prefix_projection:
# Use a two-layer MLP to encode the prefix
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size)

# 初始化重参数化的编码器
self.trans = torch.nn.Sequential(
torch.nn.Linear(config.hidden_size, config.prefix_hidden_size),
torch.nn.Tanh(),
torch.nn.Linear(config.prefix_hidden_size, config.num_hidden_layers * 2 * config.hidden_size)
)
else:
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size)

def forward(self, prefix: torch.Tensor):
if self.prefix_projection:
prefix_tokens = self.embedding(prefix)
past_key_values = self.trans(prefix_tokens)
else:
past_key_values = self.embedding(prefix)
return past_key_values

源码也可以看到 Prefix Tuning 与 P-Tuning v2 最主要的差别就是是否进行重新参数化编码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class BertPrefixForTokenClassification(BertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.bert = BertModel(config, add_pooling_layer=False)
self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)
self.classifier = torch.nn.Linear(config.hidden_size, config.num_labels)

from_pretrained = False
if from_pretrained:
self.classifier.load_state_dict(torch.load('model/checkpoint.pkl'))

for param in self.bert.parameters():
param.requires_grad = False

self.pre_seq_len = config.pre_seq_len
self.n_layer = config.num_hidden_layers
self.n_head = config.num_attention_heads
self.n_embd = config.hidden_size // config.num_attention_heads

self.prefix_tokens = torch.arange(self.pre_seq_len).long()
self.prefix_encoder = PrefixEncoder(config)

bert_param = 0
for name, param in self.bert.named_parameters():
bert_param += param.numel()
all_param = 0
for name, param in self.named_parameters():
all_param += param.numel()
total_param = all_param - bert_param
print('total param is {}'.format(total_param)) # 9860105

def get_prompt(self, batch_size):
prefix_tokens = self.prefix_tokens.unsqueeze(0).expand(batch_size, -1).to(self.bert.device)
# 得到连续Prompt
past_key_values = self.prefix_encoder(prefix_tokens)
# bsz, seqlen, _ = past_key_values.shape
# 改变形状
past_key_values = past_key_values.view(
batch_size,
self.pre_seq_len,
self.n_layer * 2,
self.n_head,
self.n_embd
)
past_key_values = self.dropout(past_key_values)
# 改变形状,划分成数组。每一个数组元素形状为:(2,batch_size,n_head,seq_len,head_dim)
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(2)
return past_key_values

def forward(
self,
input_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
labels=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

batch_size = input_ids.shape[0]
past_key_values = self.get_prompt(batch_size=batch_size)
prefix_attention_mask = torch.ones(batch_size, self.pre_seq_len).to(self.bert.device)
attention_mask = torch.cat((prefix_attention_mask, attention_mask), dim=1)

# 开始传递past_key_values
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
past_key_values=past_key_values,
)

...

return TokenClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)

一次前向计算中,P-tuning v2会通过self.get_prompt(batch_size=batch_size)得到要连续Prompt

BertEncoder会执行for循环,把past_key_values拆分到一个个BertLayer

1
2
3
4
5
6
7
8
9
10
11
12
13
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])

...

for i, layer_module in enumerate(self.layer):
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)

layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
...
# BertLayer
layer_module(..., past_key_value, ...)

巧妙的利用past_key_values参数,将past_key_values数组中每一个元素,拼接到BertSelfAttention中Key和Value

代码跟踪链路BertModel -> BertEncoder -> BertLayer -> BertAttention -> BertSelfAttention

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
class BertSelfAttention(nn.Module):
...

def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
# 将张量转换形状,调换维度。这个代码会在seq_length维度进行拼接,其他维度不可动
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor]:
mixed_query_layer = self.query(hidden_states)

# If this is instantiated as a cross-attention module, the keys
# and values come from an encoder; the attention mask needs to be
# such that the encoder's padding tokens are not attended to.
is_cross_attention = encoder_hidden_states is not None

if is_cross_attention and past_key_value is not None:
# reuse k,v, cross_attentions
key_layer = past_key_value[0]
value_layer = past_key_value[1]
attention_mask = encoder_attention_mask
elif is_cross_attention:
key_layer = self.transpose_for_scores(self.key(encoder_hidden_states))
value_layer = self.transpose_for_scores(self.value(encoder_hidden_states))
attention_mask = encoder_attention_mask
elif past_key_value is not None:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
else:
key_layer = self.transpose_for_scores(self.key(hidden_states))
value_layer = self.transpose_for_scores(self.value(hidden_states))

query_layer = self.transpose_for_scores(mixed_query_layer)

...

这里就会把past_key_value拼接到了原始的k、v上面,这样子就相当于给k、v添加了额外需要学习的参数了,再微调时只更新这部分新的参数即可

P-tuning V2连续Prompt代码实现仿真代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
#!/usr/bin/env Python
# -- coding: utf-8 --

"""
@version: v1.0
@author: huangyc
@file: p_tuning_test.py
@Description:
@time: 2023/6/6 15:31
"""
import torch
from torch import nn


def run():
def transpose_for_scores(x: torch.Tensor) -> torch.Tensor:
new_x_shape = x.size()[:-1] + (12, 64)
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)

prompt = torch.rand(32, 128, 48, 12, 64) # batch_size, seq_len, num_layer*2, num_head, head_size
prompt = prompt.permute([2, 0, 3, 1, 4])
print(f"P-tuningV2构造的trainable continuous embeddings形状:{prompt.shape}")
past_key_values = prompt.split(2)
num_layers = 24
hidden_dim = 768
n_head = 12
head_dim = hidden_dim // n_head
all_head_size = n_head * head_dim
hidden_states = torch.randn(32, 128, 768) # batch_size, seq_len, hidden_size
print(f"输入的向量形状:{hidden_states.shape}")
for i in range(num_layers):
past_key_value = past_key_values[i]
print(f"每一层BertLayer需要加入的prompt形状: {past_key_value.shape}")
self_attn_past_key_value = past_key_value[:2] if past_key_value is not None else None
# BertSelfAttention
query = nn.Linear(hidden_dim, all_head_size)
key = nn.Linear(hidden_dim, all_head_size)
value = nn.Linear(hidden_dim, all_head_size)

# 原始kv的大小
key_layer = transpose_for_scores(key(hidden_states))
old_key_layer_shape = key_layer.shape
print(f"经过transpose_for_scores后的key形状:{old_key_layer_shape}")
value_layer = transpose_for_scores(value(hidden_states))
old_value_layer_shape = value_layer.shape
print(f"经过transpose_for_scores后的value形状:{old_value_layer_shape}\n")

# 拼接后kv的大小
key_layer = torch.cat([past_key_value[0], key_layer], dim=2)
print(
f"past_key_value[0]的形状:{past_key_value[0].shape} 原始key_layer的形状:{old_key_layer_shape} 经过cat后的key_layer形状:{key_layer.shape}")
value_layer = torch.cat([past_key_value[1], value_layer], dim=2)
print(
f"past_key_value[1]的形状:{past_key_value[1].shape} 原始value_layer的形状:{old_value_layer_shape} 经过cat后的value_layer形状:{value_layer.shape}\n")

mixed_query_layer = query(hidden_states)
print(f"hidden_states经过query层后输出的形状:{mixed_query_layer.size()}") # batch seq len embed
query_layer = transpose_for_scores(mixed_query_layer)
print(f"经过transpose_for_scores后的query形状{query_layer.size()}") # batch

print("注意力分数开始计算")
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
print(f"attention_scores的形状:{attention_scores.size()}") # batch head seq_len seq_len
print("开始注意力汇聚计算")
context_layer = torch.matmul(attention_scores, value_layer)
print(f"注意力汇聚后输出矩阵context_layer的形状:{context_layer.size()}") # batch head seq_len embed/12
print("最后,将context_layer的形状恢复成输入hidden_states的形状")
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (768,)
context_layer = context_layer.view(new_context_layer_shape)
print(f"context_layer的形状恢复完成,其形状为:{context_layer.size()}")
print("一次P-tuningV2的BertLayer计算仿真结束")
break


if __name__ == '__main__':
run()

测试输出

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
S:\Anaconda3\envs\torch38\python.exe Q:\pyCharmWS\chatgpts\P-tuning-v2\tests\p_tuning_test.py 
P-tuningV2构造的trainable continuous embeddings形状:torch.Size([48, 32, 12, 128, 64])
输入的向量形状:torch.Size([32, 128, 768])
每一层BertLayer需要加入的prompt形状: torch.Size([2, 32, 12, 128, 64])
经过transpose_for_scores后的key形状:torch.Size([32, 12, 128, 64])
经过transpose_for_scores后的value形状:torch.Size([32, 12, 128, 64])

====================> 核心
past_key_value[0]的形状:torch.Size([32, 12, 128, 64]) 原始key_layer的形状:torch.Size([32, 12, 128, 64]) 经过cat后的key_layer形状:torch.Size([32, 12, 256, 64])
past_key_value[1]的形状:torch.Size([32, 12, 128, 64]) 原始value_layer的形状:torch.Size([32, 12, 128, 64]) 经过cat后的value_layer形状:torch.Size([32, 12, 256, 64])
====================> 核心

hidden_states经过query层后输出的形状:torch.Size([32, 128, 768])
经过transpose_for_scores后的query形状torch.Size([32, 12, 128, 64])
注意力分数开始计算
attention_scores的形状:torch.Size([32, 12, 128, 256])
开始注意力汇聚计算
注意力汇聚后输出矩阵context_layer的形状:torch.Size([32, 12, 128, 64])
最后,将context_layer的形状恢复成输入hidden_states的形状
context_layer的形状恢复完成,其形状为:torch.Size([32, 128, 768])
一次P-tuningV2的BertLayer计算仿真结束

LORA系列

点赞👍B站博主 小杨不努力0v0 + 博主相关的文章链接

LORA(转)

LoRA:训练你的GPT【论文粗读·1】

一种通过低秩近似增量矩阵的,经过广泛验证足够Robust的微调方法

摘要

随着自然语言处理(NLP)模型规模的不断增长,由于成本和资源限制,对其进行完全微调以用于下游任务的挑战日益增加

介绍低秩适应(Low-Rank Adaptation)LoRA通过引入参数矩阵来减少参数,并将GPU内存需求降低了3倍。相比于使用GPT-3进行微调,它将参数减少了10,000倍

尽管可训练参数更少,但LoRA在大多数语言模型上表现优于微调,具有更高的训练吞吐量和无推理延迟

对语言模型适应中的秩缺失进行的实证研究为LoRA的有效性提供了证据,LoRA是开源的

介绍

对于下游任务而言,完全微调大型语言模型是具有挑战性的。 受到[内在维度]的研究启发,我们提出了LoRA,具有以下优势:

  • 低任务切换开销: 一个预训练模型可以被共享并用于构建多个针对不同任务的小型LoRA模块
  • 参数高效: LoRA通过使用自适应优化器,在训练过程中使得训练更高效,并将硬件门槛降低了最多3倍
  • 无推理延迟: LoRA的简单线性设计使得可训练矩阵在部署时可以与冻结权重合并
  • 正交性: LoRA与许多先前的方法是正交的,并且可以与它们结合使用,比如前缀微调(prefix-tuning)

问题描述

模型训练时参数量评估,对于LLM,如果模型的参数时的话

  • 全量微调

使用Adam优化器下的,,并且使用混合精度,一个参数需要16个bytes来存储,这16个bytes分别为

权重需要来存储,激活值需要,为了更新权重还需存一个复制需要,优化器需要存两个值,分别是(方差),分别需要

其中占2bytes,占bytes,总共为bytes

因此,个参数就需要bytes

  • Lora

Lora在训练时,将原来的参数固定下来,只更新新增加的参数,因此Mem Required: bytes

在GPT-3的175B参数下,这里的可以达到原来的

为什么会提出Lora呢

原来的Adapter方法也是固定模型的参数,只训练MLP参数,但是这样子有几个弊端

  1. 增加了网络深度,增加了推理的时间
  2. 添加MLP之后,训练出来的最优方案也只能收敛到MLP层,不一定是全局最好的
  3. 直接去优化这个Promote并不能保证优化是单调的,也就是不是全局最优,很难优化好
  4. 减少了可用于处理下游任务的序列长度,因为新加入的Promote会占用输入的token长度

具体方法

最核心的思路如下公式所示,研究表明通常是一个欠秩的矩阵

因此,可以进行低秩分解

Lora思路

训练时,初始化为全零矩阵,这样子参数量就从变成了,这里的一般是远小于

优点

  • LoRA是对完全微调的一种推广方法

    • 在适应过程中,LoRA不需要对权重矩阵进行完全秩的累积梯度更新,而是可以基于预训练的权重矩阵设置秩
    • 当LoRA应用于所有权重矩阵并且偏置进行训练时,这种方法提供了类似于完全微调的表现能力
  • 没有额外的推理延迟

    在生产部署中,LoRA可以计算和存储,其中属于,当切换到另一个下游任务时,可以通过减去并加上不同的来恢复,这是一个快速操作,几乎没有额外的内存开销(潮汐GPU)

为什么低秩矩阵有效

  1. 当给定参数数量时,应该调整预训练Transformer中的哪些具体权重矩阵子集以实现最佳的下游性能?

    • 在给定参数预算的情况下,确定要调整的权重矩阵子集以实现最佳下游性能是一个复杂的问题,并且没有固定的答案

      通常,可以考虑根据下游任务的特点和需求进行权衡和选择

      一种常见的方法是通过对不同权重矩阵进行实验性微调,并根据性能评估来确定适合特定任务的权重矩阵子集

  2. 最优的适应矩阵是否真的是欠秩的吗?如果是,那么在实际情况下推荐的秩是多少?

    • 最优的适应矩阵是否真的是欠秩的,这取决于具体情况。秩缺失意味着矩阵的秩(矩阵的线性独立列数或行数的最大数量)较低
    • 对于实际目的,建议选择适当的秩以平衡模型性能和计算成本。具体推荐的秩取决于任务的复杂性、数据集的规模以及可用的计算资源等因素
  3. 之间的关系是什么?之间是否存在高相关性?的大小与相比如何?

    • 表示适应矩阵,用于调整预训练权重矩阵之间的关系取决于具体的适应方法和优化算法。在某些情况下,可以通过对的微小调整来获得,而在其他情况下,可能包含更大的变化
    • 之间的相关性取决于适应方法的设计和优化过程的细节。它们可以存在一定的相关性,但具体情况可能因模型架构、任务要求和数据集特征而异
    • 的大小与的大小之间没有固定的比较关系,因为它们的尺度取决于具体的数值范围和调整方法

对于attention参数附加到哪个上更有效

实验在GPT-3 175B模型上设置了一个参数预算为18M(如果以FP16存储,大约为35MB)。这对应于当我们适应一种注意力权重时r=8,或者当我们适应两种类型时r=4,适用于所有96层

Lora关于r参数选择的实验

需要注意的是,将所有参数放在中会导致显著降低性能,而注入到则产生最佳结果。这表明,即使秩为4,中包含了足够的信息,使得与使用具有较大秩的单一类型的权重相比,使用更多的权重矩阵更可取

回答:将参数设置到q、v上时,r取多少合适

Lora关于r参数取多少的实验

LoRA在非常小的秩(特别是对于而言)下已经取得了相当的竞争力,这表明更新矩阵可能具有非常小的内在秩,使用低秩矩阵对LLM进行fine tune的时候,可以用一个非常小的低秩矩阵,就可以捕捉到对下游任务的一些特征信息,这为我们提供了一个非常高效的LLM的fine tune的方式,同时提高了下游任务的性能

回答:之间的关系是什么

首先对进行奇异值分解,并且把它左奇异值向量和右奇异值向量乘到上,把映射到的子空间,并计算范数,同时还把映射到映射到随机矩阵上

以此来证明有更强的相关性

Lora下w和delta w的关系

首先,与随机矩阵相比,具有更强的相关性,表明放大了预训练模型中的中已经存在的某些特征

其次,不是重复的奇异值方向,而是只放大在中没有强调的方向

第三,放大因子相当巨大:当r=4时,21.5≈6.91/0.32,表明只是放大了中的一些特征,且放大倍数是很大的,相当于是把下游任务需要的特征提取出来并进行放大

AdaLoRA(转)

AdaLoRA:更强大的LoRA

github cauyxy/YourGPT

Lora中的r是一个确定值,但不是对于所有的层,像q、k这样的层,q内在秩比较大,v内在秩比较小,对于不同的矩阵应该使用不同的内在秩

AdaLoRA是在对同样的参数量下,对不同的矩阵使用不同的r,通过奇异值分解,判断r的大小,来取得更好的效果

提出问题

在NLP领域,对于下游任务进行大型预训练语言模型的微调已经成为一种重要的做法。一般而言,我们会采用对原有的预训练模型进行全量微调的方法来适配下游任务,但这种方法存在两个问题

  1. 训练阶段: 对于预训练模型进行微调的时候,为了更新权重参数,需要大量的显存来存储参数的梯度和优化器信息,在当今预训练模型的参数变得越来越大的情况下,针对下游任务微调门槛变得越来越高
  2. 推理阶段: 由于我们训练的时候是对于模型参数进行全量的更新,所以多个下游任务需要为每个任务维护一个大型模型的独立副本,这样就导致我们在实际应用的时候浪费了不必要的存储

现有方法: 为了解决这些问题,研究者提出了两个主要研究方向,以减少微调参数的数量,同时保持甚至提高预训练语言模型的性能

添加小型网络模块

将小型网络模块添加到PLMs中,保持基础模型保持不变的情况下仅针对每个任务微调这些模块,可以用于所有任务。这样,只需引入和更新少量任务特定的参数,就可以适配下游的任务,大大提高了预训练模型的实用性,方法示例

添加小型网络模块

  • Adapter tuning:是在基础模型的各层之间插入小型神经模块

  • Prefix tuning:将可训练的前缀标记附加到基础模型的输入或隐藏层上

  • Prompt Tuning: 修改模型的输入,在模型输入的前面加一些特定的前缀

可行之处:可以达到与完全微调几乎相当的性能,同时仅更新不到原始模型参数的1%,大大减少了内存消耗。

存在问题

  • Adapter tuning:引入了推理延迟,最终收敛到适配器层

  • Prefix or Prompt tuning:直接优化Prefix和Prompt是非单调的,比较难收敛,并且消耗了输入的token

下游任务增量更新

对预训练权重的增量更新进行建模,而无需修改模型架构

方法示例

  • Diff pruning:将初始化为与相同的维度,然后根据参数的大小按元素对进行剪枝

  • LoRA:通过两个小得多的矩阵的乘积将参数化为低阶矩阵

可行之处:可以达到与完全微调几乎相当的性能

存在问题:

  • Diff pruning
    • 需要底层实现来加速非结构化稀疏矩阵的计算,不能直接使用现有的框架
    • 训练过程中需要存储完整的矩阵,相比于Full finetune并没有降低计算成本
  • LoRA
    • 预先指定每个增量矩阵的内在秩r相同,忽略了在微调预训练模型时,权重矩阵的重要性在不同模块和层之间存在显著差异
    • 只训练了self-attention,没有训练feed-forward networks,事实上FFN更重要

问题总结

不能预先指定矩阵的秩,需要动态更新增量矩阵的R

  • 权重矩阵的重要性在不同模块和层之间存在显著差异

需要找到更加重要的矩阵,分配更多的参数,裁剪不重要的矩阵

  • 找到重要的矩阵,提升模型效果
  • 裁剪不重要的矩阵,降低参数计算量,降低模型效果差的风险

解决方案

目标:在类似LoRA的微调过程中动态分配参数预算给权重矩阵

  1. 调整增量矩阵的秩来控制预算分配。AdaLoRA将关键的增量矩阵分配高秩以捕捉更精细和任务特定的信息,而将较不重要的矩阵的秩降低以防止过拟合并节省计算预算
  2. 采用参数化矩阵来模拟SVD,并舍弃不重要的奇异值,同时保留奇异向量。由于对一个大矩阵进行精确SVD分解的计算消耗非常大,这种方法可以加速计算,同时保留未来恢复的可能性并稳定训练
  3. 在训练损失中添加了额外的惩罚项,以规范奇异矩阵P和Q的正交性,从而避免SVD的大量计算并稳定训练

SVD-BASED ADAPTATION

如上所述,我们把增量矩阵做一个奇异值分解的近似,即,对矩阵更新的描述则有如下表示

为了保证的正交性,即,我们提出如下所示的正则损失

为什么不直接在原来的BA上进行修剪?

  1. 当一对奇异向量被认为为不重要时,我们必须修剪它的所有元素。这就导致几乎不可能重新激活修剪过的奇异向量,因为它们的元素都被清零并且不再训练

    与之对比,AdaLoRA只是Mask了奇异值

  2. LoRA的A和B不是正交的,这意味着奇异向量可以相互依赖。 与截断最小的奇异值相比,丢弃奇异向量可能会导致原始矩阵发生更大的变化。

    因此,在分配完秩的每一步之后,增量矩阵通常会发生更多不可预测的显著变化,这导致训练不稳定,甚至损害模型的效果

IMPORTANCE-AWARE RANK ALLOCATION

我们将基于SVD的秩调整应用于每个权重矩阵,包括每个transformer层的。为了控制参数预算,我们在训练期间根据重要性得分迭代修剪奇异值

为了更好地表示,我们用来索引增量矩阵 for ,用来表示第个矩阵的奇异值,奇异向量三元组, 来表示这个三元组的重要性

作为参数训练的代价,同时加上正则化项,就得出了如下的目标函数: (是正则化系数)

我们在训练的时候就可以通过梯度下降的方式对进行更新,下面是的例子

然后我们再基于进行裁剪

其中包含所有三元组的重要性分数。是第步剩余奇异值的预算

通过这种方式,我们通过修剪不太重要的奇异值,将更多预算留给优先级较高的增量矩阵

Magnitude of singular values

这样的话只有最小的奇异值以及最不重要的奇异向量被去弃。它最大限度地减小了与原始矩阵的偏差,进一步稳定了训练。但是这个度量不能正确量化参数(三元组)对模型性能的贡献

Sensitivity-based importance

之前的工作利用灵敏度来量化单个参数的重要性,并据此对参数进行非结构化修剪。在我们的例子上,我们必须设计一个新的度量标准,因为三元组要被按组丢弃了,所以每一项的敏感性都应该被考虑,并适当地组合起来,以量化三元组对模型性能的整体贡献

我们设计了如下所示的函数来计算importance score

我们可以采用的灵敏度,定义为梯度权重乘积的大小:

本质上近似于参数归零时的损失变化。如果去除一个参数影响较大,则模型对该参数敏感,我们应该保留它

但之前的工作指出,直接计算的敏感性还不是一个可靠的重要指标。这样的分数是在抽样的minibatch上估计的。随机采样和复杂的训练动态导致灵敏度估计的变异性大,不确定性大,这样可能会导致对于参数的重要性的错误估计。提出通过灵敏度平滑和不确定性量化,加入累计灵敏度的影响来解决这一问题:

接下来,我们把定义为的乘积

这样,我们就得到了一个既考虑了三元组所有元素,又考虑了累计灵敏度足够平滑的一个重要性函数

GLOBAL BUDGET SCHEDULER

在低秩自适应的情况下,调整秩自然是为了控制参数预算。因此,我们将预算定义为所有增量矩阵的总秩,即总奇异值的数量

回想一下,预算分配是在微调期间迭代执行的。为了便于训练,我们提出了一个全局预算调度器。具体来说,我们从略高于目标预算的初始预算算开始(例如,的1.5倍)

我们将每个增量矩阵的初始秩设为。我们对步进行warmup,然后按照三次计划减少预算,直到达到

最后,我们得到的修正完预算分布,并对步骤的模型进行了微调

这使得AdaLoRA可以先探索参数空间,然后再关注最重要的权重

实验验证

QLORA

FineTune -> P_tuning -> P_tuning V2 -> LoRA -> QLoRA

BERT Adapter

评估指标含义

如何对大模型进行评估上

MMLU:mmlu数据集包含来自各个知识领域的多项选择题。该数据集涵盖了人文学科、社会科学、自然科学以及其他一些对某些人学习至关重要的领域。数据集包括57个任务,其中包括初等数学、美国历史、计算机科学、法律等内容。通过这个数据集可以评估大模型在不同领域的推理能力

CMMLU:CMMLU数据集是一个综合性的中文评估基准,由MBZUAI、上海交通大学、微软亚洲研究院共同推出,在评估语言模型在中文语境下的知识和推理能力方面极具权威性。一句话理解就是中文版本的MMLU

C-Eval:C-Eval是一个全面的中文基础模型评估套件,它包含了13948个多项选择题,涵盖了52个不同的学科和四个难度级别,一句话理解就是中文版本的mmlu

GSM-8K:GSM8K是由人类问题作者创建的8.5K高质量语言多样化小学数学单词问题的数据集,通过这套数据集可以评估大模型的数学推理运算能力。下图是考察大模型8大方面能力,例如写作,人文,推理,角色扮演等,众所周知,数学运算是所有大模型能力最弱的部分。GSM8K数据集就是专门用来评估大模型数学运算能力的

HumanEval:HumanEval是一个用于评估代码生成能力的数据集,由OpenAI在2021年推出。 这个数据集包含164个手工编写的编程问题,每个问题都包括一个函数签名、文档字符串(docstring)、函数体以及几个单元测试。 这些问题涵盖了语言理解、推理、算法和简单数学等方面

MBPP:MBPP(Mostly Basic Programming Problems)是一个数据集,主要包含了974个短小的Python函数问题,由谷歌在2021年推出,这些问题主要是为初级程序员设计的。 数据集还包含了这些程序的文本描述和用于检查功能正确性的测试用例。一句话理解,和HumanEval一样,也是用于评估大模型代码生成能力的数据集

BBH:一个包含23个具有挑战性的 BIG-Bench 任务的套件,我们称之为 BIG-Bench Hard(BBH)。这些任务是先前语言模型评估未能超越平均人类评分者的任务

Multi-HumanEval:包含多种编程语言的数据集,一句话理解就是HumanEval只包含了python的编程问题,multi-humaneval包含的多种编程语言,例如java,go,javascript等等

HumanEval-X:HumanEval-X 是一个用于评估代码生成模型的多语言能力的基准测试。 它包含了820个高质量的人工制作的数据样本(每个样本都包含测试用例),涵盖了Python、C++、Java、JavaScript和Go这五种编程语言,可用于各种任务,如代码生成和翻译。一句话理解:HumanEval-X数据集和Multi-HumanEval数据集作用相同,只是数据集推出的机构不同而已

RLHF

入门】大语言模型常用微调框架介绍|LoRA&Prefix-Tuning&Prompt-Tuning&P-Tuning v2&RLHF微调原理简介

RLHF: Reinforcement Learning from Human。Feedback,即基于人工反馈机制的强化学习。最早与2022年4月,由OpenAI研究团队系统总结并提出.并在GPT模型的对话类任务微调中大放异彩,被称为ChatGPT背后的功臣

RLHF也是目前为止常用的、最为复杂的基于强化学习的大语言模型微调方法,目前最好的端到端RLHF实现是DeepSpeedChat库,由微软开源并维护

基于强化学习的进阶微调方法RLHF方法

论文地址: https://arxiv.org/abs/2203.02155

步骤1: 监督微调 (SFT)-一 使用精选的人类回答来微调预训练的语言模型以应对各种查询

步骤2:奖励模型微调 — 使用一个包含人类对同一查询的多个答案打分的数据集来训练一个独立的 (通常比 SFT 小的) 奖励模型 (RW)

步骤3: RLHF 训练 —利用 Proximal Policy Optimization (PPO) 算法根据 RW 模型的奖励D九天Hector反馈进一步微调 SFT 模型。

Flash_Atten(转)

前置知识 GPU Arch:自顶向下分析 + B站 GPU Arch:自顶向下分析【浅谈底层·1】

随着人工智能特别是以GPT为代表的生成式AI的迅猛发展,GPU已经成为了一种不可或缺的工具,甚至企业都以拥有多少高端GPU作为抓住风口能力的衡量标准。相比之下,CPU虽然在传统计算领域占据主导地位,但在处理AI任务时却不及GPU出色

为什么AI计算通常选择GPU而不是CPU,分析GPU在AI计算中的优势,同时,从底层原理探讨从Volta到最新的Hopper四代NVIDIA GPU架构的演进,展示其不断提升的性能和功能

GPU主要由计算单元ALU组成。CPU不仅被Cache占据了大量空间,而且还有有复杂的控制逻辑和诸多优化电路,相比之下,计算能力只是CPU很小的一部分

GPU和CPU比较

通过上面自顶向下的分析,我们知道,对于GPU中的存储部分访问速度由快到慢,计算部分从大到小排列为

GPU架构发展参数

NVLink是什么?为什么需要他?

大模型通常具有巨大的参数数量和复杂的结构,需要处理大量的数据。分布式训练将这些大型模型分割成多个部分,由多个GPU或计算节点并行处理,每个部分处理自己的数据子集。然后通过全局通信,参数同步等方式进行梯度传播,此时GPU之间的通信带宽就变的越来越重要

在NVLink出现之前,GPU与GPU之间的数据交互通过PCIe(Peripheral Component Interconnect Express)总线进行。但PCIe存在两个问题,一是PCIe总线的带宽相对有限,其中PCIe 4.0x16的最大带宽也就64GB/s,二是PCIe总线的延迟相对较高,在GPU之间传输数据时,每次数据传输都需要通过CPU和主机内存来完成。这种传输路径会导致额外的延迟,并降低数据传输的效率。然而,深度学习应用中需要更高的带宽和更低的延迟,PCIe显然是无法满足当下的神经网络训练需求

引入NVLink

NVLink利用高带宽、低延迟的通信通道,直接将多个GPU连接在一起,实现快速、高效的数据传输和共享。通过NVLink,GPU之间的数据交互可以直接在GPU之间进行,而无需通过CPU和主机内存。这种直接内存访问(DMA)的方式大大减少了数据传输的复制和延迟,提高了数据共享的效率。此外,NVLink还提供了一致的内存空间,使得多个GPU能够共享同一份内存,简化了程序设计和数据管理的复杂性

概述

FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness Paper 2022

FlashAttention: 更快训练更长上下文的GPT

Transformer作为GPT类模型的基础架构提供了强大的特征处理能力,但是处理更长上下文仍然是一个挑战,因为核心的自注意力模块在序列长度上具有O(N^2)的时间和内存复杂度😓

这篇Flash Attention的工作深入硬件,新提出了一种具有IO感知的快速的⚡️,节省内存的🧠,精确的🎯注意力算法。目前,Flash Attention已经集成至torch2.0,并且社区也提供了多种实现

78s看懂FlashAttention【有点意思·1】

核心要点

⚡️为什么加快了计算?Fast

降低了耗时的HBM访问次数。采用Tiling技术分块从HBM加载数据到SRAM进行融合计算

🧠为什么节省了内存?Memory-Efficient

不再对中间矩阵S,P进行存储。在反向的时候通过Recomputation重新计算来计算梯度

🎯为什么是精准注意力?Exact Attention

算法流程只是分块计算,无近似操作

提出问题

Transformer结构已成为自然语言处理和图像分类等应用中最常用的架构。尽管Transformer在规模上不断增大和加深,但处理更长上下文仍然是一个挑战,因为核心的自注意力模块在序列长度上具有二次方的时间和内存复杂度。这导致在处理长序列时速度变慢且内存需求巨大。因此,我们需要一些优化算法来提高注意力模块的计算速度和内存利用率

解决方案

flash_Atten架构图

Forward

Standard Attention

在注意力的一般实现中,对 三个输入执行以下算法得到输出,其中softmax行级别执行

在这个算法中,矩阵都是很大,需要在HBM中实例化来进行存储,这样就会带来很多HBM的访问次数, 最终体现到算法时间端到端较长的延迟

flash_Atten流程

FlashAttention(Tiling)

理论基础

在传统算法中,一种方式是将Mask和SoftMax部分融合,以减少访存次数。然而,FlashAttention则更加激进,它将从输入到输出的整个过程进行融合,以避免矩阵的存储开销,实现端到端的延迟缩减。然而,由于输入的长度通常很长,无法完全将完整的及中间计算结果存储在SRAM中。因此,需要依 赖HBM进行访存操作,与原始计算延迟相比没有太大差异,甚至会变慢(没具体测)

为了让计算过程的结果完全在SRAM中,摆脱对HBM的依赖,可以采用分片操作,每次进行部分计算,确保这些计算结果能在SRAM内进行交互,待得到对应的结果后再进行输出

这个过程中,有一点需要注意的是,之前对于softmax的计算是以行为单位的,如下所示:

当我们将输入进行分片后,无法对完整的行数据执行Softmax操作。这是因为Softmax函数在计算时需要考虑整个行的数据

然而,我们可以通过如下所示方法来获得与完整行Softmax相同的结果,而无需使用近似操作

具体的分块softmax代码演示

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch


q = torch.tensor([1,2]).float()
v = torch.tensor([1,2]).float()
q_sm = torch.softmax(q, 0)
print(q_sm) # tensor([0.2689, 0.7311])

torch.dot(q_sm, v) # tensor(1.7311)

m_pre = float("-inf")
l_pre = 0
cur_sum = 0

block1 = torch.tensor([1]).float()
# get cur max value
m_cur = max(torch.max(block1), m_pre)
# scale pre log value by max exp
l_pre *= torch.exp(m_pre - m_cur)
# calculate current log sum
p = torch.exp(block1 - m_cur)
l_cur = torch.sum(p) + l_pre
# scale pre result by log sum
cur_sum = cur_sum * l_pre / l_cur
p = p / l_cur
cur_sum = 1 * p[0]

l_pre = l_cur
m_pre = m_cur
print(cur_sum) # tensor(1.)

block2 = torch.tensor([2]).float()
m_cur = max(torch.max(block2), m_pre)
l_pre *= torch.exp(m_pre - m_cur)
p = torch.exp(block2 - m_cur)
l_cur = torch.sum(p) + l_pre
cur_sum = cur_sum * l_pre / l_cur
p = p / l_cur
cur_sum += 2 * p[0]
print(cur_sum) # tensor(1.7311)

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
@triton.jit
def _fwd_kernel(
Q, K, V, sm_scale,
L, M,
Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
stride_oz, stride_oh, stride_om, stride_on,
Z, H, N_CTX,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
# initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_d = tl.arange(0, BLOCK_DMODEL)
off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk
off_k = off_hz * stride_qh + offs_n[None, :] * stride_kn + offs_d[:, None] * stride_kk
off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk
# Initialize pointers to Q, K, V
q_ptrs = Q + off_q
k_ptrs = K + off_k
v_ptrs = V + off_v
# initialize pointer to m and l
m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# load q: it will stay in SRAM throughout
q = tl.load(q_ptrs)
# loop over k, v and update accumulator
for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):
# -- compute qk ----
k = tl.load(k_ptrs)
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
qk += tl.dot(q, k)
qk *= sm_scale
qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf"))
# compute new m
m_curr = tl.maximum(tl.max(qk, 1), m_prev)
# correct old l
l_prev *= tl.exp(m_prev - m_curr)
# attention weights
p = tl.exp(qk - m_curr[:, None])
l_curr = tl.sum(p, 1) + l_prev
# rescale operands of matmuls
l_rcp = 1. / l_curr
p *= l_rcp[:, None]
acc *= (l_prev * l_rcp)[:, None]
# update acc
p = p.to(Q.dtype.element_ty)
v = tl.load(v_ptrs)
acc += tl.dot(p, v)
# update m_i and l_i
l_prev = l_curr
m_prev = m_curr
# update pointers
k_ptrs += BLOCK_N * stride_kn
v_ptrs += BLOCK_N * stride_vk
# rematerialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
l_ptrs = L + off_hz * N_CTX + offs_m
m_ptrs = M + off_hz * N_CTX + offs_m
tl.store(l_ptrs, l_prev)
tl.store(m_ptrs, m_prev)
# initialize pointers to output
offs_n = tl.arange(0, BLOCK_DMODEL)
off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
out_ptrs = Out + off_o
tl.store(out_ptrs, acc)

IO复杂度分析

Standard Attention

对于标准注意力实现,初期我们需要把输入 从HBM中读取,并计算完毕后把输出写入到HBM中

  1. 第一步把读取出来计算出,然后把存回去,内存访问复杂度
  2. 第二步把读取出来计算出,然后把存回去,内存访问复杂度
  3. 第三步把读取出来计算出,然后计算出结果,内存访问复杂度

综上所述,整体的内存访问复杂度为

FlashAttention

对于FlashAttention,我们设置一个分块大小来把分成块,对于的每一块都要把部分的全部元素Load一遍,这样则有FlashAttention的内存访问复杂度为

在这里,我们需要两个分块大小,的分块大小的分块大小,我们设定SRAM的大小为,为了能把分块后的放进SRAM,那么则有一下限制:

相应的,有如下限制:

最终,还有一个中间态需要存储,则有如下限制:

综上,限制如下

进而推出

那么在 的前提下,则有FlashAttention的HBM内存访问复杂度为:

在语言建模中,通常有,则有。这样,在前向的过程中,我们采用分块计算的方式,避免了矩阵的存储开销,整体的运算都在SRAM内进行,降低了HBM访问次数,大大提升了计算的速度,减少了对存储的消耗

Backward

理论基础

在上面前向的时候我们为了减少HBM访存次数,降低内存消耗量,我们并没有对矩阵进行存储,而这个在反向传播计算梯度的时候确实需要的一个信息

之前有通过Gradient checkpointing的方式来实现梯度实现在前向的时候更加节省内存

我们这里则采用重新计算的方式来计算对应的梯度。在上面前向计算的时候我们不会存储矩阵,但是我们会存储对应的指数项之和来进行梯度的计算

我们在反向的过程中最重要的事情就是就是Loss函数对应的梯度

其中$\mathbf{O}\mathbf{d} \mathbf{O}=\frac{\partial \phi}{\partial \mathbf{O}}\mathbf{O}$$是现成的

对应的梯度也很好计算,由于,根据链式求导法则和矩阵求导法则则有,更详细如下所示:

对应的梯度算起来就比较复杂一点。这两个经过的计算逻辑步骤更多,我们可以一步一步的来进行计算。我们可以先计算。由于 ,则有如下表示

Fact: 的雅各比矩阵为,具体推导见

Derivative of the Softmax Function and the Categorical Cross-Entropy Loss

由于, 根据上述定理则有:

接下来我们定义如下表示:

根据上述定义简化上上式则有如下表示:

相应的可表示为如下形式:

又因为,结合上述推导利用链式求导法则对应的梯度有如下表示:

至此,我们得到了一个完整的包含前向和反向的,降低了HBM访问次数的,新的Attention算子

代码实现

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
@triton.jit
def _bwd_kernel(
Q, K, V, sm_scale, Out, DO,
DQ, DK, DV,
L, M,
D,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_vz, stride_vh, stride_vk, stride_vn,
Z, H, N_CTX,
num_block,
BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr,
):
off_hz = tl.program_id(0)
off_z = off_hz // H
off_h = off_hz % H
# offset pointers for batch/head
Q += off_z * stride_qz + off_h * stride_qh
K += off_z * stride_qz + off_h * stride_qh
V += off_z * stride_qz + off_h * stride_qh
DO += off_z * stride_qz + off_h * stride_qh
DQ += off_z * stride_qz + off_h * stride_qh
DK += off_z * stride_qz + off_h * stride_qh
DV += off_z * stride_qz + off_h * stride_qh
for start_n in range(0, num_block):
lo = start_n * BLOCK_M
# initialize row/col offsets
offs_qm = lo + tl.arange(0, BLOCK_M)
offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)
offs_m = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_DMODEL)
# initialize pointers to value-like data
q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)
# pointer to row-wise quantities in value-like data
D_ptrs = D + off_hz * N_CTX
m_ptrs = M + off_hz * N_CTX
# initialize dv amd dk
dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)
# k and v stay in SRAM throughout
k = tl.load(k_ptrs)
v = tl.load(v_ptrs)
# loop over rows
for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):
offs_m_curr = start_m + offs_m
# load q, k, v, do on-chip
q = tl.load(q_ptrs)
# recompute p = softmax(qk, dim=-1).T
# NOTE: `do` is pre-divided by `l`; no normalization here
qk = tl.dot(q, tl.trans(k))
qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf"))
m = tl.load(m_ptrs + offs_m_curr)
p = tl.exp(qk * sm_scale - m[:, None])
# compute dv
do = tl.load(do_ptrs)
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
# compute dp = dot(v, do)
Di = tl.load(D_ptrs + offs_m_curr)
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
dp += tl.dot(do, tl.trans(v))
# compute ds = p * (dp - delta[:, None])
ds = p * dp * sm_scale
# compute dk = dot(ds.T, q)
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
# compute dq
dq = tl.load(dq_ptrs)
dq += tl.dot(ds.to(Q.dtype.element_ty), k)
tl.store(dq_ptrs, dq)
# increment pointers
dq_ptrs += BLOCK_M * stride_qm
q_ptrs += BLOCK_M * stride_qm
do_ptrs += BLOCK_M * stride_qm
# write-back
dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk)
dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk)
tl.store(dv_ptrs, dv)
tl.store(dk_ptrs, dk)

Block-Sparse

相比于上面的全量计算,块稀疏的FlashAttention需要额外提供一个Mask矩阵用于将一些元 素置零来保证块稀疏加速计算

本章对于块稀疏的一个计算只是一个简单的尝试,没有进行太深入的探索,所以这里我们先一笔带过,后面我们可以讲一篇对FlashAttention进行块稀疏优化的工作SCFA

实验验证

通过实验验证发现,FlashAttention在速度和内存占用方面都表现出明显的优势,并取得了良好的效果

flash_Atten实验验证1

flash_Atten实验验证2

目前,FlashAttention已经经过广泛验证, torch2.0中已提供flashattention的实现

正如标题《Fast and Memory-Efficient Exact Attention with IO-Awareness》所示,FlashAttention的优点在于充分考虑了在计算任务中IO的重要性,并通过分块计算的方式开发了一种快速、节省显存、精确无近似的注意力实现方法。这使得我们更便于训练具有更长上下文的Transformer模型,并且为后续注意力算法的优化提供了一个基准