o
    $ir                     @  sp  d Z ddlmZ ddlZddlZddlZddlZddlmZ ddl	m
Z
mZmZ ddlmZ ddlmZ ddlmZmZ ddlZddlZddlmZ dd	lmZ dd
lmZ ddlmZmZm Z  ddl!m"Z"m#Z# ddl$m%Z%m&Z& e#d\Z'Z(e#d\Z)Z(e#d\Z*Z(e#d\Z+Z,e#dd\Z-Z(g dZ.ee/dZ0da1dddZ2dd Z3ddd!d"Z4dd#d$Z5ej6d%fdd.d/Z7ddd7d8Z8				ddd>d?Z9		dddCdDZ:dEejj;j<fddJdKZ=ej;j>fdLdMZ?ddQdRZ@ddSdTZAeddWdXZBeddYdZZCdd]d^ZD	_			`	dddcddZEddhdiZF										j	0	`	`	kddddZG						j	0	ddddZHdddZI										0ddddZJdd ZKdd ZL	`	`ddddZM	`	`ddddZNe	`	`ddddZOddddZPG dd dejQZRejSejTfddZUejSejTfddZVG dd dejjQZWG dd dejjQZXdddZYdddZZdddZ[dddĄZ\dddƄZ]dS )zE
Utilities and types for defining networks, these depend on PyTorch.
    )annotationsN)OrderedDict)CallableMappingSequence)contextmanager)deepcopy)AnyIterable)
get_logger)PathLike)ensure_tuplesave_objset_determinism)look_up_optionoptional_import)convert_to_dst_typeconvert_to_tensoronnxzonnx.referenceonnxruntime
polygraphytorch_tensorrtz1.4.0)one_hotpredict_segmentationnormalize_transformto_norm_affineCastTempTypenormal_init	icnr_initpixelshufflepixelunshuffle	eval_mode
train_modeget_state_dictcopy_model_state
save_stateconvert_to_onnxconvert_to_torchscriptconvert_to_trtmeshgrid_ijmeshgrid_xyreplace_modulesreplace_modules_templook_up_named_moduleset_named_modulehas_nvfuser_instance_normget_profile_shapes)module_nameinput_shapeSequence[int]dynamic_batchsizeSequence[int] | Nonec                 C  sP   ddd}|r|| |d }|| |d }|| |d	 }n|  } }}|||fS )zb
    Given a sample input shape, calculate min/opt/max shapes according to dynamic_batchsize.
    r2   r3   	scale_numintc                 S  s   g | }||d< |S )Nr    )r2   r6   Zscale_shaper8   r8   V/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/utils.pyscale_batch_sizeQ   s   z,get_profile_shapes.<locals>.scale_batch_sizer         N)r2   r3   r6   r7   r8   )r2   r4   r:   min_input_shapeopt_input_shapemax_input_shaper8   r8   r9   r0   L   s   

r0   c                  C  sV   t durt S tddd\} a t sdS zddl}|d W t S  ty*   da Y t S w )zwhether the current environment has InstanceNorm3dNVFuser
    https://github.com/NVIDIA/apex/blob/23.05-devel/apex/normalization/instance_norm.py#L15-L16
    Nzapex.normalizationZInstanceNorm3dNVFusernameFr   Zinstance_norm_nvfuser_cuda)_has_nvfuserr   	importlibimport_moduleImportError)_rC   r8   r8   r9   r/   `   s   r/   FrA   strc                 C  s   t | dd | D d|d}|du rdS |dkr|S |dD ](}| r-|t| }q t |dd | D ddd}|du rC dS t||}q |S )	a  
    get the named module in `mod` by the attribute name,
    for example ``look_up_named_module(net, "features.3.1.attn")``

    Args:
        name: a string representing the module attribute.
        mod: a pytorch module to be searched (in ``mod.named_modules()``).
        print_all_options: whether to print all named modules when `name` is not found in `mod`. Defaults to False.

    Returns:
        the corresponding pytorch module's subcomponent such as ``net.features[3][1].attn``
    c                 S     h | ]}|d  qS r   r8   .0nr8   r8   r9   	<setcomp>       z'look_up_named_module.<locals>.<setcomp>N)defaultprint_all_options .c                 S  rH   rI   r8   )rK   itemr8   r8   r9   rM      rN   F)r   named_modulessplitisdigitr7   getattr)rA   modrP   name_strrL   r8   r8   r9   r-   t   s   r-   c                 C  sJ   | dd}t|dkr|nd|f\}}|s|S t|| }t||| | S )a  
    look up `name` in `mod` and replace the layer with `new_layer`, return the updated `mod`.

    Args:
        mod: a pytorch module to be updated.
        name: a string representing the target module attribute.
        new_layer: a new module replacing the corresponding layer at ``mod.name``.

    Returns:
        an updated ``mod``

    See also: :py:func:`monai.networks.utils.look_up_named_module`.
    rR   r;   r<   rQ   )rsplitlenr-   setattr)rX   rA   Z	new_layerZ	mods_attrsubmodsattr_modr8   r8   r9   r.      s   
r.   r;   labelstorch.Tensornum_classesr7   dtypetorch.dtypedimreturnc                 C  s   | j |d k rt| jdg|d t| j   }t| |} t| j}|| dkr-td|||< tj||| jd}|j	|| 
 dd} | S )a  
    For every value v in `labels`, the value in the output will be either 1 or 0. Each vector along the `dim`-th
    dimension has the "one-hot" format, i.e., it has a total length of `num_classes`,
    with a one and `num_class-1` zeros.
    Note that this will include the background label, thus a binary mask should be treated as having two classes.

    Args:
        labels: input tensor of integers to be converted into the 'one-hot' format. Internally `labels` will be
            converted into integers `labels.long()`.
        num_classes: number of output channels, the corresponding length of `labels[dim]` will be converted to
            `num_classes` from `1`.
        dtype: the data type of the output one_hot label.
        dim: the dimension to be converted to `num_classes` channels from `1` channel, should be non-negative number.

    Example:

    For a tensor `labels` of dimensions [B]1[spatial_dims], return a tensor of dimensions `[B]N[spatial_dims]`
    when `num_classes=N` number of classes and `dim=1`.

    .. code-block:: python

        from monai.networks.utils import one_hot
        import torch

        a = torch.randint(0, 2, size=(1, 2, 2, 2))
        out = one_hot(a, num_classes=2, dim=0)
        print(out.shape)  # torch.Size([2, 2, 2, 2])

        a = torch.randint(0, 2, size=(2, 1, 2, 2, 2))
        out = one_hot(a, num_classes=2, dim=1)
        print(out.shape)  # torch.Size([2, 2, 2, 2, 2])

    r;   z6labels should have a channel with length equal to one.)sizerc   device)re   indexvalue)ndimlistshaper[   torchreshapeAssertionErrorzerosrh   scatter_long)r`   rb   rc   re   rm   shor8   r8   r9   r      s   $"
