U
    Ph                     @  s  d dl mZ d dlZd dlZd dlmZ d dlZd dlZd dl	m
Z
 d dlm
  mZ d dlmZmZmZmZ d dlmZ d dlmZmZ d dlmZ edd	d
\ZZeddd
\ZZddddgZejjG dd dej
j Z!ejjG dd dej
j Z"G dd de"Z#dd Z$G dd de
j%Z&G dd deZ'G dd deZ(G dd  d eZ)G d!d" d"eZ*G d#d$ d$e
j Z+G d%d& d&e!Z,G d'd de
j Z-G d(d de
j Z.G d)d de.Z/G d*d de.Z0dS )+    )annotationsN)Optional)ActiConvNormBlockFactorizedIncreaseBlockFactorizedReduceBlockP3DActiConvNormBlock)Conv)get_act_layerget_norm_layer)optional_importzscipy.sparse
csr_matrixnamezscipy.sparse.csgraphdijkstraDiNTSTopologyConstructionTopologyInstanceTopologySearchc                   @  s"   e Zd ZdZddddddZdS )CellInterfacez"interface for torchscriptable Celltorch.TensorOptional[torch.Tensor]xweightreturnc                 C  s   d S N selfr   r   r   r   N/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/nets/dints.pyforward,   s    zCellInterface.forwardN__name__
__module____qualname____doc__r    r   r   r   r   r   (   s   r   c                   @  s    e Zd ZdZdddddZdS )StemInterfacez"interface for torchscriptable Stemr   r   r   c                 C  s   d S r   r   r   r   r   r   r   r    4   s    zStemInterface.forwardNr!   r   r   r   r   r&   0   s   r&   c                      s0   e Zd ZdZ fddZdddddZ  ZS )StemTSz wrapper for torchscriptable Stemc                   s   t    tjj| | _d S r   )super__init__torchnn
Sequentialmod)r   r/   	__class__r   r   r+   ;   s    
zStemTS.__init__r   r'   c                 C  s
   |  |S r   )r/   r(   r   r   r   r    ?   s    zStemTS.forwardr"   r#   r$   r%   r+   r    __classcell__r   r   r0   r   r)   8   s   r)   c                 C  s>   | |krdgdggS t | d |}dd |D dd |D  S )z>use depth first search to find all path activation combinationr      c                 S  s   g | ]}d g| qS r   r   .0_r   r   r   
<listcomp>H   s     z_dfs.<locals>.<listcomp>c                 S  s   g | ]}d g| qS )r4   r   r6   r   r   r   r9   H   s     )_dfs)nodepathschildr   r   r   r:   C   s    r:   c                      s   e Zd Z fddZ  ZS )_IdentityWithRAMCostc                   s   t  j|| d| _d S Nr   r*   r+   ram_cost)r   argskwargsr0   r   r   r+   M   s    z_IdentityWithRAMCost.__init__r"   r#   r$   r+   r3   r   r   r0   r   r>   K   s   r>   c                	      sB   e Zd ZdZdddddiffdddddddd	 fd
dZ  ZS )_ActiConvNormBlockWithRAMCosta!  The class wraps monai layers with ram estimation. The ram_cost = total_ram/output_size is estimated.
    Here is the estimation:
     feature_size = output_size/out_channel
     total_ram = ram_cost * output_size
     total_ram = in_channel * feature_size (activation map) +
                 in_channel * feature_size (convolution map) +
                 out_channel * feature_size (normalization)
               = (2*in_channel + out_channel) * output_size/out_channel
     ram_cost = total_ram/output_size = 2 * in_channel/out_channel + 1
       RELUINSTANCEaffineTinttuple | str)
in_channelout_channelkernel_sizepaddingspatial_dimsact_name	norm_namec              	     s.   t  ||||||| d|| d  | _d S )Nr4      r@   )r   rL   rM   rN   rO   rP   rQ   rR   r0   r   r   r+   ^   s    
z&_ActiConvNormBlockWithRAMCost.__init__)r"   r#   r$   r%   r+   r3   r   r   r0   r   rE   R   s
   
