谁动了我的显存?——深度学习训练过程显存占用分析及优化

在大语言模型时代,不仅语言模型变得越来越大,而且几乎所有的模型都想变得越来越大,试图在模型变大之后观察到一些涌现出来的能力。模型变大之后,最突出的问题就是显存不够用了。本文对深度学习训练过程中的显存占用问题进行一些具体分析,加深我对训练过程的理解,能够进行一些简单的显存优化操作。如果读者们有更多的相关资料、优化技巧,欢迎在评论区补充。显存占用概述深度学习训练过程中的显存占用,大致可以分为三部分:框架占用,例如pytorch框架的cuda context会占用大约几百MB显存模型参数相关的占用,比如7B的模型以FP16格式要占用14GB显存。此处还包括优化器、梯度相关的参数占用,全量微调的情况下,梯度与参数一样大,优化器状态是梯度的1~2倍(SGD为1倍,Adam为2倍)。如果使用DDP进行多卡训练,则还需要乘以显卡数量;如果使用FSDP进行多卡训练,显存占用与显卡数无关,但是会增加通信开销。特征相关的占用,这部分显存占用是最复杂的,因为它与模型的具体计算流程有关。很多地方只会笼统地说这类占用与batchsize成正比,但是具体的比例系数很难分析。本文希望详细解析特征相关的显存占用到底是多少。统计方法我们用一个样例程序,来使用不同的方法、在不同的情况下计算这样一个简单的函数。具体程序为:import torch # Create two tensors with 1GB memory footprint each, initialized randomly, in fp16 format # For a tensor of float16 (2 bytes), 1GB of memory can hold 1GB / 2B = 500M elements tensor_size = 512 * 1024 * 1024 x = torch.randn(tensor_size, dtype=torch.float16, device='cuda') y = torch.randn(tensor_size, dtype=torch.float16, device='cuda') # Record current memory footprint, and reset max memory counter current_memory = torch.cuda.memory_allocated() torch.cuda.reset_peak_memory_stats() def compute(x, y): return (x + 1) * (y + 1) z = compute(x, y) # Record the additional memory (both peak memory and persistent memory) after calculating the resulting tensor additional_memory = torch.cuda.memory_allocated() - (current_memory + 1e9) peak_memory = torch.cuda.max_memory_allocated() additional_peak_memory = peak_memory - (current_memory + 1e9) print(f"Additional memory used: {additional_memory / (1024 ** 3)} GB") print(f"Additional peak memory used: {additional_peak_memory / (1024 ** 3)} GB")在这个函数计算过程中,输入、,输出不可避免地要占用显存。我们希望在不同情况下、改变不同的计算方式,观察/理解为了计算这个函数所需要的额外显存。这里需要区分两个概念:峰值显存占用 与 持续显存占用 。在计算一个函数的过程中,我们可能创建了很多中间结果,它们需要临时占用显存;但是当函数计算完成之后,只有一部分结果需要持续存在(直到反向传播结束),另一部分可以被释放。上述示例小脚本,会分别输出持续显存占用和峰值显存占用。一:不需要计算梯度的情况上述示例脚本,直接运行的结果是:Additional memory used: 0.06867742538452148 GB Additional peak memory used: 2.0686774253845215 GB也就是说,函数运行期间需要大约2GB的显存占用,运行结束之后几乎不占显存。具体来说,函数计算过程中需要创建 和两个临时变量,乘积结果放在中。因此大约需要2GB的显存来存储临时变量,它们在计算结束后会被释放。至于为什么持续显存占用不严格为0、峰值显存占用不严格为2GB,这就与pytorch的具体显存管理策略、对象的显存布局有关,我们暂时不关心这部分内容。二:需要计算梯度的情况我们把计算函数改写为:def compute(x, y): x.requires_grad_(True) y.requires_grad_(True) return (x + 1) * (y + 1)得到的结果为:Additional memory used: 2.0686774253845215 GB Additional peak memory used: 2.0686774253845215 GB也就是说,需要计算梯度时,计算过程中的临时变量并不会被释放,反而会持续存在于显存中,等待后续用于反向传播计算。这个问题可以变得更复杂一些,如果我们让一个输入要求梯度、一个参数不要求梯度,会发生什么呢?def compute(x, y): x.requires_grad_(True) return (x + 1) * (y + 1)得到的结果是:Additional memory used: 1.0686774253845215 GB Additional peak memory used: 2.0686774253845215 GB可以看到,计算完成后释放了一个临时变量,还有一个临时变量持续存在。这是因为我们只要求能计算梯度,不用计算梯度。有意思的是,大部分人看到这里,都觉得既然不需要计算梯度,那么肯定是这个临时变量被释放了。然而,事实上是这个临时变量被释放掉了。为了说清楚这个问题,我们用具体的值来区分和,这里的值是1,的值是2,于是临时变量的值是2,的值是3.通过计算结果记录的中间变量的值,我们可以区分具体记录了哪个中间结果。def compute(x, y): x.zero_() y.zero_() x += 1 y += 2 x.requires_grad_(True) z = (x + 1) * (y + 1) print(z.grad_fn._saved_other.mean().item()) return z这段代码的运行结果是:3.0 Additional memory used: 1.0686774253845215 GB Additional peak memory used: 2.0686774253845215 GB可以看到,虽然是要求梯度,但是在计算过程中保留的变量却是。为了从原理上理解这个现象,我们来看看反向传播的本质:梯度求导。考虑神经网络中的某个函数,输入为和两个参数,输出为。将继续参与后续运算,得到损失函数。反向传播的任务,就是在已知的情况下,计算和。根据链式法则,且。于是,为了反向传播,我们需要记录和。不失一般性而言,是和的函数。于是,为了反向传播,我们需要完整记录和 。这是最简单粗暴的方法。实际上,对于很多简单函数来说,偏导数的表达式并不复杂。以本文的小脚本为例,,于是只和有关。也就是说,为了计算的反向传播,只需要记录的值。于是,我们就能理解,为什么需要梯度(对应地也需要梯度)时,反向传播记录的是。这部分的内容本质上就是自动微分的内容。当我们为每一个原子操作(例如加减乘除)写好了反向传播算法,自动微分就能够沿着计算图进行自动求导。这种方法写起来很简单,也很直观。然而,它的缺点也很明显:显存占用大。三:不使用自动微分计算梯度有什么办法能够绕开自动微分的限制,使得显存开销更低吗?有的,答案就是pytorch提供的torch.autograd.Function。我们把计算部分的代码替换成Function的实现,直接用一个算子实现的功能:from torch.autograd import Function class AddMulFunction(Function): @staticmethod def forward(ctx, x, y): ctx.save_for_backward(x, y) return (x + 1) * (y + 1) @staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors grad_x = grad_output * (y + 1) grad_y = grad_output * (x + 1) return grad_x, grad_y func = AddMulFunction.apply def compute(x, y): x.requires_grad_(True) y.requires_grad_(True) return func(x, y)输出结果为:Additional memory used: 0.06867742538452148 GB Additional peak memory used: 2.0686774253845215 GB这个算子也能够进行反向传播,而且计算结束之后并不会占用显存。这是因为我们在它的backward函数里手动计算了这个算子的梯度,使得它不用记录临时变量和也能进行反向传播。从这个算子的实现中,我们能清晰地看到ctx.save_for_backward函数,它为反向传播过程记录必要的参数。关于torch.autograd.Function,有一个细节值得注意:torch.autograd.Function设计的初衷就是为了让高级用户绕开自动微分的限制,因此torch.autograd.Function的forward和backward函数执行过程中,并不会记录梯度操作。大致可以理解为:torch.autograd.Function的forward和backward函数执行过程被包裹在 with torch.no_grad()环境中。例如,我们把计算代码改成:from torch.autograd import Function class AddMulFunction(Function): @staticmethod def forward(ctx, x, y): ctx.save_for_backward(x, y) z = (x + 1) * (y + 1) print(z.requires_grad) print(z.grad_fn) return z @staticmethod def backward(ctx, grad_output): x, y = ctx.saved_tensors grad_x = grad_output * (y + 1) grad_y = grad_output

