o
    &i                     @  s^   d dl mZ d dlZd dlmZ d dlmZ d dlmZ dgZ	G dd dej
jZdd	 ZdS )
    )annotationsN)softmax)	PHLFilter)meshgrid_ijCRFc                      s>   e Zd ZdZ								dd fddZdddZ  ZS )r   a  
    Conditional Random Field: Combines message passing with a class
    compatibility convolution into an iterative process designed
    to successively minimise the energy of the class labeling.

    In this implementation, the message passing step is a weighted
    combination of a gaussian filter and a bilateral filter.
    The bilateral term is included to respect existing structure
    within the reference tensor.

    See:
        https://arxiv.org/abs/1502.03240
             ?      @      ?      @N
iterationsintbilateral_weightfloatgaussian_weightbilateral_spatial_sigmabilateral_color_sigmagaussian_spatial_sigmaupdate_factorcompatibility_matrixtorch.Tensor | Nonec	           	        s>   t    || _|| _|| _|| _|| _|| _|| _|| _	dS )a  
        Args:
            iterations: the number of iterations.
            bilateral_weight: the weighting of the bilateral term in the message passing step.
            gaussian_weight: the weighting of the gaussian term in the message passing step.
            bilateral_spatial_sigma: standard deviation in spatial coordinates for the bilateral term.
            bilateral_color_sigma: standard deviation in color space for the bilateral term.
            gaussian_spatial_sigma: standard deviation in spatial coordinates for the gaussian term.
            update_factor: determines the magnitude of each update.
            compatibility_matrix: a matrix describing class compatibility,
                should be NxN where N is the number of classes.
        N)
super__init__r   r   r   r   r   r   r   r   )	selfr   r   r   r   r   r   r   r   	__class__ [/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/blocks/crf.pyr   &   s   

zCRF.__init__input_tensortorch.Tensorreference_tensorc                 C  s   t |}tj|| j || j gdd}|| j }t|dd}t| jD ]E}t	
||}t	
||}	| j| | j|	  }
| jdur]|
jddddd}t|| j}|ddd|
j}
t|| j|
  dd}q#|S )z
        Args:
            input_tensor: tensor containing initial class logits.
            reference_tensor: the reference tensor used to guide the message passing.

        Returns:
            output (torch.Tensor): output tensor.
           dimN   )	start_dimr   )_create_coordinate_tensortorchcatr   r   r   r   ranger   r   applyr   r   r   flattenpermutematmulreshapeshaper   )r   r   r    Zspatial_featuresZbilateral_featuresZgaussian_featuresoutput_tensor_Zbilateral_outputZgaussian_outputZcombined_outputflatr   r   r   forwardG   s    

zCRF.forward)r   r   r   r	   r
   r	   r   N)r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   )r   r   r    r   )__name__
__module____qualname____doc__r   r3   __classcell__r   r   r   r   r      s    !c                   sV    fddt d  D }t|}t|j j jd}tj d|g ddS )Nc                   s   g | ]
}t  |qS r   )r'   arangesize).0itensorr   r   
<listcomp>t   s    z-_create_coordinate_tensor.<locals>.<listcomp>r$   )devicedtyper   r"   )	r)   r#   r   r'   stacktor@   rA   r:   )r>   axesgridscoordsr   r=   r   r&   s   s   r&   )
__future__r   r'   torch.nn.functionalr   Zmonai.networks.layers.filteringr   monai.networks.utilsr   __all__nnModuler   r&   r   r   r   r   <module>   s   \