r           logitsmutually_exclusivebool	thresholdfloatr	   c                 C  sB   |s| |k  S | jd dkrtd | |k  S | jdddS )a%  
    Given the logits from a network, computing the segmentation by thresholding all values above 0
    if multi-labels task, computing the `argmax` along the channel axis if multi-classes task,
    logits has shape `BCHW[D]`.

    Args:
        logits: raw data of model output.
        mutually_exclusive: if True, `logits` will be converted into a binary matrix using
            a combination of argmax, which is suitable for multi-classes task. Defaults to False.
        threshold: thresholding the prediction values if multi-labels task.
    r;   zTsingle channel prediction, `mutually_exclusive=True` ignored, use threshold instead.T)keepdim)r7   rm   warningswarnargmax)rw   rx   rz   r8   r8   r9   r      s   
r   rh   torch.device | str | Nonetorch.dtype | Nonealign_cornerszero_centeredc              
   C  s  t | tj|ddd} |   jtj|d}|rDd||dk< d|r#|n|d  }tt|tjdtj|df}|sCd|d	d
d
f< n/d||dk< d|rQ|d n| }tt|tjdtj|df}|ssd|  d |d	d
d
f< |	dj|d}d|_
|S )a  
    Compute an affine matrix according to the input shape.
    The transform normalizes the homogeneous image coordinates to the
    range of `[-1, 1]`.  Currently the following source coordinates are supported:

        - `align_corners=False`, `zero_centered=False`, normalizing from ``[-0.5, d-0.5]``.
        - `align_corners=True`, `zero_centered=False`, normalizing from ``[0, d-1]``.
        - `align_corners=False`, `zero_centered=True`, normalizing from ``[-(d-1)/2, (d-1)/2]``.
        - `align_corners=True`, `zero_centered=True`, normalizing from ``[-d/2, d/2]``.

    Args:
        shape: input spatial shape, a sequence of integers.
        device: device on which the returned affine will be allocated.
        dtype: data type of the returned affine
        align_corners: if True, consider -1 and 1 to refer to the centers of the
            corner pixels rather than the image corners.
            See also: https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample
        zero_centered: whether the coordinates are normalized from a zero-centered range, default to `False`.
            Setting this flag and `align_corners` will jointly specify the normalization source range.
    TF)rh   wrap_sequence
track_meta)rc   rh   g       @      ?)r;   g      Nrv   r   rc   )r   rn   float64clonedetachtodiagcatones	unsqueezerequires_grad)rm   rh   rc   r   r   normr8   r8   r9   r      s"   ""r   affinesrc_sizedst_sizec                 C  s   t | tjstdt| j d|  dks!| jd | jd kr,tdt	| j d| jd d }|t
|ks?|t
|krQtd| dt
| d	t
| d
t|| j| j||}t|d| j||}||  ttj| | dd  S )a  
    Given ``affine`` defined for coordinates in the pixel space, compute the corresponding affine
    for the normalized coordinates.

    Args:
        affine: Nxdxd batched square matrix
        src_size: source image spatial shape
        dst_size: target image spatial shape
        align_corners: if True, consider -1 and 1 to refer to the centers of the
            corner pixels rather than the image corners.
            See also: https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample
        zero_centered: whether the coordinates are normalized from a zero-centered range, default to `False`.
            See also: :py:func:`monai.networks.utils.normalize_transform`.

    Raises:
        TypeError: When ``affine`` is not a ``torch.Tensor``.
        ValueError: When ``affine`` is not Nxdxd.
        ValueError: When ``src_size`` or ``dst_size`` dimensions differ from ``affine``.

    z%affine must be a torch.Tensor but is rR      r;   r<   zaffine must be Nxdxd, got zaffine suggests zD, got src=zD, dst=zD.cpu)dstr   )
isinstancern   Tensor	TypeErrortype__name__
ndimensionrm   
ValueErrortupler[   r   rh   rc   r   nplinalginvnumpy)r   r   r   r   r   srZ	src_xformZ	dst_xformr8   r8   r9   r   !  s    $$r   g{Gz?stdnormal_func+Callable[[torch.Tensor, float, float], Any]Nonec                 C  s   | j j}t| dddur7|ddks|ddkr7|| jjd| t| dddur5tj| j	jd dS dS |ddkrQ|| jjd	| tj| j	jd
 dS dS )a  
    Initialize the weight and bias tensors of `m' and its submodules to values from a normal distribution with a
    stddev of `std'. Weight tensors of convolution and linear modules are initialized with a mean of 0, batch
    norm modules with a mean of 1. The callable `normal_func', used to assign values, should have the same arguments
    as its default normal_(). This can be used with `nn.Module.apply` to visit submodules of a network.
    weightNZConvr   Linearrv   biasZ	BatchNormr   r   )
	__class__r   rW   findr   datanninit	constant_r   )mr   r   cnamer8   r8   r9   r   I  s   	,r   c           	      C  s   | j j^}}}|t| }t|| }t||g| }||}|dd}|||d}|dd|}|||g| }|dd}| j j	
| dS )z
    ICNR initialization for 2D/3D kernels adapted from Aitken et al.,2017 , "Checkerboard artifact free
    sub-pixel convolution".
    r   r;   r   N)r   rm   r[   r7   rn   rq   	transposero   repeatr   copy_)	convZupsample_factorr   out_channelsin_channelsdimsscale_factorZoc2kernelr8   r8   r9   r   ^  s   r   xspatial_dimsr   c              
     s  ||} t |  }|dd \}} | }|| dkr.td| d  d| d| d	t|| }||g fd	d
|dd D  }	t tddd|  }
|
|d |
d|  }
ddg}t|D ]}||
|d|  qd| ||g g|  |dd  } | ||	} | S )a  
    Apply pixel shuffle to the tensor `x` with spatial dimensions `spatial_dims` and scaling factor `scale_factor`.

    See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution
    Using a nEfficient Sub-Pixel Convolutional Neural Network."

    See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".

    Args:
        x: Input tensor with shape BCHW[D]
        spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D
        scale_factor: factor to rescale the spatial dimensions by, must be >=1

    Returns:
        Reshuffled version of `x`.

    Raises:
        ValueError: When input channels of `x` are not divisible by (scale_factor ** spatial_dims)
    Nr<   r   zNumber of input channels (z:) must be evenly divisible by scale_factor ** dimensions (z**=z).c                   s   g | ]}|  qS r8   r8   rK   dfactorr8   r9   
<listcomp>  rN   z pixelshuffle.<locals>.<listcomp>r;   )rl   rg   r   r7   rangeextendro   permute)r   r   r   re   
input_size
batch_sizechannelsZscale_divisorZorg_channelsoutput_sizeindicespermute_indicesidxr8   r   r9   r   r  s0   
"$r   c                   s  ||} t |  }|dd \}} | }|| }t fdd|dd D r8td  d|dd  ||g fdd|dd D  }	||gt fd	d|dd D g  }
d
dgdd t|D  dd t|D  }| |
|} | |	} | S )az  
    Apply pixel unshuffle to the tensor `x` with spatial dimensions `spatial_dims` and scaling factor `scale_factor`.
    Inverse operation of pixelshuffle.

    See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution
    Using an Efficient Sub-Pixel Convolutional Neural Network."

    See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".

    Args:
        x: Input tensor with shape BCHW[D]
        spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D
        scale_factor: factor to reduce the spatial dimensions by, must be >=1

    Returns:
        Unshuffled version of `x` with shape (B, C*(r**d), H/r, W/r) for 2D
        or (B, C*(r**d), D/r, H/r, W/r) for 3D, where r is the scale_factor
        and d is spatial_dims.

    Raises:
        ValueError: When spatial dimensions are not divisible by scale_factor
    Nr<   c                 3  s    | ]	}|  d kV  qdS )r   Nr8   r   r   r8   r9   	<genexpr>  s    z!pixelunshuffle.<locals>.<genexpr>z3All spatial dimensions must be divisible by factor z. , spatial shape is: c                   s   g | ]}|  qS r8   r8   r   r   r8   r9   r     rN   z"pixelunshuffle.<locals>.<listcomp>c                   s   g | ]}|   gqS r8   r8   r   r   r8   r9   r         r   r;   c                 S  s   g | ]}d | d qS )r<   r   r8   rK   ir8   r8   r9   r     r   c                 S  s   g | ]}d | d  qS )r<   r8   r   r8   r8   r9   r     r   )rl   rg   anyr   sumr   ro   r   )r   r   r   re   r   r   r   Zscale_factor_multZnew_channelsr   Zreshaped_sizer   r8   r   r9   r      s   
"(,
r    nets	nn.Modulec               
   g  s    dd | D }z-t   dd | D V  W d   n1 s w   Y  W |D ]}t|dr3|  q(dS |D ]}t|drC|  q8w )a  
    Set network(s) to eval mode and then return to original state at the end.

    Args:
        nets: Input network(s)

    Examples

    .. code-block:: python

        t=torch.rand(1,1,16,16)
        p=torch.nn.Conv2d(1,1,3)
        print(p.training)  # True
        with eval_mode(p):
            print(p.training)  # False
            print(p(t).sum().backward())  # will correctly raise an exception as gradients are calculated
    c                 S  s    g | ]}t |d r|jr|qS traininghasattrr   rJ   r8   r8   r9   r          zeval_mode.<locals>.<listcomp>c                 S  "   g | ]}t |d r| n|qS )eval)r   r   rJ   r8   r8   r9   r        " Ntrain)rn   no_gradr   r   )r   r   rL   r8   r8   r9   r!     s    


r!   c               
   g  s    dd | D }z.t d dd | D V  W d   n1 s!w   Y  W |D ]}t|dr4|  q)dS |D ]}t|drD|  q9w )a  
    Set network(s) to train mode and then return to original state at the end.

    Args:
        nets: Input network(s)

    Examples

    .. code-block:: python

        t=torch.rand(1,1,16,16)
        p=torch.nn.Conv2d(1,1,3)
        p.eval()
        print(p.training)  # False
        with train_mode(p):
            print(p.training)  # True
            print(p(t).sum().backward())  # No exception
    c                 S  s    g | ]}t |d r|js|qS r   r   rJ   r8   r8   r9   r     r   ztrain_mode.<locals>.<listcomp>Tc                 S  r   )r   )r   r   rJ   r8   r8   r9   r     r   Nr   )rn   set_grad_enabledr   r   )r   Z	eval_listrL   r8   r8   r9   r"     s    

r"   objtorch.nn.Module | Mappingc                 C  s0   t | tjtjjfr| j} t| dr|  S | S )z
    Get the state dict of input object if has `state_dict`, otherwise, return object directly.
    For data parallel model, automatically convert it to regular model first.

    Args:
        obj: input object to check and get the state_dict.

    
state_dict)r   r   DataParallelparallelDistributedDataParallelmoduler   r   )r   r8   r8   r9   r#     s   	r#   rQ   Tr   srcc                   s  t |}tt | } fdd|D }	t|t }
}| D ]#\}}| | }||v rB||	vrB|| j|jkrB|||< || q|rG|ni D ]9}| ||  }||v r||	vr|| j|| jkrwtd|| j d|| j d || ||< || qI|dur| D ]"\}}|||}|dur|d |	vr|d ||d < ||d  qtt	|}tt	|

|}td	t| d
t| d |rt| tjjrt| tjtjjfr| j} | | |||fS )a  
    Compute a module state_dict, of which the keys are the same as `dst`. The values of `dst` are overwritten
    by the ones from `src` whenever their keys match. The method provides additional `dst_prefix` for
    the `dst` key when matching them. `mapping` can be a `{"src_key": "dst_key"}` dict, indicating
    `dst[dst_prefix + dst_key] = src[src_key]`.
    This function is mainly to return a model state dict
    for loading the `src` model state into the `dst` model, `src` and `dst` can have different dict keys, but
    their corresponding values normally have the same shape.

    Args:
        dst: a pytorch module or state dict to be updated.
        src: a pytorch module or state dict used to get the values used for the update.
        dst_prefix: `dst` key prefix, so that `dst[dst_prefix + src_key]`
            will be assigned to the value of `src[src_key]`.
        mapping: a `{"src_key": "dst_key"}` dict, indicating that `dst[dst_prefix + dst_key]`
            to be assigned to the value of `src[src_key]`.
        exclude_vars: a regular expression to match the `dst` variable names,
            so that their values are not overwritten by `src`.
        inplace: whether to set the `dst` module with the updated `state_dict` via `load_state_dict`.
            This option is only available when `dst` is a `torch.nn.Module`.
        filter_func: a filter function used to filter the weights to be loaded.
            See 'filter_swinunetr' in "monai.networks.nets.swin_unetr.py".

    Examples:
        .. code-block:: python

            from monai.networks.nets import BasicUNet
            from monai.networks.utils import copy_model_state

            model_a = BasicUNet(in_channels=1, out_channels=4)
            model_b = BasicUNet(in_channels=1, out_channels=2)
            model_a_b, changed, unchanged = copy_model_state(
                model_a, model_b, exclude_vars="conv_0.conv_0", inplace=False)
            # dst model updated: 76 of 82 variables.
            model_a.load_state_dict(model_a_b)
            # <All keys matched successfully>

    Returns: an OrderedDict of the updated `dst` state, the changed, and unchanged keys.

    c                   $   h | ]} rt  |r|qS r8   recompilesearchrK   Zs_keyexclude_varsr8   r9   rM   R     $ z#copy_model_state.<locals>.<setcomp>zParam. shape changed from z to rR   Nr   r;   z'dst' model updated:  of z variables.)r#   r   rl   itemsrm   appendr}   r~   sortedset
differenceloggerinfor[   r   rn   r   Moduler   r   r   r   load_state_dict)r   r   Z
dst_prefixmappingr   inplacefilter_funcsrc_dictZdst_dictZto_skipall_keysZupdated_keyssvalZdst_keykeyrj   Znew_pairZunchanged_keysr8   r   r9   r$     sB   1 
$

 

r$   torch.nn.Module | dictpathr   c                 K  sN   i }t | tr|  D ]
\}}t|||< qnt| }td||d| dS )a  
    Save the state dict of input source data with PyTorch `save`.
    It can save `nn.Module`, `state_dict`, a dictionary of `nn.Module` or `state_dict`.
    And automatically convert the data parallel module to regular module.
    For example::

        save_state(net, path)
        save_state(net.state_dict(), path)
        save_state({"net": net, "opt": opt}, path)
        net_dp = torch.nn.DataParallel(net)
        save_state(net_dp, path)

    Refer to: https://pytorch.org/ignite/v0.4.8/generated/ignite.handlers.DiskSaver.html.

    Args:
        src: input data to save, can be `nn.Module`, `state_dict`, a dictionary of `nn.Module` or `state_dict`.
        path: target file path to save the input object.
        kwargs: other args for the `save_obj` except for the `obj` and `path`.
            default `func` is `torch.save()`, details of the args:
            https://pytorch.org/docs/stable/generated/torch.save.html.

    )r   r  Nr8   )r   dictr   r#   r   )r   r  kwargsZckptkvr8   r8   r9   r%   s  s   
r%   -C6?        modelinputsSequence[Any]input_namesSequence[str] | Noneoutput_namesopset_version
int | Nonedynamic_axesDMapping[str, Mapping[int, str]] | Mapping[str, Sequence[int]] | Nonefilename
Any | Noneverifytorch.device | Noneuse_ortort_providerrtolatol	use_tracedo_constant_foldingconstant_size_thresholdc           !   
     s  |    t m i }|r"| }|}d|v r!|d r!|r!||d< d}n
tjj| fi |}t|s6t|tr:|f}nt|}d}|du rLt	
 }|j}n|}td|  tjj||f|||p`d|||d| t|}W d   n1 sxw   Y  |rtrddlm}m} |||d	}||| |rFt|trt| } du rttj rd
