o
    -i@                    @  s  d dl mZ d dlZd dlmZ d dlmZ d dlZd dl	Z	d dl	m
Z
 d dlmZmZmZmZmZmZ d dlmZ d dlmZmZ ed	d
d\ZZdgZdYddZG dd de
jZG dd de
jZdZd[ddZG dd de
jZG d d! d!eZ G d"d# d#e
jZ!G d$d% d%e
jZ"G d&d' d'e
jZ#G d(d) d)e
jZ$G d*d+ d+e
jZ%G d,d- d-e
jZ&G d.d/ d/e
jZ'G d0d1 d1e
jZ(G d2d3 d3e
jZ)	4	5	6	4	4d\d]dMdNZ*	4	5	6	4	4d\d^dPdQZ+	4	5	6	4	4d\d_dTdUZ,G dVd de
jZ-G dWdX dXe
jZ.dS )`    )annotationsN)Sequence)reduce)nn)ConvolutionCrossAttentionBlockMLPBlockSABlockSpatialAttentionBlockUpsample)Pool)ensure_tuple_repoptional_importzeinops.layers.torch	Rearrange)nameDiffusionModelUNetmodule	nn.Modulereturnc                 C  s   |   D ]}|   q| S )z<
    Zero out the parameters of a module and return it.
    )
parametersdetachzero_)r   p r   j/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/diffusion_model_unet.pyzero_module3   s   r   c                      s<   e Zd ZdZ						dd fddZdd ddZ  ZS )!DiffusionUNetTransformerBlocka  
    A Transformer block that allows for the input dimension to differ from the hidden dimension.

    Args:
        num_channels: number of channels in the input and output.
        num_attention_heads: number of heads to use for multi-head attention.
        num_head_channels: number of channels in each attention head.
        dropout: dropout probability to use.
        cross_attention_dim: size of the context vector for cross attention.
        upcast_attention: if True, upcast attention operations to full precision.
        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
        include_fc: whether to include the final linear layer. Default to True.
        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.

            NFTnum_channelsintnum_attention_headsnum_head_channelsdropoutfloatcross_attention_dim
int | Noneupcast_attentionbooluse_flash_attention
include_fcuse_combined_linearr   Nonec
           
        s   t    t|| |||||rtjnd ||	|d	| _t||d d|d| _t|| ||||||r3tjnd |d| _	t
|| _t
|| _t
|| _d S )N)	hidden_sizehidden_input_size	num_headsdim_headdropout_rateattention_dtyper)   r*   r(      GEGLU)r,   mlp_dimactr0   )r,   r.   r-   context_input_sizer/   r0   r1   r(   )super__init__r	   torchr#   attn1r   ffr   attn2r   	LayerNormnorm1norm2norm3)
selfr   r    r!   r"   r$   r&   r(   r)   r*   	__class__r   r   r8   N   s4   

z&DiffusionUNetTransformerBlock.__init__xtorch.Tensorcontexttorch.Tensor | Nonec                 C  sD   |  | || }| j| ||d| }| | || }|S NrF   )r:   r>   r<   r?   r;   r@   )rA   rD   rF   r   r   r   forwardu   s   z%DiffusionUNetTransformerBlock.forward)r   NFFTF)r   r   r    r   r!   r   r"   r#   r$   r%   r&   r'   r(   r'   r)   r'   r*   r'   r   r+   NrD   rE   rF   rG   r   rE   __name__
__module____qualname____doc__r8   rJ   __classcell__r   r   rB   r   r   <   s    'r   c                      sB   e Zd ZdZ									d$d% fddZd&d'd"d#Z  ZS )(SpatialTransformera  
    Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
    standard transformer action. Finally, reshape to image.

    Args:
        spatial_dims: number of spatial dimensions.
        in_channels: number of channels in the input and output.
        num_attention_heads: number of heads to use for multi-head attention.
        num_head_channels: number of channels in each attention head.
        num_layers: number of layers of Transformer blocks to use.
        dropout: dropout probability to use.
        norm_num_groups: number of groups for the normalization.
        norm_eps: epsilon for the normalization.
        cross_attention_dim: number of context dimensions to use.
        upcast_attention: if True, upcast attention operations to full precision.
        include_fc: whether to include the final linear layer. Default to True.
        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).

       r       ư>NFTspatial_dimsr   in_channelsr    r!   
num_layersr"   r#   norm_num_groupsnorm_epsr$   r%   r&   r'   r)   r*   r(   r   r+   c                   s   t    || _|| _ tj|||dd| _t||ddddd| _t	 f	ddt
|D | _tt||ddddd| _d S )NT
num_groupsr   epsaffinerT   r   rW   rX   out_channelsstrideskernel_sizepadding	conv_onlyc                   s&   g | ]}t  d 	qS ))	r   r    r!   r"   r$   r&   r)   r*   r(   )r   ).0_	r$   r"   r)   	inner_dimr    r!   r&   r*   r(   r   r   
<listcomp>   s    z/SpatialTransformer.__init__.<locals>.<listcomp>)r7   r8   rW   rX   r   	GroupNormnormr   proj_in
ModuleListrangetransformer_blocksr   proj_out)rA   rW   rX   r    r!   rY   r"   rZ   r[   r$   r&   r)   r*   r(   rB   rh   r   r8      s<   


zSpatialTransformer.__init__rD   rE   rF   rG   c                 C  s@  d } } } }}| j dkr|j\}}}}| j dkr#|j\}}}}}|}| |}| |}|jd }	| j dkrH|dddd||| |	}| j dkr_|ddddd||| | |	}| jD ]}
|
||d}qb| j dkr|||||	dddd }| j dkr||||||	ddddd }| |}|| S )N      rT   r   r2   rI   )	rW   shaperl   rm   permutereshaperp   
contiguousrq   )rA   rD   rF   batchchannelheightwidthdepthresidualri   blockr   r   r   rJ      s*   






$

 
$
zSpatialTransformer.forward)	rT   r   rU   rV   NFTFF)rW   r   rX   r   r    r   r!   r   rY   r   r"   r#   rZ   r   r[   r#   r$   r%   r&   r'   r)   r'   r*   r'   r(   r'   r   r+   rK   rL   rM   r   r   rB   r   rS      s    >rS   '  	timestepsrE   embedding_dimr   
max_periodc                 C  s   | j dkr	td|d }t| tjd|tj| jd }t|| }| dddf 	 |dddf  }tj
t|t|gdd}|d dkrTtjj|d	}|S )
at  
    Create sinusoidal timestep embeddings following the implementation in Ho et al. "Denoising Diffusion Probabilistic
    Models" https://arxiv.org/abs/2006.11239.

    Args:
        timesteps: a 1-D Tensor of N indices, one per batch element.
        embedding_dim: the dimension of the output.
        max_period: controls the minimum frequency of the embeddings.
    rT   zTimesteps should be a 1d-arrayrs   r   )startenddtypedeviceNrr   dim)r   rT   r   r   )ndim
ValueErrormathlogr9   arangefloat32r   expr#   catcossinr   
functionalpad)r   r   r   Zhalf_dimexponentfreqsargs	embeddingr   r   r   get_timestep_embedding   s   

"$r   c                      s2   e Zd ZdZ	dd fddZddddZ  ZS )DiffusionUnetDownsamplea  
    Downsampling layer.

    Args:
        spatial_dims: number of spatial dimensions.
        num_channels: number of input channels.
        use_conv: if True uses Convolution instead of Pool average to perform downsampling. In case that use_conv is
            False, the number of output channels must be the same as the number of input channels.
        out_channels: number of output channels.
        padding: controls the amount of implicit zero-paddings on both sides for padding number of points
            for each dimension.
    NrT   rW   r   r   use_convr'   ra   r%   rd   r   r+   c              	     st   t    || _|p|| _|| _|r"t|| j| jdd|dd| _d S | j| jkr,tdttj	|f ddd| _d S )Nrs   rt   Tr`   z?num_channels and out_channels must be equal when use_conv=False)rc   stride)
