# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import json
import os
import warnings
from collections.abc import Sequence
from pathlib import Path
from typing import overload

from monai.config import KeysCollection, PathLike
from monai.data.utils import partition_dataset, select_cross_validation_folds
from monai.utils import ensure_tuple


@overload
def _compute_path(base_dir: PathLike, element: PathLike, check_path: bool = False) -> str: ...


@overload
def _compute_path(base_dir: PathLike, element: list[PathLike], check_path: bool = False) -> list[str]: ...


def _compute_path(base_dir, element, check_path=False):
    """
    Args:
        base_dir: the base directory of the dataset.
        element: file path(s) to append to directory.
        check_path: if `True`, only compute when the result is an existing path.

    Raises:
        TypeError: When ``element`` contains a non ``str``.
        TypeError: When ``element`` type is not in ``Union[list, str]``.

    """

    def _join_path(base_dir: PathLike, item: PathLike):
        result = os.path.normpath(os.path.join(base_dir, item))
        if check_path and not os.path.exists(result):
            # if not an existing path, don't join with base dir
            return f"{item}"
        return f"{result}"

    if isinstance(element, (str, os.PathLike)):
        return _join_path(base_dir, element)
    if isinstance(element, list):
        for e in element:
            if not isinstance(e, (str, os.PathLike)):
                return element
        return [_join_path(base_dir, e) for e in element]
    return element


def _append_paths(base_dir: PathLike, is_segmentation: bool, items: list[dict]) -> list[dict]:
    """
    Args:
        base_dir: the base directory of the dataset.
        is_segmentation: whether the datalist is for segmentation task.
        items: list of data items, each of which is a dict keyed by element names.

    Raises:
        TypeError: When ``items`` contains a non ``dict``.

    """
    for item in items:
        if not isinstance(item, dict):
            raise TypeError(f"Every item in items must be a dict but got {type(item).__name__}.")
        for k, v in item.items():
            if k == "image" or is_segmentation and k == "label":
                item[k] = _compute_path(base_dir, v, check_path=False)
            else:
                # for other items, auto detect whether it's a valid path
                item[k] = _compute_path(base_dir, v, check_path=True)
    return items


def load_decathlon_datalist(
    data_list_file_path: PathLike,
    is_segmentation: bool = True,
    data_list_key: str = "training",
    base_dir: PathLike | None = None,
) -> list[dict]:
    """Load image/label paths of decathlon challenge from JSON file

    JSON file should follow the format of the Medical Segmentation Decathlon
    datalist.json files, see http://medicaldecathlon.com.
    The files are structured as follows:

    .. code-block:: python

        {
            "metadata_key_0": "metadata_value_0",
            "metadata_key_1": "metadata_value_1",
            ...,
            "training": [
                {"image": "path/to/image_1.nii.gz", "label": "path/to/label_1.nii.gz"},
                {"image": "path/to/image_2.nii.gz", "label": "path/to/label_2.nii.gz"},
                ...
            ],
            "test": [
                "path/to/image_3.nii.gz",
                "path/to/image_4.nii.gz",
                ...
            ]
        }


    The metadata keys are optional for loading the datalist, but include:
        - some string items: ``name``, ``description``, ``reference``, ``licence``, ``release``, ``tensorImageSize``
        - two dict items: ``modality`` (keyed by channel index), and ``labels`` (keyed by label index)
        - and two integer items: ``numTraining`` and ``numTest``, with the number of items.

    The ``training`` key contains a list of dictionaries, each of which has at least
    the ``image`` and ``label`` keys.
    The image and label are loaded by :py:func:`monai.transforms.LoadImaged`, so both can be either
    a single file path or a list of file paths, in which case they are loaded as multi-channel images.
    Each item can also include a ``fold`` key for cross-validation purposes.
    The "test" key contains a list of image paths, without labels, MONAI also supports a "validation" list
    with the same format as the "training" list.


    Args:
        data_list_file_path: the path to the json file of datalist.
        is_segmentation: whether the datalist is for segmentation task, default is True.
        data_list_key: the key to get a list of dictionary to be used, default is "training".
        base_dir: the base directory of the dataset, if None, use the datalist directory.

    Raises:
        ValueError: When ``data_list_file_path`` does not point to a file.
        ValueError: When ``data_list_key`` is not specified in the data list file.

    Returns a list of data items, each of which is a dict keyed by element names, for example:

    .. code-block:: python

        [
            {'image': '/workspace/data/chest_19.nii.gz',  'label': '/workspace/labels/chest_19.nii.gz'},
            {'image': '/workspace/data/chest_31.nii.gz',  'label': '/workspace/labels/chest_31.nii.gz'},
        ]

    """
    data_list_file_path = Path(data_list_file_path)
    if not data_list_file_path.is_file():
        raise ValueError(f"Data list file {data_list_file_path} does not exist.")
    with open(data_list_file_path) as json_file:
        json_data = json.load(json_file)
    if data_list_key not in json_data:
        raise ValueError(f'Data list {data_list_key} not specified in "{data_list_file_path}".')
    expected_data = json_data[data_list_key]
    if data_list_key == "test" and not isinstance(expected_data[0], dict):
        # decathlon datalist may save the test images in a list directly instead of dict
        expected_data = [{"image": i} for i in expected_data]

    if base_dir is None:
        base_dir = data_list_file_path.parent

    return _append_paths(base_dir, is_segmentation, expected_data)