nd  fdd|D }|  } t  tdd t| | d}W d   n1 sw   Y  tdd dd |jjD }tt|dd |D }|	rtj |! |
r|
ndgd}|"d|}nt#$|}|"d|}tdd t||D ]\}} t|tj%rDtj&j'|( t)| |j*d||d q(|S )a  
    Utility to convert a model into ONNX model and optionally verify with ONNX or onnxruntime.
    See also: https://pytorch.org/docs/stable/onnx.html for how to convert a PyTorch model to ONNX.

    Args:
        model: source PyTorch model to save.
        inputs: input sample data used by pytorch.onnx.export. It is also used in ONNX model verification.
        input_names: optional input names of the ONNX model.
        output_names: optional output names of the ONNX model.
        opset_version: version of the (ai.onnx) opset to target. Must be >= 7 and not exceed
        the latest opset version supported by PyTorch, for more details:
            https://github.com/onnx/onnx/blob/main/docs/Operators.md and
            https://github.com/pytorch/pytorch/blob/master/torch/onnx/_constants.py
        dynamic_axes: specifies axes of tensors as dynamic (i.e. known only at run-time). If set to None,
            the exported model will have the shapes of all input and output tensors set to match given
            ones, for more details: https://pytorch.org/docs/stable/onnx.html#torch.onnx.export.
        filename: optional filename to save the ONNX model, if None, don't save the ONNX model.
        verify: whether to verify the ONNX model with ONNX or onnxruntime.
        device: target PyTorch device to verify the model, if None, use CUDA if available.
        use_ort: whether to use onnxruntime to verify the model.
        ort_provider": onnxruntime provider to use, default is ["CPUExecutionProvider"].
        rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model.
        atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model.
        use_trace: whether to use `torch.jit.trace` to export the torchscript model.
        do_constant_folding: passed to onnx.export(). If True, extra polygraphy folding pass is done.
        constant_size_threshold: passed to polygrapy conatant forling, default = 16M
        kwargs: if use_trace=True: additional arguments to pass to torch.onnx.export()
            else: other arguments except `obj` for `torch.jit.script()` to convert model, for more details:
            https://pytorch.org/docs/master/generated/torch.jit.script.html.

    dynamor  FNztorch_versioned_kwargs=)fr  r  r  r  r!  r   )fold_constants	save_onnx)Zsize_thresholdcudar   c                   &   g | ]}t |tjr| n|qS r8   r   rn   r   r   r   rh   r8   r9   r        & z#convert_to_onnx.<locals>.<listcomp>seedTc                 S  s   g | ]}|j qS r8   r@   r   r8   r8   r9   r         c                 S  s   g | ]}|   qS r8   )r   r   r   r8   r8   r9   r     r   ZCPUExecutionProvider)Z	providersr   r  r  )+r   rn   r   jitscript	is_tensorr   r  r   tempfileNamedTemporaryFilerA   printr   exportloadpolygraphy_importedZpolygraphy.backend.onnx.loaderr%  r&  rl   valuesrh   r'  is_availabler   r   r   graphinputzipr   ZInferenceSessionSerializeToStringrunonnxreferenceZReferenceEvaluatorr   testingassert_closer   r   rc   )!r  r  r  r  r  r  r  r  rh   r  r  r  r  r   r!  r"  r	  Ztorch_versioned_kwargsZmode_to_exportZonnx_inputs	temp_filer$  
