'''
Creation of model
Adapted from MedicalZooPytorch: https://github.com/black0017/MedicalZooPytorch
'''
import torch.optim as optim
# from .Vnet import VNet, VNetLight,VNetAttention,Swinunetr
from .dice import DiceLoss, compute_per_channel_dice

from .Basicunetplus import Basicunetplusplus

model_list = ['VNET']


def create_model(args):
  
    optimizer_name = args.opt
    lr = 1e-3
    weight_decay = 1e-3
    # model = VNet(in_channels=in_channels, elu=False, classes=num_classes)
    # model = Swinunetr(
    #     img_size=(128, 128, 128),   
    #     in_channels=1,
    #     classes=2
    # )
    # model = Basicunet(
    #     in_channels=1,
    #     classes=2   )
    # model = vNet(
    #     in_channels=1,
    #     classes=2
    # )
    # model = Ahnet(in_channels=in_channels,
    #     classes=num_classes)
    # model = UNet_N(
    #     in_channels=in_channels,
    #     classes=num_classes,
    # )
    
    model = Basicunetplusplus(
        in_channels=1,
        classes=2,
        dropout=0.2,
        # dropout_p=0.2,
    )
    # model = Basicunetplusplusdual(
    #     in_channels=1,
    #     classes=2,
    #     dropout=0.2,
    #     dropout_p=0.0,
    # )
    
    # model = Basicunetplusplus_reconstruct(
    #     in_channels=1,
    #     classes=2,
    #     dropout=0.3,
    #     dropout_p=0.0,
    # )

    # model = Daf3d_moe(  
    #     in_channels=1,
    #     classes=2
    # )
    # model = Dynunet(
    #     in_channels=1,
    #     classes=2
    # )
    # model = DualAttentionVNet3D(in_ch=in_channels,  n_cls=num_classes)

    if optimizer_name == 'sgd':
        optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.5, weight_decay=weight_decay)
    elif optimizer_name == 'adam':
        optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    elif optimizer_name == 'rmsprop':
        optimizer = optim.RMSprop(model.parameters(), lr=lr, weight_decay=weight_decay)

    return model, optimizer