rE   c                	      s>   e Zd Zdddddiffdddddddd fd	d
Z  ZS ) _P3DActiConvNormBlockWithRAMCostr   rG   rH   rI   TrJ   rK   )rL   rM   rN   rO   p3dmoderQ   rR   c              	     s.   t  ||||||| dd| |  | _d S NrS   r@   )r   rL   rM   rN   rO   rU   rQ   rR   r0   r   r   r+   n   s    
z)_P3DActiConvNormBlockWithRAMCost.__init__rD   r   r   r0   r   rT   l   s   
rT   c                      s:   e Zd Zdddddiffdddddd fd	d
Z  ZS )#_FactorizedIncreaseBlockWithRAMCostrF   rG   rH   rI   TrJ   rK   rL   rM   rP   rQ   rR   c                   s*   t  ||||| d| | d | _d S rV   r@   r   rL   rM   rP   rQ   rR   r0   r   r   r+      s    z,_FactorizedIncreaseBlockWithRAMCost.__init__rD   r   r   r0   r   rW   ~   s   
rW   c                      s:   e Zd Zdddddiffdddddd fd	d
Z  ZS )!_FactorizedReduceBlockWithRAMCostrF   rG   rH   rI   TrJ   rK   rX   c                   s0   t  ||||| || d| j  d | _d S )NrS   rF   )r*   r+   _spatial_dimsrA   rY   r0   r   r   r+      s    z*_FactorizedReduceBlockWithRAMCost.__init__rD   r   r   r0   r   rZ      s   
rZ   c                      s<   e Zd ZdZdddd fddZddd	d
ddZ  ZS )MixedOpa#  
    The weighted averaging of cell operations.
    Args:
        c: number of output channels.
        ops: a dictionary of operations. See also: ``Cell.OPS2D`` or ``Cell.OPS3D``.
        arch_code_c: binary cell operation code. It represents the operation results added to the output.
    NrJ   dict)copsc                   s^   t    |d kr tt|}t | _t||D ]$\}}|dkr4| j	|| | q4d S r?   )
r*   r+   nponeslenr-   
ModuleListr_   zipappend)r   r^   r_   arch_code_cZarch_cop_namer0   r   r   r+      s    

zMixedOp.__init__r   r   r   r   c                 C  sV   d}|dk	r| |}t| jD ]0\}}|dkr<||| n|||||   }q |S )z
        Args:
            x: input tensor.
            weight: learnable architecture weights for cell operations. arch_code_c are derived from it.
        Return:
            out: weighted average of the operation results.
                N)to	enumerater_   )r   r   r   outidx_opr   r   r   r       s    
*zMixedOp.forward)N)Nr2   r   r   r0   r   r\      s   	r\   c                      s   e Zd ZdZdZdd dd dZdd dd d	d d
d dd dZeee	e
dZddddddiffddddddd fddZddddddZ  ZS )Cella  
    The basic class for cell operation search, which contains a preprocessing operation and a mixed cell operation.
    Each cell is defined on a `path` in the topology search space.
    Args:
        c_prev: number of input channels
        c: number of output channels
        rate: resolution change rate. It represents the preprocessing operation before the mixed cell operation.
            ``-1`` for 2x downsample, ``1`` for 2x upsample, ``0`` for no change of resolution.
        arch_code_c: cell operation code
    rF   c                 C  s   t  S r   r>   _cr   r   r   <lambda>       zCell.<lambda>c                 C  s   t | | ddddS )NrF   r4   rS   rO   rP   rE   r^   r   r   r   rs      rt   skip_connectZconv_3x3c                 C  s   t  S r   rp   rq   r   r   r   rs      rt   c                 C  s   t | | ddddS )NrF   r4   ru   rv   rw   r   r   r   rs      rt   c                 C  s   t | | ddddS )NrF   r4   r   rO   rU   rT   rw   r   r   r   rs      rt   c                 C  s   t | | ddddS )NrF   r4   rz   r{   rw   r   r   r   rs      rt   c                 C  s   t | | ddddS )NrF   r4   rS   rz   r{   rw   r   r   r   rs      rt   ry   Z
conv_3x3x3Z
conv_3x3x1Z
conv_3x1x3Z
conv_1x3x3)updownidentityalign_channelsNrG   rH   rI   TrJ   rK   )c_prevr^   raterP   rQ   rR   c              	     sR  t    | _| _| _|dkrF jd || j j jd _nf|dkrp jd || j j jd _n<||kr jd   _n$ jd ||dd j j jd _d	d
  fdd