onnx_modelr%  r&  	torch_outZmodel_input_names
input_dictZort_sessZonnx_outZsessr1r2r8   r*  r9   r&     s   2
	$







"r&   filename_or_objextra_filesdict | NoneSequence[Any] | Nonec	                   s  |    t 6 |r |du rtdtjj| fd|i|	}
n
tjj| fi |	}
|dur7tjj|
||d W d   n1 sAw   Y  |rЈ du rXttj	
 rUdnd |du r`td fdd	|D }|durstj|n|
}|    |  } t # td
d t| | }td
d t|| }tdd W d   n1 sw   Y  t||D ]\}}t|tjst|tjrtjj||||d q|
S )a  
    Utility to convert a model into TorchScript model and save to file,
    with optional input / output data verification.

    Args:
        model: source PyTorch model to save.
        filename_or_obj: if not None, specify a file-like object (has to implement write and flush)
            or a string containing a file path name to save the TorchScript model.
        extra_files: map from filename to contents which will be stored as part of the save model file.
            for more details: https://pytorch.org/docs/stable/generated/torch.jit.save.html.
        verify: whether to verify the input and output of TorchScript model.
            if `filename_or_obj` is not None, load the saved TorchScript model and verify.
        inputs: input test data to verify model, should be a sequence of data, every item maps to a argument
            of `model()` function.
        device: target device to verify the model, if None, use CUDA if available.
        rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model.
        atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model.
        use_trace: whether to use `torch.jit.trace` to export the TorchScript model.
        kwargs: other arguments except `obj` for `torch.jit.script()` or `torch.jit.trace()` (if use_trace is True)
            to convert model, for more details: https://pytorch.org/docs/master/generated/torch.jit.script.html.

    Nz'Missing input data for tracing convert.example_inputs)r   r$  _extra_filesr'  r   $Missing input data for verification.c                   r(  r8   r)  r   r*  r8   r9   r   G  r+  z*convert_to_torchscript.<locals>.<listcomp>r   r,  r/  )r   rn   r   r   r0  tracer1  saverh   r'  r:  r7  r   r   r   r=  r   r   rA  rB  )r  rI  rJ  r  r  rh   r  r  r   r	  script_moduleZts_modelrE  Ztorchscript_outrG  rH  r8   r*  r9   r'     s@   "





r'   	min_shape	opt_shape	max_shape	precisionc                 C  sP  t dd\}}	|||f}
|sg n|}|sg n|}tj| ||jj}||}|dt|j	j
> }| }|rF|j|d g|
R   |||}||  }|sqd}t|jD ]}||| d 7 }q\td| | }|| |dkr||jj |||}t }|| tjj |! t"d	| ||d
}|S )af  
    This function takes an ONNX model as input, exports it to a TensorRT engine, wraps the TensorRT engine
    to a TensorRT engine-based TorchScript model and return the TorchScript model.

    Args:
        onnx_model: the source ONNX model to compile.
        min_shape: the minimum input shape of the converted TensorRT model.
        opt_shape: the optimization input shape of the model, on which the TensorRT optimizes.
        max_shape: the maximum input shape of the converted TensorRT model.
        device: the target GPU index to convert and verify the model.
        precision: the weight precision of the converted TensorRT engine-based TorchScript model.
            Should be 'fp32' or 'fp16'.
        input_names: optional input names of the ONNX model. Should be a sequence like
            `['input_0', 'input_1', ..., 'input_N']` where N equals to the number of the
            model inputs.
        output_names: optional output names of the ONNX model. Should be a sequence like
            `['output_0', 'output_1', ..., 'output_N']` where N equals to the number of
            the model outputs.

    tensorrtz8.5.3r;   r   rQ   
z.TensorRT cannot parse the ONNX model, due to:
fp16cuda:)rh   Zinput_binding_namesZoutput_binding_names)#r   rn   r'  
set_deviceLoggerWARNINGBuilderZcreate_networkr7   ZNetworkDefinitionCreationFlagZEXPLICIT_BATCHZcreate_optimization_profile	set_shapeZ
OnnxParserparser>  r   Z
num_errorsZ	get_errordesc	ExceptionZcreate_builder_configZadd_optimization_profileZset_flagZBuilderFlagFP16Zbuild_serialized_networkioBytesIOwriter   tsZembed_engine_in_new_modulegetvalueDevice)rD  rS  rT  rU  rh   rV  r  r  trtrF   input_shapesr   buildernetworkprofileparsersuccessZparser_error_messager   configZserialized_enginer$  	trt_modelr8   r8   r9   _onnx_trt_compileZ  s@   



rs  Zinput_0Zoutput_0{Gz?use_onnxbool | Noneonnx_input_namesonnx_output_namesc              
   K  s  t j s	td|std|std| d |dur-t|dkr-td| d |r1|nd	}t d
| }|dkrBt j	nt j
}t t||g}|  |} t||\}}}|r|	rjdd |	D ni }||
rwdd |
D ni  t| ||	|
||d}t|||||||	|
d}nYt| |||d}|  || t  < t jj|d% tj|||dg}tj|f||td
| dd|}W d   n1 sw   Y  W d   n1 sw   Y  |rT|du rtd|durt j|n|}t  # td	d t| | }td	d t|| }tdd W d   n	1 s,w   Y  t||D ]\}}t|t jsHt|t jrRt jj ||||d q6|S )a  
    Utility to export a model into a TensorRT engine-based TorchScript model with optional input / output data verification.

    There are two ways to export a model:
    1, Torch-TensorRT way: PyTorch module ---> TorchScript module ---> TensorRT engine-based TorchScript.
    2, ONNX-TensorRT way: PyTorch module ---> TorchScript module ---> ONNX model ---> TensorRT engine --->
    TensorRT engine-based TorchScript.

    When exporting through the first way, some models suffer from the slowdown problem, since Torch-TensorRT
    may only convert a little part of the PyTorch model to the TensorRT engine. However when exporting through
    the second way, some Python data structures like `dict` are not supported. And some TorchScript models are
    not supported by the ONNX if exported through `torch.jit.script`.

    Args:
        model: a source PyTorch model to convert.
        precision: the weight precision of the converted TensorRT engine based TorchScript models. Should be 'fp32' or 'fp16'.
        input_shape: the input shape that is used to convert the model. Should be a list like [N, C, H, W] or
            [N, C, H, W, D].
        dynamic_batchsize: a sequence with three elements to define the batch size range of the input for the model to be
            converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH]. After converted, the batchsize of model
            input should between `MIN_BATCH` and `MAX_BATCH` and the `OPT_BATCH` is the best performance batchsize that the
            TensorRT tries to fit. The `OPT_BATCH` should be the most frequently used input batchsize in the application,
            default to None.
        use_trace: whether using `torch.jit.trace` to convert the PyTorch model to a TorchScript model and then convert to
            a TensorRT engine based TorchScript model or an ONNX model (if `use_onnx` is True), default to False.
        filename_or_obj: if not None, specify a file-like object (has to implement write and flush) or a string containing a
            file path name to load the TensorRT engine based TorchScript model for verifying.
        verify: whether to verify the input and output of the TensorRT engine based TorchScript model.
        device: the target GPU index to convert and verify the model. If None, use #0 GPU.
        use_onnx: whether to use the ONNX-TensorRT way to export the TensorRT engine-based TorchScript model.
        onnx_input_names: optional input names of the ONNX model. This arg is only useful when `use_onnx` is True. Should be
            a sequence like `('input_0', 'input_1', ..., 'input_N')` where N equals to the number of the model inputs. If not
            given, will use `('input_0',)`, which supposes the model only has one input.
        onnx_output_names: optional output names of the ONNX model. This arg is only useful when `use_onnx` is True. Should be
            a sequence like `('output_0', 'output_1', ..., 'output_N')` where N equals to the number of the model outputs. If
            not given, will use `('output_0',)`, which supposes the model only has one output.
        rtol: the relative tolerance when comparing the outputs between the PyTorch model and TensorRT model.
        atol: the absolute tolerance when comparing the outputs between the PyTorch model and TensorRT model.
        kwargs: other arguments except `module`, `inputs`, `enabled_precisions` and `device` for `torch_tensorrt.compile()`
            to compile model, for more details: https://pytorch.org/TensorRT/py_api/torch_tensorrt.html#torch-tensorrt-py.
    zCannot find any GPU devices.z*Missing the input shape for model convert.z@There is no dynamic batch range. The converted model only takes z shape input.Nr   zAThe dynamic batch range sequence should have 3 elements, but got z
 elements.r   rZ  fp32c                 S     i | ]}|d diqS r   Z	batchsizer8   rK   r
  r8   r8   r9   
<dictcomp>      z"convert_to_trt.<locals>.<dictcomp>c                 S  r|  r}  r8   r~  r8   r8   r9   r    r  )r   r  )rS  rT  rU  rh   rV  r  r  )rh   r  r   r*  )rS  rT  rU  Ztorchscript)r  enabled_precisionsrh   irrO  r,  r/  )!rn   r'  r:  rb  r   r}   r~   r[   rh   float32halfrandr   r   r   r0   updater&   rs  r'   r   r   Inputr   ri  r0  r7  r   r=  r   r   rA  rB  )r  rV  r2   r4   r   rI  r  rh   rw  ry  rz  r  r  r	  target_deviceZconvert_precisionr  r=   r>   r?   r  Zir_modelrr  Zinput_placeholderrE  Ztrt_outrG  rH  r8   r8   r9   r(     s   
:




r(   c                  G  s2   t jjd urdt jjv rt j| ddiS t j|  S )Nindexingijrn   meshgrid__kwdefaults__tensorsr8   r8   r9   r)   /  s   
r)   c                  G  sL   t jjd urdt jjv rt j| ddiS t j| d | d g| dd  R  S )Nr  xyr;   r   r<   r  r  r8   r8   r9   r*   6  s   $r*   parenttorch.nn.Module
new_moduleout!list[tuple[str, torch.nn.Module]]strict_matchmatch_devicec                   s   |rt dd |  D }t|dkr||d  |d}|dkrL|d|  t|  } ||d d }g }t| ||| | fdd	|D 7 }dS |rbt| |}	t| || |||	fg7 }dS |  D ]\}
}||
v ryt| |
t	||d
d qfdS )zO
    Helper function for :py:class:`monai.networks.utils.replace_modules`.
    c                 S  s   h | ]}|j qS r8   r*  r   r8   r8   r9   rM   I  r.  z#_replace_modules.<locals>.<setcomp>r;   r   rR   r   Nc                   s&   g | ]}  d |d  |d fqS )rR   r   r;   r8   )rK   rparent_namer8   r9   r   V  r+  z$_replace_modules.<locals>.<listcomp>T)r  )
rl   
parametersr[   r   r   rW   _replace_modulesr\   rT   r   )r  rA   r  r  r  r  devicesr   _out
old_modulemod_namerF   r8   r  r9   r  =  s*   


r  c                 C  s   g }t | ||||| |S )a  
    Replace sub-module(s) in a parent module.

    The name of the module to be replace can be nested e.g.,
    `features.denseblock1.denselayer1.layers.relu1`. If this is the case (there are "."
    in the module name), then this function will recursively call itself.

    Args:
        parent: module that contains the module to be replaced
        name: name of module to be replaced. Can include ".".
        new_module: `torch.nn.Module` to be placed at position `name` inside `parent`. This will
            be deep copied if `strict_match == False` multiple instances are independent.
        strict_match: if `True`, module name must `== name`. If false then
            `name in named_modules()` will be used. `True` can be used to change just
            one module, whereas `False` can be used to replace all modules with similar
            name (e.g., `relu`).
        match_device: if `True`, the device of the new module will match the model. Requires all
            of `parent` to be on the same device.

    Returns:
        List of tuples of replaced modules. Element 0 is module name, element 1 is the replaced module.

    Raises:
        AttributeError: if `strict_match` is `True` and `name` is not a named module in `parent`.
    r  )r  rA   r  r  r  r  r8   r8   r9   r+   c  s    r+   c                 c  sl    g }z t | ||||| dV  W |D ]\}}t | ||g d|d qdS |D ]\}}t | ||g d|d q&w )z
    Temporarily replace sub-module(s) in a parent module (context manager).

    See :py:class:`monai.networks.utils.replace_modules`.
    NT)r  r  r  )r  rA   r  r  r  replacedr   r8   r8   r9   r,     s   r,   c           	        s  dur durt dt| }t }durEfdd|D }|  D ] \}}||v r5d|_|| q$|jsDd|_td| d q$ durw fd	d|D }|  D ] \}}||vrgd|_|| qV|jsvd|_td
| d qVt	t
| dt
| d dS )a  
    A utilty function to help freeze specific layers.

    Args:
        model: a source PyTorch model to freeze layer.
        freeze_vars: a regular expression to match the `model` variable names,
            so that their `requires_grad` will set to `False`.
        exclude_vars: a regular expression to match the `model` variable names,
            except for matched variable names, other `requires_grad` will set to `False`.

    Raises:
        ValueError: when freeze_vars and exclude_vars are both specified.

    NzEIncompatible values: freeze_vars and exclude_vars are both specified.c                   r   r8   r   r   )freeze_varsr8   r9   rM     r   z freeze_layers.<locals>.<setcomp>FTz!The freeze_vars does not include z0, but requires_grad is False, change it to True.c                   r   r8   r   r   r   r8   r9   rM     r   zThe exclude_vars includes r   z variables frozen.)r   r#   rl   named_parametersr   r   r}   r~   r   r   r[   )	r  r  r   r  Zfrozen_keysZ	to_freezerA   paramZ
to_excluder8   )r   r  r9   freeze_layers  s6   
"r  c                      (   e Zd ZdZ fddZdd Z  ZS )r   z}
    Cast the input tensor to a temporary type before applying the submodule, and then cast it back to the initial type.
    c                   s    t    || _|| _|| _d S N)super__init__initial_typetemporary_type	submodule)selfr  r  r  r   r8   r9   r    s   

zCastTempType.__init__c                 C  s@   |j }|| jkr|| j}| |}|| jkr|| j}|S r  )rc   r  r   r  r  )r  r   rc   r8   r8   r9   forward  s   


zCastTempType.forwardr   
__module____qualname____doc__r  r  __classcell__r8   r8   r  r9   r     s    r   c                 C  s   | j |kr| j|dS | S )zN
    Utility function to cast a single tensor from from_dtype to to_dtype
    r   )rc   r   )r   
from_dtypeto_dtyper8   r8   r9   cast_tensor  s   r  c                   sv   t | tjrt|  dS t | tr(i }|  D ]}t| |  d||< q|S t | tr9t fdd| D S dS )zU
    Utility function to cast all tensors in a tuple from from_dtype to to_dtype
    r  r  c                 3  s    | ]
}t | d V  qdS )r  N)cast_all)rK   yr  r8   r9   r     s    zcast_all.<locals>.<genexpr>N)r   rn   r   r  r  keysr  r   )r   r  r  new_dictr
  r8   r  r9   r    s   

r  c                      r  )CastToFloatzo
    Class used to add autocast protection for ONNX export
    for forward methods with single return vaue
    c                      t    || _d S r  r  r  rX   r  rX   r  r8   r9   r       

zCastToFloat.__init__c                 C  sT   |j }tjddd | j|tj|}W d    |S 1 s#w   Y  |S )Nr'  Fenabled)rc   rn   autocastrX   r  r   r  )r  r   rc   retr8   r8   r9   r    s   
zCastToFloat.forwardr  r8   r8   r  r9   r        r  c                      r  )CastToFloatAllzs
    Class used to add autocast protection for ONNX export
    for forward methods with multiple return values
    c                   r  r  r  r  r  r8   r9   r    r  zCastToFloatAll.__init__c                 G  s`   |d j }tjddd | jjt||tjd }W d    n1 s#w   Y  t|tj|dS )Nr   r'  Fr  r  )rc   rn   r  rX   r  r  r  )r  argsr  r  r8   r8   r9   r    s
   