def load_decathlon_properties(data_property_file_path: PathLike, property_keys: Sequence[str] | str) -> dict:
    """Extract the properties with the specified keys from the Decathlon JSON file.
    See under `load_decathlon_datalist` for the expected keys in the Decathlon challenge.

    Args:
        data_property_file_path: the path to the JSON file of data properties.
        property_keys: expected keys to load from the JSON file, for example, we have these keys
            in the decathlon challenge:
            `name`, `description`, `reference`, `licence`, `tensorImageSize`,
            `modality`, `labels`, `numTraining`, `numTest`, etc.

    """
    data_property_file_path = Path(data_property_file_path)
    if not data_property_file_path.is_file():
        raise ValueError(f"Data property file {data_property_file_path} does not exist.")
    with open(data_property_file_path) as json_file:
        json_data = json.load(json_file)

    properties = {}
    for key in ensure_tuple(property_keys):
        if key not in json_data:
            raise KeyError(f"key {key} is not in the data property file.")
        properties[key] = json_data[key]
    return properties


def check_missing_files(
    datalist: list[dict], keys: KeysCollection, root_dir: PathLike | None = None, allow_missing_keys: bool = False
):
    """Checks whether some files in the Decathlon datalist are missing.
    It would be helpful to check missing files before a heavy training run.

    Args:
        datalist: a list of data items, every item is a dictionary.
            usually generated by `load_decathlon_datalist` API.
        keys: expected keys to check in the datalist.
        root_dir: if not None, provides the root dir for the relative file paths in `datalist`.
        allow_missing_keys: whether allow missing keys in the datalist items.
            if False, raise exception if missing. default to False.

    Returns:
        A list of missing filenames.

    """
    missing_files = []
    for item in datalist:
        for k in ensure_tuple(keys):
            if k not in item:
                if not allow_missing_keys:
                    raise ValueError(f"key `{k}` is missing in the datalist item: {item}")
                continue

            for f in ensure_tuple(item[k]):
                if not isinstance(f, (str, os.PathLike)):
                    raise ValueError(f"filepath of key `{k}` must be a string or a list of strings, but got: {f}.")
                f = Path(f)
                if isinstance(root_dir, (str, os.PathLike)):
                    f = Path(root_dir).joinpath(f)
                if not f.exists():
                    missing_files.append(f)

    return missing_files


def create_cross_validation_datalist(
    datalist: list[dict],
    nfolds: int,
    train_folds: Sequence[int] | int,
    val_folds: Sequence[int] | int,
    train_key: str = "training",
    val_key: str = "validation",
    filename: Path | str | None = None,
    shuffle: bool = True,
    seed: int = 0,
    check_missing: bool = False,
    keys: KeysCollection | None = None,
    root_dir: str | None = None,
    allow_missing_keys: bool = False,
    raise_error: bool = True,
):
    """
    Utility to create new Decathlon style datalist based on cross validation partition.

    Args:
        datalist: loaded list of dictionaries for all the items to partition.
        nfolds: number of the kfold split.
        train_folds: indices of folds for training part.
        val_folds: indices of folds for validation part.
        train_key: the key of train part in the new datalist, defaults to "training".
        val_key: the key of validation part in the new datalist, defaults to "validation".
        filename: if not None and ends with ".json", save the new datalist into JSON file.
        shuffle: whether to shuffle the datalist before partition, defaults to `True`.
        seed: if `shuffle` is True, set the random seed, defaults to `0`.
        check_missing: whether to check all the files specified by `keys` are existing.
        keys: if not None and check_missing_files is True, the expected keys to check in the datalist.
        root_dir: if not None, provides the root dir for the relative file paths in `datalist`.
        allow_missing_keys: if check_missing_files is `True`, whether allow missing keys in the datalist items.
            if False, raise exception if missing. default to False.
        raise_error: when found missing files, if `True`, raise exception and stop, if `False`, print warning.

    """
    if check_missing and keys is not None:
        files = check_missing_files(datalist, keys, root_dir, allow_missing_keys)
        if files:
            msg = f"some files of the datalist are missing: {files}"
            if raise_error:
                raise ValueError(msg)
            warnings.warn(msg)

    data = partition_dataset(data=datalist, num_partitions=nfolds, shuffle=shuffle, seed=seed)
    train_list = select_cross_validation_folds(partitions=data, folds=train_folds)
    val_list = select_cross_validation_folds(partitions=data, folds=val_folds)
    ret = {train_key: train_list, val_key: val_list}
    if isinstance(filename, (str, Path)):
        with open(filename, "w") as f:
            json.dump(ret, f, indent=4)

    return ret