r7   r8   r   ra   r   r   opr   r   AVG)rA   rW   r   r   ra   rd   rB   r   r   r8     s"   


z DiffusionUnetDownsample.__init__rD   rE   embrG   c                 C  s>   ~|j d | jkrtd|j d  d| j d| |}|S )NrT   zInput number of channels (z/) is not equal to expected number of channels ())ru   r   r   r   )rA   rD   r   outputr   r   r   rJ   4  s   
zDiffusionUnetDownsample.forward)NrT   )rW   r   r   r   r   r'   ra   r%   rd   r   r   r+   rK   rD   rE   r   rG   r   rE   rM   r   r   rB   r   r     s
    r   c                      s$   e Zd ZdZd
d fdd	Z  ZS )WrappedUpsamplezS
    Wraps MONAI upsample block to allow for calling with timestep embeddings.
    NrD   rE   r   rG   r   c                   s   ~t  |}|S rK   )r7   rJ   )rA   rD   r   Z	upsampledrB   r   r   rJ   D  s   zWrappedUpsample.forwardrK   r   )rN   rO   rP   rQ   rJ   rR   r   r   rB   r   r   ?  s    r   c                      s8   e Zd ZdZ					dd fddZdddZ  ZS )DiffusionUNetResnetBlocka  
    Residual block with timestep conditioning.

    Args:
        spatial_dims: The number of spatial dimensions.
        in_channels: number of input channels.
        temb_channels: number of timestep embedding  channels.
        out_channels: number of output channels.
        up: if True, performs upsampling.
        down: if True, performs downsampling.
        norm_num_groups: number of groups for the group normalization.
        norm_eps: epsilon for the group normalization.
    NFrU   rV   rW   r   rX   temb_channelsra   r%   upr'   downrZ   r[   r#   r   r+   c	           	   
     s*  t    || _|| _|| _|p|| _|| _|| _tj	|||dd| _
