o
    &ik                     @  s  d dl mZ d dlZd dlZd dlZd dlZd dlmZ d dlm	Z	 d dl
mZ d dlmZmZmZmZmZ d dlZd dlmZ d dlmZmZmZ d d	lmZ ed
\ZZerld dlmZ d dlm Z m!Z!m"Z"m#Z#m$Z$ ed\Z%Z&edd\Z'Z(ed\Z)Z(e* Z+dd Z,dd Z-dd Z.G dd de/Z0G dd dZ1dd Z2dd Z3d7d%d&Z4G d'd( d(Z5d)d* Z6			d8d9d5d6Z7dS ):    )annotationsN)OrderedDict)Path)
MethodType)AnyDictListTupleUnion)
get_logger)add_casts_around_normsconvert_to_onnxget_profile_shapes)optional_import
polygraphy)bytes_from_path)CreateConfigProfileengine_bytes_from_networkengine_from_bytesnetwork_from_onnx_pathZtensorrttorch_tensorrtz1.4.0zcuda.cudartc                   C  s<   t jtjt jtjt jtjt jtjt jtjt jtjt jtjiS N)	trtint32torchfloat32float16bfloat16int64int8bool r"   r"   ]/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/networks/trt_compiler.pytrt_to_torch_dtype_dict1   s   r$   c                 C  s|   i }| s|S | D ]3}|D ].}g }|| }t t|d D ]}|d | |d | kr/|| qt|dkr:|||< qq|S )z
    This method calculates dynamic_axes to use in onnx.export().
    Args:
       profiles: [[min,opt,max],...] list of profile dimensions
    r      )rangelenappend)profilesdynamic_axesprofilekeyaxesvalsir"   r"   r#   get_dynamic_axes=   s    
r0   c                 C  s6   | d }|dkrt d| t| dkr| d S dS )z[
    Error reporting method for CUDA calls.
    Args:
     cuda_ret: CUDA return code.
    r   zCUDA ERROR:    N)RuntimeErrorr'   )Zcuda_reterrr"   r"   r#   cuassertR   s   r4   c                   @  s   e Zd ZdZdS )
ShapeErrorzM
    Exception class to report errors from setting TRT plan input shapes
    N)__name__
__module____qualname____doc__r"   r"   r"   r#   r5   `   s    r5   c                   @  s4   e Zd ZdZdddZdd Zdd Zdd
dZdS )	TRTEnginezK
    An auxiliary class to implement running of TRT optimized engines

    Nc                 C  s  || _ |ptd| _| jd| j   tt| j | _t | _d| _	| j
 | _g | _g | _g | _d| _i | _t }t| jjD ]6}| j| }| j|tjjkrY| j| qA| j|tjjkrw| j| || j| }| j| qA| jd| j  d| j d| j  dS )z
        Loads serialized engine, creates execution context and activates it
        Args:
          plan_path: path to serialized TRT engine.
          logger: optional logger object
        monai.networks.trt_compilerzLoading TensorRT engine: Nr   zLoaded TensorRT engine: z
.
Inputs: z

Outputs: )	plan_pathr   loggerinfor   r   enginer   tensorscuda_graph_instanceZcreate_execution_contextcontextinput_namesoutput_namesdtypescur_profileinput_tabler$   r&   Znum_io_tensorsZget_tensor_moder   ZTensorIOModeINPUTr(   ZOUTPUTZget_tensor_dtype)selfr<   r=   
dtype_dictidxbindingdtyper"   r"   r#   __init__n   s2   
zTRTEngine.__init__c                 C  s~   | j }t| jD ]4\}}t||}|| jvs"t| j| j|kr<tj|| j	| |d
 }|| j|< |||  qdS )zx
        Allocates outputs to run TRT engine
        Args:
            device: GPU device to allocate memory on
        )rM   deviceN)rB   	enumeraterD   listZget_tensor_shaper@   shaper   emptyrE   
