U
    Ph                     @  s^  d Z ddlmZ ddlZddlZddlZddlmZ ddlm	Z	m
Z
mZ ddlmZ ddlmZ ddlmZ ddlZddlZddlmZ dd	lmZ dd
lmZ ddlmZmZmZ ddlm Z m!Z!m"Z" ddl#m$Z$m%Z% e!d\Z&Z'e!d\Z(Z'e!d\Z)Z'dddddddddddddddd d!d"d#d$d%d&gZ*ee+d'Z,da-d(d& Z.d}d*d+d,d$Z/d*d+d-d%Z0ej1d.fd/d0d1d0d/d2d3dZ2d~d/d5d6d7d8d9dZ3dd:d;d5d5d/d<d=dZ4dd/d>d>d5d5d/d?d@dZ5dAejj6j7fd6dBdCdDdEdZ8ej6j9fdFdZ:d/d0d0d/dGdHdZ;edIdJdKdZ<edIdJdLdZ=dMdNdOdZ>ddMdMdRdSdZ?dTdUdVdWdZ@ddIdYdZdZd[d\d]d5d^d5dZd6d6d5d_d`dZAddId]dad5dbd^d6d6d5dc	dddZBd>d>d>d0d*dZdZdedfdgZCddId*d>dkd5d]d5d[dldZdZd6d6dmdndZDdod  ZEdpd! ZFddqd*dqdrd5d5dCdsdtduZGddqd*dqd5d5drdvdwd"ZHeddqd*dqd5d5dxdyd#ZIddIdzd{d|ZJdS )zE
Utilities and types for defining networks, these depend on PyTorch.
    )annotationsN)OrderedDict)CallableMappingSequence)contextmanager)deepcopy)Any)
get_logger)PathLike)ensure_tuplesave_objset_determinism)look_up_optionoptional_importpytorch_after)convert_to_dst_typeconvert_to_tensoronnxzonnx.referenceonnxruntimeone_hotpredict_segmentationnormalize_transformto_norm_affinenormal_init	icnr_initpixelshuffle	eval_mode
train_modeget_state_dictcopy_model_state
save_stateconvert_to_onnxconvert_to_torchscriptconvert_to_trtmeshgrid_ijmeshgrid_xyreplace_modulesreplace_modules_templook_up_named_moduleset_named_modulehas_nvfuser_instance_norm)module_namec                  C  sX   t dk	rt S tddd\} a t s$dS zddl}|d W n tk
rR   da Y nX t S )zwhether the current environment has InstanceNorm3dNVFuser
    https://github.com/NVIDIA/apex/blob/23.05-devel/apex/normalization/instance_norm.py#L15-L16
    Nzapex.normalizationZInstanceNorm3dNVFusernameFr   Zinstance_norm_nvfuser_cuda)_has_nvfuserr   	importlibimport_moduleImportError)_r0    r4   I/home/dell461/cl/sdc2/HISourceFinder-master-l/src/monai/networks/utils.pyr+   F   s    
Fstrr-   c                 C  s   t | dd | D d|d}|dkr*dS |dkr6|S |dD ]P}| rZ|t| }q@t |dd | D ddd}|dkr dS t||}q@|S )	a  
    get the named module in `mod` by the attribute name,
    for example ``look_up_named_module(net, "features.3.1.attn")``

    Args:
        name: a string representing the module attribute.
        mod: a pytorch module to be searched (in ``mod.named_modules()``).
        print_all_options: whether to print all named modules when `name` is not found in `mod`. Defaults to False.

    Returns:
        the corresponding pytorch module's subcomponent such as ``net.features[3][1].attn``
    c                 S  s   h | ]}|d  qS r   r4   .0nr4   r4   r5   	<setcomp>h   s     z'look_up_named_module.<locals>.<setcomp>N)defaultprint_all_options .c                 S  s   h | ]}|d  qS r7   r4   )r9   itemr4   r4   r5   r;   r   s     F)r   named_modulessplitisdigitintgetattr)r.   modr=   name_strr:   r4   r4   r5   r)   Z   s$       c                 C  sJ   | dd}t|dkr|nd|f\}}|s0|S t|| }t||| | S )a  
    look up `name` in `mod` and replace the layer with `new_layer`, return the updated `mod`.

    Args:
        mod: a pytorch module to be updated.
        name: a string representing the target module attribute.
        new_layer: a new module replacing the corresponding layer at ``mod.name``.

    Returns:
        an updated ``mod``

    See also: :py:func:`monai.networks.utils.look_up_named_module`.
    r?         r>   )rsplitlenr)   setattr)rF   r.   Z	new_layerZ	mods_attrsubmodsattr_modr4   r4   r5   r*   y   s    
