o
    0i&                    @  sL  d dl mZ d dlZd dlZd dlmZmZ d dlmZm	Z	m
Z
mZmZ d dlmZ d dlmZ d dlmZ d dlZd dlmZ d dlm  mZ d dlmZ d d	lmZ d d
lmZ d dlm Z  d dl!m"Z"m#Z# d dl$m%Z% d dl&m'Z'm(Z( d dl)m*Z*m+Z+m,Z,m-Z-m.Z.m/Z/m0Z0 d dl1m2Z2m3Z3 d dl4m5Z5m6Z6 d dl7m8Z8m9Z9m:Z:m;Z;m<Z<m=Z= d dl>m?Z?m@Z@mAZA e=ddd\ZBZCeeDZEg dZFG dd deZGG dd deGZHG dd deGZIG dd deGZJG dd  d eJZKG d!d" d"eGZLG d#d$ d$eJZMG d%d& d&eGZNG d'd( d(eNZOG d)d* d*eNZPG d+d, d,ePZQG d-d. d.ejRZSdS )/    )annotationsN)ABCabstractmethod)CallableIterableIteratorMappingSequence)partial)locate)Any)
get_logger)decollate_batch)
MetaTensor)ThreadBuffer)	AvgMergerMerger)Splitter)compute_importance_mapsliding_window_inference)VQVAEAutoencoderKL
ControlNetDecoderOnlyTransformerDiffusionModelUNetSPADEAutoencoderKLSPADEDiffusionModelUNet)RFlowScheduler	Scheduler)CenterSpatialCrop
SpatialPad)	BlendModeOrdering	PatchKeysPytorchPadModeensure_tupleoptional_import)CAMGradCAM	GradCAMpptqdmname)InfererPatchInfererSimpleInfererSlidingWindowInfererSaliencyInfererSliceInfererSlidingWindowInfererAdaptc                   @  s   e Zd ZdZedd
dZdS )r-   a  
    A base class for model inference.
    Extend this class to support operations during inference, e.g. a sliding window method.

    Example code::

        device = torch.device("cuda:0")
        transform = Compose([ToTensor(), LoadImage(image_only=True)])
        data = transform(img_path).to(device)
        model = UNet(...).to(device)
        inferer = SlidingWindowInferer(...)

        model.eval()
        with torch.no_grad():
            pred = inferer(inputs=data, network=model)
        ...

    inputstorch.Tensornetworkr   argsr   kwargsreturnc                 O  s   t d| jj d)a  
        Run inference on `inputs` with the `network` model.

        Args:
            inputs: input of the model inference.
            network: model for inference.
            args: optional args to be passed to ``network``.
            kwargs: optional keyword args to be passed to ``network``.

        Raises:
            NotImplementedError: When the subclass does not override this method.

        z	Subclass z must implement this method.)NotImplementedError	__class____name__selfr4   r6   r7   r8    r?   X/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/inferers/inferer.py__call__R   s   zInferer.__call__N)
r4   r5   r6   r   r7   r   r8   r   r9   r   )r<   
__module____qualname____doc__r   rA   r?   r?   r?   r@   r-   >   s    r-   c                   @  sl   e Zd ZdZdeddddddfd5ddZd6ddZd7d!d"Zd8d)d*Zd+d, Z	d-d. Z
d/d0 Zd9d3d4ZdS ):r.   a&  
    Inference on patches instead of the whole image based on Splitter and Merger.
    This splits the input image into patches and then merge the resulted patches.

    Args:
        splitter: a `Splitter` object that split the inputs into patches. Defaults to None.
            If not provided or None, the inputs are considered to be already split into patches.
            In this case, the output `merged_shape` and the optional `cropped_shape` cannot be inferred
            and should be explicitly provided.
        merger_cls: a `Merger` subclass that can be instantiated to merges patch outputs.
            It can also be a string that matches the name of a class inherited from `Merger` class.
            Defaults to `AvgMerger`.
        batch_size: batch size for patches. If the input tensor is already batched [BxCxWxH],
            this adds additional batching [(Bp*B)xCxWpxHp] for inference on patches.
            Defaults to 1.
        preprocessing: a callable that process patches before the being fed to the network.
            Defaults to None.
        postprocessing: a callable that process the output of the network.
            Defaults to None.
        output_keys: if the network output is a dictionary, this defines the keys of
            the output dictionary to be used for merging.
            Defaults to None, where all the keys are used.
        match_spatial_shape: whether to crop the output to match the input shape. Defaults to True.
        buffer_size: number of patches to be held in the buffer with a separate thread for batch sampling. Defaults to 0.
        merger_kwargs: arguments to be passed to `merger_cls` for instantiation.
            `merged_shape` is calculated automatically based on the input shape and
            the output patch shape unless it is passed here.
    N   Tr   splitterSplitter | None
merger_clstype[Merger] | str
batch_sizeintpreprocessingCallable | Nonepostprocessingoutput_keysSequence | Nonematch_spatial_shapeboolbuffer_sizemerger_kwargsr   r9   Nonec	                 K  s0  t |  t|ttd fst|tstdt| d|| _t|trAtd|d\}
}|s3t	|}
|
d u r?t
d| d|
}t|tsNtd| d|| _|	| _|d urft|sftdt| d|| _|d ur{t|s{td	t| d|| _|d
k rt
d| d|| _|| _|| _|| _d S )Nz'splitter' should be a `Splitter` object that returns: an iterable of pairs of (patch, location) or a MetaTensor that has `PatchKeys.LOCATION` metadata).z
 is given.zmonai.inferers.mergerr+   zThe requested `merger_cls` ['z'] does not exist.z+'merger' should be a subclass of `Merger`, z-'preprocessing' should be a callable object, z.'postprocessing' should be a callable object, rE   z(`batch_size` must be a positive number, )r-   __init__
isinstancer   type	TypeErrorrF   strr&   r   
ValueError
issubclassr   rH   rT   callablerL   rN   rJ   rO   rQ   rS   )r>   rF   rH   rJ   rL   rN   rO   rQ   rS   rT   Zvalid_merger_clsZmerger_foundr?   r?   r@   rV      s@   




zPatchInferer.__init__patches9Iterable[tuple[torch.Tensor, Sequence[int]]] | MetaTensor,Iterator[tuple[torch.Tensor, Sequence, int]]c           
      c  s0   t |tr4t|}td|| jD ] }t| j|| }||||  ||||  jtj |fV  qdS | j	dkrBt
|| j	dd}n|}dg| j }dg| j }d}|D ].}	|	d ||< |	d ||< |d7 }|| jkrt|||fV  dg| j }dg| j }d}qT|dkrt|d| ||fV  dS dS )zGenerate batch of patches and locations

        Args:
            patches: a tensor or list of tensors

        Yields:
            A batch of patches (torch.Tensor or MetaTensor), a sequence of location tuples, and the batch size
        r   g?)rS   timeoutNrE   )rW   r   lenrangerJ   minmetar#   LOCATIONrS   r   torchcat)
r>   r^   
total_sizeirJ   bufferZpatch_batchZlocation_batchidx_in_batchsampler?   r?   r@   _batch_sampler   s4   
.

zPatchInferer._batch_sampleroutputstuplec                   sF   t  tr| jd u rt  | _t fdd| jD S t ddS )Nc                 3  s    | ]} | V  qd S Nr?   ).0kro   r?   r@   	<genexpr>   s    z5PatchInferer._ensure_tuple_outputs.<locals>.<genexpr>T)
wrap_array)rW   dictrO   listkeysrp   r%   )r>   ro   r?   rt   r@   _ensure_tuple_outputs   s
   

z"PatchInferer._ensure_tuple_outputsr6   r   patchr5   r7   r8   c                 O  s@   | j r|  |}||g|R i |}| jr| |}| |S rq   )rL   rN   rz   )r>   r6   r{   r7   r8   ro   r?   r?   r@   _run_inference   s   