contiguousset_tensor_addressdata_ptr)rI   rO   ctxr/   rL   rR   tr"   r"   r#   allocate_buffers   s   
zTRTEngine.allocate_buffersc                   s   j }j j} fdd}	 z|  W n( ty7   jd |j }||kr+ |_ j| Y n ty>    w q  }t|dksLJ dS )z
        Sets input bindings for TRT engine according to feed_dict
        Args:
           feed_dict: a dictionary [str->Tensor]
           stream: CUDA stream to use
        c                    sT   j D ]$} j|  d }|d ur'| }|j} | |  | |  qd S r   )rC   getrG   rT   rR   Zset_input_shaperU   rV   )rL   rX   rR   rW   	feed_dictrI   r"   r#   try_set_inputs   s   
z,TRTEngine.set_inputs.<locals>.try_set_inputsTr1   r   N)	r?   rB   rF   r5   Znum_optimization_profilesZset_optimization_profile_async	ExceptionZinfer_shapesr'   )rI   r\   streameZlast_profiler]   Znext_profileleftr"   r[   r#   
set_inputs   s(   	zTRTEngine.set_inputsFc                 C  s   |rO| j durtt| j | tt| | jS | j|}|s&tdtt|tj	j
 | j| tt|}tt|d| _ | jd | jS | j|}tt| |sbtd| jS )z
        Runs TRT engine.
        Args:
            stream: CUDA stream to run on
            use_cuda_graph: use CUDA graph. Note: requires all inputs to be the same GPU memory between calls.
        NzERROR: inference failed.r   zCUDA Graph captured!)rA   r4   cudartZcudaGraphLaunchZcudaStreamSynchronizerB   Zexecute_async_v3
ValueErrorZcudaStreamBeginCaptureZcudaStreamCaptureModeZ cudaStreamCaptureModeThreadLocalZcudaStreamEndCaptureZcudaGraphInstantiater=   r>   r@   )rI   r_   use_cuda_graphZnoerrorgraphr"   r"   r#   infer   s*   
zTRTEngine.inferr   )F)r6   r7   r8   r9   rN   rY   rb   rg   r"   r"   r"   r#   r:   h   s    
 $r:   c                 C  s   t | tjr| S t|  S r   )
isinstancer   Tensortensorcuda)dr"   r"   r#   make_tensor   s   rm   c                 C  sp   i }| D ]1}|| }|d ur5t |tst |tr/tt|D ]}t|| || d| < qqt|||< q|S )N_)rh   rQ   tupler&   r'   rm   )rC   input_exampleunrolled_inputnamevalr/   r"   r"   r#   unroll_input   s   rt   retList[torch.Tensor]output_listsList[List[int]]return3Tuple[Union[torch.Tensor, List[torch.Tensor]], ...]c           
      C  s  t  }d}tt|D ]}|| }t|dkst|dksJ t|dks+|d dkr9g || | R }|d }q|d dkrUg || |||d   R }||d  }q|d dkrt  }t| }tt|d |dD ]M}|| }	t|	dkst|	dksJ t|	dks|	d dkr|d }g || | R }ql|	d dkr||	d  }g || |||	d   R }qltdg || || |ddd R } |S q|S )a)  
    Implements parsing of 'output_lists' arg of trt_compile().

    Args:
      ret: plain list of Tensors

      output_lists: list of output group sizes: to form some Lists/Tuples out of 'ret' List, this will be a list
                    of group dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list.
        Format: [[group_n] | [], ...]
          [] or group_n == 0 : next output from ret is a scalar
          group_n > 0  :       next output from ret is a list of group_n length
          group_n == -1:       next output is a dynamic list. This entry can be at any
                               position in output_lists, but can appear only once.
    Returns:
       Tuple of Union[torch.Tensor, List[torch.Tensor]], according to the grouping in output_lists

    r   r1   zTwo -1 lists in outputN)ro   r&   r'   rd   )
ru   rw   groupscurlglZ
rev_groupsZrcurrlZrglr"   r"   r#   parse_groups   s:   
 $r   c                   @  s^   e Zd ZdZ														dddZdd	 Zd
d Zdd Zdd Zdd Z	dS )TrtCompilerz
    This class implements:
      - TRT lazy persistent export
      - Running TRT with optional fallback to Torch
        (for TRT engines with limited profiles)
    fp16onnxNFc                 C  s  ddg}||vrt d| d| dg d}||vr&t d| d| d|| _|| _|| _|du| _|p7g | _|p<g | _|
