U
    Ph#=                     @  s  d Z ddlmZ ddlZddlmZmZ ddlmZ ddl	m
Z
mZ ddlmZmZ ddlZddlZddlmZ dd	lmZmZ zddlmZ d
ZW n ek
r   dZY nX erddlmZmZ n(edejed\ZZ edejed\ZZ dZ!de!fdfddddddddddZ"de!fdfdddddddddd d!Z#dd"d#d$d%d&Z$d'de!fde$ddfdd(ddddd)d*dd+d,
d-d.Z%d/d"d0d1d2Z&G d3d4 d4eZ'G d5d6 d6eZ(dS )7z
This set of utility function is meant to make using Jupyter notebooks easier with MONAI. Plotting functions using
Matplotlib produce common plots for metrics and images.
    )annotationsN)CallableMapping)Enum)RLockThread)TYPE_CHECKINGAny)
IgniteInfo)min_versionoptional_importTF)EngineEventszignite.enginer   r   losslog   zplt.Axesstrz;Mapping[str, list[float] | tuple[list[float], list[float]]]z
tuple[str]intNone)axtitlegraphmapyscaleavg_keyswindow_fractionreturnc                 C  sH  ddl m} | D ]\}}t|dkrt|d ttfrHt| \}	}
nttt|t| }	}
| j	|	|
| d|
d dd ||krt||krt|| }t
|f| }t
j|
d f|d  |
 |dd	}| j	|	|| d
|d dd q| | | | | d | jdddd | ddd | j|dd dS )ar  
    Plot metrics on a single graph with running averages plotted for selected keys. The values in `graphmap`
    should be lists of (timepoint, value) pairs as stored in MetricLogger objects.

    Args:
        ax: Axes object to plot into
        title: graph title
        graphmap: dictionary of named graph values, which are lists of values or (index, value) pairs
        yscale: scale for y-axis compatible with `Axes.set_yscale`
        avg_keys: tuple of keys in `graphmap` to provide running average plots for
        window_fraction: what fraction of the graph value length to use as the running average window
    r   )MaxNLocatorz = z.5g)label   valid)modez Avg = on)r   r   g        )Zbbox_to_anchorlocZborderaxespadTboth)integerN)Zmatplotlib.tickerr   itemslen
isinstancetuplelistziprangeplotnponesconvolve	set_title
set_yscaleaxislegendgridxaxisZset_major_locator)r   r   r   r   r   r   r   nvindsvalswindowkernelra r>   N/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/utils/jupyter_utils.pyplot_metric_graph.   s$     ""


r@   z
plt.Figurezdict[str, np.ndarray]r*   )figr   r   imagemapr   r   r   r   c              	   C  s   dt dt|f}tj|d|d | d}t|||||| |g}	t|D ]\}
}tj|d|
fd| d}|| jd dkr||| dddg n|jt	
|| d	d
 || d||  dd||   d |d |	| qH|	S )a  
    Plot metric graph data with images below into figure `fig`. The intended use is for the graph data to be
    metrics from a training run and the images to be the batch and output from the last iteration. This uses
    `plot_metric_graph` to plot the metric graph.

    Args:
        fig: Figure object to plot into, reuse from previous plotting for flicker-free refreshing
        title: graph title
        graphmap: dictionary of named graph values, which are lists of values or (index, value) pairs
        imagemap: dictionary of named images to show with metric plot
        yscale: for metric plot, scale for y-axis compatible with `Axes.set_yscale`
        avg_keys: for metric plot, tuple of keys in `graphmap` to provide running average plots for
        window_fraction: for metric plot, what fraction of the graph value length to use as the running average window

    Returns:
        list of Axes objects for graph followed by images
       r   )r   r   )colspanrA      )rowspanrA   r      Zgray)cmap