zPatchInferer._run_inferencec                 C  s   t ||d }g }g }|D ]Z}t ||d }	tdd t|jdd  |	jdd  D }
| j }| ||	|
\}}d|vrN||d< |d d u rNtdd|vrV||d< | j	di |}|
| |
|
 q||fS )	Nr   c                 s  s    | ]	\}}|| V  qd S rq   r?   )rr   ipopr?   r?   r@   ru     s    z3PatchInferer._initialize_mergers.<locals>.<genexpr>   merged_shapez `merged_shape` cannot be `None`.cropped_shaper?   )rg   chunkrp   zipshaperT   copy_get_merged_shapesr[   rH   append)r>   r4   ro   r^   rJ   in_patchmergersratiosZout_patch_batch	out_patchratiorT   r   r   mergerr?   r?   r@   _initialize_mergers   s$   ,

z PatchInferer._initialize_mergersc                 C  sX   t |||D ]#\}}}t |t||D ]\}	}
dd t |	|D }||
| qqd S )Nc                 S  s   g | ]
\}}t || qS r?   round)rr   lrr?   r?   r@   
<listcomp>  s    z+PatchInferer._aggregate.<locals>.<listcomp>)r   rg   r   	aggregate)r>   ro   	locationsrJ   r   r   Zoutput_patchesr   r   Zin_locr   Zout_locr?   r?   r@   
_aggregate  s   zPatchInferer._aggregatec           
      C  s   | j du rdS | j |}| j |}tdd t||D }tdd t||D }|jdd | }|jdd | }	| jsB|	}||	fS )z:Define the shape of merged tensors (non-padded and padded)N)NNc                 s       | ]\}}t || V  qd S rq   r   rr   sr   r?   r?   r@   ru   )      z2PatchInferer._get_merged_shapes.<locals>.<genexpr>c                 s  r   rq   r   r   r?   r?   r@   ru   *  r   r   )rF   Zget_input_shapeZget_padded_shaperp   r   r   rQ   )
