o
    -io                     @  s   d dl mZ d dlmZ d dlZd dlmZ d dlmZmZ d dl	m
Z
 d dlmZmZmZmZmZmZmZmZ d dlmZ d	gZG d
d dejZG dd dejZG dd dejZG dd dejZ					d3d4d0d1ZG d2d	 d	ejZdS )5    )annotations)SequenceN)nn)ConvolutionSpatialAttentionBlock)SPADE)DiffusionUnetDownsampleDiffusionUNetResnetBlockSpatialTransformerWrappedUpsampleget_down_blockget_mid_blockget_timestep_embeddingzero_module)ensure_tuple_repSPADEDiffusionModelUNetc                      s:   e Zd ZdZ						dd  fddZd!ddZ  ZS )"SPADEDiffResBlocka  
    Residual block with timestep conditioning and SPADE norm.
    Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)

    Args:
        spatial_dims: The number of spatial dimensions.
        in_channels: number of input channels.
        temb_channels: number of timestep embedding  channels.
        label_nc: number of semantic channels for SPADE normalisation.
        out_channels: number of output channels.
        up: if True, performs upsampling.
        down: if True, performs downsampling.
        norm_num_groups: number of groups for the group normalization.
        norm_eps: epsilon for the group normalization.
        spade_intermediate_channels: number of intermediate channels for SPADE block layer
    NF    ư>   spatial_dimsintin_channelstemb_channelslabel_ncout_channels
int | Noneupbooldownnorm_num_groupsnorm_epsfloatspade_intermediate_channelsreturnNonec              
     sB  t    || _|| _|| _|p|| _|| _|| _t||d||	dd|
d|d| _	t
 | _t||| jddddd| _d  | _| _| jrQt|d||d	d
d d| _n
|r[t||dd| _t
|| j| _t|| jd||	dd|
d|d| _tt|| j| jddddd| _|  | j|krt
 | _d S t||| jddddd| _d S )NGROUPT)
num_groupsepsaffine   )r   norm_ncnormnorm_paramshidden_channelskernel_sizer      r   r   r   stridesr/   padding	conv_onlynontrainablenearest       @)r   moder   r   interp_modescale_factoralign_cornersF)use_convr   )super__init__r   channelsemb_channelsr   r   r   r   norm1r   SiLUnonlinearityr   conv1upsample
downsampler   r   Lineartime_emb_projnorm2r   conv2Identityskip_connection)selfr   r   r   r   r   r   r   r    r!   r#   	__class__ p/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/spade_diffusion_model_unet.pyr>   J   s   






	
	
zSPADEDiffResBlock.__init__xtorch.Tensorembsegc                 C  s   |}|  ||}| |}| jd ur| |}| |}n| jd ur,| |}| |}| |}| jdkrI| | |d d d d d d f }n| | |d d d d d d d f }|| }| ||}| |}| |}| 	|| }|S )N   )
rA   rC   rE   rF   rD   r   rH   rI   rJ   rL   )rM   rR   rT   rU   htemboutputrP   rP   rQ   forward   s&   







&&

zSPADEDiffResBlock.forward)NFFr   r   r   )r   r   r   r   r   r   r   r   r   r   r   r   r   r   r    r   r!   r"   r#   r   r$   r%   )rR   rS   rT   rS   rU   rS   r$   rS   __name__
__module____qualname____doc__r>   rZ   __classcell__rP   rP   rN   rQ   r   8   s    ]r   c                      s>   e Zd ZdZ						d&d' fddZ	d(d)d$d%Z  ZS )*SPADEUpBlocka  
    Unet's up block containing resnet and upsamplers blocks.
    Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)

    Args:
        spatial_dims: The number of spatial dimensions.
        in_channels: number of input channels.
        prev_output_channel: number of channels from residual connection.
        out_channels: number of output channels.
        temb_channels: number of timestep embedding channels.
        label_nc: number of semantic channels for SPADE normalisation.
        num_res_blocks: number of residual blocks.
        norm_num_groups: number of groups for the group normalization.
        norm_eps: epsilon for the group normalization.
        add_upsample: if True add downsample block.
        resblock_updown: if True use residual blocks for upsampling.
        spade_intermediate_channels: number of intermediate channels for SPADE block layer.
    r0   r   r   TFr   r   r   r   prev_output_channelr   r   r   num_res_blocksr    r!   r"   add_upsampler   resblock_updownr#   r$   r%   c                   s   t    || _g }t|D ]%}||d kr|n|}|dkr |n|}|t||| |||||	|d qt|| _|  |
