我正在尝试建立一个Pytorch网络,用于图像字幕。
目前我有一个编码器和解码器的工作网络,我想添加
nn.MultiheadAttnetion
向它分层(用作自我关注)。
目前我的解码如下:
class Decoder(nn.Module):
def __init__(self, hidden_size, embed_dim, vocab_size, layers = 1):
super(Decoder, self).__init__()
self.embed_dim = embed_dim
self.vocab_size = vocab_size
self.layers = layers
self.hidden_size = hidden_size
self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
self.lstm = nn.LSTM(input_size = embed_dim, hidden_size = hidden_size, batch_first = True, num_layers = layers)
#self.attention = nn.MultiheadAttention(hidden_size, num_heads=1, batch_first= True)
self.fc = nn.Linear(hidden_size, self.vocab_size)
def init_hidden(self, batch_size):
h = torch.zeros(self.layers, batch_size, self.hidden_size).to(device)
c = torch.zeros(self.layers, batch_size, self.hidden_size).to(device)
return h,c
def forward(self, features, caption):
batch_size = caption.size(0)
caption_size = caption.size(1)
h,c = self.init_hidden(batch_size)
embeddings = self.embedding(caption)
lstm_input = torch.cat((features.unsqueeze(1), embeddings[:,:-1,:]), dim=1)
output, (h,c) = self.lstm(lstm_input, (h,c))
#output, _ = self.attention(output, output, output)
output = self.fc(output)
return output
def generate_caption(self, features, max_caption_size = MAX_LEN):
h,c = self.init_hidden(1)
caption = ""
embeddings = features.unsqueeze(1)
for i in range(max_caption_size):
output, (h, c) = self.lstm(embeddings, (h,c))
#output, _ = self.attention(output, output, output)
output = self.fc(output)
_, word_index = torch.max(output, dim=2) # take the word with highest probability
if word_index == vocab.get_index(END_WORD):
break
caption += vocab.get_word(word_index) + " "
embeddings = self.embedding(torch.LongTensor([word_index]).view(1,-1).to(device))
return caption
对于图像字幕来说,它给出了比较好的结果。
我想添加注释掉的行,以便模型使用注意。但是——当我这么做时——模型崩溃了,尽管损失变得极低(在训练期间从2.7降到0.2,而不是在没有注意的情况下从2.7降到1)——字幕生成并没有真正起作用(反复预测同一个单词)。
我的问题是:
-
我是在用手机吗
nn.MultiheadAttention
正确地对我来说,使用它是非常奇怪的
之后
LSTM,但我在网上看到了这一点,它是从尺寸角度工作的
-
知道为什么我的模型在我使用注意力时会崩溃吗?
编辑
:我也试着引起大家的注意
之前
LSTM,但效果不好(网络预测每张图片都有相同的标题)