o
    i3$                     @  sD  d dl mZ d dlmZmZ d dlZd dlZd dlm	Z	 d dl
mZ d dlmZmZ ed\ZZedd	d
\ZZerUd dlmZ d dlmZ d dlmZ d dlmZ dZn eddd
\ZZeddd
\ZZeddd
\ZZeddd
\ZZg dZ	 	d;d<d"d#Z		$	%	d=d>d'd(Z	$	%		d?d@d.d/Z	 	0	%	1	2dAdBd9d:ZdS )C    )annotations)TYPE_CHECKINGAnyN)NdarrayTensorrescale_array)convert_data_typeoptional_importPILzPIL.GifImagePluginImage)name)Summary)SummaryWriterTz$tensorboard.compat.proto.summary_pb2r   ztensorboardX.proto.summary_pb2ztorch.utils.tensorboardr   tensorboardX)make_animated_gif_summaryadd_animated_gifplot_2d_or_3d_image      ?tagstrimagenp.ndarray | torch.Tensorwriter%SummaryWriter | SummaryWriterX | None	frame_dimintscale_factorfloatreturnr   c                   s   t |jdkrtdt|tjd^}} fddt||dD }dd |D }d}tj	|d d D ]}	||	7 }q5|d	7 }|D ]}
tj
|
D ]}	||	7 }qJqB|d
7 }tr_t|tr_tnt}|jddd|d}|j| |d}||gdS )a  Function to actually create the animated gif.

    Args:
        tag: Data identifier
        image: 3D image tensors expected to be in `HWD` format
        writer: the tensorboard writer to plot image
        frame_dim: the dimension used as frames for GIF image, expect data shape as `HWD`, default to `0`.
        scale_factor: amount to multiply values by. if the image data is between 0 and 1, using 255 for this value will
            scale it to displayable range
       zF3D image tensors expected to be in `HWD` format, len(image.shape) != 3)output_typec                   s    g | ]}|  j tjd dqS )F)copy)astypenpuint8.0ir    a/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/visualize/img2tensorboard.py
<listcomp>?   s     z(_image3_animated_gif.<locals>.<listcomp>r   c                 S  s   g | ]}t |qS r)   )GifImage	fromarray)r&   imr)   r)   r*   r+   @   s        s   !NETSCAPE2.0      ;
      )heightwidth
colorspaceencoded_image_string)r   r   )value)lenshapeAssertionErrorr   r#   ndarraymoveaxisr
   GifImagePlugin	getheadergetdatahas_tensorboardx
isinstanceSummaryWriterXSummaryXr   r   Value)r   r   r   r   r   image_np_imsZimg_strb_datar'   summaryZsummary_image_strimage_summaryr)   r(   r*   _image3_animated_gif*   s$   

rK   r   max_outc           
   	   C  s   |dkrdnd}|dkr|d n|}g }t t|t|jd D ]8}t|tjr:||ddddddf jddn||ddddddf }	|t	| |
| |	||| q |S )am  Creates an animated gif out of an image tensor in 'CHWD' format and returns Summary.

    Args:
        tag: Data identifier
        image: The image, expected to be in `CHWD` format
        writer: the tensorboard writer to plot image
        max_out: maximum number of image channels to animate through
        frame_dim: the dimension used as frames for GIF image, expect input data shape as `CHWD`,
            default to `-3` (the first spatial dim)
        scale_factor: amount to multiply values by.
            if the image data is between 0 and 1, using 255 for this value will scale it to displayable range
    r2   z/imagez	/image/{}r   N)dim)rangeminlistr9   rA   torchTensorsqueezeappendrK   format)
