LSTM MultiheadAttention 输入维度

news/2023/6/6 4:56:56

最近遇到点问题,对于模块的输入矩阵的维度搞不清楚,这里在学习一下,记录下来,方便以后查阅。

LSTM & Attention 输入维度

  • LSTM
    • 记忆单元
    • 门控机制
    • LSTM结构
    • LSTM的计算过程
      • 遗忘门
      • 输入门
      • 更新记忆单元
      • 输出门
    • LSTM单元的pytorch实现
    • Pytorch中的LSTM
      • 参数
      • 输入Inputs: input, (h_0, c_0)
      • 输出Outputs: output, (h_n, c_n)
      • 参数解释
  • MultiheadAttention
    • Self Attention 计算过程
    • Multihead Attention 计算过程
    • MultiheadAttention单元的pytorch实现
    • Pytorch中的MultiheadAttention
    • 输入的矩阵维度
  • 参考资料

LSTM

LSTM是RNN的一种变种,可以有效地解决RNN的梯度爆炸或者消失问题。

在这里插入图片描述

记忆单元

LSTM引入了一个新的记忆单元ctc_tct,用于进行线性的循环信息传递,同时输出信息给隐藏层的外部状态hth_tht。在每个时刻tttctc_tct记录了到当前时刻为止的历史信息。

门控机制

LSTM引入门控机制来控制信息传递的路径,类似于数字电路中的门,0即关闭,1即开启。

LSTM中的三个门为遗忘门ftf_tft,输入门iti_tit,和输出门oto_tot

  • ftf_tft控制上一个时刻的记忆单元ct−1c_{t-1}ct1需要遗忘多少信息
  • iti_tit控制当前时刻的候选状态c~t\tilde{c}_tc~t有多少信息需要存储
  • oto_tot控制当前时刻的记忆单元ctc_tct有多少信息需要输出给外部状态hth_tht

LSTM结构

如图一所示为LSTM的结构,LSTM网络由一个个的LSTM单元连接而成。

在这里插入图片描述

LSTM 的关键就是记忆单元,水平线在图上方贯穿运行。

记忆单元类似于传送带。直接在整个链上运行,只有一些少量的线性交互。信息在上面流传保持不变会很容易。

LSTM的计算过程

遗忘门

在这里插入图片描述

在这一步中,遗忘门读取ht−1h_{t-1}ht1xtx_txt,经由sigmoid,输入一个在0到1之间数值给每个在记忆单元ct−1c_{t-1}ct1中的数字,1表示完全保留,0表示完全舍弃。

输入门

在这里插入图片描述
输入门将确定什么样的信息内存放在记忆单元中,这里包含两个部分。

  1. sigmoid层同样输出[0,1]的数值,决定候选状态c~t\tilde{c}_tc~t有多少信息需要存储
  2. tanh层会创建候选状态c~t\tilde{c}_tc~t

更新记忆单元

随后更新旧的细胞状态,将ct−1c_{t-1}ct1更新为ctc_tct

在这里插入图片描述

首先将旧状态ct−1c_{t-1}ct1ftf_tft相乘,遗忘掉由ftf_tft所确定的需要遗忘的信息,然后加上it∗c~ti_t*\tilde{c}_titc~t,由此得到了新的记忆单元ctc_tct

输出门

结合输出门oto_tot将内部状态的信息传递给外部状态hth_tht。同样传递给外部状态的信息也是个过滤后的信息,首先sigmoid层确定记忆单元的那些信息被传递出去,然后,把细胞状态通过tanh层进行处理(得到[-1,1]的值)并将它和输出门的输出相乘,最终外部状态仅仅会得到输出门确定输出的那部分。

在这里插入图片描述

LSTM单元的pytorch实现

class LSTMCell(nn.Module):def __init__(self, input_size, hidden_size, cell_size, output_size):super().__init__()self.hidden_size = hidden_size # 隐含状态h的大小,也即LSTM单元隐含层神经元数量self.cell_size = cell_size # 记忆单元c的大小# 门self.gate = nn.Linear(input_size+hidden_size, cell_size)self.output = nn.Linear(hidden_size, output_size)self.sigmoid = nn.Sigmoid()self.tanh = nn.Tanh()self.softmax = nn.LogSoftmax(dim=1)def forward(self, input, hidden, cell):# 连接输入x与h combined = torch.cat((input, hidden), 1)# 遗忘门f_gate = self.sigmoid(self.gate(combined))# 输入门i_gate = self.sigmoid(self.gate(combined))z_state = self.tanh(self.gate(combined))# 输出门o_gate = self.sigmoid(self.gate(combined))# 更新记忆单元cell = torch.add(torch.mul(cell, f_gate), torch.mul(z_state, i_gate))# 更新隐藏状态hhidden = torch.mul(self.tanh(cell), o_gate)output = self.output(hidden)output = self.softmax(output)return output, hidden, celldef initHidden(self):return torch.zeros(1, self.hidden_size)def initCell(self):return torch.zeros(1, self.cell_size)

Pytorch中的LSTM

在这里插入图片描述

