o
    iD                     @  s  d dl mZ d dlZd dlmZ d dlmZ d dlZd dl	Z	d dl
mZmZ d dlmZmZmZmZmZmZmZ d dlmZ d dlmZmZmZmZmZmZ d d	lmZm Z m!Z! d d
l"m#Z# d dl$m%Z%m&Z& e!ddd\Z'Z(e!ddd\Z)Z(e!ddd\Z*Z(e!ddd\Z+Z(e!ddd\Z,Z(e!ddd\Z-Z(g dZ.G dd deZ/G dd deZ0G dd deZ1G dd deZ2G d d! d!eZ3G d"d# d#eZ4G d$d% d%eZ5G d&d' d'eZ6G d(d) d)eZ7G d*d+ d+eZ8G d,d- d-eZ9dS ).    )annotationsN)Sequence)Callable)	DtypeLikeNdarrayOrTensor)Activations
AsDiscreteBoundingRect	FillHolesGaussianSmoothRemoveSmallObjectsSobelGradients)	Transform)maxmaximumminsumuniquewhere)TransformBackendsconvert_to_numpyoptional_import)ensure_tuple_rep)convert_to_dst_typeconvert_to_tensorzscipy.ndimagelabel)namezskimage.morphologydiskopeningzskimage.segmentation	watershedzskimage.measurefind_contourscentroid)	WatershedGenerateWatershedMaskGenerateInstanceBorderGenerateDistanceMapGenerateWatershedMarkersGenerateSuccinctContourGenerateInstanceContourGenerateInstanceCentroidGenerateInstanceType!HoVerNetInstanceMapPostProcessing!HoVerNetNuclearTypePostProcessingc                   @  s8   e Zd ZdZejgZdejfdd	d
Z		ddddZ
dS )r"   a  
    Use `skimage.segmentation.watershed` to get instance segmentation results from images.
    See: https://scikit-image.org/docs/stable/api/skimage.segmentation.html#skimage.segmentation.watershed.

    Args:
        connectivity: an array with the same number of dimensions as image whose non-zero elements indicate
            neighbors for connection. Following the scipy convention, default is a one-connected array of
            the dimension of the image.
        dtype: target data content type to convert, default is np.int64.

       connectivity
int | Nonedtyper   returnNonec                 C     || _ || _d S N)r.   r0   )selfr.   r0    r6   l/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/pathology/transforms/post/array.py__init__J      
zWatershed.__init__Nimager   maskNdarrayOrTensor | Nonemarkersc                 C  s>   t |}t |}t |}t|||| jd}t||| jdd S )a\  
        Args:
            image: image where the lowest value points are labeled first. Shape must be [1, H, W, [D]].
            mask: optional, the same shape as image. Only points at which mask == True will be labeled.
                If None (no mask given), it is a volume of all 1s.
            markers: optional, the same shape as image. The desired number of markers, or an array marking
                the basins with the values to be assigned in the label matrix. Zero means not a marker.
                If None (no markers given), the local minima of the image are used as markers.
        )r=   r;   r.   r0   r   )r   r   r.   r   r0   )r5   r:   r;   r=   Zinstance_segr6   r6   r7   __call__N   s
   zWatershed.__call__)r.   r/   r0   r   r1   r2   )NN)r:   r   r;   r<   r=   r<   r1   r   __name__
__module____qualname____doc__r   NUMPYbackendnpint64r8   r?   r6   r6   r6   r7   r"   ;   s    r"   c                   @  s8   e Zd ZdZejgZdddejfdddZ	dddZ
dS )r#   az  
    generate mask used in `watershed`. Only points at which mask == True will be labeled.

    Args:
        activation: the activation layer to be applied on the input probability map.
            It can be "softmax" or "sigmoid" string, or any callable. Defaults to "softmax".
        threshold: an optional float value to threshold to binarize probability map.
            If not provided, defaults to 0.5 when activation is not "softmax", otherwise None.
        min_object_size: objects smaller than this size (in pixel) are removed. Defaults to 10.
        dtype: target data content type to convert, default is np.uint8.

    softmaxN
   
