o
    ia                     @  s   d dl mZ d dlZd dlZd dlmZ d dlmZmZm	Z	 d dl
Zd dlZd dlmZ d dlmZ d dlmZ edd	d
\ZZeddd
\ZZg dZ								d>d?d#d$ZG d%d& d&ejjZd@dAd+d,ZdBdCd2d3Z		4dDdEd:d;ZdBdCd<d=ZdS )F    )annotationsN)Path)AnyOptionalUnion)cudnn)
MetaTensor)optional_importz4batchgenerators.utilities.file_and_folder_operationsjoin)name	load_json)get_nnunet_trainerget_nnunet_monai_predictorget_network_from_nnunet_plansconvert_nnunet_to_monai_bundleconvert_monai_bundle_to_nnunetModelnnUNetWrappernnUNetTrainernnUNetPlansFcudadataset_name_or_idUnion[str, int]configurationstrfoldUnion[int, str]trainer_class_nameplans_identifieruse_compressed_databoolcontinue_trainingonly_run_validationdisable_checkpointingdevicepretrained_modelOptional[str]returnr   c              
   C  s   t |tr&|dkr&zt|}W n ty% } z
td| d |d}~ww ddlm}m} |t| ||||t	|	d}|rB||_
|rJ|rJJ d|||| |  tj r_d	t_d
t_|
durwtj|
d
d}d|v rw|jj|d  |S )a	  
    Get the nnUNet trainer instance based on the provided configuration.
    The returned nnUNet trainer can be used to initialize the SupervisedTrainer for training, including the network,
    optimizer, loss function, DataLoader, etc.

    Example::

        from monai.apps import SupervisedTrainer
        from monai.bundle.nnunet import get_nnunet_trainer

        dataset_name_or_id = 'Task009_Spleen'
        fold = 0
        configuration = '3d_fullres'
        nnunet_trainer = get_nnunet_trainer(dataset_name_or_id, configuration, fold)

        trainer = SupervisedTrainer(
            device=nnunet_trainer.device,
            max_epochs=nnunet_trainer.num_epochs,
            train_data_loader=nnunet_trainer.dataloader_train,
            network=nnunet_trainer.network,
            optimizer=nnunet_trainer.optimizer,
            loss_function=nnunet_trainer.loss_function,
            epoch_length=nnunet_trainer.num_iterations_per_epoch,
        )

    Parameters
    ----------
    dataset_name_or_id : Union[str, int]
        The name or ID of the dataset to be used.
    configuration : str
        The configuration name for the training.
    fold : Union[int, str]
        The fold number or 'all' for cross-validation.
    trainer_class_name : str, optional
        The class name of the trainer to be used. Default is 'nnUNetTrainer'.
        For a complete list of supported trainers, check:
        https://github.com/MIC-DKFZ/nnUNet/tree/master/nnunetv2/training/nnUNetTrainer/variants
    plans_identifier : str, optional
        Identifier for the plans to be used. Default is 'nnUNetPlans'.
    use_compressed_data : bool, optional
        Whether to use compressed data. Default is False.
    continue_training : bool, optional
        Whether to continue training from a checkpoint. Default is False.
    only_run_validation : bool, optional
        Whether to only run validation. Default is False.
    disable_checkpointing : bool, optional
        Whether to disable checkpointing. Default is False.
    device : str, optional
        The device to be used for training. Default is 'cuda'.
    pretrained_model : Optional[str], optional
        Path to the pretrained model file.

    Returns
    -------
    nnunet_trainer : object
        The nnUNet trainer instance.
    allz/Unable to convert given value for fold to int: z+. fold must bei either "all" or an integer!Nr   )get_trainer_from_argsmaybe_load_checkpoint)r#   z6Cannot set --c and --val flag at the same time. Dummy.FTweights_onlynetwork_weights)
isinstancer   int
ValueErrorprintZnnunetv2.run.run_trainingr(   r)   torchr#   r"   Zon_train_startr   is_availabler   deterministic	benchmarkloadnetwork	_orig_modload_state_dict)r   r   r   r   r   r   r    r!   r"   r#   r$   er(   r)   nnunet_trainer
state_dict r<   a/home/dell461/cl/sdc2/last_ska_mid/HISourceFinder-master-l/src/monai/apps/nnunet/nnunet_bundle.pyr   &   s8   
G

