Source code for easycv.models.utils.gather_layer

# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import torch.distributed as dist


[docs]class GatherLayer(torch.autograd.Function): '''Gather tensors from all process, supporting backward propagation. '''
[docs] @staticmethod def forward(ctx, input): ctx.save_for_backward(input) output = [ torch.zeros_like(input) for _ in range(dist.get_world_size()) ] dist.all_gather(output, input) return tuple(output)
[docs] @staticmethod def backward(ctx, *grads): input, = ctx.saved_tensors grad_out = torch.zeros_like(input) grad_out[:] = grads[dist.get_rank()] return grad_out