o
    %i?m                     @  s  d dl mZ d dlZd dlmZ d dlmZ d dlZd dlm	  m
Z d dlm	Z	 d dlmZ d dlmZ d dlmZ d d	lmZ d d
lmZmZmZmZmZmZmZ ed\ZZed\ZZg dZG dd de	j Z!G dd de	j Z"G dd de	j Z#G dd de	j Z$dOd#d$Z%dPdQd(d)Z&dRd+d,Z'G d-d. d.e	j Z(G d/d0 d0e	j Z)ej*dfdSd3d4Z+	5	6	dTdUd;d<Z,G d=d> d>e	j Z-G d?d@ d@e	j Z.G dAdB dBeZ/G dCdD dDe	j Z0G dEdF dFe	j Z1G dGdH dHe1Z2G dIdJ dJe1Z3G dKdL dLe1Z4G dMdN dNe4Z5dS )V    )annotationsNdeepcopy)Sequence)nn)Function)NdarrayOrTensor)gaussian_1d)Conv)ChannelMatchingSkipModeconvert_to_tensorensure_tuple_repissequenceiterablelook_up_optionoptional_importzmonai._Cz	torch.fft)
ChannelPadFlattenGaussianFilterHilbertTransformLLTMMedianFilterReshapeSavitzkyGolayFilterSkipConnectionapply_filtermedian_filterseparable_filteringc                      s2   e Zd ZdZejfd fdd	ZdddZ  ZS )r   z
    Expand the input tensor's channel dimension from length `in_channels` to `out_channels`,
    by padding or a projection.
    spatial_dimsintin_channelsout_channelsmodeChannelMatching | strc           	        s   t    d| _d| _||krdS t|t}|tjkr,ttj|f }|||dd| _dS |tj	krZ||kr9t
d|| d }|| | }ddg| ||g ddg }t|| _dS dS )a  

        Args:
            spatial_dims: number of spatial dimensions of the input image.
            in_channels: number of input channels.
            out_channels: number of output channels.
            mode: {``"pad"``, ``"project"``}
                Specifies handling residual branch and conv branch channel mismatches. Defaults to ``"pad"``.

                - ``"pad"``: with zero padding.
                - ``"project"``: with a trainable conv with kernel size one.
        N   )kernel_sizezKIncompatible values: channel_matching="pad" and in_channels > out_channels.   r   )super__init__projectpadr   r   PROJECTr
   CONVPAD
ValueErrortuple)	selfr   r    r!   r"   	conv_typeZpad_1Zpad_2r*   	__class__ d/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/layers/simplelayers.pyr(   =   s&   




zChannelPad.__init__xtorch.Tensorreturnc                 C  s6   | j d urt|  |S | jd urt|| jS |S N)r)   torch	as_tensorr*   Fr0   r6   r4   r4   r5   forward_   s
   

zChannelPad.forward)r   r   r    r   r!   r   r"   r#   r6   r7   r8   r7   )	__name__
__module____qualname____doc__r   r-   r(   r>   __classcell__r4   r4   r2   r5   r   7   s
    "r   c                      .   e Zd ZdZdd fd
dZdddZ  ZS )r   z
    Combine the forward pass input with the result from the given submodule::

        --+--submodule--o--
          |_____________|

    The available modes are ``"cat"``, ``"add"``, ``"mul"``.
    r$   catdimr   r"   str | SkipModer8   Nonec                   s(   t    || _|| _t|tj| _dS )a  

        Args:
            submodule: the module defines the trainable branch.
            dim: the dimension over which the tensors are concatenated.
                Used when mode is ``"cat"``.
            mode: ``"cat"``, ``"add"``, ``"mul"``. defaults to ``"cat"``.
        N)r'   r(   	submodulerG   r   r   valuer"   )r0   rJ   rG   r"   r2   r4   r5   r(   q   s   
	zSkipConnection.__init__r6   r7   c                 C  sf   |  |}| jdkrtj||g| jdS | jdkrt||S | jdkr*t||S td| j d)NrF   rG   addmulzUnsupported mode .)rJ   r"   r:   rF   rG   rM   rN   NotImplementedError)r0   r6   yr4   r4   r5   r>      s   



