o
    ,i                     @  s  d dl mZ d dlZd dlZd dlmZ d dlZd dlZd dl	m
Z
 d dlm
  mZ d dlmZmZmZmZ d dlmZ d dlmZmZ d dlmZ edd	d
\ZZeddd
\ZZg dZejjG dd dej
j Z!ejjG dd dej
j Z"G dd de"Z#dd Z$G dd de
j%Z&G dd deZ'G dd deZ(G dd deZ)G dd deZ*G d d! d!e
j Z+G d"d# d#e!Z,G d$d% d%e
j Z-G d&d' d'e
j Z.G d(d) d)e.Z/G d*d+ d+e.Z0dS ),    )annotationsN)Optional)ActiConvNormBlockFactorizedIncreaseBlockFactorizedReduceBlockP3DActiConvNormBlock)Conv)get_act_layerget_norm_layer)optional_importzscipy.sparse
csr_matrixnamezscipy.sparse.csgraphdijkstra)DiNTSTopologyConstructionTopologyInstanceTopologySearchc                   @  s   e Zd ZdZd
ddZd	S )CellInterfacez"interface for torchscriptable Cellxtorch.TensorweightOptional[torch.Tensor]returnc                 C     d S N selfr   r   r   r   [/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/nets/dints.pyforward,      zCellInterface.forwardNr   r   r   r   r   r   __name__
__module____qualname____doc__r    r   r   r   r   r   (       r   c                   @  s   e Zd ZdZdddZdS )	StemInterfacez"interface for torchscriptable Stemr   r   r   c                 C  r   r   r   r   r   r   r   r   r    4   r!   zStemInterface.forwardNr   r   r   r   r#   r   r   r   r   r)   0   r(   r)   c                      s*   e Zd ZdZ fddZd	ddZ  ZS )
StemTSz wrapper for torchscriptable Stemc                   s   t    tjj| | _d S r   )super__init__torchnn
Sequentialmod)r   r2   	__class__r   r   r.   ;   s   
zStemTS.__init__r   r   r   c                 C  s
   |  |S r   )r2   r*   r   r   r   r    ?   s   
zStemTS.forwardr+   r$   r%   r&   r'   r.   r    __classcell__r   r   r3   r   r,   8   s    r,   c                 C  s>   | |kr
dgdggS t | d |}dd |D dd |D  S )z>use depth first search to find all path activation combinationr      c                 S     g | ]}d g| qS r   r   .0_r   r   r   
<listcomp>H       z_dfs.<locals>.<listcomp>c                 S  r8   )r7   r   r:   r   r   r   r=   H   r>   )_dfs)nodepathschildr   r   r   r?   C   s   r?   c                      s   e Zd Z fddZ  ZS )_IdentityWithRAMCostc                   s   t  j|i | d| _d S Nr   r-   r.   ram_cost)r   argskwargsr3   r   r   r.   M   s   
z_IdentityWithRAMCost.__init__r$   r%   r&   r.   r6   r   r   r3   r   rC   K   s    rC   c                      s2   e Zd ZdZdddddiffd fddZ  ZS )_ActiConvNormBlockWithRAMCosta!  The class wraps monai layers with ram estimation. The ram_cost = total_ram/output_size is estimated.
    Here is the estimation:
     feature_size = output_size/out_channel
     total_ram = ram_cost * output_size
     total_ram = in_channel * feature_size (activation map) +
                 in_channel * feature_size (convolution map) +
                 out_channel * feature_size (normalization)
               = (2*in_channel + out_channel) * output_size/out_channel
     ram_cost = total_ram/output_size = 2 * in_channel/out_channel + 1
       RELUINSTANCEaffineT
