U
    Ph8o                     @  sZ  d dl mZ d dlZd dlmZ d dlmZ d dlZd dlm	  m
Z d dlm	Z	 d dlmZ d dlmZ d dlmZ d d	lmZ d d
lmZmZmZmZmZmZmZmZ ed\ZZed\ZZddddddddddddgZ G dd de	j!Z"G dd de	j!Z#G dd de	j!Z$G dd de	j!Z%dddd d d!d dd"d#d$Z&dDddddd&d'dZ'dddd(d)dZ(G d*d de	j!Z)G d+d de	j!Z*ej+dfd,dd-d.d/Z,dEdd,d d2dd3d4dZ-G d5d de	j!Z.G d6d de	j!Z/G d7d8 d8eZ0G d9d de	j!Z1G d:d; d;e	j!Z2G d<d= d=e2Z3G d>d? d?e2Z4G d@dA dAe2Z5G dBdC dCe5Z6dS )F    )annotationsNdeepcopy)Sequence)nn)Function)NdarrayOrTensor)gaussian_1d)Conv)ChannelMatchingSkipModeconvert_to_tensorensure_tuple_repissequenceiterablelook_up_optionoptional_importpytorch_afterzmonai._Cz	torch.fft
ChannelPadFlattenGaussianFilterHilbertTransformLLTMMedianFilterReshapeSavitzkyGolayFilterSkipConnectionapply_filtermedian_filterseparable_filteringc                      sB   e Zd ZdZejfddddd fddZdddd	d
Z  ZS )r   z
    Expand the input tensor's channel dimension from length `in_channels` to `out_channels`,
    by padding or a projection.
    intzChannelMatching | str)spatial_dimsin_channelsout_channelsmodec           	        s   t    d| _d| _||kr"dS t|t}|tjkrXttj|f }|||dd| _dS |tj	kr||krrt
d|| d }|| | }ddg| ||g ddg }t|| _dS dS )a  

        Args:
            spatial_dims: number of spatial dimensions of the input image.
            in_channels: number of input channels.
            out_channels: number of output channels.
            mode: {``"pad"``, ``"project"``}
                Specifies handling residual branch and conv branch channel mismatches. Defaults to ``"pad"``.

                - ``"pad"``: with zero padding.
                - ``"project"``: with a trainable conv with kernel size one.
        N   )kernel_sizezKIncompatible values: channel_matching="pad" and in_channels > out_channels.   r   )super__init__projectpadr   r   PROJECTr
   ZCONVPAD
ValueErrortuple)	selfr    r!   r"   r#   	conv_typeZpad_1Zpad_2r*   	__class__ W/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/layers/simplelayers.pyr(   >   s$    