rh|rNt	||||||	dd| _
d S t|||ddddd}t|d||d	d
|d d| _
d S d | _
d S )Nr0   r   r   r   r   r   r   r    r!   r#   Tr   r   r   r   r    r!   r   r*   r1   r5   r6   r7   r   r8   r   r   r9   r:   	post_convr;   )r=   r>   re   rangeappendr   r   
ModuleListresnetsr	   	upsamplerr   r   )rM   r   r   rb   r   r   r   rc   r    r!   rd   re   r#   rm   ires_skip_channelsresnet_in_channelsri   rN   rP   rQ   r>      sf   

	
zSPADEUpBlock.__init__Nhidden_statesrS   res_hidden_states_listlist[torch.Tensor]rX   rU   contexttorch.Tensor | Nonec                 C  sZ   ~| j D ]}|d }|d d }tj||gdd}||||}q| jd ur+| ||}|S Nr0   dim)rm   torchcatrn   )rM   rr   rs   rX   rU   ru   resnetres_hidden_statesrP   rP   rQ   rZ     s   

zSPADEUpBlock.forward)r0   r   r   TFr   )r   r   r   r   rb   r   r   r   r   r   r   r   rc   r   r    r   r!   r"   rd   r   re   r   r#   r   r$   r%   Nrr   rS   rs   rt   rX   rS   rU   rS   ru   rv   r$   rS   r[   rP   rP   rN   rQ   ra      s    Ora   c                      sF   e Zd ZdZ										d*d+ fddZ	d,d-d(d)Z  ZS ).SPADEAttnUpBlocka  
    Unet's up block containing resnet, upsamplers, and self-attention blocks.
    Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)

    Args:
        spatial_dims: The number of spatial dimensions.
        in_channels: number of input channels.
        prev_output_channel: number of channels from residual connection.
        out_channels: number of output channels.
        temb_channels: number of timestep embedding channels.
        label_nc: number of semantic channels for SPADE normalisation
        num_res_blocks: number of residual blocks.
        norm_num_groups: number of groups for the group normalization.
        norm_eps: epsilon for the group normalization.
        add_upsample: if True add downsample block.
        resblock_updown: if True use residual blocks for upsampling.
        num_head_channels: number of channels in each attention head.
        spade_intermediate_channels: number of intermediate channels for SPADE block layer
        include_fc: whether to include the final linear layer. Default to True.
        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
    r0   r   r   TFr   r   r   r   rb   r   r   r   rc   r    r!   r"   rd   r   re   num_head_channelsr#   