Jul 7, 2023 - 07:00
 0
谁动了我的显存?——深度学习训练过程显存占用分析及优化

在大语言模型时代,不仅语言模型变得越来越大,而且几乎所有的模型都想变得越来越大,试图在模型变大之后观察到一些涌现出来的能力。

模型变大之后,最突出的问题就是显存不够用了。本文对深度学习训练过程中的显存占用问题进行一些具体分析,加深我对训练过程的理解,能够进行一些简单的显存优化操作。如果读者们有更多的相关资料、优化技巧,欢迎在评论区补充。

显存占用概述

深度学习训练过程中的显存占用,大致可以分为三部分:

  • 框架占用,例如pytorch框架的cuda context会占用大约几百MB显存
  • 模型参数相关的占用,比如7B的模型以FP16格式要占用14GB显存。此处还包括优化器、梯度相关的参数占用,全量微调的情况下,梯度与参数一样大,优化器状态是梯度的1~2倍(SGD为1倍,Adam为2倍)。如果使用DDP进行多卡训练,则还需要乘以显卡数量;如果使用FSDP进行多卡训练,显存占用与显卡数无关,但是会增加通信开销。
  • 特征相关的占用,这部分显存占用是最复杂的,因为它与模型的具体计算流程有关。很多地方只会笼统地说这类占用与batchsize成正比,但是具体的比例系数很难分析。