zSkipConnection.forward)r$   rF   )rG   r   r"   rH   r8   rI   r?   r@   rA   rB   rC   r(   r>   rD   r4   r4   r2   r5   r   g   s    	r   c                   @  s   e Zd ZdZdddZdS )	r   zM
    Flattens the given input in the forward pass to be [B,-1] in shape.
    r6   r7   r8   c                 C  s   | |ddS )Nr   )viewsizer=   r4   r4   r5   r>      s   zFlatten.forwardNr?   )r@   rA   rB   rC   r>   r4   r4   r4   r5   r      s    r   c                      ,   e Zd ZdZd fddZdd
dZ  ZS )r   zk
    Reshapes input tensors to the given shape (minus batch dimension), retaining original batch size.
    shaper   r8   rI   c                   s   t    dt| | _dS )a  
        Given a shape list/tuple `shape` of integers (s0, s1, ... , sn), this layer will reshape input tensors of
        shape (batch, s0 * s1 * ... * sn) to shape (batch, s0, s1, ... , sn).

        Args:
            shape: list/tuple of integer shape dimensions
        r$   N)r'   r(   r/   rW   )r0   rW   r2   r4   r5   r(      s   
zReshape.__init__r6   r7   c                 C  s"   t | j}|jd |d< ||S )Nr   )listrW   reshape)r0   r6   rW   r4   r4   r5   r>      s   

zReshape.forward)rW   r   r8   rI   r?   rR   r4   r4   r2   r5   r      s    r   input_r7   kernelslist[torch.Tensor]pad_modestrdr   r   paddings	list[int]num_channelsr8   c              	   C  s   |dk r| S dgt | j }d||d < || |}| dkr3|d dkr3t| |||d |||S ||dgdg|  }dg| }	|| |	|< tjtjtj	g|d  }
dd t
|	D }t|g }tj| ||d}|
t||||d |||||dS )	Nr   r$   rS   r&   c                 S  s   g | ]}||gqS r4   r4   ).0pr4   r4   r5   
<listcomp>       z-_separable_filtering_conv.<locals>.<listcomp>r"   )inputweightgroups)lenrW   rZ   numel_separable_filtering_convrepeatr<   conv1dconv2dconv3dreversedsumr*   )r[   r\   r^   r`   r   ra   rc   s_kernel_paddingr1    _reversed_padding_repeated_twiceZ$_sum_reversed_padding_repeated_twicepadded_inputr4   r4   r5   rn      s&   	

rn   zerosr6   r"   c                   s   t  tjstdt j dt jd }t |tjr#|g| } fdd|D }dd |D } jd }|dkr>d	n|}t |||d |||S )
a1  
    Apply 1-D convolutions along each spatial dimension of `x`.

    Args:
        x: the input image. must have shape (batch, channels, H[, W, ...]).
        kernels: kernel along each spatial dimension.
            could be a single kernel (duplicated for all spatial dimensions), or
            a list of `spatial_dims` number of kernels.
        mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'``
            or ``'circular'``. Default: ``'zeros'``. See ``torch.nn.Conv1d()`` for more information.

    Raises:
        TypeError: When ``x`` is not a ``torch.Tensor``.

    Examples:

    .. code-block:: python

        >>> import torch
        >>> from monai.networks.layers import separable_filtering
        >>> img = torch.randn(2, 4, 32, 32)  # batch_size 2, channels 4, 32x32 2D images
        # applying a [-1, 0, 1] filter along each of the spatial dimensions.
        # the output shape is the same as the input shape.
        >>> out = separable_filtering(img, torch.tensor((-1., 0., 1.)))
        # applying `[-1, 0, 1]`, `[1, 0, -1]` filters along two spatial dimensions respectively.
        # the output shape is the same as the input shape.
        >>> out = separable_filtering(img, [torch.tensor((-1., 0., 1.)), torch.tensor((1., 0., -1.))])

     x must be a torch.Tensor but is rO   r&   c                   s   g | ]}|  qS r4   )tord   ru   r6   r4   r5   rf      s    z'separable_filtering.<locals>.<listcomp>c                 S  s   g | ]}|j d  d d qS )r   r$   r&   )rW   )rd   kr4   r4   r5   rf          r$   rz   constant)	
