错误消息正在正确诊断问题:有些参数出现在多个参数组中。你可以通过以下方式向自己证明这一点:
>>> parameter_ids = [[id(p) for p in group["params"]] for group in caption_params]
>>> parameter_ids[0]
[140666221372896]
>>> parameter_ids[3]
[140666221372896]
这表明,第一个和最后一个参数组(每个参数组都包含一个大的嵌入张量)实际上包含对同一精确张量的引用。这个张量是什么?让我们来看看它,使用这两条引用路径来进一步表明它是同一件事:
>>> a = next(language_model.lm_head.parameters())
>>> a
Parameter containing:
tensor([[-0.1101, -0.0393, 0.0331, ..., -0.1364, 0.0151, 0.0453],
[ 0.0403, -0.0486, 0.0462, ..., 0.0861, 0.0025, 0.0432],
[-0.1275, 0.0479, 0.1841, ..., 0.0899, -0.1297, -0.0879],
...,
[-0.0445, -0.0548, 0.0123, ..., 0.1044, 0.0978, -0.0695],
[ 0.1860, 0.0167, 0.0461, ..., -0.0963, 0.0785, -0.0225],
[ 0.0514, -0.0277, 0.0499, ..., 0.0070, 0.1552, 0.1207]],
requires_grad=True)
>>> b = next(language_model.transformer.wte.parameters())
>>> b
Parameter containing:
tensor([[-0.1101, -0.0393, 0.0331, ..., -0.1364, 0.0151, 0.0453],
[ 0.0403, -0.0486, 0.0462, ..., 0.0861, 0.0025, 0.0432],
[-0.1275, 0.0479, 0.1841, ..., 0.0899, -0.1297, -0.0879],
...,
[-0.0445, -0.0548, 0.0123, ..., 0.1044, 0.0978, -0.0695],
[ 0.1860, 0.0167, 0.0461, ..., -0.0963, 0.0785, -0.0225],
[ 0.0514, -0.0277, 0.0499, ..., 0.0070, 0.1552, 0.1207]],
requires_grad=True)
>>> a is b
True
这是有道理的,因为许多基于Transformer的模型在开始时(初始
Embedding
层)和模型的末端(LM头)。
对于您的特定问题,您可以接受绑定的权重将以相同的LR移动,也可以通过克隆并将参数的新副本分配给两个模块之一来解开它们。