diff --git a/xtuner/v1/model/base.py b/xtuner/v1/model/base.py index 79ca92c93..bb11302ac 100644 --- a/xtuner/v1/model/base.py +++ b/xtuner/v1/model/base.py @@ -1217,7 +1217,11 @@ def _fsdp_foreach_allgather( else: origin_fsdp_size.append(load_spec.shape[self.FSDP_SHARD_DIM]) - _fsdp_unsharded_tensor_list = foreach_all_gather(padded_tensor_list, self.fsdp_mesh.get_group()) + _fsdp_unsharded_tensor_list = foreach_all_gather( + padded_tensor_list, + self.fsdp_mesh.get_group(), + [[tuple(t.size()) for t in padded_tensor_list]] * self.fsdp_mesh.size(), + ) fsdp_unsharded_tensor_list = [] # Concatenate the tensors along the FSDP shard dim diff --git a/xtuner/v1/ops/comm/foreach_allgather.py b/xtuner/v1/ops/comm/foreach_allgather.py index 2aa949e68..c46ed6cb7 100644 --- a/xtuner/v1/ops/comm/foreach_allgather.py +++ b/xtuner/v1/ops/comm/foreach_allgather.py @@ -1,3 +1,5 @@ +from functools import reduce +from operator import mul from typing import cast import torch @@ -9,6 +11,7 @@ def foreach_all_gather( params: list[torch.Tensor], group: dist.ProcessGroup | None, + params_shapes_across_group: list[list[tuple[int, ...]]] | None = None, ) -> list[list[torch.Tensor]]: if group is None: group = dist.group.WORLD @@ -19,35 +22,56 @@ def foreach_all_gather( input_tensor_numels = [param.numel() for param in params] input_tensor_shapes = [param.shape for param in params] - flatten_copyin_tensor = torch.empty((sum(input_tensor_numels),), dtype=param0.dtype, device=param0.device) - splits_copyin_tensor = torch.split(flatten_copyin_tensor, input_tensor_numels) - torch._foreach_copy_(splits_copyin_tensor, [p.flatten() for p in params]) - - input_tensor_numels_tensor = torch.tensor(input_tensor_numels, dtype=torch.int64, device=param0.device) - global_input_tensor_numels = [ - torch.zeros_like(input_tensor_numels_tensor) for _ in range(dist.get_world_size(group)) - ] - - dist.all_gather(global_input_tensor_numels, input_tensor_numels_tensor, group=group) - copyout_size = int(sum(sum(i) for i in global_input_tensor_numels)) - flatten_copyout_tensor = torch.empty((copyout_size,), dtype=param0.dtype, device=param0.device) - - dist.all_gather_into_tensor(flatten_copyout_tensor, flatten_copyin_tensor, group=group) - copyout_split_size: list[int] = sum([i.tolist() for i in global_input_tensor_numels], []) - splits_copyout_tensor = torch.split(flatten_copyout_tensor, copyout_split_size) - - _global_input_tensor_shapes: list[None] | list[list[tuple]] = [None for _ in range(dist.get_world_size(group))] - dist.all_gather_object(_global_input_tensor_shapes, input_tensor_shapes, group=group) - _global_input_tensor_shapes = cast(list[list[tuple]], _global_input_tensor_shapes) - global_input_tensor_shapes: list[tuple] = sum(_global_input_tensor_shapes, []) - - gathered_params: list[list[torch.Tensor]] = [] - for i in range(len(params)): - single_gathered_params: list[torch.Tensor] = [] - for rank in range(dist.get_world_size(group)): - offset = len(params) * rank - origin_shape: tuple = global_input_tensor_shapes[offset + i] - single_gathered_params.append(splits_copyout_tensor[offset + i].view(origin_shape)) - gathered_params.append(single_gathered_params) + global_input_tensor_numels: list[torch.Tensor] + if params_shapes_across_group is None: + input_tensor_numels_tensor = torch.tensor(input_tensor_numels, dtype=torch.int64, device=param0.device) + global_input_tensor_numels = [ + torch.zeros_like(input_tensor_numels_tensor) for _ in range(dist.get_world_size(group)) + ] + dist.all_gather(global_input_tensor_numels, input_tensor_numels_tensor, group=group) + else: + global_input_tensor_numels = [ + torch.tensor([reduce(mul, shape, 1) for shape in param_shapes], dtype=torch.int64, device="cpu") + for param_shapes in params_shapes_across_group # each param_shapes represents all params shapes on one rank + ] + + if len(params) == 1: + param0_shape_except_dim0 = list(param0.shape)[1:] + param0_numel_except_dim0 = param0[0].numel() + # Calculate the size of dimension 0 of the gathered tensor, it's compatible for the case of uneven split + split_dim0_sizes = [t.tolist()[0] // param0_numel_except_dim0 for t in global_input_tensor_numels] + gathered_tensor_dim0_size = sum(split_dim0_sizes) + + # all_gather_into_tensor gather different ranks data along dimension 0 + gathered_tensor = torch.empty( + (gathered_tensor_dim0_size, *param0_shape_except_dim0), dtype=param0.dtype, device=param0.device + ) + dist.all_gather_into_tensor(gathered_tensor, param0, group=group) + return [gathered_tensor.split(split_dim0_sizes, dim=0)] + else: + flatten_copyin_tensor = torch.empty((sum(input_tensor_numels),), dtype=param0.dtype, device=param0.device) + splits_copyin_tensor = torch.split(flatten_copyin_tensor, input_tensor_numels) + torch._foreach_copy_(splits_copyin_tensor, [p.flatten() for p in params]) + + copyout_size = int(sum(sum(i) for i in global_input_tensor_numels)) + flatten_copyout_tensor = torch.empty((copyout_size,), dtype=param0.dtype, device=param0.device) + + dist.all_gather_into_tensor(flatten_copyout_tensor, flatten_copyin_tensor, group=group) + copyout_split_size: list[int] = sum([i.tolist() for i in global_input_tensor_numels], []) + splits_copyout_tensor = torch.split(flatten_copyout_tensor, copyout_split_size) + + _global_input_tensor_shapes: list[None] | list[list[tuple]] = [None for _ in range(dist.get_world_size(group))] + dist.all_gather_object(_global_input_tensor_shapes, input_tensor_shapes, group=group) + _global_input_tensor_shapes = cast(list[list[tuple]], _global_input_tensor_shapes) + global_input_tensor_shapes: list[tuple] = sum(_global_input_tensor_shapes, []) + + gathered_params: list[list[torch.Tensor]] = [] + for i in range(len(params)): + single_gathered_params: list[torch.Tensor] = [] + for rank in range(dist.get_world_size(group)): + offset = len(params) * rank + origin_shape: tuple = global_input_tensor_shapes[offset + i] + single_gathered_params.append(splits_copyout_tensor[offset + i].view(origin_shape)) + gathered_params.append(single_gathered_params) return gathered_params