activationstr | Callable	thresholdfloat | Nonemin_object_sizeintr0   r   r1   r2   c                 C  s   || _ d}d}d }t|tr(| dkrd}n"| dkr d}ntd| dt|r/|}n
tdt| dt|||d| _|sI|d u rId	}t	||d
| _
|dkr\t|d| _d S d | _d S )NFrI   TsigmoidJThe activation should be 'softmax' or 'sigmoid' string, or any callable. '' was given.7The activation type should be either str or callable. 'rI   rQ   other      ?rM   argmaxr   min_size)r0   
isinstancestrlower
ValueErrorcallabletyper   rK   r   as_discreter   remove_small_objects)r5   rK   rM   rO   r0   use_softmaxuse_sigmoidactivation_fnr6   r6   r7   r8   t   s(   

"zGenerateWatershedMask.__init__prob_mapr   c                 C  sf   |  |}| |}t|}t|d }| jdur| |}t|dk}d||< t||| jdd S )zk
        Args:
            prob_map: probability map of segmentation, shape must be [C, H, W, [D]]
        r   Nr-   r>   )	rK   rb   r   r   rc   rG   r   r   r0   )r5   rg   predZpred_indicesr6   r6   r7   r?      s   



zGenerateWatershedMask.__call__)
rK   rL   rM   rN   rO   rP   r0   r   r1   r2   )rg   r   r1   r   )rA   rB   rC   rD   r   rE   rF   rG   uint8r8   r?   r6   r6   r6   r7   r#   d   s    $r#   c                   @  s4   e Zd ZdZejgZdejfdd	d
Z	dddZ
dS )r$   a  
    Generate instance border by hover map. The more parts of the image that cannot be identified as foreground areas,
    the larger the grey scale value. The grey value of the instance's border will be larger.

    Args:
        kernel_size: the size of the Sobel kernel. Defaults to 5.
        dtype: target data type to convert to. Defaults to np.float32.


    Raises:
        ValueError: when the `mask` shape is not [1, H, W].
        ValueError: when the `hover_map` shape is not [2, H, W].

       kernel_sizerP   r0   r   r1   r2   c                 C  s   || _ t|d| _d S )Nrk   )r0   r   sobel_gradient)r5   rk   r0   r6   r6   r7   r8      s   zGenerateInstanceBorder.__init__r;   r   	hover_mapc                 C  s  t |jdkrtd|j dt |jdkr*|jd dkr)td|jd  dnt |jdkr6|d }n	td	|j d|jd dkrPtd
|jd  |dddf }|dddf }t|t|}}t|t|}}|| dks~|| dkrtd|| ||  }|| ||  }| |d }	| |d }
t|	t|	}}t|
t|
}}|| dks|| dkrtdd|	| ||   }	d|
| ||   }
t|	|
}|d|  }d||dk < t||| jdd S )a  
        Args:
            mask: binary segmentation map, the output of :py:class:`GenerateWatershedMask`.
                Shape must be [1, H, W] or [H, W].
            hover_map:  horizontal and vertical distances of nuclear pixels to their centres of mass. Shape must be [2, H, W].
                The first and second channel represent the horizontal and vertical maps respectively. For more details refer
                to papers: https://arxiv.org/abs/1812.06499.
           z:The hover map should have the shape of [C, H, W], but got .r   r-   /The mask should have only one channel, but got    N?The mask should have the shape of [1, H, W] or [H, W], but got z5Suppose the hover map only has two channels, but got .z.Not a valid hover map, please check your input)r-   .)r   .zNot a valid sobel gradient mapr>   )	lenshaper_   r   r   rm   r   r   r0   )r5   r;   rn   Zhover_hZhover_vZhover_h_minZhover_h_maxZhover_v_minZhover_v_maxZsobelhZsobelvZ
sobelh_minZ
sobelh_maxZ
sobelv_minZ
sobelv_maxZoverallr6   r6   r7   r?      s>   	

zGenerateInstanceBorder.__call__N)rk   rP   r0   r   r1   r2   )r;   r   rn   r   r1   r   rA   rB   rC   rD   r   rE   rF   rG   float32r8   r?   r6   r6   r6   r7   r$      s
    r$   c                   @  s4   e Zd ZdZejgZdejfdd	d
