o
    i                     @  sL   d dl mZ d dlmZ d dlZd dlmZ d dlmZ G dd deZ	dS )    )annotations)SequenceN)
ControlNet)get_timestep_embeddingc                      s   e Zd ZdZ											
	
								dCdD fd+d,Z	-	
	
dEdFd7d8Zd9d: Zd;d< Zd=d> Zd?d@ Z	dAdB Z
  ZS )GControlNetMaisia  
    Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image
    Diffusion Models" (https://arxiv.org/abs/2302.05543)

    Args:
        spatial_dims: number of spatial dimensions.
        in_channels: number of input channels.
        num_res_blocks: number of residual blocks (see ResnetBlock) per level.
        num_channels: tuple of block output channels.
        attention_levels: list of levels to add attention.
        norm_num_groups: number of groups for the normalization.
        norm_eps: epsilon for the normalization.
        resblock_updown: if True use residual blocks for up/downsampling.
        num_head_channels: number of channels in each attention head.
        with_conditioning: if True add spatial transformers to perform conditioning.
        transformer_num_layers: number of layers of Transformer blocks to use.
        cross_attention_dim: number of context dimensions to use.
        num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
            classes.
        upcast_attention: if True, upcast attention operations to full precision.
        conditioning_embedding_in_channels: number of input channels for the conditioning embedding.
        conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding.
        use_checkpointing: if True, use activation checkpointing to save memory.
        include_fc: whether to include the final linear layer. Default to False.
        use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
        use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
       r   r   r       @   r   r   FFTTr
   ư>F      N   r
   `      Tspatial_dimsintin_channelsnum_res_blocksSequence[int] | intnum_channelsSequence[int]attention_levelsSequence[bool]norm_num_groupsnorm_epsfloatresblock_updownboolnum_head_channelsint | Sequence[int]with_conditioningtransformer_num_layerscross_attention_dim
int | Nonenum_class_embedsupcast_attention"conditioning_embedding_in_channels#conditioning_embedding_num_channelsuse_checkpointing
include_fcuse_combined_linearuse_flash_attentionreturnNonec                   s:   t  |||||||||	|
||||||||| || _d S N)super__init__r,   )selfr   r   r   r   r   r   r   r    r"   r$   r%   r&   r(   r)   r*   r+   r,   r-   r.   r/   	__class__ w/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/generation/maisi/networks/controlnet_maisi.pyr4   3   s,   
zControlNetMaisi.__init__      ?xtorch.Tensor	timestepscontrolnet_condconditioning_scalecontexttorch.Tensor | Noneclass_labels'tuple[list[torch.Tensor], torch.Tensor]c                   s   |  |||}| |}| jrtjjj| j|dd}n| |}||7 }| |||\}	}| |||}| 	||	\}	}
 fdd|	D }	|
 9 }
|	|
fS )NF)use_reentrantc                   s   g | ]}|  qS r8   r8   ).0hr?   r8   r9   
<listcomp>w   s    z+ControlNetMaisi.forward.<locals>.<listcomp>)
!_prepare_time_and_class_embedding_apply_initial_convolutionr,   torchutils
checkpointcontrolnet_cond_embedding_apply_down_blocks_apply_mid_block_apply_controlnet_blocks)r5   r;   r=   r>   r?   r@   rB   embrF   down_block_res_samplesmid_block_res_sampler8   rG   r9   forwarda   s   	

zControlNetMaisi.forwardc                 C  sf   t || jd }|j|jd}| |}| jd ur1|d u r!td| |}|j|jd}|| }|S )Nr   )dtypez9class_labels should be provided when num_class_embeds > 0)r   block_out_channelstorV   
time_embedr(   
ValueErrorclass_embedding)r5   r;   r=   rB   t_embrR   	class_embr8   r8   r9   rI   |   s   


z1ControlNetMaisi._prepare_time_and_class_embeddingc                 C  s   |  |}|S r2   )conv_in)r5   r;   rF   r8   r8   r9   rJ      s   
z*ControlNetMaisi._apply_initial_convolutionc                 C  sZ   |d ur| j du rtd|g}| jD ]}||||d\}}|D ]}|| q q||fS )NFzAmodel should have with_conditioning = True if context is providedhidden_statestembr@   )r$   rZ   down_blocksappend)r5   rR   r@   rF   rS   downsample_blockres_samplesresidualr8   r8   r9   rO      s   
z"ControlNetMaisi._apply_down_blocksc                 C  s   | j |||d}|S )Nr_   )middle_block)r5   rR   r@   rF   r8   r8   r9   rP      s   z ControlNetMaisi._apply_mid_blockc                 C  s>   g }t || jD ]\}}||}|| q| |}||fS r2   )zipcontrolnet_down_blocksrc   controlnet_mid_block)r5   rF   rS   !controlnet_down_block_res_samplesdown_block_res_samplecontrolnet_blockrT   r8   r8   r9   rQ      s   
z(ControlNetMaisi._apply_controlnet_blocks)r   r	   r   r
   r   Fr   Fr   NNFr   r   TFFF)*r   r   r   r   r   r   r   r   r   r   r   r   r   r   r    r!   r"   r#   r$   r!   r%   r   r&   r'   r(   r'   r)   r!   r*   r   r+   r   r,   r!   r-   r!   r.   r!   r/   r!   r0   r1   )r:   NN)r;   r<   r=   r<   r>   r<   r?   r   r@   rA   rB   rA   r0   rC   )__name__
__module____qualname____doc__r4   rU   rI   rJ   rO   rP   rQ   __classcell__r8   r8   r6   r9   r      s<     3r   )

__future__r   collections.abcr   rK   Zmonai.networks.nets.controlnetr   (monai.networks.nets.diffusion_model_unetr   r   r8   r8   r8   r9   <module>   s   