isinstancer:   Tensor	TypeErrortyper@   rl   rW   rn   )r6   r\   r"   r   _kernelsZ	_paddingsZn_chsr^   r4   r~   r5   r      s   

r   kernelc           
      K  sf  t | tjstdt| j d| j^}}}t|}|dkr'td| dt|j}||k s6||d krGt	d| d|d  d| j d|
| }|j||g|j|| d	 R  }|jd
dg|jdd	 R  }| jd|jd g|R  } tjtjtjg|d  }d|vrd|d< d|vrd|d< || |f|jd d	d|}	|	j||g|	jdd	 R  S )a  
    Filtering `x` with `kernel` independently for each batch and channel respectively.

    Args:
        x: the input image, must have shape (batch, channels, H[, W, D]).
        kernel: `kernel` must at least have the spatial shape (H_k[, W_k, D_k]).
            `kernel` shape must be broadcastable to the `batch` and `channels` dimensions of `x`.
        kwargs: keyword arguments passed to `conv*d()` functions.

    Returns:
        The filtered `x`.

    Examples:

    .. code-block:: python

        >>> import torch
        >>> from monai.networks.layers import apply_filter
        >>> img = torch.rand(2, 5, 10, 10)  # batch_size 2, channels 5, 10x10 2D images
        >>> out = apply_filter(img, torch.rand(3, 3))   # spatial kernel
        >>> out = apply_filter(img, torch.rand(5, 3, 3))  # channel-wise kernels
        >>> out = apply_filter(img, torch.rand(2, 5, 3, 3))  # batch-, channel-wise kernels

    r{   rO      z6Only spatial dimensions up to 3 are supported but got r&   zkernel must have z ~ z% dimensions to match the input shape NrS   r$   r   paddingsamestride)rk   bias)r   r:   r   r   r   r@   rW   rl   rP   r.   r|   expandrZ   rT   r<   rp   rq   rr   )
r6   r   kwargsbatchZchnsZspatialsZ	n_spatialZk_sizeconvoutputr4   r4   r5   r      s,   

"r   c                      s:   e Zd ZdZdd fd
dZdddZedd Z  ZS )r   aR  
    Convolve a Tensor along a particular axis with a Savitzky-Golay kernel.

    Args:
        window_length: Length of the filter window, must be a positive odd integer.
        order: Order of the polynomial to fit to each window, must be less than ``window_length``.
        axis (optional): Axis along which to apply the filter kernel. Default 2 (first spatial dimension).
        mode (string, optional): padding mode passed to convolution class. ``'zeros'``, ``'reflect'``, ``'replicate'`` or
        ``'circular'``. Default: ``'zeros'``. See torch.nn.Conv1d() for more information.
    r&   rz   window_lengthr   orderaxisr"   r_   c                   s8   t    ||krtd|| _|| _| ||| _d S )Nz&order must be less than window_length.)r'   r(   r.   r   r"   _make_coeffscoeffs)r0   r   r   r   r"   r2   r4   r5   r(   ;  s   
zSavitzkyGolayFilter.__init__r6   r7   r8   c              	   C  s  t j|t|t jr|jndd}t |rtd|jt jd}| j	dk s/| j	t
|jd kr<td| j	 d|j d	t
|jd
 }| j	d
 }|}|| d }| jj|j|jdg}t|D ]}|dt jd|j|jd q_t|D ]}|t jd|j|jd qtt||| jdS )a  
        Args:
            x: Tensor or array-like to filter. Must be real, in shape ``[Batch, chns, spatial1, spatial2, ...]`` and
                have a device type of ``'cpu'``.
        Returns:
            torch.Tensor: ``x`` filtered by Savitzky-Golay kernel with window length ``self.window_length`` using
            polynomials of order ``self.order``, along axis specified in ``self.axis``.
        Ndevicex must be real.dtyper   r$   &Invalid axis for shape of x, got axis  and shape rO   r&   r   r   rh   )r:   r;   r   r   r   