zChannelPad.__init__torch.Tensorxreturnc                 C  s6   | j d k	rt|  |S | jd k	r2t|| jS |S N)r)   torch	as_tensorr*   Fr/   r7   r3   r3   r4   forward`   s
    

zChannelPad.forward)	__name__
__module____qualname____doc__r   r,   r(   r>   __classcell__r3   r3   r1   r4   r   8   s   "c                      s<   e Zd ZdZddddd fdd	Zd
d
dddZ  ZS )r   z
    Combine the forward pass input with the result from the given submodule::

        --+--submodule--o--
          |_____________|

    The available modes are ``"cat"``, ``"add"``, ``"mul"``.
    r$   catr   zstr | SkipModeNone)dimr#   r8   c                   s(   t    || _|| _t|tj| _dS )a  

        Args:
            submodule: the module defines the trainable branch.
            dim: the dimension over which the tensors are concatenated.
                Used when mode is ``"cat"``.
            mode: ``"cat"``, ``"add"``, ``"mul"``. defaults to ``"cat"``.
        N)r'   r(   	submodulerF   r   r   valuer#   )r/   rG   rF   r#   r1   r3   r4   r(   r   s    	
zSkipConnection.__init__r5   r6   c                 C  sj   |  |}| jdkr(tj||g| jdS | jdkr>t||S | jdkrTt||S td| j dd S )NrD   rF   addmulzUnsupported mode .)rG   r#   r:   rD   rF   rJ   rK   NotImplementedError)r/   r7   yr3   r3   r4   r>      s    



zSkipConnection.forward)r$   rD   r?   r@   rA   rB   r(   r>   rC   r3   r3   r1   r4   r   h   s   	c                   @  s    e Zd ZdZdddddZdS )r   zM
    Flattens the given input in the forward pass to be [B,-1] in shape.
    r5   r6   c                 C  s   | |ddS )Nr   )viewsizer=   r3   r3   r4   r>      s    zFlatten.forwardN)r?   r@   rA   rB   r>   r3   r3   r3   r4   r      s   c                      s8   e Zd ZdZddd fddZdddd	d
Z  ZS )r   zk
    Reshapes input tensors to the given shape (minus batch dimension), retaining original batch size.
    r   rE   )shaper8   c                   s   t    dt| | _dS )a  
        Given a shape list/tuple `shape` of integers (s0, s1, ... , sn), this layer will reshape input tensors of
        shape (batch, s0 * s1 * ... * sn) to shape (batch, s0, s1, ... , sn).

        Args:
            shape: list/tuple of integer shape dimensions
        )r$   N)r'   r(   r.   rS   )r/   rS   r1   r3   r4   r(      s    
zReshape.__init__r5   r6   c                 C  s"   t | j}|jd |d< ||S )Nr   )listrS   reshape)r/   r7   rS   r3   r3   r4   r>      s    
zReshape.forwardrO   r3   r3   r1   r4   r      s   r5   zlist[torch.Tensor]strr   z	list[int])input_kernelspad_modedr    paddingsnum_channelsr8   c              	   C  s   |dk r| S dgt | j }d||d < || |}| dkrf|d dkrft| |||d |||S ||dgdg|  }dg| }	|| |	|< tjtjtj	g|d  }
dd t
|	D }t|g }tj| ||d}|
t||||d |||||dS )	Nr   r$   rP   r&   c                 S  s   g | ]}||gqS r3   r3   ).0pr3   r3   r4   
<listcomp>   s     z-_separable_filtering_conv.<locals>.<listcomp>r#   )inputweightgroups)lenrS   rU   numel_separable_filtering_convrepeatr<   conv1dconv2dconv3dreversedsumr*   )rW   rX   rY   rZ   r    r[   r\   s_kernel_paddingr0    _reversed_padding_repeated_twiceZ$_sum_reversed_padding_repeated_twicepadded_inputr3   r3   r4   rf      s&    	

rf   zeros)r7   rX   r#   r8   c                   s   t  tjs"tdt j dt jd }t |tjrF|g| } fdd|D }dd |D } jd }|dkr|d	n|}t |||d |||S )
a1  
    Apply 1-D convolutions along each spatial dimension of `x`.

    Args:
        x: the input image. must have shape (batch, channels, H[, W, ...]).
        kernels: kernel along each spatial dimension.
            could be a single kernel (duplicated for all spatial dimensions), or
            a list of `spatial_dims` number of kernels.
        mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'``
            or ``'circular'``. Default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information.

    Raises:
        TypeError: When ``x`` is not a ``torch.Tensor``.

    Examples:

    .. code-block:: python

        >>> import torch
        >>> from monai.networks.layers import separable_filtering
        >>> img = torch.randn(2, 4, 32, 32)  # batch_size 2, channels 4, 32x32 2D images
        # applying a [-1, 0, 1] filter along each of the spatial dimensions.
        # the output shape is the same as the input shape.
        >>> out = separable_filtering(img, torch.tensor((-1., 0., 1.)))
        # applying `[-1, 0, 1]`, `[1, 0, -1]` filters along two spatial dimensions respectively.
        # the output shape is the same as the input shape.
        >>> out = separable_filtering(img, [torch.tensor((-1., 0., 1.)), torch.tensor((1., 0., -1.))])

     x must be a torch.Tensor but is rL   r&   c                   s   g | ]}|  qS r3   )tor]   rm   r7   r3   r4   r_      s     z'separable_filtering.<locals>.<listcomp>c                 S  s   g | ]}|j d  d d qS )r   r$   r&   )rS   r]   kr3   r3   r4   r_      s     r$   rr   constant)	
isinstancer:   Tensor	TypeErrortyper?   rd   rS   rf   )r7   rX   r#   r    _kernelsZ	_paddingsZn_chsrY   r3   rv   r4   r      s    

)r7   kernelr8   c           
      K  s  t | tjs"tdt| j d| j^}}}t|}|dkrNtd| dt|j}||k sl||d krt	d| d|d  d| j d|
| }|j||f|j|| d	  }|jd|jdd	  }| jd|jd f| } tjtjtjg|d  }d|kr>tddr d|d< ndd |jdd	 D |d< n6|d dkrttddstdd |jdd	 D |d< d|krd|d< || |f|jd d	d|}	|	j||f|	jdd	  S )a  
    Filtering `x` with `kernel` independently for each batch and channel respectively.

    Args:
        x: the input image, must have shape (batch, channels, H[, W, D]).
        kernel: `kernel` must at least have the spatial shape (H_k[, W_k, D_k]).
            `kernel` shape must be broadcastable to the `batch` and `channels` dimensions of `x`.
        kwargs: keyword arguments passed to `conv*d()` functions.

    Returns:
        The filtered `x`.

    Examples:

    .. code-block:: python

        >>> import torch
        >>> from monai.networks.layers import apply_filter
        >>> img = torch.rand(2, 5, 10, 10)  # batch_size 2, channels 5, 10x10 2D images
        >>> out = apply_filter(img, torch.rand(3, 3))   # spatial kernel
        >>> out = apply_filter(img, torch.rand(5, 3, 3))  # channel-wise kernels
        >>> out = apply_filter(img, torch.rand(2, 5, 3, 3))  # batch-, channel-wise kernels

    rs   rL      z6Only spatial dimensions up to 3 are supported but got r&   zkernel must have z ~ z% dimensions to match the input shape NrP   r$   r   padding
   samec                 S  s   g | ]}|d  d qS r$   r&   r3   rw   r3   r3   r4   r_   ,  s     z apply_filter.<locals>.<listcomp>c                 S  s   g | ]}|d  d qS r   r3   rw   r3   r3   r4   r_   /  s     stride)rc   bias)rP   r$   )rz   r:   r{   r|   r}   r?   rS   rd   rM   r-   rt   expandrU   rQ   r<   rh   ri   rj   r   )
r7   r   kwargsbatchZchnsZspatialsZ	n_spatialZk_sizeconvoutputr3   r3   r4   r      s4    

 


c                      sJ   e Zd ZdZdddddd fddZd	d	d
ddZedd Z  ZS )r   aR  
    Convolve a Tensor along a particular axis with a Savitzky-Golay kernel.

    Args:
        window_length: Length of the filter window, must be a positive odd integer.
        order: Order of the polynomial to fit to each window, must be less than ``window_length``.
        axis (optional): Axis along which to apply the filter kernel. Default 2 (first spatial dimension).
        mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or
        ``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information.
    r&   rr   r   rV   )window_lengthorderaxisr#   c                   s8   t    ||krtd|| _|| _| ||| _d S )Nz&order must be less than window_length.)r'   r(   r-   r   r#   _make_coeffscoeffs)r/   r   r   r   r#   r1   r3   r4   r(   C  s    
zSavitzkyGolayFilter.__init__r5   r6   c              	   C  s  t j|t|t jr|jndd}t |r2td|jt jd}| j	dk s^| j	t
|jd krxtd| j	 d|j d	t
|jd
 }| j	d
 }|}|| d }| jj|j|jdg}t|D ] }|dt jd|j|jd qt|D ]}|t jd|j|jd qt||| jdS )a  
        Args:
            x: Tensor or array-like to filter. Must be real, in shape ``[Batch, chns, spatial1, spatial2, ...]`` and
                have a device type of ``'cpu'``.
        Returns:
            torch.Tensor: ``x`` filtered by Savitzky-Golay kernel with window length ``self.window_length`` using
            polynomials of order ``self.order``, along axis specified in ``self.axis``.
        Ndevicex must be real.dtyper   r$   &Invalid axis for shape of x, got axis  and shape rL   r&   r   r   r`   )r:   r;   rz   r{   r   
is_complexr-   rt   floatr   rd   rS   r   r   rangeinsertonesappendr   r#   )r/   r7   Zn_spatial_dimsZspatial_processing_axisnew_dims_beforenew_dims_afterkernel_list_r3   r3   r4   r>   L  s      

zSavitzkyGolayFilter.forwardc                 C  s   t | d\}}|dkrtdtj| | d | d dtjdd}|tj|d tjdddd }tj|d tjdd}d|d< tdd	st||j	
 S tj||j	
 S )
