o
    -iI                     @  s   d dl mZ d dlmZ d dlZd dlmZ d dlmZ d dlm	Z	m
Z
mZ d dlmZ G dd	 d	ejZd
d ZG dd dejZdS )    )annotations)SequenceN)nn)Convolution)get_down_blockget_mid_blockget_timestep_embedding)ensure_tuple_repc                      s*   e Zd ZdZd fdd	Zd
d Z  ZS )ControlNetConditioningEmbeddingzA
    Network to encode the conditioning into a latent space.
    spatial_dimsintin_channelsout_channelschannelsSequence[int]c                   s   t    t|||d dddddd| _tg | _tt|d D ],}|| }||d  }| j	t|||dddddd | j	t|||dddddd q"t
t||d |dddd	d
| _d S )Nr         AZSWISH)r   r   r   strideskernel_sizepaddingadn_orderingact   Tr   r   r   r   r   r   	conv_only)super__init__r   conv_inr   
ModuleListblocksrangelenappendzero_moduleconv_out)selfr   r   r   r   iZ
channel_inZchannel_out	__class__ `/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/controlnet.pyr   1   sd   

z(ControlNetConditioningEmbedding.__init__c                 C  s,   |  |}| jD ]}||}q| |}|S N)r   r!   r&   )r'   Zconditioning	embeddingblockr+   r+   r,   forwardj   s
   



z'ControlNetConditioningEmbedding.forward)r   r   r   r   r   r   r   r   )__name__
__module____qualname____doc__r   r0   __classcell__r+   r+   r)   r,   r
   ,   s    9r
   c                 C  s   |   D ]}tj| q| S r-   )
parametersr   initzeros_)modulepr+   r+   r,   r%   u   s   r%   c                      sd   e Zd ZdZ											
	
							d<d= fd*d+Z	,	
	
d>d?d6d7Zd@dAd:d;Z  ZS )B
ControlNeta  
    Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image
    Diffusion Models" (https://arxiv.org/abs/2302.05543)

    Args:
        spatial_dims: number of spatial dimensions.
        in_channels: number of input channels.
        num_res_blocks: number of residual blocks (see ResnetBlock) per level.
        channels: tuple of block output channels.
        attention_levels: list of levels to add attention.
        norm_num_groups: number of groups for the normalization.
        norm_eps: epsilon for the normalization.
        resblock_updown: if True use residual blocks for up/downsampling.
        num_head_channels: number of channels in each attention head.
        with_conditioning: if True add spatial transformers to perform conditioning.
        transformer_num_layers: number of layers of Transformer blocks to use.
        cross_attention_dim: number of context dimensions to use.
        num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
            classes.
        upcast_attention: if True, upcast attention operations to full precision.
        conditioning_embedding_in_channels: number of input channels for the conditioning embedding.
        conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding.
        include_fc: whether to include the final linear layer. Default to True.
        use_combined_linear: whether to use a single linear layer for qkv projection, default to True.
        use_flash_attention: if True, use Pytorch's inbuilt flash attention for a memory efficient attention mechanism
            (see https://pytorch.org/docs/2.2/generated/torch.nn.functional.scaled_dot_product_attention.html).
    r   r   r   r       @   r?   r?   FFTTr>   ư>F   r   N   r>   `      Tr   r   r   num_res_blocksSequence[int] | intr   r   attention_levelsSequence[bool]norm_num_groupsnorm_epsfloatresblock_updownboolnum_head_channelsint | Sequence[int]with_conditioningtransformer_num_layerscross_attention_dim
int | Nonenum_class_embedsupcast_attention"conditioning_embedding_in_channels#conditioning_embedding_num_channels
include_fcuse_combined_linearuse_flash_attentionreturnNonec                   s  t    |
du r|d u rtd|d ur|
du rtdt fdd|D r2td| d  t|t|krDtd	| d
| t|	trPt|	t|}	t|	t|krctd| d
| dt|trot|t|}t|t|krtd| d| d|| _|| _	|| _
|| _|	| _|
| _t|||d ddddd| _|d d }tt|d |t t||| _|| _|d urt||| _t||||d d| _tg | _tg | _|d }t|||ddddd}t|j}| j| t t|D ]}|}|| }|t|d k}t!d*i d|d|d|d|d|| d d|d| d|d|| o:|
 d || oB|
