U
    |Ph'                     @  s(  d dl mZ d dlZd dlmZ d dlmZmZ d dlm	Z	m
Z
 d dlZd dlZd dlmZmZmZ d dlmZmZmZmZ edejed	\ZZe	rd d
lmZ nedejed\ZZddddgZdddddZddddZd+dddddddddd	d dZd,d"d#d$d%d&dZ d'dd(d)d*Z!dS )-    )annotationsN)OrderedDict)CallableSequence)TYPE_CHECKINGAny)
IgniteInfoKeysCollectionPathLike)ensure_tuplelook_up_optionmin_versionoptional_importignitedistributed)Enginezignite.enginer   stopping_fn_from_metricstopping_fn_from_losswrite_metrics_reportsfrom_enginestrzCallable[[Engine], Any])metric_namereturnc                   s   ddd fdd}|S )zd
    Returns a stopping function for ignite.handlers.EarlyStopping using the given metric name.
    r   r   enginer   c                   s   | j j  S N)statemetricsr   r    I/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/handlers/utils.pystopping_fn'   s    z,stopping_fn_from_metric.<locals>.stopping_fnr    )r   r"   r    r   r!   r   "   s    )r   c                  C  s   ddddd} | S )z]
    Returns a stopping function for ignite.handlers.EarlyStopping using the loss value.
    r   r   r   c                 S  s
   | j j S r   )r   outputr   r    r    r!   r"   2   s    z*stopping_fn_from_loss.<locals>.stopping_fnr    )r"   r    r    r!   r   -   s    ,csvr
   zSequence[str] | Nonez+dict[str, torch.Tensor | np.ndarray] | Nonezstr | Sequence[str] | Nonezlist[str] | NoneNone)	save_dirimagesr   metric_detailssummary_opsdelioutput_typeclass_labelsr   c                   s  |  dkrtd| dtj| s2t|  |dk	rt|dkrttj| dd4}|	 D ]$\}	}
|
|	 | t|
 d qdW 5 Q R X |dk	rt|dkr|	 D ].\}	}
t|
tjr|
  }
|
jdkr|
d	}
n|
jd
kr|
d}
|dkr&dd t|
jd
 D }ndd |D }|dg7 }tj|
tj|
d
ddgd
d}
ttj| |	 ddt}|
d| || d t|
D ]H\}}|
|dk	r|| nt| | |dd |D  d qW 5 Q R X |dk	rttjtjtjtjdd tjdd dt|}d|kr@t  }ddddfdd  ttj| |	 d!dn}|
d"| || d tt!|
D ]<\}|
||  | | fd#d|D  d qW 5 Q R X qdS )$a
  
    Utility function to write the metrics into files, contains 3 parts:
    1. if `metrics` dict is not None, write overall metrics into file, every line is a metric name and value pair.
    2. if `metric_details` dict is not None,  write raw metric data of every image into file, every line for 1 image.
    3. if `summary_ops` is not None, compute summary based on operations on `metric_details` and write to file.

    Args:
        save_dir: directory to save all the metrics reports.
        images: name or path of every input image corresponding to the metric_details data.
            if None, will use index number as the filename of every input image.
        metrics: a dictionary of (metric name, metric value) pairs.
        metric_details: a dictionary of (metric name, metric raw values) pairs, usually, it comes from metrics
            computation, for example, the raw value can be the mean_dice of every channel of every input image.
        summary_ops: expected computation operations to generate the summary report.
            it can be: None, "*" or list of strings, default to None.
            None - don't generate summary report for every expected metric_details.
            "*" - generate summary report for every metric_details with all the supported operations.
            list of strings - generate summary report for every metric_details with specified operations, they
            should be within list: ["mean", "median", "max", "min", "<int>percentile", "std", "notnans"].
            the number in "<int>percentile" should be [0, 100], like: "15percentile". default: "90percentile".
            for more details, please check: https://numpy.org/doc/stable/reference/generated/numpy.nanpercentile.html.
            note that: for the overall summary, it computes `nanmean` of all classes for each image first,
            then compute summary. example of the generated summary report::

                class    mean    median    max    5percentile 95percentile  notnans
                class0  6.0000   6.0000   7.0000   5.1000      6.9000       2.0000
                class1  6.0000   6.0000   6.0000   6.0000      6.0000       1.0000
                mean    6.2500   6.2500   7.0000   5.5750      6.9250       2.0000

        deli: the delimiter character in the saved file, default to "," as the default output type is `csv`.
            to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.
        output_type: expected output file type, supported types: ["csv"], default to "csv".
        class_labels: list of class names used to name the classes in the output report, if None,
            "class0", ..., "classn" are used, default to None.

    r%   zunsupported output type: .Nr   zmetrics.csvw
)   r1   r1   )r1   c                 S  s   g | ]}d t | qS )classr   .0ir    r    r!   
<listcomp>}   s     z)write_metrics_reports.<locals>.<listcomp>c                 S  s   g | ]}t |qS r    r4   r5   r    r    r!   r8      s     meanT)axiskeepdims)r:   z_raw.csvfilenamec                 S  s*   g | ]"}t |ttfr|d nt|qS z.4f)
isinstanceintfloatr   )r6   cr    r    r!   r8      s     c                 S  s   t | d | d S )Nr   r1   )npnanpercentilexr    r    r!   <lambda>       z'write_metrics_reports.<locals>.<lambda>c                 S  s   t |   S r   )rB   isnansumrD   r    r    r!   rF      rG   )r9   medianmaxmin90percentilestdZnotnans*r   z
np.ndarrayr   )opdr   c                   s>   |  dst|  }||S t| dd } d ||fS )N
percentiler   rM   )endswithr   r?   split)rP   rQ   Zc_op	threshold)supported_opsr    r!   _compute_op   s
    