pAg | _|| _|pIi | _	|	pNi | _
d| _|| _|| _d	| _|patd
| _t|j| _|du rv| jjdd }i | _| jjdurtt| jjD ]}| jj| d  }|durt|}|| j| jj| d  < q|| _|j| _|durtj| jrtj| j|k rt | j dS dS dS dS )a  
        Initialization method:
         Tries to load persistent serialized TRT engine
         Saves its arguments for lazy TRT build on first forward() call
        Args:
            model: Model to "wrap".
            plan_path : Path where to save persistent serialized TRT engine.
            precision: TRT builder precision o engine model. Should be 'fp32'|'tf32'|'fp16'|'bf16'.
            method: One of 'onnx'|'torch_trt'.
                    Default is 'onnx' (torch.onnx.export()->TRT). This is the most stable and efficient option.
                    'torch_trt' may not work for some nets. Also AMP must be turned off for it to work.
            input_names: Optional list of input names. If None, will be read from the function signature.
            output_names: Optional list of output names. Note: If not None, patched forward() will return a dictionary.
            output_lists: Optional list of output group sizes: when forward() returns Lists/Tuples, this will be a list
                          of their dimensions, like [[], [5], [-1]] for returning Tensor, list of 5 items and dynamic list.
            export_args: Optional args to pass to export method. See onnx.export() and Torch-TensorRT docs for details.
            build_args: Optional args to pass to TRT builder. See polygraphy.Config for details.
            input_profiles: Optional list of profiles for TRT builder and ONNX export.
                            Each profile is a map of the form : {"input id" : [min_shape, opt_shape, max_shape], ...}.
            dynamic_batchsize: A sequence with three elements to define the batch size range of the input for the model to be
                               converted. Should be a sequence like [MIN_BATCH, OPT_BATCH, MAX_BATCH].
            [note]: If neither input_profiles nor dynamic_batchsize specified, static shapes will be used to build TRT engine.
            use_cuda_graph: Use CUDA Graph for inference. Note: all inputs have to be the same GPU memory between calls!
            timestamp: Optional timestamp to rebuild TRT engine (e.g. if config file changes).
            fallback: Allow to fall back to Pytorch when TRT inference fails (e.g, shapes exceed max profile).
        r   	torch_trtz)trt_compile(): 'method' should be one of z, got: .)fp32tf32r   bf16z,trt_compile(): 'precision' should be one of NFr;   r1   )!rd   r<   	precisionmethodreturn_dictrD   rw   r)   dynamic_batchsizeexport_args
build_argsr?   re   fallbackdisabledr   r=   inspectgetfullargspecforwardargspecargsdefaultsr&   r'   rm   rC   old_forwardospathexistsgetmtimeremove)rI   modelr<   r   r   rC   rD   rw   r   r   Zinput_profilesr   re   	timestampr   Zforward_overrider=   Zmethod_valsZprecision_valsr/   rl   r"   r"   r#   rN   .  sH   .





(zTrtCompiler.__init__c                 C  s,   i }t |D ]\}}| j| }|||< q|S r   )rP   rC   )rI   rp   Z
trt_inputsr/   inp
input_namer"   r"   r#   _inputs_to_dict  s
   