in_channelintout_channelkernel_sizepaddingspatial_dimsact_nametuple | str	norm_namec              	     s.   t  ||||||| d|| d  | _d S )Nr7      rE   )r   rO   rQ   rR   rS   rT   rU   rW   r3   r   r   r.   ^   s   
z&_ActiConvNormBlockWithRAMCost.__init__)rO   rP   rQ   rP   rR   rP   rS   rP   rT   rP   rU   rV   rW   rV   )r$   r%   r&   r'   r.   r6   r   r   r3   r   rJ   R   s    
rJ   c                      s.   e Zd Zdddddiffd fddZ  ZS ) _P3DActiConvNormBlockWithRAMCostr   rL   rM   rN   TrO   rP   rQ   rR   rS   p3dmoderU   rV   rW   c              	     s.   t  ||||||| dd| |  | _d S NrX   rE   )r   rO   rQ   rR   rS   rZ   rU   rW   r3   r   r   r.   n   s   
z)_P3DActiConvNormBlockWithRAMCost.__init__)rO   rP   rQ   rP   rR   rP   rS   rP   rZ   rP   rU   rV   rW   rV   rI   r   r   r3   r   rY   l   s
    
rY   c                      .   e Zd Zdddddiffd fddZ  ZS )#_FactorizedIncreaseBlockWithRAMCostrK   rL   rM   rN   TrO   rP   rQ   rT   rU   rV   rW   c                   s*   t  ||||| d| | d | _d S r[   rE   r   rO   rQ   rT   rU   rW   r3   r   r   r.      s   z,_FactorizedIncreaseBlockWithRAMCost.__init__
rO   rP   rQ   rP   rT   rP   rU   rV   rW   rV   rI   r   r   r3   r   r]   ~   
    