Z	dddZ
dS )r%   a  
    Generate distance map.
    In general, the instance map is calculated from the distance to the background.
    Here, we use 1 - "instance border map" to generate the distance map.
    Nuclei values form mountains so invert them to get basins.

    Args:
        smooth_fn: smoothing function for distance map, which can be any callable object.
            If not provided :py:class:`monai.transforms.GaussianSmooth()` is used.
        dtype: target data type to convert to. Defaults to np.float32.
    N	smooth_fnCallable | Noner0   r   r1   r2   c                 C  s   |d ur|nt  | _|| _d S r4   )r   rx   r0   )r5   rx   r0   r6   r6   r7   r8      s   
zGenerateDistanceMap.__init__r;   r   instance_borderc                 C  s   t |jdkr|jd dkrtd|jd  dnt |jdkr&|d }n	td|j d|jd dks;|jdkrCtd	|j d
| | }| |}t| || jdd S )3  
        Args:
            mask: binary segmentation map, the output of :py:class:`GenerateWatershedMask`.
                Shape must be [1, H, W] or [H, W].
            instance_border: instance border map, the output of :py:class:`GenerateInstanceBorder`.
                Shape must be [1, H, W].
        ro   r   r-   rq   rp   rr   Nrs   @Input instance_border should be with size of [1, H, W], but got       ?r>   )rt   ru   r_   ndimrx   r   r0   )r5   r;   rz   distance_mapr6   r6   r7   r?     s   

zGenerateDistanceMap.__call__)rx   ry   r0   r   r1   r2   r;   r   rz   r   r1   r   rv   r6   r6   r6   r7   r%      s
    r%   c                   @  s:   e Zd ZdZejgZddddejfdddZ	dddZ
dS )r&   a"  
    Generate markers to be used in `watershed`. The watershed algorithm treats pixels values as a local topography
    (elevation). The algorithm floods basins from the markers until basins attributed to different markers meet on
    watershed lines. Generally, markers are chosen as local minima of the image, from which basins are flooded.
    Here is the implementation from HoVerNet paper.
    For more details refer to papers: https://arxiv.org/abs/1812.06499.

    Args:
        threshold: a float value to threshold to binarize instance border map.
            It turns uncertain area to 1 and other area to 0. Defaults to 0.4.
        radius: the radius of the disk-shaped footprint used in `opening`. Defaults to 2.
        min_object_size: objects smaller than this size (in pixel) are removed. Defaults to 10.
        postprocess_fn: additional post-process function on the markers.
            If not provided, :py:class:`monai.transforms.post.FillHoles()` will be used.
        dtype: target data type to convert to. Defaults to np.int64.

    皙?rr   rJ   NrM   floatradiusrP   rO   postprocess_fnry   r0   r   r1   r2   c                 C  sH   || _ || _|| _|d u rt }|| _|dkrt|d| _d S d | _d S )Nr   rZ   )rM   r   r0   r
   r   r   rc   )r5   rM   r   rO   r   r0   r6   r6   r7   r8   1  s   "z!GenerateWatershedMarkers.__init__r;   r   rz   c                 C  s  t |jdkr|jd dkrtd|jd  dnt |jdkr&|d }n	td|j d|jd dks;|jdkrCtd	|j || jk}|t||d  }t|dk }d||< | |}t|}t	|
 t| j}t|d d }| jdur| |}t||| jd
