o
    iB6                     @  sN   d Z ddlmZ ddlZddlZddlmZ G dd dZG dd deZdS )	z
The functions in this script are adapted from nnDetection,
https://github.com/MIC-DKFZ/nnDetection/blob/main/nndet/core/boxes/sampler.py
    )annotationsN)Tensorc                   @  s&   e Zd ZdZddddZdddZdS )HardNegativeSamplerBasea  
    Base class of hard negative sampler.

    Hard negative sampler is used to suppress false positive rate in classification tasks.
    During training, it select negative samples with high prediction scores.

    The training workflow is described as the follows:
    1) forward network and get prediction scores (classification prob/logits) for all the samples;
    2) use hard negative sampler to choose negative samples with high prediction scores and some positive samples;
    3) compute classification loss for the selected samples;
    4) do back propagation.

    Args:
        pool_size: when we need ``num_neg`` hard negative samples, they will be randomly selected from
            ``num_neg * pool_size`` negative samples with the highest prediction scores.
            Larger ``pool_size`` gives more randomness, yet selects negative samples that are less 'hard',
            i.e., negative samples with lower prediction scores.
    
   	pool_sizefloatreturnNonec                 C  s
   || _ d S )Nr   )selfr    r   r/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/detection/utils/hard_negative_sampler.py__init__<   s   
z HardNegativeSamplerBase.__init__negativer   num_negintfg_probsc                 C  s   |  |  krtdt|| j }t|  |}|| tjj|ddd\}}|| }tj	|  |j
dd| }|| }	tj|tjd}
d|
|	< |
S )	a;  
        Select hard negative samples.

        Args:
            negative: indices of all the negative samples, sized (P,),
                where P is the number of negative samples
            num_neg: number of negative samples to sample
            fg_probs: maximum foreground prediction scores (probability) across all the classes
                for each sample, sized (A,), where A is the number of samples.

        Returns:
            binary mask of negative samples to choose, sized (A,),
                where A is the number of samples in one image
        zSThe number of negative samples should not be larger than the number of all samples.r   T)dimsorteddeviceNdtype   )numel
ValueErrorr   r   mintotorchfloat32topkrandpermr   
zeros_likeuint8)r   r   r   r   pool_Znegative_idx_poolZhard_negativeperm2Zselected_neg_idxneg_maskr   r   r   select_negatives?   s    z(HardNegativeSamplerBase.select_negativesN)r   )r   r   r   r	   )r   r   r   r   r   r   r   r   )__name__
__module____qualname____doc__r   r(   r   r   r   r   r   (   s    r   c                      sb   e Zd ZdZ	d(d) fddZd*ddZd+ddZd,ddZd-dd Zd.d#d$Z	d/d&d'Z
  ZS )0HardNegativeSamplera  
    HardNegativeSampler is used to suppress false positive rate in classification tasks.
    During training, it selects negative samples with high prediction scores.

    The training workflow is described as the follows:
    1) forward network and get prediction scores (classification prob/logits) for all the samples;
    2) use hard negative sampler to choose negative samples with high prediction scores and some positive samples;
    3) compute classification loss for the selected samples;
    4) do back propagation.

    Args:
        batch_size_per_image: number of training samples to be randomly selected per image
        positive_fraction: percentage of positive elements in the selected samples
        min_neg: minimum number of negative samples to select if possible.
        pool_size: when we need ``num_neg`` hard negative samples, they will be randomly selected from
            ``num_neg * pool_size`` negative samples with the highest prediction scores.
            Larger ``pool_size`` gives more randomness, yet selects negative samples that are less 'hard',
            i.e., negative samples with lower prediction scores.
    r   r   batch_size_per_imager   positive_fractionr   min_negr   r   r	   c                   s.   t  j|d || _|| _|| _td d S )Nr
   z,Sampling hard negatives on a per batch basis)superr   r0   r.   r/   logginginfo)r   r.   r/   r0   r   	__class__r   r   r   x   s
   zHardNegativeSampler.__init__target_labelslist[Tensor]concat_fg_probsr   !tuple[list[Tensor], list[Tensor]]c                 C  s&   dd |D }| |d}| ||S )a  
        Select positives and hard negatives from list samples per image.
        Hard negative sampler will be applied to each image independently.

        Args:
            target_labels: list of labels per image.
                For image i in the batch, target_labels[i] is a Tensor sized (A_i,),
                where A_i is the number of samples in image i.
                Positive samples have positive labels, negative samples have label 0.
            concat_fg_probs: concatenated maximum foreground probability for all the images, sized (R,),
                where R is the sum of all samples inside one batch, i.e., R = A_0 + A_1 + ...

        Returns:
            - list of binary mask for positive samples
            - list of binary mask for negative samples

        Example:
            .. code-block:: python

                sampler = HardNegativeSampler(
                    batch_size_per_image=6, positive_fraction=0.5, min_neg=1, pool_size=2
                )
                # two images with different number of samples
                target_labels = [ torch.tensor([0,1]), torch.tensor([1,0,2,1])]
                concat_fg_probs = torch.rand(6)
                pos_idx_list, neg_idx_list = sampler(target_labels, concat_fg_probs)
        c                 S  s   g | ]}|j d  qS )r   )shape).0Zsamples_in_imager   r   r   
