o
    iI                     @  s   d Z ddlmZ ddlmZmZ ddlZddlmZmZ ddl	m
Z
 ddlmZ ddlmZ G d	d
 d
ejZG dd deZdS )z~
This script is adapted from
https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/anchor_utils.py
    )annotations)ListSequenceN)Tensornn)ensure_tuple)issequenceiterable)look_up_optionc                      sp   e Zd ZdZdeej iZ			d*d+ fddZej	dfd,ddZ
d-ddZdd Zd.d#d$Zd/d(d)Z  ZS )0AnchorGeneratora
  
    This module is modified from torchvision to support both 2D and 3D images.

    Module that generates anchors for a set of feature maps and
    image sizes.

    The module support computing anchors at multiple sizes and aspect ratios
    per feature map.

    sizes and aspect_ratios should have the same number of elements, and it should
    correspond to the number of feature maps.

    sizes[i] and aspect_ratios[i] can have an arbitrary number of elements.
    For 2D images, anchor width and height w:h = 1:aspect_ratios[i,j]
    For 3D images, anchor width, height, and depth w:h:d = 1:aspect_ratios[i,j,0]:aspect_ratios[i,j,1]

    AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
    per spatial location for feature map i.

    Args:
        sizes: base size of each anchor.
            len(sizes) is the number of feature maps, i.e., the number of output levels for
            the feature pyramid network (FPN).
            Each element of ``sizes`` is a Sequence which represents several anchor sizes for each feature map.
        aspect_ratios: the aspect ratios of anchors. ``len(aspect_ratios) = len(sizes)``.
            For 2D images, each element of ``aspect_ratios[i]`` is a Sequence of float.
            For 3D images, each element of ``aspect_ratios[i]`` is a Sequence of 2 value Sequence.
        indexing: choose from {``'ij'``, ``'xy'``}, optional,
            Matrix (``'ij'``, default and recommended) or Cartesian (``'xy'``) indexing of output.

            - Matrix (``'ij'``, default and recommended) indexing keeps the original axis not changed.
            - To use other monai detection components, please set ``indexing = 'ij'``.
            - Cartesian (``'xy'``) indexing swaps axis 0 and 1.
            - For 2D cases, monai ``AnchorGenerator(sizes, aspect_ratios, indexing='xy')`` and
              ``torchvision.models.detection.anchor_utils.AnchorGenerator(sizes, aspect_ratios)`` are equivalent.


    Reference:.
        https://github.com/pytorch/vision/blob/release/0.12/torchvision/models/detection/anchor_utils.py

    Example:
        .. code-block:: python

            # 2D example inputs for a 2-level feature maps
            sizes = ((10,12,14,16), (20,24,28,32))
            base_aspect_ratios = (1., 0.5,  2.)
            aspect_ratios = (base_aspect_ratios, base_aspect_ratios)
            anchor_generator = AnchorGenerator(sizes, aspect_ratios)

            # 3D example inputs for a 2-level feature maps
            sizes = ((10,12,14,16), (20,24,28,32))
            base_aspect_ratios = ((1., 1.), (1., 0.5), (0.5, 1.), (2., 2.))
            aspect_ratios = (base_aspect_ratios, base_aspect_ratios)
            anchor_generator = AnchorGenerator(sizes, aspect_ratios)
    cell_anchors)      (   ))      ?   )r   r   ijsizesSequence[Sequence[int]]aspect_ratiosr   indexingstrreturnNonec                   s   t    t|d stdd |D  _nt| _t|d s)|ft j }t jt|kr6tdtt|d d d }t|ddg}| _	t|dd	g _
| _ fd
dt j|D  _d S )Nr   c                 s  s    | ]}|fV  qd S N .0sr   r   i/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/detection/utils/anchor_utils.py	<genexpr>x   s    z+AnchorGenerator.__init__.<locals>.<genexpr>zllen(sizes) and len(aspect_ratios) should be equal.                 It represents the number of feature maps.r         r   xyc                   s   g | ]
\}}  ||qS r   )generate_anchors)r   sizeaspect_ratioselfr   r    
<listcomp>   s    z,AnchorGenerator.__init__.<locals>.<listcomp>)super__init__r   tupler   r   len
ValueErrorr	   spatial_dimsr   r   zipr   )r)   r   r   r   r0   	__class__r(   r    r,   o   s$   



zAnchorGenerator.__init__Nscalesdtypetorch.dtypedevicetorch.device | Nonetorch.Tensorc                 C  s  t j|||d}t j|||d}| jdkr1t|jdkr1td| j dt|jd  d|j d| jdkrS|jd | jd krStd| j d	| jd  d
|j d| jdkrdt |}d| }|}	n,t |dddf |dddf  d}d| }|dddf | }	|dddf | }
|dddf |dddf  d}|	dddf |dddf  d}| jdkrt j	| | ||gddd }|
 S |
dddf |dddf  d}t j	| | | |||gddd }|
 S )a  
        Compute cell anchor shapes at multiple sizes and aspect ratios for the current feature map.

        Args:
            scales: a sequence which represents several anchor sizes for the current feature map.
            aspect_ratios: a sequence which represents several aspect_ratios for the current feature map.
                For 2D images, it is a Sequence of float aspect_ratios[j],
                anchor width and height w:h = 1:aspect_ratios[j].
                For 3D images, it is a Sequence of 2 value Sequence aspect_ratios[j,0] and aspect_ratios[j,1],
                anchor width, height, and depth w:h:d = 1:aspect_ratios[j,0]:aspect_ratios[j,1]
            dtype: target data type of the output Tensor.
            device: target device to put the output Tensor data.

            Returns:
                For each s in scales, returns [s, s*aspect_ratios[j]] for 2D images,
                and [s, s*aspect_ratios[j,0],s*aspect_ratios[j,1]] for 3D images.
        r5   r7   r#   r"   zIn zA-D image, aspect_ratios for each level should be                 r   z%-D. But got aspect_ratios with shape .zK-D image, aspect_ratios for each level should has                 shape (_,z$). But got aspect_ratios with shape Nr   gUUUUUU?dim       @)torch	as_tensorr0   r.   shaper/   sqrtpowviewstackround)r)   r4   r   r5   r7   Zscales_tZaspect_ratios_tZ
area_scalew_ratiosh_ratiosZd_ratioswshsbase_anchorsdsr   r   r    r%      sD   



(&&
&$z AnchorGenerator.generate_anchorstorch.devicec                   s    fdd| j D | _ dS )z`
        Convert each element in self.cell_anchors to ``dtype`` and send to ``device``.
        c                   s   g | ]	}|j  d qS r:   )to)r   cell_anchorr7   r5   r   r    r*          z4AnchorGenerator.set_cell_anchors.<locals>.<listcomp>Nr   )r)   r5   r7   r   rR   r    set_cell_anchors   s   z AnchorGenerator.set_cell_anchorsc                 C  s   dd | j D S )zF
        Return number of anchor shapes for each feature map.
        c                 S  s   g | ]}|j d  qS )r   )rB   )r   cr   r   r    r*      s    z<AnchorGenerator.num_anchors_per_location.<locals>.<listcomp>rT   r(   r   r   r    num_anchors_per_location   s   z(AnchorGenerator.num_anchors_per_location
grid_sizeslist[list[int]]strideslist[list[Tensor]]list[Tensor]c           	   
     s4  g }| j }|du rtt|t|  krt|ks#td tdt|||D ]n\}|j  fddt| jD }tt	j
|d| j dd}t| jD ]}|| d||< qR| jdkrp|d	 |d
 |d
< |d	< t	j|d d
d}||dd
| jd |d
d| jd  d| jd  q)|S )ai  
        Every combination of (a, (g, s), i) in (self.cell_anchors, zip(grid_sizes, strides), 0:spatial_dims)
        corresponds to a feature map.
        It outputs g[i] anchors that are s[i] distance apart in direction i, with the same dimensions as a.

        Args:
            grid_sizes: spatial size of the feature maps
            strides: strides of the feature maps regarding to the original image

        Example:
            .. code-block:: python

                grid_sizes = [[100,100],[50,50]]
                strides = [[torch.tensor(2),torch.tensor(2)], [torch.tensor(4),torch.tensor(4)]]
        NzAnchors should be Tuple[Tuple[int]] because each feature map could potentially have different sizes and aspect ratios. There needs to be a match between the number of feature maps passed and the number of sizes / aspect ratios specified.c                   s,   g | ]}t jd | t j d|  qS )r   r:   )r@   arangeint32r   axisr7   r&   strider   r    r*      s    z0AnchorGenerator.grid_anchors.<locals>.<listcomp>r   )r   r<   r$   r   r   r"   r=   )r   AssertionErrorr.   r/   r1   r7   ranger0   listr@   meshgridreshaper   rF   appendrE   )	r)   rX   rZ   anchorsr   rL   Zshifts_centersr`   shiftsr   ra   r    grid_anchors   s:   