is_complexr.   r|   floatr   rl   rW   r   r   rangeinsertonesappendr   r"   )r0   r6   Zn_spatial_dimsZspatial_processing_axisnew_dims_beforenew_dims_afterkernel_list_r4   r4   r5   r>   D  s     

zSavitzkyGolayFilter.forwardc                 C  s   t | d\}}|dkrtdtj| | d | d dtjdd}|tj|d tjdddd }tj|d tjdd}d|d< tj||j	
 S )	Nr&   r   zwindow_length must be odd.r$   rS   cpur   r         ?)divmodr.   r:   aranger   rZ   rz   linalglstsqsolutionsqueeze)r   r   Zhalf_lengthremidxarQ   r4   r4   r5   r   e  s   $"z SavitzkyGolayFilter._make_coeffs)r&   rz   )r   r   r   r   r   r   r"   r_   r?   )	r@   rA   rB   rC   r(   r>   staticmethodr   rD   r4   r4   r2   r5   r   /  s    
	!r   c                      rE   )r   a  
    Determine the analytical signal of a Tensor along a particular axis.

    Args:
        axis: Axis along which to apply Hilbert transform. Default 2 (first spatial dimension).
        n: Number of Fourier components (i.e. FFT size). Default: ``x.shape[axis]``.
    r&   Nr   r   n
int | Noner8   rI   c                   s   t    || _|| _d S r9   )r'   r(   r   r   )r0   r   r   r2   r4   r5   r(   {  s   

zHilbertTransform.__init__r6   r7   c           
   
   C  s  t j|t|t jr|jndd}t |rtd|jt jd}| j	dk s/| j	t
|jd kr<td| j	 d|j d	| jdu rG|j| j	 n| j}|dkrRtd
t j|t jd}t t t jd|d d d |jdt|t t j|d  d|jdt|g}tj||| j	d}t |t jdg|jd}t j||j|jd}| j	}t
|j| j	 d }t|D ]}|d qt|D ]}|d qtj|d | | j	d}	t j|	|	j|	jdS )a  
        Args:
            x: Tensor or array-like to transform. Must be real and in shape ``[Batch, chns, spatial1, spatial2, ...]``.
        Returns:
            torch.Tensor: Analytical signal of ``x``, transformed along axis specified in ``self.axis`` using
            FFT of size ``self.N``. The absolute value of ``x_ht`` relates to the envelope of ``x`` along axis ``self.axis``.
        Nr   r   r   r   r$   r   r   rO   zN must be positive.r&   )r   rG   g      ?r   rS   rL   r   )r:   r;   r   r   r   r   r.   r|   r   r   rl   rW   r   	complex64rF   true_divider   fft	heavisidetensorr   r   
unsqueeze_ifft)
r0   r6   r   fxfur   r   r   Zhtr4   r4   r5   r>     s4    

("zHilbertTransform.forward)r&   N)r   r   r   r   r8   rI   r?   rR   r4   r4   r2   r5   r   r  s    r   window_sizeSequence[int]c                 C  s@   t | tdd}t|}|dg|}ttj|||d|S )zv
    Create a binary kernel to extract the patches.
    The window size HxWxD will create a (H*W*D)xHxWxD kernel.
    Twrap_sequencer$   r   )r   r   r:   proddiagr   rT   )r   r   r   win_sizer   ru   r4   r4   r5   get_binary_kernel  s   