<listcomp>   s    z0HardNegativeSampler.__call__.<locals>.<listcomp>r   )splitselect_samples_img_list)r   r6   r8   samples_per_imager   r   r   r   __call__   s   zHardNegativeSampler.__call__r   c           	      C  sv   g }g }t |t |krtdt | dt | dt||D ]\}}| ||\}}|| || q ||fS )a%  
        Select positives and hard negatives from list samples per image.
        Hard negative sampler will be applied to each image independently.

        Args:
            target_labels: list of labels per image.
                For image i in the batch, target_labels[i] is a Tensor sized (A_i,),
                where A_i is the number of samples in image i.
                Positive samples have positive labels, negative samples have label 0.
            fg_probs: list of maximum foreground probability per images,
                For image i in the batch, target_labels[i] is a Tensor sized (A_i,),
                where A_i is the number of samples in image i.

        Returns:
            - list of binary mask for positive samples
            - list binary mask for negative samples

        Example:
            .. code-block:: python

                sampler = HardNegativeSampler(
                    batch_size_per_image=6, positive_fraction=0.5, min_neg=1, pool_size=2
                )
                # two images with different number of samples
                target_labels = [ torch.tensor([0,1]), torch.tensor([1,0,2,1])]
                fg_probs = [ torch.rand(2), torch.rand(4)]
                pos_idx_list, neg_idx_list = sampler.select_samples_img_list(target_labels, fg_probs)
        zDRequire len(target_labels) == len(fg_probs). Got len(target_labels)=z, len(fg_probs)=.)lenr   zipselect_samples_per_imgappend)	r   r6   r   pos_idxneg_idxlabels_per_imgfg_probs_per_imgpos_idx_per_image_maskneg_idx_per_image_maskr   r   r   r>      s"   
z+HardNegativeSampler.select_samples_img_listrH   rI   tuple[Tensor, Tensor]c           	      C  sv   |  |  krtdt|dkd }t|dkd }| |}| |||}| ||}| |||}||fS )at  
        Select positives and hard negatives from samples.

        Args:
            labels_per_img: labels, sized (A,).
                Positive samples have positive labels, negative samples have label 0.
            fg_probs_per_img: maximum foreground probability, sized (A,)

        Returns:
            - binary mask for positive samples, sized (A,)
            - binary mask for negative samples, sized (A,)

        Example:
            .. code-block:: python

                sampler = HardNegativeSampler(
                    batch_size_per_image=6, positive_fraction=0.5, min_neg=1, pool_size=2
                )
                # two images with different number of samples
                target_labels = torch.tensor([1,0,2,1])
                fg_probs = torch.rand(4)
                pos_idx, neg_idx = sampler.select_samples_per_img(target_labels, fg_probs)
        zHlabels_per_img and fg_probs_per_img should have same number of elements.r   r   )r   r   r   whereget_num_posselect_positivesget_num_negr(   )	r   rH   rI   positiver   num_posrJ   r   rK   r   r   r   rD      s   
z*HardNegativeSampler.select_samples_per_imgrQ   torch.Tensorc                 C  s"   t | j| j }t| |}|S )z
        Number of positive samples to draw

        Args:
            positive: indices of positive samples

        Returns:
            number of positive sample
        )r   r.   r/   r   r   )r   rQ   rR   r   r   r   rN      s   zHardNegativeSampler.get_num_posr   rR   c                 C  s>   t td|tddt| j   }t| t|| j}|S )a  
        Sample enough negatives to fill up ``self.batch_size_per_image``

        Args:
            negative: indices of positive samples
            num_pos: number of positive samples to draw

        Returns:
            number of negative samples
        r   g      ?)r   maxabsr   r/   r   r   r0   )r   r   rR   r   r   r   r   rP     s   $zHardNegativeSampler.get_num_neglabelsc                 C  sX   |  |  krtdtj|  |jdd| }|| }tj|tjd}d||< |S )a  
        Select positive samples

        Args:
            positive: indices of positive samples, sized (P,),
                where P is the number of positive samples
            num_pos: number of positive samples to sample
            labels: labels for all samples, sized (A,),
                where A is the number of samples.

        Returns:
            binary mask of positive samples to choose, sized (A,),
                where A is the number of samples in one image
        zSThe number of positive samples should not be larger than the number of all samples.r   Nr   r   )r   r   r   r!   r   r"   r#   )r   rQ   rR   rV   perm1pos_idx_per_imagerJ   r   r   r   rO     s   z$HardNegativeSampler.select_positives)r   r   )
r.   r   r/   r   r0   r   r   r   r   r	   )r6   r7   r8   r   r   r9   )r6   r7   r   r7   r   r9   )rH   r   rI   r   r   rL   )rQ   rS   r   r   )r   rS   rR   r   r   r   )rQ   r   rR   r   rV   r   r   r   )r)   r*   r+   r,   r   r@   r>   rD   rN   rP   rO   __classcell__r   r   r4   r   r-   c   s    
	
 
0
'
r-   )r,   
__future__r   r2   r   r   r   r-   r   r   r   r   <module>   s   ;