F.Linear()

源码

可以看到nn.Linear内部调用了F.Linear,相当于是将其封装了,并自动地对参数进行了初始化。如果我们想自己初始化参数,那么可以不用nn.Linear。

为了灵活地对参数按照自己的方式进行初始化,可以借鉴fairseq的初始化做法

	def reset_parameters(self):
               nn.init.xavier_uniform_(self.in_proj_weight)
               nn.init.xavier_uniform_(self.out_proj.weight)
               if self.in_proj_bias is not None:
                   nn.init.constant_(self.in_proj_bias, 0.)
                   nn.init.constant_(self.out_proj.bias, 0.)
               if self.bias_k is not None:
                   nn.init.xavier_normal_(self.bias_k)
               if self.bias_v is not None:	
                   nn.init.xavier_normal_(self.bias_v)
 self.in_proj_weight = Parameter(torch.Tensor(3 * embed_dim, embed_dim))
 if bias:
      self.in_proj_bias = Parameter(torch.Tensor(3 * embed_dim))
 else:
      self.register_parameter(\'in_proj_bias\', None)
 self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
if add_bias_kv:
     self.bias_k = Parameter(torch.Tensor(1, 1, embed_dim))        
     self.bias_v = Parameter(torch.Tensor(1, 1, embed_dim))
 else:
     self.bias_k = self.bias_v = None

版权声明:本文为匿名原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接: