o
    %i`                      @  s   d dl mZ d dlmZ d dlZd dlZg dZd-d.ddZd/ddZ	d0ddZ
	d1d2d d!Zd3d"d#Zd4d%d&Zd4d'd(Zd5d+d,ZdS )6    )annotations)SequenceN)same_paddingstride_minus_kernel_paddingcalculate_out_shapegaussian_1dpolyval   kernel_sizeSequence[int] | intdilationreturntuple[int, ...] | intc                 C  s~   t | }t |}t |d | d dkr"td| d| d|d d | }tdd |D }t|dkr;|S |d S )	aS  
    Return the padding value needed to ensure a convolution using the given kernel size produces an output of the same
    shape as the input for a stride of 1, otherwise ensure a shape of the input divided by the stride rounded down.

    Raises:
        NotImplementedError: When ``np.any((kernel_size - 1) * dilation % 2 == 1)``.

    r	      z+Same padding not available for kernel_size=z and dilation=.c                 s      | ]}t |V  qd S Nint.0p r   a/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/layers/convutils.py	<genexpr>)       zsame_padding.<locals>.<genexpr>r   )np
atleast_1danyNotImplementedErrortuplelen)r
   r   kernel_size_npZdilation_np
padding_nppaddingr   r   r   r      s   


r   stridec                 C  sF   t | }t |}|| }tdd |D }t|dkr|S |d S )Nc                 s  r   r   r   r   r   r   r   r   3   r   z.stride_minus_kernel_padding.<locals>.<genexpr>r	   r   )r   r   r    r!   )r
   r%   r"   	stride_npZout_padding_npZout_paddingr   r   r   r   .   s
   

r   in_shape Sequence[int] | int | np.ndarrayr$   c           
      C  sV   t | }t |}t |}t |}|| | | | d }tdd |D }	|	S )a-  
    Calculate the output tensor shape when applying a convolution to a tensor of shape `inShape` with kernel size
    `kernel_size`, stride value `stride`, and input padding value `padding`. All arguments can be scalars or multiple
    values, return value is a scalar if all inputs are scalars.
    r	   c                 s  r   r   r   )r   sr   r   r   r   I   r   z&calculate_out_shape.<locals>.<genexpr>)r   r   r    )
r'   r
   r%   r$   Zin_shape_npr"   r&   r#   Zout_shape_np	out_shaper   r   r   r   8   s   



r         @erfFsigmatorch.Tensor	truncatedfloatapproxstr	normalizeboolc                 C  s  t j| t jt| t jr| jndd} | j}|dkr!td| dttt| | dd }|	 dkrat j
| |d t j|d}d	t |  }d||d   ||d     }|jd
d}nx|	 dkrt j
| |d t j| jd}t d| |   |d  }|s|d|   }nN|	 dkr| |  }	dg|d  }
t|	|
d
< t|	|
d< tdt|
D ]	}t||	|
|< q|
dd
d }||
 t |t |	  }ntd| d|r||  S |S )a  
    one dimensional Gaussian kernel.

    Args:
        sigma: std of the kernel
        truncated: tail length
        approx: discrete Gaussian kernel type, available options are "erf", "sampled", and "scalespace".

            - ``erf`` approximation interpolates the error function;
            - ``sampled`` uses a sampled Gaussian kernel;
            - ``scalespace`` corresponds to
              https://en.wikipedia.org/wiki/Scale_space_implementation#The_discrete_Gaussian_kernel
              based on the modified Bessel functions.

        normalize: whether to normalize the kernel with `kernel.sum()`.

    Raises:
        ValueError: When ``truncated`` is non-positive.

    Returns:
        1D torch tensor

    Ndtypedevice        z truncated must be positive, got r         ?r,   r	   g'e?r   )minZsampledg      r   g@Z
scalespacezUnsupported option: approx='z'.)torch	as_tensorr0   
isinstanceTensorr7   
ValueErrorr   maxlowerarangeabsr,   clampexp_modified_bessel_0_modified_bessel_1ranger!   _modified_bessel_iextendstackr   sum)r-   r/   r1   r3   r7   tailxtoutsigma2Zout_poskr   r   r   r   N   s8   $$
r   c                 C  s   t |tjr	|jnd}tj| tj|d} | jdkst| dk r%t|j	S tj|tj|d}| d }| dd D ]}|| | }q8|S )a  
    Evaluates the polynomial defined by `coef` at `x`.

    For a 1D sequence of coef (length n), evaluate::

        y = coef[n-1] + x * (coef[n-2] + ... + x * (coef[1] + x * coef[0]))

    Args:
        coef: a sequence of floats representing the coefficients of the polynomial
        x: float or a sequence of floats representing the variable of the polynomial

    Returns:
        1D torch tensor
    Nr5   r   r	   )
r>   r<   r?   r7   r=   r0   ndimr!   zerosshape)coefrO   r7   anscr   r   r   r      s   r   rO   c                 C  s   t j| t jt| t jr| jnd d} t | dk r&| |  d }tg d|S t | }d| }g d}t||t | t 	| S )Nr5         @      ,@)gtHZr?gIx?g2t?g,?N?g03@g$@      ?)	g;^p?gUL+ߐgZ?g'gTPÂ?gJNYgՒ+Hub?g-5? e3E?
r<   r=   r0   r>   r?   r7   rD   r   rF   sqrt)rO   yax_coefr   r   r   rG      s   $
rG   c                 C  s   t j| t jt| t jr| jnd d} t | dk r-| |  d }g d}t | t|| S t | }d| }g d}t||t | t 	| }| dk rP| S |S )Nr5   rZ   r[   )gӰ٩=5?g.h?gZ9?g*O?g(z?gY?r9   )	g;PJ4qgqJ:N?gP⥝g'8`?g<Q gtZOZ?g?Vmg.kr]   r8   r^   )rO   r`   rb   ra   rX   r   r   r   rH      s   $
rH   nr   c           
   	   C  s6  | dk rt d|  dtj|tjt|tjr|jnd d}|dkr$|S |j}dt| }tjd|dtjd|dtjd|d}}}t	d| t
t
d	|    }t|d
dD ](}|t|| |  }	|}|	}t|dkr}|d }|d }|d }|| kr|}q[|t| | }|dk r| d dkr| S |S )Nr   z n must be greater than 1, got n=r   r5   r8   g       @)r7   r\   g      D@r   r;   g    _Bg|=r	   )r@   r<   r=   r0   r>   r?   r7   rD   tensorr   r   floorr_   rI   rG   )
rc   rO   r7   ZtoxrX   ZbipbimjZbimr   r   r   rJ      s,   $. rJ   )r	   )r
   r   r   r   r   r   )r
   r   r%   r   r   r   )
r'   r(   r
   r   r%   r   r$   r   r   r   )r+   r,   F)
r-   r.   r/   r0   r1   r2   r3   r4   r   r.   )r   r.   )rO   r.   r   r.   )rc   r   rO   r.   r   r.   )
__future__r   collections.abcr   numpyr   r<   __all__r   r   r   r   r   rG   rH   rJ   r   r   r   r   <module>   s   



8

