U
    Phta                     @  sB  d dl mZ d dlZd dlmZ d dlZd dlZd dlm	Z	m
Z
 d dlmZ d dlmZmZmZ d dlmZmZmZ ed\ZZed	\ZZed
dd\ZZG dd deZG dd deZG dd deZG dd deZG dd deZG dd deeZG dd deZG dd deZ G dd deZ!G dd  d eZ"dS )!    )annotationsN)Any)KeysCollectionNdarrayOrTensor)GaussianFilter)MapTransformRandomizable
SpatialPad)StrEnumconvert_to_numpyoptional_importzskimage.measurezskimage.morphologyzscipy.ndimage.morphologydistance_transform_cdt)namec                   @  s@   e Zd ZdZdZdZdZdZdZdZ	dZ
d	Zd
ZdZdZdZdS )NuclickKeysz&
    Keys for nuclick transforms.
    imagelabelothers
foregroundcentroid
mask_valuelocation
nuc_pointsbounding_boxes
img_height	img_widthpred_classesN)__name__
__module____qualname____doc__IMAGELABELOTHERS
FOREGROUNDCENTROID
MASK_VALUELOCATION
NUC_POINTSBOUNDING_BOXES
IMG_HEIGHT	IMG_WIDTHPRED_CLASSES r,   r,   R/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/apps/nuclick/transforms.pyr      s   r   c                      s4   e Zd ZdZddddd fdd	Zd
d Z  ZS )FlattenLabelda  
    FlattenLabeld creates labels per closed object contour (defined by a connectivity). For e.g if there are
    12 small regions of 1's it will delineate them into 12 different label classes

    Args:
        connectivity: Max no. of orthogonal hops to consider a pixel/voxel as a neighbor. Refer skimage.measure.label
        allow_missing_keys: don't raise exception if key is missing.
       Fr   intbool)keysconnectivityallow_missing_keysc                   s   t  || || _d S N)super__init__r3   )selfr2   r3   r4   	__class__r,   r-   r7   =   s    zFlattenLabeld.__init__c                 C  sX   t |}| jD ]D}t|| tjr.t|| n|| }tj|| jd	t
j||< q|S )Nr3   )dictr2   
isinstancetorchTensorr   measurer   r3   astypenpuint8r8   datadkeyimgr,   r,   r-   __call__A   s
    
$zFlattenLabeld.__call__)r/   Fr   r   r   r   r7   rI   __classcell__r,   r,   r9   r-   r.   3   s   	r.   c                      sH   e Zd ZdZejddfdddddd	 fd
dZdd Zdd Z  Z	S )ExtractPatchda  
    Extracts a patch from the given image and label, however it is based on the centroid location.
    The centroid location is a 2D coordinate (H, W). The extracted patch is extracted around the centroid,
    if the centroid is towards the edge, the centroid will not be the center of the image as the patch will be
    extracted from the edges onwards

    Args:
        keys: image, label
        centroid_key: key where the centroid values are stored, defaults to ``"centroid"``
        patch_size: size of the extracted patch
        allow_missing_keys: don't raise exception if key is missing.
        pad_kwargs: other arguments for the SpatialPad transform
       Fr   strztuple[int, int] | intr1   r   )r2   centroid_key
patch_sizer4   kwargsc                   s$   t  || || _|| _|| _d S r5   )r6   r7   rO   rP   rQ   )r8   r2   rO   rP   r4   rQ   r9   r,   r-   r7   X   s    zExtractPatchd.__init__c                 C  s   t |}|| j }| j| jf}| jD ]d}|| }| | j||jdd  \}}}	}
|d d |||	|
f }tf d|i| j|||< q$|S )Nspatial_size)r<   rO   rP   r2   bboxshaper	   rQ   )r8   rE   rF   r   roi_sizerG   rH   x_startx_endy_starty_endcroppedr,   r,   r-   rI   e   s    