r   c                      s.   e Zd ZdZdd fd	d
ZdddZ  ZS )r   at  
    A wrapper class for nnUNet model integration with MONAI framework.
    The wrapper can be use to integrate the nnUNet Bundle within MONAI framework for inference.

    Parameters
    ----------
    predictor : nnUNetPredictor
        The nnUNet predictor object used for inference.
    model_folder : Union[str, Path]
        The folder path where the model and related files are stored.
    model_name : str, optional
        The name of the model file, by default "model.pt".

    Attributes
    ----------
    predictor : nnUNetPredictor
        The nnUNet predictor object used for inference.
    network_weights : torch.nn.Module
        The network weights of the model.

    Notes
    -----
    This class integrates nnUNet model with MONAI framework by loading necessary configurations,
    restoring network architecture, and setting up the predictor for inference.
    model.pt	predictorobjectmodel_folderUnion[str, Path]
model_namer   c                   s  t    || _|}ddlm} ttt|jd}ttt|jd}||}g }	t	j
tt|jdt	ddd}
|
d	 }|
d
 d }d|
 v rO|
d nd }t|| r{t	j
t||t	ddd}d| v rv|	|d  n|	| ||}dd l}ddlm} ddlm} ||||}|t|jd dd|d}|d u rtd| d|j|j|j|j|||jdd}||_||_|	|_ ||_!||_"||_#||_$|||_%| jj!| _&d S )Nr   PlansManagerdataset.json
plans.jsonnnunet_checkpoint.pthcpuT)map_locationr+   trainer_name	init_argsr    inference_allowed_mirroring_axesr,   )recursive_find_python_classdetermine_num_input_channelstrainingr   znnunetv2.training.nnUNetTrainerzUnable to locate trainer class zM in nnunetv2.training.nnUNetTrainer. Please place it there (in any .py file)!F)enable_deep_supervision)'super__init__r?   /nnunetv2.utilities.plans_handling.plans_handlerrE   r   r
   r   parentr1   r5   r#   keysjoinpathis_fileappendget_configurationnnunetv2Z%nnunetv2.utilities.find_class_by_namerN   0nnunetv2.utilities.label_handling.label_handlingrP   __path__RuntimeErrorZbuild_network_architecturenetwork_arch_class_namenetwork_arch_init_kwargs#network_arch_init_kwargs_req_importget_label_managernum_segmentation_headsplans_managerconfiguration_managerZlist_of_parametersr6   dataset_jsonrK   Zallowed_mirroring_axeslabel_managerr,   )selfr?   rA   rC   Zmodel_training_output_dirrE   rg   plansre   
parameters
checkpointrK   Zconfiguration_namerM   Zmonai_checkpointrf   r\   rN   rP   num_input_channelsZtrainer_classr6   	__class__r<   r=   rT      sn   




	zModelnnUNetWrapper.__init__xr   r&   c           	      C  s(  t |trSd|jv rd|jd d dd   i}n;d|jv rLt|jd d d  t|jd d d  t|jd d d  g}d|i}ndg di}ntd	|  dd
d
f }| j	j
|d
|d
dddd}g }|D ]}|ttt|dd qtt|d}t||jdS )a  
        Forward pass for the nnUNet model.

        Args:
            x (MetaTensor): Input tensor. If the input is a tuple,
                it is assumed to be a decollated batch (list of tensors). Otherwise, it is assumed to be a collated batch.

        Returns:
            MetaTensor: The output tensor with the same metadata as the input.

        Raises:
            TypeError: If the input is not a torch.Tensor or a tuple of MetaTensors.

        Notes:
            - If the input is a tuple, the filenames are extracted from the metadata of each tensor in the tuple.
            - If the input is a collated batch, the filenames are extracted from the metadata of the input tensor.
            - The filenames are used to generate predictions using the nnUNet predictor.
            - The predictions are converted to torch tensors, with added batch and channel dimensions.
            - The output tensor is concatenated along the batch dimension and returned as a MetaTensor with the same metadata.
        pixdimspacingr         affine   )      ?rw   rw   z5Input must be a MetaTensor or a tuple of MetaTensors.NF)Ztruncated_ofnameZsave_probabilitiesZnum_processesZ!num_processes_segmentation_export)meta)r-   r   rx   numpytolistabsitem	TypeErrorrI   r?   Zpredict_from_list_of_npy_arraysrZ   r1   
from_numpynpexpand_dimscat)	ri   rp   Z properties_or_list_of_propertiesrr   Zimage_or_list_of_imagesZprediction_outputZout_tensorsout
out_tensorr<   r<   r=   forward   s4   

$