t | _t||| jddddd| _d  | _| _| jrKt|d||ddd d	| _n
|rUt||d
d| _t|| j| _tj	|| j|dd| _tt|| j| jddddd| _|  | j|krt | _d S t||| jddddd| _d S )NTr\   rT   rt   r`   nontrainablenearest       @)rW   moderX   ra   interp_modescale_factoralign_cornersF)r   r   )r7   r8   rW   channelsZemb_channelsra   r   r   r   rk   r>   SiLUnonlinearityr   conv1upsample
downsampler   r   Lineartime_emb_projr?   r   conv2Identityskip_connection)	rA   rW   rX   r   ra   r   r   rZ   r[   rB   r   r   r8   Y  sp   




	
z!DiffusionUNetResnetBlock.__init__rD   rE   r   c                 C  s   |}|  |}| |}| jd ur| |}| |}n| jd ur+| |}| |}| |}| jdkrH| | |d d d d d d f }n| | |d d d d d d d f }|| }| |}| |}| |}| 	|| }|S )Nrs   )
r>   r   r   r   r   rW   r   r?   r   r   )rA   rD   r   htembr   r   r   r   rJ     s&   








&&


z DiffusionUNetResnetBlock.forward)NFFrU   rV   )rW   r   rX   r   r   r   ra   r%   r   r'   r   r'   rZ   r   r[   r#   r   r+   )rD   rE   r   rE   r   rE   rM   r   r   rB   r   r   J  s    Ir   c                      s>   e Zd ZdZ						d!d" fddZ	d#d$dd Z  ZS )%	DownBlocka  
    Unet's down block containing resnet and downsamplers blocks.

    Args:
        spatial_dims: The number of spatial dimensions.
        in_channels: number of input channels.
        out_channels: number of output channels.
        temb_channels: number of timestep embedding channels.
        num_res_blocks: number of residual blocks.
        norm_num_groups: number of groups for the group normalization.
        norm_eps: epsilon for the group normalization.
        add_downsample: if True add downsample block.
        resblock_updown: if True use residual blocks for downsampling.
        downsample_padding: padding used in the downsampling block.
    rT   rU   rV   TFrW   r   rX   ra   r   num_res_blocksrZ   r[   r#   add_downsampler'   resblock_updowndownsample_paddingr   r+   c                   s   t    |	| _g }t|D ]}|dkr|n|}|t||||||d qt|| _|rL|  |	r@t||||||dd| _	d S t
||d||
d| _	d S d | _	d S )Nr   rW   rX   ra   r   rZ   r[   TrW   rX   ra   r   rZ   r[   r   rW   r   r   ra   rd   )r7   r8   r   ro   appendr   r   rn   resnetsdownsamplerr   )rA   rW   rX   ra   r   r   rZ   r[   r   r   r   r   irB   r   r   r8     sH   


zDownBlock.__init__Nhidden_statesrE   r   rF   rG   'tuple[torch.Tensor, list[torch.Tensor]]c                 C  sN   ~g }| j D ]}|||}|| q| jd ur#| ||}|| ||fS rK   )r   r   r   )rA   r   r   rF   output_statesresnetr   r   r   rJ     s   



zDownBlock.forward)rT   rU   rV   TFrT   )rW   r   rX   r   ra   r   r   r   r   r   rZ   r   r[   r#   r   r'   r   r'   r   r   r   r+   rK   r   rE   r   rE   rF   rG   r   r   rM   r   r   rB   r   r     s    9r   c                      sF   e Zd ZdZ										d%d& fddZ	d'd(d#d$Z  ZS ))AttnDownBlocka  
    Unet's down block containing resnet, downsamplers and self-attention blocks.

    Args:
        spatial_dims: The number of spatial dimensions.
        in_channels: number of input channels.
        out_channels: number of output channels.
        temb_channels: number of timestep embedding  channels.
        num_res_blocks: number of residual blocks.
        norm_num_groups: number of groups for the group normalization.
        norm_eps: epsilon for the group normalization.
        add_downsample: if True add downsample block.
        resblock_updown: if True use residual blocks for downsampling.
        downsample_padding: padding used in the downsampling block.
        num_head_channels: number of channels in each attention head.
        include_fc: whether to include the final linear layer. Default to True.
        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
    rT   rU   rV   TFrW   r   rX   ra   r   r   rZ   r[   r#   r   r'   r   r   r!   r)   r*   r(   r   r+   c                   s   t    |	| _g }g }t|D ]&}|dkr|n|}|t||||||d |t||||||||d qt|| _	t|| _
|  |rc|	rWt||||||dd| _d S t||d||
d| _d S d | _d S )Nr   r   rW   r   r!   rZ   r[   r)   r*   r(   Tr   r   )r7   r8   r   ro   r   r   r
   r   rn   
attentionsr   r   r   )rA   rW   rX   ra   r   r   rZ   r[   r   r   r   r!   r)   r*   r(   r   r   r   rB   r   r   r8   -  sd   



zAttnDownBlock.__init__Nr   rE   r   rF   rG   r   c                 C  sf   ~g }t | j| jD ]\}}|||}|| }|| q
| jd ur/| ||}|| ||fS rK   zipr   r   rx   r   r   rA   r   r   rF   r   r   attnr   r   r   rJ   w  s   


zAttnDownBlock.forward)
rT   rU   rV   TFrT   rT   TFF)rW   r   rX   r   ra   r   r   r   r   r   rZ   r   r[   r#   r   r'   r   r'   r   r   r!   r   r)   r'   r*   r'   r(   r'   r   r+   rK   r   rM   r   r   rB   r   r     s    Kr   c                      sN   e Zd ZdZ														d+d, fd!d"Z	d-d.d)d*Z  ZS )/CrossAttnDownBlocka  
    Unet's down block containing resnet, downsamplers and cross-attention blocks.

    Args:
        spatial_dims: number of spatial dimensions.
        in_channels: number of input channels.
        out_channels: number of output channels.
        temb_channels: number of timestep embedding channels.
        num_res_blocks: number of residual blocks.
        norm_num_groups: number of groups for the group normalization.
        norm_eps: epsilon for the group normalization.
        add_downsample: if True add downsample block.
        resblock_updown: if True use residual blocks for downsampling.
        downsample_padding: padding used in the downsampling block.
        num_head_channels: number of channels in each attention head.
        transformer_num_layers: number of layers of Transformer blocks to use.
        cross_attention_dim: number of context dimensions to use.
        upcast_attention: if True, upcast attention operations to full precision.
        dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers.
        include_fc: whether to include the final linear layer. Default to True.
        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
    rT   rU   rV   TFNr   rW   r   rX   ra   r   r   rZ   r[   r#   r   r'   r   r   r!   transformer_num_layersr$   r%   r&   dropout_cattnr)   r*   r(   r   r+   c                   s   t    |	| _g }g }t|D ]-}|dkr|n|}|t||||||d |t|||| ||||||||||d qt|| _	t|| _
|  |rj|	r^t||||||dd| _d S t||d||
d| _d S d | _d S )Nr   r   rW   rX   r    r!   rY   rZ   r[   r$   r&   r"   r)   r*   r(   Tr   r   )r7   r8   r   ro   r   r   rS   r   rn   r   r   r   r   )rA   rW   rX   ra   r   r   rZ   r[   r   r   r   r!   r   r$   r&   r   r)   r*   r(   r   r   r   rB   r   r   r8     sn   


zCrossAttnDownBlock.__init__r   rE   r   rF   rG   r   c                 C  sh   g }t | j| jD ]\}}|||}|||d }|| q	| jd ur0| ||}|| ||fS rH   r   r   r   r   r   rJ     s   


zCrossAttnDownBlock.forward)rT   rU   rV   TFrT   rT   rT   NFr   TFF)&rW   r   rX   r   ra   r   r   r   r   r   rZ   r   r[   r#   r   r'   r   r'   r   r   r!   r   r   r   r$   r%   r&   r'   r   r#   r)   r'   r*   r'   r(   r'   r   r+   rK   r   rM   r   r   rB   r   r     s&    Ur   c                      s>   e Zd ZdZ						dd  fddZ	d!d"ddZ  ZS )#AttnMidBlockaZ  
    Unet's mid block containing resnet and self-attention blocks.

    Args:
        spatial_dims: The number of spatial dimensions.
        in_channels: number of input channels.
        temb_channels: number of timestep embedding channels.
        norm_num_groups: number of groups for the group normalization.
        norm_eps: epsilon for the group normalization.
        num_head_channels: number of channels in each attention head.
        include_fc: whether to include the final linear layer. Default to True.
        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
    rU   rV   rT   TFrW   r   rX   r   rZ   r[   r#   r!   r)   r'   r*   r(   r   r+   c
           
   
     sT   t    t||||||d| _t||||||||	d| _t||||||d| _d S )Nr   r   )r7   r8   r   resnet_1r
   	attentionresnet_2)
rA   rW   rX   r   rZ   r[   r!   r)   r*   r(   rB   r   r   r8     s6   
zAttnMidBlock.__init__Nr   rE   r   rF   rG   c                 C  s,   ~|  ||}| | }| ||}|S rK   )r   r   rx   r   rA   r   r   rF   r   r   r   rJ   C  s
   zAttnMidBlock.forward)rU   rV   rT   TFF)rW   r   rX   r   r   r   rZ   r   r[   r#   r!   r   r)   r'   r*   r'   r(   r'   r   r+   rK   r   rE   r   rE   rF   rG   r   rE   rM   r   r   rB   r   r     s    +r   c                      sF   e Zd ZdZ										d%d& fddZ	d'd(d#d$Z  ZS ))CrossAttnMidBlocka=  
    Unet's mid block containing resnet and cross-attention blocks.

    Args:
        spatial_dims: The number of spatial dimensions.
        in_channels: number of input channels.
        temb_channels: number of timestep embedding channels
        norm_num_groups: number of groups for the group normalization.
        norm_eps: epsilon for the group normalization.
        num_head_channels: number of channels in each attention head.
        transformer_num_layers: number of layers of Transformer blocks to use.
        cross_attention_dim: number of context dimensions to use.
        upcast_attention: if True, upcast attention operations to full precision.
        include_fc: whether to include the final linear layer. Default to True.
        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
    rU   rV   rT   NFr   TrW   r   rX   r   rZ   r[   r#   r!   r   r$   r%   r&   r'   r   r)   r*   r(   r   r+   c                   sb   t    t||||||d| _t|||| ||||||	|
|||d| _t||||||d| _d S )Nr   r   )r7   r8   r   r   rS   r   r   )rA   rW   rX   r   rZ   r[   r!   r   r$   r&   r   r)   r*   r(   rB   r   r   r8   b  s@   
zCrossAttnMidBlock.__init__r   rE   r   rF   rG   c                 C  s*   |  ||}| j||d}| ||}|S rH   )r   r   r   r   r   r   r   rJ     s   zCrossAttnMidBlock.forward)
rU   rV   rT   rT   NFr   TFF)rW   r   rX   r   r   r   rZ   r   r[   r#   r!   r   r   r   r$   r%   r&   r'   r   r#   r)   r'   r*   r'   r(   r'   r   r+   rK   r   rM   r   r   rB   r   r   N  s    3r   c                      s<   e Zd ZdZ					d"d# fddZ	d$d%d d!Z  ZS )&UpBlocka  
    Unet's up block containing resnet and upsamplers blocks.

    Args:
        spatial_dims: The number of spatial dimensions.
        in_channels: number of input channels.
        prev_output_channel: number of channels from residual connection.
        out_channels: number of output channels.
        temb_channels: number of timestep embedding channels.
        num_res_blocks: number of residual blocks.
        norm_num_groups: number of groups for the group normalization.
        norm_eps: epsilon for the group normalization.
        add_upsample: if True add downsample block.
        resblock_updown: if True use residual blocks for upsampling.
    rT   rU   rV   TFrW   r   rX   prev_output_channelra   r   r   rZ   r[   r#   add_upsampler'   r   r   r+   c                   s   t    |
| _g }t|D ]#}||d kr|n|}|dkr |n|}|t||| ||||d qt|| _|  |	rf|
rLt||||||dd| _	d S t
|||ddddd}t|d||d	d
|d d| _	d S d | _	d S )NrT   r   r   TrW   rX   ra   r   rZ   r[   r   rt   r`   r   r   r   rW   r   rX   ra   r   r   	post_convr   )r7   r8   r   ro   r   r   r   rn   r   	upsamplerr   r   )rA   rW   rX   r   ra   r   r   rZ   r[   r   r   r   r   res_skip_channelsresnet_in_channelsr   rB   r   r   r8     sb   

	
zUpBlock.__init__Nr   rE   res_hidden_states_listlist[torch.Tensor]r   rF   rG   c                 C  sX   ~| j D ]}|d }|d d }tj||gdd}|||}q| jd ur*| ||}|S Nrr   rT   r   )r   r9   r   r   )rA   r   r   r   rF   r   res_hidden_statesr   r   r   rJ     s   