zCastToFloatAll.forwardr  r8   r8   r  r9   r  	  r  r  base_ttype[nn.Module]dest_t'Callable[[nn.Module], nn.Module | None]c                   s   d fdd}|S )	z
    Generic function generator to replace base_t module with dest_t wrapper.
    Args:
        base_t : module type to replace
        dest_t : destination module type
    Returns:
        swap function to replace base_t module with dest_t
    rX   r   rf   nn.Module | Nonec                   s    | }|S r  r8   )rX   r  r  r8   r9   expansion_fn$  s   z!wrap_module.<locals>.expansion_fnNrX   r   rf   r  r8   r  r  r  r8   r  r9   wrap_module  s   
r  c                   s   d fdd}|S )	a7  
    Generic function generator to replace base_t module with dest_t.
    base_t and dest_t should have same atrributes. No weights are copied.
    Args:
        base_t : module type to replace
        dest_t : destination module type
    Returns:
        swap function to replace base_t module with dest_t
    rX   r   rf   r  c                   s2   t  sd S  j} fdd|D }| }|S )Nc                   s   g | ]}t  |d qS r  )rW   )rK   rA   rX   r8   r9   r   :  r   z8simple_replace.<locals>.expansion_fn.<locals>.<listcomp>)r   __constants__)rX   	constantsr  r  r  r  r  r9   r  6  s   