r]   c                      r\   )!_FactorizedReduceBlockWithRAMCostrK   rL   rM   rN   TrO   rP   rQ   rT   rU   rV   rW   c                   s0   t  ||||| || d| j  d | _d S )NrX   rK   )r-   r.   _spatial_dimsrF   r^   r3   r   r   r.      s   z*_FactorizedReduceBlockWithRAMCost.__init__r_   rI   r   r   r3   r   ra      r`   ra   c                      s0   e Zd ZdZdd fddZddddZ  ZS )MixedOpa#  
    The weighted averaging of cell operations.
    Args:
        c: number of output channels.
        ops: a dictionary of operations. See also: ``Cell.OPS2D`` or ``Cell.OPS3D``.
        arch_code_c: binary cell operation code. It represents the operation results added to the output.
    NcrP   opsdictc                   s^   t    |d u rtt|}t | _t||D ]\}}|dkr,| j	|| | qd S rD   )
r-   r.   nponeslenr0   
ModuleListre   zipappend)r   rd   re   arch_code_cZarch_cop_namer3   r   r   r.      s   

zMixedOp.__init__r   r   r   r   c                 C  sV   d}|dur| |}t| jD ]\}}|du r||| n	|||||   }q|S )z
        Args:
            x: input tensor.
            weight: learnable architecture weights for cell operations. arch_code_c are derived from it.
        Return:
            out: weighted average of the operation results.
                N)to	enumeratere   )r   r   r   outidx_opr   r   r   r       s   
*zMixedOp.forwardr   )rd   rP   re   rf   )r   r   r   r   r5   r   r   r3   r   rc      s    	rc   c                      s   e Zd ZdZdZdd dd dZdd dd d	d d
d dd dZeee	e
dZddddddiffd$ fddZd%d"d#Z  ZS )&Cella  
    The basic class for cell operation search, which contains a preprocessing operation and a mixed cell operation.
    Each cell is defined on a `path` in the topology search space.
    Args:
        c_prev: number of input channels
        c: number of output channels
        rate: resolution change rate. It represents the preprocessing operation before the mixed cell operation.
            ``-1`` for 2x downsample, ``1`` for 2x upsample, ``0`` for no change of resolution.
        arch_code_c: cell operation code
    rK   c                 C     t  S r   rC   _cr   r   r   <lambda>       zCell.<lambda>c                 C     t | | ddddS )NrK   r7   rX   rS   rT   rJ   rd   r   r   r   rz          skip_connectZconv_3x3c                 C  rv   r   rw   rx   r   r   r   rz      r{   c                 C  s   t | | ddddS )NrK   r7   r}   r~   r   r   r   r   rz      r   c                 C  r|   )NrK   r7   r   rS   rZ   rY   r   r   r   r   rz      r   c                 C  s   t | | ddddS )NrK   r7   r   r   r   r   r   r   rz      r   c                 C  r|   )NrK   r7   rX   r   r   r   r   r   r   rz      r   r   Z
conv_3x3x3Z
conv_3x3x1Z
conv_3x1x3Z
conv_1x3x3)updownidentityalign_channelsNrL   rM   rN   Tc_prevrP   rd   raterT   rU   rV   rW   c              	     sN  t    | _| _| _|dkr# jd || j j jd _n3|dkr8 jd || j j jd _n||krD jd   _n jd ||dd j j jd _d	d
  fdd
d _dd
  fdd
 fdd
 fdd
 fdd
d _i  _	 jdkr j _	n jdkr j _	n	t
d j dt| j	| _d S )Nr   )rT   rU   rW   r7   r   r   r   r   c                 S  rv   r   rw   rx   r   r   r   rz     r{   zCell.__init__.<locals>.<lambda>c              	        t | | ddd j jdS )NrK   r7   rX   rS   rT   rU   rW   rJ   	_act_name
_norm_namer   r   r   r   rz         r   c                 S  rv   r   rw   rx   r   r   r   rz     r{   c              	     s   t | | ddd j jdS )NrK   r7   r   r   r   r   r   r   rz     r   c              	     r   )NrK   r7   r   rS   rZ   rU   rW   rY   r   r   r   r   r   r   rz     r   c              	     s   t | | ddd j jdS )NrK   r7   r   r   r   r   r   r   rz      r   c              	     r   )NrK   r7   rX   r   r   r   r   r   r   rz   #  r   r   rX   rK   Spatial dimensions  is not supported.)r-   r.   rb   r   r   ConnOPS
preprocessOPS2DOPS3DZOPSNotImplementedErrorrc   op)r   r   rd   r   rm   rT   rU   rW   r3   r   r   r.      sB   




	







zCell.__init__r   r   r   r   r   c                 C  s   |  |}| ||}|S )zi
        Args:
            x: input tensor
            weight: weights for different operations.
        )r   r   r   r   r   r   r    2  s   
zCell.forward)r   rP   rd   rP   r   rP   rT   rP   rU   rV   rW   rV   r"   )r$   r%   r&   r'   
DIRECTIONSr   r   r]   ra   rC   rJ   r   r.   r    r6   r   r   r3   r   ru      s.    

Bru   c                      sH   e Zd ZdZddddifdddfd fddZdd ZdddZ  ZS )r   a  
    Reimplementation of DiNTS based on
    "DiNTS: Differentiable Neural Network Topology Search for 3D Medical Image Segmentation
    <https://arxiv.org/abs/2103.15954>".

    The model contains a pre-defined multi-resolution stem block (defined in this class) and a
    DiNTS space (defined in :py:class:`monai.networks.nets.TopologyInstance` and
    :py:class:`monai.networks.nets.TopologySearch`).

    The stem block is for: 1) input downsample and 2) output upsample to original size.
    The model downsamples the input image by 2 (if ``use_downsample=True``).
    The downsampled image is downsampled by [1, 2, 4, 8] times (``num_depths=4``) and used as input to the
    DiNTS search space (``TopologySearch``) or the DiNTS instance (``TopologyInstance``).

        - ``TopologyInstance`` is the final searched model. The initialization requires the searched architecture codes.
        - ``TopologySearch`` is a multi-path topology and cell operation search space.
          The architecture codes will be initialized as one.
        - ``TopologyConstruction`` is the parent class which constructs the instance and search space.

    To meet the requirements of the structure, the input size for each spatial dimension should be:
    divisible by 2 ** (num_depths + 1).

    Args:
        dints_space: DiNTS search space. The value should be instance of `TopologyInstance` or `TopologySearch`.
        in_channels: number of input image channels.
        num_classes: number of output segmentation classes.
        act_name: activation name, default to 'RELU'.
        norm_name: normalization used in convolution blocks. Default to `InstanceNorm`.
        spatial_dims: spatial 2D or 3D inputs.
        use_downsample: use downsample in the stem.
            If ``False``, the search space will be in resolution [1, 1/2, 1/4, 1/8],
            if ``True``, the search space will be in resolution [1/2, 1/4, 1/8, 1/16].
        node_a: node activation numpy matrix. Its shape is `(num_depths, num_blocks + 1)`.
            +1 for multi-resolution inputs.
            In model searching stage, ``node_a`` can be None. In deployment stage, ``node_a`` cannot be None.
    rL   rM   rN   TrK   Nin_channelsrP   num_classesrU   rV   rW   rT   use_downsampleboolc	                   s  t    || _|j| _|j| _|j| _|dvr td| d|| _|d u r4t	| jd | jf| _
n|| _
ttj|f }	t | _t | _tt| jd | jd |||d|	| jd |ddddddd| _| jd	krpd
nd}
t| jD ]}|rttjdd|  |
dd|	|| j| d	ddddddt||| j| dt|d|	| j| | j|d  d	ddddddt||| j|d  d| jt|< tt|d|	| j|d  | j| d	ddddddt||| j| dtjd|
dd| jt|< qwttjdd|  |
dd|	|| j| d	ddddddt||| j| d| jt|< tt|d|	| j| | jt|d d d	ddddddt||| jt|d d dtjd|dk |
dd| jt|< qwd S )N)rX   rK   r   r   r7   r   )rU   rW   rT   T)r   out_channelsrR   striderS   groupsbiasdilationrK   	trilinearbilinearrX   )scale_factormodealign_cornersF)r   rT   channelsr   )r-   r.   dints_spacefilter_nums
num_blocks
num_depthsr   rb   r/   rh   node_ar   CONVr0   
ModuleDict	stem_downstem_upr1   r   stem_finalsranger,   Upsampler
   r	   strmax)r   r   r   r   rU   rW   rT   r   r   	conv_typer   res_idxr3   r   r   r.   c  s   







zDiNTS.__init__c                 C  s   dd |   D S )Nc                 S  s   g | ]\}}|qS r   r   r;   r   paramr   r   r   r=     s    z+DiNTS.weight_parameters.<locals>.<listcomp>named_parametersr   r   r   r   weight_parameters  s   zDiNTS.weight_parametersr   r   c                 C  s   g }t | jD ]#}| jt| }||}| jd | r"|| q|t| q| 	|}| j
d }d}td}	t | jd ddD ]'}
| jt|
 }|rZ|||
 |	 }	qE| j|d  |
 rld}|||
 }	qE| |	}|S )zd
        Prediction based on dynamic arch_code.

        Args:
            x: input tensor.
        r   r7   Fr   T)r   r   r   r   r    r   rl   r/   
zeros_liker   r   emptyr   r   )r   r   inputsdZ_mod_wx_outoutputsblk_idxstart_tempr   Z_mod_up
predictionr   r   r   r      s*   




zDiNTS.forward)r   rP   r   rP   rU   rV   rW   rV   rT   rP   r   r   )r   r   )r$   r%   r&   r'   r.   r   r    r6   r   r   r3   r   r   =  s    *
r   c                
      sH   e Zd ZdZddeddddddd	ifd	d
f
d fddZdd Z  ZS )r   a?	  
    The base class for `TopologyInstance` and `TopologySearch`.

    Args:
        arch_code: `[arch_code_a, arch_code_c]`, numpy arrays. The architecture codes defining the model.
            For example, for a ``num_depths=4, num_blocks=12`` search space:

            - `arch_code_a` is a 12x10 (10 paths) binary matrix representing if a path is activated.
            - `arch_code_c` is a 12x10x5 (5 operations) binary matrix representing if a cell operation is used.
            - `arch_code` in ``__init__()`` is used for creating the network and remove unused network blocks. If None,

            all paths and cells operations will be used, and must be in the searching stage (is_search=True).
        channel_mul: adjust intermediate channel number, default is 1.
        cell: operation of each node.
        num_blocks: number of blocks (depth in the horizontal direction) of the DiNTS search space.
        num_depths: number of image resolutions of the DiNTS search space: 1, 1/2, 1/4 ... in each dimension.
        use_downsample: use downsample in the stem. If False, the search space will be in resolution [1, 1/2, 1/4, 1/8],
            if True, the search space will be in resolution [1/2, 1/4, 1/8, 1/16].
        device: `'cpu'`, `'cuda'`, or device ID.


    Predefined variables:
        `filter_nums`: default to 32. Double the number of channels after downsample.
        topology related variables:

            - `arch_code2in`: path activation to its incoming node index (resolution). For depth = 4,
              arch_code2in = [0, 1, 0, 1, 2, 1, 2, 3, 2, 3]. The first path outputs from node 0 (top resolution),
              the second path outputs from node 1 (second resolution in the search space),
              the third path outputs from node 0, etc.
            - `arch_code2ops`: path activation to operations of upsample 1, keep 0, downsample -1. For depth = 4,
              arch_code2ops = [0, 1, -1, 0, 1, -1, 0, 1, -1, 0]. The first path does not change
              resolution, the second path perform upsample, the third perform downsample, etc.
            - `arch_code2out`: path activation to its output node index.
              For depth = 4, arch_code2out = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
              the first and second paths connects to node 0 (top resolution), the 3,4,5 paths connects to node 1, etc.
    N      ?   rK   rL   rM   rN   Tcpu	arch_codelist | Nonechannel_mulfloatr   rP   r   rT   rU   rV   rW   r   r   devicer   c              
     s  t    tdd t|d D } fdd|D | _|| _|| _tdt	j	
 d|t|	   || _|| _|| _|	| _|
| _d| _| jdkrQt|j| _n| jdkr\t|j| _g g }}ttj| j d D ]}||d tj d |d tj   qkg d	| j dd
 }t| jD ]
}||||g q|dd
 }|| _|| _|| _|d u rt| jt| jf| j}t| jt| j| jf| j}nt|d | j}t !t|d tj"| j| j}|| _#|| _$t%& | _'t| jD ]G}tt| jD ]=}| j#||f dkrD|| j| j| t|	  | j| j| t|	  | j| | j$||f | j| j| j| j't(||f< qqd S )Nc                 S  s   g | ]}d d|  qS )    rX   r   )r;   _ir   r   r   r=   ;      z1TopologyConstruction.__init__.<locals>.<listcomp>r7   c                   s   g | ]}t |  qS r   )rP   )r;   Zn_featr   r   r   r=   <  r   zC{} - Length of input patch is recommended to be a multiple of {:d}.rX   r   rK   )r   r   r7   r   ))r-   r.   tupler   r   r   r   printformatdatetimenowrP   rb   r   r   r   r   num_cell_opsri   r   r   ru   r   rl   extendarch_code2inarch_code2opsarch_code2outr/   rh   rp   
from_numpyFone_hotint64arch_code_arm   r0   r   	cell_treer   )r   r   r   cellr   r   rT   rU   rW   r   r   Zn_featsr   r   ir   mr   rm   r   r   r3   r   r   r.   ,  sj   



($(
zTopologyConstruction.__init__c                 C  s   dS )zOThis function to be implemented by the architecture instances or search spaces.Nr   r*   r   r   r   r    v  s   zTopologyConstruction.forward)r   r   r   r   r   rP   r   rP   rT   rP   rU   rV   rW   rV   r   r   r   r   r$   r%   r&   r'   ru   r.   r    r6   r   r   r3   r   r     s    '
Jr   c                
      sJ   e Zd ZdZddeddddddd	ifd	d
f
d fddZd ddZ  ZS )!r   z`
    Instance of the final searched architecture. Only used in re-training/inference stage.
    Nr   r   rK   rL   rM   rN   Tr   r   r   r   rP   r   rT   rU   rV   rW   r   r   r   r   c                   s6   |du r	t d t j|||||||||	|
