o
    i'                     @  s   d dl mZ d dlZd dlmZ d dlmZmZ d dlm	Z	m
Z
 d dlZd dlZd dlmZmZ d dlmZmZmZmZmZ edejed	\ZZe	rQd d
lmZ n
edejed\ZZg dZd4ddZd5ddZ			d6d7d&d'Zd8d9d.d/Z d:d2d3Z!dS );    )annotationsN)OrderedDict)CallableSequence)TYPE_CHECKINGAny)KeysCollectionPathLike)
IgniteInfoensure_tuplelook_up_optionmin_versionoptional_importignitedistributed)Enginezignite.enginer   )stopping_fn_from_metricstopping_fn_from_losswrite_metrics_reportsfrom_enginemetric_namestrreturnCallable[[Engine], Any]c                   s   d fdd}|S )	zd
    Returns a stopping function for ignite.handlers.EarlyStopping using the given metric name.
    enginer   r   r   c                   s   | j j  S N)statemetricsr   r    V/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/handlers/utils.pystopping_fn'   s   z,stopping_fn_from_metric.<locals>.stopping_fnNr   r   r   r   r    )r   r"   r    r   r!   r   "   s   r   c                  C  s   ddd} | S )	z]
    Returns a stopping function for ignite.handlers.EarlyStopping using the loss value.
    r   r   r   r   c                 S  s
   | j j S r   )r   outputr   r    r    r!   r"   2   s   
z*stopping_fn_from_loss.<locals>.stopping_fnNr#   r    )r"   r    r    r!   r   -   s   
r   ,csvsave_dirr	   imagesSequence[str] | Noner   +dict[str, torch.Tensor | np.ndarray] | Nonemetric_detailssummary_opsstr | Sequence[str] | Nonedelioutput_typeclass_labelslist[str] | NoneNonec                   s$  |  dkrtd| dtj| st|  |durTt|dkrTttj| dd}|	 D ]\}	}
|
|	 | t|
 d q2W d   n1 sOw   Y  |durt|dkr|	 D ]*\}	}
t|
tjru|
  }
|
jdkr|
d	}
n
|
jd
kr|
d}
|du rdd t|
jd
 D }ndd |D }|dg7 }tj|
tj|
d
ddgd
d}
ttj| |	 dd=}|
d| || d t|
D ]"\}}|
|dur|| nt| | |dd |D  d qW d   n	1 sw   Y  |durttjtjtjtjdd tjdd dt|}d|v r2t  }d&fd!d" ttj| |	 d#d<}|
d$| || d tt!|
D ]\}|
||  | | fd%d|D  d q\W d   n	1 sw   Y  qddS dS dS )'a
  
    Utility function to write the metrics into files, contains 3 parts:
    1. if `metrics` dict is not None, write overall metrics into file, every line is a metric name and value pair.
    2. if `metric_details` dict is not None,  write raw metric data of every image into file, every line for 1 image.
    3. if `summary_ops` is not None, compute summary based on operations on `metric_details` and write to file.

    Args:
        save_dir: directory to save all the metrics reports.
        images: name or path of every input image corresponding to the metric_details data.
            if None, will use index number as the filename of every input image.
        metrics: a dictionary of (metric name, metric value) pairs.
        metric_details: a dictionary of (metric name, metric raw values) pairs, usually, it comes from metrics
            computation, for example, the raw value can be the mean_dice of every channel of every input image.
        summary_ops: expected computation operations to generate the summary report.
            it can be: None, "*" or list of strings, default to None.
            None - don't generate summary report for every expected metric_details.
            "*" - generate summary report for every metric_details with all the supported operations.
            list of strings - generate summary report for every metric_details with specified operations, they
            should be within list: ["mean", "median", "max", "min", "<int>percentile", "std", "notnans"].
            the number in "<int>percentile" should be [0, 100], like: "15percentile". default: "90percentile".
            for more details, please check: https://numpy.org/doc/stable/reference/generated/numpy.nanpercentile.html.
            note that: for the overall summary, it computes `nanmean` of all classes for each image first,
            then compute summary. example of the generated summary report::

                class    mean    median    max    5percentile 95percentile  notnans
                class0  6.0000   6.0000   7.0000   5.1000      6.9000       2.0000
                class1  6.0000   6.0000   6.0000   6.0000      6.0000       1.0000
                mean    6.2500   6.2500   7.0000   5.5750      6.9250       2.0000

        deli: the delimiter character in the saved file, default to "," as the default output type is `csv`.
            to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.
        output_type: expected output file type, supported types: ["csv"], default to "csv".
        class_labels: list of class names used to name the classes in the output report, if None,
            "class0", ..., "classn" are used, default to None.

    r&   zunsupported output type: .Nr   zmetrics.csvw
)   r6   r6   )r6   c                 S  s   g | ]}d t | qS )classr   .0ir    r    r!   
<listcomp>}   s    z)write_metrics_reports.<locals>.<listcomp>c                 S  s   g | ]}t |qS r    r9   r:   r    r    r!   r=          meanT)axiskeepdims)r@   z_raw.csvfilenamec                 S  s*   g | ]}t |ttfr|d nt|qS z.4f)
isinstanceintfloatr   )r;   cr    r    r!   r=      s   * c                 S  s   t | d | d S )Nr   r6   )npnanpercentilexr    r    r!   <lambda>   r>   z'write_metrics_reports.<locals>.<lambda>c                 S  s   t |   S r   )rH   isnansumrJ   r    r    r!   rL      s    )r?   medianmaxmin90percentilestdZnotnans*opr   d
np.ndarrayr   r   c                   s>   |  dst|  }||S t| dd } d ||fS )N
percentiler   rR   )endswithr   rE   split)rU   rV   Zc_op	threshold)supported_opsr    r!   _compute_op   s
   

z*write_metrics_reports.<locals>._compute_opz_summary.csvr8   c                   s   g | ]	} |d qS rC   r    r;   k)r]   rG   r    r!   r=      s    )rU   r   rV   rW   r   r   )"lower
ValueErrorospathexistsmakedirslenopenjoinitemswriter   rD   torchTensorcpunumpyndimreshaperangeshaperH   concatenatenanmean	enumerater   	nanmediannanmaxnanminnanstdr   tuplekeys	transpose)r'   r(   r   r+   r,   r.   r/   r0   fr_   vr<   bopsr    )r]   rG   r\   r!   r   8   sr   .






6r   Fr{   r   firstboolr   c                   s   t |   fdd}|S )a  
    Utility function to simplify the `batch_transform` or `output_transform` args of ignite components
    when handling dictionary or list of dictionaries(for example: `engine.state.batch` or `engine.state.output`).
    Users only need to set the expected keys, then it will return a callable function to extract data from
    dictionary and construct a tuple respectively.

    If data is a list of dictionaries after decollating, extract expected keys and construct lists respectively,
    for example, if data is `[{"A": 1, "B": 2}, {"A": 3, "B": 4}]`, from_engine(["A", "B"]): `([1, 3], [2, 4])`.

    It can help avoid a complicated `lambda` function and make the arg of metrics more straight-forward.
    For example, set the first key as the prediction and the second key as label to get the expected data
    from `engine.state.output` for a metric::

        from monai.handlers import MeanDice, from_engine

        metric = MeanDice(
            include_background=False,
            output_transform=from_engine(["pred", "label"])
        )

    Args:
        keys: specified keys to extract data from dictionary or decollated list of dictionaries.
        first: whether only extract specified keys from the first item if input data is a list of dictionaries,
            it's used to extract the scalar data which doesn't have batch dim and was replicated into every
            dictionary when decollating, like `loss`, etc.


    c                   sp   t  trt fddD S t  tr4t  d tr6 fddD }t|dkr0t|S |d S d S d S )Nc                 3  s    | ]} | V  qd S r   r    r^   datar    r!   	<genexpr>   s    z0from_engine.<locals>._wrapper.<locals>.<genexpr>r   c                   s.   g | ] rd    n fddD qS )r   c                   s   g | ]}|  qS r    r    r:   r_   r    r!   r=      r>   z<from_engine.<locals>._wrapper.<locals>.<listcomp>.<listcomp>r    )r;   )r   r   r   r!   r=      s   . z1from_engine.<locals>._wrapper.<locals>.<listcomp>r6   )rD   dictrz   listrf   )r   ret_keysr   r   r!   _wrapper   s   
zfrom_engine.<locals>._wrapper)r   )r{   r   r   r    r   r!   r      s   	r   rK   r   c                 C  s   dS )z
    Always return `None` for any input data.
    A typical usage is to avoid logging the engine output of every iteration during evaluation.

    Nr    rJ   r    r    r!   ignore_data   s   r   )r   r   r   r   )r   r   )r%   r&   N)r'   r	   r(   r)   r   r*   r+   r*   r,   r-   r.   r   r/   r   r0   r1   r   r2   )F)r{   r   r   r   r   r   )rK   r   r   r2   )"
__future__r   rb   collectionsr   collections.abcr   r   typingr   r   rn   rH   rk   monai.configr   r	   monai.utilsr
   r   r   r   r   OPT_IMPORT_VERSIONidist_ignite.enginer   __all__r   r   r   r   r   r    r    r    r!   <module>   s,   

r+