grad_norm通常用模型参数梯度的范数表示,常用二范数计算;所以grad_norm出现NAN值先排查模型参数的梯度是否存在NAN值;

Megatron-LM中grad_norm计算方法是MegatronOptimizer类中的clip_grad_norm方法(megatron/optimizer/optimizer.py文件中)

def clip_grad_norm(self, clip_grad, check_for_nan_in_grad):
    params = self.get_parameters()
    grads_for_norm = self.get_main_grads_for_grad_norm()
    return clip_grad_norm_fp32(
        params, grads_for_norm, clip_grad,
        check_for_nan_in_grad,
        model_parallel_group=self.get_model_parallel_group())

出现该问题通常有以下几种原因:

机器问题

如果是机器问题导致的计算错误,只需要定位到有问题的卡即可;这时只需要检查各个卡上的参数梯度是否存在NAN值,如果存在,则打印卡的TP_rank、PP_rank、DP_rank即可定位出问题卡;在clip_grad_norm中添加如下代码即可:

def clip_grad_norm(self, clip_grad, check_for_nan_in_grad):
    params = self.get_parameters()

    # check param.grad to find NAN
    for param in params:
        if torch.isnan(param.grad).any():
            from megatron.core import parallel_state
            tp_rank = parallel_state.get_tensor_model_parallel_rank()
            pp_rank = parallel_state.get_pipeline_model_parallel_rank()
            dp_rank = parallel_state.get_data_parallel_rank()
            print(f"param occurrence NAN rank: tp={tp_rank}, pp={pp_rank}, dp={dp_rank}")

    grads_for_norm = self.get_main_grads_for_grad_norm()
    return clip_grad_norm_fp32(
        params, grads_for_norm, clip_grad,
        check_for_nan_in_grad,
        model_parallel_group=self.get_model_parallel_group())

数据触发的算子问题

如果机器硬件没问题,那么发生计算错误通常是由特定数据触发的算子bug,这时需要把算子找出来,然后dump下来反向计算的数据;

找到计算出错的算子

可以从参数梯度为NAN的参数名入手,然后在模型结构中找到使用该算子的模块; 优化器里的参数没有名字怎么找呢?在建立优化器参数时把参数位置和参数名持久化保存下来,在检查参数梯度时根据参数的位置去查找参数名即可。建立优化器参数的方法是megatron/optimizer/init.py里的get_param_groups方法,修改如下:

def get_param_groups(modules,
                     no_weight_decay_cond,
                     scale_lr_cond,
                     lr_mult):
    """creates param groups based on weight decay condition (regularized vs non regularized)
       and learning rate scale condition (args.lr vs lr_mult * args.lr)
       scale_lr_cond is used during finetuning where head of the network requires a scaled
       version of the base learning rate. 
    """
    wd_no_scale_lr, wd_no_scale_lr_names = [], []
    wd_scale_lr, wd_scale_lr_names = [], []
    no_wd_no_scale_lr, no_wd_no_scale_lr_names = [], []
    no_wd_scale_lr, no_wd_scale_lr_names = [], []
    for module in modules:
        for name, param in module.named_parameters():
            if not param.requires_grad:
                continue

            if no_weight_decay_cond is not None:
                no_wd = no_weight_decay_cond(name, param)
            else:
                # do not regularize biases nor Norm parameters
                no_wd = name.endswith(".bias") or len(param.shape) == 1

            if scale_lr_cond is not None:
                scale_lr = scale_lr_cond(name, param)
            else:
                scale_lr = False

            if not no_wd and not scale_lr:
                wd_no_scale_lr.append(param)
                wd_no_scale_lr_names.append(name)
            elif not no_wd and scale_lr:
                wd_scale_lr.append(param)
                wd_scale_lr_names.append(name)
            elif no_wd and not scale_lr:
                no_wd_no_scale_lr.append(param)
                no_wd_no_scale_lr_names.append(name)
            else:
                no_wd_scale_lr.append(param)
                no_wd_scale_lr_names.append(name)

    param_groups = []
    if len(wd_no_scale_lr):
        param_groups.append({'params': wd_no_scale_lr, 'wd_mult': 1.0, 'lr_mult': 1.0, 'names': wd_no_scale_lr_names})
    if len(wd_scale_lr):
        param_groups.append({'params': wd_scale_lr, 'wd_mult': 1.0, 'lr_mult': lr_mult, 'names': wd_scale_lr_names})
    if len(no_wd_no_scale_lr):
        param_groups.append(
            {'params': no_wd_no_scale_lr, 'wd_mult': 0.0, 'lr_mult': 1.0, 'names': no_wd_no_scale_lr_names})
    if len(no_wd_scale_lr):
        param_groups.append(
            {'params': no_wd_scale_lr, 'wd_mult': 0.0, 'lr_mult': lr_mult, 'names': no_wd_scale_lr_names})

    # build the mapping between param id and name
    param_id = 0
    param_id_name_map = {}
    for param in param_groups:
        for name in param['names']:
            param_id_name_map[param_id] = name
            param_id += 1
    # Persistent mapping relationship
    from megatron.core import parallel_state
    tp_rank = parallel_state.get_tensor_model_parallel_rank()
    pp_rank = parallel_state.get_pipeline_model_parallel_rank()
    dp_rank = parallel_state.get_data_parallel_rank()
    import json
    with open(f"param_id_name_map_tp{tp_rank}pp{pp_rank}dp{dp_rank}.json", 'w') as f:
        f.write(json.dumps(param_id_name_map))

    return param_groups