"zExtractPatchd.__call__c                 C  s   |\}}|\}}t t||d  d}t t||d  d}	|| }
|	| }|
|kr`|}
|| }||krt|}|| }	||
|	|fS )N   r   )r0   max)r8   rP   r   sizexymnrW   rY   rX   rZ   r,   r,   r-   rT   r   s    zExtractPatchd.bbox)
r   r   r   r   r   r$   r7   rI   rT   rK   r,   r,   r9   r-   rL   I   s   rL   c                      sH   e Zd ZdZejejdddfdddddd	d
 fddZdd Z  Z	S )SplitLabeldap  
    Extracts a single label from all the given classes, the single label is defined by mask_value, the remaining
    labels are kept in others

    Args:
        label: key of the label source
        others: other labels storage key, defaults to ``"others"``
        mask_value: the mask_value that will be kept for binarization of the label, defaults to ``"mask_value"``
        min_area: The smallest allowable object size.
        others_value: Value/class for other nuclei;  Use this to separate core nuclei vs others.
        to_binary_mask: Convert mask to binary;  Set it false to restore original class values
       r   Tr   rN   z
str | Noner0   r1   )r2   r   r   min_areaothers_valueto_binary_maskc                   s2   t  j|dd || _|| _|| _|| _|| _d S NF)r4   )r6   r7   r   r   re   rf   rg   )r8   r2   r   r   re   rf   rg   r9   r,   r-   r7      s    	zSplitLabeld.__init__c                 C  s  t |}t| jdkr"td d S | jD ]R}t|| tjrF|| nt|| }t|}| j	r||| j	 }d|||k< nd||| j
k< tt|}| jrd||dk< t|}d|||k< d||dk< t|rtjt|d dd}t|d  }t|tjr|tjn|}t|tjr2|tjn|}t|| tjrL|nt|||< t|| tjrn|nt||| j< q(|S )Nr/   z8Only 'label' key is supported, more than 1 key was foundr   r;   )r<   lenr2   printr=   r>   r?   
from_numpycloner   rf   r0   r]   rg   count_nonzeror@   r   r   typerC   r   )r8   rE   rF   rG   r   maskr   r   r,   r,   r-   rI      s2    &



"&zSplitLabeld.__call__)
r   r   r   r   r   r"   r%   r7   rI   rK   r,   r,   r9   r-   rc      s   rc   c                      sb   e Zd ZdZddddd fdd	Zd
d Zdd ZdddZdddZdd Z	dddZ
  ZS ) FilterImageda   
    Filters Green and Gray channel of the image using an allowable object size, this pre-processing transform
    is specific towards NuClick training process. More details can be referred in this paper Koohbanani,
    Navid Alemi, et al. "NuClick: a deep learning framework for interactive segmentation of microscopic images."
    Medical Image Analysis 65 (2020): 101771.

    Args:
        min_size: The smallest allowable object size
        allow_missing_keys: don't raise exception if key is missing.
      Fr   r0   r1   )r2   min_sizer4   c                   s   t  || || _d S r5   )r6   r7   rr   )r8   r2   rr   r4   r9   r,   r-   r7      s    zFilterImaged.__init__c                 C  sJ   t |}| jD ]6}t|| tjr.t|| n|| }| |||< q|S r5   )r<   r2   r=   r>   r?   r   filterrD   r,   r,   r-   rI      s
    
$zFilterImaged.__call__c                 C  sJ   |  |}| |}||@ }| jr2| j|| jdn|}|t|||g S )Nrr   )filter_green_channelfilter_graysrr   filter_remove_small_objectsrB   dstack)r8   rgbZmask_not_greenZmask_not_grayZmask_gray_greenro   r,   r,   r-   rs      s    

zFilterImaged.filter   TZ   c           
      C  st   |d d d d df }||k |dk@ }|  |}||krp|dk rp|dkrptd| d | }	| ||	|||}|S )Nr/   r      Tr\   )mask_percentmathceilru   )
r8   img_npZgreen_threshavoid_overmaskovermask_threshoutput_typegZ
gr_ch_maskmask_percentageZnew_green_threshr,   r,   r-   ru      s    
    z!FilterImaged.filter_green_channel   c                 C  s   t |d d d d df |d d d d df  |k}t |d d d d df |d d d d df  |k}t |d d d d df |d d d d df  |k}||@ |@  S )Nr   r/   r\   )abs)r8   ry   	toleranceZrg_diffZrb_diffZgb_diffr,   r,   r-   rv      s    444zFilterImaged.filter_graysc                 C  s   t |jdkrx|jd dkrx|d d d d df |d d d d df  |d d d d df  }dt||j d  }ndt||j d  }|S )N   r\   r   r/   d   )ri   rU   rB   rm   r^   )r8   r   Znp_sumr   r,   r,   r-   r}      s
    BzFilterImaged.mask_percent  _   c                 C  sV   t j|t|d}| |}||krR|dkrR|dkrRt|d }| ||||}|S )Nrt   r/   Tr\   )
morphologyremove_small_objectsrA   r1   r}   roundrw   )r8   r   rr   r   r   Zrem_smr   Znew_min_sizer,   r,   r-   rw     s    
z(FilterImaged.filter_remove_small_objects)rq   F)rz   Tr{   r1   )r   )r   Tr   )r   r   r   r   r7   rI   rs   ru   rv   r}   rw   rK   r,   r,   r9   r-   rp      s          

rp   c                   @  st   e Zd ZdZejejejdddddddf
dddd	d
dd	d	ddd
ddZdd Z	dd Z
dd Zdd Zdd ZdS )AddPointGuidanceSignalda  
    Adds Guidance Signal to the input image

    Args:
        image: key of source image, defaults to ``"image"``
        label: key of source label, defaults to ``"label"``
        others: source others (other labels from the binary mask which are not being used for training)
            defaults to ``"others"``
        drop_rate: probability of dropping the signal, defaults to ``0.5``
        jitter_range: noise added to the points in the point mask for exclusion mask, defaults to ``3``
        gaussian: add gaussian
        sigma: sigma value for gaussian
        truncated: spreads how many stds for gaussian
        add_exclusion_map: add exclusion map/signal
    g      ?r   F      ?       @TrN   floatr0   r1   )
r   r   r   	drop_ratejitter_rangegaussiansigma	truncatedadd_exclusion_mapuse_distancec                 C  sL   t | | || _|| _|| _|| _|| _|| _|| _|| _	|	| _
|
| _d S r5   )r   r7   r   r   r   r   r   r   r   r   r   r   )r8   r   r   r   r   r   r   r   r   r   r   r,   r,   r-   r7     s    z AddPointGuidanceSignald.__init__c                 C  s4  t |}t|| j tjr$|| j nt|| j }t|| j tjrP|| j nt|| j }| j|d |jd}| 	|}| j
rt|| j tjr|| j nt|| j }| j|d |j| j| jd}| 	|}tj||d  |d  fdd}ntj||d  fdd}t|| j tjr"|nt||| j< |S )Nr   dtype)r   r   r   dim)r<   r=   r   r>   r?   rk   r   inclusion_mapr   _apply_gaussianr   r   exclusion_mapr   r   catr   )r8   rE   rF   r   ro   Zinc_sigr   Zexc_sigr,   r,   r-   rI   8  s$    ,,
,   
&z AddPointGuidanceSignald.__call__c                 C  sJ   | j rt|dkr|S td| j| jd|dd}|ddS Nr   r\   )spatial_dimsr   r   r   r>   rm   r   r   r   	unsqueezesqueezer8   tr_   r,   r,   r-   r   M  s    "z'AddPointGuidanceSignald._apply_gaussianc           	      C  s   t d ks| jstttdr(t|dk}ntt|dk}t|dkrp| j	dt|}||df ||df fS d S t |
 }t|d }t|
 dkd }| jj|d|| t||  d}tt||j  d }|d |d fS )Nargwherer   r/   r   )r^   prR   )r   r   hasattrr>   r   rB   r   ri   Rrandintflattenexpwherechoicesumasarrayunravel_indexrU   	transposetolist)	r8   r   indicesindexdistanceprobabilityidxseedr   r,   r,   r-   _seed_pointS  s    
$ z#AddPointGuidanceSignald._seed_pointc                 C  s8   t j||d}| |}|d k	r4d||d |d f< |S )Nr   r/   r   )r>   
zeros_liker   )r8   ro   r   
point_maskptr,   r,   r-   r   h  s
    
z%AddPointGuidanceSignald.inclusion_mapc                 C  s  t j||d}tjjddg|d| gdr.|S |jd d }|jd d }tt|}|D ]}	tjjddg|d| gdr~q\|	j	\}
}t
t|
}
t
t|}|r|
| jj| |d }
|| jj| |d }ttd|
|}
ttd||}d||
|f< q\|S )Nr   TFr/   )r   r   )lowhigh)r>   r   rB   randomr   rU   r@   regionpropsr   r   r0   r~   floorr   r   minr]   )r8   r   r   r   r   r   max_xmax_ystatsstatr_   r`   r,   r,   r-   r   p  s&    
z%AddPointGuidanceSignald.exclusion_mapN)r   r   r   r   r   r    r!   r"   r7   rI   r   r   r   r   r,   r,   r,   r-   r     s"   "r   c                	   @  s`   e Zd ZdZejejdddddfdddd	d
d
d	dddZdd ZdddZ	dd Z
dd ZdS )AddClickSignalsda  
    Adds Click Signal to the input image

    Args:
        image: source image, defaults to ``"image"``
        foreground: 2D click indices as list, defaults to ``"foreground"``
        bb_size: single integer size, defines a bounding box like (bb_size, bb_size)
        gaussian: add gaussian
        sigma: sigma value for gaussian
        truncated: spreads how many stds for gaussian
        add_exclusion_map: add exclusion map/signal
    rM   Fr   r   TrN   r0   r1   r   r   r   bb_sizer   r   r   r   c                 C  s.   || _ || _|| _|| _|| _|| _|| _d S r5   r   )r8   r   r   r   r   r   r   r   r,   r,   r-   r7     s    