d
 dS )Q
        Initialize DiNTS topology search space of neural architectures.
        Nz*arch_code not provided when not searching.
r   r   r   r   r   rT   rU   rW   r   r   )warningswarnr-   r.   )r   r   r   r   r   r   rT   rU   rW   r   r   r3   r   r   r.     s   

zTopologyInstance.__init__r   list[torch.Tensor]r   c           	      C  s   |}t | jD ]H}tjd|d j|d jdg| j }t| j| j	D ])\}}|rL| j
t||f }|j|| j|  dd}|| j|  | || j| < q#|}q|S )z4
        Args:
            x: input tensor.
        ro   r   dtyper   N)r   r   )r   r   r/   tensorr   r   r   rq   r   datar   r   r    r   r   )	r   r   r   r   r   r   
activationr2   _outr   r   r   r      s   $zTopologyInstance.forward)r   r   r   rP   r   rP   rT   rP   rU   rV   rW   rV   r   r   r   r   )r   r   r   r   r   r   r   r3   r   r   {  s    
 r   c                
      s   e Zd ZU dZded< ded< deddddd	d
ddifddf
d1 fddZd2d d!Zd"d# Zd3d4d&d'Z	d3d5d)d*Z
d+d, Zd-d. Zd/d0 Z  ZS )6r   a  
    DiNTS topology search space of neural architectures.

    Examples:

    .. code-block:: python

        from monai.networks.nets.dints import TopologySearch

        topology_search_space = TopologySearch(
            channel_mul=0.5, num_blocks=8, num_depths=4, use_downsample=True, spatial_dims=3)
        topology_search_space.get_ram_cost_usage(in_size=(2, 16, 80, 80, 80), full=True)
        multi_res_images = [
            torch.randn(2, 16, 80, 80, 80),
            torch.randn(2, 32, 40, 40, 40),
            torch.randn(2, 64, 20, 20, 20),
            torch.randn(2, 128, 10, 10, 10)]
        prediction = topology_search_space(image)
        for x in prediction: print(x.shape)
        # torch.Size([2, 16, 80, 80, 80])
        # torch.Size([2, 32, 40, 40, 40])
        # torch.Size([2, 64, 20, 20, 20])
        # torch.Size([2, 128, 10, 10, 10])

    Class method overview:

        - ``get_prob_a()``: convert learnable architecture weights to path activation probabilities.
        - ``get_ram_cost_usage()``: get estimated ram cost.
        - ``get_topology_entropy()``: get topology entropy loss in searching stage.
        - ``decode()``: get final binarized architecture code.
        - ``gen_mtx()``: generate variables needed for topology search.

    Predefined variables:
        - `tidx`: index used to convert path activation matrix T = (depth,depth) in transfer_mtx to
          path activation arch_code (1,3*depth-2), for depth = 4, tidx = [0, 1, 4, 5, 6, 9, 10, 11, 14, 15],
          A tidx (10 binary values) represents the path activation.
        - `transfer_mtx`: feasible path activation matrix (denoted as T) given a node activation pattern.
          It is used to convert path activation pattern (1, paths) to node activation (1, nodes)
        - `node_act_list`: all node activation [2^num_depths-1, depth]. For depth = 4, there are 15 node activation
          patterns, each of length 4. For example, [1,1,0,0] means nodes 0, 1 are activated (with input paths).
        - `all_connect`: All possible path activations. For depth = 4,
          all_connection has 1024 vectors of length 10 (10 paths).
          The return value will exclude path activation of all 0.
    z