r   r   r   r   r   	in_tensorr%   Sequence[int] | inttorch.Tensor | Nonec                 K  s.  t | tjstdt|  | j}|dt||  || d }}tt|t	dd}|du r@t
||}t|| j| j}n|| }tjtjtjg|d  }	| j|dg|R  }
dd t|jdd D }tj|
|d	d
}|	||fddd|}|j|dg|R  }tj|ddd }||}|S )a  
    Apply median filter to an image.

    Args:
        in_tensor: input tensor; median filtering will be applied to the last `spatial_dims` dimensions.
        kernel_size: the convolution kernel size.
        spatial_dims: number of spatial dimensions to apply median filtering.
        kernel: an optional customized kernel.
        kwargs: additional parameters to the `conv`.

    Returns:
        the filtered input tensor, shape remains the same as ``in_tensor``

    Example::

        >>> from monai.networks.layers import median_filter
        >>> import torch
        >>> x = torch.rand(4, 5, 7, 6)
        >>> output = median_filter(x, (3, 3, 3))
        >>> output.shape
        torch.Size([4, 5, 7, 6])

    z&Input type is not a torch.Tensor. Got NTr   r$   c                 S  s&   g | ]}t d D ]}|d d  qqS )r&   r$   )r   )rd   r   r   r4   r4   r5   rf     s   & z!median_filter.<locals>.<listcomp>r&   	replicate)r*   r"   r   )r   r   rS   rL   )r   r:   r   r   r   rW   rl   r   r   r   r   r   r   r   r|   r<   rp   rq   rr   rZ   rs   r*   rT   median)r   r%   r   r   r   original_shapeZoshapeZsshapeZoprodr   Zreshaped_inputr   ry   featuresr   r4   r4   r5   r     s$   $


r   c                      s0   e Zd ZdZdd fd
dZddddZ  ZS )r   a  
    Apply median filter to an image.

    Args:
        radius: the blurring kernel radius (radius of 1 corresponds to 3x3x3 kernel when spatial_dims=3).

    Returns:
        filtered input tensor.

    Example::

        >>> from monai.networks.layers import MedianFilter
        >>> import torch
        >>> in_tensor = torch.rand(4, 5, 7, 6)
        >>> blur = MedianFilter([1, 1, 1])  # 3x3x3 kernel
        >>> output = blur(in_tensor)
        >>> output.shape
        torch.Size([4, 5, 7, 6])

    r   r   radiusr   r   r   r8   rI   c                   sB   t    || _t||| _dd | jD | _t| j|d| _d S )Nc                 S  s   g | ]
}d dt |  qS )r$   r&   r   )rd   rr4   r4   r5   rf     s    z)MedianFilter.__init__.<locals>.<listcomp>r   )r'   r(   r   r   r   windowr   r   )r0   r   r   r   r2   r4   r5   r(     s
   
zMedianFilter.__init__r$   r   r7   c                 C  s(   |}t |D ]}t|| j| jd}q|S )z
        Args:
            in_tensor: input tensor, median filtering will be applied to the last `spatial_dims` dimensions.
            number_of_passes: median filtering will be repeated this many times
        )r   r   )r   r   r   r   )r0   r   Znumber_of_passesr6   r   r4   r4   r5   r>     s   zMedianFilter.forward)r   r   )r   r   r   r   r8   rI   rX   )r   r7   r8   r7   rR   r4   r4   r2   r5   r     s    r   c                      s0   e Zd Z			dd fddZdddZ  ZS )r         @erfFr   r   sigma?Sequence[float] | float | Sequence[torch.Tensor] | torch.Tensor	truncatedr   approxr_   requires_gradboolr8   rI   c                   s   t rt|krtnfddt|D t    fddD | _|| _|| _t	| jD ]\}}| 