zUpBlock.forward)rT   rU   rV   TF)rW   r   rX   r   r   r   ra   r   r   r   r   r   rZ   r   r[   r#   r   r'   r   r'   r   r+   rK   
r   rE   r   r   r   rE   rF   rG   r   rE   rM   r   r   rB   r   r     s    Kr   c                      sD   e Zd ZdZ									d&d' fddZ	d(d)d$d%Z  ZS )*AttnUpBlocka  
    Unet's up block containing resnet, upsamplers, and self-attention blocks.

    Args:
        spatial_dims: The number of spatial dimensions.
        in_channels: number of input channels.
        prev_output_channel: number of channels from residual connection.
        out_channels: number of output channels.
        temb_channels: number of timestep embedding channels.
        num_res_blocks: number of residual blocks.
        norm_num_groups: number of groups for the group normalization.
        norm_eps: epsilon for the group normalization.
        add_upsample: if True add downsample block.
        resblock_updown: if True use residual blocks for upsampling.
        num_head_channels: number of channels in each attention head.
        include_fc: whether to include the final linear layer. Default to True.
        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
    rT   rU   rV   TFrW   r   rX   r   ra   r   r   rZ   r[   r#   r   r'   r   r!   r)   r*   r(   r   r+   c                   s  t    |
| _g }g }t|D ]2}||d kr|n|}|dkr"|n|}|t||| ||||d |t||||||||d qt|| _	t|| _
|  |	r}|
rct||||||dd| _d S t|||ddddd}t|d	||d
d|d d| _d S d | _d S )NrT   r   r   r   Tr   rt   r`   r   r   r   r   )r7   r8   r   ro   r   r   r
   r   rn   r   r   r   r   r   )rA   rW   rX   r   ra   r   r   rZ   r[   r   r   r!   r)   r*   r(   r   r   r   r   r   r   rB   r   r   r8   !  s~   

	
zAttnUpBlock.__init__Nr   rE   r   r   r   rF   rG   c                 C  sp   ~t | j| jD ]"\}}|d }|d d }tj||gdd}|||}|| }q| jd ur6| ||}|S r   )r   r   r   r9   r   rx   r   rA   r   r   r   rF   r   r   r   r   r   r   rJ   z  s   

zAttnUpBlock.forward)	rT   rU   rV   TFrT   TFF)rW   r   rX   r   r   r   ra   r   r   r   r   r   rZ   r   r[   r#   r   r'   r   r'   r!   r   r)   r'   r*   r'   r(   r'   r   r+   rK   r   rM   r   r   rB   r   r     s    ^r   c                      sL   e Zd ZdZ													d,d- fd!d"Z	d.d/d*d+Z  ZS )0CrossAttnUpBlocka  
    Unet's up block containing resnet, upsamplers, and self-attention blocks.

    Args:
        spatial_dims: The number of spatial dimensions.
        in_channels: number of input channels.
        prev_output_channel: number of channels from residual connection.
        out_channels: number of output channels.
        temb_channels: number of timestep embedding channels.
        num_res_blocks: number of residual blocks.
        norm_num_groups: number of groups for the group normalization.
        norm_eps: epsilon for the group normalization.
        add_upsample: if True add downsample block.
        resblock_updown: if True use residual blocks for upsampling.
        num_head_channels: number of channels in each attention head.
        transformer_num_layers: number of layers of Transformer blocks to use.
        cross_attention_dim: number of context dimensions to use.
        upcast_attention: if True, upcast attention operations to full precision.
        dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers.
        include_fc: whether to include the final linear layer. Default to True.
        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
    rT   rU   rV   TFNr   rW   r   rX   r   ra   r   r   rZ   r[   r#   r   r'   r   r!   r   r$   r%   r&   r   r)   r*   r(   r   r+   c                   s  t    |
| _g }g }t|D ]9}||d kr|n|}|dkr"|n|}|t||| ||||d |t|||| ||||||||||d qt|| _	t|| _
|  |	r|
rjt||||||dd| _d S t|||ddddd}t|d	||d
d|d d| _d S d | _d S )NrT   r   r   )rW   rX   r    r!   rZ   r[   rY   r$   r&   r"   r)   r*   r(   Tr   rt   r`   r   r   r   r   )r7   r8   r   ro   r   r   rS   r   rn   r   r   r   r   r   )rA   rW   rX   r   ra   r   r   rZ   r[   r   r   r!   r   r$   r&   r   r)   r*   r(   r   r   r   r   r   r   rB   r   r   r8     s   

	
zCrossAttnUpBlock.__init__r   rE   r   r   r   rF   rG   c                 C  sn   t | j| jD ]"\}}|d }|d d }tj||gdd}|||}|||d}q| jd ur5| ||}|S )Nrr   rT   r   rI   )r   r   r   r9   r   r   r   r   r   r   rJ     s   