r>   r4   r   r   original_spatial_shapeZpadded_spatial_shapeoutput_spatial_shapeZpadded_output_spatial_shaper   r   r?   r?   r@   r     s   
zPatchInferer._get_merged_shapesr4   NCallable[..., torch.Tensor | Sequence[torch.Tensor] | dict[Any, torch.Tensor]]c                 O  s~  | dd}|durpt|tjr)t|tjr)|j|jkr(td|j d|j nGt|trlt|trlt|t|krJtdt| dt| dt||D ]\\}}\}}|j|jkrjtd|j d|j dqOntd	| j	du rt|tjrt|t
rtj|jvrtd
n
tdt| d|}	|dur|}
n| 	|}	|dur| 	|}
g }g }|durt| |	| |
D ]0\\}}}\}}}||d< | j||g|R i |}|s| ||||\}}| ||||| qn.| |	D ](\}}}| j||g|R i |}|s| ||||\}}| ||||| qdd |D }| jr2tt| j|S t|dkr=|d S |S )a  
        Args:
            inputs: input data for inference, a torch.Tensor, representing an image or batch of images.
                However if the data is already split, it can be fed by providing a list of tuple (patch, location),
                or a MetaTensor that has metadata for `PatchKeys.LOCATION`. In both cases no splitter should be provided.
            network: target model to execute inference.
                supports callables such as ``lambda x: my_torch_model(x, additional_config)``
            args: optional args to be passed to ``network``.
            kwargs: optional keyword args to be passed to ``network``.
            condition (torch.Tensor, optional): If provided via `**kwargs`,
                this tensor must match the shape of `inputs` and will be sliced, patched, or windowed alongside the inputs.
                The resulting segments will be passed to the model together with the corresponding input segments.

        	conditionN*`condition` must match shape of `inputs` (), but got z/Length of `condition` must match `inputs`. Got  and .zREach `condition` patch must match the shape of the corresponding input patch. Got zX`condition` and `inputs` must be of the same type (both Tensor or both list of patches).z`PatchKey.LOCATION` does not exists in `inputs.meta`. If the inputs are already split into patches, the location of patches needs to be provided as `PatchKey.LOCATION` metadata in a MetaTensor. If the input is not already split, please provide `splitter`.z`splitter` should be set if the input is not already split into patches. For inputs that are split, the location of patches needs to be provided as (image, location) pairs, or as `PatchKey.LOCATION` metadata in a MetaTensor. The provided inputs type is c                 S  s   g | ]}|  qS r?   )finalize)rr   r   r?   r?   r@   r         z)PatchInferer.__call__.<locals>.<listcomp>rE   r   )poprW   rg   Tensorr   r[   rx   rb   r   rF   r   r#   rf   re   rX   rn   r|   r   r   rO   rw   )r>   r4   r6   r7   r8   r   r   _Z
cond_patchZpatches_locationsZcondition_locationsr   r   r^   r   rJ   Zcondition_patchesro   Zmerged_outputsr?   r?   r@   rA   5  s   



zPatchInferer.__call__)rF   rG   rH   rI   rJ   rK   rL   rM   rN   rM   rO   rP   rQ   rR   rS   rK   rT   r   r9   rU   )r^   r_   r9   r`   )ro   r   r9   rp   )
r6   r   r{   r5   r7   r   r8   r   r9   rp   )
r4   r5   r6   r   r7   r   r8   r   r9   r   )r<   rB   rC   rD   r   rV   rn   rz   r|   r   r   r   rA   r?   r?   r?   r@   r.   d   s$    
?
(
r.   c                   @  s$   e Zd ZdZdddZdddZdS )r/   z
    SimpleInferer is the normal inference method that run model forward() directly.
    Usage example can be found in the :py:class:`monai.inferers.Inferer` base class.

    r9   rU   c                 C     t |  d S rq   r-   rV   r>   r?   r?   r@   rV        zSimpleInferer.__init__r4   r5   r6   Callable[..., torch.Tensor]r7   r   r8   c                 O  s   ||g|R i |S )a  Unified callable function API of Inferers.

        Args:
            inputs: model input data for inference.
            network: target model to execute inference.
                supports callables such as ``lambda x: my_torch_model(x, additional_config)``
            args: optional args to be passed to ``network``.
            kwargs: optional keyword args to be passed to ``network``.

        r?   r=   r?   r?   r@   rA     s   zSimpleInferer.__call__Nr9   rU   )
r4   r5   r6   r   r7   r   r8   r   r9   r5   r<   rB   rC   rD   rV   rA   r?   r?   r?   r@   r/     s    
r/   c                      sN   e Zd ZdZddejdejdddddddddfd/ fd#d$Zd0d-d.Z  Z	S )1r0   a  
    Sliding window method for model inference,
    with `sw_batch_size` windows for every model.forward().
    Usage example can be found in the :py:class:`monai.inferers.Inferer` base class.

    Args:
        roi_size: the window size to execute SlidingWindow evaluation.
            If it has non-positive components, the corresponding `inputs` size will be used.
            if the components of the `roi_size` are non-positive values, the transform will use the
            corresponding components of img size. For example, `roi_size=(32, -1)` will be adapted
            to `(32, 64)` if the second spatial dimension size of img is `64`.
        sw_batch_size: the batch size to run window slices.
        overlap: Amount of overlap between scans along each spatial dimension, defaults to ``0.25``.
        mode: {``"constant"``, ``"gaussian"``}
            How to blend output of overlapping windows. Defaults to ``"constant"``.

            - ``"constant``": gives equal weight to all predictions.
            - ``"gaussian``": gives less weight to predictions on edges of windows.

        sigma_scale: the standard deviation coefficient of the Gaussian window when `mode` is ``"gaussian"``.
            Default: 0.125. Actual window sigma is ``sigma_scale`` * ``dim_size``.
            When sigma_scale is a sequence of floats, the values denote sigma_scale at the corresponding
            spatial dimensions.
        padding_mode: {``"constant"``, ``"reflect"``, ``"replicate"``, ``"circular"``}
            Padding mode when ``roi_size`` is larger than inputs. Defaults to ``"constant"``
            See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html
        cval: fill value for 'constant' padding mode. Default: 0
        sw_device: device for the window data.
            By default the device (and accordingly the memory) of the `inputs` is used.
            Normally `sw_device` should be consistent with the device where `predictor` is defined.
        device: device for the stitched output prediction.
            By default the device (and accordingly the memory) of the `inputs` is used. If for example
            set to device=torch.device('cpu') the gpu memory consumption is less and independent of the
            `inputs` and `roi_size`. Output is on the `device`.
        progress: whether to print a tqdm progress bar.
        cache_roi_weight_map: whether to precompute the ROI weight map.
        cpu_thresh: when provided, dynamically switch to stitching on cpu (to save gpu memory)
            when input image volume is larger than this threshold (in pixels/voxels).
            Otherwise use ``"device"``. Thus, the output may end-up on either cpu or gpu.
        buffer_steps: the number of sliding window iterations along the ``buffer_dim``
            to be buffered on ``sw_device`` before writing to ``device``.
            (Typically, ``sw_device`` is ``cuda`` and ``device`` is ``cpu``.)
            default is None, no buffering. For the buffer dim, when spatial size is divisible by buffer_steps*roi_size,
            (i.e. no overlapping among the buffers) non_blocking copy may be automatically enabled for efficiency.
        buffer_dim: the spatial dimension along which the buffers are created.
            0 indicates the first spatial dimension. Default is -1, the last spatial dimension.
        with_coord: whether to pass the window coordinates to ``network``. Defaults to False.
            If True, the ``network``'s 2nd input argument should accept the window coordinates.

    Note:
        ``sw_batch_size`` denotes the max number of windows per network inference iteration,
        not the batch size of inputs.

    rE   g      ?g      ?g        NFroi_sizeSequence[int] | intsw_batch_sizerK   overlapSequence[float] | floatmodeBlendMode | strsigma_scalepadding_modePytorchPadMode | strcvalfloat	sw_devicetorch.device | str | NonedeviceprogressrR   cache_roi_weight_map
cpu_thresh
int | Nonebuffer_steps
buffer_dim
with_coordr9   rU   c                   s  t    || _|| _|| _t|| _|| _|| _|| _	|| _
|	| _|
| _|| _|| _|| _|| _d | _z4|rTt|trTt|dkrT|	d u rHd}	tt| j|||	d| _|rc| jd u rftd W d S W d S W d S  ty } ztd| j d| d| d|	 d		|d }~ww )
Nr   cpu)r   r   r   zHcache_roi_weight_map=True, but cache is not created. (dynamic roi_size?)z	roi size z, mode=z, sigma_scale=z	, device=z^
Seems to be OOM. Please try smaller patch size or mode='constant' instead of mode='gaussian'.)superrV   r   r   r   r!   r   r   r   r   r   r   r   r   r   r   r   roi_weight_maprW   r	   rd   r   r%   warningswarnBaseExceptionRuntimeError)r>   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   er;   r?   r@   rV     sD   

zSlidingWindowInferer.__init__r4   r5   r6   r   r7   r   r8   Atorch.Tensor | tuple[torch.Tensor, ...] | dict[Any, torch.Tensor]c           	      O  s   | dd}|dur|j|jkrtd|j d|j |d| j}|d| j}|d| j}|du rH| jdurH|jdd  | jkrHd	}t	|| j
| j|| j| j| j| j| j| j|| j| jd||| jg|R i |S )
a  

        Args:
            inputs: model input data for inference.
            network: target model to execute inference.
                supports callables such as ``lambda x: my_torch_model(x, additional_config)``
            args: optional args to be passed to ``network``.
            kwargs: optional keyword args to be passed to ``network``.
            condition (torch.Tensor, optional): If provided via `**kwargs`,
                this tensor must match the shape of `inputs` and will be sliced, patched, or windowed alongside the inputs.
                The resulting segments will be passed to the model together with the corresponding input segments.
        r   Nr   r   r   r   r   r   r   )getr   r[   r   r   r   r   r   numelr   r   r   r   r   r   r   r   r   r   r   r   )	r>   r4   r6   r7   r8   r   r   r   r   r?   r?   r@   rA   *  s>   *zSlidingWindowInferer.__call__) r   r   r   rK   r   r   r   r   r   r   r   r   r   r   r   r   r   r   r   rR   r   rR   r   r   r   r   r   rK   r   rR   r9   rU   
r4   r5   r6   r   r7   r   r8   r   r9   r   )
r<   rB   rC   rD   r!   CONSTANTr$   rV   rA   __classcell__r?   r?   r   r@   r0     s$    :4r0   c                      s"   e Zd ZdZd fddZ  ZS )r3   a(  
    SlidingWindowInfererAdapt extends SlidingWindowInferer to automatically switch to buffered and then to CPU stitching,
    when OOM on GPU. It also records a size of such large images to automatically
    try CPU stitching for the next large image of a similar size.  If the stitching 'device' input parameter is provided,
    automatic adaptation won't be attempted, please keep the default option device = None for adaptive behavior.
    Note: the output might be on CPU (even if the input was on GPU), if the GPU memory was not sufficient.

    r4   r5   r6   r   r7   r   r8   r9   r   c                   sV  | j durt j||g|R i |S | jduo| jdk}| jduo-|jdd  | jk}|jo3| }|jo;|o;| }| jdurGtd| jnd}	d}
t	|jdd }|
t|}|j|d  |jd  dkrk|}
tdD ]}z!t j||g|R |r|j nt d|r|	nd|
d|W   S  ty } zx|s|rd	tt|jvr|t| |rd
}|jdd  d | _|rd
}td|j d n?d}|	| _td|	 d|
 d|j d n)|	dkrtd|	d }	|	| _td|j d|	 d nd
}td|j d W Y d}~qod}~ww td| d| d| d| d|	 
)ag  

        Args:
            inputs: model input data for inference.
            network: target model to execute inference.
                supports callables such as ``lambda x: my_torch_model(x, additional_config)``
            args: optional args to be passed to ``network``.
            kwargs: optional keyword args to be passed to ``network``.

        Nr   r   rE   r   
   r   )r   r   r   OutOfMemoryErrorFz3GPU stitching failed, attempting on CPU, image dim r   TzGPU stitching failed, buffer z dim z, image dim z)GPU buffered stitching failed, image dim z reducing buffer to z<GPU buffered stitching failed, attempting on CPU, image dim zSlidingWindowInfererAdapt  )r   r   rA   r   r   r   r   is_cudamaxrx   indexrc   rg   r   rZ   rX   r<   loggerinfowarning)r>   r4   r6   r7   r8   Zskip_bufferZcpu_condZgpu_stitchingZbuffered_stitchingr   r   shmax_dimr   r   r   r?   r@   rA   j  sl   
"
	

 z"SlidingWindowInfererAdapt.__call__r   )r<   rB   rC   rD   rA   r   r?   r?   r   r@   r3   `  s    	r3   c                   @  s(   e Zd ZdZ	ddddZdddZdS )r1   a  
    SaliencyInferer is inference with activation maps.

    Args:
        cam_name: expected CAM method name, should be: "CAM", "GradCAM" or "GradCAMpp".
        target_layers: name of the model layer to generate the feature map.
        class_idx: index of the class to be visualized. if None, default to argmax(logits).
        args: other optional args to be passed to the `__init__` of cam.
        kwargs: other optional keyword args to be passed to `__init__` of cam.

    Ncam_namerZ   target_layers	class_idxr   r7   r   r8   r9   rU   c                 O  sD   t |  | dvrtd| | _|| _|| _|| _|| _d S )N)camgradcamZ	gradcamppz4cam_name should be: 'CAM', 'GradCAM' or 'GradCAMpp'.)	r-   rV   lowerr[   r   r   r   r7   r8   )r>   r   r   r   r7   r8   r?   r?   r@   rV     s   


