from nvidia.dali.pipeline import pipeline_def
import nvidia.dali.types as types
import nvidia.dali.fn as fn
from nvidia.dali.plugin.pytorch import DALIGenericIterator
import os

# To run with different data, see documentation of nvidia.dali.fn.readers.file
# points to https://github.com/NVIDIA/DALI_extra
data_root_dir = "./data"
fits_dir = os.path.join(data_root_dir, 'fits')


def loss_func(pred, y):
    pass


def model(x):
    pass


def backward(loss, model):
    pass


@pipeline_def(num_threads=4, device_id=0)
def get_dali_pipeline():
  fn.experimental.readers.fits(file_root=fits_dir, random_shuffle=True, name='FITS_READER')


train_data = DALIGenericIterator(
    [get_dali_pipeline(batch_size=16)],
    ['data', 'label'],
    reader_name='Reader'
)


for i, data in enumerate(train_data):
    x, y = data[0]['data'], data[0]['label']
    pred = model(x)
    loss = loss_func(pred, y)
    backward(loss, model)