include_fcuse_combined_linearuse_flash_attentionr$   r%   c                   s  t    || _g }g }t|D ]4}||d kr|n|}|dkr"|n|}|t||| |||||	|d |t|||||	|||d qt|| _	t|| _
|  |
r|ret||||||	dd| _d S t|||ddddd}t|d	||d
d|d d| _d S d | _d S )Nr0   r   rf   )r   num_channelsr   r    r!   r   r   r   Trg   r*   r1   r5   r6   r7   rh   )r=   r>   re   rj   rk   r   r   r   rl   rm   
attentionsr	   rn   r   r   )rM   r   r   rb   r   r   r   rc   r    r!   rd   re   r   r#   r   r   r   rm   r   ro   rp   rq   ri   rN   rP   rQ   r>   N  s   

	
zSPADEAttnUpBlock.__init__Nrr   rS   rs   rt   rX   rU   ru   rv   c           	      C  sr   ~t | j| jD ]#\}}|d }|d d }tj||gdd}||||}|| }q| jd ur7| ||}|S rw   ziprm   r   r{   r|   
contiguousrn   	rM   rr   rs   rX   rU   ru   r}   attnr~   rP   rP   rQ   rZ     s   
zSPADEAttnUpBlock.forward)
r0   r   r   TFr0   r   TFF)"r   r   r   r   rb   r   r   r   r   r   r   r   rc   r   r    r   r!   r"   rd   r   re   r   r   r   r#   r   r   r   r   r   r   r   r$   r%   r   r   r[   rP   rP   rN   rQ   r   5  s     ar   c                      sN   e Zd ZdZ													d.d/ fd"d#Z		d0d1d,d-Z  ZS )2SPADECrossAttnUpBlocka  
    Unet's up block containing resnet, upsamplers, and self-attention blocks.
    Enables SPADE normalisation for semantic conditioning (Park et. al (2019): https://github.com/NVlabs/SPADE)

    Args:
        spatial_dims: The number of spatial dimensions.
        in_channels: number of input channels.
        prev_output_channel: number of channels from residual connection.
        out_channels: number of output channels.
        temb_channels: number of timestep embedding channels.
        label_nc: number of semantic channels for SPADE normalisation.
        num_res_blocks: number of residual blocks.
        norm_num_groups: number of groups for the group normalization.
        norm_eps: epsilon for the group normalization.
        add_upsample: if True add downsample block.
        resblock_updown: if True use residual blocks for upsampling.
        num_head_channels: number of channels in each attention head.
        transformer_num_layers: number of layers of Transformer blocks to use.
        cross_attention_dim: number of context dimensions to use.
        upcast_attention: if True, upcast attention operations to full precision.
        spade_intermediate_channels: number of intermediate channels for SPADE block layer.
        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism.
            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
    r0   r   r   TFNr   r   r   r   rb   r   r   r   rc   r    r!   r"   rd   r   re   r   transformer_num_layerscross_attention_dimr   upcast_attentionr#   r   r   r   r$   r%   c                   s  t    || _g }g }t|D ]:}||d kr|n|}|dkr"|n|}|t||| ||||	||d |t|||| |||	||||||d qt|| _	t|| _
|  |
r|rkt||||||	dd| _d S t|||ddddd}t|d	||d
d|d d| _d S d | _d S )Nr0   r   )r   r   r   r   r    r!   r   r#   )r   r   num_attention_headsr   r    r!   
num_layersr   r   r   r   r   Trg   r*   r1   r5   r6   r7   rh   )r=   r>   re   rj   rk   r   r
   r   rl   r   rm   r	   rn   r   r   )rM   r   r   rb   r   r   r   rc   r    r!   rd   re   r   r   r   r   r#   r   r   r   rm   r   ro   rp   rq   ri   rN   rP   rQ   r>     s   

	
zSPADECrossAttnUpBlock.__init__rr   rS   rs   rt   rX   rU   rv   ru   c           	      C  st   t | j| jD ]%\}}|d }|d d }tj||gdd}||||}|||d }q| jd ur8| ||}|S )Nrx   r0   ry   )ru   r   r   rP   rP   rQ   rZ   <  s   
zSPADECrossAttnUpBlock.forward)r0   r   r   TFr0   r0   NFr   TFF)(r   r   r   r   rb   r   r   r   r   r   r   r   rc   r   r    r   r!   r"   rd   r   re   r   r   r   r   r   r   r   r   r   r#   r   r   r   r   r   r   r   r$   r%   )NN)rr   rS   rs   rt   rX   rS   rU   rv   ru   rv   r$   rS   r[   rP   rP   rN   rQ   r     s&    !gr   Fr   Tr   r   r   rb   r   r   rc   r    r!   r"   rd   r   re   	with_attnwith_cross_attnr   r   r   r   r   r   r#   r   r   r   r$   	nn.Modulec                 C  s  |
r7t di d| d|d|d|d|d|d|d|d	|d
|d|	d|d|d|d|d|S |rqtdi d| d|d|d|d|d|d|d|d	|d
|d|	d|d|d|d|d|d|S t| ||||||||||	|dS )Nr   r   rb   r   r   r   rc   r    r!   rd   re   r   r#   r   r   r   r   r   r   )r   r   rb   r   r   r   rc   r    r!   rd   re   r#   rP   )r   r   ra   )r   r   rb   r   r   rc   r    r!   rd   re   r   r   r   r   r   r   r   r#   r   r   r   rP   rP   rQ   get_spade_up_blockR  s   	
	
