PyTorch CosineSimilarity 慎用
Johann Li | June 26, 2024
tl;dr
PyTorch 的 CosineSimilarity(余弦相似度) 有两个问题:算得慢,同时还占用显存。所以建议是使用 torch.norm
+ torch.mm
替代。
具体代码
我这里要对若干个超大矩阵使用 nn.CosineSimilarity
或者 F.cosine_similarity
计算余弦相似度。
须要超级大量的显存,然后我就只好将计算操作进行拆分,通常来说,torch.compile
能优化一些,但是有时有优化不一定有作用。
下面是我分块计算相似度的代码:
def cosine_similarity(self, feat: Tensor, prototypes: Tensor) -> Tensor:
'''
feat: shot x channel x height x width
prototype: num x channel
return: shot x num x height x width
'''
_, _, height, width = feat.shape
feat = feat.flatten(2, 3).unsqueeze(dim = 1) # N 1 C X
prototypes = prototypes.unsqueeze(dim = 0).unsqueeze(dim = 3) # 1 P C 1
return torch.concat( [
F.cosine_similarity(f, prototypes, dim = 2, eps=self.eps)
for f in feat.split(self.split_size, dim = 3)
] , dim=2).unflatten(2, (height, width))
这个是我使用 torch.norm
+ torch.mm
计算相似度的代码:
def nmm_similarity(self, feat: Tensor, prototypes: Tensor) -> Tensor:
'''
feat: shot x channel x height x width
prototype: num x channel
return: shot x num x height x width
'''
s, _, h, w = feat.shape
feat = feat.permute(0, 2, 3, 1).flatten(0, 2) # shot_height_width x channel
feat = feat / (feat.norm(dim = 1, p = 2, keepdim = True) + self.eps)
prototypes = prototypes.permute(1, 0) # channel x num
prototypes = prototypes / (prototypes.norm(dim = 0, p = 2, keepdim = True) + self.eps)
sim = (feat @ prototypes).unflatten(0, (s, h, w)).permute(0, 3, 1, 2)
return sim
在测试过程中(就是推理)
cosine_similarity
这个显存占用大致 20G,然后单个类别的测试需要 25-30 分钟,
而 nmm_similarity
在测试过程中,大致占用 12G 显存,然后单个类别需要 2 分钟左右。
(没有启用 torch.compile
因为这个对 cosine_similarity
提升效果较差)。