o
    /iU                     @  sH  d dl mZ d dlZd dlZd dlZd dlmZ d dlmZm	Z	m
Z
 d dlZd dlZd dlmZ d dlmZ d dlmZ d dlmZ d dlmZ d d	lmZmZ d d
lmZmZmZ erkd dlm Z! dZ"d dl#Z#dZ$ned\Z!Z"ed\Z#Z$dgZ%G dd dZ&G dd de&Z'G dd de&Z(dddZ)dddZ*G dd dZ+dS )    )annotationsN)partial)TYPE_CHECKINGAnyCallable)	Optimizer)DEFAULT_PROTOCOL)
DataLoader)	eval_mode)ExponentialLRLinearLR)StateCachercopy_to_deviceoptional_importTzmatplotlib.pyplottqdmLearningRateFinderc                   @  s:   e Zd Zddd	Zed
d Zdd Zdd Zdd ZdS )DataLoaderIterdata_loaderr	   image_extractorr   label_extractorreturnNonec                 C  s>   t |tstdt| d|| _t|| _|| _|| _d S )NzLoader has unsupported type: z1. Expected type was `torch.utils.data.DataLoader`)	
isinstancer	   
ValueErrortyper   iter	_iteratorr   r   selfr   r   r    r   \/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/optimizers/lr_finder.py__init__/   s   


zDataLoaderIter.__init__c                 C  s   | j jS N)r   datasetr   r   r   r    r#   9   s   zDataLoaderIter.datasetc                 C  s   |  |}| |}||fS r"   )r   r   )r   
batch_dataimageslabelsr   r   r    inputs_labels_from_batch=   s   

z'DataLoaderIter.inputs_labels_from_batchc                 C  s   | S r"   r   r$   r   r   r    __iter__B   s   zDataLoaderIter.__iter__c                 C  s   t | j}| |S r"   )nextr   r(   )r   batchr   r   r    __next__E   s   

zDataLoaderIter.__next__Nr   r	   r   r   r   r   r   r   )	__name__
__module____qualname__r!   propertyr#   r(   r)   r,   r   r   r   r    r   -   s    


r   c                      s*   e Zd Z	dd fddZdd Z  ZS )TrainDataLoaderIterTr   r	   r   r   r   
auto_resetboolr   r   c                   s   t  ||| || _d S r"   )superr!   r3   )r   r   r   r   r3   	__class__r   r    r!   L   s   
zTrainDataLoaderIter.__init__c                 C  sh   zt | j}| |\}}W ||fS  ty3   | js t| j| _t | j}| |\}}Y ||fS w r"   )r*   r   r(   StopIterationr3   r   r   )r   r+   inputsr'   r   r   r    r,   R   s   

zTrainDataLoaderIter.__next__T)
r   r	   r   r   r   r   r3   r4   r   r   )r.   r/   r0   r!   r,   __classcell__r   r   r6   r    r2   J   s    r2   c                      s6   e Zd ZdZd fd	d
Zdd Z fddZ  ZS )ValDataLoaderItera  This iterator will reset itself **only** when it is acquired by
    the syntax of normal `iterator`. That is, this iterator just works
    like a `torch.data.DataLoader`. If you want to restart it, you
    should use it like:

        ```
        loader_iter = ValDataLoaderIter(data_loader)
        for batch in loader_iter:
            ...

        # `loader_iter` should run out of values now, you can restart it by:
        # 1. the way we use a `torch.data.DataLoader`
        for batch in loader_iter:        # __iter__ is called implicitly
            ...

        # 2. passing it into `iter()` manually
        loader_iter = iter(loader_iter)  # __iter__ is called by `iter()`
        ```
    r   r	   r   r   r   r   r   c                   s&   t  ||| t| j| _d| _d S Nr   )r5   r!   lenr   	run_limitrun_counterr   r6   r   r    r!   u   s   