zTrtCompiler._inputs_to_dictc              
   C  s   z:t | j| j| _i }| jjD ]}|dr"|| jvr"|dd }n|}|||< q|| j_| jd| jj  W dS  tyV } z| jd|  W Y d}~dS d}~ww )zO
        Loads TRT plan from disk and activates its execution context.
        __r%   NzEngine loaded, inputs:z$Exception while loading the engine:
)	r:   r<   r=   r?   rC   
startswithrG   r>   r^   )rI   rG   rr   	orig_namer`   r"   r"   r#   _load_engine  s   
 zTrtCompiler._load_enginec              
   C  sV  | j }|| t|dkr|| | | jdu r| js|j}| j|_z4|   | jdu rX|	 }t
  | || W d   n1 sHw   Y  |   | jdusXJ W n$ ty} } z| jrq| jd|  d| _n|W Y d}~nd}~ww | js| js| D ]}~qt
j  ||_zj| jdurtY t
j }	t
jj|	d}
| jt| j||
j | jj|	d |
t
j  | jj|
j| jd}| j st!|" }| j#rt$|| j#}n
t|dkr|d }|W  d   W S 1 sw   Y  W n$ ty" } z| jr| jd| d	 n|W Y d}~nd}~ww | j|i |S )
af  
        Main forward method:
         Builds TRT engine if not available yet.
         Tries to run TRT engine
         If exception thrown and self.callback==True: falls back to original Pytorch

        Args: Passing through whatever args wrapped module's forward() has
        Returns: Passing through wrapped module's forward() return value(s)

        r   NzFailed to build engine: T)rO   )re   r1   zException: z
Falling back to Pytorch ...)%r   updater'   r   r?   r   r   r   r   copyr   no_grad_build_and_saver^   r   r=   r>   
parametersrk   empty_cachelock_smcurrent_deviceStreamrb   rt   rC   cuda_streamrY   wait_streamcurrent_streamrg   re   r   rQ   valuesrw   r   )rI   r   argvkwargsr   new_forwardr   r`   paramrO   r_   ru   r"   r"   r#   r     sp   





"zTrtCompiler.forwardc           	      C  s   g }| j D ]"}t }| D ]\}}|j||d |d |d d q|| q| j }| jdk|d< | jdkr>d|d< n	| jd	krGd|d	< | j	d
| d| j
  t|tjjgd}t|tdd|i|dS )z[
        Builds TRT engine from ONNX file at onnx_path and saves to self.plan_path
        r   r1   r%   )minoptmaxr   r   r   Tr   zBuilding TensorRT engine for z: )flagsr)   )configNr"   )r)   r   itemsaddr(   r   r   r   r=   r>   r<   r   r   ZOnnxParserFlagZNATIVE_INSTANCENORMr   r   )	rI   	onnx_pathr)   r+   pidrs   r   networkr"   r"   r#   _onnx_to_trt  s   
 