Nr&   r   zwindow_length must be odd.r$   rP   cpur   r         ?   )divmodr-   r:   aranger   rU   rr   r   lstsqsolutionsqueezelinalg)r   r   Zhalf_lengthremidxarN   r3   r3   r4   r   m  s    $"z SavitzkyGolayFilter._make_coeffs)r&   rr   )	r?   r@   rA   rB   r(   r>   staticmethodr   rC   r3   r3   r1   r4   r   7  s
   	!c                      s<   e Zd ZdZddddd fdd	Zd
d
dddZ  ZS )r   a  
    Determine the analytical signal of a Tensor along a particular axis.

    Args:
        axis: Axis along which to apply Hilbert transform. Default 2 (first spatial dimension).
        n: Number of Fourier components (i.e. FFT size). Default: ``x.shape[axis]``.
    r&   Nr   z
int | NonerE   )r   nr8   c                   s   t    || _|| _d S r9   )r'   r(   r   r   )r/   r   r   r1   r3   r4   r(     s    
zHilbertTransform.__init__r5   r6   c           
   
   C  s  t j|t|t jr|jndd}t |r2td|jt jd}| j	dk s^| j	t
|jd krxtd| j	 d|j d	| jdkr|j| j	 n| j}|dkrtd
t j|t jd}t t t jd|d d d |jdt|t t j|d  d|jdt|g}tj||| j	d}t |t jdg|jd}t j||j|jd}| j	}t
|j| j	 d }t|D ]}|d qjt|D ]}|d qtj|d | | j	d}	t j|	|	j|	jdS )a  
        Args:
            x: Tensor or array-like to transform. Must be real and in shape ``[Batch, chns, spatial1, spatial2, ...]``.
        Returns:
            torch.Tensor: Analytical signal of ``x``, transformed along axis specified in ``self.axis`` using
            FFT of size ``self.N``. The absolute value of ``x_ht`` relates to the envelope of ``x`` along axis ``self.axis``.
        Nr   r   r   r   r$   r   r   rL   zN must be positive.r&   )r   rF   g      ?r   rP   rI   r   )r:   r;   rz   r{   r   r   r-   rt   r   r   rd   rS   r   	complex64rD   true_divider   fft	heavisidetensorr   r   
unsqueeze_ifft)
r/   r7   r   fxfur   r   r   htr3   r3   r4   r>     s4    
 