zCrossAttnUpBlock.forward)rT   rU   rV   TFrT   rT   NFr   TFF)&rW   r   rX   r   r   r   ra   r   r   r   r   r   rZ   r   r[   r#   r   r'   r   r'   r!   r   r   r   r$   r%   r&   r'   r   r#   r)   r'   r*   r'   r(   r'   r   r+   rK   r   rM   r   r   rB   r   r     s$     gr   Fr   TrW   rX   ra   r   r   rZ   r[   r#   r   r'   r   	with_attnwith_cross_attnr!   r   r$   r%   r&   r   r)   r*   r(   c                 C  s   |	rt | ||||||||||||dS |
rMtdi d| d|d|d|d|d|d|d	|d
|d|d|d|d|d|d|d|d|S t| ||||||||d	S )N)rW   rX   ra   r   r   rZ   r[   r   r   r!   r)   r*   r(   rW   rX   ra   r   r   rZ   r[   r   r   r!   r   r$   r&   r   r)   r*   r(   )	rW   rX   ra   r   r   rZ   r[   r   r   r   )r   r   r   )rW   rX   ra   r   r   rZ   r[   r   r   r   r   r!   r   r$   r&   r   r)   r*   r(   r   r   r   get_down_block#  s~   	
r   with_conditioningc                 C  s@   |rt | ||||||||	|
|||dS t| ||||||||d	S )N)rW   rX   r   rZ   r[   r!   r   r$   r&   r   r)   r*   r(   )	rW   rX   r   rZ   r[   r!   r)   r*   r(   )r   r   rW   rX   r   rZ   r[   r   r!   r   r$   r&   r   r)   r*   r(   r   r   r   get_mid_blockj  s6   r   r   r   c                 C  s   |
rt | |||||||||	||||dS |rQtdi d| d|d|d|d|d|d|d	|d
|d|	d|d|d|d|d|d|d|d|S t| |||||||||	d
S )N)rW   rX   r   ra   r   r   rZ   r[   r   r   r!   r)   r*   r(   rW   rX   r   ra   r   r   rZ   r[   r   r   r!   r   r$   r&   r   r)   r*   r(   )
rW   rX   r   ra   r   r   rZ   r[   r   r   r   )r   r   r   )rW   rX   r   ra   r   r   rZ   r[   r   r   r   r   r!   r   r$   r&   r   r)   r*   r(   r   r   r   get_up_block  s   	
r   c                      sd   e Zd ZdZ											
	
					d;d< fd*d+Z	
	
	
	
d=d>d5d6Zd?d@d9d:Z  ZS )Ar   ap  
    Unet network with timestep embedding and attention mechanisms for conditioning based on
    Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752
    and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162

    Args:
        spatial_dims: number of spatial dimensions.
        in_channels: number of input channels.
        out_channels: number of output channels.
        num_res_blocks: number of residual blocks (see _ResnetBlock) per level.
        channels: tuple of block output channels.
        attention_levels: list of levels to add attention.
        norm_num_groups: number of groups for the normalization.
        norm_eps: epsilon for the normalization.
        resblock_updown: if True use residual blocks for up/downsampling.
        num_head_channels: number of channels in each attention head.
        with_conditioning: if True add spatial transformers to perform conditioning.
        transformer_num_layers: number of layers of Transformer blocks to use.
        cross_attention_dim: number of context dimensions to use.
        num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
            classes.
        upcast_attention: if True, upcast attention operations to full precision.
        dropout_cattn: if different from zero, this will be the dropout value for the cross-attention layers.
        include_fc: whether to include the final linear layer. Default to True.
        use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
    rs   rs   rs   rs   rU   @   r   r   FFTTrU   rV   F   rT   Nr   TrW   r   rX   ra   r   Sequence[int] | intr   Sequence[int]attention_levelsSequence[bool]rZ   r[   r#   r   r'   r!   int | Sequence[int]r   r   r$   r%   num_class_embedsr&   r   r)   r*   r(   r   r+   c                    s  t    |du r|d u rtd|d ur|du rtd|dks%|dk r)tdt fdd	|D r8td
t|t|krDtdt|
trPt|
t|}
t|
t|kr\tdt|trht|t|}t|t|krttd|| _|| _	|| _
|| _|| _|
| _|| _t|||d ddddd| _|d d }tt|d |t t||| _|| _|d urt||| _tg | _|d }tt|D ]d}|}|| }|t|d k}td+i d|d|d|d|d|| d d|d| d|	d|| o| d|| o|d|
| d|d |d!|d"|d#|d$|d%|}| j| qt||d& | |||
d& |||||||d'| _tg | _ t!t"|}t!t"|}t!t"|}t!t"|
}|d }tt|D ]w}|}|| }|t#|d t|d  }|t|d k}t$d+i d|d|d(|d|d|d|| d d d|d)| d|	d|| o| d|| o|d|| d|d |d!|d"|d#|d$|d%|}| j | qsttj% |d |dd*t t&t||d |ddddd| _'d S ),NTz|DiffusionModelUNet expects dimension of the cross-attention conditioning (cross_attention_dim) when using with_conditioning.FzZDiffusionModelUNet expects with_conditioning=True when specifying the cross_attention_dim.g      ?r   z#Dropout cannot be negative or >1.0!c                 3      | ]	}|  d kV  qdS r   Nr   rf   out_channelrZ   r   r   	<genexpr>%      z.DiffusionModelUNet.__init__.<locals>.<genexpr>zMDiffusionModelUNet expects all num_channels being multiple of norm_num_groupszKDiffusionModelUNet expects num_channels being same size of attention_levelsnum_head_channels should have the same length as attention_levels. For the i levels without attention, i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored.zj`num_res_blocks` should be a single integer or a tuple of integers with the same length as `num_channels`.r   rT   rt   r`   r2   rW   rX   ra   r   r   rZ   r[   r   r   r   r   r!   r   r$   r&   r   r)   r*   r(   rr   r   r   r   r\   r   )(r7   r8   r   anylen
isinstancer   r   rX   block_out_channelsra   r   r   r!   r   r   conv_inr   
Sequentialr   r   
time_embedr  	Embeddingclass_embeddingrn   down_blocksro   r   r   r   middle_block	up_blockslistreversedminr   rk   r   out) rA   rW   rX   ra   r   r   r   rZ   r[   r   r!   r   r   r$   r  r&   r   r)   r*   r(   time_embed_dimoutput_channelr   input_channelis_final_block
down_blockreversed_block_out_channelsreversed_num_res_blocksreversed_attention_levelsZreversed_num_head_channelsr   Zup_blockrB   r  r   r8     sp  


	
	

zDiffusionModelUNet.__init__rD   rE   r   rF   rG   class_labelsdown_block_additional_residualstuple[torch.Tensor] | Nonemid_block_additional_residualc                 C  sb  t || jd }|j|jd}| |}| jdur1|du r!td| |}	|	j|jd}	||	 }| |}
|durC| j	du rCtd|
g}| j
D ]}||
||d\}
}|D ]}|| qVqI|durzg }t||D ]\}}|| }||g7 }qj|}| j|
||d}
|dur|
| }
| jD ]}t|j }||d }|d| }||
|||d}
q| |
}|S )	a  
        Args:
            x: input tensor (N, C, SpatialDims).
            timesteps: timestep tensor (N,).
            context: context tensor (N, 1, ContextDim).
            class_labels: context tensor (N, ).
            down_block_additional_residuals: additional residual tensors for down blocks (N, C, FeatureMapsDims).
            mid_block_additional_residual: additional residual tensor for mid block (N, C, FeatureMapsDims).
        r   r   N9class_labels should be provided when num_class_embeds > 0FAmodel should have with_conditioning = True if context is providedr   r   rF   )r   r   r   rF   )r   r  tor   r  r  r   r  r  r   r  r   r   r  r  r  r   r  )rA   rD   r   rF   r#  r$  r&  t_embr   	class_embr   down_block_res_samplesdownsample_blockres_samplesr~   Znew_down_block_res_samplesdown_block_res_sampleZdown_block_additional_residualZupsample_blockidxr   r   r   r   rJ     sH   






