PyTorch 2.0 torch.compile 内存泄漏问题 (Triton 引起的)

简短的来说, PyTorch 2.0 新引入进来了一个 torch.compile API, 通过一定的编译的方式,加速伸进网络的训练和推理(虽然好像效果一般吧)。 但是这个东西,也带来了一个问题,就是说 内存泄漏

PyTorch 官方 GitHub 仓库中是有一个这样的 issue 的,已经关闭了。 里面提到的解决方式是安装最新的 Triton 就解决了。我测试了一下,能够解决这个问题。但是这里须要提到的一点是,出问题的版本是 2.0.0,而解决问题的版本是 2.0.0.post1。 这个要安装这个版本,可能须要执行 pip install triton==2.0.0.post1 这个命令。

除此之外,通过简单的测试,在更新前, torch.compile 里面如果将 backend 设置成 aot_ts_nvfuser 或者 nvprims_nvfuser 的话,没有出现过内存溢出的问题。

测试代码

我是用的网络是,通过 monkey-patch 魔改的 ResNet 网络:


def replace(m: nn.Module):
    for n, mx in m.named_children():
        if isinstance(mx, pool):
            m.add_module(n, nn.Identity())
        elif isinstance(mx, nn.Conv2d):
            if mx.stride != (1, 1) or mx.stride != 1:
                conv = nn.Conv2d(
                    mx.in_channels, mx.out_channels, mx.kernel_size,
                    stride=1, padding=mx.padding, dilation=mx.dilation,
                    groups=mx.groups, bias=mx.bias, padding_mode=mx.padding_mode,
                    device=mx.weight.device, dtype=mx.weight.dtype
                )
                m.add_module(n, conv)
        else:
            replace(mx)

class ResNet(M.resnet.ResNet):

    def __init__(
        self,
        block: Type[Union[M.resnet.BasicBlock, M.resnet.Bottleneck]],
        layers: List[int],
        planes: List[int] = [64, 64, 64, 64],
        zero_init_residual: bool = False,
        groups: int = 1,
        width_per_group: int = 64,
        replace_stride_with_dilation: Optional[List[bool]] = None,
        norm_layer: Optional[Callable[..., nn.Module]] = None,
        num_classes: int = 1000,
    ):
        super().__init__(
            block, layers,
            zero_init_residual=zero_init_residual,
            groups=groups, width_per_group=width_per_group,
            replace_stride_with_dilation=replace_stride_with_dilation,
            norm_layer=norm_layer,
            num_classes=num_classes,
        )
        if replace_stride_with_dilation is None:
            replace_stride_with_dilation = [False, False, False]

        self.inplanes = 64
        self.add_module('layer1', self._make_layer(block, planes[0], layers[0]))
        self.add_module('layer2', self._make_layer(block, planes[1], layers[1],
            stride=1, dilate=replace_stride_with_dilation[0]))
        self.add_module('layer3', self._make_layer(block, planes[2], layers[2],
            stride=1, dilate=replace_stride_with_dilation[1]))
        self.add_module(
            'layer4', self._make_layer(block, planes[3], layers[3],
            stride=1, dilate=replace_stride_with_dilation[2])
        )

        project_planes = self.layer4[-1].bn3.num_features \
            if isinstance(self.layer4[-1], M.resnet.Bottleneck) \
            else self.layer4[-1].bn2.num_features

        replace(self)

然后测试的代码是

import torch
import model.resnet as R

net = R.resnet101(None, planes = [64,32,128,32]).cuda()

for i in range(100):
    net(torch.rand(1,3,128,128).cuda())

for i in range(100):
    net(torch.rand(1,3,128,128).cuda())

net = torch.compile(R.resnet101(None, planes = [64,32,128,32]).cuda())

for i in range(100):
    net(torch.rand(1,3,128,128).cuda())

for i in range(100):
    net(torch.rand(1,3,128,128).cuda())

net = torch.compile(R.resnet101(None, planes = [64,32,128,32]).cuda(), backend="aot_ts_nvfuser")

for i in range(100):
    net(torch.rand(1,3,128,128).cuda())

for i in range(100):
    net(torch.rand(1,3,128,128).cuda())

net = torch.compile(R.resnet101(None, planes = [64,32,128,32]).cuda(), backend="nvprims_nvfuser")

for i in range(100):
    net(torch.rand(1,3,128,128).cuda())

for i in range(100):
    net(torch.rand(1,3,128,128).cuda())

测试的时候,前项传播的过程是要执行两次的,因为如果只执行一次的话,不一定准确。

如何定为到 torch.compile 的问题的呢?

首先,一开始使用 objgraph 对内存使用状况进行监测。发现了内存虽然增长了,但是随着训练的时间增加,却没有检测出来任何增加的对象。

然后,就换了 memory_profiler。换之前,实际上通过各种方法,将可疑的的部份替换点,来看是否是有问题的,但是从来没有假设到 torch.compile 和 神经网络会有问题,所以就一直找不到问题。 然后更换了 memory_profiler 之后,就发现,模型执行的地方存在这个问题。 memory profiler 提示一直在分配内存。

一开始以为是我代码的实现上有问题,在模型或者损失函数的地方导致了内存泄漏。但是发现并不是,然后偶然将 torch.compile 关闭之后,就没有内存泄漏的问题了。

最后在网上搜索相关的信息后,发现问题并解决了。