zValDataLoaderIter.__init__c                 C  s"   | j | jkrt| j| _d| _ | S r=   )r@   r?   r   r   r   r$   r   r   r    r)   z   s   zValDataLoaderIter.__iter__c                   s   |  j d7  _ t  S )N   )r@   r5   r,   r$   r6   r   r    r,      s   
zValDataLoaderIter.__next__r-   )r.   r/   r0   __doc__r!   r)   r,   r;   r   r   r6   r    r<   `   s
    r<   xr   r   torch.Tensorc                 C  "   t | tr| d }|S | d }|S )z3Default callable for getting image from batch data.imager   r   dictrC   outr   r   r    default_image_extractor      rK   c                 C  rE   )z3Default callable for getting label from batch data.labelrA   rG   rI   r   r   r    default_label_extractor   rL   rN   c                   @  s   e Zd ZdZddddeedfdUddZdVddZdee	ddddd d!d"ddfdWd6d7Z
dXd:d;Zd<d= Z	dYdZd@dAZdYd[dDdEZd\d]dJdKZd\d^dMdNZ	F	F			d_d`dSdTZdS )ar   a  Learning rate range test.

    The learning rate range test increases the learning rate in a pre-training run
    between two boundaries in a linear or exponential manner. It provides valuable
    information on how well the network can be trained over a range of learning rates
    and what is the optimal learning rate.

    Example (fastai approach):
    >>> lr_finder = LearningRateFinder(net, optimizer, criterion)
    >>> lr_finder.range_test(data_loader, end_lr=100, num_iter=100)
    >>> lr_finder.get_steepest_gradient()
    >>> lr_finder.plot() # to inspect the loss-learning rate graph

    Example (Leslie Smith's approach):
    >>> lr_finder = LearningRateFinder(net, optimizer, criterion)
    >>> lr_finder.range_test(train_loader, val_loader=val_loader, end_lr=1, num_iter=100, step_mode="linear")

    Gradient accumulation is supported; example:
    >>> train_data = ...    # prepared dataset
    >>> desired_bs, real_bs = 32, 4         # batch size
    >>> accumulation_steps = desired_bs // real_bs     # required steps for accumulation
    >>> data_loader = torch.utils.data.DataLoader(train_data, batch_size=real_bs, shuffle=True)
    >>> acc_lr_finder = LearningRateFinder(net, optimizer, criterion)
    >>> acc_lr_finder.range_test(data_loader, end_lr=10, num_iter=100, accumulation_steps=accumulation_steps)

    By default, image will be extracted from data loader with x["image"] and x[0], depending on whether
    batch data is a dictionary or not (and similar behaviour for extracting the label). If your data loader
    returns something other than this, pass a callable function to extract it, e.g.:
    >>> image_extractor = lambda x: x["input"]
    >>> label_extractor = lambda x: x[100]
    >>> lr_finder = LearningRateFinder(net, optimizer, criterion)
    >>> lr_finder.range_test(train_loader, val_loader, image_extractor, label_extractor)

    References:
    Modified from: https://github.com/davidtvs/pytorch-lr-finder.
    Cyclical Learning Rates for Training Neural Networks: https://arxiv.org/abs/1506.01186
    NTFmodel	nn.Module	optimizerr   	criteriontorch.nn.Moduledevicestr | torch.device | Nonememory_cacher4   	cache_dir
str | Noneamppickle_moduletypes.ModuleTypepickle_protocolintverboser   r   c                 C  s   || _ |   || _|| _g g d| _|| _|| _|| _|
| _t	| j
 j| _t||||	d| _| jd| j  | jd| j   |rL|| _dS | j| _dS )a8  Constructor.

        Args:
            model: wrapped model.
            optimizer: wrapped optimizer.
            criterion: wrapped loss function.
            device: device on which to test. run a string ("cpu" or "cuda") with an
                optional ordinal for the device type (e.g. "cuda:X", where is the ordinal).
                Alternatively, can be an object representing the device on which the
                computation will take place. Default: None, uses the same device as `model`.
            memory_cache: if this flag is set to True, `state_dict` of
                model and optimizer will be cached in memory. Otherwise, they will be saved
                to files under the `cache_dir`.
            cache_dir: path for storing temporary files. If no path is
                specified, system-wide temporary directory is used. Notice that this
                parameter will be ignored if `memory_cache` is True.
            amp: use Automatic Mixed Precision
            pickle_module: module used for pickling metadata and objects, default to `pickle`.
                this arg is used by `torch.save`, for more details, please check:
                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
            pickle_protocol: can be specified to override the default protocol, default to `2`.
                this arg is used by `torch.save`, for more details, please check:
                https://pytorch.org/docs/stable/generated/torch.save.html#torch.save.
            verbose: verbose output
        Returns:
            None
        lrloss)	in_memoryrW   rZ   r\   rO   rQ   N)rQ   _check_for_schedulerrO   rR   historyrV   rW   rY   r^   r*   