r   r   r   rM   r   r   suffixZ
summary_opZit_iZone_channel_imgr)   r)   r*   r   P   s   Jr   SummaryWriter | SummaryWriterXimage_tensorglobal_step
int | NoneNonec           	      C  s2   t ||| |||d}|D ]
}|  || qdS )a  Creates an animated gif out of an image tensor in 'CHWD' format and writes it with SummaryWriter.

    Args:
        writer: Tensorboard SummaryWriter to write to
        tag: Data identifier
        image_tensor: tensor for the image to add, expected to be in `CHWD` format
        max_out: maximum number of image channels to animate through
        frame_dim: the dimension used as frames for GIF image, expect input data shape as `CHWD`,
            default to `-3` (the first spatial dim)
        scale_factor: amount to multiply values by. If the image data is between 0 and 1, using 255 for this value will
            scale it to displayable range
        global_step: Global step value to record
    )r   r   r   rM   r   r   N)r   _get_file_writeradd_summary)	r   r   rY   rM   r   r   rZ   rI   sr)   r)   r*   r   t   s   r   r2      outputdata#NdarrayTensor | list[NdarrayTensor]stepindexmax_channels
max_framesc                 C  s  | | }|dkr|d n|}t |tjr|   nt|}	|	jdkr>t	|	dd}	d}
|j
| d|
 |	||
d dS |	jdkr|	jd dkr`|dkr`d	}
|j
| d|
 |	||
d dS d}
t|	d| D ]\}}t	|dd}|j
| d|
 d| |||
d qjdS |	jd
kr|	jdd }|	dgt| }	|	jd }|dkr|dkrtrt |trt|	|d}	|j||	d ||dd dS t||}tjdd |	d| D dd}	t|| d|	|||d dS dS )aX  Plot 2D or 3D image on the TensorBoard, 3D image will be converted to GIF image.

    Note:
        Plot 3D or 2D image(with more than 3 channels) as separate images.
        And if writer is from TensorBoardX, data has 3 channels and `max_channels=3`, will plot as RGB video.

    Args:
        data: target data to be plotted as image on the TensorBoard.
            The data is expected to have 'NCHW[D]' dimensions or a list of data with `CHW[D]` dimensions,
            and only plot the first in the batch.
        step: current step to plot in a chart.
        writer: specify TensorBoard or TensorBoardX SummaryWriter to plot the image.
        index: plot which element in the input data batch, default is the first element.
        max_channels: number of channels to plot.
        frame_dim: if plotting 3D image as GIF, specify the dimension used as frames,
            expect input data shape as `NCHWD`, default to `-3` (the first spatial dim)
        max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`.
        tag: tag of the plotted image on TensorBoard.
    r   r2      HWrF   )dataformatsNr   CHW   rL   ZNCHWT)fpsrj   c                 S  s   g | ]}t |d dqS )r      r   r%   r)   r)   r*   r+      s    z'plot_2d_or_3d_image.<locals>.<listcomp>)axisZ_HWD)rM   r   rZ   )rA   rR   rS   detachcpunumpyr#   asarrayndimr   	add_imager9   	enumeratereshaperQ   r@   rB   r<   	add_videorP   stackr   )rb   rd   r   re   rf   r   rg   r   
data_indexdrj   jd2spatialZd_chansr)   r)   r*   r      s@   $

$


 r   )r   r   )r   r   r   r   r   r   r   r   r   r   r   r   )Nr   rL   r   )r   r   r   r   r   r   rM   r   r   r   r   r   r   r   )r   rL   r   N)r   rX   r   r   rY   r   rM   r   r   r   r   r   rZ   r[   r   r\   )r   r2   rL   r`   ra   )rb   rc   rd   r   r   rX   re   r   rf   r   r   r   rg   r   r   r   r   r\   ) 
__future__r   typingr   r   rs   r#   rR   monai.configr   monai.transformsr   monai.utilsr   r	   r
   rF   r,   Z$tensorboard.compat.proto.summary_pb2r   r   r   rB   ZtensorboardX.proto.summary_pb2rC   Ztorch.utils.tensorboardr@   __all__rK   r   r   r   r)   r)   r)   r*   <module>   sN   )("