"zModelnnUNetWrapper.forwardr>   )r?   r@   rA   rB   rC   r   )rp   r   r&   r   )__name__
__module____qualname____doc__rT   r   __classcell__r<   r<   rn   r=   r      s    Fr   r>   rA   rB   rC   c              	   C  s:   ddl m} |dddtdddddd}t|| |}|S )a  
    Initializes and returns a `nnUNetMONAIModelWrapper` containing the corresponding `nnUNetPredictor`.
    The model folder should contain the following files, created during training:

        - dataset.json: from the nnUNet results folder
        - plans.json: from the nnUNet results folder
        - nnunet_checkpoint.pth: The nnUNet checkpoint file, containing the nnUNet training configuration
        - model.pt: The checkpoint file containing the model weights.

    The returned wrapper object can be used for inference with MONAI framework:

    Example::

        from monai.bundle.nnunet import get_nnunet_monai_predictor

        model_folder = 'path/to/monai_bundle/model'
        model_name = 'model.pt'
        wrapper = get_nnunet_monai_predictor(model_folder, model_name)

        # Perform inference
        input_data = ...
        output = wrapper(input_data)


    Parameters
    ----------
    model_folder : Union[str, Path]
        The folder where the model is stored.
    model_name : str, optional
        The name of the model file, by default "model.pt".

    Returns
    -------
    ModelnnUNetWrapper
        A wrapper object that contains the nnUNetPredictor and the loaded model.
    r   )nnUNetPredictorg      ?TFr   )Ztile_step_sizeZuse_gaussianZuse_mirroringr#   verboseZverbose_preprocessingZ
allow_tqdm)Z(nnunetv2.inference.predict_from_raw_datar   r1   r#   r   )rA   rC   r   r?   wrapperr<   r<   r=   r   )  s   &

r   nnunet_configdictbundle_root_folderr.   Nonec                 C  s  d}d}d}d| v r| d }d| v r| d }d| v r| d }ddl m} || d	 }ttjd
 || d| d| }tjt|d| ddd}	tjt|d| ddd}
i }|	d |d< |	d |d< |	d |d< t|t|dd t|dd| j	ddd i }|	d |d< t|t|dd| d i }|
d |d< t|t|dd| d tj
tj
|ddstt|dt|dd tj
tj
|ddstt|dt|dd dS dS )a  
    Convert nnUNet model checkpoints and configuration to MONAI bundle format.

    Parameters
    ----------
    nnunet_config : dict
        Configuration dictionary for nnUNet, containing keys such as 'dataset_name_or_id', 'nnunet_configuration',
        'nnunet_trainer', and 'nnunet_plans'.
    bundle_root_folder : str
        Root folder where the MONAI bundle will be saved.
    fold : int, optional
        Fold number of the nnUNet model to be converted, by default 0.

    Returns
    -------
    None
    r   r   Z
3d_fullresr:   nnunet_plansnnunet_configurationr   maybe_convert_to_dataset_namer   nnUNet_results__fold_checkpoint_final.pthTr*   checkpoint_best.pthrM   rL   rK   modelsrH   parentsexist_okr,   r>   zbest_model.ptrG   rF   N)-nnunetv2.utilities.dataset_name_id_conversionr   r   osenvironrX   r1   r5   savemkdirpathexistsr
   shutilcopy)r   r   r   r:   r   r   r   dataset_namennunet_model_folderZnnunet_checkpoint_finalZnnunet_checkpoint_bestnnunet_checkpointmonai_last_checkpointmonai_best_checkpointr<   r<   r=   r   _  sT      r   model
plans_filedataset_file
model_ckptmodel_key_in_ckptUnion[torch.nn.Module, Any]c              	   C  s   ddl m} ddlm} ddlm} ddlm} || }	||}
||	}||}||||
}|	|
}d}||j
|j|j||jd|d}|du rK|S tj|dd	}|||  |S )
a-  
    Load and initialize a nnUNet network based on nnUNet plans and configuration.

    Parameters
    ----------
    plans_file : str
        Path to the JSON file containing the nnUNet plans.
    dataset_file : str
        Path to the JSON file containing the dataset information.
    configuration : str
        The configuration name to be used from the plans.
    model_ckpt : Optional[str], optional
        Path to the model checkpoint file. If None, the network is returned without loading weights (default is None).
    model_key_in_ckpt : str, optional
        The key in the checkpoint file that contains the model state dictionary (default is "model").

    Returns
    -------
    network : torch.nn.Module
        The initialized neural network, with weights loaded if `model_ckpt` is provided.
    r   )r   )get_network_from_plansrO   rD   T)Z
