| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| from abc import ABC, abstractmethod | |
| from torch import nn | |
| class BaseProjector(nn.Module, ABC): | |
| def __init__(self): | |
| super().__init__() | |
| self.adaptive_avg_pool = None | |
| def setup_projector(self): | |
| """ | |
| Setup the vision_projector attribute in subclasses. | |
| """ | |
| pass | |
| def forward(self, x): | |
| x = x.permute(1, 0, 2) # NLD -> LND | |
| x = self.projector(x) | |
| x = x.permute(1, 0, 2) | |
| if self.adaptive_avg_pool is not None: | |
| x = self.adaptive_avg_pool(x) | |
| return x | |