d _dd
  fdd
 fdd
 fdd
 fdd
d _i  _	 jdkr j _	n( jdkr, j _	nt
d j dt| j	| _d S )Nr~   )rP   rQ   rR   r4   r}   r   r   r   c                 S  s   t  S r   rp   rq   r   r   r   rs     rt   zCell.__init__.<locals>.<lambda>c              	     s   t | | ddd j jdS )NrF   r4   rS   rO   rP   rQ   rR   rE   	_act_name
_norm_namerw   r   r   r   rs     s         rx   c                 S  s   t  S r   rp   rq   r   r   r   rs     rt   c              	     s   t | | ddd j jdS )NrF   r4   r   r   rw   r   r   r   rs     s         c              	     s   t | | ddd j jdS )NrF   r4   r   rO   rU   rQ   rR   rT   r   r   rw   r   r   r   rs     s         c              	     s   t | | ddd j jdS )NrF   r4   r   r   rw   r   r   r   rs      s         c              	     s   t | | ddd j jdS )NrF   r4   rS   r   r   rw   r   r   r   rs   #  s         r|   rS   rF   Spatial dimensions  is not supported.)r*   r+   r[   r   r   ConnOPS
preprocessOPS2DOPS3DZOPSNotImplementedErrorr\   op)r   r   r^   r   rf   rP   rQ   rR   r0   r   r   r+      s^    

    
    
      
	





zCell.__init__r   r   r   c                 C  s   |  |}| ||}|S )zi
        Args:
            x: input tensor
            weight: weights for different operations.
        )r   r   r   r   r   r   r    2  s    
zCell.forward)r"   r#   r$   r%   
DIRECTIONSr   r   rW   rZ   r>   rE   r   r+   r    r3   r   r   r0   r   ro      s,   

Bro   c                      sZ   e Zd ZdZddddifdddfddd	d	dd
d fddZdd ZddddZ  ZS )r   a  
    Reimplementation of DiNTS based on
    "DiNTS: Differentiable Neural Network Topology Search for 3D Medical Image Segmentation
    <https://arxiv.org/abs/2103.15954>".

    The model contains a pre-defined multi-resolution stem block (defined in this class) and a
    DiNTS space (defined in :py:class:`monai.networks.nets.TopologyInstance` and
    :py:class:`monai.networks.nets.TopologySearch`).

    The stem block is for: 1) input downsample and 2) output upsample to original size.
    The model downsamples the input image by 2 (if ``use_downsample=True``).
    The downsampled image is downsampled by [1, 2, 4, 8] times (``num_depths=4``) and used as input to the
    DiNTS search space (``TopologySearch``) or the DiNTS instance (``TopologyInstance``).

        - ``TopologyInstance`` is the final searched model. The initialization requires the searched architecture codes.
        - ``TopologySearch`` is a multi-path topology and cell operation search space.
          The architecture codes will be initialized as one.
        - ``TopologyConstruction`` is the parent class which constructs the instance and search space.

    To meet the requirements of the structure, the input size for each spatial dimension should be:
    divisible by 2 ** (num_depths + 1).

    Args:
        dints_space: DiNTS search space. The value should be instance of `TopologyInstance` or `TopologySearch`.
        in_channels: number of input image channels.
        num_classes: number of output segmentation classes.
        act_name: activation name, default to 'RELU'.
        norm_name: normalization used in convolution blocks. Default to `InstanceNorm`.
        spatial_dims: spatial 2D or 3D inputs.
        use_downsample: use downsample in the stem.
            If ``False``, the search space will be in resolution [1, 1/2, 1/4, 1/8],
            if ``True``, the search space will be in resolution [1/2, 1/4, 1/8, 1/16].
        node_a: node activation numpy matrix. Its shape is `(num_depths, num_blocks + 1)`.
            +1 for multi-resolution inputs.
            In model searching stage, ``node_a`` can be None. In deployment stage, ``node_a`` cannot be None.
    rG   rH   rI   TrF   NrJ   rK   bool)in_channelsnum_classesrQ   rR   rP   use_downsamplec	                   s  t    || _|j| _|j| _|j| _|dkr@td| d|| _|d krht	| jd | jf| _
