o
    i5                     @  s   d dl mZ d dlmZ d dlZd dlmZ d dlmZm	Z	 d dl
mZmZmZ eddd	\ZZd
dgZG dd
 d
eZ					d.d/ddZd0d d!Z	"d1d2d&d'Z	"d3d4d*d+Zd5d,d-ZdS )6    )annotations)SequenceN)CumulativeIterationMetric)do_metric_reductionremap_instance_id)MetricReductionensure_tupleoptional_importzscipy.optimizelinear_sum_assignment)namePanopticQualityMetriccompute_panoptic_qualityc                      sD   e Zd ZdZdejddfd fddZdddZddddZ  Z	S ) r   a  
    Compute Panoptic Quality between two instance segmentation masks. If specifying `metric_name` to "SQ" or "RQ",
    Segmentation Quality (SQ) or Recognition Quality (RQ) will be returned instead.

    Panoptic Quality is a metric used in panoptic segmentation tasks. This task unifies the typically distinct tasks
    of semantic segmentation (assign a class label to each pixel) and
    instance segmentation (detect and segment each object instance). Compared with semantic segmentation, panoptic
    segmentation distinguish different instances that belong to same class.
    Compared with instance segmentation, panoptic segmentation does not allow overlap and only one semantic label and
    one instance id can be assigned to each pixel.
    Please refer to the following paper for more details:
    https://openaccess.thecvf.com/content_CVPR_2019/papers/Kirillov_Panoptic_Segmentation_CVPR_2019_paper.pdf

    This class also refers to the following implementation:
    https://github.com/TissueImageAnalytics/CoNIC

    Args:
        num_classes: number of classes. The number should not count the background.
        metric_name: output metric. The value can be "pq", "sq" or "rq".
            Except for input only one metric, multiple metrics are also supported via input a sequence of metric names
            such as ("pq", "sq", "rq"). If input a sequence, a list of results with the same order
            as the input names will be returned.
        reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
            available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
            ``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.
        match_iou_threshold: IOU threshold to determine the pairing between `y_pred` and `y`. Usually,
            it should >= 0.5, the pairing between instances of `y_pred` and `y` are identical.
            If set `match_iou_threshold` < 0.5, this function uses Munkres assignment to find the
            maximal amount of unique pairing.
        smooth_numerator: a small constant added to the numerator to avoid zero.

    pq      ?ư>num_classesintmetric_nameSequence[str] | str	reductionMetricReduction | strmatch_iou_thresholdfloatsmooth_numeratorreturnNonec                   s0   t    || _|| _|| _|| _t|| _d S N)super__init__r   r   r   r   r   r   )selfr   r   r   r   r   	__class__ `/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/metrics/panoptic_quality.pyr   =   s   
zPanopticQualityMetric.__init__y_predtorch.Tensoryc              	   C  s  |j |j krtd|j  d|j  d|j d dkr%td|j d  d| }|dkr5td| d|j d	 }tj|| jdg|jd
}t|D ]B}||d	f ||d	f }}||df ||df }	}
t| jD ] }|
|d k| }|	|d k| }t||d| j	dd|||f< qkqJ|S )aV  
        Args:
            y_pred: Predictions. It must be in the form of B2HW and have integer type. The first channel and the
                second channel represent the instance predictions and classification predictions respectively.
            y: ground truth. It must have the same shape as `y_pred` and have integer type. The first channel and the
                second channel represent the instance labels and classification labels respectively.
                Values in the second channel of `y_pred` and `y` should be in the range of 0 to `self.num_classes`,
                where 0 represents the background.

        Raises:
            ValueError: when `y_pred` and `y` have different shapes.
            ValueError: when `y_pred` and `y` have != 2 channels.
            ValueError: when `y_pred` and `y` have != 4 dimensions.

        z*y_pred and y should have same shapes, got  and .      zJfor panoptic quality calculation, only 2 channels input is supported, got    z6y_pred should have 4 dimensions (batch, 2, h, w), got r   deviceT)predgtremapr   output_confusion_matrix)
shape
ValueError
ndimensiontorchzerosr   r-   ranger   r   )r   r$   r&   dims
batch_sizeoutputsbZtrue_instanceZpred_instanceZ
true_classZ
pred_classcZpred_instance_cZtrue_instance_cr"   r"   r#   _compute_tensorL   s4   
z%PanopticQualityMetric._compute_tensorNMetricReduction | str | None!torch.Tensor | list[torch.Tensor]c                 C  s   |   }t|tjstdt||p| j\}}|d |d |d |d f\}}}}g }	| jD ]>}
t|
}
|
dkrL|		||d|  d|  | j
   q/|
dkr[|		||| j
   q/|		||d|  d|  | j
   q/t|	d	krx|	d
 S |	S )a  
        Execute reduction logic for the output of `compute_panoptic_quality`.

        Args:
            reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
                available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
                ``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.

        z-the data to aggregate must be PyTorch Tensor.).r   ).r)   ).r*   ).   rqr   sqr)   r   )
get_buffer
isinstancer5   Tensorr3   r   r   r   _check_panoptic_metric_nameappendr   len)r   r   dataf_tpfpfniou_sumresultsr   r"   r"   r#   	aggregate}   s   
$
&&zPanopticQualityMetric.aggregate)r   r   r   r   r   r   r   r   r   r   r   r   )r$   r%   r&   r%   r   r%   r   )r   r>   r   r?   )
__name__
__module____qualname____doc__r   
MEAN_BATCHr   r=   rQ   __classcell__r"   r"   r    r#   r      s    $
1r   Tr   r   Fr.   r%   r/   r   strr0   boolr   r   r   r1   r   c                   s  |j | j krtd| j  d|j  d|dks|dkr#td| d| }|  } |du r7t|}t| } t| || jd\}}}	t|||jd\}
 fd	d
|dd D } fdd
|	dd D }tt|t|}}}|
 }|rt	j
||||g| jdS t|}|dkrt	j
||d|  d|  |  | jdS |dkrt	j
|||  | jdS t	j
||d|  d|  |  | jdS )av  Computes Panoptic Quality (PQ). If specifying `metric_name` to "SQ" or "RQ",
    Segmentation Quality (SQ) or Recognition Quality (RQ) will be returned instead.

    In addition, if `output_confusion_matrix` is True, the function will return a tensor with shape 4, which
    represents the true positive, false positive, false negative and the sum of iou. These four values are used to
    calculate PQ, and returning them directly enables further calculation over all images.

    Args:
        pred: input data to compute, it must be in the form of HW and have integer type.
        gt: ground truth. It must have the same shape as `pred` and have integer type.
        metric_name: output metric. The value can be "pq", "sq" or "rq".
        remap: whether to remap `pred` and `gt` to ensure contiguous ordering of instance id.
        match_iou_threshold: IOU threshold to determine the pairing between `pred` and `gt`. Usually,
            it should >= 0.5, the pairing between instances of `pred` and `gt` are identical.
            If set `match_iou_threshold` < 0.5, this function uses Munkres assignment to find the
            maximal amount of unique pairing.
        smooth_numerator: a small constant added to the numerator to avoid zero.

    Raises:
        ValueError: when `pred` and `gt` have different shapes.
        ValueError: when `match_iou_threshold` <= 0.0 or > 1.0.

    z)pred and gt should have same shapes, got r'   r(           g      ?z4'match_iou_threshold' should be within (0, 1], got: Tr,   c                      g | ]}| vr|qS r"   r"   .0idx)paired_truer"   r#   
<listcomp>       z,compute_panoptic_quality.<locals>.<listcomp>r)   Nc                   r[   r"   r"   r\   )paired_predr"   r#   r`      ra   rA   r   rB   )r2   r3   r   r   _get_pairwise_iour-   _get_paired_iourH   sumr5   	as_tensorrF   )r.   r/   r   r0   r   r   r1   pairwise_ioutrue_id_listpred_id_list
paired_iouZunpaired_trueZunpaired_predrL   rM   rN   rO   r"   )rb   r_   r#   r      s2   !((list[torch.Tensor]c                 C  s.   t |  }d|vr|dtd  |S )Nr   )listuniqueinsertr5   tensorr   )r/   Zid_listr"   r"   r#   _get_id_list   s   rp   cpur-   str | torch.device;tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]c                 C  s<  t | }t |}tjt|d t|d gtj|d}g }g }|dd  D ]}tj||k|d }	||	 q%|dd  D ]}
tj| |
k|d }|| q>tdt|D ]@}||d  }	| |	dk }t	|
 }|D ])}|dkruqn||d  }|	|  }|	|  }|||  }|||d |d f< qnqX|||fS )Nr)   )dtyper-   r,   r   )rp   r5   r6   rH   r   rf   r   rG   r7   rl   rm   re   )r.   r/   r-   ri   rh   rg   Z
