U
    Ph#                     @  sl  d dl mZ d dlmZmZ d dlZd dlZd dlm	Z	 d dl
mZ d dlmZmZ ed\ZZedd	d
\ZZerd dlmZ d dlmZ d dlmZ d dlmZ dZn@eddd
\ZZeddd
\ZZeddd
\ZZeddd
\ZZdddgZd0ddddddddd Zd1dddddddd#d$dZd2d%dddddd&d'd(d)dZd3d-dd%dddddd'd.	d/dZdS )4    )annotations)TYPE_CHECKINGAnyN)NdarrayTensorrescale_array)convert_data_typeoptional_importPILzPIL.GifImagePluginImage)name)Summary)SummaryWriterTz$tensorboard.compat.proto.summary_pb2r   ztensorboardX.proto.summary_pb2ztorch.utils.tensorboardr   tensorboardXmake_animated_gif_summaryadd_animated_gifplot_2d_or_3d_image      ?strznp.ndarray | torch.Tensorz%SummaryWriter | SummaryWriterX | Noneintfloatr   )tagimagewriter	frame_dimscale_factorreturnc                   s   t |jdkrtdt|tjd^}} fddt||dD }dd |D }d}tj	|d d D ]}	||	7 }qj|d	7 }|D ]}
tj
|
D ]}	||	7 }qq|d
7 }trt|trtnt}|jddd|d}|j| |d}||gdS )a  Function to actually create the animated gif.

    Args:
        tag: Data identifier
        image: 3D image tensors expected to be in `HWD` format
        writer: the tensorboard writer to plot image
        frame_dim: the dimension used as frames for GIF image, expect data shape as `HWD`, default to `0`.
        scale_factor: amount to multiply values by. if the image data is between 0 and 1, using 255 for this value will
            scale it to displayable range
       zF3D image tensors expected to be in `HWD` format, len(image.shape) != 3)output_typec                   s    g | ]}|  j tjd dqS )F)copy)astypenpuint8.0ir    T/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/visualize/img2tensorboard.py
<listcomp>?   s     z(_image3_animated_gif.<locals>.<listcomp>r   c                 S  s   g | ]}t |qS r'   )GifImage	fromarray)r$   imr'   r'   r(   r)   @   s         s   !NETSCAPE2.0      ;
      )heightwidth
colorspaceencoded_image_string)r   r   )value)lenshapeAssertionErrorr   r!   ndarraymoveaxisr
   GifImagePlugin	getheadergetdatahas_tensorboardx
isinstanceSummaryWriterXSummaryXr   r   Value)r   r   r   r   r   image_np_imsZimg_strb_datar%   summaryZsummary_image_strimage_summaryr'   r&   r(   _image3_animated_gif*   s"    
rI   r   )r   r   r   max_outr   r   r   c           
   	   C  s   |dkrdnd}|dkr |d n|}g }t t|t|jd D ]p}t|tjrt||ddddddf jddn||ddddddf }	|t	| |
| |	||| q@|S )am  Creates an animated gif out of an image tensor in 'CHWD' format and returns Summary.

    Args:
        tag: Data identifier
        image: The image, expected to be in `CHWD` format
        writer: the tensorboard writer to plot image
        max_out: maximum number of image channels to animate through
        frame_dim: the dimension used as frames for GIF image, expect input data shape as `CHWD`,
            default to `-3` (the first spatial dim)
        scale_factor: amount to multiply values by.
            if the image data is between 0 and 1, using 255 for this value will scale it to displayable range
    r0   z/imagez	/image/{}r   N)dim)rangeminlistr7   r?   torchTensorsqueezeappendrI   format)
r   r   r   rK   r   r   suffixZ
summary_opZit_iZone_channel_imgr'   r'   r(   r   P   s    JzSummaryWriter | SummaryWriterXz
int | NoneNone)r   r   image_tensorrK   r   r   global_stepr   c           	      C  s2   t ||| |||d}|D ]}|  || qdS )a  Creates an animated gif out of an image tensor in 'CHWD' format and writes it with SummaryWriter.

    Args:
        writer: Tensorboard SummaryWriter to write to
        tag: Data identifier
        image_tensor: tensor for the image to add, expected to be in `CHWD` format
        max_out: maximum number of image channels to animate through
        frame_dim: the dimension used as frames for GIF image, expect input data shape as `CHWD`,
            default to `-3` (the first spatial dim)
        scale_factor: amount to multiply values by. If the image data is between 0 and 1, using 255 for this value will
            scale it to displayable range
        global_step: Global step value to record
    )r   r   r   rK   r   r   N)r   _get_file_writeradd_summary)	r   r   rW   rK   r   r   rX   rG   sr'   r'   r(   r   t   s         r0      outputz#NdarrayTensor | list[NdarrayTensor])	datastepr   indexmax_channelsr   
max_framesr   r   c                 C  s  | | }|dkr|d n|}t |tjr8|   n|}	|	jdkrvt|	dd}	d}
|j| d|
 |	||
d dS |	jdkr|	j	d dkr|dkrd	}
|j| d|
 |	||
d dS d}
t
|	d| D ]6\}}t|dd}|j| d|
 d| |||
d qdS |	jd
kr|	j	dd }|	dgt| }	|	j	d dkr|dkrtrt |trt|	|d}	|j||	d ||dd dS t||	j	d }tjdd |	d| D dd}	t|| d|	|||d dS dS )aX  Plot 2D or 3D image on the TensorBoard, 3D image will be converted to GIF image.

    Note:
        Plot 3D or 2D image(with more than 3 channels) as separate images.
        And if writer is from TensorBoardX, data has 3 channels and `max_channels=3`, will plot as RGB video.

    Args:
        data: target data to be plotted as image on the TensorBoard.
            The data is expected to have 'NCHW[D]' dimensions or a list of data with `CHW[D]` dimensions,
            and only plot the first in the batch.
        step: current step to plot in a chart.
        writer: specify TensorBoard or TensorBoardX SummaryWriter to plot the image.
        index: plot which element in the input data batch, default is the first element.
        max_channels: number of channels to plot.
        frame_dim: if plotting 3D image as GIF, specify the dimension used as frames,
            expect input data shape as `NCHWD`, default to `-3` (the first spatial dim)
        max_frames: if plot 3D RGB image as video in TensorBoardX, set the FPS to `max_frames`.
        tag: tag of the plotted image on TensorBoard.
    r   r0      HWrD   )dataformatsNr   CHW   rJ   ZNCHWT)fpsre   c                 S  s   g | ]}t |d dqS )r      r   r#   r'   r'   r(   r)      s     z'plot_2d_or_3d_image.<locals>.<listcomp>)axisZ_HWD)rK   r   rX   )r?   rP   rQ   detachcpunumpyndimr   	add_imager7   	enumeratereshaperO   r>   r@   r!   r:   	add_videorN   stackr   )r^   r_   r   r`   ra   r   rb   r   
data_indexdre   jd2spatialr'   r'   r(   r      s:     
$, )r   r   )Nr   rJ   r   )r   rJ   r   N)r   r0   rJ   r\   r]   ) 
__future__r   typingr   r   rn   r!   rP   monai.configr   monai.transformsr   monai.utilsr   r	   r
   rD   r*   Z$tensorboard.compat.proto.summary_pb2r   r   r   r@   ZtensorboardX.proto.summary_pb2rA   Ztorch.utils.tensorboardr>   __all__rI   r   r   r   r'   r'   r'   r(   <module>   sL   
  )    (    "     