在clip_grad_norm中检查出param.grad为NAN时,根据param_id去对应文件查找参数名即可,示例代码如下:

def clip_grad_norm(self, clip_grad, check_for_nan_in_grad):
        params = self.get_parameters()
        # check param.grad to find NAN
        for param_id, param in enumerate(params):
            if torch.isnan(param.grad).any():
                from megatron.core import parallel_state
                tp_rank = parallel_state.get_tensor_model_parallel_rank()
                pp_rank = parallel_state.get_pipeline_model_parallel_rank()
                dp_rank = parallel_state.get_data_parallel_rank()
                print(
                    f"find param name by param_id={param_id} in param_id_name_map_tp{tp_rank}pp{pp_rank}dp{dp_rank}.json")

        grads_for_norm = self.get_main_grads_for_grad_norm()
        return clip_grad_norm_fp32(
            params, grads_for_norm, clip_grad,
            check_for_nan_in_grad,
            model_parallel_group=self.get_model_parallel_group())

找到计算错误的模块,dump输入输出数据

找到param.grad为NAN的参数名后,就能在模型里找对应的nn.Module模块;然后注册模块的前向和反向钩子函数,获取该模块的前向输入输出、反向输入梯度和输出梯度,这样就能获取导致NAN的具体数据,保存下来复现问题了;示例如下:

import torch
from torch import nn
import torch.nn.functional as F


def check_nan(t):
    """
    检查 NAN
    """
    if isinstance(t, (list, tuple)) and t:
        for item in x:
            check_nan(item)
    else:
        if torch.isnan(t).any():
            return True
        return False


def check_nan_forward_fn_hook(module, input, output):
    """NAN检查的前向钩子函数
    module: 执行注册的 nn.Module模块
    input: module的forward输入,因为输入可能是多个,所以input是元组格式
    output: module forward的输出,tensor格式
    """
    print(f"module : {module}")
    print(f"input : {input}")
    print(f"output : {output}")

    if check_nan(input):
        # forward入参检查到NAN
        pass  # 此处添加保存逻辑,把数据保存下来
    if check_nan(output):
        # forward输出检测到NAN
        pass  # 此处添加保存逻辑,把数据保存下来


def check_nan_backward_fn_hook(module, input_grad, output_grad):
    """NAN检查的反向钩子函数
    module: 执行注册的 nn.Module模块
    input_grad: 输入梯度,元组格式,包括module的forward输入,module里的参数输入梯度, 排列顺序和算子实现相关
                如果输入或参数requires_grad=False, 相应位置值为None(forward入参默认就是False)
    output_grad: 输出梯度, 元组格式
    """
    print(f"module : {module}")
    print(f"input_grad : {input_grad}")
    print(f"output_grad : {output_grad}")
    if check_nan(input_grad):
        # forward入参梯度检查到NAN
        pass  # 此处添加保存逻辑,把数据保存下来
    if check_nan(output_grad):
        # forward输出梯度检测到NAN
        pass  # 此处添加保存逻辑,把数据保存下来


class MyNet(nn.Module):
    def __init__(self):
        super(MyNet, self).__init__()

        self.fc1 = nn.Linear(3, 2)
        self.fc2 = nn.Linear(2, 4)

        self.fc1.register_forward_hook(check_nan_forward_fn_hook)
        self.fc1.register_backward_hook(check_nan_backward_fn_hook)

    def forward(self, x):
        return self.fc2(F.silu(self.fc1(x)))


x = torch.randn(2, 3)
print(f"x : {x}")
# x.requires_grad = True
model = MyNet()
y = model(x)
z = torch.mean(y)
z.backward()
#以上代码执行的示例输出如下:
x : tensor([[ 0.3263,  0.6122,  0.8691],
        [ 0.3619, -0.7433,  0.7780]])
module : Linear(in_features=3, out_features=2, bias=True)
input : (tensor([[ 0.3263,  0.6122,  0.8691],
        [ 0.3619, -0.7433,  0.7780]]),)
output : tensor([[ 0.1120,  0.0568],
        [-0.1083, -0.4955]], grad_fn=<AddmmBackward0>)
module : Linear(in_features=3, out_features=2, bias=True)
input_grad : (tensor([ 0.1115, -0.0605]), None, tensor([[ 0.0382, -0.0204],
        [ 0.0010, -0.0098],
        [ 0.0924, -0.0507]]))
output_grad : (tensor([[ 0.0619, -0.0404],
        [ 0.0496, -0.0200]]),)
Logo

鲲鹏昇腾开发者社区是面向全社会开放的“联接全球计算开发者,聚合华为+生态”的社区,内容涵盖鲲鹏、昇腾资源,帮助开发者快速获取所需的知识、经验、软件、工具、算力,支撑开发者易学、好用、成功,成为核心开发者。

更多推荐