parametersrT   model_devicer   state_cacherstore
state_dict)r   rO   rQ   rR   rT   rV   rW   rY   rZ   r\   r^   r   r   r    r!      s    )zLearningRateFinder.__init__c                 C  s:   | j | jd | j| jd | j | j dS )z9Restores the model and optimizer to their initial states.rO   rQ   N)rO   load_state_dictrg   retrieverQ   torf   r$   r   r   r    reset   s   zLearningRateFinder.resetg      $@d   expg?   rA   train_loaderr	   
val_loaderDataLoader | Noner   r   r   start_lrfloat | Noneend_lrfloatnum_iter	step_modestrsmooth_f
diverge_thaccumulation_stepsnon_blocking_transferr3   c                 C  s  g g d| _ td }| j| j |   |r| | |dkr%td| dkr3t	| j
||}n| dkrAt| j
||}ntd| |	dk sP|	dkrTtd	t|||}|rbt|||}| jrstrsttjd
d}tjj}nt}t}||D ]k}| jrtstd|d  d|  | j|||d}|r| j||d}| j d | d  |  |dkr|}n|	dkr|	| d|	 | j d d   }||k r|}| j d | ||
| kr| jr|d  nq{|r| jrtd |   dS dS )aj  Performs the learning rate range test.

        Args:
            train_loader: training set data loader.
            val_loader: validation data loader (if desired).
            image_extractor: callable function to get the image from a batch of data.
                Default: `x["image"] if isinstance(x, dict) else x[0]`.
            label_extractor: callable function to get the label from a batch of data.
                Default: `x["label"] if isinstance(x, dict) else x[1]`.
            start_lr : the starting learning rate for the range test.
                The default is the optimizer's learning rate.
            end_lr: the maximum learning rate to test. The test may stop earlier than
                this if the result starts diverging.
            num_iter: the max number of iterations for test.
            step_mode: schedule for increasing learning rate: (`linear` or `exp`).
            smooth_f: the loss smoothing factor within the `[0, 1[` interval. Disabled
                if set to `0`, otherwise loss is smoothed using exponential smoothing.
            diverge_th: test is stopped when loss surpasses threshold:
                `diverge_th * best_loss`.
            accumulation_steps: steps for gradient accumulation. If set to `1`,
                gradients are not accumulated.
            non_blocking_transfer: when `True`, moves data to device asynchronously if
                possible, e.g., moving CPU Tensors with pinned memory to CUDA devices.
            auto_reset: if `True`, returns model and optimizer to original states at end
                of test.
        Returns:
            None
        r_   infrA   z `num_iter` must be larger than 1ro   linearz#expected one of (exp, linear), got r   z$smooth_f is outside the range [0, 1[zComputing optimal learning rate)descz+Computing optimal learning rate, iteration /)r~   r`   ra   z%Stopping early, the loss has divergedzResetting model and optimizerN)rd   rw   rO   rl   rT   rc   _set_learning_rater   lowerr   rQ   r   r2   r<   r^   has_tqdmr   r   trangewriterangeprint_train_batch	_validateappendget_lrsteprm   )r   rq   rr   r   r   rt   rv   rx   ry   r{   r|   r}   r~   r3   	best_lossZlr_schedule
train_iterval_iterr   Ztprint	iterationra   r   r   r    
range_test   s`   .




zLearningRateFinder.range_testnew_lrsfloat | listc                 C  s\   t |ts|gt| jj }t|t| jjkrtdt| jj|D ]\}}||d< q#dS )z#Set learning rate(s) for optimizer.zYLength of `new_lrs` is not equal to the number of parameter groups in the given optimizerr`   N)r   listr>   rQ   param_groupsr   zip)r   r   param_groupnew_lrr   r   r    r   y  s   

z%LearningRateFinder._set_learning_ratec                 C  s"   | j jD ]
}d|v rtdqdS )z/Check optimizer doesn't already have scheduler.
initial_lrz0Optimizer already has a scheduler attached to itN)rQ   r   RuntimeError)r   r   r   r   r    rc     s
   z'LearningRateFinder._check_for_schedulerr   r2   c              	   C  s   | j   d}| j  t|D ]^}t|\}}t||g| j|d\}}|  |}| ||}	|	| }	| j	rdt
| jdrd|d | dk}
tjj	j|	| j|
d}|  W d    n1 s^w   Y  n|	  ||	 7 }q| j  |S )Nr   rT   non_blockingZ
_amp_stashrA   )delay_unscale)rO   trainrQ   	zero_gradr   r*   r   rT   rR   rY   hasattrtorchcudaZ
scale_lossbackwarditemr   )r   r   r}   r~   
total_lossir9   r'   outputsra   r   Zscaled_lossr   r   r    r     s&   