d S )r{   ro   r   r-   rq   rp   rr   Nrs   r|   r>   )rt   ru   r_   r~   rM   r   r   r   r   r   squeezer   r   r   rc   r0   )r5   r;   rz   markerZmarker_indicesr6   r6   r7   r?   B  s(   




z!GenerateWatershedMarkers.__call__)rM   r   r   rP   rO   rP   r   ry   r0   r   r1   r2   r   r@   r6   r6   r6   r7   r&     s    r&   c                   @  s8   e Zd ZdZdddZdddZdddZdddZdS )r'   aW  
    Converts SciPy-style contours (generated by skimage.measure.find_contours) to a more succinct version which only includes
    the pixels to which lines need to be drawn (i.e. not the intervening pixels along each line).

    Args:
        height: height of bounding box, used to detect direction of line segment.
        width: width of bounding box, used to detect direction of line segment.

    Returns:
        the pixels that need to be joined by straight lines to describe the outmost pixels of the foreground similar to
            OpenCV's cv.CHAIN_APPROX_SIMPLE (counterclockwise)
    heightrP   widthr1   r2   c                 C  r3   r4   )r   r   )r5   r   r   r6   r6   r7   r8   r  r9   z GenerateSuccinctContour.__init__current
np.ndarrayprevioustuple[int, int]c                 C  s   |d |d  |d |d  f}d\}}|dv r*t |d d }t |d }||fS |dv r>t |d }t |d }||fS |dv rTt |d }t |d d }||fS |dkrht |d d }t |d d }||fS )	aR  
        Generate contour coordinates. Given the previous and current coordinates of border positions,
        returns the int pixel that marks the extremity of the segmented pixels.

        Args:
            current: coordinates of the current border position.
            previous: coordinates of the previous border position.
        r   r-   r   ))        r}   )rW   rW   )r}   r   rW   ))r   g      )rW         ))r   r   )r   r   )r   rW   )rP   )r5   r   r   Zp_deltarowcolr6   r6   r7   _generate_contour_coordv  s$    