n|| _
ttj|f }	t | _t | _tt| jd | jd |||d|	| jd |ddddddd| _| jd	krd
nd}
t| jD ]}|rttjdd|  |
dd|	|| j| d	ddddddt||| j| dt|d|	| j| | j|d  d	ddddddt||| j|d  d| jt|< tt|d|	| j|d  | j| d	ddddddt||| j| dtjd|
dd| jt|< qttjdd|  |
dd|	|| j| d	ddddddt||| j| d| jt|< tt|d|	| j| | jt|d d d	ddddddt||| jt|d d dtjd|dk |
dd| jt|< qd S )N)rS   rF   r   r   r4   r   )rQ   rR   rP   T)r   out_channelsrN   striderO   groupsbiasdilationrF   	trilinearbilinearrS   )scale_factormodealign_cornersF)r   rP   channelsr   )r*   r+   dints_spacefilter_nums
num_blocks
num_depthsr   r[   r,   ra   node_ar   CONVr-   
ModuleDict	stem_downstem_upr.   r   stem_finalsranger)   Upsampler
   r	   strmax)r   r   r   r   rQ   rR   rP   r   r   	conv_typer   res_idxr0   r   r   r+   c  s    







  zDiNTS.__init__c                 C  s   dd |   D S )Nc                 S  s   g | ]\}}|qS r   r   r7   r   paramr   r   r   r9     s     z+DiNTS.weight_parameters.<locals>.<listcomp>named_parametersr   r   r   r   weight_parameters  s    zDiNTS.weight_parametersr   )r   c                 C  s   g }t | jD ]F}| jt| }||}| jd | rD|| q|t| q| 	|}| j
d }d}td}	t | jd ddD ]N}
| jt|
 }|r|||
 |	 }	q| j|d  |
 rd}|||
 }	q| |	}|S )zd
        Prediction based on dynamic arch_code.

        Args:
            x: input tensor.
        r   r4   Fr   T)r   r   r   r   r    r   re   r,   
zeros_liker   r   emptyr   r   )r   r   inputsdZ_mod_wx_outoutputsblk_idxstart_tempr   Z_mod_up
predictionr   r   r   r      s(    




zDiNTS.forward)r"   r#   r$   r%   r+   r   r    r3   r   r   r0   r   r   =  s   *
c                      s\   e Zd ZdZddeddddddd	ifd	d
f
dddddddddd	 fddZdd Z  ZS )r   a?	  
    The base class for `TopologyInstance` and `TopologySearch`.

    Args:
        arch_code: `[arch_code_a, arch_code_c]`, numpy arrays. The architecture codes defining the model.
            For example, for a ``num_depths=4, num_blocks=12`` search space:

            - `arch_code_a` is a 12x10 (10 paths) binary matrix representing if a path is activated.
            - `arch_code_c` is a 12x10x5 (5 operations) binary matrix representing if a cell operation is used.
            - `arch_code` in ``__init__()`` is used for creating the network and remove unused network blocks. If None,

            all paths and cells operations will be used, and must be in the searching stage (is_search=True).
        channel_mul: adjust intermediate channel number, default is 1.
        cell: operation of each node.
        num_blocks: number of blocks (depth in the horizontal direction) of the DiNTS search space.
        num_depths: number of image resolutions of the DiNTS search space: 1, 1/2, 1/4 ... in each dimension.
        use_downsample: use downsample in the stem. If False, the search space will be in resolution [1, 1/2, 1/4, 1/8],
            if True, the search space will be in resolution [1/2, 1/4, 1/8, 1/16].
        device: `'cpu'`, `'cuda'`, or device ID.


    Predefined variables:
        `filter_nums`: default to 32. Double the number of channels after downsample.
        topology related variables:

            - `arch_code2in`: path activation to its incoming node index (resolution). For depth = 4,
              arch_code2in = [0, 1, 0, 1, 2, 1, 2, 3, 2, 3]. The first path outputs from node 0 (top resolution),
              the second path outputs from node 1 (second resolution in the search space),
              the third path outputs from node 0, etc.
            - `arch_code2ops`: path activation to operations of upsample 1, keep 0, downsample -1. For depth = 4,
              arch_code2ops = [0, 1, -1, 0, 1, -1, 0, 1, -1, 0]. The first path does not change
              resolution, the second path perform upsample, the third perform downsample, etc.
            - `arch_code2out`: path activation to its output node index.
              For depth = 4, arch_code2out = [0, 0, 1, 1, 1, 2, 2, 2, 3, 3],
              the first and second paths connects to node 0 (top resolution), the 3,4,5 paths connects to node 1, etc.
    N      ?   rF   rG   rH   rI   Tcpulist | NonefloatrJ   rK   r   r   )		arch_codechannel_mulr   r   rP   rQ   rR   r   devicec              
     s  t    tdd t|d D } fdd|D | _|| _|| _tdt	j	
 d|t|	   || _|| _|| _|	| _|
