U
    {Ph%                     @  s^  d dl mZ d dlZd dlmZmZ d dlmZmZm	Z	 d dl
Z
d dlmZ d dlmZ d dlmZmZ d dlmZmZmZmZ ed	ejed
\ZZerz4ed	\ZZd dlmZ d dlmZ d dlm Z  W n e!k
r   dZY nX nDedejed\ZZedejeddd\ZZedejeddd\Z ZG dd deZ"eddddG dd de"Z#dS )     )annotationsN)CallableSequence)TYPE_CHECKINGAnycast)_Loss)
IgniteInfo)CumulativeIterationMetric
LossMetric)MetricReduction
deprecatedmin_versionoptional_importignitedistributed)Engine)Metric)reinit__is_reducedFzignite.enginer   zignite.metricsr   base)as_typezignite.metrics.metricr   	decoratorc                	      s   e Zd ZdZdddd dejdfddd	d
dd
dd fddZeddddZedddddZ	ddddZ
dddd fddZ  ZS )IgniteMetricHandlera  
    Base Metric class based on ignite event handler mechanism.
    The input `prediction` or `label` data can be a PyTorch Tensor or numpy array with batch dim and channel dim,
    or a list of PyTorch Tensor or numpy array without batch dim.

    Args:
        metric_fn: callable function or class to compute raw metric results after every iteration.
            expect to return a Tensor with shape (batch, channel, ...) or tuple (Tensor, not_nans).
        loss_fn: A torch _Loss function which is used to generate the LossMetric
        output_transform: callable to extract `y_pred` and `y` from `ignite.engine.state.output` then
            construct `(y_pred, y)` pair, where `y_pred` and `y` can be `batch-first` Tensors or
            lists of `channel-first` Tensors. the form of `(y_pred, y)` is required by the `update()`.
            `engine.state` and `output_transform` inherit from the ignite concept:
            https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
            https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
        save_details: whether to save metric computation details per image, for example: mean_dice of every image.
            default to True, will save to `engine.state.metric_details` dict with the metric name as key.
        reduction: Argument for the LossMetric, look there for details
        get_not_nans: Argument for the LossMetric, look there for details

    Nc                 C  s   | S N xr   r   Q/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/handlers/ignite_metric.py<lambda>G       zIgniteMetricHandler.<lambda>TF CumulativeIterationMetric | None_Loss | Noner   boolMetricReduction | strNone	metric_fnloss_fnoutput_transformsave_details	reductionget_not_nansreturnc                   s   d| _ tt|| _|| _|| _g | _d | _d | _| jd krL| jd krLt	d| jd k	rh| jd k	rht	d| jrt
| j||d| _t | d S )NFz.Either metric_fn or loss_fn have to be passed.z<Either metric_fn or loss_fn have to be passed, but not both.)r'   r*   r+   )_is_reducedr   r
   r&   r'   r)   Z_scores_engine_name
ValueErrorr   super__init__selfr&   r'   r(   r)   r*   r+   	__class__r   r   r2   C   s    	zIgniteMetricHandler.__init__)r,   c                 C  s   | j   d S r   )r&   reset)r4   r   r   r   r7   ]   s    zIgniteMetricHandler.resetzSequence[torch.Tensor])outputr,   c                 C  s8   t |dkr tdt | d|\}}| || dS )z
        Args:
            output: sequence with contents [y_pred, y].

        Raises:
            ValueError: When ``output`` length is not 2. metric_fn can only support y_pred and y.

           zoutput must have length 2, got .N)lenr0   r&   )r4   r8   y_predyr   r   r   updatea   s    
zIgniteMetricHandler.updater   c                 C  s   | j  }t|ttfr6t|dkr.td |d }d| _| j	rt| j
dksV| jdkr^td| j  | j
jj| j< t|tjr| }|jdkr| }|S )zr
        Raises:
            NotComputableError: When ``compute`` is called before an ``update`` occurs.

           z>metric handler can only record the first value of result list.r   TNzCplease call the attach() function to connect expected engine first.)r&   	aggregate
isinstancetuplelistr;   warningswarnr-   r)   r.   r/   RuntimeError
get_bufferstatemetric_detailstorchTensorsqueezendimitem)r4   resultr   r   r   computer   s    


zIgniteMetricHandler.computer   str)enginenamer,   c                   s:   t  j||d || _|| _| jr6t|jds6i |j_dS )aF  
        Attaches current metric to provided engine. On the end of engine's run,
        `engine.state.metrics` dictionary will contain computed metric's value under provided name.

        Args:
            engine: the engine to which the metric must be attached.
            name: the name of the metric to attach.

        )rR   rS   rI   N)r1   attachr.   r/   r)   hasattrrH   rI   )r4   rR   rS   r5   r   r   rT      s
    
zIgniteMetricHandler.attach)__name__
__module____qualname____doc__r   MEANr2   r   r7   r>   rP   rT   __classcell__r   r   r5   r   r   ,   s    r   z1.2z1.4z0Use IgniteMetricHandler instead of IgniteMetric.)sinceremoved
msg_suffixc                	      sB   e Zd Zdddd dejdfdddd	d
d	dd fddZ  ZS )IgniteMetricNc                 C  s   | S r   r   r   r   r   r   r      r   zIgniteMetric.<lambda>TFr    r!   r   r"   r#   r$   r%   c                   s   t  j||||||d d S )N)r&   r'   r(   r)   r*   r+   )r1   r2   r3   r5   r   r   r2      s    	zIgniteMetric.__init__)rV   rW   rX   r   rZ   r2   r[   r   r   r5   r   r_      s   r_   )$
__future__r   rD   collections.abcr   r   typingr   r   r   rJ   torch.nn.modules.lossr   monai.configr	   monai.metricsr
   r   monai.utilsr   r   r   r   OPT_IMPORT_VERSIONidist_
has_igniteignite.enginer   ignite.metricsr   Zignite.metrics.metricr   ImportErrorr   r_   r   r   r   r   <module>   s:       
r