本文希望详细解析特征相关的显存占用到底是多少。

统计方法

我们用一个样例程序,来使用不同的方法、在不同的情况下计算(x+1)(y+1)这样一个简单的函数。具体程序为:

import torch

# Create two tensors with 1GB memory footprint each, initialized randomly, in fp16 format
# For a tensor of float16 (2 bytes), 1GB of memory can hold 1GB / 2B = 500M elements
tensor_size = 512 * 1024 * 1024 
x = torch.randn(tensor_size, dtype=torch.float16, device='cuda')
y = torch.randn(tensor_size, dtype=torch.float16, device='cuda')

# Record current memory footprint, and reset max memory counter
current_memory = torch.cuda.memory_allocated()
torch.cuda.reset_peak_memory_stats()

def compute(x, y):
    return (x + 1) * (y + 1)

z = compute(x, y)

# Record the additional memory (both peak memory and persistent memory) after calculating the resulting tensor
additional_memory = torch.cuda.memory_allocated() - (current_memory + 1e9)
peak_memory = torch.cuda.max_memory_allocated()
additional_peak_memory = peak_memory - (current_memory + 1e9)

print(f"Additional memory used: {additional_memory / (1024 ** 3)} GB")
print(f"Additional peak memory used: {additional_peak_memory / (1024 ** 3)} GB")

在这个函数计算过程中,输入xy,输出z不可避免地要占用显存。我们希望在不同情况下、改变不同的计算方式,观察/理解为了计算这个函数所需要的额外显存。

这里需要区分两个概念:峰值显存占用 与 持续显存占用 。在计算一个函数的过程中,我们可能创建了很多中间结果,它们需要临时占用显存;但是当函数计算完成之后,只有一部分结果需要持续存在(直到反向传播结束),另一部分可以被释放。上述示例小脚本,会分别输出持续显存占用和峰值显存占用。

一:不需要计算梯度的情况

上述示例脚本,直接运行的结果是:

Additional memory used: 0.06867742538452148 GB
Additional peak memory used: 2.0686774253845215 GB

也就是说,函数运行期间需要大约2GB的显存占用,运行结束之后几乎不占显存。

具体来说,函数计算过程中需要创建 x+1y+1两个临时变量,乘积结果放在z中。因此大约需要2GB的显存来存储临时变量,它们在计算结束后会被释放。

至于为什么持续显存占用不严格为0、峰值显存占用不严格为2GB,这就与pytorch的具体显存管理策略、对象的显存布局有关,我们暂时不关心这部分内容。

二:需要计算梯度的情况

我们把计算函数改写为:

def compute(x, y):
    x.requires_grad_(True)
    y.requires_grad_(True)
    return (x + 1) * (y + 1)

得到的结果为:

Additional memory used: 2.0686774253845215 GB
Additional peak memory used: 2.0686774253845215 GB

也就是说,需要计算梯度时,计算过程中的临时变量并不会被释放,反而会持续存在于显存中,等待后续用于反向传播计算。

这个问题可以变得更复杂一些,如果我们让一个输入要求梯度、一个参数不要求梯度,会发生什么呢?

def compute(x, y):
    x.requires_grad_(True)
    return (x + 1) * (y + 1)

得到的结果是:

Additional memory used: 1.0686774253845215 GB
Additional peak memory used: 2.0686774253845215 GB

可以看到,计算完成后释放了一个临时变量,还有一个临时变量持续存在。这是因为我们只要求x能计算梯度,y不用计算梯度。

有意思的是,大部分人看到这里,都觉得既然y不需要计算梯度,那么肯定是y+1这个临时变量被释放了。然而,事实上是x+1这个临时变量被释放掉了。

为了说清楚这个问题,我们用具体的值来区分xy,这里x的值是1,y的值是2,于是临时变量x+1的值是2,y+1的值是3.通过计算结果z记录的中间变量的值,我们可以区分z具体记录了哪个中间结果。