参数

  • input_size – 输入特征维数
  • hidden_size – 隐含状态h hh的维数
  • num_layers – RNN层的个数:(在竖直方向堆叠的多个相同个数单元的层数),默认为1
  • bias – 隐层状态是否带bias,默认为true
  • batch_first – 是否输入输出的第一维为batchsize
  • dropout – 是否在除最后一个RNN层外的RNN层后面加dropout层
  • bidirectional –是否是双向RNN,默认为false
  • proj_size – 如果>0, 则会使用相应投影大小的LSTM,默认值:0

其中比较重要的参数就是hidden_size与num_layers,hidden_size所代表的就是LSTM单元中神经元的个数。num_layers所代表的含义,就是depth的堆叠,也就是有几层的隐含层。

在这里插入图片描述

这张图是以MLP的形式展示LSTM的传播方式(不用管左边的符号,输出和隐状态其实是一样的),方便理解hidden_size这个参数。其实hidden_size在各个函数里含义都差不多,就是参数W的第一维(或最后一维)。那么对应前面的公式,hidden_size实际就是以这个size设置所有W的对应维。

在这里插入图片描述

这张图非常便于理解参数num_layers。实际上就是个depth堆叠,每个蓝色块都是LSTM单元。只不过第一层输入是xt,ht−1(0),ct−1(0)x_t, h_{t-1}^{(0)}, c_{t-1}^{(0)}xt,ht1(0),ct1(0),中间层输入是ht(k−1),ht−1(k),ct−1(k)h_{t}^{(k-1)}, h_{t-1}^{(k)}, c_{t-1}^{(k)}ht(k1),ht1(k),ct1(k)

输入Inputs: input, (h_0, c_0)

  • input:当batch_first = False 时形状为(L,N,H_in),当 batch_first = True 则为(N, L, H_in​) ,包含批量样本的时间序列输入。该输入也可是一个可变换长度的时间序序列。
  • h_0:形状为(D∗num_layers, N, H_out),指的是包含每一个批量样本的初始隐含状态。如果模型未提供(h_0, c_0) ,默认为是全0矩阵。
    c_0:形状为(D∗num_layers, N, H_cell), 指的是包含每一个批量样本的初始记忆细胞状态。 如果模型未提供(h_0, c_0) ,默认为是全0矩阵。

输出Outputs: output, (h_n, c_n)

  • output: 当batch_first = False 形状为(L, N, D∗H_out​) ,当batch_first = True 则为 (N, L, D∗H_out​) ,包含LSTM最后一层每一个时间步长 的输出特征()。
  • h_n: 形状为(D∗num_layers, N, H_out​),包括每一个批量样本最后一个时间步的隐含状态。
  • c_n: 形状为(D∗num_layers, N, H_cell​),包括每一个批量样本最后一个时间步的记忆细胞状态。

参数解释

  • N = 批量大小
  • L = 序列长度
  • D = 2 如果模型参数bidirectional = 2,否则为1
  • H_in = 输入的特征大小(input_size)
  • H_cell = 隐含单元数量(hidden_size)
  • H_out = proj_size, 如果proj_size > 0, 否则的话 = 隐含单元数量(hidden_size)

MultiheadAttention

Self Attention 计算过程

在这里插入图片描述

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

Multihead Attention 计算过程

在这里插入图片描述

MultiheadAttention单元的pytorch实现