zDiffusionModelUNet.forwardold_state_dictdictc                   sD  |    t fdd|D rtd | | dS |rB D ]}||vr,td| d qtd |D ]}| vrAtd| d q3 D ]}||v rQ|| |< qDd	d
  D }|D ]x}|| d | d< || d | d< || d | d< || d | d< || d | d< || d | d< | d v r| d v r|| d | d< || d | d< q[dd
  D }|D ]}|| d | d< || d | d< q݈ D ]}d |v r|d d!}|| |< q|rtd"|  |   dS )#z
        Load a state dict from a DiffusionModelUNet trained with
        [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).

        Args:
            old_state_dict: state dict from the old DecoderOnlyTransformer  model.
        c                 3  s    | ]}| v V  qd S rK   r   rf   knew_state_dictr   r   r    s    z9DiffusionModelUNet.load_old_state_dict.<locals>.<genexpr>z#All keys match, loading state dict.Nzkey z not found in old state dictz.----------------------------------------------z not found in new state dictc                 S  s    g | ]}d |v r| ddqS )zattn.to_k.weight.attn.to_k.weight replacer5  r   r   r   rj   3  s     z:DiffusionModelUNet.load_old_state_dict.<locals>.<listcomp>z.to_q.weightz.attn.to_q.weightz.to_k.weightr9  z.to_v.weightz.attn.to_v.weightz