def compute(x, y):
    x.zero_()
    y.zero_()
    x += 1
    y += 2
    x.requires_grad_(True)
    z = (x + 1) * (y + 1)
    print(z.grad_fn._saved_other.mean().item())
    return z

这段代码的运行结果是:

3.0
Additional memory used: 1.0686774253845215 GB
Additional peak memory used: 2.0686774253845215 GB

可以看到,虽然是x要求梯度,但是在计算过程中保留的变量却是y+1

为了从原理上理解这个现象,我们来看看反向传播的本质:梯度求导。

考虑神经网络中的某个函数c = f(a,b),输入为ab两个参数,输出为cc将继续参与后续运算,得到损失函数J = g(c)。反向传播的任务,就是在已知\frac{\partial J}{\partial c}的情况下,计算\frac{\partial J}{\partial a}\frac{\partial J}{\partial b}

根据链式法则,\frac{\partial J}{\partial a} = \frac{\partial J}{\partial c} \frac{\partial c}{\partial a}\frac{\partial J}{\partial b} = \frac{\partial J}{\partial c} \frac{\partial c}{\partial b}。于是,为了反向传播,我们需要记录\frac{\partial c}{\partial a}\frac{\partial c}{\partial b}

不失一般性而言,\frac{\partial c}{\partial a}ab的函数。于是,为了反向传播,我们需要完整记录ab 。这是最简单粗暴的方法。

实际上,对于很多简单函数来说,偏导数的表达式并不复杂。以本文的小脚本为例,c = f(a,b)=a * b,于是\frac{\partial c}{\partial a}=b只和b有关。也就是说,为了计算a的反向传播,只需要记录b的值。

于是,我们就能理解,为什么x需要梯度(对应地x+1也需要梯度)时,反向传播记录的是y+1

这部分的内容本质上就是自动微分的内容。当我们为每一个原子操作(例如加减乘除)写好了反向传播算法,自动微分就能够沿着计算图进行自动求导。这种方法写起来很简单,也很直观。然而,它的缺点也很明显:显存占用大。

三:不使用自动微分计算梯度

有什么办法能够绕开自动微分的限制,使得显存开销更低吗?

有的,答案就是pytorch提供的torch.autograd.Function

我们把计算部分的代码替换成Function的实现,直接用一个算子实现(x+1)*(y+1)的功能:

from torch.autograd import Function

class AddMulFunction(Function):
    @staticmethod
    def forward(ctx, x, y):
        ctx.save_for_backward(x, y)
        return (x + 1) * (y + 1)

    @staticmethod
    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors
        grad_x = grad_output * (y + 1)
        grad_y = grad_output * (x + 1)
        return grad_x, grad_y

func = AddMulFunction.apply

def compute(x, y):
    x.requires_grad_(True)
    y.requires_grad_(True)
    return func(x, y)

输出结果为:

Additional memory used: 0.06867742538452148 GB
Additional peak memory used: 2.0686774253845215 GB

这个算子也能够进行反向传播,而且计算结束之后并不会占用显存。这是因为我们在它的backward函数里手动计算了这个算子的梯度,使得它不用记录临时变量x+1y+1也能进行反向传播。

从这个算子的实现中,我们能清晰地看到ctx.save_for_backward函数,它为反向传播过程记录必要的参数。

关于torch.autograd.Function,有一个细节值得注意:torch.autograd.Function设计的初衷就是为了让高级用户绕开自动微分的限制,因此torch.autograd.Functionforwardbackward函数执行过程中,并不会记录梯度操作。大致可以理解为:torch.autograd.Functionforwardbackward函数执行过程被包裹在 with torch.no_grad()环境中。

例如,我们把计算代码改成:

from torch.autograd import Function

class AddMulFunction(Function):
    @staticmethod
    def forward(ctx, x, y):
        ctx.save_for_backward(x, y)
        z = (x + 1) * (y + 1)
        print(z.requires_grad)
        print(z.grad_fn)
        return z

    @staticmethod
    def backward(ctx, grad_output):
        x, y = ctx.saved_tensors
        grad_x = grad_output * (y + 1)
        grad_y = grad_output * (x + 1)
        return grad_x, grad_y

func = AddMulFunction.apply

def compute(x, y):
    x.requires_grad_(True)
    y.requires_grad_(True)
    return func(x, y)

z = compute(x, y)

print(z.requires_grad)
print(z.grad_fn)

输出结果为:

False
None
True

Additional memory used: 0.06867742538452148 GB
Additional peak memory used: 2.0686774253845215 GB

