U
    Ph                     @  sd   d dl mZ d dlmZ d dlZd dlmZ d dlmZ ddgZ	G dd deZG d	d deZ
dS )
    )annotations)SequenceN)Dataset)DistributedSamplerr    DistributedWeightedRandomSamplerc                      s0   e Zd ZdZd
dddddd fdd	Z  ZS )r   a  
    Enhance PyTorch DistributedSampler to support non-evenly divisible sampling.

    Args:
        dataset: Dataset used for sampling.
        even_divisible: if False, different ranks can have different data length.
            for example, input data: [1, 2, 3, 4, 5], rank 0: [1, 3, 5], rank 1: [2, 4].
        num_replicas: number of processes participating in distributed training.
            by default, `world_size` is retrieved from the current distributed group.
        rank: rank of the current process within `num_replicas`. by default,
            `rank` is retrieved from the current distributed group.
        shuffle: if `True`, sampler will shuffle the indices, default to True.
        kwargs: additional arguments for `DistributedSampler` super class, can be `seed` and `drop_last`.

    More information about DistributedSampler, please check:
    https://pytorch.org/docs/stable/data.html#torch.utils.data.distributed.DistributedSampler.

    TNr   bool
int | None)dataseteven_divisiblenum_replicasrankshufflec           	        sl   t  jf ||||d| |sht|}|| jk r:td| j| }| j| | jkrb|  jd8  _|| _d S )N)r	   r   r   r   zBthe dataset length is less than the number of participating ranks.   )super__init__lenr   
ValueError
total_sizer   num_samples)	selfr	   r
   r   r   r   kwargsdata_len
extra_size	__class__ H/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/data/samplers.pyr   +   s    	

zDistributedSampler.__init__)TNNT)__name__
__module____qualname____doc__r   __classcell__r   r   r   r   r      s       c                	      sH   e Zd ZdZddddddddd	 fd
dZ fddZdd Z  ZS )r   a	  
    Extend the `DistributedSampler` to support weighted sampling.
    Refer to `torch.utils.data.WeightedRandomSampler`, for more details please check:
    https://pytorch.org/docs/stable/data.html#torch.utils.data.WeightedRandomSampler.

    Args:
        dataset: Dataset used for sampling.
        weights: a sequence of weights, not necessary summing up to one, length should exactly
            match the full dataset.
        num_samples_per_rank: number of samples to draw for every rank, sample from
            the distributed subset of dataset.
            if None, default to the length of dataset split by DistributedSampler.
        generator: PyTorch Generator used in sampling.
        even_divisible: if False, different ranks can have different data length.
            for example, input data: [1, 2, 3, 4, 5], rank 0: [1, 3, 5], rank 1: [2, 4].'
        num_replicas: number of processes participating in distributed training.
            by default, `world_size` is retrieved from the current distributed group.
        rank: rank of the current process within `num_replicas`. by default,
            `rank` is retrieved from the current distributed group.
        kwargs: additional arguments for `DistributedSampler` super class, can be `seed` and `drop_last`.

    NTr   zSequence[float]r   ztorch.Generator | Noner   )r	   weightsnum_samples_per_rank	generatorr
   r   r   c           	        sL   | dd t jf ||||d| || _|d k	r:|n| j| _|| _d S )Nr   T)r	   r
   r   r   )
setdefaultr   r   r"   r   r#   r$   )	r   r	   r"   r#   r$   r
   r   r   r   r   r   r   r   X   s
    z)DistributedWeightedRandomSampler.__init__c                 #  sZ   t t  }tj fdd|D tjd}tj| jd jd}|D ]}|| V  qFd S )Nc                   s   g | ]} j | qS r   )r"   ).0ir   r   r   
<listcomp>k   s     z=DistributedWeightedRandomSampler.__iter__.<locals>.<listcomp>)dtypeT)r$   )	listr   __iter__torch	as_tensordoublemultinomialr#   r$   )r   indicesr"   rand_tensorr'   r   r(   r   r,   i   s
    z)DistributedWeightedRandomSampler.__iter__c                 C  s   | j S )N)r#   r(   r   r   r   __len__r   s    z(DistributedWeightedRandomSampler.__len__)NNTNN)r   r   r   r    r   r,   r3   r!   r   r   r   r   r   @   s         	)
__future__r   collections.abcr   r-   torch.utils.datar   r   Z_TorchDistributedSampler__all__r   r   r   r   r   <module>   s   )