| _d| _| jdkrt|j| _n| jdkrt|j| _g g  }}ttj| j d D ]*}||d tj d |d tj   qd	ddg| j dd	 }t| jD ]}||||g q$|dd	 }|| _|| _|| _|d krt| jt| jf| j}t| jt| j| jf| j}n>t|d | j}t !t|d tj"| j| j}|| _#|| _$t%& | _'t| jD ]}tt| jD ]z}| j#||f dkr|| j| j| t|	  | j| j| t|	  | j| | j$||f | j| j| j| j't(||f< qqd S )
Nc                 S  s   g | ]}d d|  qS )    rS   r   )r7   _ir   r   r   r9   ;  s     z1TopologyConstruction.__init__.<locals>.<listcomp>r4   c                   s   g | ]}t |  qS r   )rJ   )r7   Zn_featr   r   r   r9   <  s     zC{} - Length of input patch is recommended to be a multiple of {:d}.rS   r   rF   r   ))r*   r+   tupler   r   r   r   printformatdatetimenowrJ   r[   r   r   r   r   num_cell_opsrb   r   r   ro   r   re   extendarch_code2inarch_code2opsarch_code2outr,   ra   rj   
from_numpyFone_hotint64arch_code_arf   r-   r   	cell_treer   )r   r   r   cellr   r   rP   rQ   rR   r   r   Zn_featsr   r   ir   mr   rf   r   r   r0   r   r   r+   ,  sf    
 