list[list]node2outnode2inr   Nr   rK   rL   rM   rN   Tr   r   r   r   r   r   rP   r   rT   rU   rV   rW   r   r   r   r   c                   s  t  j|||||||||	|
d
 g }tj}t|j d D ]}||d | j |d |  d |d |   q|_|\}}}t	
|_fddttjD _|_t	
|_t	jtjjf_tjD ]6}ttjD ],}j||f dkrjt||f  t	 fdd jjdj D j||f< qqwttjtjjdd	j ! _"ttjtjd
d	j ! _#ddg_$dS )r   r   rX   r7   c                      i | ]
}t  j| |qS r   )r   node_act_listr;   r   r   r   r   
<dictcomp>	      z+TopologySearch.__init__.<locals>.<dictcomp>c                   s   g | ]	}|j  jj  qS r   )rF   r   )r;   r   )
cell_interr   r   r=     s    z+TopologySearch.__init__.<locals>.<listcomp>Ng{Gz?r   log_alpha_alog_alpha_c)%r-   r.   ru   r   r   r   rl   tidxgen_mtxrg   asarrayr   ri   node_act_dicttransfer_mtx
child_listzerosr   r   r   rF   r   r   r   arrayr   re   r0   	Parameterr/   normal_rp   r   requires_grad_r  r   _arch_param_names)r   r   r   r   r   r   rT   rU   rW   r   r   r  _dr   r  r   r  r   r   r3   )r   r   r   r.     sT   6&zTopologySearch.__init__depthc                   s   t j| d }td|d }g }|D ]0}t||f}t|D ]}|| ||d t j |d t j d |d t j  f< q|| qtd|d dd }i }	|D ]  fdd|D }
|
|	tt < qR|	||dd fS )a  
        Generate elements needed in decoding and topology.

            - `transfer_mtx`: feasible path activation matrix (denoted as T) given a node activation pattern.
               It is used to convert path activation pattern (1, paths) to node activation (1, nodes)
            - `node_act_list`: all node activation [2^num_depths-1, depth]. For depth = 4, there are 15 node activation
               patterns, each of length 4. For example, [1,1,0,0] means nodes 0, 1 are activated (with input paths).
            - `all_connect`: All possible path activations. For depth = 4,
              all_connection has 1024 vectors of length 10 (10 paths).
              The return value will exclude path activation of all 0.
        rX   r   r7   Nc                   s4   g | ]}t |d d ktt  k r|qS r9   )rg   sumastyperP   r	  allr:   r   r   r   r=   D  s   4 z*TopologySearch.gen_mtx.<locals>.<listcomp>)	ru   r   r?   rg   r  r   rl   r   r	  )r   r  rA   Zall_connectmtxr   mar   r   r  Zarch_code_mtxr   r  r   r  "  s   8zTopologySearch.gen_mtxc                   s    fdd   D S )Nc                   s   g | ]\}}| j vr|qS r   )r  r   r   r   r   r=   J  s    z4TopologySearch.weight_parameters.<locals>.<listcomp>r   r   r   r   r   r   I  s   z TopologySearch.weight_parametersFrB   c                   sz   t | j dd  d  d }|r9t | j| j fddt	| j