z*write_metrics_reports.<locals>._compute_opz_summary.csvr3   c                   s   g | ]} |d qS r=   r    r6   k)rW   rA   r    r!   r8      s     )"lower
ValueErrorospathexistsmakedirslenopenjoinitemswriter   r>   torchTensorcpunumpyndimreshaperangeshaperB   concatenatenanmean	enumerater   	nanmediannanmaxnanminnanstdr   tuplekeys	transpose)r'   r(   r   r)   r*   r+   r,   r-   frY   vr7   bopsr    )rW   rA   rV   r!   r   8   s\    .
(



6
Fr	   boolr   )ru   firstr   c                   s   t |   fdd}|S )a  
    Utility function to simplify the `batch_transform` or `output_transform` args of ignite components
    when handling dictionary or list of dictionaries(for example: `engine.state.batch` or `engine.state.output`).
    Users only need to set the expected keys, then it will return a callable function to extract data from
    dictionary and construct a tuple respectively.

    If data is a list of dictionaries after decollating, extract expected keys and construct lists respectively,
    for example, if data is `[{"A": 1, "B": 2}, {"A": 3, "B": 4}]`, from_engine(["A", "B"]): `([1, 3], [2, 4])`.

    It can help avoid a complicated `lambda` function and make the arg of metrics more straight-forward.
    For example, set the first key as the prediction and the second key as label to get the expected data
    from `engine.state.output` for a metric::

        from monai.handlers import MeanDice, from_engine

        metric = MeanDice(
            include_background=False,
            output_transform=from_engine(["pred", "label"])
        )

    Args:
        keys: specified keys to extract data from dictionary or decollated list of dictionaries.
        first: whether only extract specified keys from the first item if input data is a list of dictionaries,
            it's used to extract the scalar data which doesn't have batch dim and was replicated into every
            dictionary when decollating, like `loss`, etc.


    c                   sl   t  tr t fddD S t  trht  d trh fddD }t|dkr`t|S |d S d S )Nc                 3  s   | ]} | V  qd S r   r    rX   datar    r!   	<genexpr>   s     z0from_engine.<locals>._wrapper.<locals>.<genexpr>r   c                   s.   g | ]& rd    n fddD qS )r   c                   s   g | ]}|  qS r    r    r5   rY   r    r!   r8      s     z<from_engine.<locals>._wrapper.<locals>.<listcomp>.<listcomp>r    )r6   )r~   r|   r   r!   r8      s     z1from_engine.<locals>._wrapper.<locals>.<listcomp>r1   )r>   dictrt   listr`   )r~   ret_keysr|   r}   r!   _wrapper   s
    
zfrom_engine.<locals>._wrapper)r   )ru   r|   r   r    r   r!   r      s    	r   )rE   r   c                 C  s   dS )z
    Always return `None` for any input data.
    A typical usage is to avoid logging the engine output of every iteration during evaluation.

    Nr    rD   r    r    r!   ignore_data   s    r   )r$   r%   N)F)"
__future__r   r\   collectionsr   collections.abcr   r   typingr   r   rh   rB   re   monai.configr   r	   r
   monai.utilsr   r   r   r   OPT_IMPORT_VERSIONidist_ignite.enginer   __all__r   r   r   r   r   r    r    r    r!   <module>   s*       r+