import torch from functools import partial from torch.utils.data import DataLoader from torch import nn def collate_fn(batch): input_ids = torch.tensor(batch[0]['input_ids']) attention_mask = torch.tensor(batch[0]['attention_mask']) return { 'input_ids': input_ids, 'attention_mask': attention_mask } class CustomDataModule(nn.Module): def __init__(self, train_dataset, val_dataset, test_dataset, collate_fn=collate_fn): super().__init__() self.train_dataset = train_dataset self.val_dataset = val_dataset self.test_dataset = test_dataset self.collate_fn = collate_fn def train_dataloader(self): return DataLoader(self.train_dataset, collate_fn=partial(self.collate_fn), num_workers=8, pin_memory=True, shuffle=True) def val_dataloader(self): return DataLoader(self.val_dataset, collate_fn=partial(self.collate_fn), num_workers=8, pin_memory=True, shuffle=False) def test_dataloader(self): return DataLoader(self.test_dataset, collate_fn=partial(self.collate_fn), num_workers=8, pin_memory=True, shuffle=False)