o
    i#                     @  s0  d dl mZ d dlZd dlmZmZ d dlmZmZm	Z	 d dl
Z
d dlmZ d dlmZmZ d dlmZmZmZmZ edejed	\ZZerlzed\ZZd d
lmZ d dlmZ d dlmZ W n- eyk   dZY n#w edejed\ZZedejeddd\ZZedejeddd\ZZG dd deZ dS )    )annotationsN)CallableSequence)TYPE_CHECKINGAnycast)_Loss)CumulativeIterationMetric
LossMetric)
IgniteInfoMetricReductionmin_versionoptional_importignitedistributed)Engine)Metric)reinit__is_reducedFzignite.enginer   zignite.metricsr   base)as_typezignite.metrics.metricr   	decoratorc                      sj   e Zd ZdZdddd dejdfd% fddZed&ddZed'ddZ	d(ddZ
d) fd#d$Z  ZS )*IgniteMetricHandlera  
    Base Metric class based on ignite event handler mechanism.
    The input `prediction` or `label` data can be a PyTorch Tensor or numpy array with batch dim and channel dim,
    or a list of PyTorch Tensor or numpy array without batch dim.

    Args:
        metric_fn: callable function or class to compute raw metric results after every iteration.
            expect to return a Tensor with shape (batch, channel, ...) or tuple (Tensor, not_nans).
        loss_fn: A torch _Loss function which is used to generate the LossMetric
        output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
            construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
            lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.
            `engine.state` and `output_transform` inherit from the ignite concept:
            https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
            https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
        save_details: whether to save metric computation details per image, for example: mean_dice of every image.
            default to True, will save to `engine.state.metric_details` dict with the metric name as key.
        reduction: Argument for the LossMetric, look there for details
        get_not_nans: Argument for the LossMetric, look there for details

    Nc                 C  s   | S N )xr   r   ^/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/handlers/ignite_metric.py<lambda>F   s    zIgniteMetricHandler.<lambda>TF	metric_fn CumulativeIterationMetric | Noneloss_fn_Loss | Noneoutput_transformr   save_detailsbool	reductionMetricReduction | strget_not_nansreturnNonec                   s   d| _ tt|| _|| _|| _g | _d | _d | _| jd u r&| jd u r&t	d| jd ur4| jd ur4t	d| jr@t
| j||d| _t | d S )NFz.Either metric_fn or loss_fn have to be passed.z<Either metric_fn or loss_fn have to be passed, but not both.)r   r$   r&   )_is_reducedr   r	   r   r   r"   Z_scores_engine_name
ValueErrorr
   super__init__)selfr   r   r!   r"   r$   r&   	__class__r   r   r.   B   s   	zIgniteMetricHandler.__init__c                 C  s   | j   d S r   )r   reset)r/   r   r   r   r2   \   s   zIgniteMetricHandler.resetoutputSequence[torch.Tensor]c                 C  s8   t |dkrtdt | d|\}}| || dS )z
        Args:
            output: sequence with contents [y_pred, y].

        Raises:
            ValueError: When ``output`` length is not 2. metric_fn can only support y_pred and y.

           zoutput must have length 2, got .N)lenr,   r   )r/   r3   y_predyr   r   r   update`   s   
zIgniteMetricHandler.updater   c                 C  s   | j  }t|ttfrt|dkrtd |d }d| _| j	r:| j
du s+| jdu r/td| j  | j
jj| j< t|tjrM| }|jdkrM| }|S )zr
        Raises:
            NotComputableError: When ``compute`` is called before an ``update`` occurs.

           z>metric handler can only record the first value of result list.r   TNzCplease call the attach() function to connect expected engine first.)r   	aggregate
isinstancetuplelistr7   warningswarnr)   r"   r*   r+   RuntimeError
get_bufferstatemetric_detailstorchTensorsqueezendimitem)r/   resultr   r   r   computeq   s   


zIgniteMetricHandler.computeenginer   namestrc                   sB   t  j||d || _|| _| jrt|jdsi |j_dS dS dS )aF  
        Attaches current metric to provided engine. On the end of engine's run,
        `engine.state.metrics` dictionary will contain computed metric's value under provided name.

        Args:
            engine: the engine to which the metric must be attached.
            name: the name of the metric to attach.

        )rM   rN   rE   N)r-   attachr*   r+   r"   hasattrrD   rE   )r/   rM   rN   r0   r   r   rP      s   
zIgniteMetricHandler.attach)r   r   r   r    r!   r   r"   r#   r$   r%   r&   r#   r'   r(   )r'   r(   )r3   r4   r'   r(   )r'   r   )rM   r   rN   rO   r'   r(   )__name__
__module____qualname____doc__r   MEANr.   r   r2   r:   rL   rP   __classcell__r   r   r0   r   r   +   s    
r   )!
__future__r   r@   collections.abcr   r   typingr   r   r   rF   torch.nn.modules.lossr   monai.metricsr	   r
   monai.utilsr   r   r   r   OPT_IMPORT_VERSIONidist_
has_igniteignite.enginer   ignite.metricsr   Zignite.metrics.metricr   ImportErrorr   r   r   r   r   <module>   s0   
