o
    iJ                     @  s   d dl mZ d dlmZ d dlZd dlmZ d dlmZ d dlm	Z	m
Z
mZmZmZ d dlmZ d dlmZ d	gZG d
d	 d	ejZdS )    )annotations)SequenceN)nn)Convolution)get_down_blockget_mid_blockget_timestep_embeddingget_up_blockzero_module)ensure_tuple_rep)convert_to_tensorDiffusionModelUNetMaisic                      s   e Zd ZdZ											
	
								dFdG fd,d-Zd.d/ Zd0d1 Zd2d3 Zd4d5 Zd6d7 Z		
	
	
	
	
	
	
dHdIdDdEZ
  ZS )Jr   a,  
    U-Net network with timestep embedding and attention mechanisms for conditioning based on
    Rombach et al. "High-Resolution Image Synthesis with Latent Diffusion Models" https://arxiv.org/abs/2112.10752
    and Pinaya et al. "Brain Imaging Generation with Latent Diffusion Models" https://arxiv.org/abs/2209.07162

    Args:
        spatial_dims: Number of spatial dimensions.
        in_channels: Number of input channels.
        out_channels: Number of output channels.
        num_res_blocks: Number of residual blocks (see ResnetBlock) per level. Can be a single integer or a sequence of integers.
        num_channels: Tuple of block output channels.
        attention_levels: List of levels to add attention.
        norm_num_groups: Number of groups for the normalization.
        norm_eps: Epsilon for the normalization.
        resblock_updown: If True, use residual blocks for up/downsampling.
        num_head_channels: Number of channels in each attention head. Can be a single integer or a sequence of integers.
        with_conditioning: If True, add spatial transformers to perform conditioning.
        transformer_num_layers: Number of layers of Transformer blocks to use.
        cross_attention_dim: Number of context dimensions to use.
        num_class_embeds: If specified (as an int), then this model will be class-conditional with `num_class_embeds` classes.
        upcast_attention: If True, upcast attention operations to full precision.
        include_fc: whether to include the final linear layer. Default to False.
        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
        use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
        dropout_cattn: If different from zero, this will be the dropout value for the cross-attention layers.
        include_top_region_index_input: If True, use top region index input.
        include_bottom_region_index_input: If True, use bottom region index input.
        include_spacing_input: If True, use spacing input.
       r   r   r       @   r   r   FFTTr   ư>F      N        spatial_dimsintin_channelsout_channelsnum_res_blocksSequence[int] | intnum_channelsSequence[int]attention_levelsSequence[bool]norm_num_groupsnorm_epsfloatresblock_updownboolnum_head_channelsint | Sequence[int]with_conditioningtransformer_num_layerscross_attention_dim