D }t |}||fS d|fS )a  
        Get final path and child model probabilities from architecture weights `log_alpha_a`.
        This is used in forward pass, getting training loss, and final decoding.

        Args:
            child: return child probability (used in decoding)
        Return:
            arch_code_prob_a: the path activation probability of size:
                `[number of blocks, number of paths in each block]`.
                For 12 blocks, 4 depths search space, the size is [12,10]
            probs_a: The probability of all child models (size 1023x10). Each child model is a path activation pattern
                 (1D vector of length 10 for 10 paths). In total 1023 child models (2^10 -1)
        r7   r   c                   s:   g | ]} |  d  d  |     d|  qS )r7   r   )prod)r;   r   Z_arch_code_prob_anormZpath_activationr   r   r=   `  s    
z-TopologySearch.get_prob_a.<locals>.<listcomp>N)r/   sigmoidr   r  	unsqueezer   r  rp   r   r   r   stack)r   rB   arch_code_prob_aprobs_ar   r  r   
get_prob_aL  s   
zTopologySearch.get_prob_afullc              	   C  s>  |d }t || j d }g }t| jD ]}||| j|  |d|     qtj	|tj
| jddt| j  }| jdd\}}tj| jdd}	|rW| }|d	 t| jjtj
| jd}
d
}t| jD ]+}tt| jD ]!}||||f d	|
||f |	||f     || j|   7 }qtqk|d d d S )a  
        Get estimated output tensor size to approximate RAM consumption.

        Args:
            in_size: input image shape (4D/5D, ``[BCHW[D]]``) at the highest resolution level.
            full: full ram cost usage with all probability of 1.
        r   NrX   r   FrB   r   dimr7   ro   r      i   )rg   r	  rb   r   r   rl   r   r  r/   r   float32r   rP   r   r  r   softmaxr  detachfill_r   rF   rp   r   ri   r   r  )r   in_sizer  
batch_size
image_sizesizesr   r  r  Z	cell_probrF   usager   path_idxr   r   r   get_ram_cost_usagel  s0   	&"

z!TopologySearch.get_ram_cost_usagec              	   C  s  t | dr| j}| j}ndd tt| jD }dd tt| jD }tt| jD ]_}t| j	t| j	}}tt| j
D ]"}|| j
|   | j| | 7  < || j|   | j| | 7  < qA|dkt}|dkt}|| jt|  | || jt|  | q+|| _|| _d}t| jd D ]A}	d}
tt| jD ]1}||	|| f  }||	d || f  }|
|t|d  d| td| d    7 }
q||
7 }q|S )z
        Get topology entropy loss at searching stage.

        Args:
            probs: path activation probabilities
        r   c                 S     g | ]}g qS r   r   r:   r   r   r   r=         z7TopologySearch.get_topology_entropy.<locals>.<listcomp>c                 S  r/  r   r   r:   r   r   r   r=     r0  r7   r   h㈵>)hasattrr   r   r   ri   r   r  rg   r  r   r   r   r  rP   r  r   rl   r   r  r/   log)r   probsr   r   	child_idxZ_node_inZ	_node_outr   entr   Zblk_entnode_idxZ_node_pZ
_out_probsr   r   r   get_topology_entropy  s2   
 "4
z#TopologySearch.get_topology_entropyc              	     s   j dd\}} jt|dj   }tt j	ddj  }|j  }t
dt j j  d dt j j  d f} fddtt jD }t
t jt jf}tt jD ]L}t
 jt}	tt j| D ]}
|	 j|
    j| |
 7  < q~|	dkt}	 jt|	 D ]}|t|  j t }d|||f< qqjt
|d d  d	 |dddt j f< td jD ]A}|t
t
|| d  d	 t jdf |d|d t j  d|t j  d|t j  d|d t j  f< qd	|d jd t j  d jt j  df< t|}t|ddddd
\}}}d\}}t
 jt jf}t
 jd  jf}	 || }|dkren6|d t j } j| ||ddf< tt jD ]}|| j| f  |||f 7  < q|d8 }q[tt jD ]}|| j| f  |d|f 7  < q|dkt}||||fS )a  
        Decode network log_alpha_a/log_alpha_c using dijkstra shortest path algorithm.

        `[node_a, arch_code_a, arch_code_c, arch_code_a_max]` is decoded when using ``self.decode()``.

        For example, for a ``num_depths=4``, ``num_blocks=12`` search space:

            - ``node_a`` is a 4x13 binary matrix representing if a feature node is activated
              (13 because of multi-resolution inputs).
            - ``arch_code_a`` is a 12x10 (10 paths) binary matrix representing if a path is activated.
            - ``arch_code_c`` is a 12x10x5 (5 operations) binary matrix representing if a cell operation is used.

        Return:
            arch_code with maximum probability
        Tr   r   r7   c                   r   r   )r   r  r   r   r   r   r     r   z)TopologySearch.decode.<locals>.<dictcomp>r   r1  gMbP?)csgraphdirectedindicesmin_onlyreturn_predecessors)r   r   N)r  r  r/   argmaxr   r   numpyr   r%  r  rg   r  ri   r   r   r   r  rP   r   r  r   flattenr  r3  tiler   r   r   )r   r4  r  Zarch_code_a_maxrm   ZamtxZ
path2childZsub_amtxr5  Z	_node_actr-  r  Zconnect_child_idxr   graphdist_matrixpredecessorssourcesindexZa_idxr   r   r   r   r   r   decode  s`    .".,&&4