z.3gz -> off)maxr'   pltZsubplot2gridr@   	enumerateshapeimshow	transposer.   squeezer1   minr3   append)rA   r   r   rB   r   r   r   Z	gridshapegraphaxesir7   imr>   r>   r?   plot_metric_images]   s    .
rX   torch.Tensorznp.ndarray | None)nametensorr   c                 C  s   |j dkr4|jd dkr4|jd dkr4| j S |j dkr|jd dkr|jd dkr|jd d }|dd|f  j S dS )a  
    Return an tuple of images derived from the given tensor. The `name` value indices which key from the
    output or batch value the tensor was stored as, or is "Batch" or "Output" if these were single tensors
    instead of dictionaries. Returns a tuple of 2D images of shape HW, or 3D images of shape CHW where C is
    color channels RGB or RGBA. This allows multiple images to be created from a single tensor, ie. to show
    each channel separately.
    rG   r   rE   rC   N)ndimrN   cpudatanumpy)rZ   r[   Zdmidr>   r>   r?   tensor_to_images   s    &&r`   zTraining Logr	   z)Callable[[str, torch.Tensor], Any] | Noneplt.Figure | Noneztuple[plt.Figure, list])
engineloggerr   r   r   r   image_fnrA   selected_instr   c	                 C  s  |dk	r|   ntjdddd}t|ji}	|	|j i }
|dk	rX| jdk	rX| jjdk	rX| jj| jj	fD ]}|| jjkrdnd}|}t
|tr|| }d}d	d
 | D }t
|tr&| D ]Z\}}t
|tjr|jdkr|||| }|dk	rt|D ]\}}||
| d| < qqqpt
|tjrp|||}|dk	rp||
| d| < qpt|||	|
|||}|jr|d j|jd d ddd ||fS )a  
    Plot the status of the given Engine with its logger. The plot will consist of a graph of loss values and metrics
    taken from the logger, and images taken from the `output` and `batch` members of `engine.state`. The images are
    converted to Numpy arrays suitable for input to `Axes.imshow` using `image_fn`, if this is None then no image
    plotting is done.

    Args:
        engine: Engine to extract images from
        logger: MetricLogger to extract loss and metric data from
        title: graph title
        yscale: for metric plot, scale for y-axis compatible with `Axes.set_yscale`
        avg_keys: for metric plot, tuple of keys in `graphmap` to provide running average plots for
        window_fraction: for metric plot, what fraction of the graph value length to use as the running average window
        image_fn: callable converting tensors keyed to a name in the Engine to a tuple of images to plot
        fig: Figure object to plot into, reuse from previous plotting for flicker-free refreshing
        selected_inst: index of the instance to show in the image plot

    Returns:
        Figure object (or `fig` if given), list of Axes objects for graph and images
    N)r   
   Twhite)figsizetight_layout	facecolorZBatchOutputr   c                 S  s0   i | ](\}}t |tjr|jd kr||d qS )rG   N)r(   torchTensorr\   ).0kr8   r>   r>   r?   
<dictcomp>   s
      
  z&plot_engine_status.<locals>.<dictcomp>rC   _r   r   ro   :)cls)ZclfrL   Figure	LOSS_NAMEr   updatemetricsstatebatchoutputr(   r*   r&   dictrl   rm   r\   rM   rX   Zaxhline)rb   rc   r   r   r   r   rd   rA   re   r   rB   srcr   Zbatch_selected_instZselected_dictro   r8   imagerV   rW   rU   r>   r>   r?   plot_engine_status   s:    

$

r   zUlist[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor)r{   r   c                 C  s.   ddddd}t | tr&|| d S || S )zJReturns a single value from the network output, which is a dict or tensor.z&torch.Tensor | dict[str, torch.Tensor]rY   )r^   r   c                 S  s   t | tr| d S | S )Nr   )r(   r|   )r^   r>   r>   r?   	_get_loss   s    
z(_get_loss_from_output.<locals>._get_lossr   )r(   r*   )r{   r   r>   r>   r?   _get_loss_from_output   s    
r   c                   @  s    e Zd ZdZdZdZdZdZdS )StatusMembersz`
    Named members of the status dictionary, others may be present for named metric values.
    StatusZEpochsZItersZLossN)__name__
__module____qualname____doc__STATUSEPOCHSITERSLOSSr>   r>   r>   r?   r      s
   r   c                      s   e Zd ZdZedd dfddddd fd	d