zLearningRateFinder._train_batchr   r<   c                 C  s   d}t | j0 |D ]%\}}t||g| j|d\}}| |}| ||}|| t| 7 }q
W d    n1 s:w   Y  |t|j S )Nr   r   )r
   rO   r   rT   rR   r   r>   r#   )r   r   r~   Zrunning_lossr9   r'   r   ra   r   r   r    r     s   

zLearningRateFinder._validater   
skip_startskip_endtuple[list, list]c                 C  sd   |dk rt d|dk rt d| jd }| jd }t|| d }||| }||| }||fS )zGet learning rates and their corresponding losses

        Args:
            skip_start: number of batches to trim from the start.
            skip_end: number of batches to trim from the end.
        r   zskip_start cannot be negativezskip_end cannot be negativer`   ra   rA   )r   rd   r>   )r   r   r   lrslossesend_idxr   r   r    get_lrs_and_losses  s   

z%LearningRateFinder.get_lrs_and_losses'tuple[float, float] | tuple[None, None]c                 C  sT   |  ||\}}ztt| }|| || fW S  ty)   td Y dS w )aC  Get learning rate which has steepest gradient and its corresponding loss

        Args:
            skip_start: number of batches to trim from the start.
            skip_end: number of batches to trim from the end.

        Returns:
            Learning rate which has steepest gradient and its corresponding loss
        zBFailed to compute the gradients, there might not be enough points.)NN)r   npgradientarrayargminr   r   )r   r   r   r   r   Zmin_grad_idxr   r   r    get_steepest_gradient  s   
z(LearningRateFinder.get_steepest_gradientlog_lrax
Any | Nonesteepest_lrc              	   C  s   t s	td dS | ||\}}d}|du rt \}}||| |rE| ||\}	}
|	durE|
durE|j|	|
dddddd |	  |rL|
d	 |d
 |d |dur^t  |S )a  Plots the learning rate range test.

        Args:
            skip_start: number of batches to trim from the start.
            skip_end: number of batches to trim from the start.
            log_lr: True to plot the learning rate in a logarithmic
                scale; otherwise, plotted in a linear scale.
            ax: the plot is created in the specified matplotlib axes object and the
                figure is not be shown. If `None`, then the figure and axes object are
                created in this method and the figure is shown.
            steepest_lr: plot the learning rate which had the steepest gradient.

        Returns:
            The `matplotlib.axes.Axes` object that contains the plot. Returns `None` if
            `matplotlib` is not installed.
        z(Matplotlib is missing, can't plot resultNK   ored   zsteepest gradient)smarkercolorzorderrM   logzLearning rateLoss)has_matplotlibwarningswarnr   pltsubplotsplotr   scatterlegend
set_xscale
set_xlabel
set_ylabelshow)r   r   r   r   r   r   r   r   figZlr_at_steepest_gradZloss_at_steepest_gradr   r   r    r     s8   
	


zLearningRateFinder.plot)rO   rP   rQ   r   rR   rS   rT   rU   rV   r4   rW   rX   rY   r4   rZ   r[   r\   r]   r^   r4   r   r   )r   r   )rq   r	   rr   rs   r   r   r   r   rt   ru   rv   rw   rx   r]   ry   rz   r{   rw   r|   r]   r}   r]   r~   r4   r3   r4   r   r   )r   r   r   r   r:   )r   r2   r}   r]   r~   r4   r   rw   )r   r<   r~   r4   r   rw   )r   r   )r   r]   r   r]   r   r   )r   r]   r   r]   r   r   )r   r   TNT)r   r]   r   r]   r   r4   r   r   r   r4   r   r   )r.   r/   r0   rB   pickler   r!   rm   rK   rN   r   r   rc   r   r   r   r   r   r   r   r   r    r      sJ    +
@

z#)rC   r   r   rD   ),
__future__r   r   typesr   	functoolsr   typingr   r   r   numpyr   r   torch.nnnntorch.optimr   torch.serializationr   torch.utils.datar	   monai.networks.utilsr
   Zmonai.optimizers.lr_schedulerr   r   monai.utilsr   r   r   matplotlib.pyplotpyplotr   r   r   r   __all__r   r2   r<   rK   rN   r   r   r   r   r    <module>   s:   
%