zSaliencyInferer.__init__r4   r5   r6   	nn.Modulec                 O  s   | j dkrt|| jg| jR i | j}n$| j dkr*t|| jg| jR i | j}nt|| jg| jR i | j}||| jg|R i |S )a  Unified callable function API of Inferers.

        Args:
            inputs: model input data for inference.
            network: target model to execute inference.
                supports callables such as ``lambda x: my_torch_model(x, additional_config)``
            args: other optional args to be passed to the `__call__` of cam.
            kwargs: other optional keyword args to be passed to `__call__` of cam.

        r   r   )r   r'   r   r7   r8   r(   r)   r   )r>   r4   r6   r7   r8   r   r?   r?   r@   rA     s   
 
 zSaliencyInferer.__call__rq   )r   rZ   r   rZ   r   r   r7   r   r8   r   r9   rU   )r4   r5   r6   r   r7   r   r8   r   r   r?   r?   r?   r@   r1     s
    r1   c                      s@   e Zd ZdZdd fd
dZd fddZ	ddddZ  ZS )r2   a[  
    SliceInferer extends SlidingWindowInferer to provide slice-by-slice (2D) inference when provided a 3D volume.
    A typical use case could be a 2D model (like 2D segmentation UNet) operates on the slices from a 3D volume,
    and the output is a 3D volume with 2D slices aggregated. Example::

        # sliding over the `spatial_dim`
        inferer = SliceInferer(roi_size=(64, 256), sw_batch_size=1, spatial_dim=1)
        output = inferer(input_volume, net)

    Args:
        spatial_dim: Spatial dimension over which the slice-by-slice inference runs on the 3D volume.
            For example ``0`` could slide over axial slices. ``1`` over coronal slices and ``2`` over sagittal slices.
        args: other optional args to be passed to the `__init__` of base class SlidingWindowInferer.
        kwargs: other optional keyword args to be passed to `__init__` of base class SlidingWindowInferer.

    Note:
        ``roi_size`` in SliceInferer is expected to be a 2D tuple when a 3D volume is provided. This allows
        sliding across slices along the 3D volume using a selected ``spatial_dim``.

    r   spatial_dimrK   r7   r   r8   r9   rU   c                   s(   || _ t j|i | t| j| _d S rq   )r   r   rV   r%   r   orig_roi_size)r>   r   r7   r8   r   r?   r@   rV     s   zSliceInferer.__init__r4   r5   r6   r   r   c                   s   j dkr	tdtj_tjdkr0t|jdd dkr0tj_jj d nt	dj d|j d|
d	d}|durY|j|jkrYtd
|j d|j |durkt j| fdd|dS t j| fdddS )aE  
        Args:
            inputs: 3D input for inference
            network: 2D model to execute inference on slices in the 3D input
            args: optional args to be passed to ``network``.
            kwargs: optional keyword args to be passed to ``network``.
            condition (torch.Tensor, optional): If provided via `**kwargs`,
                this tensor must match the shape of `inputs` and will be sliced, patched, or windowed alongside the inputs.
                The resulting segments will be passed to the model together with the corresponding input segments.r   zB`spatial_dim` can only be `0, 1, 2` with `[H, W, D]` respectively.N   rE   zCurrently, only 2D `roi_size` (z!) with 3D `inputs` tensor (shape=z) is supported.r   r   r   c                      j  | g|R i |S rq   network_wrapperxr7   r8   r6   r>   r?   r@   <lambda>)      z'SliceInferer.__call__.<locals>.<lambda>)r4   r6   r   c                   r   rq   r   r   r   r?   r@   r   .  r   )r4   r6   )r   r[   r%   r   rb   r   r   rx   insertr   r   r   rA   )r>   r4   r6   r7   r8   r   r   r   r@   rA     s*   
$zSliceInferer.__call__Nr   r   torch.Tensor | Nonec                   s   |j  jd d}|dur#|j  jd d}|||g|R i |}n||g|R i |}t|tjr=|j jd dS t|trX| D ]}|| j jd d||< qF|S t fdd|D S )zP
        Wrapper handles inference for 2D models over 3D volume inputs.
        r   dimNc                 3  s"    | ]}|j  jd  dV  qdS )r   r   N)	unsqueezer   )rr   out_ir   r?   r@   ru   O  s     z/SliceInferer.network_wrapper.<locals>.<genexpr>)	squeezer   rW   rg   r   r   r   ry   rp   )r>   r6   r   r   r7   r8   outrs   r?   r   r@   r   1  s   
zSliceInferer.network_wrapper)r   )r   rK   r7   r   r8   r   r9   rU   r   rq   )r6   r   r   r5   r   r   r7   r   r8   r   r9   r   )r<   rB   rC   rD   rV   rA   r   r   r?   r?   r   r@   r2     s    2r2   c                      s   e Zd ZdZd8 fddZ				d9d:ddZe 										d;d<d)d*Ze 						+	,		d=d>d0d1Z	d2d3 Z
	+	,d?d@d6d7Z  ZS )ADiffusionInfererz
    DiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal forward pass
    for a training iteration, and sample from the model.

    Args:
        scheduler: diffusion scheduler.
    	schedulerr   r9   rU   c                   s   t    || _d S rq   )r   rV   r   r>   r   r   r?   r@   rV   [  s   

zDiffusionInferer.__init__N	crossattnr4   r5   diffusion_modelr   noise	timestepsr   r   r   rZ   segc           
      C  s   |dvrt | d| jj|||d}|dkr+|du r tdtj||gdd}d}t|tr6t||d	n|}||||d
}	|	S )a>  
        Implements the forward pass for a supervised training iteration.

        Args:
            inputs: Input image to which noise is added.
            diffusion_model: diffusion model.
            noise: random noise, of the same shape as the input.
            timesteps: random timesteps.
            condition: Conditioning for network input.
            mode: Conditioning mode for the network.
            seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be
            provided on the forward (for SPADE-like AE or SPADE-like DM)
        r   concat condition is not supportedZoriginal_samplesr   r   r  Nz-Conditioning is required for concat conditionrE   r   r  r   r   context)	r:   r   	add_noiser[   rg   rh   rW   r   r
   )