z$simple_replace.<locals>.expansion_fnNr  r8   r  r8   r  r9   simple_replace+  s   r  r   dict[str, nn.Module]c                 C  s^   |  D ](\}}|d}| }|dd D ]}|j| }|du r" n|}q||j|d < q| S )a  
    This function swaps nested modules as specified by "dot paths" in mod with a desired replacement. This allows
    for swapping nested modules through arbitrary levels if children

    NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.

    rR   Nr   )r   rU   _modules)r  r   r  new_modZexpanded_pathZ
parent_modZsub_pathsubmodr8   r8   r9   _swap_modulesA  s   

r  
expansions2dict[str, Callable[[nn.Module], nn.Module | None]]c                 C  sb   i }|   D ]\}}t|j}||v r|| |}|r|||< qtdt| d t| | | S )a  
    Top-level function to replace modules in model, specified by class name with a desired replacement.
    NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
    Args:
        model : top level module
        expansions : replacement dictionary: module class name -> replacement function generator
    Returns:
        model, possibly modified in-place
    zSwapped z modules)rT   r   r   r5  r[   r  )r  r  r   rA   r   Zm_typeswappedr8   r8   r9   replace_modules_by_typeW  s   

r  c                 C  sX   t d ttjtttjtttjtttjtttjtttjtd}t	| | | S )a5  
    Top-level function to add cast wrappers around modules known to cause issues for FP16/autocast ONNX export
    NOTE: This occurs in place, if you want to preserve model then make sure to copy it first.
    Args:
        model : top level module
    Returns:
        model, possibly modified in-place
    zAdding casts around norms...)BatchNorm1dBatchNorm2dBatchNorm3d	LayerNormInstanceNorm1dInstanceNorm3d)