&	&zTopologySearch.decodec           
   	   C  s   | j dd\}}|}t| jD ]J}dg| j }t| j| j  D ]3\}}|rVt	j
| j||f dd}	|| j|   | jt||f || j|  |	d|||f  7  < q#|}q|S )z
        Prediction based on dynamic arch_code.

        Args:
            x: a list of `num_depths` input tensors as a multi-resolution input.
                tensor is of shape `BCHW[D]` where `C` must match `self.filter_nums`.
        Fr   ro   r   r!  )r   )r  r   r   r   rq   r   r   r   r?  r   r%  r  r   r   r   r   )
r   r   r  r  r   r   r   r   r   _wr   r   r   r      s   	 "
zTopologySearch.forward)r   r   r   r   r   rP   r   rP   rT   rP   rU   rV   rW   rV   r   r   r   r   )r  rP   )F)rB   r   )r  r   )r$   r%   r&   r'   __annotations__ru   r.   r  r   r  r.  r8  rG  r    r6   r   r   r3   r   r     s,   
 -

>' !&Lr   )1
__future__r   r   r   typingr   r?  rg   r/   torch.nnr0   torch.nn.functional
functionalr   Z!monai.networks.blocks.dints_blockr   r   r   r   monai.networks.layers.factoriesr   monai.networks.layers.utilsr	   r
   monai.utilsr   r   r<   r   __all__jit	interfaceModuler   r)   r,   r?   IdentityrC   rJ   rY   r]   ra   rc   ru   r   r   r   r   r   r   r   r   <module>   sB   "{ Ju8