(
zAnchorGenerator.grid_anchorsimagesr   feature_mapsc           	        s   fdd|D }|j j d |j d }|d j|d j}  fdd|D }|  ||}tt|}|g| S )aF  
        Generate anchor boxes for each image.

        Args:
            images: sized (B, C, W, H) or (B, C, W, H, D)
            feature_maps: for FPN level i, feature_maps[i] is sized (B, C_i, W_i, H_i) or (B, C_i, W_i, H_i, D_i).
                This input argument does not have to be the actual feature maps.
                Any list variable with the same (C_i, W_i, H_i) or (C_i, W_i, H_i, D_i) as feature maps works.

        Return:
            A list with length of B. Each element represents the anchors for this image.
            The B elements are identical.

        Example:
            .. code-block:: python

                images = torch.zeros((3,1,128,128,128))
                feature_maps = [torch.zeros((3,6,64,64,32)), torch.zeros((3,6,32,32,16))]
                anchor_generator(images, feature_maps)
        c                   s"   g | ]}t |j j d  qS r   )re   rB   r0   )r   feature_mapr(   r   r    r*   1  s   " z+AnchorGenerator.forward.<locals>.<listcomp>Nr   c                   s(   g | ]  fd dt jD qS )c                   s*   g | ]}t j| |  t j d qS rO   )r@   tensorint64r_   )r7   g
image_sizer   r    r*   6  s    z6AnchorGenerator.forward.<locals>.<listcomp>.<listcomp>)rd   r0   )r   r7   rr   r)   )rq   r    r*   5  s    )	rB   r0   r5   r7   rU   rk   r@   catre   )	r)   rl   rm   rX   	batchsizer5   rZ   anchors_over_all_feature_mapsanchors_per_imager   rs   r    forward  s   

zAnchorGenerator.forward)r   r   r   )r   r   r   r   r   r   r   r   )
r4   r   r   r   r5   r6   r7   r8   r   r9   )r5   r6   r7   rN   r   r   )rX   rY   rZ   r[   r   r\   )rl   r   rm   r\   r   r\   )__name__
__module____qualname____doc__r   r@   r   __annotations__r,   float32r%   rU   rW   rk   rx   __classcell__r   r   r2   r    r
   4   s    8)
<
@r
   c                   @  sF   e Zd ZdZdeej iZ			ddddZe	ej