d| | q2dS )a>  
        Args:
            spatial_dims: number of spatial dimensions of the input image.
                must have shape (Batch, channels, H[, W, ...]).
            sigma: std. could be a single value, or `spatial_dims` number of values.
            truncated: spreads how many stds.
            approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace".

                - ``erf`` approximation interpolates the error function;
                - ``sampled`` uses a sampled Gaussian kernel;
                - ``scalespace`` corresponds to
                  https://en.wikipedia.org/wiki/Scale_space_implementation#The_discrete_Gaussian_kernel
                  based on the modified Bessel functions.

            requires_grad: whether to store the gradients for sigma.
                if True, `sigma` will be the initial value of the parameters of this module
                (for example `parameters()` iterator could be used to get the parameters);
                otherwise this module will fix the kernels using `sigma` as the std.
        c                   s   g | ]}t  qS r4   r   rd   r   )r   r4   r5   rf   ?  rg   z+GaussianFilter.__init__.<locals>.<listcomp>c              	     s<   g | ]}t jjt j|t jt|t jr|jnd d dqS )Nr   r   )r:   r   	Parameterr;   r   r   r   r   r}   r   r4   r5   rf   A  s    "Zkernel_sigma_N)r   rl   r.   r   r'   r(   r   r   r   	enumerateregister_parameter)r0   r   r   r   r   r   r   paramr2   )r   r   r5   r(      s   

zGaussianFilter.__init__r6   r7   c                   s     fdd j D }t||dS )zG
        Args:
            x: in shape [Batch, chns, H, W, D].
        c                   s   g | ]}t | j jd qS ))r   r   )r	   r   r   r}   r0   r4   r5   rf   R  r   z*GaussianFilter.forward.<locals>.<listcomp>)r6   r\   )r   r   )r0   r6   rv   r4   r   r5   r>   M  s   zGaussianFilter.forward)r   r   F)r   r   r   r   r   r   r   r_   r   r   r8   rI   r?   )r@   rA   rB   r(   r>   rD   r4   r4   r2   r5   r     s    -r   c                   @  s$   e Zd Zedd Zedd ZdS )LLTMFunctionc           
      C  sF   t |||||}|d d \}}|dd  |g }	| j|	  ||fS )Nr&   r$   )_CZlltm_forwardsave_for_backward)
ctxri   weightsr   Zold_hZold_celloutputsZnew_hZnew_cell	variablesr4   r4   r5   r>   X  s
   
zLLTMFunction.forwardc           	      C  sB   t j| | g| jR  }|d d \}}}}}|||||fS )N   )r   Zlltm_backward
contiguoussaved_tensors)	r   Zgrad_hZ	grad_cellr   Zd_old_hd_inputZ	d_weightsd_biasZ
d_old_cellr4   r4   r5   backwarda  s   zLLTMFunction.backwardN)r@   rA   rB   r   r>   r   r4   r4   r4   r5   r   V  s
    
r   c                      s2   e Zd ZdZd fddZdd Zd	d
 Z  ZS )r   aF  
    This recurrent unit is similar to an LSTM, but differs in that it lacks a forget
    gate and uses an Exponential Linear Unit (ELU) as its internal activation function.
    Because this unit never forgets, call it LLTM, or Long-Long-Term-Memory unit.
    It has both C++ and CUDA implementation, automatically switch according to the
    target device where put this module to.

    Args:
        input_features: size of input feature data
        state_size: size of the state of recurrent unit

    Referring to: https://pytorch.org/tutorials/advanced/cpp_extension.html
    input_featuresr   
state_sizec                   sV   t    || _|| _ttd| || | _ttdd| | _	| 
  d S )Nr   r$   )r'   r(   r   r   r   r   r:   emptyr   r   reset_parameters)r0   r   r   r2   r4   r5   r(   x  s   
zLLTM.__init__c                 C  s4   dt | j }|  D ]}|j| |
  qd S )Nr   )mathsqrtr   
parametersdatauniform_)r0   stdvrj   r4   r4   r5   r     s   zLLTM.reset_parametersc                 C  s   t j|| j| jg|R  S r9   )r   applyr   r   )r0   ri   stater4   r4   r5   r>     s   zLLTM.forward)r   r   r   r   )r@   rA   rB   rC   r(   r   r>   rD   r4   r4   r2   r5   r   i  s
    r   c                      rV   )ApplyFilterz,Wrapper class to apply a filter to an image.filterr   r8   rI   c                   s   t    t|tjd| _d S )Nr   )r'   r(   r   r:   float32r   )r0   r   r2   r4   r5   r(     s   