("zHilbertTransform.forward)r&   NrO   r3   r3   r1   r4   r   ~  s   zSequence[int])window_sizer8   c                 C  s@   t | tdd}t|}|df|}ttj|||d|S )zv
    Create a binary kernel to extract the patches.
    The window size HxWxD will create a (H*W*D)xHxWxD kernel.
    Twrap_sequencer$   r   )r   r   r:   proddiagr   rQ   )r   r   r   win_sizer   rm   r3   r3   r4   get_binary_kernel  s    
r   r   r   r   r   ztorch.Tensor | None)	in_tensorr%   r    r   r8   c                 K  s*  t | tjstdt|  | j}|dt||  || d  }}tt|t	dd}|dkrt
||}t|| j| j}n
|| }tjtjtjg|d  }	| j|df| }
dd t|jdd D }tj|
|d	d
}|	||fddd|}|j|df| }tj|ddd }||}|S )a  
    Apply median filter to an image.

    Args:
        in_tensor: input tensor; median filtering will be applied to the last `spatial_dims` dimensions.
        kernel_size: the convolution kernel size.
        spatial_dims: number of spatial dimensions to apply median filtering.
        kernel: an optional customized kernel.
        kwargs: additional parameters to the `conv`.

    Returns:
        the filtered input tensor, shape remains the same as ``in_tensor``

    Example::

        >>> from monai.networks.layers import median_filter
        >>> import torch
        >>> x = torch.rand(4, 5, 7, 6)
        >>> output = median_filter(x, (3, 3, 3))
        >>> output.shape
        torch.Size([4, 5, 7, 6])

    z&Input type is not a torch.Tensor. Got NTr   r$   c                 S  s&   g | ]}t d D ]}|d d  qqS )r&   r$   )r   )r]   rx   r   r3   r3   r4   r_     s     
  z!median_filter.<locals>.<listcomp>r&   	replicate)r*   r#   r   )r   r   rP   rI   )rz   r:   r{   r|   r}   rS   rd   r   r   r   r   r   r   r   rt   r<   rh   ri   rj   rU   rk   r*   rQ   median)r   r%   r    r   r   original_shapeZoshapeZsshapeZoprodr   Zreshaped_inputr   rq   featuresr   r3   r3   r4   r     s$    $