class Attention(nn.Module):'''Attention Module used to perform self-attention operation allowing the model to attendinformation from different representation subspaces on an input sequence of embeddings.The sequence of operations is as follows :-Input -> Query, Key, Value -> ReshapeHeads -> Query.TransposedKey -> Softmax -> Dropout-> AttentionScores.Value -> ReshapeHeadsBack -> OutputArgs:embed_dim: Dimension size of the hidden embeddingheads: Number of parallel attention heads (Default=8)activation: Optional activation function to be applied to the input while transforming to query, key and value matrixes (Default=None)dropout: Dropout value for the layer on attention_scores (Default=0.1)Methods:_reshape_heads(inp) :- Changes the input sequence embeddings to reduced dimension according to the numberof attention heads to parallelize attention operation(batch_size, seq_len, embed_dim) -> (batch_size * heads, seq_len, reduced_dim)_reshape_heads_back(inp) :-Changes the reduced dimension due to parallel attention heads back to the originalembedding size(batch_size * heads, seq_len, reduced_dim) -> (batch_size, seq_len, embed_dim)forward(inp) :-Performs the self-attention operation on the input sequence embedding.Returns the output of self-attention as well as atttention scores(batch_size, seq_len, embed_dim) -> (batch_size, seq_len, embed_dim), (batch_size * heads, seq_len, seq_len)Examples:>>> attention = Attention(embed_dim, heads, activation, dropout)>>> out, weights = attention(inp)'''def __init__(self, embed_dim, heads=8, activation=None, dropout=0.1):super(Attention, self).__init__()self.heads = headsself.embed_dim = embed_dimself.query = nn.Linear(embed_dim, embed_dim)self.key = nn.Linear(embed_dim, embed_dim)self.value = nn.Linear(embed_dim, embed_dim)self.softmax = nn.Softmax(dim=-1)if activation == 'relu':self.activation = nn.ReLU()elif activation == 'elu':self.activation = nn.ELU()else:self.activation = nn.Identity()self.dropout = nn.Dropout(dropout)def forward(self, inp):# inp: (batch_size, data_aug, cha_tim_dim, embed_dim)batch_size, data_aug, cha_tim_dim, embed_dim = inp.size()assert embed_dim == self.embed_dimquery = self.activation(self.query(inp))key   = self.activation(self.key(inp))value = self.activation(self.value(inp))# output of _reshape_heads(): (batch_size * heads, data_aug, cha_tim_dim, reduced_dim) | reduced_dim = embed_dim // headsquery = self._reshape_heads(query)key   = self._reshape_heads(key)value = self._reshape_heads(value)# attention_scores: (batch_size * heads, data_aug, cha_tim_dim, cha_tim_dim) | Softmaxed along the last dimensionattention_scores = self.softmax(torch.matmul(query, key.transpose(2, 3)))# out: (batch_size * heads, data_aug, cha_tim_dim, reduced_dim)out = torch.matmul(self.dropout(attention_scores), value)# output of _reshape_heads_back(): (batch_size, data_aug, cha_tim_dim, embed_dim)out = self._reshape_heads_back(out)return out, attention_scoresdef _reshape_heads(self, inp):# inp: (batch_size, data_aug, cha_tim_dim, embed_dim)batch_size, data_aug, cha_tim_dim, embed_dim = inp.size()reduced_dim = self.embed_dim // self.headsassert reduced_dim * self.heads == self.embed_dimout = inp.reshape(batch_size, data_aug, cha_tim_dim, self.heads, reduced_dim)out = out.permute(0, 3, 1, 2, 4)out = out.reshape(-1, data_aug, cha_tim_dim, reduced_dim)# out: (batch_size * heads, data_aug, cha_tim_dim, reduced_dim)return out

Pytorch中的MultiheadAttention

在这里插入图片描述

在这里插入图片描述

输入的矩阵维度

在这里插入图片描述

参考资料

LSTM详解

Pytorch LSTM模型 参数详解

[译] 理解 LSTM 网络

https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html?highlight=attention#torch.nn.MultiheadAttention

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.exyb.cn/news/show-4559152.html

如若内容造成侵权/违法违规/事实不符,请联系郑州代理记账网进行投诉反馈,一经查实,立即删除!

相关文章

唞音最新无人直播技巧系统

抖音直播汉字找不同游戏系统 系统简介; 找不同游戏直播系统,简单来说就是从多个汉字里面找不出一个不一样的字, 然后根据排序,把序列号打在直播间公屏上面,打出正确的答案系统会自动播报语音,相反不是正确的答案系统…

高人气直播间必备的8个直播留人技巧

你有没有遇到这样的情况? 直播开场吸引了2000个人进入直播间,不到20分钟,直播间不到300人,直播结束时,直播间的人数只有2位数,十分惨淡。 也就是说,你的直播留不住人。 直播间是否能留得住人…

【问题描述】3.2.6 中国余数定理:“有物不知几何,三三数余一,五五数余二,七七数余三,问:物有几何?”。编程求1~1000以内所有解。

【输入输出样例】 【样例输出说明】 (1)一行输出5个数,每个数占位5个字符(输出结束后跟换行符) (2)该数同时满足:被3除余1,被5除余2,被7除余3 #include &l…

算法学习笔记--余数定理

tips:算法学习过程中,总是会遇到许多有趣的知识,当前的数学理论几乎西方数学理论所统治,但在学习的过程经常会发现我华夏先祖也曾拥有独特的思考。与西方文明喜欢用晦涩的特殊符号进行理论推导所不同的是,华夏先祖往往以更生动形象的方式,将他们的智慧结晶遗留下来。西方…

正确清理mysql-bin

1. 背景 机器磁盘满导致mysql连接不上,删除部分日志,错误将mysql-bin.0050* 都删除,重启mysql失败 $ service mysqld start Starting MySQL.......... ERROR! The server quit without updating PID file (/data/mysqldata/gzqc249-243-214…

余数定理问题和余数类问题的解法

一、引言 Python里面有一个重要的求模运算符号“%”,作为一个小白,实验了好多次求模的运算,发现这个算法不同于一般的四则运算,其运算效率简直可以用神奇来形容。 例如以当今知道的最大质数——梅森素数为例&#xff0…

余数 中国余数定理

在我们学校的OJ系统中刚好看到这道题目,于是尝试着做了一下。其实关键的不是这道题目,而是它涉及到的数据类型。long long类型。关于_int64与long long的区别,请移步至__int64 与long long 的区别。 题目: 解题思路: …

数据结构和算法 数论 中国余数定理

1、中国余数定理概述 找出所有整数,它们被3、5、7除时,余数分别为2,3和2。一个这样的解为,所有的解是形如(k为任意整数)的整数。“中国余数定理”提出,对一组两两互质的模数(如3、5、7)来说&…