z/GenerateSuccinctContour._generate_contour_coordsequenceSequence[tuple[int, int]]c                 C  s   |d }|d dkr|d }|S |d | j d kr"| j |d  }|S |d | jd kr9d| j  | j |d  }|S d| j | j  |d  }|S )a  
        Each sequence of coordinates describes a boundary between foreground and background starting and ending at two sides
        of the bounding box. To order the sequences correctly, we compute the distance from the top-left of the bounding box
        around the perimeter in a clockwise direction.

        Args:
            sequence: list of border points coordinates.

        Returns:
            the distance round the perimeter of the bounding box from the top-left origin
        r   r-   rr   )r   r   )r5   r   Zfirst_coorddistancer6   r6   r7   !_calculate_distance_from_top_left  s   z9GenerateSuccinctContour._calculate_distance_from_top_leftcontourslist[np.ndarray]c                 C  s  g }g }g d}|D ]G}g }d}d}d}	t |D ]*\}
}|
dkr|d dkrMd}	dt|d d f}|d | jd krBd|d< nl|d dkrLd|d< na|d dkr`d}	t|d d df}nN|d | jd krd	}	t|d t|d d f}|d | jd krd|d
< n'|d | jd krd
}	t|d d t|d f}ntd| d   dS || |}n:|
t|d kr| ||}||kr|| |}n t	
|| ||
d  | kr| ||}||kr|| |}|
t|d krB|	dkr
|d dkr	d||	< n8|	dkr|d | jd krd||	< n$|	d
kr2|d | jd kr1d||	< n|	d	krB|d dkrBd||	< |}q| |}|||d q
|d du rc|ddgd |d du ry|| jd| jd fgd |d
 du r|| j| j | jd | jd fgd |d	 du r|d
| j | j | jd dfgd |jdd d d}|D ]$}|d d |kr|  |rg ||d }n|d }|d }q|d |kr|d |d dkr||d t	t|t	jdS )a  
        Args:
            contours: list of (n, 2)-ndarrays, scipy-style clockwise line segments, with lines separating foreground/background.
                Each contour is an ndarray of shape (n, 2), consisting of n (row, column) coordinates along the contour.
        )FFFFNr   r   r   r-   rW   Tro   rr   zInvalid contour coord z" is generated, skip this instance.)r   r   Fr   r   c                 S  s
   |  dS )Nr   )getxr6   r6   r7   <lambda>  s   
 z2GenerateSuccinctContour.__call__.<locals>.<lambda>)keyr   r   r>   )	enumeraterP   r   r   warningswarnappendrt   r   rG   anyr   sortpopflipr   int32)r5   r   pixels	sequencescornersgroupr   Z
last_addedprevcornericoordpixeldistlastZ	_sequencer6   r6   r7   r?     s   









*(
z GenerateSuccinctContour.__call__N)r   rP   r   rP   r1   r2   )r   r   r   r   r1   r   )r   r   r1   rP   )r   r   r1   r   )rA   rB   rC   rD   r8   r   r   r?   r6   r6   r6   r7   r'   d  s    


r'   c                   @  s0   e Zd ZdZejgZddd
dZddddZdS )r(   a  
    Generate contour for each instance in a 2D array. Use `GenerateSuccinctContour` to only include
    the pixels to which lines need to be drawn

    Args:
        min_num_points: assumed that the created contour does not form a contour if it does not contain more points
            than the specified value. Defaults to 3.
        contour_level: an optional value for `skimage.measure.find_contours` to find contours in the array.
            If not provided, the level is set to `(max(image) + min(image)) / 2`.

    ro   Nmin_num_pointsrP   contour_levelrN   r1   r2   c                 C  s   || _ || _d S r4   )r   r   )r5   r   r   r6   r6   r7   r8   (  r9   z GenerateInstanceContour.__init__r   	inst_maskr   offsetSequence[int] | Nonenp.ndarray | Nonec                 C  s   |  }t|}t|| jd}t|jd |jd }||}|du r$dS |jd | jk r7td| j d dS t|jdkrJtt|j d dS |dddf  |d 7  < |dddf  |d 7  < |S )	z
        Args:
            inst_mask: segmentation mask for a single instance. Shape should be [1, H, W, [D]]
            offset: optional offset of starting position of the instance mask in the original array. Default to 0 for each dim.
        )levelr   r-   Nz< z) points don't make a contour, so skipped!rr   z != 2, check for tricky shapes!)	r   r   r    r   r'   ru   r   printrt   )r5   r   r   Zinst_contour_cvZgenerate_contourZinst_contourr6   r6   r7   r?   ,  s    z GenerateInstanceContour.__call__)ro   N)r   rP   r   rN   r1   r2   )r   )r   r   r   r   r1   r   )	rA   rB   rC   rD   r   rE   rF   r8   r?   r6   r6   r6   r7   r(     s
    r(   c                   @  s2   e Zd ZdZejgZefdddZddddZ	dS )r)   z
    Generate instance centroid using `skimage.measure.centroid`.

    Args:
        dtype: the data type of output centroid.

    r0   DtypeLike | Noner1   r2   c                 C  s
   || _ d S r4   r>   )r5   r0   r6   r6   r7   r8   S  s   
z!GenerateInstanceCentroid.__init__r   r   r   r   Sequence[int] | intc                 C  sd   t |}|d}t|j}t||}t|}t|D ]}||  || 7  < qt||| jdd S )z
        Args:
            inst_mask: segmentation mask for a single instance. Shape should be [1, H, W, [D]]
            offset: optional offset of starting position of the instance mask in the original array. Default to 0 for each dim.

        r   r>   )	r   r   rt   ru   r   r!   ranger   r0   )r5   r   r   r~   Zinst_centroidr   r6   r6   r7   r?   V  s   


z!GenerateInstanceCentroid.__call__N)r0   r   r1   r2   )r   )r   r   r   r   r1   r   )
rA   rB   rC   rD   r   rE   rF   rP   r8   r?   r6   r6   r6   r7   r)   H  s
    r)   c                   @  s"   e Zd ZdZejgZdddZdS )r*   zC
    Generate instance type and probability for each instance.
    	type_predr   seg_predbboxr   instance_idrP   r1   tuple[int, float]c                 C  s   |  \}}}}|d||||f }	|d||||f }
t|	|k|
tdd }	|
|	 }t|dd\}}tt||}t|dd dd}|d d }|dkr[t|dkr[|d d }d	d
 |D }|| t|	d  }t	|t
|fS )ak  
        Args:
            type_pred: pixel-level type prediction map after activation function.
            seg_pred: pixel-level segmentation prediction map after activation function.
            bbox: bounding box coordinates of the instance, shape is [channel, 2 * spatial dims].
            instance_id: get instance type from specified instance id.
        r   r>   T)return_countsc                 S  s   | d S )Nr-   r6   r   r6   r6   r7   r     s    z/GenerateInstanceType.__call__.<locals>.<lambda>)r   reverser-   c                 S  s   i | ]	}|d  |d qS )r   r-   r6   ).0vr6   r6   r7   
<dictcomp>  s    z1GenerateInstanceType.__call__.<locals>.<dictcomp>gư>)flattenr   boolr   listzipsortedrt   r   rP   r   )r5   r   r   r   r   rminrmaxcmincmaxZseg_map_cropZtype_map_cropZ	inst_typeZ	type_listZtype_pixels	type_dict	type_probr6   r6   r7   r?   p  s   zGenerateInstanceType.__call__N)
r   r   r   r   r   r   r   rP   r1   r   )rA   rB   rC   rD   r   rE   rF   r?   r6   r6   r6   r7   r*   i  s    r*   c                      sF   e Zd ZdZ													d'd( fdd Zd)d%d&Z  ZS )*r+   a  
    The post-processing transform for HoVerNet model to generate instance segmentation map.
    It generates an instance segmentation map as well as a dictionary containing centroids, bounding boxes, and contours
    for each instance.

    Args:
        activation: the activation layer to be applied on the input probability map.
            It can be "softmax" or "sigmoid" string, or any callable. Defaults to "softmax".
        mask_threshold: a float value to threshold to binarize probability map to generate mask.
        min_object_size: objects smaller than this size (in pixel) are removed. Defaults to 10.
        sobel_kernel_size: the size of the Sobel kernel used in :py:class:`GenerateInstanceBorder`. Defaults to 5.
        distance_smooth_fn: smoothing function for distance map.
            If not provided, :py:class:`monai.transforms.intensity.GaussianSmooth()` will be used.
        marker_threshold: a float value to threshold to binarize instance border map for markers.
            It turns uncertain area to 1 and other area to 0. Defaults to 0.4.
        marker_radius: the radius of the disk-shaped footprint used in `opening` of markers. Defaults to 2.
        marker_postprocess_fn: post-process function for watershed markers.
            If not provided, :py:class:`monai.transforms.post.FillHoles()` will be used.
        watershed_connectivity: `connectivity` argument of `skimage.segmentation.watershed`.
        min_num_points: minimum number of points to be considered as a contour. Defaults to 3.
        contour_level: an optional value for `skimage.measure.find_contours` to find contours in the array.
            If not provided, the level is set to `(max(image) + min(image)) / 2`.
        device: target device to put the output Tensor data.
    rI   NrJ   rj   r   rr   r-   ro   rK   rL   mask_thresholdrN   rO   rP   sobel_kernel_sizedistance_smooth_fnry   marker_thresholdr   marker_radiusmarker_postprocess_fnwatershed_connectivityr/   r   r   devicestr | torch.device | Noner1   r2   c                   sp   t    || _t|||d| _t|d| _t|d| _t	||||d| _
t|	d| _t|
|d| _t | _d S )N)rK   rM   rO   rl   )rx   )rM   r   r   rO   )r.   )r   r   )superr8   r   r#   generate_watershed_maskr$   generate_instance_borderr%   generate_distance_mapr&   generate_watershed_markersr"   r   r(   generate_instance_contourr)   generate_instance_centroid)r5   rK   r   rO   r   r   r   r   r   r   r   r   r   	__class__r6   r7   r8     s$   
z*HoVerNetInstanceMapPostProcessing.__init__nuclear_predictionr   rn   tuple[dict, NdarrayOrTensor]c                 C  s  |  |}| ||}| ||}| ||}| |||}tt|dh }i }	|D ]O}
||
k}t |}|dd|d d |d d |d d |d d f }|d d |d d g}| 	t
 ||}|dur{| ||}|||d|	|
< q,t|| jd}|	|fS )a"  post-process instance segmentation branches (NP and HV) to generate instance segmentation map.

        Args:
            nuclear_prediction: the output of NP (nuclear prediction) branch of HoVerNet model
            hover_map: the output of HV (hover map) branch of HoVerNet model
        r   Nr-   rr   ro   )bounding_boxr!   contourr   )r   r   r   r   r   setrG   r   r	   r   r
   r   r   r   )r5   r   rn   Zwatershed_maskZinstance_bordersr   Zwatershed_markersinstance_mapZinstance_idsinstance_infoinst_idinstance_maskZinstance_bboxr   Zinstance_contourZinstance_centroidr6   r6   r7   r?     s0   

4
z*HoVerNetInstanceMapPostProcessing.__call__)rI   NrJ   rj   Nr   rr   Nr-   ro   NN)rK   rL   r   rN   rO   rP   r   rP   r   ry   r   r   r   rP   r   ry   r   r/   r   rP   r   rN   r   r   r1   r2   )r   r   rn   r   r1   r   rA   rB   rC   rD   r8   r?   __classcell__r6   r6   r   r7   r+     s     "r+   c                      s6   e Zd ZdZ				dd fddZdddZ  ZS )r,   a#  
    The post-processing transform for HoVerNet model to generate nuclear type information.
    It updates the input instance info dictionary with information about types of the nuclei (value and probability).
    Also if requested (`return_type_map=True`), it generates a pixel-level type map.

    Args:
        activation: the activation layer to be applied on nuclear type branch. It can be "softmax" or "sigmoid" string,
            or any callable. Defaults to "softmax".
        threshold: an optional float value to threshold to binarize probability map.
            If not provided, defaults to 0.5 when activation is not "softmax", otherwise None.
        return_type_map: whether to calculate and return pixel-level type map.
        device: target device to put the output Tensor data.

    rI   NTrK   rL   rM   rN   return_type_mapr   r   r   r1   r2   c                   s   t    || _|| _t | _d}d}d }t|tr4| dkr#d}n"| dkr,d}nt	d| dt
|r;|}n
t	dt| dt|||d| _|sU|d u rUd	}t||d
| _d S )NFrI   TrQ   rR   rS   rT   rU   rW   rX   )r   r8   r   r   r*   generate_instance_typer\   r]   r^   r_   r`   ra   r   rK   r   rb   )r5   rK   rM   r   r   rd   re   rf   r   r6   r7   r8     s,   


