U
    Ph!                     @  s,  d dl mZ d dlZd dlZd dlmZ d dlmZ ejdkrJd dl	m
Z
 d dl	mZ d dlZd dlmZ d dlmZ d d	lmZmZ ed
ejed\ZZddddgZdd ZeddddddZeddddddZeddddddZd#ddddddZd$dddd d!dZG d"d deZdS )%    )annotationsN)Callable)Filter)      )Literal)overload)
IgniteInfo)min_versionoptional_importZignitedistributedget_dist_deviceevenly_divisible_all_gatherstring_list_all_gather
RankFilterc                  C  sN   t  rJt  } | dkr8tj r8tdtj  S | dkrJtdS dS )a  
    Get the expected target device in the native PyTorch distributed data parallel.
    For NCCL backend, return GPU device of current process.
    For GLOO backend, return CPU.
    For any other backends, return None as the default, tensor.to(None) will not change the device.

    ncclzcuda:gloocpuN)distis_initializedget_backendtorchcudais_availabledevicecurrent_device)backend r   E/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/utils/dist.pyr   #   s    
torch.TensorzLiteral[True])dataconcatreturnc                 C  s   d S Nr   r    r!   r   r   r   r   4   s    zLiteral[False]list[torch.Tensor]c                 C  s   d S r#   r   r$   r   r   r   r   8   s    boolz!torch.Tensor | list[torch.Tensor]c                 C  s   d S r#   r   r$   r   r   r   r   <   s    Tc                   s   t | tjstd|  dkr.| jd nd ddd fdd}ddd fd	d
}tr~t dkrr| S || d}n0t	
 rt	 rt	 dkr| S || d}n| S |rtj|ddS |S )a)  
    Utility function for distributed data parallel to pad at first dim to make it evenly divisible and all_gather.
    The input data of every rank should have the same number of dimensions, only the first dim can be different.

    Note: If has ignite installed, will execute based on ignite distributed APIs, otherwise, if the native
    PyTorch distributed group initialized, will execute based on native PyTorch distributed APIs.

    Args:
        data: source tensor to pad and execute all_gather in distributed data parallel.
        concat: whether to concat the gathered list to be a Tensor, if False, return a list
            of Tensors, similar behavior as torch.distributed.all_gather(). default to True.

    Note:
        The input data on different ranks must have exactly same `dtype`.

    z"input data must be PyTorch Tensor.r      r   r%   )r    r"   c                   s   t  } j | dkr( dn  tjg|dfddtt D }t	| dd |D }t
|}|k r| gt jdd  }tj  |dgdd  fd	dtt D }t	|  fd
dt||D S )zY
        Implementation based on native PyTorch distributed data parallel APIs.

        r   )r   c                   s   g | ]}t  qS r   r   
zeros_like.0_)length_tensorr   r   
<listcomp>c   s     zJevenly_divisible_all_gather.<locals>._torch_all_gather.<locals>.<listcomp>c                 S  s   g | ]}t | qS r   )intitem)r+   ir   r   r   r.   e   s     r'   Ndimc                   s   g | ]}t  qS r   r(   r*   r    r   r   r.   l   s     c                   s8   g | ]0\}} d kr| d n|d|df qS )r   N.)squeezeto)r+   ol)ndimsorig_devicer   r   r.   o   s     )r   r   r6   	unsqueezer   	as_tensorranger   get_world_size
all_gathermaxlistshapecatnew_fullzip)r    r   all_lensZ	all_lens_max_lensizeoutputlengthr9   )r    r-   r:   r   _torch_all_gatherW   s    
z6evenly_divisible_all_gather.<locals>._torch_all_gatherc                   s   dkr|  dn| } t}t|  k rf  gt| jdd  }tj| | |dgdd} t| dkrttj	ddS  fddt
|D S )zi
        Implementation based on PyTorch ignite package, it can support more kinds of backends.

        r   r'   Nr2   c                   s,   g | ]$\}}|  |  | d f qS ).r   )r+   r1   r8   rG   rI   r   r   r.      s     zKevenly_divisible_all_gather.<locals>._ignite_all_gather.<locals>.<listcomp>)r;   idistr?   r@   rA   rB   r   rC   rD   unbind	enumerate)r    rF   rH   rJ   rM   r   _ignite_all_gatherq   s    

z7evenly_divisible_all_gather.<locals>._ignite_all_gatherr4   r2   )
isinstancer   Tensor
ValueError
ndimensionrB   
has_igniterN   r>   r   r   r   rC   )r    r!   rL   rQ   rI   r   rJ   r   r   @   s     	z	list[str]str)strings	delimiterr"   c                   s~   d}t rt }nt r*t r*t }|dkr6| S  | }ttj	t
|dtjddd} fdd|D }dd |D S )	a  
    Utility function for distributed data parallel to all gather a list of strings.
    Refer to the idea of ignite `all_gather(string)`:
    https://pytorch.org/ignite/v0.4.5/distributed.html#ignite.distributed.utils.all_gather.

    Note: If has ignite installed, will execute based on ignite distributed APIs, otherwise, if the native
    PyTorch distributed group initialized, will execute based on native PyTorch distributed APIs.

    Args:
        strings: a list of strings to all gather.
        delimiter: use the delimiter to join the string list to be a long string,
            then all gather across ranks and split to a list. default to "	".

    r'   utf-8)dtypeF)r!   c                   s$   g | ]}t | d  qS )r[   )	bytearraytolistdecodesplit)r+   grZ   r   r   r.      s     z*string_list_all_gather.<locals>.<listcomp>c                 S  s   g | ]}|D ]}|qqS r   r   )r+   kr1   r   r   r   r.      s       )rV   rN   r>   r   r   r   joinr   r   tensorr]   long)rY   rZ   
world_sizejoinedgatheredZ	_gatheredr   rb   r   r      s    

c                      s:   e Zd ZdZddd fddd fdd	Zd
d Z  ZS )r   aW  
    The RankFilter class is a convenient filter that extends the Filter class in the Python logging module.
    The purpose is to control which log records are processed based on the rank in a distributed environment.

    Args:
        rank: the rank of the process in the torch.distributed. Default is None and then it will use dist.get_rank().
        filter_fn: an optional lambda function used as the filtering criteria.
            The default function logs only if the rank of the process is 0,
            but the user can define their own function to implement custom filtering logic.
    Nc                 C  s   | dkS )Nr   r   )rankr   r   r   <lambda>       zRankFilter.<lambda>z
int | Noner   )rj   	filter_fnc                   sd   t    || _t r8t r8|d k	r,|nt | _n(tj	 rZtj	
 dkrZtd d| _d S )Nr'   zThe torch.distributed is either unavailable and uninitiated when RankFilter is instantiated.
If torch.distributed is used, please ensure that the RankFilter() is called
after torch.distributed.init_process_group() in the script.
r   )super__init__rm   r   r   r   get_rankrj   r   r   device_countwarningswarn)selfrj   rm   	__class__r   r   ro      s    
zRankFilter.__init__c                 G  s   |  | jS r#   )rm   rj   )rt   _argsr   r   r   filter   s    zRankFilter.filter)__name__
__module____qualname____doc__ro   rx   __classcell__r   r   ru   r   r      s   )T)rW   )
__future__r   sysrr   collections.abcr   loggingr   version_infotypingr   r   r   torch.distributedr   r   Zmonai.configr	   monai.utils.moduler
   r   ZOPT_IMPORT_VERSIONrN   rV   __all__r   r   r   r   r   r   r   r   <module>   s.   
T