allow_initdeep_supervisionNr*   )Z4batchgenerators.utilities.file_and_folder_operationsr   Z)nnunetv2.utilities.get_network_from_plansr   r]   rP   rU   rE   r[   rc   r`   ra   rb   rd   r1   r5   r8   )r   r   r   r   r   r   r   rP   rE   rj   rg   re   rf   rm   rh   rR   r6   r;   r<   r<   r=   r     s2   


r   c                 C  sR  ddl m } d}d}d| v r| d }d| v r| d }ddlm} ddlm} 	
d9d:dd}ttjd || d | d| d}	ttjd || d }
t|	d| j	d
d
d t
j| dd
d}|t|dd| d d
d!}g }|D ]}|t|td"td#   q}|  |d$ }t
j| d%| d&| d#d
d}|t|dd| d'd
d!}g }|D ]}|t|td(td#   q|  |d$ }t
j| d%| d)| d#d
d}|d* |d*< | |d+< |d+ D ]}|d+ | |d+ |< q||d,< |  |d-< d|d.< d	|d/< t
|t|	d| d0 | |d+< |d* |d*< |d+ D ]}|d+ | |d+ |< q4t
|t|	d| d1 tjtj|	d2sft| d3|	 tjtj|	d4s{t| d5|	 tjtj|	d6st|
 d7|	 tjtj|	d8st| d|	 d	S d	S );av  
    Convert a MONAI bundle to nnU-Net format.

    Parameters
    ----------
    nnunet_config : dict
        Configuration dictionary for nnU-Net. Expected keys are:
        - "dataset_name_or_id": str, name or ID of the dataset.
        - "nnunet_trainer": str, optional, name of the nnU-Net trainer (default is "nnUNetTrainer").
        - "nnunet_plans": str, optional, name of the nnU-Net plans (default is "nnUNetPlans").
    bundle_root_folder : str
        Path to the root folder of the MONAI bundle.
    fold : int, optional
        Fold number for cross-validation (default is 0).

    Returns
    -------
    None
    r   )odictr   r   r:   r   )nnUNetLoggerr   NTfolderrB   prefixr%   suffixsortr   r&   	list[str]c                   s,    fddt |  D }|r|  |S )Nc                   sB   g | ]}|  r d u s|j rd u s|jr|jqS )N)rY   r   
startswithendswith).0ir   r   r<   r=   
<listcomp>  s    zDconvert_monai_bundle_to_nnunet.<locals>.subfiles.<locals>.<listcomp>)r   iterdirr   )r   r   r   r   resr<   r   r=   subfiles  s   
z0convert_monai_bundle_to_nnunet.<locals>.subfilesr   r   r   Z__3d_fullresZnnUNet_preprocessedr   r   z/models/nnunet_checkpoint.pthr*   r   Zcheckpoint_epoch)r   r   zcheckpoint_epoch=z.ptz/models/fold_z/checkpoint_epoch=Zcheckpoint_key_metriczcheckpoint_key_metric=z/checkpoint_key_metric=optimizer_stater,   current_epochloggingZ	_best_emaZgrad_scaler_stater   r   rF   z/models/dataset.jsonrG   z/models/plans.jsonzdataset_fingerprint.jsonz/dataset_fingerprint.jsonrH   )NNT)
r   rB   r   r%   r   r%   r   r   r&   r   )r   Z'nnunetv2.training.logging.nnunet_loggerr   r   r   r   r   r   rX   r   r1   r5   rZ   r.   lenr   r   Zget_checkpointr   r   r   r
   r   r   )r   r   r   r   r:   r   r   r   r   r   Znnunet_preprocess_model_folderr   Zlatest_checkpointsepochsZlatest_checkpointZfinal_epochr   Zbest_checkpointsZkey_metricsZbest_checkpointZbest_key_metricr   keyr<   r<   r=   r     s   

""

r   )r   r   FFFFr   N)r   r   r   r   r   r   r   r   r   r   r   r   r    r   r!   r   r"   r   r#   r   r$   r%   r&   r   r   )rA   rB   rC   r   r&   r   )r   )r   r   r   r   r   r.   r&   r   )Nr   )r   r   r   r   r   r   r   r%   r   r   r&   r   )
__future__r   r   r   pathlibr   typingr   r   r   ry   r   r1   torch.backendsr   monai.data.meta_tensorr   monai.utilsr	   r
   _r   __all__r   nnModuler   r   r   r   r   r<   r<   r<   r=   <module>   s<   
h 6N=