rH   ztorch.TensorrD   ztorch.dtype)labelsnum_classesdtypedimreturnc                 C  s   | j |d k r<t| jdg|d t| j   }t| |} t| j}|| dkrZtd|||< tj||| jd}|j	|| 
 dd} | S )a  
    For every value v in `labels`, the value in the output will be either 1 or 0. Each vector along the `dim`-th
    dimension has the "one-hot" format, i.e., it has a total length of `num_classes`,
    with a one and `num_class-1` zeros.
    Note that this will include the background label, thus a binary mask should be treated as having two classes.

    Args:
        labels: input tensor of integers to be converted into the 'one-hot' format. Internally `labels` will be
            converted into integers `labels.long()`.
        num_classes: number of output channels, the corresponding length of `labels[dim]` will be converted to
            `num_classes` from `1`.
        dtype: the data type of the output one_hot label.
        dim: the dimension to be converted to `num_classes` channels from `1` channel, should be non-negative number.

    Example:

    For a tensor `labels` of dimensions [B]1[spatial_dims], return a tensor of dimensions `[B]N[spatial_dims]`
    when `num_classes=N` number of classes and `dim=1`.

    .. code-block:: python

        from monai.networks.utils import one_hot
        import torch

        a = torch.randint(0, 2, size=(1, 2, 2, 2))
        out = one_hot(a, num_classes=2, dim=0)
        print(out.shape)  # torch.Size([2, 2, 2, 2])

        a = torch.randint(0, 2, size=(2, 1, 2, 2, 2))
        out = one_hot(a, num_classes=2, dim=1)
        print(out.shape)  # torch.Size([2, 2, 2, 2, 2])

    rH   z6labels should have a channel with length equal to one.)sizerR   device)rS   indexvalue)ndimlistshaperK   torchreshapeAssertionErrorzerosrV   scatter_long)rP   rQ   rR   rS   r[   shor4   r4   r5   r      s    $"
        boolfloatr	   )logitsmutually_exclusive	thresholdrT   c                 C  sB   |s| |k  S | jd dkr4td | |k  S | jdddS )a%  
    Given the logits from a network, computing the segmentation by thresholding all values above 0
    if multi-labels task, computing the `argmax` along the channel axis if multi-classes task,
    logits has shape `BCHW[D]`.

    Args:
        logits: raw data of model output.
        mutually_exclusive: if True, `logits` will be converted into a binary matrix using
            a combination of argmax, which is suitable for multi-classes task. Defaults to False.
        threshold: thresholding the prediction values if multi-labels task.
    rH   zTsingle channel prediction, `mutually_exclusive=True` ignored, use threshold instead.T)keepdim)rD   r[   warningswarnargmax)rg   rh   ri   r4   r4   r5   r      s    
ztorch.device | str | Noneztorch.dtype | None)rV   rR   align_cornerszero_centeredrT   c              
   C  s  t | tj|ddd} |   jtj|d}|rd||dk< d|rF|n|d  }tt|tjdtj|df}|sd|d	d
d
f< n^d||dk< d|r|d n| }tt|tjdtj|df}|sd|  d |d	d
d
f< |	dj|d}d|_
|S )a  
    Compute an affine matrix according to the input shape.
    The transform normalizes the homogeneous image coordinates to the
    range of `[-1, 1]`.  Currently the following source coordinates are supported:

        - `align_corners=False`, `zero_centered=False`, normalizing from ``[-0.5, d-0.5]``.
        - `align_corners=True`, `zero_centered=False`, normalizing from ``[0, d-1]``.
        - `align_corners=False`, `zero_centered=True`, normalizing from ``[-(d-1)/2, (d-1)/2]``.
        - `align_corners=True`, `zero_centered=True`, normalizing from ``[-d/2, d/2]``.

    Args:
        shape: input spatial shape, a sequence of integers.
        device: device on which the returned affine will be allocated.
        dtype: data type of the returned affine
        align_corners: if True, consider -1 and 1 to refer to the centers of the
            corner pixels rather than the image corners.
            See also: https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample
        zero_centered: whether the coordinates are normalized from a zero-centered range, default to `False`.
            Setting this flag and `align_corners` will jointly specify the normalization source range.
    TF)rV   wrap_sequence
track_meta)rR   rV   g       @      ?)rH   g      Nrd   r   rR   )r   r\   float64clonedetachtodiagcatones	unsqueezerequires_grad)r[   rV   rR   rn   ro   normr4   r4   r5   r      s     ""Sequence[int])affinesrc_sizedst_sizern   ro   rT   c                 C  s   t | tjs"tdt| j d|  dksB| jd | jd krXtdt	| j d| jd d }|t
|ks~|t
|krtd| dt
| d	t
| d
t|| j| j||}t|d| j||}||  ttj| | dd  S )a  
    Given ``affine`` defined for coordinates in the pixel space, compute the corresponding affine
    for the normalized coordinates.

    Args:
        affine: Nxdxd batched square matrix
        src_size: source image spatial shape
        dst_size: target image spatial shape
        align_corners: if True, consider -1 and 1 to refer to the centers of the
            corner pixels rather than the image corners.
            See also: https://pytorch.org/docs/stable/nn.functional.html#torch.nn.functional.grid_sample
        zero_centered: whether the coordinates are normalized from a zero-centered range, default to `False`.
            See also: :py:func:`monai.networks.utils.normalize_transform`.

    Raises:
        TypeError: When ``affine`` is not a ``torch.Tensor``.
        ValueError: When ``affine`` is not Nxdxd.
        ValueError: When ``src_size`` or ``dst_size`` dimensions differ from ``affine``.

    z%affine must be a torch.Tensor but is r?      rH   rI   zaffine must be Nxdxd, got zaffine suggests zD, got src=zD, dst=zD.cpu)dstr   )