d!|	| d"|d#|d$|d%|d&|d'|}| j| t || D ]}t|||ddddd}t|}| j| qh|st|||ddddd}t|}| j| q|d( }t"||| ||
|	d( ||||||d)| _#t|||ddddd}t|}|| _$d S )+NTzControlNet expects dimension of the cross-attention conditioning (cross_attention_dim) to be specified when with_conditioning=True.FzRControlNet expects with_conditioning=True when specifying the cross_attention_dim.c                 3  s    | ]	}|  d kV  qdS )r   Nr+   ).0out_channelrK   r+   r,   	<genexpr>   s    z&ControlNet.__init__.<locals>.<genexpr>zVControlNet expects all channels to be a multiple of norm_num_groups, but got channels=z and norm_num_groups=zZControlNet expects channels to have the same length as attention_levels, but got channels=z and attention_levels=zTnum_head_channels should have the same length as attention_levels, but got channels=zq . For the i levels without attention, i.e. `attention_level[i]=False`, the num_head_channels[i] will be ignored.z`num_res_blocks` should be a single integer or a tuple of integers with the same length as `num_channels`, but got num_res_blocks=z and channels=.r   r   r   r      )r   r   r   r   r   r   r   temb_channelsrG   rK   rL   Zadd_downsamplerN   Z	with_attnZwith_cross_attnrP   rS   rT   rW   rZ   r[   r\   r   )r   r   re   rK   rL   rR   rP   rS   rT   rW   rZ   r[   r\   r+   )%r   r   
ValueErroranyr#   
isinstancer   r	   r   block_out_channelsrG   rI   rP   rR   r   r   r   
SequentialLinearSiLU
time_embedrV   	Embeddingclass_embeddingr
   controlnet_cond_embeddingr    down_blockscontrolnet_down_blocksr%   convr$   r"   r   r   middle_blockcontrolnet_mid_block)r'   r   r   rG   r   rI   rK   rL   rN   rP   rR   rS   rT   rV   rW   rX   rY   rZ   r[   r\   Ztime_embed_dimoutput_channelcontrolnet_blockr(   input_channelis_final_blockZ
down_block_Zmid_block_channelr)   ra   r,   r      sf  



		
			
zControlNet.__init__      ?xtorch.Tensor	timestepscontrolnet_condconditioning_scalecontexttorch.Tensor | Noneclass_labels'tuple[list[torch.Tensor], torch.Tensor]c                   s<  t || jd }|j|jd}| |}| jdur1|du r!td| |}	|	j|jd}	||	 }| |}
| 	|}|
|7 }
|durL| j
du rLtd|
g}| jD ]}||
||d\}
}|D ]}|| q_qR| j|
||d}
g }t|| jD ]\}}||}|| qx|}| |
} fdd	|D }| 9 }||fS )
a  
        Args:
            x: input tensor (N, C, H, W, [D]).
            timesteps: timestep tensor (N,).
            controlnet_cond: controlnet conditioning tensor (N, C, H, W, [D])
            conditioning_scale: conditioning scale.
            context: context tensor (N, 1, cross_attention_dim), where cross_attention_dim is specified in the model init.
            class_labels: context tensor (N, ).
        r   )dtypeNz9class_labels should be provided when num_class_embeds > 0FzAmodel should have with_conditioning = True if context is provided)Zhidden_statesZtembr   c                   s   g | ]}|  qS r+   r+   )r_   hr   r+   r,   
<listcomp>  s    z&ControlNet.forward.<locals>.<listcomp>)r   ri   tor   rm   rV   rf   ro   r   rp   rR   rq   r$   rt   ziprr   ru   )r'   r|   r~   r   r   r   r   Zt_embZembZ	class_embr   Zdown_block_res_samplesZdownsample_blockZres_samplesresidualZ!controlnet_down_block_res_samplesZdown_block_res_samplerw   Zmid_block_res_sampler+   r   r,   r0   b  s<   






zControlNet.forwardold_state_dictdictc                   s  |    t fdd|D rtd | | dS |rB D ]}||vr,td| d qtd |D ]}| vrAtd| d q3 D ]}||v rQ|| |< qDd	d
  D }|D ]}|| d | d< || d | d< q[|rtd|  |   dS )z
        Load a state dict from a ControlNet trained with
        [MONAI Generative](https://github.com/Project-MONAI/GenerativeModels).

        Args:
            old_state_dict: state dict from the old ControlNet model.
        c                 3  s    | ]}| v V  qd S r-   r+   r_   knew_state_dictr+   r,   rb     s    z1ControlNet.load_old_state_dict.<locals>.<genexpr>z#All keys match, loading state dict.Nzkey z not found in old state dictz.----------------------------------------------z not found in new state dictc                 S  s    g | ]}d |v r| ddqS )zout_proj.weight.out_proj.weight )replacer   r+   r+   r,   r     s     z2ControlNet.load_old_state_dict.<locals>.<listcomp>z.to_out.0.weightr   z.to_out.0.biasz.out_proj.biasz!remaining keys in old_state_dict:)
state_dictallprintload_state_dictpopkeys)r'   r   verboser   attention_blocksr/   r+   r   r,   load_old_state_dict  s4   	
zControlNet.load_old_state_dict)r<   r=   r@   r>   rA   FrB   Fr   NNFr   rC   TFF)(r   r   r   r   rG   rH   r   r   rI   rJ   rK   r   rL   rM   rN   rO   rP   rQ   rR   rO   rS   r   rT   rU   rV   rU   rW   rO   rX   r   rY   r   rZ   rO   r[   rO   r\   rO   r]   r^   )r{   NN)r|   r}   r~   r}   r   r}   r   rM   r   r   r   r   r]   r   )F)r   r   r]   r^   )r1   r2   r3   r4   r   r0   r   r5   r+   r+   r)   r,   r;   {   s4      PGr;   )
__future__r   collections.abcr   torchr   monai.networks.blocksr   Z(monai.networks.nets.diffusion_model_unetr   r   r   monai.utilsr	   Moduler
   r%   r;   r+   r+   r+   r,   <module>   s   I