(
$(
zTopologyConstruction.__init__c                 C  s   dS )zOThis function to be implemented by the architecture instances or search spaces.Nr   r(   r   r   r   r    v  s    zTopologyConstruction.forwardr"   r#   r$   r%   ro   r+   r    r3   r   r   r0   r   r     s   '
$Jc                
      sb   e Zd ZdZddeddddddd	ifd	d
f
ddddddddd fddZdddddZ  ZS )r   z`
    Instance of the final searched architecture. Only used in re-training/inference stage.
    Nr   r   rF   rG   rH   rI   Tr   r   rJ   rK   r   r   )r   r   r   rP   rQ   rR   r   r   c                   s6   |dkrt d t j|||||||||	|
d
 dS )Q
        Initialize DiNTS topology search space of neural architectures.
        Nz*arch_code not provided when not searching.
r   r   r   r   r   rP   rQ   rR   r   r   )warningswarnr*   r+   )r   r   r   r   r   r   rP   rQ   rR   r   r   r0   r   r   r+     s    
zTopologyInstance.__init__zlist[torch.Tensor]r'   c           	      C  s   |}t | jD ]}tjd|d j|d jdg| j }t| j| j	D ]R\}}|rF| j
t||f }|j|| j|  dd}|| j|  | || j| < qF|}q|S )z4
        Args:
            x: input tensor.
        ri   r   dtyper   Nrh   )r   r   r,   tensorr   r   r   rk   r   datar   r   r    r   r   )	r   r   r   r   r   r   
activationr/   _outr   r   r   r      s    $zTopologyInstance.forwardr   r   r   r0   r   r   {  s   
" c                      s   e Zd ZU dZded< ded< deddddd	d
ddifddf
dddddddddd	 fddZddddZdd Zd)ddddZ	d*dd d!d"Z
d#d$ Zd%d& Zd'd( Z  ZS )+r   a  
    DiNTS topology search space of neural architectures.

    Examples:

    .. code-block:: python

        from monai.networks.nets.dints import TopologySearch

        topology_search_space = TopologySearch(
            channel_mul=0.5, num_blocks=8, num_depths=4, use_downsample=True, spatial_dims=3)
        topology_search_space.get_ram_cost_usage(in_size=(2, 16, 80, 80, 80), full=True)
        multi_res_images = [
            torch.randn(2, 16, 80, 80, 80),
            torch.randn(2, 32, 40, 40, 40),
            torch.randn(2, 64, 20, 20, 20),
            torch.randn(2, 128, 10, 10, 10)]
        prediction = topology_search_space(image)
        for x in prediction: print(x.shape)
        # torch.Size([2, 16, 80, 80, 80])
        # torch.Size([2, 32, 40, 40, 40])
        # torch.Size([2, 64, 20, 20, 20])
        # torch.Size([2, 128, 10, 10, 10])

    Class method overview:

        - ``get_prob_a()``: convert learnable architecture weights to path activation probabilities.
        - ``get_ram_cost_usage()``: get estimated ram cost.
        - ``get_topology_entropy()``: get topology entropy loss in searching stage.
        - ``decode()``: get final binarized architecture code.
        - ``gen_mtx()``: generate variables needed for topology search.

    Predefined variables:
        - `tidx`: index used to convert path activation matrix T = (depth,depth) in transfer_mtx to
          path activation arch_code (1,3*depth-2), for depth = 4, tidx = [0, 1, 4, 5, 6, 9, 10, 11, 14, 15],
          A tidx (10 binary values) represents the path activation.
        - `transfer_mtx`: feasible path activation matrix (denoted as T) given a node activation pattern.
          It is used to convert path activation pattern (1, paths) to node activation (1, nodes)
        - `node_act_list`: all node activation [2^num_depths-1, depth]. For depth = 4, there are 15 node activation
          patterns, each of length 4. For example, [1,1,0,0] means nodes 0, 1 are activated (with input paths).
        - `all_connect`: All possible path activations. For depth = 4,
          all_connection has 1024 vectors of length 10 (10 paths).
          The return value will exclude path activation of all 0.
    z
list[list]node2outnode2inr   Nr   rF   rG   rH   rI   Tr   r   r   rJ   rK   r   r   )	r   r   r   r   rP   rQ   rR   r   r   c                   s  t  j|||||||||	|
d
 g }tj}t|j d D ]8}||d | j |d |  d |d |   q<|_|\}}}t	
|_fddttjD _|_t	
|_t	jtjjf_tjD ]p ttjD ]\j f dkr t	 fddjt f jjdj D j f< q qttjtjjdd	j ! _"ttjtjd
d	j ! _#ddg_$dS )r   r   rS   r4   c                   s   i | ]}t  j| |qS r   )r   node_act_listr7   r   r   r   r   
<dictcomp>	  s      z+TopologySearch.__init__.<locals>.<dictcomp>c                   s(   g | ] }|j jt f jj  qS r   )rA   r   r   r   )r7   r   r   r   r   r   r   r9     s   z+TopologySearch.__init__.<locals>.<listcomp>Ng{Gz?r   log_alpha_alog_alpha_c)%r*   r+   ro   r   r   r   re   tidxgen_mtxr`   asarrayr   rb   node_act_dicttransfer_mtx
child_listzerosr   r   r   rA   r   arrayr   r   r   r_   r-   	Parameterr,   normal_rj   r   requires_grad_r   r   _arch_param_names)r   r   r   r   r   r   rP   rQ   rR   r   r   r   _dr   r   r   r   r0   r   r   r+     sV    6 &zTopologySearch.__init__)depthc                   s   t j| d }td|d }g }|D ]`}t||f}t|D ]:}|| ||d t j |d t j d |d t j  f< q>|| q$td|d dd }i }	|D ](  fdd|D }
|
|	tt < q|	||dd fS )a  
        Generate elements needed in decoding and topology.

            - `transfer_mtx`: feasible path activation matrix (denoted as T) given a node activation pattern.
               It is used to convert path activation pattern (1, paths) to node activation (1, nodes)
            - `node_act_list`: all node activation [2^num_depths-1, depth]. For depth = 4, there are 15 node activation
               patterns, each of length 4. For example, [1,1,0,0] means nodes 0, 1 are activated (with input paths).
            - `all_connect`: All possible path activations. For depth = 4,
              all_connection has 1024 vectors of length 10 (10 paths).
              The return value will exclude path activation of all 0.
        rS   r   r4   Nc                   s4   g | ],}t |d d ktt  k r|qS r5   )r`   sumastyperJ   r   allr6   r   r   r   r9   F  s     $ z*TopologySearch.gen_mtx.<locals>.<listcomp>)	ro   r   r:   r`   r   r   re   r   r   )r   r   r<   Zall_connectmtxr   mar   r   r   Zarch_code_mtxr   r  r   r   $  s    8zTopologySearch.gen_mtxc                   s    fdd   D S )Nc                   s   g | ]\}}| j kr|qS r   )r   r   r   r   r   r9   L  s     
 z4TopologySearch.weight_parameters.<locals>.<listcomp>r   r   r   r   r   r   K  s    z TopologySearch.weight_parametersFr=   c                   sz   t | j dd  d  d }|rrt | j| j fddt	| j
D }t |}||fS d|fS )a  
        Get final path and child model probabilities from architecture weights `log_alpha_a`.
        This is used in forward pass, getting training loss, and final decoding.

        Args:
            child: return child probability (used in decoding)
        Return:
            arch_code_prob_a: the path activation probability of size:
                `[number of blocks, number of paths in each block]`.
                For 12 blocks, 4 depths search space, the size is [12,10]
            probs_a: The probability of all child models (size 1023x10). Each child model is a path activation pattern
                 (1D vector of length 10 for 10 paths). In total 1023 child models (2^10 -1)
        r4   r   c                   s:   g | ]2} |  d  d  |     d|  qS )r4   r   )prod)r7   r   Z_arch_code_prob_anormZpath_activationr   r   r9   b  s   
z-TopologySearch.get_prob_a.<locals>.<listcomp>N)r,   sigmoidr   r  	unsqueezer   r   rj   r   r   r   stack)r   r=   arch_code_prob_aprobs_ar   r	  r   
get_prob_aN  s    
zTopologySearch.get_prob_a)fullc              	   C  s>  |d }t || j d }g }t| jD ](}||| j|  |d|     q,tj	|tj
| jddt| j  }| jdd\}}tj| jdd}	|r| }|d	 t| jjtj
| jd}
d
}t| jD ]V}tt| jD ]B}||||f d	|
||f |	||f     || j|   7 }qq|d d d S )a  
        Get estimated output tensor size to approximate RAM consumption.

        Args:
            in_size: input image shape (4D/5D, ``[BCHW[D]]``) at the highest resolution level.
            full: full ram cost usage with all probability of 1.
        r   NrS   r   Fr  r   dimr4   ri   r      i   )r`   r   r[   r   r   re   r   r  r,   r   float32r   rJ   r   r  r   softmaxr   detachfill_r   rA   rj   r   rb   r   r  )r   in_sizer  