dfdddZdS )AnchorGeneratorWithAnchorShapea  
    Module that generates anchors for a set of feature maps and
    image sizes, inherited from :py:class:`~monai.apps.detection.networks.utils.anchor_utils.AnchorGenerator`

    The module support computing anchors at multiple base anchor shapes
    per feature map.

    ``feature_map_scales`` should have the same number of elements with the number of feature maps.

    base_anchor_shapes can have an arbitrary number of elements.
    For 2D images, each element represents anchor width and height [w,h].
    For 2D images, each element represents anchor width, height, and depth [w,h,d].

    AnchorGenerator will output a set of ``len(base_anchor_shapes)`` anchors
    per spatial location for feature map ``i``.

    Args:
        feature_map_scales: scale of anchors for each feature map, i.e., each output level of
            the feature pyramid network (FPN). ``len(feature_map_scales)`` is the number of feature maps.
            ``scale[i]*base_anchor_shapes`` represents the anchor shapes for feature map ``i``.
        base_anchor_shapes: a sequence which represents several anchor shapes for one feature map.
            For N-D images, it is a Sequence of N value Sequence.
        indexing: choose from {'xy', 'ij'}, optional
            Cartesian ('xy') or matrix ('ij', default) indexing of output.
            Cartesian ('xy') indexing swaps axis 0 and 1, which is the setting inside torchvision.
            matrix ('ij', default) indexing keeps the original axis not changed.
            See also indexing in https://pytorch.org/docs/stable/generated/torch.meshgrid.html

    Example:
        .. code-block:: python

            # 2D example inputs for a 2-level feature maps
            feature_map_scales = (1, 2)
            base_anchor_shapes = ((10, 10), (6, 12), (12, 6))
            anchor_generator = AnchorGeneratorWithAnchorShape(feature_map_scales, base_anchor_shapes)

            # 3D example inputs for a 2-level feature maps
            feature_map_scales = (1, 2)
            base_anchor_shapes = ((10, 10, 10), (12, 12, 8), (10, 10, 6), (16, 16, 10))
            anchor_generator = AnchorGeneratorWithAnchorShape(feature_map_scales, base_anchor_shapes)
    r   r   r"         )    r   r   )0   r   r   )r   r   r   )r   r   r   r   feature_map_scalesSequence[int] | Sequence[float]base_anchor_shapes3Sequence[Sequence[int]] | Sequence[Sequence[float]]r   r   r   r   c                   s`   t j t|d }t|ddg}|_t|ddg_t|  fdd|D _	d S )Nr   r"   r#   r   r$   c                   s   g | ]	} |  qS r   )generate_anchors_using_shaper   Zbase_anchor_shapes_tr)   r   r    r*     rS   z;AnchorGeneratorWithAnchorShape.__init__.<locals>.<listcomp>)
r   Moduler,   r.   r	   r0   r   r@   r   r   )r)   r   r   r   r0   r   r   r    r,   q  s   
z'AnchorGeneratorWithAnchorShape.__init__Nanchor_shapesr9   r5   r6   r7   r8   c                 C  s.   | d }t j| |gdd}| j||dS )a  
        Compute cell anchor shapes at multiple sizes and aspect ratios for the current feature map.

        Args:
            anchor_shapes: [w, h] or [w, h, d], sized (N, spatial_dims),
                represents N anchor shapes for the current feature map.
            dtype: target data type of the output Tensor.
            device: target device to put the output Tensor data.

        Returns:
            For 2D images, returns [-w/2, -h/2, w/2, h/2];
            For 3D images, returns [-w/2, -h/2, -d/2, w/2, h/2, d/2]
        r?   r   r=   r:   )r@   rt   rG   rP   )r   r5   r7   Zhalf_anchor_shapesrL   r   r   r    r     s   z;AnchorGeneratorWithAnchorShape.generate_anchors_using_shape)r   r   r   )r   r   r   r   r   r   r   r   )r   r9   r5   r6   r7   r8   r   r9   )ry   rz   r{   r|   r   r@   r   r}   r,   staticmethodr~   r   r   r   r   r    r   D  s    *r   )r|   
__future__r   typingr   r   r@   r   r   monai.utilsr   monai.utils.miscr   monai.utils.moduler	   r   r
   r   r   r   r   r    <module>   s   "  