即使xy是需要梯度的,在Functionforward函数中,z是不需要梯度的。然而,当走出forward函数之后,pytorch会为它加上需要梯度的标志,并且通过grad_fn属性记录其反向传播需要执行的函数。

通过这一细节,我们可以理解,为什么定义了AddMulFunction之后,不能直接使用AddMulFunction.forward函数,而必须用func = AddMulFunction.apply

以上涉及的内容,其实就是“算子融合”,通过手动计算反向传播过程,节约不必要的显存开销。上述算子还可以进一步优化,把峰值显存占用也降下来。感兴趣的朋友可以试试。

我们日常使用的很多算子,都是融合过的。

以sigmoid算子为例,如果我们自己来实现:

def compute(x):
    x.requires_grad_(True)
    z = 1 / (1 + torch.exp(-x))
    return z

z = compute(x)

输出结果为:

Additional memory used: 2.0686774253845215 GB
Additional peak memory used: 3.0686774253845215 GB

峰值显存占用为3GB,持续显存占用为2GB。

如果改为pytorch自带的已经融合过的算子:

def compute(x):
    x.requires_grad_(True)
    z = torch.nn.Sigmoid()(x)
    return z

z = compute(x)

输出结果为:

Additional memory used: 0.06867742538452148 GB
Additional peak memory used: 0.06867742538452148 GB

峰值显存占用与持续显存占用几乎都是0!

这是怎么做到的呢?

  • sigmoid函数是element-wise的函数,只需要申请一次显存,把所有的操作都变成in-place,再把这块显存作为输出内容,就不用申请临时空间了。
  • sigmoid函数 z=\frac{1}{1+e^{-x}}的导数是z * (1-z),为了计算反向传播,只需要记录输出z。而在我们的示例程序中,z原本就会保留,因此sigmoid函数的反向传播记录的z就不用额外占用空间。

算子显存占用分析中的记账问题

上述分析中,关于sigmoid算子显存占用为0的结论并不严谨。它占用的显存刚好是我们的输出,因此没有算在它的显存开销中。

为了更准确地反映这一问题,我们让它多计算几次:

def compute(x):
    x.requires_grad_(True)
    for i in range(5):    
        x = torch.nn.Sigmoid()(x)
    return x

z = compute(x)

计算5次,额外占用显存为4GB:

Additional memory used: 4.0686774253845215 GB
Additional peak memory used: 4.0686774253845215 GB

大体上来说,一个算子持续占用的显存,就是它在前向传播过程中保存下来的变量所占的显存。但一个程序占用的显存总量,并不能用全部算子占用的显存数进行求和,因为这些变量之间可能有重复(正如我们的示例中的输入变量、输出变量那样)。

总结

本文介绍了深度学习训练过程中的显存占用分析方法、自动求导与手动算子融合、优化等技术原理。算子融合是深度学习编译器等技术的核心,而算子优化目前还需要人工设计。对算子优化感兴趣的朋友,可以看看FlashAttention论文(参见《Flashattention: Fast and memory-efficient exact attention with io-awareness》),它是一个十分优雅的算子优化的例子。

注:

如何查看pytorch自带算子为反向传播保存的变量?可以通过输出的grad_fn属性的dir(var.grad_fn)看到,里面的_saved_xxx就是为了反向传播保存的变量。

对于乘法,这个属性是_saved_other,因为乘法的梯度是另一个变量;对于sigmoid算子,这个属性是_saved_result,因为sigmoid的梯度和计算结果有关。

大部分的pytorch算子都可以通过这种方式获得保存的具体变量内容,例如卷积算子保留了以下内容:_saved_bias_sym_sizes_opt/_saved_dilation/_saved_groups/_saved_input/_saved_output_padding/_saved_padding/_saved_stride/_saved_transposed/_saved_weight. 其中大部分都是卷积的配置(例如padding大小、stride大小等内容),真正对显存占用影响最大的就是_saved_input_saved_weight

附上这部分代码,感兴趣的读者可以用它来分析pytorch自带算子的具体计算机制。

var = z
names =[k for k in dir(var.grad_fn) if k.startswith('_saved')]
for k in names:
    v = getattr(var.grad_fn, k)
    if isinstance(v, torch.Tensor):
        print(k, v.shape)
    else:
        print(k, v)



来源:知乎 www.zhihu.com
作者:游凯超

【知乎日报】千万用户的选择,做朋友圈里的新鲜事分享大牛。 点击下载

like

dislike

love

funny

angry

sad

wow

李芷晴 https://tszching.uk