o
    )i_@                     @  s   d dl mZ d dlmZ d dlZd dlZd dlmZ d dl	m  m
Z d dlmZ d dlmZ d dlmZ d dlmZ d dlmZ d	gZG d
d deZG dd dejZG dd dejZG dd dejZG dd	 d	ejZdS )    )annotations)SequenceN)Convolution)SPADE)Act)get_act_layer)StrEnumSPADENetc                   @  s   e Zd ZdZdZdZdS )UpsamplingModesbicubicnearestbilinearN)__name__
__module____qualname__r   r   r    r   r   c/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/spade_network.pyr
      s    r
   c                      sF   e Zd ZdZddejddifdfd fddZdd Zdd Z  Z	S )SPADENetResBlocka,  
    Creates a Residual Block with SPADE normalisation.

    Args:
        spatial_dims: number of spatial dimensions
        in_channels: number of input channels
        out_channels: number of output channels
        label_nc: number of semantic channels that will be taken into account in SPADE normalisation blocks
        spade_intermediate_channels: number of intermediate channels in the middle conv. layers in SPADE normalisation blocks
        norm: base normalisation type used on top of SPADE
        kernel_size: convolutional kernel size
       INSTANCEnegative_slope皙?   spatial_dimsintin_channelsout_channelslabel_ncspade_intermediate_channelsnormstr | tupleactkernel_sizec	           	        s   t    || _|| _t| j| j| _| j| jk| _t|| j| jd d d| _t|| j| jd d d| _	t
|| _t|| j||||d| _t|| j||||d| _| jrmt|| j| jd d dd| _t|| j||||d| _d S d S )N)r   r   r   r!   r   )r   norm_ncr"   r   hidden_channelsr      )r   r   r   r!   r   r"   )super__init__r   r   minZint_channelslearned_shortcutr   conv_0conv_1r   
activationr   norm_0norm_1conv_snorm_s)	selfr   r   r   r   r   r   r!   r"   	__class__r   r   r'   2   sd   

	zSPADENetResBlock.__init__c                 C  sH   |  ||}| | | ||}| | | ||}|| }|S N)shortcutr*   r,   r-   r+   r.   )r1   xsegx_sdxoutr   r   r   forwardp   s
   zSPADENetResBlock.forwardc                 C  s$   | j r| | ||}|S |}|S r4   )r)   r/   r0   )r1   r6   r7   r8   r   r   r   r5   w   s
   zSPADENetResBlock.shortcut)r   r   r   r   r   r   r   r   r   r   r   r    r!   r    r"   r   )
r   r   r   __doc__r   	LEAKYRELUr'   r;   r5   __classcell__r   r   r2   r   r   $   s    >r   c                      sL   e Zd ZdZddejddiffd fddZdd Zdd Zdd Z	  Z
S )SPADEEncoderaT  
    Encoding branch of a VAE compatible with a SPADE-like generator

    Args:
        spatial_dims: number of spatial dimensions
        in_channels: number of input channels
        z_dim: latent space dimension of the VAE containing the image sytle information
        channels: number of output after each downsampling block
        input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers
        of the autoencoder (HxWx[D])
        kernel_size: convolutional kernel size
        norm: normalisation layer type
        act: activation type
    r   r   r   r   r   r   r   z_dimchannelsSequence[int]input_shaper"   r   r    r!   c	                   s,  t    | _| _| _t||krtd| t|D ]\}	}
|
dt|  |
dt|  kr=td|	|
t|f q| _ fdd jD  _	g } j}t|D ]\}}|
t|||d|||d |}qUt| _tjt j	 jd   jd _tjt j	 jd   jd _d S )	N?Length of parameter input shape must match spatial_dims; got %s   Each dimension of your input must be divisible by 2 ** (autoencoder depth).The shape in position %d, %d is not divisible by %d. c                      g | ]}|d t  j  qS rE   )lenrA   .0s_r1   r   r   
<listcomp>       z)SPADEEncoder.__init__.<locals>.<listcomp>)r   r   r   stridesr"   r   r!   )in_featuresout_features)r&   r'   r   r@   rA   rI   
ValueError	enumeraterC   latent_spatial_shapeappendr   nn
ModuleListblocksLinearnpprodfc_mufc_var)r1   r   r   r@   rA   rC   r"   r   r!   s_indrL   rZ   Zch_init_ch_valuer2   rM   r   r'      sL   
 zSPADEEncoder.__init__c                 C  sB   | j D ]}||}q||dd}| |}| |}||fS Nr   rQ   )rZ   viewsizer^   r_   r1   r6   blockmulogvarr   r   r   r;      s   