batch_size
image_sizesizesr   r  r  Z	cell_probrA   usager   path_idxr   r   r   get_ram_cost_usagen  s.    	&"

z!TopologySearch.get_ram_cost_usagec              	   C  s  t | dr| j}| j}n
dd tt| jD }dd tt| jD }tt| jD ]}t| j	t| j	 }}tt| j
D ]D}|| j
|   | j| | 7  < || j|   | j| | 7  < q|dkt}|dkt}|| jt|  | || jt|  | qX|| _|| _d}t| jd D ]}	d}
tt| jD ]d}||	|| f  }||	d || f  }|
|t|d  d| td| d    7 }
qL||
7 }q6|S )z
        Get topology entropy loss at searching stage.

        Args:
            probs: path activation probabilities
        r   c                 S  s   g | ]}g qS r   r   r6   r   r   r   r9     s     z7TopologySearch.get_topology_entropy.<locals>.<listcomp>c                 S  s   g | ]}g qS r   r   r6   r   r   r   r9     s     r4   r   h㈵>)hasattrr   r   r   rb   r   r   r`   r   r   r   r   r  rJ   r   r   re   r   r  r,   log)r   probsr   r   	child_idxZ_node_inZ	_node_outr   entr   Zblk_entnode_idxZ_node_pZ
_out_probsr   r   r   get_topology_entropy  s2    

 "6z#TopologySearch.get_topology_entropyc              	     s   j dd\}} jt|dj   }tt j	ddj  }|j  }t
dt j j  d dt j j  d f} fddtt jD }t
t jt jf}tt jD ]}t
 jt}	tt j| D ]$}
|	 j|
    j| |
 7  < q|	dkt}	 jt|	 D ].}|t|  j t }d|||f< q>qt
|d d  d	 |dddt j f< td jD ]}|t
t
|| d  d	 t jdf |d|d t j  d|t j  d|t j  d|d t j  f< qd	|d jd t j  d jt j  df< t|}t|ddddd
\}}}d\}}t
 jt jf}t
 jd  jf}|| }|dkrΐq:|d t j } j| ||ddf< tt jD ](}|| j| f  |||f 7  < q|d8 }qtt jD ](}|| j| f  |d|f 7  < qH|dkt}||||fS )a  
        Decode network log_alpha_a/log_alpha_c using dijkstra shortest path algorithm.

        `[node_a, arch_code_a, arch_code_c, arch_code_a_max]` is decoded when using ``self.decode()``.

        For example, for a ``num_depths=4``, ``num_blocks=12`` search space:

            - ``node_a`` is a 4x13 binary matrix representing if a feature node is activated
              (13 because of multi-resolution inputs).
            - ``arch_code_a`` is a 12x10 (10 paths) binary matrix representing if a path is activated.
            - ``arch_code_c`` is a 12x10x5 (5 operations) binary matrix representing if a cell operation is used.

        Return:
            arch_code with maximum probability
        Tr  r   r4   c                   s   i | ]}t  j| |qS r   )r   r   r   r   r   r   r     s      z)TopologySearch.decode.<locals>.<dictcomp>r   r   gMbP?)csgraphdirectedindicesmin_onlyreturn_predecessors)r   r   N)r  r   r,   argmaxr   r   numpyr   r  r   r`   r   rb   r   r   r   r  rJ   r   r   r   flattenr   r"  tiler   r   r   )r   r#  r  Zarch_code_a_maxrf   ZamtxZ