r5  r  r   r  r  r  r  r  r  r  )r  Zcast_replacementsr8   r8   r9   add_casts_around_normsq  s   	






r  )r2   r3   r4   r5   )F)rA   rG   )
r`   ra   rb   r7   rc   rd   re   r7   rf   ra   )Frv   )rw   ra   rx   ry   rz   r{   rf   r	   )NNFF)
rh   r   rc   r   r   ry   r   ry   rf   ra   )FF)r   ra   r   r3   r   r3   r   ry   r   ry   rf   ra   )r   r{   r   r   rf   r   )r   ra   r   r7   r   r7   rf   ra   )r   r   )r   r   )rQ   NNTN)r   r   r   r   )r   r  r  r   )NNNNNFNFNr  rv   TTr  ) r  r   r  r  r  r  r  r  r  r  r  r  r  r  r  ry   rh   r  r  ry   r  r  r  r{   r  r{   r   ry   r!  ry   r"  r7   )NNFNNr  rv   F)r  r   rI  r  rJ  rK  r  ry   r  rL  rh   r  r  r{   r  r{   r   ry   )rS  r3   rT  r3   rU  r3   rh   r7   rV  rG   r  r  r  r  )
NFNFNFrt  ru  rv  rv   )r  r   rV  rG   r2   r3   r4   r5   r   ry   rI  r  r  ry   rh   r  rw  rx  ry  r  rz  r  r  r{   r  r{   )TT)r  r  rA   rG   r  r  r  r  r  ry   r  ry   rf   r   )r  r  rA   rG   r  r  r  ry   r  ry   rf   r  )
r  r  rA   rG   r  r  r  ry   r  ry   )NN)r  r   )r  r  r  r  rf   r  )r  r   r   r  rf   r   )r  r   r  r  rf   r   )r  r   rf   r   )^r  
__future__r   rd  r   r3  r}   collectionsr   collections.abcr   r   r   
contextlibr   copyr   typingr	   r
   r   r   rn   torch.nnr   monai.apps.utilsr   monai.configr   monai.utils.miscr   r   r   monai.utils.moduler   r   monai.utils.type_conversionr   r   r   rF   r@  r   r   r8  r   __all__r   r   rB   r0   r/   r-   r.   r{   r   r   r   r   r   normal_r   kaiming_normal_r   r   r    r!   r"   r#   r$   r%   r&   r'   rs  r(   r)   r*   r  r+   r,   r  r   r   float16r  r  r  r  r  r  r  r  r  r  r8   r8   r8   r9   <module>   s   



52
)

-*"
#
U% 
FN *%,



