torch.nn.functional.normalize详解
来源CSDN
torch.nn.functional.normalize
torch.nn.functional.normalize(input, p=2, dim=1, eps=1e-12, out=None)
功能:将某一个维度除以那个维度对应的范数(默认是2范数)。
主要讲以下三种情况:
输入为一维Tensor
a = torch.Tensor([1,2,3])
torch.nn.functional.normalize(a, dim=0)
tensor([0.2673, 0.5345, 0.8018])
输入为二维Tensor
b = torch.Tensor([[1,2,3], [4,5,6]])
torch.nn.functional.normalize(b, dim=0)
tensor([[0.2425, 0.3714, 0.4472],
[0.9701, 0.9285, 0.8944]])
b = torch.Tensor([[1,2,3], [4,5,6]])
torch.nn.functional.normalize(b, dim=1)
tensor([[0.2673, 0.5345, 0.8018],
[0.4558, 0.5698, 0.6838]])
因为dim=1,所以是对行操作。以第一行为例,整体除以了第一行的范数: