o
    io!                     @  s   d dl mZ d dlmZmZ d dlmZ d dlmZ d dl	m
Z
 d dlmZ d dlmZ d dlmZmZmZmZ ed	ejed
\ZZedejed\ZZerUd dlmZ n
ed	ejed\ZZG dd dZdS )    )annotations)CallableSequence)TYPE_CHECKING)decollate_batch)write_metrics_reports)
IgniteInfo)ImageMetaKey)ensure_tuplemin_versionoptional_importstring_list_all_gatherzignite.engineEventsignitedistributed)Enginer   c                   @  sV   e Zd ZdZdddd ddddfd$ddZd%ddZd&ddZd%d d!Zd%d"d#ZdS )'MetricsSavera  
    ignite handler to save metrics values and details into expected files.

    Args:
        save_dir: directory to save the metrics and metric details.
        metrics: expected final metrics to save into files, can be: None, "*" or list of strings.
            None - don't save any metrics into files.
            "*" - save all the existing metrics in `engine.state.metrics` dict into separate files.
            list of strings - specify the expected metrics to save.
            default to "*" to save all the metrics into `metrics.csv`.
        metric_details: expected metric details to save into files, the data comes from
            `engine.state.metric_details`, which should be provided by different `Metrics`,
            typically, it's some intermediate values in metric computation.
            for example: mean dice of every channel of every image in the validation dataset.
            it must contain at least 2 dims: (batch, classes, ...),
            if not, will unsqueeze to 2 dims.
            this arg can be: None, "*" or list of strings.
            None - don't save any metric_details into files.
            "*" - save all the existing metric_details in `engine.state.metric_details` dict into separate files.
            list of strings - specify the metric_details of expected metrics to save.
            if not None, every metric_details array will save a separate `{metric name}_raw.csv` file.
        batch_transform: a callable that is used to extract the `meta_data` dictionary of
            the input images from `ignite.engine.state.batch` if saving metric details. the purpose is to get the
            input filenames from the `meta_data` and store with metric details together.
            `engine.state` and `batch_transform` inherit from the ignite concept:
            https://pytorch.org/ignite/concepts.html#state, explanation and usage example are in the tutorial:
            https://github.com/Project-MONAI/tutorials/blob/master/modules/batch_output_transform.ipynb.
        summary_ops: expected computation operations to generate the summary report.
            it can be: None, "*" or list of strings, default to None.
            None - don't generate summary report for every expected metric_details.
            "*" - generate summary report for every metric_details with all the supported operations.
            list of strings - generate summary report for every metric_details with specified operations, they
            should be within list: ["mean", "median", "max", "min", "<int>percentile", "std", "notnans"].
            the number in "<int>percentile" should be [0, 100], like: "15percentile". default: "90percentile".
            for more details, please check: https://numpy.org/doc/stable/reference/generated/numpy.nanpercentile.html.
            note that: for the overall summary, it computes `nanmean` of all classes for each image first,
            then compute summary. example of the generated summary report::

                class    mean    median    max    5percentile 95percentile  notnans
                class0  6.0000   6.0000   7.0000   5.1000      6.9000       2.0000
                class1  6.0000   6.0000   6.0000   6.0000      6.0000       1.0000
                mean    6.2500   6.2500   7.0000   5.5750      6.9250       2.0000

        save_rank: only the handler on specified rank will save to files in multi-gpus validation, default to 0.
        delimiter: the delimiter character in the saved file, default to "," as the default output type is `csv`.
            to be consistent with: https://docs.python.org/3/library/csv.html#csv.Dialect.delimiter.
        output_type: expected output file type, supported types: ["csv"], default to "csv".

    *Nc                 C  s   | S N )xr   r   ^/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/handlers/metrics_saver.py<lambda>W   s    zMetricsSaver.<lambda>r   ,csvsave_dirstrmetricsstr | Sequence[str] | Nonemetric_detailsbatch_transformr   summary_ops	save_rankint	delimiteroutput_typereturnNonec	           	      C  sj   || _ |d urt|nd | _|d urt|nd | _|| _|d ur$t|nd | _|| _|| _|| _g | _	d S r   )
r   r
   r   r   r    r!   r"   delir%   
_filenames)	selfr   r   r   r    r!   r"   r$   r%   r   r   r   __init__R   s   
zMetricsSaver.__init__enginer   c                 C  s2   | tj| j | tj| j | tj|  dS )g
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        N)add_event_handlerr   EPOCH_STARTED_startedITERATION_COMPLETED_get_filenamesEPOCH_COMPLETED)r*   r,   r   r   r   attachg   s   zMetricsSaver.attach_enginec                 C  s
   g | _ dS )zs
        Initialize internal buffers.

        Args:
            _engine: Ignite Engine, unused argument.

        N)r)   )r*   r5   r   r   r   r0   p   s   
zMetricsSaver._startedc                 C  sR   | j d ur%| |jj}t|trt|}|D ]}| j|	t
j  qd S d S r   )r   r    statebatch
isinstancedictr   r)   appendgetKeyFILENAME_OR_OBJ)r*   r,   	meta_datamr   r   r   r2   z   s   

zMetricsSaver._get_filenamesc           	   	     s  t  } j|krtd|dkrt jdn j}t   jkri } jdur=t|j	jdkr= fdd|j	j
 D }i }t|j	drk|j	j} jdurkt|dkrk|
 D ]\}}| jv sfd	 jv rj|||< qXt jt|dkrvdn||| j j jd
 dS dS )r-   z<target save rank is greater than the distributed group size.   )stringsNr   c                   s*   i | ]\}}| j v sd  j v r||qS )r   )r   ).0kvr*   r   r   
<dictcomp>   s   * z)MetricsSaver.__call__.<locals>.<dictcomp>r   r   )r   imagesr   r   r!   r(   r%   )idistget_world_sizer"   
ValueErrorr   r)   get_rankr   lenr6   itemshasattrr   r   r   r!   r(   r%   )	r*   r,   ws_images_metricsZ_metric_detailsdetailsrC   rD   r   rE   r   __call__   s4   

zMetricsSaver.__call__)r   r   r   r   r   r   r    r   r!   r   r"   r#   r$   r   r%   r   r&   r'   )r,   r   r&   r'   )r5   r   r&   r'   )	__name__
__module____qualname____doc__r+   r4   r0   r2   rS   r   r   r   r   r      s    5

	

	r   N)
__future__r   collections.abcr   r   typingr   
monai.datar   Zmonai.handlers.utilsr   monai.utilsr   r	   r<   r
   r   r   r   OPT_IMPORT_VERSIONr   _rH   ignite.enginer   r   r   r   r   r   <module>   s   