true_masksZ
pred_maskstZt_maskpZp_maskZtrue_idZpred_true_overlapZpred_true_overlap_idpred_idtotalinteriour"   r"   r#   rc      s2   &
	rc   rg   /tuple[torch.Tensor, torch.Tensor, torch.Tensor]c                 C  s   |dkr4d| | |k< t | d d df t | d d df }}| ||f }|d7 }|d7 }|||fS |   } t|  \}}| ||f }t jt|||k d |d}t jt|||k d |d}|||k }|||fS )Nr   rZ   r   r)   r,   )r5   nonzerorq   numpyr
   rf   rl   )rg   r   r-   r_   rb   rj   r"   r"   r#   rd   	  s   .

rd   c                 C  sH   |  dd} |  } | dv rdS | dv rdS | dv rdS td	|  d
)N rK   )panoptic_qualityr   r   )Zsegmentation_qualityrB   rB   )Zrecognition_qualityrA   rA   zmetric name: z) is wrong, please use 'pq', 'sq' or 'rq'.)replacelowerr3   )r   r"   r"   r#   rF     s   rF   )r   Tr   r   F)r.   r%   r/   r%   r   rX   r0   rY   r   r   r   r   r1   rY   r   r%   )r/   r%   r   rk   )rq   )r.   r%   r/   r%   r-   rr   r   rs   )r   rq   )rg   r%   r   r   r-   rr   r   r{   )r   rX   r   rX   )
__future__r   collections.abcr   r5   Zmonai.metrics.metricr   monai.metrics.utilsr   r   monai.utilsr   r   r	   r
   rK   __all__r   r   rp   rc   rd   rF   r"   r"   r"   r#   <module>   s,    
C
#