r>   r4   r   r   r   r   r   r  noisy_image
predictionr?   r?   r@   rA   `  s   zDiffusionInferer.__call__Fd   T      input_noiseScheduler | Nonesave_intermediatesbool | Noneintermediate_stepsr   conditioningverboserR   cfgfloat | Nonecfg_fill_valuer   6torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]c                 C  s  |dvrt | d|dkr|du rtd|s| j}|}t|jdd tjdg|jjdf}|rItrIt	t
|j|tt|jt|d	}ntt
|j|}g }|D ]\}}t|trdt||	d
n|}|
durtj|gd dd}|durt|}|| tj||gdd}nd}n|}|}|dkr|durtj||gdd}||t|f|jdd}n||t|f|j|d}|
dur|d\}}||
||   }t|ts||||\}}n
|||||\}}|r|| dkr|| qU|r||fS |S )a  
        Args:
            input_noise: random noise, of the same shape as the desired sample.
            diffusion_model: model to sample from.
            scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
            save_intermediates: whether to return intermediates along the sampling change
            intermediate_steps: if save_intermediates is True, saves every n steps
            conditioning: Conditioning for network input.
            mode: Conditioning mode for the network.
            verbose: if true, prints the progression bar of the sampling process.
            seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
            cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
            cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance.
        r  r  r  N>Conditioning must be supplied for if condition mode is concat.rE   r   dtypetotalr  r   r   r   r  )r:   r[   r   rg   rh   r   tensorr  has_tqdmr*   r   rd   rb   iterrW   r   r
   	ones_likefill_r   tor   r   r   stepr   )r>   r  r   r   r  r  r  r   r  r  r  r  imageall_next_timestepsprogress_barintermediatestnext_tmodel_inputunconditionconditioning_inputmodel_outputmodel_output_uncondmodel_output_condr   r?   r?   r@   rm     sb   (




zDiffusionInferer.sampler      r   rE   original_input_rangerp   scaled_input_rangec           !      C  s8  |s| j }| dkrtd|  |dvrt| d|dkr+|du r+td|	r5tr5t|j}nt|j}g }t	|
|j}t|jd 
|j}|D ]>}tj|jdd	 ||jd
 }| j j|||d}t|trzt||
dn|}|dkr|durtj||gd	d}|||dd}n||||d}|jd	 |jd	 d kr|jdv rtj||jd	 d	d\}}nd}|j| }|dkr|j|d	  n|j}d	| }d	| }|jdkr||d |  |d  }n|jdkr|}n|jdkr|d | |d |  }|jrt|dd	}|d |j|  | }|j| d | | }|| ||  }|j|||d}|j||d}t |}|rIt |n|}|dkr^| j!||d| ||d } ndd| | t"||  || d t"|    } || #| jd dj$d	d7 }|r|%| &  qS|r||fS |S )a[  
        Computes the log-likelihoods for an input.

        Args:
            inputs: input images, NxCxHxW[xD]
            diffusion_model: model to compute likelihood from
            scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
            save_intermediates: save the intermediate spatial KL maps
            conditioning: Conditioning for network input.
            mode: Conditioning mode for the network.
            original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
            scaled_input_range: the [min,max] intensity range of the input data after scaling.
            verbose: if true, prints the progression bar of the sampling process.
            seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
        DDPMSchedulerLLikelihood computation is only compatible with DDPMScheduler, you are using r  r  r  Nr  r   rE   r   r  r  r   r  r  r   ZlearnedZlearned_rangeepsilon      ?rm   v_predictionr   timestepx_0x_tr?  predicted_variancer4   means
log_scalesr5  r6  r  )'r   	_get_namer:   r[   r   r*   r   r!  rg   
randn_liker$  r   zerosr   fulllongr	  rW   r   r
   rh   variance_typesplitalphas_cumprodoneprediction_typeclip_sampleclampbetasalphas	_get_mean_get_variancelog_get_decoder_log_likelihoodexpviewmeanr   r   )!r>   r4   r   r   r  r  r   r5  r6  r  r  r(  r)  r   total_klr*  r   r
  r/  rC  alpha_prod_talpha_prod_t_prevbeta_prod_tbeta_prod_t_prevpred_original_samplepred_original_sample_coeffcurrent_sample_coeffpredicted_meanposterior_meanposterior_variancelog_posterior_variancelog_predicted_varianceklr?   r?   r@   get_likelihood  s   

 "





	zDiffusionInferer.get_likelihoodc                 C  sB   ddt t t dtj g|j|dt |d     S )z
        A fast approximation of the cumulative distribution function of the
        standard normal. Code adapted from https://github.com/openai/improved-diffusion.
        r<        ?g       @gHm?r   )	rg   tanhsqrtr   mathpir$  r   pow)r>   r   r?   r?   r@   _approx_standard_normal_cdfh  s   <z,DiffusionInferer._approx_standard_normal_cdfrE  rF  c                 C  s   |j |j krtd|j  d|j  |d |d  |d |d   }|| }t| }|||d   }	| |	}
|||d   }| |}t|
jdd}td| jdd}|
| }t|d	k |t|d
k|t|jdd}|S )ax  
        Compute the log-likelihood of a Gaussian distribution discretizing to a
        given image. Code adapted from https://github.com/openai/improved-diffusion.

        Args:
            input: the target images. It is assumed that this was uint8 values,
                      rescaled to the range [-1, 1].
            means: the Gaussian mean Tensor.
            log_scales: the Gaussian log stddev Tensor.
            original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
            scaled_input_range: the [min,max] intensity range of the input data after scaling.
        z/Inputs and means must have the same shape, got r   rE   r   r   g-q=)rd   rk  g+g+?)r   r[   rg   rY  rq  rW  rR  where)r>   r4   rE  rF  r5  r6  	bin_widthZ
centered_xZinv_stdvZplus_inZcdf_plusZmin_inZcdf_minZlog_cdf_plusZlog_one_minus_cdf_minZ	cdf_delta	log_probsr?   r?   r@   rX  r  s(   

z,DiffusionInferer._get_decoder_log_likelihoodr   r   r9   rU   Nr   N)r4   r5   r   r   r   r5   r   r5   r   r   r   rZ   r  r   r9   r5   	NFr  Nr   TNNr  )r  r5   r   r   r   r  r  r  r  r   r  r   r   rZ   r  rR   r  r   r  r  r  r   r9   r  NFNr   r2  r4  TN)r4   r5   r   r   r   r  r  r  r  r   r   rZ   r5  rp   r6  rp   r  rR   r  r   r9   r  )r2  r4  )r4   r5   rE  r5   rF  r5   r5  rp   r6  rp   r9   r5   )r<   rB   rC   rD   rV   rA   rg   no_gradrm   rj  rq  rX  r   r?   r?   r   r@   r   R  sB    *\ r   c                      s   e Zd ZdZ			d<d= fddZ			d>d? fddZe 			 			!			"d@dA fd0d1Ze 					2	3	!		4	dBdC fd:d;Z	  Z
S )DLatentDiffusionInferera-  
    LatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, and a scheduler, and can
    be used to perform a signal forward pass for a training iteration, and sample from the model.

    Args:
        scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents.
        scale_factor: scale factor to multiply the values of the latent representation before processing it by the
            second stage.
        ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape.
        autoencoder_latent_shape:  autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a
             difference between the autoencoder's latent shape and the DM shape.
    rk  Nr   r   scale_factorr   ldm_latent_shapelist | Noneautoencoder_latent_shaper9   rU   c                   t   t  j|d || _|d u |d u A rtd|| _|| _| jd ur6| jd ur8t| jd| _t| jd| _	d S d S d S )Nr   zSIf ldm_latent_shape is None, autoencoder_latent_shape must be None, and vice versa.spatial_sizer   
r   rV   r{  r[   r|  r~  r    ldm_resizerr   autoencoder_resizerr>   r   r{  r|  r~  r   r?   r@   rV        zLatentDiffusionInferer.__init__r   r4   r5   autoencoder_modelAutoencoderKL | VQVAEr   r   r   r   r   r   r   rZ   r  c	              	     s~   t   || j }	W d   n1 sw   Y   jdur0t  fddt|	D d}	t j|	||||||d}
|
S )aw  
        Implements the forward pass for a supervised training iteration.

        Args:
            inputs: input image to which the latent representation will be extracted and noise is added.
            autoencoder_model: first stage model.
            diffusion_model: diffusion model.
            noise: random noise, of the same shape as the latent representation.
            timesteps: random timesteps.
            condition: conditioning for network input.
            mode: Conditioning mode for the network.
            seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
        Nc                      g | ]}  |qS r?   r  rr   rj   r   r?   r@   r         z3LatentDiffusionInferer.__call__.<locals>.<listcomp>r   )r4   r   r   r   r   r   r  )	rg   ry  encode_stage_2_inputsr{  r|  stackr   r   rA   )r>   r4   r  r   r   r   r   r   r  latentr  r   r   r@   rA     s   

	zLatentDiffusionInferer.__call__Fr  Tr  r  r  r  r  r  r   r  r  rR   r  r  r  r  c                   s*  t |trt |tr|jj|jkrtd|jj d|j t j||||||||	|
||d}|r6|\}}n|} jdurWt	
 fddt|D d}|rW fdd|D }|j}t |trft|j|
d	}|| j }|rg }|D ]}|j}t |trt|j|
d	}||| j  qs||fS |S )
a>  
        Args:
            input_noise: random noise, of the same shape as the desired latent representation.
            autoencoder_model: first stage model.
            diffusion_model: model to sample from.
            scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
            save_intermediates: whether to return intermediates along the sampling change
            intermediate_steps: if save_intermediates is True, saves every n steps
            conditioning: Conditioning for network input.
            mode: Conditioning mode for the network.
            verbose: if true, prints the progression bar of the sampling process.
            seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
             is instance of SPADEAutoencoderKL, segmentation must be provided.
            cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
            cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance.
        zIf both autoencoder_model and diffusion_model implement SPADE, the number of semanticlabels for each must be compatible, but got z and)r  r   r   r  r  r  r   r  r  r  r  Nc                   r  r?   r  r  r   r?   r@   r   $  r  z1LatentDiffusionInferer.sample.<locals>.<listcomp>r   c                   *   g | ]}t  fd dt|D dqS )c                   r  r?   r  r  r   r?   r@   r   '  r  z<LatentDiffusionInferer.sample.<locals>.<listcomp>.<listcomp>r   rg   r  r   rr   r   r   r?   r@   r   &      r  )rW   r   r   decoderlabel_ncr[   r   rm   r~  rg   r  r   decode_stage_2_outputsr
   r{  r   )r>   r  r  r   r   r  r  r  r   r  r  r  r  ro   r  latent_intermediatesdecoder&  r)  latent_intermediater   r   r@   rm     s^   "




zLatentDiffusionInferer.sampler2  r4  nearestr5  tuple | Noner6  resample_latent_likelihoodsresample_interpolation_modec              
     s   |r|dvrt d| ||j }jdur)tfddt|D d}t j|||||||
|d}|rZ|rZ|d }t	j
|jd	d |d
  fdd|D }|d |f}|S )a,  
        Computes the log-likelihoods of the latent representations of the input.

        Args:
            inputs: input images, NxCxHxW[xD]
            autoencoder_model: first stage model.
            diffusion_model: model to compute likelihood from
            scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
            save_intermediates: save the intermediate spatial KL maps
            conditioning: Conditioning for network input.
            mode: Conditioning mode for the network.
            original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
            scaled_input_range: the [min,max] intensity range of the input data after scaling.
            verbose: if true, prints the progression bar of the sampling process.
            resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial
                dimension as the input images.
            resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear',
                or 'trilinear;
            seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
             is instance of SPADEAutoencoderKL, segmentation must be provided.
        r  bilinear	trilinearRresample_interpolation mode should be either nearest, bilinear, or trilinear, got Nc                   r  r?   r  r  r   r?   r@   r   h  r  z9LatentDiffusionInferer.get_likelihood.<locals>.<listcomp>r   )r4   r   r   r  r  r   r  r  rE   r   sizer   c                      g | ]} |qS r?   r?   rr   r   resizerr?   r@   r   x  r   )r[   r  r{  r|  rg   r  r   r   rj  nnUpsampler   )r>   r4   r  r   r   r  r  r   r5  r6  r  r  r  r  latentsro   r)  r   r  r>   r@   rj  ;  s.   &
z%LatentDiffusionInferer.get_likelihoodrk  NN
r   r   r{  r   r|  r}  r~  r}  r9   rU   rv  )r4   r5   r  r  r   r   r   r5   r   r5   r   r   r   rZ   r  r   r9   r5   rw  )r  r5   r  r  r   r   r   r  r  r  r  r   r  r   r   rZ   r  rR   r  r   r  r  r  r   r9   r  
NFNr   r2  r4  TFr  N)r4   r5   r  r  r   r   r   r  r  r  r  r   r   rZ   r5  r  r6  r  r  rR   r  rR   r  rZ   r  r   r9   r  r<   rB   rC   rD   rV   rA   rg   ry  rm   rj  r   r?   r?   r   r@   rz    sB    )Vrz  c                      s~   e Zd ZdZd5ddZ				d6d7ddZe 										d8d9d,d-Ze 						.	/		d:d; fd3d4Z	  Z
S )<ControlNetDiffusionInferera  
    ControlNetDiffusionInferer takes a trained diffusion model and a scheduler and can be used to perform a signal
    forward pass for a training iteration, and sample from the model, supporting ControlNet-based conditioning.

    Args:
        scheduler: diffusion scheduler.
    r   r   r9   rU   c                 C  s   t |  || _d S rq   )r-   rV   r   r   r?   r?   r@   rV     s   

z#ControlNetDiffusionInferer.__init__Nr   r4   r5   r   r   
controlnetr   r   r   cn_condr   r   r   rZ   r  c
                 C  s   |dvrt | d| jj|||d}
|dkr'|dur'tj|
|gdd}
d}||
|||d\}}|}t|tr>t||	d	}||
||||d
}|S )a  
        Implements the forward pass for a supervised training iteration.

        Args:
            inputs: Input image to which noise is added.
            diffusion_model: diffusion model.
            controlnet: controlnet sub-network.
            noise: random noise, of the same shape as the input.
            timesteps: random timesteps.
            cn_cond: conditioning image for the ControlNet.
            condition: Conditioning for network input.
            mode: Conditioning mode for the network.
            seg: if model is instance of SPADEDiffusionModelUnet, segmentation must be
            provided on the forward (for SPADE-like AE or SPADE-like DM)
        r  r  r  r  NrE   r   r   r   controlnet_condr  r  r   r   r  down_block_additional_residualsmid_block_additional_residual)r:   r   r	  rg   rh   rW   r   r
   )r>   r4   r   r  r   r   r  r   r   r  r
  down_block_res_samplesmid_block_res_samplediffuser  r?   r?   r@   rA     s(   

z#ControlNetDiffusionInferer.__call__Fr  Tr  r  r  r  r  r  r   r  r  rR   r  r  r  r   r  c                 C  sh  |	dvrt |	 d|s| j}|}t|jdd tjdg|jjdf}|
r=tr=tt	|j|t
t|jt|d}ntt	|j|}g }|durUtj|gd dd	}|D ]\}}|durtj|gd dd	}|durt|}|| tj||gdd	}nd}n|}|}|}t|trt||d
}|	dkr|durtj||gdd	}||t|f|j|dd\}}||t|f|jd||d}n#||t|f|j||d\}}||t|f|j|||d}|dur|d\}}||||   }t|ts||||\}}n
|||||\}}|r*|| dkr*|| qW|r2||fS |S )a#  
        Args:
            input_noise: random noise, of the same shape as the desired sample.
            diffusion_model: model to sample from.
            controlnet: controlnet sub-network.
            cn_cond: conditioning image for the ControlNet.
            scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
            save_intermediates: whether to return intermediates along the sampling change
            intermediate_steps: if save_intermediates is True, saves every n steps
            conditioning: Conditioning for network input.
            mode: Conditioning mode for the network.
            verbose: if true, prints the progression bar of the sampling process.
            seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
            cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
            cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance.
        r  r  rE   Nr   r  r  r   r   r  r  r  r   r  r  r  )r:   r   rg   rh   r   r  r  r   r*   r   rd   rb   r!  r"  r#  rW   r   r
   r   r$  r   r   r   r%  r   )r>   r  r   r  r  r   r  r  r  r   r  r  r  r  r&  r'  r(  r)  r*  r+  r,  r-  r.  r  r  r  r/  r0  r1  r   r?   r?   r@   rm     s   !(




	



z!ControlNetDiffusionInferer.sampler2  r4  r5  rp   r6  c           &        sv  |s| j }| dkrtd|  |dvrt| d|r)tr)t|j}nt|j}g }t|	|j
}t|jd 	|j
}|D ]i}tj|jdd ||j
d }| j j|||d	}|}t|trpt||d
}|dkr|durtj||gdd}||t|f	|j
|dd\}}|||d||d}n||t|f	|j
||d\}}||||||d}|jd |jd d kr|jdv rtj||jd dd\}}nd}|j| }|dkr|j|d  n|j}d| }d| }|jdkr||d |  |d  }n|jdkr|}n|jdkr"|d | |d |  }|jr-t|dd}|d |j|  | }|j| d | | }|| ||  } |j|||d}!|j||d}"t |"}#|rgt |n|#}$|dkr}t! j"|| d|$ |	|
d }%ndd|$ |# t#|#|$  |!|  d t#|$    }%||%$|%jd dj%dd7 }|r|&|%'  qG|r||fS |S )a  
        Computes the log-likelihoods for an input.

        Args:
            inputs: input images, NxCxHxW[xD]
            diffusion_model: model to compute likelihood from
            controlnet: controlnet sub-network.
            cn_cond: conditioning image for the ControlNet.
            scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
            save_intermediates: save the intermediate spatial KL maps
            conditioning: Conditioning for network input.
            mode: Conditioning mode for the network.
            original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
            scaled_input_range: the [min,max] intensity range of the input data after scaling.
            verbose: if true, prints the progression bar of the sampling process.
            seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
        r7  r8  r  r  r   NrE   r9  r  r  r  r   r  r  r  r   r:  r;  r<  rm   r=  r   r>  rB  rD  r  )(r   rG  r:   r   r*   r   r!  rg   rH  r$  r   rI  r   rJ  rK  r	  rW   r   r
   rh   r   rL  rM  rN  rO  rP  rQ  rR  rS  rT  rU  rV  rW  r   rX  rY  rZ  r[  r   r   )&r>   r4   r   r  r  r   r  r  r   r5  r6  r  r  r(  r)  r   r\  r*  r   r
  r  r  r  r/  rC  r]  r^  r_  r`  ra  rb  rc  rd  re  rf  rg  rh  ri  r   r?   r@   rj  >  s   "

 