zSPADEEncoder.forwardc                 C  sF   | j D ]}||}q||dd}| |}| |}| ||S rc   )rZ   rd   re   r^   r_   reparameterizerf   r   r   r   encode   s   



zSPADEEncoder.encodec                 C  s&   t d| }t |}||| S )Ng      ?)torchexp
randn_likemul)r1   rh   ri   stdepsr   r   r   rj      s   
zSPADEEncoder.reparameterize)r   r   r   r   r@   r   rA   rB   rC   rB   r"   r   r   r    r!   r    )r   r   r   r<   r   r=   r'   r;   rk   rj   r>   r   r   r2   r   r?      s    1r?   c                      sX   e Zd ZdZddddejddifejddifdejjfd$ fddZ	d%d&d"d#Z
  ZS )'SPADEDecodera  
    Decoder branch of a SPADE-like generator. It can be used independently, without an encoding branch,
    behaving like a GAN, or coupled to a SPADE encoder.

    Args:
        label_nc: number of semantic labels
        spatial_dims: number of spatial dimensions
        out_channels: number of output channels
        label_nc: number of semantic channels used for the SPADE normalisation blocks
        input_shape: spatial input shape of the tensor, necessary to do the reshaping after the linear layers
        channels: number of output after each downsampling block
        z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used)
        is_vae: whether the decoder is going to be coupled to an autoencoder or not (true: yes, false: no)
        spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks
        norm: base normalisation type
        act:  activation layer type
        last_act: activation layer type for the last layer of the network (can differ from previous)
        kernel_size: convolutional kernel size
        upsampling_mode: upsampling mode (nearest, bilinear etc.)
    NTr   r   r   r   r   r   r   r   r   rC   rB   rA   	list[int]r@   
int | Noneis_vaeboolr   r   r    r!   last_actstr | tuple | Noner"   upsampling_modestrc                   s  t    | _| _| _| _t||krtd| t|D ]\}}|dt|  |dt|  kr@td||t|f q! fdd|D  _	 jsZt
|||d |d _n jre|d u retdt|t j	|d   _| _g }| j tjjd|d	 _t|d d
 D ]\}}|t||||d  |||	||
d qtj| _t
||d
 ||d d |d |d _d S )NrD   rE   rF   c                   rG   rH   )rI   num_channelsrJ   rM   r   r   rN   	  rO   z)SPADEDecoder.__init__.<locals>.<listcomp>r   )r   r   r   r"   zqIf the network is used in VAE-GAN mode, parameter z_dim (number of latent channels in the VAE) must be populated.)scale_factormoderQ   r%   )r   r   r   r   r   r   r"   r!   )r   r   r   paddingr"   r   r!   )r&   r'   ru   r   r   r{   rI   rT   rU   rV   r   	conv_initrX   r[   r\   r]   fcr@   rW   rl   Upsample
upsamplingr   rY   rZ   	last_conv)r1   r   r   r   rC   rA   r@   ru   r   r   r!   rw   r"   ry   r`   rL   rZ   Zch_indrb   r2   rM   r   r'      sh   
 


zSPADEDecoder.__init__ztorch.Tensor | Nonec                 C  s   | j stj|t| jd}| |}n+|du r,| jdur,tj|	d| jtj
| d}| |}|jd| jd g| j  }| jD ]}|||}| |}qA| |}|S )a	  
        Args:
            seg: input BxCxHxW[xD] semantic map on which the output is conditioned on
            z: latent vector output by the encoder if self.is_vae is True. When is_vae is
            False, z is a random noise vector.

        Returns:

        )re   Nr   )dtypedevicerQ   )ru   FinterpolatetuplerV   r   r@   rl   randnre   float32