r   c                      sX   e Zd ZdZ											
	
					d9d: fd+d,Z	
	
	
	
d;d<d7d8Z  ZS )=r   a  
    UNet network with timestep embedding and attention mechanisms for conditioning, with added SPADE normalization for
    semantic conditioning (Park et.al (2019): https://github.com/NVlabs/SPADE). An example tutorial can be found at
    https://github.com/Project-MONAI/GenerativeModels/tree/main/tutorials/generative/2d_spade_ldm

    Args:
        spatial_dims: number of spatial dimensions.
        in_channels: number of input channels.
        out_channels: number of output channels.
        label_nc: number of semantic channels for SPADE normalisation.
        num_res_blocks: number of residual blocks (see ResnetBlock) per level.
        channels: tuple of block output channels.
        attention_levels: list of levels to add attention.
        norm_num_groups: number of groups for the normalization.
        norm_eps: epsilon for the normalization.
        resblock_updown: if True use residual blocks for up/downsampling.
        num_head_channels: number of channels in each attention head.
        with_conditioning: if True add spatial transformers to perform conditioning.
        transformer_num_layers: number of layers of Transformer blocks to use.
        cross_attention_dim: number of context dimensions to use.
        num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
            classes.
        upcast_attention: if True, upcast attention operations to full precision.
        spade_intermediate_channels: number of intermediate channels for SPADE block layer.
        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
    rV   rV   rV   rV   r   @   r   r   FFTTr   r   F   r0   Nr   Tr   r   r   r   r   rc   Sequence[int] | intr?   Sequence[int]attention_levelsSequence[bool]r    r!   r"   re   r   r   int | Sequence[int]with_conditioningr   r   r   num_class_embedsr   r#   r   r   r   r$   r%   c           !        s  t    |du r|d u rtd|d ur|du rtdt fdd|D r,tdt|t|kr8tdt|trDt|t|}t|t|krPtd	t|tr\t|t|}t|t|krhtd
|| _|| _	|| _
|| _|| _|| _|| _|| _t|||d ddddd| _|d d }tt|d |t t||| _|| _|d urt||| _tg | _|d }tt|D ]`}|}|| }|t|d k}td)i d|d|d|d|d|| d d|	d| d|
d|| o| d|| o|d|| d|d|d|d|d |d!|}| j| qt||d" | |	||d" ||||||d#| _ tg | _!t"t#|}t"t#|}t"t#|}t"t#|}|d }tt|D ]t}|}|| }|t$|d t|d  }|t|d k}t%d)i d|d|d$|d|d|d|| d d d|	d%| d|
d|| o| d|| o|d|| d|d|d|d&|d'|d!|} | j!|  qettj& |d |	dd(t t't||d |ddddd| _(d S )*NTzSPADEDiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) when using with_conditioning.Fz_SPADEDiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim.c                 3  s    | ]	}|  d kV  qdS )r   NrP   ).0out_channelr    rP   rQ   	<genexpr>  s    z3SPADEDiffusionModelUNet.__init__.<locals>.<genexpr>zRSPADEDiffusionModelUNet expects all num_channels being multiple of norm_num_groupszPSPADEDiffusionModelUNet expects num_channels being same size of attention_levelsznum_head_channels should have the same length as attention_levels. For the i levels without attention, i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored.zj`num_res_blocks` should be a single integer or a tuple of integers with the same length as `num_channels`.r   r0   r*   r1      r   r   r   r   rc   r    r!   add_downsamplere   r   r   r   r   r   r   r   r   r   rx   )r   r   r   r    r!   r   r   r   r   r   r   r   r   rb   rd   r   r#   )r'   r   r(   r)   rP   ))r=   r>   
ValueErroranylen
isinstancer   r   r   block_out_channelsr   rc   r   r   r   r   r   conv_inr   
SequentialrG   rB   
time_embedr   	Embeddingclass_embeddingrl   down_blocksrj   r   rk   r   middle_block	up_blockslistreversedminr   	GroupNormr   out)!rM   r   r   r   r   rc   r?   r   r    r!   re   r   r   r   r   r   r   r#   r   r   r   time_embed_dimoutput_channelro   input_channelis_final_block
down_blockreversed_block_out_channelsreversed_num_res_blocksreversed_attention_levelsreversed_num_head_channelsrb   up_blockrN   r   rQ   r>     sd  


	
	

z SPADEDiffusionModelUNet.__init__rR   rS   	timestepsrU   ru   rv   class_labelsdown_block_additional_residualstuple[torch.Tensor] | Nonemid_block_additional_residualc                 C  sf  t || jd }|j|jd}| |}	| jdur1|du r!td| |}
|
j|jd}
|	|
 }	| |}|durC| j	du rCtd|g}| j
D ]}|||	|d\}}|D ]}|| qVqI|dur{|g}t||D ]\}}|| }|| qk|}| j||	|d}|dur|| }| jD ]}t|j }||d }|d| }|||||	|d}q| |}|S )	a  
        Args:
            x: input tensor (N, C, SpatialDims).
            timesteps: timestep tensor (N,).
            seg: Bx[LABEL_NC]x[SPATIAL DIMENSIONS] tensor of segmentations for SPADE norm.
            context: context tensor (N, 1, ContextDim).
            class_labels: context tensor (N, ).
            down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims).
            mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims).
        r   )dtypeNz9class_labels should be provided when num_class_embeds > 0FzAmodel should have with_conditioning = True if context is provided)rr   rX   ru   )rr   rs   rU   rX   ru   )r   r   tor   r   r   r   r   r   r   r   rk   r   r   r   r   rm   r   )rM   rR   r   rU   ru   r   r   r   t_embrT   	class_embrW   down_block_res_samplesdownsample_blockres_samplesresidualnew_down_block_res_samplesdown_block_res_sampledown_block_additional_residualupsample_blockidxrY   rP   rP   rQ   rZ     sH   






zSPADEDiffusionModelUNet.forward)r   r   r   r   r   Fr   Fr0   NNFr   TFF)*r   r   r   r   r   r   r   r   rc   r   r?   r   r   r   r    r   r!   r"   re   r   r   r   r   r   r   r   r   r   r   r   r   r   r#   r   r   r   r   r   r   r   r$   r%   )NNNN)rR   rS   r   rS   rU   rS   ru   rv   r   rv   r   r   r   rv   r$   rS   r[   rP   rP   rN   rQ   r     s2    " G)Fr   TFF),r   r   r   r   rb   r   r   r   r   r   rc   r   r    r   r!   r"   rd   r   re   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r#   r   r   r   r   r   r   r   r$   r   )
__future__r   collections.abcr   r{   r   monai.networks.blocksr   r   Z monai.networks.blocks.spade_normr   (monai.networks.nets.diffusion_model_unetr   r	   r
   r   r   r   r   r   monai.utilsr   __all__Moduler   ra   r   r   r   r   rP   rP   rP   rQ   <module>   s.   (
 s  $O