| import torch | |
| from functools import partial | |
| from torch.utils.data import DataLoader | |
| from torch import nn | |
| def collate_fn(batch): | |
| input_ids = torch.tensor(batch[0]['input_ids']) | |
| attention_mask = torch.tensor(batch[0]['attention_mask']) | |
| return { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask | |
| } | |
| class CustomDataModule(nn.Module): | |
| def __init__(self, train_dataset, val_dataset, test_dataset, collate_fn=collate_fn): | |
| super().__init__() | |
| self.train_dataset = train_dataset | |
| self.val_dataset = val_dataset | |
| self.test_dataset = test_dataset | |
| self.collate_fn = collate_fn | |
| def train_dataloader(self): | |
| return DataLoader(self.train_dataset, | |
| collate_fn=partial(self.collate_fn), | |
| num_workers=8, | |
| pin_memory=True, | |
| shuffle=True) | |
| def val_dataloader(self): | |
| return DataLoader(self.val_dataset, | |
| collate_fn=partial(self.collate_fn), | |
| num_workers=8, | |
| pin_memory=True, | |
| shuffle=False) | |
| def test_dataloader(self): | |
| return DataLoader(self.test_dataset, | |
| collate_fn=partial(self.collate_fn), | |
| num_workers=8, | |
| pin_memory=True, | |
| shuffle=False) |