Zdd Zdd Zdd Ze	ddddZ
ddddZefddddddZ  ZS )ThreadContainera  
    Contains a running `Engine` object within a separate thread from main thread in a Jupyter notebook. This
    allows an engine to begin a run in the background and allow the starting notebook cell to complete. A
    user can thus start a run and then navigate away from the notebook without concern for loosing connection
    with the running cell. All output is acquired through methods which synchronize with the running engine
    using an internal `lock` member, acquiring this lock allows the engine to be inspected while it's prevented
    from starting the next iteration.

    Args:
        engine: wrapped `Engine` object, when the container is started its `run` method is called
        loss_transform: callable to convert an output dict into a single numeric value
        metric_transform: callable to convert a named metric value into a single numeric value
        status_format: format string for status key-value pairs.
    c                 C  s   |S Nr>   )rZ   valuer>   r>   r?   <lambda>      zThreadContainer.<lambda>z	{}: {:.4}r   r   r   )rb   loss_transformmetric_transformstatus_formatc                   sL   t    t | _|| _i | _|| _|| _d | _|| _	| j
tj| j d S r   )super__init__r   lockrb   _status_dictr   r   rA   r   Zadd_event_handlerr   ZITERATION_COMPLETED_update_status)selfrb   r   r   r   	__class__r>   r?   r     s    
zThreadContainer.__init__c                 C  s   | j   dS )z-Calls the `run` method of the wrapped engine.N)rb   runr   r>   r>   r?   r   (  s    zThreadContainer.runc                 C  s   | j   |   dS )z$Stop the engine and join the thread.N)rb   	terminatejoinr   r>   r>   r?   stop,  s    
zThreadContainer.stopc              	   C  s   | j  | jj}tjjdtjjdtjjtdi}|dk	r|j	dk	rd|j	dkrd|j
 d|j	 }n
t|j
}|jdk	r|j|j  d|j }n
t|j}||tjj< ||tjj< | |j|tjj< |jpi }| D ]*\}}| ||}|dk	r|| | q| j| W 5 Q R X dS )zNCalled as an event, updates the internal status dict at the end of iterations.r   nanNr   /)r   rb   ry   r   r   r   r   r   floatZ
max_epochsepochr   Zepoch_length	iterationr   r{   rx   r&   r   rS   r   rw   )r   ry   statsr   itersrx   mr8   r>   r>   r?   r   1  s2    
   




zThreadContainer._update_statuszdict[str, str])r   c              
   C  sF   | j 6 tjj|  rdndi}|| j |W  5 Q R  S Q R X dS )zTA dictionary containing status information, current loss, and current metric values.ZRunningZStoppedN)r   r   r   r   is_aliverw   r   )r   r   r>   r>   r?   status_dictR  s    zThreadContainer.status_dictc                 C  s   t | j}|tjjdt|tjjd g}|	 D ]:\}}t
|tr\| j||}n| d| }|| q:d|S )z<Returns a status string for the current state of the engine.zIters: r   z: z, )copydeepcopyr   popr   r   r   r   r   r&   r(   r   r   formatrS   r   )r   r   Zmsgskeyvalmsgr>   r>   r?   statusZ  s    &
zThreadContainer.statusr	   ra   )rc   	plot_funcr   c              
   C  sB   | j 2 ||  | j|| jd\| _}| jW  5 Q R  S Q R X dS )a$  
        Generate a plot of the current status of the contained engine whose loss and metrics were tracked by `logger`.
        The function `plot_func` must accept arguments `title`, `engine`, `logger`, and `fig` which are the plot title,
        `self.engine`, `logger`, and `self.fig` respectively. The return value must be a figure object (stored in
        `self.fig`) and a list of Axes objects for the plots in the figure. Only the figure is returned by this method,
        which holds the internal lock during the plot generation.
        )r   rb   rc   rA   N)r   r   rb   rA   )r   rc   r   rq   r>   r>   r?   plot_statusj  s    zThreadContainer.plot_status)r   r   r   r   r   r   r   r   r   propertyr   r   r   r   __classcell__r>   r>   r   r?   r     s   !r   ))r   
__future__r   r   collections.abcr   r   enumr   	threadingr   r   typingr   r	   r_   r.   rl   monai.configr
   monai.utils.moduler   r   matplotlib.pyplotpyplotrL   Zhas_matplotlibImportErrorignite.enginer   r   OPT_IMPORT_VERSIONrq   rv   r@   rX   r`   r   r   r   r   r>   r>   r>   r?   <module>   sP   
40"N