zTrtCompiler._onnx_to_trtc              
     s`  j durdS j}d}t| jdkrRtjg}jdkr%|tj njdkr0|tj	 t
| }dd fdd|D }tj|d	f||d
|}nΈj  rtjdkrbtdt dkrltdi | D ]6\}}	 fdd}
t|	t
st|	trtt|	D ]}|
| d| |	|  qqrt|	tjr|
||	 qrg_tj_tjdkr|dji t Q}tj|}tt|d }j !d| dt
|"  ddj# dj d|   t$||f|t
|" j#d| j !d %|}W d   n	1 sw   Y  |r.t&j'd(| dS dS )z
        If TRT engine is not ready, exports model to ONNX,
        builds TRT engine and saves serialized TRT engine to the disk.
        Args:
             input_example: passed to onnx.export()
        Nr   r   r   c                 S  s    t | |\}}}tj|||dS )N)Z	min_shapeZ	opt_shapeZ	max_shape)r   r   Input)input_shaper   Zmin_input_shapeZopt_input_shapeZmax_input_shaper"   r"   r#   get_torch_trt_input  s   z8TrtCompiler._build_and_save.<locals>.get_torch_trt_inputc                   s   g | ]	} |j jqS r"   )rR   r   ).0r/   )r   rI   r"   r#   
<listcomp>  s    z/TrtCompiler._build_and_save.<locals>.<listcomp>r   )Z
arg_inputsenabled_precisionsr   zEERROR: Both dynamic_batchsize and input_profiles set for TrtCompiler!   z&dynamic_batchsize has to have len ==3 c                   sR   |j }t|dkr'|dd  } d g| d g| d g|g| < d S d S )Nr   r1   r%   )rR   r'   )r   rs   sh)dbsr+   r"   r#   add_profile)  s
   0z0TrtCompiler._build_and_save.<locals>.add_profilern   r*   z
model.onnxzExporting to z:
unrolled_inputs=
zoutput_names=z
input_names=z
export args: )filenamerC   rD   zExport to ONNX successful.wb))r?   r   r   r   r   r   r   r(   r   r   rQ   r   r   Zconvert_method_to_trt_enginer   r'   r)   rd   r   rh   ro   r&   ri   r0   r*   r   tempfileTemporaryDirectoryrt   rC   strr   r=   r>   keysrD   r   r   openr<   write)rI   r   rp   r   Zengine_bytesr   inputsZ	tt_inputsr   rs   r   r/   tmpdirrq   r   r"   )r   r   r+   rI   r#   r     s   






zTrtCompiler._build_and_save)r   r   NNNNNNNFNFNN)
r6   r7   r8   r9   rN   r   r   r   r   r   r"   r"   r"   r#   r   &  s,    
YFr   c                 O  s   | j | ||S )zk
    Patch function to replace original model's forward() with.
    Redirects to TrtCompiler.forward()
    )_trt_compilerr   )rI   r   r   r"   r"   r#   trt_forwardQ  s   r   r   torch.nn.Module	base_pathr   r   Dict[str, Any] | None	submoduleUnion[str, List[str]] | Noner=   
Any | Nonec                   s  dddddd}|  pi  | trttrttj rttj|r:t	tj
|}d v r6tt	 d |}| d<  fdd	}fd
d|durmt|trS|g}|D ]}| |\}	}
|t|	|
|d |  qU| S || | | S pytdd | S )a  
    Instruments model or submodule(s) with TrtCompiler and replaces its forward() with TRT hook.
    Note: TRT 10.3 is recommended for best performance. Some nets may even fail to work with TRT 8.x.
        NVIDIA Volta support (GPUs with compute capability 7.0) has been removed starting with TensorRT 10.5.
        Review the TensorRT Support Matrix for which GPUs are supported.
    Args:
      model: module to patch with TrtCompiler object.
      base_path: TRT plan(s) saved to f"{base_path}[.{submodule}].plan" path.
                 dirname(base_path) must exist, base_path does not have to.
                 If base_path does point to existing file (e.g. associated checkpoint),
                 that file becomes a dependency - its mtime is added to args["timestamp"].
      args: Optional dict : unpacked and passed to TrtCompiler() - see TrtCompiler above for details.
      submodule: Optional hierarchical id(s) of submodule to patch, e.g. ['image_decoder.decoder']
                  If None, TrtCompiler patch is applied to the whole model.
                  Otherwise, submodule (or list of) is being patched.
      logger: Optional logger for diagnostics.
    Returns:
      Always returns same model passed in as argument. This is for ease of use in configs.
    r   r      Zobey)Zbuilder_optimization_levelZprecision_constraints)r   r   r   r   c                   sF   t | ds!| j| _t| |d fdi }|| _tt| | _d S d S )Nr   z.planr=   )hasattrr   orig_forwardr   r   r   r   )r   r   wrapper)r   r=   r"   r#   wrap  s   
ztrt_compile.<locals>.wrapc                   sJ   | d}|dkr!|d | }t| |} ||d d  } | |S | |fS )Nr   r{   r1   )findgetattr)parentr   rK   parent_name)find_subr"   r#   r     s   


ztrt_compile.<locals>.find_subNr   r;   zSTensorRT and/or polygraphy packages are not available! trt_compile() has no effect.)r   trt_importedpolygraphy_importedr   rk   is_availabler   r   r   intr   r   rh   r   r   r   warning)r   r   r   r   r=   Zdefault_argsr   r   sr   subr"   )r   r   r=   r#   trt_compileY  s4   



r   )ru   rv   rw   rx   ry   rz   )NNN)r   r   r   r   r   r   r   r   r=   r   ry   r   )8
__future__r   r   r   r   	threadingcollectionsr   pathlibr   typesr   typingr   r   r   r	   r
   r   monai.apps.utilsr   monai.networks.utilsr   r   r   monai.utils.moduler   r   r   Zpolygraphy.backend.commonr   Zpolygraphy.backend.trtr   r   r   r   r   r   r   r   rn   rc   Lockr   r$   r0   r4   r^   r5   r:   rm   rt   r   r   r   r   r"   r"   r"   r#   <module>   sJ   z
2  -