"



	z)ControlNetDiffusionInferer.get_likelihoodru  rv  )r4   r5   r   r   r  r   r   r5   r   r5   r  r5   r   r   r   rZ   r  r   r9   r5   rw  )r  r5   r   r   r  r   r  r5   r   r  r  r  r  r   r  r   r   rZ   r  rR   r  r   r  r  r  r   r9   r  rx  )r4   r5   r   r   r  r   r  r5   r   r  r  r  r  r   r   rZ   r5  rp   r6  rp   r  rR   r  r   r9   r  r  r?   r?   r   r@   r  }  s8    
6}r  c                      s   e Zd ZdZ			d?d@ fddZ			dAdB fd d!Ze 		"	#			$			%dCdD fd3d4Ze 		"			5	6	$	"	7	dEdF fd=d>Z	  Z
S )G ControlNetLatentDiffusionInfereraG  
    ControlNetLatentDiffusionInferer takes a stage 1 model (VQVAE or AutoencoderKL), diffusion model, controlnet,
    and a scheduler, and can be used to perform a signal forward pass for a training iteration, and sample from
    the model.

    Args:
        scheduler: a scheduler to be used in combination with `unet` to denoise the encoded image latents.
        scale_factor: scale factor to multiply the values of the latent representation before processing it by the
            second stage.
        ldm_latent_shape: desired spatial latent space shape. Used if there is a difference in the autoencoder model's latent shape.
        autoencoder_latent_shape:  autoencoder_latent_shape: autoencoder spatial latent space shape. Used if there is a
             difference between the autoencoder's latent shape and the DM shape.
    rk  Nr   r   r{  r   r|  r}  r~  r9   rU   c                   r  )Nr  zQIf ldm_latent_shape is None, autoencoder_latent_shape must be Noneand vice versa.r  r  r  r  r   r?   r@   rV     r  z)ControlNetLatentDiffusionInferer.__init__r   r4   r5   r  r  r   r   r  r   r   r   r  r   r   r   rZ   r  c                   s   t   || j }W d   n1 sw   Y   jdur0t  fddt|D d}|jdd |jdd krIt	||jdd }t
 j||||||||	|