c                      s>   e Zd ZdZddddd fdd	ZddddddZ  ZS )r   a  
    Apply median filter to an image.

    Args:
        radius: the blurring kernel radius (radius of 1 corresponds to 3x3x3 kernel when spatial_dims=3).

    Returns:
        filtered input tensor.

    Example::

        >>> from monai.networks.layers import MedianFilter
        >>> import torch
        >>> in_tensor = torch.rand(4, 5, 7, 6)
        >>> blur = MedianFilter([1, 1, 1])  # 3x3x3 kernel
        >>> output = blur(in_tensor)
        >>> output.shape
        torch.Size([4, 5, 7, 6])

    r   r   zSequence[int] | intr   rE   )radiusr    r8   c                   sB   t    || _t||| _dd | jD | _t| j|d| _d S )Nc                 S  s   g | ]}d dt |  qS r   r   )r]   rr3   r3   r4   r_     s     z)MedianFilter.__init__.<locals>.<listcomp>r   )r'   r(   r    r   r   windowr   r   )r/   r   r    r   r1   r3   r4   r(     s
    
zMedianFilter.__init__r$   r5   )r   r8   c                 C  s(   |}t |D ]}t|| j| jd}q|S )z
        Args:
            in_tensor: input tensor, median filtering will be applied to the last `spatial_dims` dimensions.
            number_of_passes: median filtering will be repeated this many times
        )r   r    )r   r   r   r    )r/   r   Znumber_of_passesr7   r   r3   r3   r4   r>     s    zMedianFilter.forward)r   r   )r$   rO   r3   r3   r1   r4   r     s   c                      s>   e Zd Zddddddd	d
 fddZdddddZ  ZS )r         @erfFr   z?Sequence[float] | float | Sequence[torch.Tensor] | torch.Tensorr   rV   boolrE   )r    sigma	truncatedapproxrequires_gradr8   c                   s   t rt|kr0tnfddt|D t    fddD | _|| _|| _t	| jD ]\}}| 