int | Nonenum_class_embedsupcast_attention
include_fcuse_combined_linearuse_flash_attentiondropout_cattninclude_top_region_index_input!include_bottom_region_index_inputinclude_spacing_inputreturnNonec           $        s  t    |du r|d u rtd|d ur|du rtd|dks%|dk r)tdt fdd	|D r>td
| d  t|t|krTtdt| dt| t|
tr`t|
t|}
t|
t|krltdt|trxt|t|}t|t|krtd|du rtj	
 std|| _|| _|| _|| _|| _|
| _|| _t|||d ddddd| _|d d }| |d || _|| _|d urt||| _|| _|| _|| _|}| jr| d|| _||7 }| jr| d|| _||7 }| jr| d|| _||7 }t g | _!|d }t"t|D ]e}|}|| }|t|d k}t#d.i d|d|d|d|d|| d d|d| d|	d|| oP| d || oX|d!|
| d"|d#|d$|d%|d&|d'|d(|}| j!$| qt%||d) | |||
d) |||||||d*| _&t g | _'t(t)|}t(t)|}t(t)|} t(t)|
}!|d }t"t|D ]w}|}"|| }|t*|d t|d  }|t|d k}t+d.i d|d|d+|"d|d|d|| d d d|d,| d|	d| | o| d | | o|d!|!| d"|d#|d$|d%|d&|d'|d(|}#| j'$|# qt,tj- |d |dd-t. t/t||d |ddddd| _0d S )/NTzDiffusionModelUNetMaisi expects dimension of the cross-attention conditioning (cross_attention_dim) when using with_conditioning.Fz_DiffusionModelUNetMaisi expects with_conditioning=True when specifying the cross_attention_dim.g      ?r   z#Dropout cannot be negative or >1.0!c                 3  s    | ]	}|  d kV  qdS )r   N ).0out_channelr"   r8   /home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/generation/maisi/networks/diffusion_model_unet_maisi.py	<genexpr>{   s    z3DiffusionModelUNetMaisi.__init__.<locals>.<genexpr>zjDiffusionModelUNetMaisi expects all num_channels being multiple of norm_num_groups, but get num_channels: z and norm_num_groups: zhDiffusionModelUNetMaisi expects num_channels being same size of attention_levels, but get num_channels: z and attention_levels: znum_head_channels should have the same length as attention_levels. For the i levels without attention, i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored.zj`num_res_blocks` should be a single integer or a tuple of integers with the same length as `num_channels`.zatorch.cuda.is_available() should be True but is False. Flash attention is only available for GPU.r   r      )r   r   r   strideskernel_sizepadding	conv_only   r   r   r   temb_channelsr   r"   r#   add_downsampler%   	with_attnwith_cross_attnr'   r*   r+   r.   r/   r0   r1   r2   )r   r   rD   r"   r#   r)   r'   r*   r+   r.   r/   r0   r1   r2   prev_output_channeladd_upsample)
num_groupsr   epsaffiner8   )1super__init__
ValueErroranylen
isinstancer   r   torchcudais_availabler   block_out_channelsr   r   r    r'   r)   r   conv_in_create_embedding_module
time_embedr-   r   	Embeddingclass_embeddingr3   r4   r5   top_region_index_layerbottom_region_index_layerspacing_layer
ModuleListdown_blocksranger   appendr   middle_block	up_blockslistreversedminr	   
Sequential	GroupNormSiLUr
   out)$selfr   r   r   r   r   r    r"   r#   r%   r'   r)   r*   r+   r-   r.   r/   r0   r1   r2   r3   r4   r5   time_embed_dimZnew_time_embed_dimoutput_channeliinput_channelis_final_block
down_blockreversed_block_out_channelsreversed_num_res_blocksreversed_attention_levelsreversed_num_head_channelsrI   up_block	__class__r;   r<   rO   T   s  


	
	

z DiffusionModelUNetMaisi.__init__c                 C  s&   t t ||t  t ||}|S )N)r   ri   Linearrk   )rm   	input_dim	embed_dimmodelr8   r8   r<   rY   2  s   "z0DiffusionModelUNetMaisi._create_embedding_modulec                 C  sf   t || jd }|j|jd}| |}| jd ur1|d u r!td| |}|j|jd}||7 }|S )Nr   )dtypez9class_labels should be provided when num_class_embeds > 0)r   rW   tor   rZ   r-   rP   r\   )rm   x	timestepsclass_labelst_embemb	class_embr8   r8   r<   _get_time_and_class_embedding6  s   


z5DiffusionModelUNetMaisi._get_time_and_class_embeddingc                 C  sj   | j r| |}tj||fdd}| jr"| |}tj||fdd}| jr3| |}tj||fdd}|S )Nr   )dim)r3   r]   rT   catr4   r^   r5   r_   )rm   r   Z	top_indexbottom_indexspacingZ_embr8   r8   r<   _get_input_embeddingsG  s   


z-DiffusionModelUNetMaisi._get_input_embeddingsc                 C  s   |d ur| j du rtd|g}| jD ]}||||d\}}|| q|d ur?g }t||D ]\}	}
|	|
7 }	||	 q/|}||fS )NFzAmodel should have with_conditioning = True if context is provided)hidden_statestembcontext)r)   rP   ra   extendziprc   )rm   hr   r   down_block_additional_residualsdown_block_res_samplesdownsample_blockres_samplesnew_down_block_res_samplesdown_block_res_sampledown_block_additional_residualr8   r8   r<   _apply_down_blocksS  s   
z*DiffusionModelUNetMaisi._apply_down_blocksc                 C  sD   | j D ]}t|j }||d  }|d | }|||||d}q|S )N)r   res_hidden_states_listr   r   )re   rR   resnets)rm   r   r   r   r   upsample_blockidxr   r8   r8   r<   _apply_up_blocksg  s   
z(DiffusionModelUNetMaisi._apply_up_blocksr   torch.Tensorr   r   torch.Tensor | Noner   r   tuple[torch.Tensor] | Nonemid_block_additional_residualtop_region_index_tensorbottom_region_index_tensorspacing_tensorc
                 C  s   |  |||}
| |
|||	}
| |}| ||
||\}}| ||
|}|dur-||7 }| ||
||}| |}t|}|S )a{  
        Forward pass through the UNet model.

        Args:
            x: Input tensor of shape (N, C, SpatialDims).
            timesteps: Timestep tensor of shape (N,).
            context: Context tensor of shape (N, 1, ContextDim).
            class_labels: Class labels tensor of shape (N,).
            down_block_additional_residuals: Additional residual tensors for down blocks of shape (N, C, FeatureMapsDims).
            mid_block_additional_residual: Additional residual tensor for mid block of shape (N, C, FeatureMapsDims).
            top_region_index_tensor: Tensor representing top region index of shape (N, 4).
            bottom_region_index_tensor: Tensor representing bottom region index of shape (N, 4).
            spacing_tensor: Tensor representing spacing of shape (N, 3).

        Returns:
            A tensor representing the output of the UNet model.
        N)r   r   rX   r   rd   r   rl   r   )rm   r   r   r   r   r   r   r   r   r   r   r   Z_updated_down_block_res_samplesZh_tensorr8   r8   r<   forwardp  s   

zDiffusionModelUNetMaisi.forward)r   r   r   r   r   Fr   Fr   NNFFFFr   FFF).r   r   r   r   r   r   r   r   r   r   r    r!   r"   r   r#   r$   r%   r&   r'   r(   r)   r&   r*   r   r+   r,   r-   r,   r.   r&   r/   r&   r0   r&   r1   r&   r2   r$   r3   r&   r4   r&   r5   r&   r6   r7   )NNNNNNN)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r6   r   )__name__
__module____qualname____doc__rO   rY   r   r   r   r   r   __classcell__r8   r8   ry   r<   r   5   sH    # _)
__future__r   collections.abcr   rT   r   monai.networks.blocksr   (monai.networks.nets.diffusion_model_unetr   r   r   r	   r
   monai.utilsr   monai.utils.type_conversionr   __all__Moduler   r8   r8   r8   r<   <module>   s   