import copy import pickle import torch class EnhancerDataset(torch.utils.data.Dataset): def __init__(self, mel_enhancer=True, split='train'): all_data = pickle.load(open(f'./dataset/enhancer_data/Deep{"MEL2" if mel_enhancer else "FlyBrain"}_data.pkl', 'rb')) self.seqs = torch.argmax(torch.from_numpy(copy.deepcopy(all_data[f'{split}_data'])), dim=-1) self.clss = torch.argmax(torch.from_numpy(copy.deepcopy(all_data[f'y_{split}'])), dim=-1) self.num_cls = all_data[f'y_{split}'].shape[-1] self.alphabet_size = 4 def __len__(self): return len(self.seqs) def __getitem__(self, idx): return self.seqs[idx], self.clss[idx]