path2childZsub_amtxr$  Z	_node_actr  r  Zconnect_child_idxr   graphdist_matrixpredecessorssourcesindexZa_idxr   r   r   r   r   r   decode  sf     .".,  
4    
&&zTopologySearch.decodec           
   	   C  s   | j dd\}}|}t| jD ]}dg| j }t| j| j  D ]f\}}|rFt	j
| j||f dd}	|| j|   | jt||f || j|  |	d|||f  7  < qF|}q|S )z
        Prediction based on dynamic arch_code.

        Args:
            x: a list of `num_depths` input tensors as a multi-resolution input.
                tensor is of shape `BCHW[D]` where `C` must match `self.filter_nums`.
        Fr  ri   r   r  )r   )r  r   r   r   rk   r   r   r   r.  r   r  r   r   r   r   r   )
r   r   r  r  r   r   r   r   r   _wr   r   r   r      s    	 "
zTopologySearch.forward)F)F)r"   r#   r$   r%   __annotations__ro   r+   r   r   r  r  r'  r6  r    r3   r   r   r0   r   r     s*   
-
$@' !&L)1
__future__r   r   r   typingr   r.  r`   r,   torch.nnr-   torch.nn.functional
functionalr   Z!monai.networks.blocks.dints_blockr   r   r   r   monai.networks.layers.factoriesr   monai.networks.layers.utilsr	   r
   monai.utilsr   r   r8   r   __all__jit	interfaceModuler   r&   r)   r:   Identityr>   rE   rT   rW   rZ   r\   ro   r   r   r   r   r   r   r   r   <module>   s@   "{ Ju8