File size: 428 Bytes
c69c4af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
"""Task class definition """

from dataclasses import dataclass
from typing import  List, Type
import torch.nn as nn



@dataclass
class Task:
    """Encapsulates all configuration for a single task."""
    name: str
    class_labels: List[str]
    criterion: Type[nn.Module]
    weight: float = 1.0
    use_weighted_loss: bool = False
    
    @property
    def num_classes(self) -> int:
        return len(self.class_labels)