zApplyFilter.__init__r6   r7   c                 C  s   t || jS r9   )r   r   r=   r4   r4   r5   r>     s   zApplyFilter.forward)r   r   r8   rI   r?   rR   r4   r4   r2   r5   r     s    r   c                      "   e Zd ZdZd	 fddZ  ZS )

MeanFilterz
    Mean filtering can smooth edges and remove aliasing artifacts in an segmentation image.
    The mean filter used, is a `torch.Tensor` of all ones.
    r   r   rU   r8   rI   c                   s&   t |g| }|}t j|d dS )
        Args:
            spatial_dims: `int` of either 2 for 2D images and 3 for 3D images
            size: edge length of the filter
        r   N)r:   r   r'   r(   )r0   r   rU   r   r2   r4   r5   r(     s   zMeanFilter.__init__r   r   rU   r   r8   rI   r@   rA   rB   rC   r(   rD   r4   r4   r2   r5   r        r  c                      r   )
LaplaceFilterz
    Laplacian filtering for outline detection in images. Can be used to transform labels to contours.
    The laplace filter used, is a `torch.Tensor` where all values are -1, except the center value
    which is `size` ** `spatial_dims`
    r   r   rU   r8   rI   c                   sL   t |g|  d }t|d g| }|| d ||< t j|d dS )r  r$   r&   r  N)r:   rz   r   r/   r'   r(   )r0   r   rU   r   center_pointr2   r4   r5   r(     s   zLaplaceFilter.__init__r  r  r4   r4   r2   r5   r        r  c                      r   )
EllipticalFilterz
    Elliptical filter, can be used to dilate labels or label-contours.
    The elliptical filter used here, is a `torch.Tensor` with shape (size, ) * ndim containing a circle/sphere of `1`
    r   r   rU   r8   rI   c                   sb   d  t jfddt|D  }t  fdd|D dd}| d k}t j|d dS )r  r&   c                   s   g | ]}t d  qS )r   )r:   r   r   )rU   r4   r5   rf         z-EllipticalFilter.__init__.<locals>.<listcomp>c                   s   g | ]}|  d  qS )r&   r4   )rd   r   )r   r4   r5   rf     r  r   r  N)r:   meshgridr   stackrt   r'   r(   )r0   r   rU   gridZsquared_distancesr   r2   )r   rU   r5   r(     s
    zEllipticalFilter.__init__r  r  r4   r4   r2   r5   r
    r  r
  c                      r   )
SharpenFilterz
    Convolutional filter to sharpen a 2D or 3D image.
    The filter used contains a circle/sphere of `-1`, with the center value being
    the absolute sum of all non-zero elements in the kernel
    r   r   rU   r8   rI   c                   sH   t  j||d t|d g| }| j }|  jd9  _|| j|< dS )r  )r   rU   r&   rS   N)r'   r(   r/   r   rt   )r0   r   rU   r  Zcenter_valuer2   r4   r5   r(     s
   
zSharpenFilter.__init__r  r  r4   r4   r2   r5   r    r	  r  )r[   r7   r\   r]   r^   r_   r`   r   r   r   ra   rb   rc   r   r8   r7   )rz   )r6   r7   r\   r]   r"   r_   r8   r7   )r6   r7   r   r7   r8   r7   )r   r   r8   r7   )r   r   N)
r   r7   r%   r   r   r   r   r   r8   r7   )6
__future__r   r   copyr   typingr   r:   torch.nn.functionalr   
functionalr<   Ztorch.autogradr   monai.config.type_definitionsr   Zmonai.networks.layers.convutilsr	   Zmonai.networks.layers.factoriesr
   monai.utilsr   r   r   r   r   r   r   r   r   r   __all__Moduler   r   r   r   rn   r   r   r   r   r   r   r   r   r   r   r   r   r  r  r
  r  r4   r4   r4   r5   <module>   sL   $
0$	
%
-3C<<)8 