d	}|S )a  
        Implements the forward pass for a supervised training iteration.

        Args:
            inputs: input image to which the latent representation will be extracted and noise is added.
            autoencoder_model: first stage model.
            diffusion_model: diffusion model.
            controlnet: instance of ControlNet model
            noise: random noise, of the same shape as the latent representation.
            timesteps: random timesteps.
            cn_cond: conditioning tensor for the ControlNet network
            condition: conditioning for network input.
            mode: Conditioning mode for the network.
            seg: if diffusion model is instance of SPADEDiffusionModel, segmentation must be provided.
        Nc                   r  r?   r  r  r   r?   r@   r     r  z=ControlNetLatentDiffusionInferer.__call__.<locals>.<listcomp>r   r   )	r4   r   r  r   r   r  r   r   r  )rg   ry  r  r{  r|  r  r   r   Finterpolater   rA   )r>   r4   r  r   r  r   r   r  r   r   r  r  r  r   r   r@   rA     s&   

z)ControlNetLatentDiffusionInferer.__call__Fr  Tr  r  r  r  r  r  r   r  r  rR   r  r  r  r  c                   sN  t |trt |tr|jj|jkrtd|jdd |jdd kr.t||jdd }t	 j
||||||||	|
||||d}|rH|\}}n|} jdurit fddt|D d}|ri fdd|D }|j}t |trxt|j|d	}|| j }|rg }|D ]}|j}t |trt|j|d	}||| j  q||fS |S )
a  
        Args:
            input_noise: random noise, of the same shape as the desired latent representation.
            autoencoder_model: first stage model.
            diffusion_model: model to sample from.
            controlnet: instance of ControlNet model.
            cn_cond: conditioning tensor for the ControlNet network.
            scheduler: diffusion scheduler. If none provided will use the class attribute scheduler.
            save_intermediates: whether to return intermediates along the sampling change
            intermediate_steps: if save_intermediates is True, saves every n steps
            conditioning: Conditioning for network input.
            mode: Conditioning mode for the network.
            verbose: if true, prints the progression bar of the sampling process.
            seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
             is instance of SPADEAutoencoderKL, segmentation must be provided.
            cfg: classifier-free-guidance scale, which indicates the level of strengthening on the conditioning.
            cfg_fill_value: the fill value to use for the unconditioned input when using classifier-free guidance.
        zIf both autoencoder_model and diffusion_model implement SPADE, the number of semanticlabels for each must be compatible. Got {autoencoder_model.decoder.label_nc} and {diffusion_model.label_nc}r   N)r  r   r  r  r   r  r  r  r   r  r  r  r  c                   r  r?   r  r  r   r?   r@   r   q  r  z;ControlNetLatentDiffusionInferer.sample.<locals>.<listcomp>r   c                   r  )c                   r  r?   r  r  r   r?   r@   r   t  r  zFControlNetLatentDiffusionInferer.sample.<locals>.<listcomp>.<listcomp>r   r  r  r   r?   r@   r   s  r  r  )rW   r   r   r  r  r[   r   r  r  r   rm   r~  rg   r  r   r  r
   r{  r   )r>   r  r  r   r  r  r   r  r  r  r   r  r  r  r  ro   r  r  r  r&  r)  r  r   r   r@   rm   )  s^   &