z*HoVerNetNuclearTypePostProcessing.__init__type_predictionr   r   dict[int, dict]r   #tuple[dict, NdarrayOrTensor | None]c                 C  s   |  |}| |}d}| jrtt|j|d }|D ].}| j|||| d |d\}}||| d< ||| d< |durJ||||k< t|| j	d}q||fS )a  Process NC (type prediction) branch and combine it with instance segmentation
        It updates the instance_info with instance type and associated probability, and generate instance type map.

        Args:
            instance_info: instance information dictionary, the output of :py:class:`HoVerNetInstanceMapPostProcessing`
            instance_map: instance segmentation map, the output of :py:class:`HoVerNetInstanceMapPostProcessing`
            type_prediction: the output of NC (type prediction) branch of HoVerNet model
        Nr   r   )r   r   r   r   r   ra   r   )
rK   rb   r   r   torchzerosru   r   r   r   )r5   r   r   r   type_mapr   Zinstance_typeZinstance_type_probr6   r6   r7   r?   '  s&   



z*HoVerNetNuclearTypePostProcessing.__call__)rI   NTN)
rK   rL   rM   rN   r   r   r   r   r1   r2   )r   r   r   r   r   r   r1   r   r   r6   r6   r   r7   r,     s    $r,   ):
__future__r   r   collections.abcr   typingr   numpyrG   r   monai.config.type_definitionsr   r   monai.transformsr   r   r	   r
   r   r   r   monai.transforms.transformr   0monai.transforms.utils_pytorch_numpy_unificationr   r   r   r   r   r   monai.utilsr   r   r   monai.utils.miscr   monai.utils.type_conversionr   r   r   _r   r   r   r    r!   __all__r"   r#   r$   r%   r&   r'   r(   r)   r*   r+   r,   r6   r6   r6   r7   <module>   s@   $	 )HE+H 6/!&d