get_devicer   rd   r{   rZ   r   r   )r1   r7   r   r6   	res_blockr   r   r   r;   4  s   
 



zSPADEDecoder.forward)r   r   r   r   r   r   rC   rB   rA   rs   r@   rt   ru   rv   r   r   r   r    r!   r    rw   rx   r"   r   ry   rz   r4   )r   r   )r   r   r   r<   r   r=   r
   r   valuer'   r;   r>   r   r   r2   r   rr      s    Hrr   c                      sn   e Zd ZdZddddejddifejddifdejjfd, fdd Z	d-d.d%d&Z
d/d'd(Zd-d0d*d+Z  ZS )1r	   a  
    SPADE Network, implemented based on the code by Park, T et al. in
    "Semantic Image Synthesis with Spatially-Adaptive Normalization"
    (https://github.com/NVlabs/SPADE)

    Args:
        spatial_dims: number of spatial dimensions
        in_channels: number of input channels
        out_channels: number of output channels
        label_nc: number of semantic channels used for the SPADE normalisation blocks
        input_shape:  spatial input shape of the tensor, necessary to do the reshaping after the linear layers
        channels: number of output after each downsampling block
        z_dim: latent space dimension of the VAE containing the image sytle information (None if encoder is not used)
        is_vae: whether the decoder is going to be coupled to an autoencoder (true) or not (false)
        spade_intermediate_channels: number of channels in the intermediate layers of the SPADE normalisation blocks
        norm: base normalisation type
        act: activation layer type
        last_act: activation layer type for the last layer of the network (can differ from previous)
        kernel_size: convolutional kernel size
        upsampling_mode: upsampling mode (nearest, bilinear etc.)
    NTr   r   r   r   r   r   r   r   r   r   rC   rB   rA   rs   r@   rt   ru   rv   r   r   r    r!   rw   rx   r"   ry   rz   c                   s   t    || _|| _|| _|| _|| _|| _| jr0|d u r#td nt	|||||||
|d| _
|}|  t||||||||	|
||||d| _d S )NzVThe latent space dimension mapped by parameter z_dim cannot be None is is_vae is True.)r   r   r@   rA   rC   r"   r   r!   )r   r   r   rC   rA   r@   ru   r   r   r!   rw   r"   ry   )r&   r'   ru   r   r   rA   r   rC   rT   r?   encoderreverserr   decoder)r1   r   r   r   r   rC   rA   r@   ru   r   r   r!   rw   r"   ry   decoder_channelsr2   r   r   r'   h  sJ   

zSPADENet.__init__r7   torch.Tensorr6   r   c                 C  sF   d }| j r| |\}}| j||}| ||||fS | ||fS r4   )ru   r   rj   r   )r1   r7   r6   r   z_muZz_logvarr   r   r   r;     s   zSPADENet.forwardc                 C  s   | j r	| j|S d S r4   )ru   r   rk   )r1   r6   r   r   r   rk     s   zSPADENet.encoder   c                 C  s   |  ||S r4   )r   )r1   r7   r   r   r   r   decode  s   zSPADENet.decode)r   r   r   r   r   r   r   r   rC   rB   rA   rs   r@   rt   ru   rv   r   r   r   r    r!   r    rw   rx   r"   r   ry   rz   r4   )r7   r   r6   r   )r6   r   )r7   r   r   r   )r   r   r   r<   r   r=   r
   r   r   r'   r;   rk   r   r>   r   r   r2   r   r	   Q  s    ;
	)
__future__r   typingr   numpyr\   rl   torch.nnrX   torch.nn.functional
functionalr   monai.networks.blocksr   Z monai.networks.blocks.spade_normr   monai.networks.layersr   monai.networks.layers.utilsr   monai.utils.enumsr   __all__r
   Moduler   r?   rr   r	   r   r   r   r   <module>   s"   [W{