isinstancer\   Tensor	TypeErrortype__name__
ndimensionr[   
ValueErrortuplerK   r   rV   rR   r   nplinalginvnumpy)r   r   r   rn   ro   srZ	src_xformZ	dst_xformr4   r4   r5   r     s     $g{Gz?z+Callable[[torch.Tensor, float, float], Any]None)stdnormal_funcrT   c                 C  s   | j j}t| dddk	rh|ddks4|ddkrh|| jjd| t| dddk	rtj| j	jd n0|ddkr|| jjd	| tj| j	jd
 dS )a  
    Initialize the weight and bias tensors of `m' and its submodules to values from a normal distribution with a
    stddev of `std'. Weight tensors of convolution and linear modules are initialized with a mean of 0, batch
    norm modules with a mean of 1. The callable `normal_func', used to assign values, should have the same arguments
    as its default normal_(). This can be used with `nn.Module.apply` to visit submodules of a network.
    weightNConvrs   Linearrd   biasZ	BatchNormrr   r   )
	__class__r   rE   findr   datanninit	constant_r   )mr   r   cnamer4   r4   r5   r   /  s    	,c           	      C  s   | j j^}}}|t| }t|| }t||g| }||}|dd}|||d}|dd|}|||g| }|dd}| j j	
| dS )z
    ICNR initialization for 2D/3D kernels adapted from Aitken et al.,2017 , "Checkerboard artifact free
    sub-pixel convolution".
    r   rH   rs   N)r   r[   rK   rD   r\   r_   	transposer]   repeatr   copy_)	convZupsample_factorr   out_channelsin_channelsdimsscale_factorZoc2kernelr4   r4   r5   r   D  s    )xspatial_dimsr   rT   c              
     s  || } t |  }|dd \}} | }|| dkr\td| d  d| d| d	t|| }||g fd	d
|dd D  }	t tddd|  }
|
|d |
d|  }
ddg}t|D ]}||
|d|  q| ||g g|  |dd  } | ||	} | S )a  
    Apply pixel shuffle to the tensor `x` with spatial dimensions `spatial_dims` and scaling factor `scale_factor`.

    See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution
    Using a nEfficient Sub-Pixel Convolutional Neural Network."

    See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution".

    Args:
        x: Input tensor
        spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D
        scale_factor: factor to rescale the spatial dimensions by, must be >=1

    Returns:
        Reshuffled version of `x`.

    Raises:
        ValueError: When input channels of `x` are not divisible by (scale_factor ** spatial_dims)
    NrI   r   zNumber of input channels (z:) must be evenly divisible by scale_factor ** dimensions (z**=z).c                   s   g | ]}|  qS r4   r4   )r9   dfactorr4   r5   
<listcomp>x  s     z pixelshuffle.<locals>.<listcomp>rH   )rZ   rU   r   rD   rangeextendr]   permute)r   r   r   rS   
input_size
batch_sizechannelsZscale_divisorZorg_channelsoutput_sizeindicesZpermute_indicesidxr4   r   r5   r   X  s$    
"$z	nn.Module)netsc               
   g  sZ   dd | D }z(t  dd | D V  W 5 Q R X W 5 |D ]}t |dr<|  q<X dS )a  
    Set network(s) to eval mode and then return to original state at the end.

    Args:
        nets: Input network(s)

    Examples

    .. code-block:: python

        t=torch.rand(1,1,16,16)
        p=torch.nn.Conv2d(1,1,3)
        print(p.training)  # True
        with eval_mode(p):
            print(p.training)  # False
            print(p(t).sum().backward())  # will correctly raise an exception as gradients are calculated
    c                 S  s    g | ]}t |d r|jr|qS traininghasattrr   r8   r4   r4   r5   r     s     
  zeval_mode.<locals>.<listcomp>trainc                 S  s"   g | ]}t |d r| n|qS )eval)r   r   r8   r4   r4   r5   r     s     N)r   r   r\   no_grad)r   r   r:   r4   r4   r5   r     s    

c               
   g  s\   dd | D }z*td dd | D V  W 5 Q R X W 5 |D ]}t |dr>|  q>X dS )a  
    Set network(s) to train mode and then return to original state at the end.

    Args:
        nets: Input network(s)

    Examples

    .. code-block:: python

        t=torch.rand(1,1,16,16)
        p=torch.nn.Conv2d(1,1,3)
        p.eval()
        print(p.training)  # False
        with train_mode(p):
            print(p.training)  # True
            print(p(t).sum().backward())  # No exception
    c                 S  s    g | ]}t |d r|js|qS r   r   r8   r4   r4   r5   r     s     
  ztrain_mode.<locals>.<listcomp>r   Tc                 S  s"   g | ]}t |d r| n|qS )r   )r   r   r8   r4   r4   r5   r     s     N)r   r   r\   set_grad_enabled)r   Z	eval_listr:   r4   r4   r5   r     s    
ztorch.nn.Module | Mappingobjc                 C  s0   t | tjtjjfr| j} t| dr,|  S | S )z
    Get the state dict of input object if has `state_dict`, otherwise, return object directly.
    For data parallel model, automatically convert it to regular model first.

    Args:
        obj: input object to check and get the state_dict.

    
state_dict)r   r   DataParallelparallelDistributedDataParallelmoduler   r   r   r4   r4   r5   r     s    	r>   T)r   srcc                   s  t |}tt | } fdd|D }	t|t  }
}| D ]F\}}| | }||kr>||	kr>|| j|jkr>|||< || q>|r|ni D ]r}| ||  }||kr||	kr|| j|| jkrtd|| j d|| j d || ||< || q|dk	rd| D ]J\}}|||}|dk	r|d |	kr|d ||d < ||d  qtt	|}tt	|

|}td	t| d
t| d |rt| tjjrt| tjtjjfr| j} | | |||fS )a  
    Compute a module state_dict, of which the keys are the same as `dst`. The values of `dst` are overwritten
    by the ones from `src` whenever their keys match. The method provides additional `dst_prefix` for
    the `dst` key when matching them. `mapping` can be a `{"src_key": "dst_key"}` dict, indicating
    `dst[dst_prefix + dst_key] = src[src_key]`.
    This function is mainly to return a model state dict
    for loading the `src` model state into the `dst` model, `src` and `dst` can have different dict keys, but
    their corresponding values normally have the same shape.

    Args:
        dst: a pytorch module or state dict to be updated.
        src: a pytorch module or state dict used to get the values used for the update.
        dst_prefix: `dst` key prefix, so that `dst[dst_prefix + src_key]`
            will be assigned to the value of `src[src_key]`.
        mapping: a `{"src_key": "dst_key"}` dict, indicating that `dst[dst_prefix + dst_key]`
            to be assigned to the value of `src[src_key]`.
        exclude_vars: a regular expression to match the `dst` variable names,
            so that their values are not overwritten by `src`.
        inplace: whether to set the `dst` module with the updated `state_dict` via `load_state_dict`.
            This option is only available when `dst` is a `torch.nn.Module`.
        filter_func: a filter function used to filter the weights to be loaded.
            See 'filter_swinunetr' in "monai.networks.nets.swin_unetr.py".

    Examples:
        .. code-block:: python

            from monai.networks.nets import BasicUNet
            from monai.networks.utils import copy_model_state

            model_a = BasicUNet(in_channels=1, out_channels=4)
            model_b = BasicUNet(in_channels=1, out_channels=2)
            model_a_b, changed, unchanged = copy_model_state(
                model_a, model_b, exclude_vars="conv_0.conv_0", inplace=False)
            # dst model updated: 76 of 82 variables.
            model_a.load_state_dict(model_a_b)
            # <All keys matched successfully>

    Returns: an OrderedDict of the updated `dst` state, the changed, and unchanged keys.

    c                   s$   h | ]} rt  |r|qS r4   recompilesearchr9   Zs_keyexclude_varsr4   r5   r;     s       z#copy_model_state.<locals>.<setcomp>zParam. shape changed from z to r?   Nr   rH   z'dst' model updated:  of z variables.)r   r   rZ   itemsr[   appendrk   rl   sortedset
differenceloggerinforK   r   r\   r   Moduler   r   r   r   load_state_dict)r   r   Z
dst_prefixmappingr   inplacefilter_funcsrc_dictZdst_dictZto_skipall_keysZupdated_keyssvalZdst_keykeyrX   Znew_pairZunchanged_keysr4   r   r5   r      s<    1 $

 
ztorch.nn.Module | dictr   )r   pathc                 K  sN   i }t | tr.|  D ]\}}t|||< qnt| }tf ||d| dS )a  
    Save the state dict of input source data with PyTorch `save`.
    It can save `nn.Module`, `state_dict`, a dictionary of `nn.Module` or `state_dict`.
    And automatically convert the data parallel module to regular module.
    For example::

        save_state(net, path)
        save_state(net.state_dict(), path)
        save_state({"net": net, "opt": opt}, path)
        net_dp = torch.nn.DataParallel(net)
        save_state(net_dp, path)

    Refer to: https://pytorch.org/ignite/v0.4.8/generated/ignite.handlers.DiskSaver.html.

    Args:
        src: input data to save, can be `nn.Module`, `state_dict`, a dictionary of `nn.Module` or `state_dict`.
        path: target file path to save the input object.
        kwargs: other args for the `save_obj` except for the `obj` and `path`.
            default `func` is `torch.save()`, details of the args:
            https://pytorch.org/docs/stable/generated/torch.save.html.

    )r   r   N)r   dictr   r   r   )r   r   kwargsZckptkvr4   r4   r5   r!   /  s    
-C6?zSequence[Any]zSequence[str] | Nonez
int | NonezDMapping[str, Mapping[int, str]] | Mapping[str, Sequence[int]] | Nonez
Any | Noneztorch.device | None)modelinputsinput_namesoutput_namesopset_versiondynamic_axesfilenameverifyrV   use_ortort_providerrtolatol	use_tracec              	     s4  |    t  i }|r | }n<tddsLd|kr:td|d |d< |d= tjj| f|}|dkrt }tj	j
|t|f|||||d| t	| }n0tj	j
|t|f|||||d| t	|}W 5 Q R X |r0 dkrttj rdnd  fd	d
|D }|  } t  tdd t| | d}W 5 Q R X tdd dd
 |jjD }tt|dd
 |D }|	rtj| |
r|
ndgd}|d|}nt|}|d|}tdd t||D ]R\}}t|tj rtddrtj!j"ntj!j#}||$ t%||j&d||d q|S )a  
    Utility to convert a model into ONNX model and optionally verify with ONNX or onnxruntime.
    See also: https://pytorch.org/docs/stable/onnx.html for how to convert a PyTorch model to ONNX.

    Args:
        model: source PyTorch model to save.
        inputs: input sample data used by pytorch.onnx.export. It is also used in ONNX model verification.
        input_names: optional input names of the ONNX model.
        output_names: optional output names of the ONNX model.
        opset_version: version of the (ai.onnx) opset to target. Must be >= 7 and not exceed
        the latest opset version supported by PyTorch, for more details:
            https://github.com/onnx/onnx/blob/main/docs/Operators.md and
            https://github.com/pytorch/pytorch/blob/master/torch/onnx/_constants.py
        dynamic_axes: specifies axes of tensors as dynamic (i.e. known only at run-time). If set to None,
            the exported model will have the shapes of all input and output tensors set to match given
            ones, for more details: https://pytorch.org/docs/stable/onnx.html#torch.onnx.export.
        filename: optional filename to save the ONNX model, if None, don't save the ONNX model.
        verify: whether to verify the ONNX model with ONNX or onnxruntime.
        device: target PyTorch device to verify the model, if None, use CUDA if available.
        use_ort: whether to use onnxruntime to verify the model.
        ort_provider": onnxruntime provider to use, default is ["CPUExecutionProvider"].
        rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model.
        atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model.
        use_trace: whether to use `torch.jit.trace` to export the torchscript model.
        kwargs: other arguments except `obj` for `torch.jit.script()` to convert model, for more details:
            https://pytorch.org/docs/master/generated/torch.jit.script.html.

    rH   
   example_outputszexample_outputs is required in scripting mode before PyTorch 1.10.Please provide example outputs or use trace mode to export onnx model.N)fr   r   r   r   cudar   c                   s&   g | ]}t |tjr| n|qS r4   r   r\   r   rx   r9   irV   r4   r5   r     s     z#convert_to_onnx.<locals>.<listcomp>r   seedTc                 S  s   g | ]
}|j qS r4   r-   r   r4   r4   r5   r     s     c                 S  s   g | ]}|   qS r4   )r   r   r   r4   r4   r5   r     s     ZCPUExecutionProvider)Z	providers   rt   r   r   )'r   r\   r   r   r   jitscriptioBytesIOr   exportr   Zload_model_from_stringgetvalueloadrV   r   is_availablerx   r   r   graphinputr   zipr   ZInferenceSessionSerializeToStringrunonnxreferenceZReferenceEvaluatorr   r   testingassert_closeassert_allcloser   r   rR   )r   r   r   r   r   r   r   r   rV   r   r   r   r   r   r   Ztorch_versioned_kwargsZmode_to_exportr   
onnx_model	torch_outZmodel_input_namesZ
input_dictZort_sessZonnx_outZsessr1r2	assert_fnr4   r   r5   r"   Q  s    -








 

"zdict | NonezSequence[Any] | None)	r   filename_or_objextra_filesr   r   rV   r   r   r   c	              	     s  |    t ^ |r@|dkr&tdtjj| fd|i|	}
ntjj| f|	}
|dk	rjtjj|
||d W 5 Q R X |r dkrttj	
 rdnd |dkrtd fdd	|D }|dk	rtj|n|
}|    |  } t < td
d t| | }td
d t|| }tdd W 5 Q R X t||D ]R\}}t|tjsdt|tjr@tddrxtjjntjj}|||||d q@|
S )a  
    Utility to convert a model into TorchScript model and save to file,
    with optional input / output data verification.

    Args:
        model: source PyTorch model to save.
        filename_or_obj: if not None, specify a file-like object (has to implement write and flush)
            or a string containing a file path name to save the TorchScript model.
        extra_files: map from filename to contents which will be stored as part of the save model file.
            for more details: https://pytorch.org/docs/stable/generated/torch.jit.save.html.
        verify: whether to verify the input and output of TorchScript model.
            if `filename_or_obj` is not None, load the saved TorchScript model and verify.
        inputs: input test data to verify model, should be a sequence of data, every item maps to a argument
            of `model()` function.
        device: target device to verify the model, if None, use CUDA if available.
        rtol: the relative tolerance when comparing the outputs of PyTorch model and TorchScript model.
        atol: the absolute tolerance when comparing the outputs of PyTorch model and TorchScript model.
        use_trace: whether to use `torch.jit.trace` to export the TorchScript model.
        kwargs: other arguments except `obj` for `torch.jit.script()` or `torch.jit.trace()` (if use_trace is True)
            to convert model, for more details: https://pytorch.org/docs/master/generated/torch.jit.script.html.

    Nz'Missing input data for tracing convert.example_inputs)r   r   _extra_filesr   r   $Missing input data for verification.c                   s&   g | ]}t |tjr| n|qS r4   r   r   r   r4   r5   r     s     z*convert_to_torchscript.<locals>.<listcomp>r   r   rH   r  r  )r   r\   r   r   r  tracer  saverV   r   r  r
  rx   r   r   r  r   r   r   r  r  r  )r   r  r  r   r   rV   r   r   r   r   script_moduleZts_modelr  Ztorchscript_outr  r  r  r4   r   r5   r#     s:    "




	min_shape	opt_shape	max_shaperV   	precisionr   r   c                 C  s\  t dd\}}	t dd\}
}	|||f}|s.g n|}|s:g n|}|
| ||jj}||}|dt|jj> }|	 }|r|j
|d f|  |||}||  }|sd}t|jD ]}||| d 7 }qtd	| | }|| |d
kr||jj |||}t }|| |
jj| t d| ||d}|S )af  
    This function takes an ONNX model as input, exports it to a TensorRT engine, wraps the TensorRT engine
    to a TensorRT engine-based TorchScript model and return the TorchScript model.

    Args:
        onnx_model: the source ONNX model to compile.
        min_shape: the minimum input shape of the converted TensorRT model.
        opt_shape: the optimization input shape of the model, on which the TensorRT optimizes.
        max_shape: the maximum input shape of the converted TensorRT model.
        device: the target GPU index to convert and verify the model.
        precision: the weight precision of the converted TensorRT engine-based TorchScript model.
            Should be 'fp32' or 'fp16'.
        input_names: optional input names of the ONNX model. Should be a sequence like
            `['input_0', 'input_1', ..., 'input_N']` where N equals to the number of the
            model inputs.
        output_names: optional output names of the ONNX model. Should be a sequence like
            `['output_0', 'output_1', ..., 'output_N']` where N equals to the number of
            the model outputs.

    Ztensorrtz8.5.3torch_tensorrt1.4.0rH   r   r>   
z.TensorRT cannot parse the ONNX model, due to:
Zfp16cuda:)rV   Zinput_binding_namesZoutput_binding_names)!r   
set_deviceLoggerWARNINGBuilderZcreate_networkrD   ZNetworkDefinitionCreationFlagZEXPLICIT_BATCHZcreate_optimization_profile	set_shapeZ
OnnxParserparser  r   Z
num_errorsZ	get_errordesc	ExceptionZcreate_builder_configZadd_optimization_profileZset_flagZBuilderFlagZFP16Zbuild_serialized_networkr  r  writetsZembed_engine_in_new_moduler	  r\   rV   )r  r#  r$  r%  rV   r&  r   r   Ztrtr3   r'  input_shapesr   buildernetworkprofileparsersuccessZparser_error_messager   configZserialized_enginer   	trt_modelr4   r4   r5   _onnx_trt_compile  sB    





r=  Zinput_0Zoutput_0{Gz?zSequence[int] | Nonezbool | None)r   r&  input_shapedynamic_batchsizer   r  r   rV   use_onnxonnx_input_namesonnx_output_namesr   r   c               
   K  s  t ddd\}}tj s"td|s.td|sDtd| d |dk	rjt|d	krjtd
| d |rr|nd}|rt	d| nt	d}|dkrtj
ntj}tt||g}ddddd}|r|||d }|||d }|||d }n| } }}|  |} t| |||d}|  |r|	rJdd |	D ni }||
rfdd |
D ni  t| ||	|
||d}t|||||||	|
d}nf|| t N tjj	|d4 |j|||dg}|j|f|||dd |}W 5 Q R X W 5 Q R X |r|dkrtd!|dk	r.tj|n|}t < tdd" t| | }tdd" t|| }tdd" W 5 Q R X t||D ]R\}}t|tjst|tjrtdd#rtjjntjj }|||||d$ q|S )%a  
    Utility to export a model into a TensorRT engine-based TorchScript model with optional input / output data verification.

    There are two ways to export a model:
    1, Torch-TensorRT way: PyTorch module ---> TorchScript module ---> TensorRT engine-based TorchScript.
    2, ONNX-TensorRT way: PyTorch module ---> TorchScript module ---> ONNX model ---> TensorRT engine --->
    TensorRT engine-based TorchScript.

    When exporting through the first way, some models suffer from the slowdown problem, since Torch-TensorRT
    may only convert a little part of the PyTorch model to the TensorRT engine. However when exporting through
    the second way, some Python data structures like `dict` are not supported. And some TorchScript models are
    not supported by the ONNX if exported through `torch.jit.script`.

    Args:
        model: a source PyTorch model to convert.
        precision: the weight precision of the converted TensorRT engine based TorchScript models. Should be 'fp32' or 'fp16'.
        input_shape: the input shape that is used to convert the model. Should be a list like [N, C, H, W] or
            [N, C, H, W, D].
        dynamic_batchsize: a sequence with three elements to define the batch size range of the input for the model to be
            converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH]. After converted, the batchsize of model
            input should between `MIN_BATCH` and `MAX_BATCH` and the `OPT_BATCH` is the best performance batchsize that the
            TensorRT tries to fit. The `OPT_BATCH` should be the most frequently used input batchsize in the application,
            default to None.
        use_trace: whether using `torch.jit.trace` to convert the PyTorch model to a TorchScript model and then convert to
            a TensorRT engine based TorchScript model or an ONNX model (if `use_onnx` is True), default to False.
        filename_or_obj: if not None, specify a file-like object (has to implement write and flush) or a string containing a
            file path name to load the TensorRT engine based TorchScript model for verifying.
        verify: whether to verify the input and output of the TensorRT engine based TorchScript model.
        device: the target GPU index to convert and verify the model. If None, use #0 GPU.
        use_onnx: whether to use the ONNX-TensorRT way to export the TensorRT engine-based TorchScript model.
        onnx_input_names: optional input names of the ONNX model. This arg is only useful when `use_onnx` is True. Should be
            a sequence like `('input_0', 'input_1', ..., 'input_N')` where N equals to the number of the model inputs. If not
            given, will use `('input_0',)`, which supposes the model only has one input.
        onnx_output_names: optional output names of the ONNX model. This arg is only useful when `use_onnx` is True. Should be
            a sequence like `('output_0', 'output_1', ..., 'output_N')` where N equals to the number of the model outputs. If
            not given, will use `('output_0',)`, which supposes the model only has one output.
        rtol: the relative tolerance when comparing the outputs between the PyTorch model and TensorRT model.
        atol: the absolute tolerance when comparing the outputs between the PyTorch model and TensorRT model.
        kwargs: other arguments except `module`, `inputs`, `enabled_precisions` and `device` for `torch_tensorrt.compile()`
            to compile model, for more details: https://pytorch.org/TensorRT/py_api/torch_tensorrt.html#torch-tensorrt-py.
    r'  r(  )versionzCannot find any GPU devices.z*Missing the input shape for model convert.z@There is no dynamic batch range. The converted model only takes z shape input.Nr   zAThe dynamic batch range sequence should have 3 elements, but got z
 elements.r   r*  zcuda:0Zfp32r   rD   )rA  	scale_numc                 S  s   | }|d  |9  < |S )Nr   r4   )rA  rG  Zscale_shaper4   r4   r5   scale_batch_size  s    z(convert_to_trt.<locals>.scale_batch_sizerH   rI   )rV   r   r   c                 S  s   i | ]}|d diqS r   Z	batchsizer4   r9   r   r4   r4   r5   
<dictcomp>  s      z"convert_to_trt.<locals>.<dictcomp>c                 S  s   i | ]}|d diqS rI  r4   rJ  r4   r4   r5   rK    s      )r   r   r"  r   )r#  r$  r%  Ztorchscript)r   Zenabled_precisionsrV   irr  r   r  r  )!r   r\   r   r  r2  r   rk   rl   rK   rV   float32halfrandr   rx   r   r#   updater"   r=  r   Inputr   r  r
  r   r  r   r   r   r  r  r  ) r   r&  rA  rB  r   r  r   rV   rC  rD  rE  r   r   r   r'  r3   target_deviceZconvert_precisionr   rH  Zmin_input_shapeZopt_input_shapeZmax_input_shapeZir_modelr   r<  Zinput_placeholderr  Ztrt_outr  r  r  r4   r4   r5   r$   \  s    :
     

  




c                  G  s2   t jjd k	r(dt jjkr(t j| ddiS t j|  S )Nindexingijr\   meshgrid__kwdefaults__tensorsr4   r4   r5   r%     s    c                  G  sJ   t jjd k	r(dt jjkr(t j| ddiS t j| d | d f| dd   S )NrS  xyrH   r   rI   rU  rX  r4   r4   r5   r&     s    ztorch.nn.Modulez!list[tuple[str, torch.nn.Module]])parentr.   
new_moduleoutstrict_matchmatch_devicerT   c                   s   |r4t dd |  D }t|dkr4||d  |d}|dkr|d|  t|  } ||d d }g }t| ||| | fdd	|D 7 }nZ|rt| |}	t| || |||	fg7 }n0|  D ]&\}
}||
krt| |
t	||d
d qdS )zO
    Helper function for :py:class:`monai.networks.utils.replace_modules`.
    c                 S  s   h | ]
}|j qS r4   r   r   r4   r4   r5   r;     s     z#_replace_modules.<locals>.<setcomp>rH   r   r?   rs   Nc                   s&   g | ]}  d |d  |d fqS )r?   r   rH   r4   )r9   rparent_namer4   r5   r     s     z$_replace_modules.<locals>.<listcomp>T)r^  )
rZ   
parametersrK   rx   r   rE   _replace_modulesrL   rA   r   )r[  r.   r\  r]  r^  r_  devicesr   Z_outZ
old_modulemod_namer3   r4   ra  r5   rd    s&    


rd  )r[  r.   r\  r^  r_  rT   c                 C  s   g }t | ||||| |S )a  
    Replace sub-module(s) in a parent module.

    The name of the module to be replace can be nested e.g.,
    `features.denseblock1.denselayer1.layers.relu1`. If this is the case (there are "."
    in the module name), then this function will recursively call itself.

    Args:
        parent: module that contains the module to be replaced
        name: name of module to be replaced. Can include ".".
        new_module: `torch.nn.Module` to be placed at position `name` inside `parent`. This will
            be deep copied if `strict_match == False` multiple instances are independent.
        strict_match: if `True`, module name must `== name`. If false then
            `name in named_modules()` will be used. `True` can be used to change just
            one module, whereas `False` can be used to replace all modules with similar
            name (e.g., `relu`).
        match_device: if `True`, the device of the new module will match the model. Requires all
            of `parent` to be on the same device.

    Returns:
        List of tuples of replaced modules. Element 0 is module name, element 1 is the replaced module.

    Raises:
        AttributeError: if `strict_match` is `True` and `name` is not a named module in `parent`.
    rd  )r[  r.   r\  r^  r_  r]  r4   r4   r5   r'   +  s     )r[  r.   r\  r^  r_  c                 c  sJ   g }zt | ||||| dV  W 5 |D ]\}}t | ||g d|d q&X dS )z
    Temporarily replace sub-module(s) in a parent module (context manager).

    See :py:class:`monai.networks.utils.replace_modules`.
    T)r^  r_  Nrg  )r[  r.   r\  r^  r_  replacedr   r4   r4   r5   r(   P  s    
)r   c           	        s  dk	r dk	rt dt| }t }dk	rfdd|D }|  D ]@\}}||krjd|_|| qH|jsHd|_td| d qH dk	r fd	d|D }|  D ]@\}}||krd|_|| q|jsd|_td
| d qt	t
| dt
| d dS )a  
    A utilty function to help freeze specific layers.

    Args:
        model: a source PyTorch model to freeze layer.
        freeze_vars: a regular expression to match the `model` variable names,
            so that their `requires_grad` will set to `False`.
        exclude_vars: a regular expression to match the `model` variable names,
            except for matched variable names, other `requires_grad` will set to `False`.

    Raises:
        ValueError: when freeze_vars and exclude_vars are both specified.

    NzEIncompatible values: freeze_vars and exclude_vars are both specified.c                   s$   h | ]} rt  |r|qS r4   r   r   )freeze_varsr4   r5   r;   }  s       z freeze_layers.<locals>.<setcomp>FTz!The freeze_vars does not include z0, but requires_grad is False, change it to True.c                   s$   h | ]} rt  |r|qS r4   r   r   r   r4   r5   r;     s       zThe exclude_vars includes r   z variables frozen.)r   r   rZ   named_parametersr}   r   rk   rl   r   r   rK   )	r   ri  r   r   Zfrozen_keysZ	to_freezer.   paramZ
to_excluder4   )r   ri  r5   freeze_layersh  s2    
rl  )F)Frd   )NNFF)FF)r>   NNTN)NNNNNFNFNr   rd   T)NNFNNr   rd   F)
NFNFNFr>  r?  r@  rd   )TT)TT)TT)NN)K__doc__
__future__r   r  r   rk   collectionsr   collections.abcr   r   r   
contextlibr   copyr   typingr	   r   r   r\   torch.nnr   monai.apps.utilsr
   monai.configr   monai.utils.miscr   r   r   monai.utils.moduler   r   r   monai.utils.type_conversionr   r   r   r3   r  r   __all__r   r   r/   r+   r)   r*   rf   r   r   r   r   r   normal_r   kaiming_normal_r   r   r   r   r   r    r!   r"   r#   r=  r$   r%   r&   rd  r'   r(   rl  r4   r4   r4   r5   <module>   s   
5    2  ) -"#     U%            *{         GO          (   *  %  