.to_q.biasz.attn.to_q.biasz
.to_k.biasz.attn.to_k.biasz
.to_v.biasz.attn.to_v.biasz.attn.out_proj.weightz.attn.out_proj.biasz.proj_attn.weightz.proj_attn.biasc                 S  s(   g | ]}d |v rd|v r| ddqS )zout_proj.weightrp   .out_proj.weightr:  r;  r5  r   r   r   rj   A  s
    
z.to_out.0.weightr=  z.to_out.0.biasz.out_proj.biaspostconvconvz!remaining keys in old_state_dict:)
state_dictallprintload_state_dictpopr<  keys)rA   r3  verboser6  attention_blocksr   Zcross_attention_blocksold_namer   r7  r   load_old_state_dict  sZ   	

z&DiffusionModelUNet.load_old_state_dict)r   r   r   rU   rV   Fr   FrT   NNFr   TFF)(rW   r   rX   r   ra   r   r   r   r   r   r   r   rZ   r   r[   r#   r   r'   r!   r  r   r'   r   r   r$   r%   r  r%   r&   r'   r   r#   r)   r'   r*   r'   r(   r'   r   r+   )NNNN)rD   rE   r   rE   rF   rG   r#  rG   r$  r%  r&  rG   r   rE   )F)r3  r4  r   r+   )rN   rO   rP   rQ   r8   rJ   rI  rR   r   r   rB   r   r     s4    " IMc                      sT   e Zd ZdZ											
						d4d5 fd*d+Z		d6d7d2d3Z  ZS )8DiffusionModelEncoderaU  
    Classification Network based on the Encoder of the Diffusion Model, followed by fully connected layers. This network is based on
    Wolleb et al. "Diffusion Models for Medical Anomaly Detection" (https://arxiv.org/abs/2203.04306).

    Args:
        spatial_dims: number of spatial dimensions.
        in_channels: number of input channels.
        out_channels: number of output channels.
        input_shape: spatial shape of the input (without batch and channel dims).
        num_res_blocks: number of residual blocks (see _ResnetBlock) per level.
        channels: tuple of block output channels.
        attention_levels: list of levels to add attention.
        norm_num_groups: number of groups for the normalization.
        norm_eps: epsilon for the normalization.
        resblock_updown: if True use residual blocks for downsampling.
        num_head_channels: number of channels in each attention head.
        with_conditioning: if True add spatial transformers to perform conditioning.
        transformer_num_layers: number of layers of Transformer blocks to use.
        cross_attention_dim: number of context dimensions to use.
        num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds` classes.
        upcast_attention: if True, upcast attention operations to full precision.
    r   r   r   r   r   rU   rV   Fr   rT   NTrW   r   rX   ra   input_shaper   r   r   r   r   r   rZ   r[   r#   r   r'   r!   r  r   r   r$   r%   r  r&   r)   r*   r(   r   r+   c              	     s  t    |du r|d u rtd|d ur|du rtdt fdd|D r,tdt|t|kr8tdt|trDt|t|}t|trPt|t|}t|t|kr\td	|| _|| _	|| _
|| _|| _|| _|| _t|||d
 ddddd| _|d
 d }tt|d
 |t t||| _|| _|d urt||| _tg | _|d
 }tt|D ]]}|}|| }|t|k}td(i d|d|d|d|d|| d d|	d| d|
d|| o| d|| o|d|| d|d|d|d|d|d |}| j| q|D ]
}d!d" |D }qttd#d$ ||d%  }tt|d&t t d'td&| j
| _!d S ))NTzDiffusionModelEncoder expects dimension of the cross-attention conditioning (cross_attention_dim) when using with_conditioning.Fz]DiffusionModelEncoder expects with_conditioning=True when specifying the cross_attention_dim.c                 3  r  r  r   r  r  r   r   r    r	  z1DiffusionModelEncoder.__init__.<locals>.<genexpr>zPDiffusionModelEncoder expects all num_channels being multiple of norm_num_groupszNDiffusionModelEncoder expects num_channels being same size of attention_levelsr
  r   rT   rt   r`   r2   rW   rX   ra   r   r   rZ   r[   r   r   r   r   r!   r   r$   r&   r)   r*   r(   c                 S  s   g | ]}t t|d  qS )rs   )r   npceil)rf   i_r   r   r   rj     s    z2DiffusionModelEncoder.__init__.<locals>.<listcomp>c                 S  s   | | S rK   r   )rD   yr   r   r   <lambda>  s    z0DiffusionModelEncoder.__init__.<locals>.<lambda>rr   i   g?r   )"r7   r8   r   r  r  r  r   r   rX   r  ra   r   r   r!   r   r   r  r   r  r   r   r  r  r  r  rn   r  ro   r   r   r   ReLUDropoutr  )rA   rW   rX   ra   rL  r   r   r   rZ   r[   r   r!   r   r   r$   r  r&   r)   r*   r(   r  r  r   r  r  r  rg   Zlast_dim_flattenedrB   r  r   r8   m  s   


	