d| | qddS )a>  
        Args:
            spatial_dims: number of spatial dimensions of the input image.
                must have shape (Batch, channels, H[, W, ...]).
            sigma: std. could be a single value, or `spatial_dims` number of values.
            truncated: spreads how many stds.
            approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace".

                - ``erf`` approximation interpolates the error function;
                - ``sampled`` uses a sampled Gaussian kernel;
                - ``scalespace`` corresponds to
                  https://en.wikipedia.org/wiki/Scale_space_implementation#The_discrete_Gaussian_kernel
                  based on the modified Bessel functions.

            requires_grad: whether to store the gradients for sigma.
                if True, `sigma` will be the initial value of the parameters of this module
                (for example `parameters()` iterator could be used to get the parameters);
                otherwise this module will fix the kernels using `sigma` as the std.
        c                   s   g | ]}t  qS r3   r   r]   r   )r   r3   r4   r_   K  s     z+GaussianFilter.__init__.<locals>.<listcomp>c              	     s<   g | ]4}t jjt j|t jt|t jr*|jnd d dqS )Nr   r   )r:   r   	Parameterr;   r   rz   r{   r   ru   r   r3   r4   r_   M  s
   "Zkernel_sigma_N)r   rd   r-   r   r'   r(   r   r   r   	enumerateregister_parameter)r/   r    r   r   r   r   r   paramr1   )r   r   r4   r(   ,  s    

zGaussianFilter.__init__r5   r6   c                   s     fdd j D }t||dS )zG
        Args:
            x: in shape [Batch, chns, H, W, D].
        c                   s   g | ]}t | j jd qS ))r   r   )r	   r   r   ru   r/   r3   r4   r_   ^  s     z*GaussianFilter.forward.<locals>.<listcomp>)r7   rX   )r   r   )r/   r7   rn   r3   r   r4   r>   Y  s    zGaussianFilter.forward)r   r   F)r?   r@   rA   r(   r>   rC   r3   r3   r1   r4   r   *  s
      -c                   @  s$   e Zd Zedd Zedd ZdS )LLTMFunctionc           
      C  sF   t |||||}|d d \}}|dd  |g }	| j|	  ||fS )Nr&   r$   )_CZlltm_forwardsave_for_backward)
ctxra   weightsr   Zold_hZold_celloutputsZnew_hZnew_cell	variablesr3   r3   r4   r>   d  s
    
zLLTMFunction.forwardc           	      C  s@   t j| | f| j }|d d \}}}}}|||||fS )N   )r   Zlltm_backward
contiguoussaved_tensors)	r   Zgrad_hZ	grad_cellr   Zd_old_hd_inputZ	d_weightsd_biasZ
d_old_cellr3   r3   r4   backwardm  s    zLLTMFunction.backwardN)r?   r@   rA   r   r>   r   r3   r3   r3   r4   r   b  s   
r   c                      s8   e Zd ZdZddd fddZdd Zdd	 Z  ZS )
r   aF  
    This recurrent unit is similar to an LSTM, but differs in that it lacks a forget
    gate and uses an Exponential Linear Unit (ELU) as its internal activation function.
    Because this unit never forgets, call it LLTM, or Long-Long-Term-Memory unit.
    It has both C++ and CUDA implementation, automatically switch according to the
    target device where put this module to.

    Args:
        input_features: size of input feature data
        state_size: size of the state of recurrent unit

    Referring to: https://pytorch.org/tutorials/advanced/cpp_extension.html
    r   )input_features
state_sizec                   sV   t    || _|| _ttd| || | _ttdd| | _	| 
  d S )Nr   r$   )r'   r(   r   r   r   r   r:   emptyr   r   reset_parameters)r/   r   r   r1   r3   r4   r(     s    
zLLTM.__init__c                 C  s4   dt | j }|  D ]}|j| |
  qd S )Nr   )mathsqrtr   
