Spaces:
Runtime error
Runtime error
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import copy | |
| import math | |
| import torch | |
| from torch import nn, Tensor | |
| import torch.nn.functional as F | |
| def rand_sample(x, divisor, max_len): | |
| # non_zero_pos_point = [rand_sample((m.nonzero()/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']] | |
| if len(x.nonzero()) == 0: | |
| return x.nonzero().t() | |
| non_zero_point_index = (x.nonzero()/divisor).t() | |
| mask_ids = non_zero_point_index[0].unique().long() | |
| # compute probability for each samle | |
| probs = torch.zeros_like(non_zero_point_index[0]) | |
| for idx in mask_ids: | |
| prob = 1./(len(mask_ids)*((non_zero_point_index[0:1]==idx).sum())) | |
| probs[non_zero_point_index[0]==idx] = prob | |
| indices = torch.multinomial(probs, num_samples=min(max_len, len(probs)), replacement=False).sort()[0] | |
| non_zero_point_index = non_zero_point_index[:,indices] | |
| return non_zero_point_index # [n, 512] | |
| def rand_sample_plain(x, max_len): | |
| if x.shape[1] <= max_len: | |
| return x | |
| else: | |
| rand_idx = torch.randperm(x.shape[1])[:max_len] | |
| return x[:,rand_idx] | |
| def prepare_features(x, num_feature_levels, pe_layer, input_proj, level_embed): | |
| src = [] | |
| pos = [] | |
| size_list = [] | |
| # disable mask, it does not affect performance | |
| for i in range(num_feature_levels): | |
| size_list.append(x[i].shape[-2:]) | |
| pos.append(pe_layer(x[i], None).flatten(2)) | |
| src.append(input_proj[i](x[i]).flatten(2) + level_embed.weight[i][None, :, None]) | |
| # flatten NxCxHxW to HWxNxC | |
| pos[-1] = pos[-1].permute(2, 0, 1) | |
| src[-1] = src[-1].permute(2, 0, 1) | |
| return src, pos, size_list |