$
zDiffusionModelEncoder.__init__rD   rE   r   rF   rG   r#  c                 C  s   t || jd }|j|jd}| |}| jdur1|du r!td| |}|j|jd}|| }| |}|durC| j	du rCtd| j
D ]}	|	|||d\}}
qF||jd d}| |}|S )	z
        Args:
            x: input tensor (N, C, SpatialDims).
            timesteps: timestep tensor (N,).
            context: context tensor (N, 1, ContextDim).
            class_labels: context tensor (N, ).
        r   r'  Nr(  Fr)  r*  rr   )r   r  r+  r   r  r  r   r  r  r   r  rw   ru   r  )rA   rD   r   rF   r#  r,  r   r-  r   r/  rg   r   r   r   r   rJ     s"   





zDiffusionModelEncoder.forward)rK  r   r   r   rU   rV   Fr   FrT   NNFTFF)(rW   r   rX   r   ra   r   rL  r   r   r   r   r   r   r   rZ   r   r[   r#   r   r'   r!   r  r   r'   r   r   r$   r%   r  r%   r&   r'   r)   r'   r*   r'   r(   r'   r   r+   )NN)
rD   rE   r   rE   rF   rG   r#  rG   r   rE   rM   r   r   rB   r   rJ  U  s,    }rJ  )r   r   r   r   )r   )r   rE   r   r   r   r   r   rE   )Fr   TFF)(rW   r   rX   r   ra   r   r   r   r   r   rZ   r   r[   r#   r   r'   r   r'   r   r'   r   r'   r!   r   r   r   r$   r%   r&   r'   r   r#   r)   r'   r*   r'   r(   r'   r   r   )rW   r   rX   r   r   r   rZ   r   r[   r#   r   r'   r!   r   r   r   r$   r%   r&   r'   r   r#   r)   r'   r*   r'   r(   r'   r   r   )*rW   r   rX   r   r   r   ra   r   r   r   r   r   rZ   r   r[   r#   r   r'   r   r'   r   r'   r   r'   r!   r   r   r   r$   r%   r&   r'   r   r#   r)   r'   r*   r'   r(   r'   r   r   )/
__future__r   r   collections.abcr   	functoolsr   numpyrM  r9   r   monai.networks.blocksr   r   r   r	   r
   r   monai.networks.layers.factoriesr   monai.utilsr   r   r   rg   __all__r   Moduler   rS   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   rJ  r   r   r   r   <module>   sf    
	Et/sZrFPm  "Q>K  t