parametersdatauniform_)r/   stdvrb   r3   r3   r4   r     s    zLLTM.reset_parametersc                 C  s   t j|| j| jf| S r9   )r   applyr   r   )r/   ra   stater3   r3   r4   r>     s    zLLTM.forward)r?   r@   rA   rB   r(   r   r>   rC   r3   r3   r1   r4   r   u  s   c                      s8   e Zd ZdZddd fddZdddd	d
Z  ZS )ApplyFilterz,Wrapper class to apply a filter to an image.r   rE   )filterr8   c                   s   t    t|tjd| _d S )Nr   )r'   r(   r   r:   float32r   )r/   r   r1   r3   r4   r(     s    
zApplyFilter.__init__r5   r6   c                 C  s   t || jS r9   )r   r   r=   r3   r3   r4   r>     s    zApplyFilter.forwardrO   r3   r3   r1   r4   r     s   r   c                      s*   e Zd ZdZdddd fddZ  ZS )
MeanFilterz
    Mean filtering can smooth edges and remove aliasing artifacts in an segmentation image.
    The mean filter used, is a `torch.Tensor` of all ones.
    r   rE   r    rR   r8   c                   s&   t |g| }|}t j|d dS )
        Args:
            spatial_dims: `int` of either 2 for 2D images and 3 for 3D images
            size: edge length of the filter
        r   N)r:   r   r'   r(   )r/   r    rR   r   r1   r3   r4   r(     s    zMeanFilter.__init__r?   r@   rA   rB   r(   rC   r3   r3   r1   r4   r     s   r   c                      s*   e Zd ZdZdddd fddZ  ZS )LaplaceFilterz
    Laplacian filtering for outline detection in images. Can be used to transform labels to contours.
    The laplace filter used, is a `torch.Tensor` where all values are -1, except the center value
    which is `size` ** `spatial_dims`
    r   rE   r   c                   sL   t |g|  d }t|d g| }|| d ||< t j|d dS )r   r$   r&   r   N)r:   rr   r   r.   r'   r(   )r/   r    rR   r   center_pointr1   r3   r4   r(     s    zLaplaceFilter.__init__r   r3   r3   r1   r4   r     s   r   c                      s*   e Zd ZdZdddd fddZ  ZS )EllipticalFilterz
    Elliptical filter, can be used to dilate labels or label-contours.
    The elliptical filter used here, is a `torch.Tensor` with shape (size, ) * ndim containing a circle/sphere of `1`
    r   rE   r   c                   sb   d  t jfddt|D  }t  fdd|D dd}| d k}t j|d dS )r   r&   c                   s   g | ]}t d  qS )r   )r:   r   r   )rR   r3   r4   r_     s     z-EllipticalFilter.__init__.<locals>.<listcomp>c                   s   g | ]}|  d  qS )r&   r3   )r]   r   )r   r3   r4   r_     s     r   r   N)r:   meshgridr   stackrl   r'   r(   )r/   r    rR   gridZsquared_distancesr   r1   )r   rR   r4   r(     s
     zEllipticalFilter.__init__r   r3   r3   r1   r4   r     s   r   c                      s*   e Zd ZdZdddd fddZ  ZS )SharpenFilterz
    Convolutional filter to sharpen a 2D or 3D image.
    The filter used contains a circle/sphere of `-1`, with the center value being
    the absolute sum of all non-zero elements in the kernel
    r   rE   r   c                   sH   t  j||d t|d g| }| j }|  jd9  _|| j|< dS )r   )r    rR   r&   rP   N)r'   r(   r.   r   rl   )r/   r    rR   r   Zcenter_valuer1   r3   r4   r(     s
    
zSharpenFilter.__init__r   r3   r3   r1   r4   r    s   r  )rr   )r   r   N)7
__future__r   r   copyr   typingr   r:   torch.nn.functionalr   
functionalr<   Ztorch.autogradr   monai.config.type_definitionsr   Zmonai.networks.layers.convutilsr	   Zmonai.networks.layers.factoriesr
   monai.utilsr   r   r   r   r   r   r   r   r   r   r   __all__Moduler   r   r   r   rf   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r  r3   r3   r3   r4   <module>   sb   (0$	%-:G<   <)8 