z'ControlNetLatentDiffusionInferer.sampler2  r4  r  r5  r  r6  r  r  c                   s   |r|dvrt d| ||j }|jdd |jdd kr.t||jdd }jdurBtfddt	|D d}t
 j||||||||	||d
}|ru|ru|d	 }tj|jdd |d
  fdd|D }|d |f}|S )a  
        Computes the log-likelihoods of the latent representations of the input.

        Args:
            inputs: input images, NxCxHxW[xD]
            autoencoder_model: first stage model.
            diffusion_model: model to compute likelihood from
            controlnet: instance of ControlNet model.
            cn_cond: conditioning tensor for the ControlNet network.
            scheduler: diffusion scheduler. If none provided will use the class attribute scheduler
            save_intermediates: save the intermediate spatial KL maps
            conditioning: Conditioning for network input.
            mode: Conditioning mode for the network.
            original_input_range: the [min,max] intensity range of the input data before any scaling was applied.
            scaled_input_range: the [min,max] intensity range of the input data after scaling.
            verbose: if true, prints the progression bar of the sampling process.
            resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial
                dimension as the input images.
            resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear',
                or 'trilinear;
            seg: if diffusion model is instance of SPADEDiffusionModel, or autoencoder_model
             is instance of SPADEAutoencoderKL, segmentation must be provided.
        r  r  r   Nc                   r  r?   r  r  r   r?   r@   r     r  zCControlNetLatentDiffusionInferer.get_likelihood.<locals>.<listcomp>r   )
r4   r   r  r  r   r  r  r   r  r  rE   r  c                   r  r?   r?   r  r  r?   r@   r     r   )r[   r  r{  r   r  r  r|  rg   r  r   r   rj  r  r  )r>   r4   r  r   r  r  r   r  r  r   r5  r6  r  r  r  r  r  ro   r)  r   r  r@   rj    s6   *
z/ControlNetLatentDiffusionInferer.get_likelihoodr  r  rv  )r4   r5   r  r  r   r   r  r   r   r5   r   r5   r  r5   r   r   r   rZ   r  r   r9   r5   rw  )r  r5   r  r  r   r   r  r   r  r5   r   r  r  r  r  r   r  r   r   rZ   r  rR   r  r   r  r  r  r   r9   r  r  ) r4   r5   r  r  r   r   r  r   r  r5   r   r  r  r  r  r   r   rZ   r5  r  r6  r  r  rR   r  rR   r  rZ   r  r   r9   r  r  r?   r?   r   r@   r    sB    3`r  c                   @  sb   e Zd ZdZd*ddZ		d+d,ddZe 				d-d.d"d#Ze 			$	d/d0d(d)Z	dS )1VQVAETransformerInfererzF
    Class to perform inference with a VQVAE + Transformer model.
    r9   rU   c                 C  r   rq   r   r   r?   r?   r@   rV     r   z VQVAETransformerInferer.__init__NFr4   r5   vqvae_modelr   transformer_modelr   orderingr"   r   r   return_latentrR   7torch.Tensor | tuple[torch.Tensor, torch.Tensor, tuple]c                 C  s$  t   ||}W d   n1 sw   Y  t|jdd }||jd d}|dd| f }| }	t	|dd|j
}|ddddf }| }|jd }
|j}||
k rmtt jd|
d | dd }nd}||dd||| f |d	}|r||	dd||| f |fS |S )
a  
        Implements the forward pass for a supervised training iteration.

        Args:
            inputs: input image to which the latent representation will be extracted.
            vqvae_model: first stage model.
            transformer_model: autoregressive transformer model.
            ordering: ordering of the quantised latent representation.
            return_latent: also return latent sequence and spatial dim of the latent.
            condition: conditioning for network input.
        NrE   r   r   rE   r   constant)rE   )lowhighr  r   r  )rg   ry  index_quantizerp   r   reshapeget_sequence_orderingcloner  padnum_embeddingsrK  max_seq_lenrK   randintitem)r>   r4   r  r  r  r   r  r  latent_spatial_dimtargetseq_lenr  startr  r?   r?   r@   rA     s&   

" z VQVAETransformerInferer.__call__rk  Tr  &tuple[int, int, int] | tuple[int, int]starting_tokensr  temperaturer   top_kr   r  c
              	   C  s^  t |}
|	rtrtt|
}ntt|
}| }|D ]n}|d|jkr)|}n|dd|j df }|||d}|dddddf | }|durjt	
|t||d\}}td |||dddgf k < tj|dd}d|dd|jf< t	j|dd}t	j||fdd}q|ddddf }|dd| f }||jd f| }||S )	a@  
        Sampling function for the VQVAE + Transformer model.

        Args:
            latent_spatial_dim: shape of the sampled image.
            starting_tokens: starting tokens for the sampling. It must be vqvae_model.num_embeddings value.
            vqvae_model: first stage model.
            transformer_model: model to sample from.
            conditioning: Conditioning for network input.
            temperature: temperature for sampling.
            top_k: top k sampling.
            verbose: if true, prints the progression bar of the sampling process.
        rE   Nr  r   Infr   r   )num_samples)rn  prodr   r*   rc   r!  rK  r  r  rg   topkrd   r   r  softmaxr  multinomialrh   get_revert_sequence_orderingr  r   decode_samples)r>   r  r  r  r  r  r  r  r  r  r  r(  Z
latent_seqr   idx_condlogitsvprobsZidx_nextr  r?   r?   r@   rm     s,   
 
zVQVAETransformerInferer.sampler  r  r  rZ   c	                 C  sL  |r|dvrt d| t  ||}	W d   n1 s!w   Y  t|	jdd }
|	|	jd d}	|	dd| f }	t	|
}t
|	dd|j}	|	 }	||	ddd|jf |d	}t
j|dd
}|	ddddf }t|d|ddd|jf dd}|jd |jd k r|rtrtt|j|}ntt|j|}|D ]D}|	dd|d |j |d f }|||d	}|dddddf }t
j|dd
}t|d|dd|f d}tj||fdd
}qt|}|dd| f }||jd f|
 }|r$tj|jdd |d}||ddddf }|S )a  
        Computes the log-likelihoods of the latent representations of the input.

        Args:
            inputs: input images, NxCxHxW[xD]
            vqvae_model: first stage model.
            transformer_model: autoregressive transformer model.
            ordering: ordering of the quantised latent representation.
            condition: conditioning for network input.
            resample_latent_likelihoods: if true, resamples the intermediate likelihood maps to have the same spatial
                dimension as the input images.
            resample_interpolation_mode: if use resample_latent_likelihoods, select interpolation 'nearest', 'bilinear',
                or 'trilinear;
            verbose: if true, prints the progression bar of the sampling process.

        r  r  NrE   r   r   r  r  r  r   r   r  .)r[   rg   ry  r  rp   r   r  r  rn  r  r  r  r  rK  r  r  gatherr   r   r   r*   rc   r!  rh   rW  r  r  r  )r>   r4   r  r  r  r   r  r  r  r  r  r  r  r  r  r(  rj   r  pZprobs_reshapedr  r?   r?   r@   rj  O  sF   

," 
z&VQVAETransformerInferer.get_likelihoodr   )NF)r4   r5   r  r   r  r   r  r"   r   r   r  rR   r9   r  )Nrk  NT)r  r  r  r5   r  r   r  r   r  r"   r  r   r  r   r  r   r  rR   r9   r5   )NFr  F)r4   r5   r  r   r  r   r  r"   r   r   r  rR   r  rZ   r  rR   r9   r5   )
r<   rB   rC   rD   rV   rA   rg   ry  rm   rj  r?   r?   r?   r@   r    s$    
	1?r  )T
__future__r   rn  r   abcr   r   collections.abcr   r   r   r   r	   	functoolsr
   pydocr   typingr   rg   torch.nnr  torch.nn.functional
functionalr  monai.apps.utilsr   
monai.datar   monai.data.meta_tensorr   Zmonai.data.thread_bufferr   Zmonai.inferers.mergerr   r   Zmonai.inferers.splitterr   Zmonai.inferers.utilsr   r   monai.networks.netsr   r   r   r   r   r   r   Zmonai.networks.schedulersr   r   monai.transformsr   r    monai.utilsr!   r"   r#   r$   r%   r&   Zmonai.visualizer'   r(   r)   r*   r   r<   r   __all__r-   r.   r/   r0   r3   r1   r2   r   rz  r  r  Moduler  r?   r?   r?   r@   <module>   s`   $	 &  B #Y/j  L b  [  