File size: 1,432 Bytes
b24eac9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import torch
from functools import partial
from torch.utils.data import DataLoader
from torch import nn

def collate_fn(batch):
  input_ids = torch.tensor(batch[0]['input_ids'])
  attention_mask = torch.tensor(batch[0]['attention_mask'])
  return {
    'input_ids': input_ids,
    'attention_mask': attention_mask
  }

class CustomDataModule(nn.Module):
    def __init__(self, train_dataset, val_dataset, test_dataset, collate_fn=collate_fn):
        super().__init__()
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        self.collate_fn = collate_fn

    def train_dataloader(self):
        return DataLoader(self.train_dataset, 
                          collate_fn=partial(self.collate_fn),
                          num_workers=8,
                          pin_memory=True,
                          shuffle=True)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, 
                          collate_fn=partial(self.collate_fn),
                          num_workers=8,
                          pin_memory=True,
                          shuffle=False)
  
    def test_dataloader(self):
        return DataLoader(self.test_dataset, 
                          collate_fn=partial(self.collate_fn),
                          num_workers=8,
                          pin_memory=True,
                          shuffle=False)