zAddClickSignalsd.__init__c              	   C  sD  t |}t|| j tjr$|| j nt|| j }|jd }|jd }|tj	j
d}|d |d  }}|| j}	|	rt|	||f t ng }	dd |	D }
dd |	D }| j||
|||| jd	\}}|std
| j||||
|||d}||tjj
< ||tjj
< ||tjj
< t|| j tjr2|nt||| j< |S )NrR   r   )r   r   r   r/   c                 S  s   g | ]}|d  qS r   r,   .0xyr,   r,   r-   
<listcomp>  s     z-AddClickSignalsd.__call__.<locals>.<listcomp>c                 S  s   g | ]}|d  qS )r/   r,   r   r,   r,   r-   r     s     )cxcyr_   r`   bbz0Failed to create patches from given click points)rH   	click_mapr   r   r   r_   r`   )r<   r=   r   r>   r?   rk   rU   getr   r&   valuer   rB   arrayrA   r0   r   get_clickmap_boundingboxr   
ValueErrorget_patches_and_signalsr(   r*   r)   r   )r8   rE   rF   rH   r_   r`   r   txtyposr   r   r   r   patchesr,   r,   r-   rI     s6    ,

$      &zAddClickSignalsd.__call__c                   sd  t |d } fddtt D }fddttD }	t||	}
t |
 t|
d| f< g }tt D ]}td | |d  }td| |d  }t	|| }t	|| }|| |kr|| }|| |kr|| }|| |kr,|| |kr,|
||||g qtd| d| d| d| d	| d
| d q||fS )Nr   c                   s(   h | ] } | ks  | d k r|qS r   r,   r   ir   r_   r,   r-   	<setcomp>  s       z<AddClickSignalsd.get_clickmap_boundingbox.<locals>.<setcomp>c                   s(   h | ] } | ks  | d k r|qS r   r,   r   r   r`   r,   r-   r     s       r/   r\   zIgnore smaller sized bbox (z, z) (Min size: r_   ))r>   r   rangeri   listunionrB   deleter]   r   appendrj   )r8   rH   r   r   r_   r`   r   r   x_del_indicesy_del_indicesdel_indicesr   r   rW   rY   rX   rZ   r,   r   r   r_   r`   r-   r     s*    0z)AddClickSignalsd.get_clickmap_boundingboxc              	     sj  g } fddt t D }	fddt tD }
t|	|
}t | t|t|D ]\}}|d }|d }|d }|d }|d d ||||f }t|d }d| | | f< |||||f }| 	|}| j
rF|| dk|j}|||||f }| 	|}|t||d  |d  g qj|t||d  g qjt|S )Nc                   s(   h | ] } | ks  | d k r|qS r   r,   r   r   r,   r-   r     s       z;AddClickSignalsd.get_patches_and_signals.<locals>.<setcomp>c                   s(   h | ] } | ks  | d k r|qS r   r,   r   r   r,   r-   r     s       r   r/   r\   r   )r   ri   r   r   rB   r   	enumerater>   r   r   r   rn   r   r   r   stack)r8   rH   r   r   r   r   r_   r`   r   r   r   r   r   bounding_boxrW   rY   rX   rZ   patchZthis_click_mapr   Zothers_click_mapother_pointsr,   r   r-   r     s.    

 z(AddClickSignalsd.get_patches_and_signalsc                 C  sJ   | j rt|dkr|S td| j| jd|dd}|ddS r   r   r   r,   r,   r-   r     s    "z AddClickSignalsd._apply_gaussianN)rM   )r   r   r   r   r   r    r#   r7   rI   r   r   r   r,   r,   r,   r-   r     s   
!r   c                      sv   e Zd ZdZejejejejdddddej	f
ddddddd	d	d
d
dd fddZ
dd ZdddZdddZ  ZS )PostFilterLabeldaV  
    Performs Filtering of Labels on the predicted probability map

    Args:
        thresh: probability threshold for classifying a pixel as a mask
        min_size: min_size objects that will be removed from the image, refer skimage remove_small_objects
        min_hole: min_hole that will be removed from the image, refer skimage remove_small_holes
        do_reconstruction: Boolean Flag, Perform a morphological reconstruction of an image, refer skimage
        allow_missing_keys: don't raise exception if key is missing.
        pred_classes: List of Predicted class for each instance
    Q?
      Fr   rN   r   r0   r1   )r2   r   r   r   r   threshrr   min_holedo_reconstructionr4   r   c                   sH   t  ||
 || _|| _|| _|| _|| _|| _|| _|	| _	|| _
d S r5   )r6   r7   r   r   r   r   r   rr   r   r   r   )r8   r2   r   r   r   r   r   rr   r   r   r4   r   r9   r,   r-   r7     s    zPostFilterLabeld.__init__c           
      C  s   t |}|| j}|| j }|| j }|| j }| jD ]J}|| tj	}| 
|| j| j| j}	| j|	||||dtj	||< q8|S )N)r   )r<   r   r   r   r   r   r2   rA   rB   rC   post_processingr   rr   r   gen_instance_map)
r8   rE   rF   r   r   r_   r`   rG   r   masksr,   r,   r-   rI   2  s    



"zPostFilterLabeld.__call__c                 C  sL   ||k}t |jd D ]0}tj|| |d||< tj|| |d||< q|S )Nr   rt   )area_threshold)r   rU   r   r   remove_small_holes)r8   predsr   rr   r   r   r   r,   r,   r-   r   @  s
    z PostFilterLabeld.post_processingTNc                 C  s   t j||ft jd}t|D ]\}}	|| }
|rD|t|k rD|| nd}|rP|n|d }||
d |
d |
d |
d f }t |	dk||}|||
d |
d |
d |
d f< q|S )Nr   r/   r   r\   r   )rB   zerosuint16r   ri   r   )r8   r   r   r_   r`   r   r   instance_mapr   ro   r   cZthis_mapr,   r,   r-   r   G  s    $&z!PostFilterLabeld.gen_instance_map)r   r   r   )TN)r   r   r   r   r   r'   r(   r)   r*   r+   r7   rI   r   r   rK   r,   r,   r9   r-   r     s   (
r   c                      s4   e Zd ZdZddddd fddZd	d
 Z  ZS )AddLabelAsGuidancedz
    Add Label as new guidance channel

    Args:
        source: label/source key which gets added as additional guidance channel
    r   r   rN   None)r2   sourcereturnc                   s   t  j|dd || _d S rh   )r6   r7   r
  )r8   r2   r
  r9   r,   r-   r7   ]  s    zAddLabelAsGuidanced.__init__c                 C  s   t |}| jD ]}t|| tjr*|| nt|| }t|| j tjrT|| j nt|| j }|dk}t|jt|jk r|d  }tj	||
|jgt|jd d}t|| tjr|nt|||< q|S )Nr   r   r   )r<   r2   r=   r>   r?   rk   r
  ri   rU   r   rn   r   r   )r8   rE   rF   rG   r   r   r,   r,   r-   rI   a  s    
&,$"zAddLabelAsGuidanced.__call__)r   rJ   r,   r,   r9   r-   r  U  s   r  c                      s4   e Zd ZdZddddd fddZd	d
 Z  ZS )SetLabelClassdz
    Assign class value from the labelmap.  This converts multi-dimension tensor to single scalar tensor.

    Args:
        offset: offset value to be added to the mask value to determine the final class
    r   r   r0   r	  )r2   offsetr  c                   s   t  j|dd || _d S rh   )r6   r7   r  )r8   r2   r  r9   r,   r-   r7   w  s    zSetLabelClassd.__init__c                 C  sZ   t |}| jD ]F}t|| tjr*|| nt|| }tt|}|| j ||< q|S r5   )	r<   r2   r=   r>   r?   rk   r0   r]   r  )r8   rE   rF   rG   r   r   r,   r,   r-   rI   {  s    
&zSetLabelClassd.__call__)r   rJ   r,   r,   r9   r-   r  o  s   r  )#
__future__r   r~   typingr   numpyrB   r>   monai.configr   r   monai.networks.layersr   monai.transformsr   r   r	   monai.utilsr
   r   r   r@   _r   r   r   r.   rL   rc   rp   r   r   r   r  r  r,   r,   r,   r-   <module>   s,   :CG} J