Spaces:
Configuration error
Configuration error
| """ | |
| Based on the implementation from: | |
| https://huggingface.co/spaces/fffiloni/lama-video-watermark-remover/tree/main | |
| Modules were adapted by Hans Brouwer to only support the final configuration of the model uploaded here: | |
| https://huggingface.co/akhaliq/lama | |
| Apache License 2.0: https://github.com/advimman/lama/blob/main/LICENSE | |
| @article{suvorov2021resolution, | |
| title={Resolution-robust Large Mask Inpainting with Fourier Convolutions}, | |
| author={Suvorov, Roman and Logacheva, Elizaveta and Mashikhin, Anton and Remizova, Anastasia and Ashukha, Arsenii and Silvestrov, Aleksei and Kong, Naejin and Goka, Harshith and Park, Kiwoong and Lempitsky, Victor}, | |
| journal={arXiv preprint arXiv:2109.07161}, | |
| year={2021} | |
| } | |
| """ | |
| import os | |
| import sys | |
| from urllib.request import urlretrieve | |
| import torch | |
| from einops import rearrange | |
| from PIL import Image | |
| from torch import nn | |
| from torch.nn import functional as F | |
| from torchvision.transforms.functional import to_tensor | |
| from tqdm import tqdm | |
| from train import export_to_video | |
| LAMA_URL = "https://huggingface.co/akhaliq/lama/resolve/main/best.ckpt" | |
| LAMA_PATH = "models/lama.ckpt" | |
| def download_progress(t): | |
| last_b = [0] | |
| def update_to(b=1, bsize=1, tsize=None): | |
| if tsize is not None: | |
| t.total = tsize | |
| t.update((b - last_b[0]) * bsize) | |
| last_b[0] = b | |
| return update_to | |
| def download(url, path): | |
| with tqdm(unit="B", unit_scale=True, unit_divisor=1024, miniters=1, desc=path) as t: | |
| urlretrieve(url, filename=path, reporthook=download_progress(t), data=None) | |
| class FourierUnit(nn.Module): | |
| def __init__(self, in_channels, out_channels, groups=1): | |
| super(FourierUnit, self).__init__() | |
| self.groups = groups | |
| self.conv_layer = torch.nn.Conv2d( | |
| in_channels=in_channels * 2, | |
| out_channels=out_channels * 2, | |
| kernel_size=1, | |
| stride=1, | |
| padding=0, | |
| groups=self.groups, | |
| bias=False, | |
| ) | |
| self.bn = torch.nn.BatchNorm2d(out_channels * 2) | |
| self.relu = torch.nn.ReLU(inplace=True) | |
| def forward(self, x): | |
| batch = x.shape[0] | |
| # (batch, c, h, w/2+1, 2) | |
| fft_dim = (-2, -1) | |
| ffted = torch.fft.rfftn(x, dim=fft_dim, norm="ortho") | |
| ffted = torch.stack((ffted.real, ffted.imag), dim=-1) | |
| ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1) | |
| ffted = ffted.view((batch, -1) + ffted.size()[3:]) | |
| ffted = self.conv_layer(ffted) # (batch, c*2, h, w/2+1) | |
| ffted = self.relu(self.bn(ffted)) | |
| # (batch,c, t, h, w/2+1, 2) | |
| ffted = ffted.view((batch, -1, 2) + ffted.size()[2:]).permute(0, 1, 3, 4, 2).contiguous() | |
| ffted = torch.complex(ffted[..., 0], ffted[..., 1]) | |
| ifft_shape_slice = x.shape[-2:] | |
| output = torch.fft.irfftn(ffted, s=ifft_shape_slice, dim=fft_dim, norm="ortho") | |
| return output | |
| class SpectralTransform(nn.Module): | |
| def __init__(self, in_channels, out_channels, stride=1, groups=1): | |
| super(SpectralTransform, self).__init__() | |
| self.stride = stride | |
| if stride == 2: | |
| self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2) | |
| else: | |
| self.downsample = nn.Identity() | |
| self.conv1 = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False), | |
| nn.BatchNorm2d(out_channels // 2), | |
| nn.ReLU(inplace=True), | |
| ) | |
| self.fu = FourierUnit(out_channels // 2, out_channels // 2, groups) | |
| self.conv2 = torch.nn.Conv2d(out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False) | |
| def forward(self, x): | |
| x = self.downsample(x) | |
| x = self.conv1(x) | |
| output = self.fu(x) | |
| output = self.conv2(x + output) | |
| return output | |
| class FFC(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| ratio_gin, | |
| ratio_gout, | |
| stride=1, | |
| padding=0, | |
| dilation=1, | |
| groups=1, | |
| bias=False, | |
| padding_type="reflect", | |
| gated=False, | |
| ): | |
| super(FFC, self).__init__() | |
| assert stride == 1 or stride == 2, "Stride should be 1 or 2." | |
| self.stride = stride | |
| in_cg = int(in_channels * ratio_gin) | |
| in_cl = in_channels - in_cg | |
| out_cg = int(out_channels * ratio_gout) | |
| out_cl = out_channels - out_cg | |
| self.ratio_gin = ratio_gin | |
| self.ratio_gout = ratio_gout | |
| self.global_in_num = in_cg | |
| module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d | |
| self.convl2l = module( | |
| in_cl, out_cl, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type | |
| ) | |
| module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d | |
| self.convl2g = module( | |
| in_cl, out_cg, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type | |
| ) | |
| module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d | |
| self.convg2l = module( | |
| in_cg, out_cl, kernel_size, stride, padding, dilation, groups, bias, padding_mode=padding_type | |
| ) | |
| module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform | |
| self.convg2g = module(in_cg, out_cg, stride, 1 if groups == 1 else groups // 2) | |
| self.gated = gated | |
| module = nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d | |
| self.gate = module(in_channels, 2, 1) | |
| def forward(self, x): | |
| x_l, x_g = x if type(x) is tuple else (x, 0) | |
| out_xl, out_xg = 0, 0 | |
| if self.gated: | |
| total_input_parts = [x_l] | |
| if torch.is_tensor(x_g): | |
| total_input_parts.append(x_g) | |
| total_input = torch.cat(total_input_parts, dim=1) | |
| gates = torch.sigmoid(self.gate(total_input)) | |
| g2l_gate, l2g_gate = gates.chunk(2, dim=1) | |
| else: | |
| g2l_gate, l2g_gate = 1, 1 | |
| if self.ratio_gout != 1: | |
| out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate | |
| if self.ratio_gout != 0: | |
| out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g) | |
| return out_xl, out_xg | |
| class FFC_BN_ACT(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, | |
| out_channels, | |
| kernel_size, | |
| ratio_gin=0, | |
| ratio_gout=0, | |
| stride=1, | |
| padding=0, | |
| dilation=1, | |
| groups=1, | |
| bias=False, | |
| norm_layer=nn.BatchNorm2d, | |
| activation_layer=nn.ReLU, | |
| ): | |
| super(FFC_BN_ACT, self).__init__() | |
| self.ffc = FFC( | |
| in_channels, out_channels, kernel_size, ratio_gin, ratio_gout, stride, padding, dilation, groups, bias | |
| ) | |
| lnorm = nn.Identity if ratio_gout == 1 else norm_layer | |
| gnorm = nn.Identity if ratio_gout == 0 else norm_layer | |
| global_channels = int(out_channels * ratio_gout) | |
| self.bn_l = lnorm(out_channels - global_channels) | |
| self.bn_g = gnorm(global_channels) | |
| lact = nn.Identity if ratio_gout == 1 else activation_layer | |
| gact = nn.Identity if ratio_gout == 0 else activation_layer | |
| self.act_l = lact(inplace=True) | |
| self.act_g = gact(inplace=True) | |
| def forward(self, x): | |
| x_l, x_g = self.ffc(x) | |
| x_l = self.act_l(self.bn_l(x_l)) | |
| x_g = self.act_g(self.bn_g(x_g)) | |
| return x_l, x_g | |
| class FFCResnetBlock(nn.Module): | |
| def __init__(self, dim, ratio_gin, ratio_gout): | |
| super().__init__() | |
| self.conv1 = FFC_BN_ACT( | |
| dim, dim, kernel_size=3, padding=1, dilation=1, ratio_gin=ratio_gin, ratio_gout=ratio_gout | |
| ) | |
| self.conv2 = FFC_BN_ACT( | |
| dim, dim, kernel_size=3, padding=1, dilation=1, ratio_gin=ratio_gin, ratio_gout=ratio_gout | |
| ) | |
| def forward(self, x): | |
| x_l, x_g = x if type(x) is tuple else (x, 0) | |
| id_l, id_g = x_l, x_g | |
| x_l, x_g = self.conv1((x_l, x_g)) | |
| x_l, x_g = self.conv2((x_l, x_g)) | |
| x_l, x_g = id_l + x_l, id_g + x_g | |
| out = x_l, x_g | |
| return out | |
| class ConcatTupleLayer(nn.Module): | |
| def forward(self, x): | |
| assert isinstance(x, tuple) | |
| x_l, x_g = x | |
| assert torch.is_tensor(x_l) or torch.is_tensor(x_g) | |
| if not torch.is_tensor(x_g): | |
| return x_l | |
| return torch.cat(x, dim=1) | |
| class LargeMaskInpainting(nn.Module): | |
| def __init__(self, input_nc=4, output_nc=3, ngf=64, n_downsampling=3, n_blocks=18, max_features=1024): | |
| super().__init__() | |
| model = [nn.ReflectionPad2d(3), FFC_BN_ACT(input_nc, ngf, kernel_size=7)] | |
| ### downsample | |
| for i in range(n_downsampling): | |
| mult = 2**i | |
| model += [ | |
| FFC_BN_ACT( | |
| min(max_features, ngf * mult), | |
| min(max_features, ngf * mult * 2), | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| ratio_gout=0.75 if i == n_downsampling - 1 else 0, | |
| ) | |
| ] | |
| ### resnet blocks | |
| for i in range(n_blocks): | |
| cur_resblock = FFCResnetBlock(min(max_features, ngf * 2**n_downsampling), ratio_gin=0.75, ratio_gout=0.75) | |
| model += [cur_resblock] | |
| model += [ConcatTupleLayer()] | |
| ### upsample | |
| for i in range(n_downsampling): | |
| mult = 2 ** (n_downsampling - i) | |
| model += [ | |
| nn.ConvTranspose2d( | |
| min(max_features, ngf * mult), | |
| min(max_features, int(ngf * mult / 2)), | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| output_padding=1, | |
| ), | |
| nn.BatchNorm2d(min(max_features, int(ngf * mult / 2))), | |
| nn.ReLU(True), | |
| ] | |
| model += [nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7), nn.Sigmoid()] | |
| self.model = nn.Sequential(*model) | |
| def forward(self, img, mask): | |
| masked_img = img * (1 - mask) | |
| masked_img = torch.cat([masked_img, mask], dim=1) | |
| pred = self.model(masked_img) | |
| inpainted = mask * pred + (1 - mask) * img | |
| return inpainted | |
| def inpaint_watermark(imgs): | |
| if not os.path.exists(LAMA_PATH): | |
| download(LAMA_URL, LAMA_PATH) | |
| mask = to_tensor(Image.open("./utils/mask.png").convert("L")).unsqueeze(0).to(imgs.device) | |
| if mask.shape[-1] != imgs.shape[-1]: | |
| mask = F.interpolate(mask, size=(imgs.shape[2], imgs.shape[3]), mode="nearest") | |
| mask = mask.expand(imgs.shape[0], 1, mask.shape[2], mask.shape[3]) | |
| model = LargeMaskInpainting().to(imgs.device) | |
| state_dict = torch.load(LAMA_PATH, map_location=imgs.device)["state_dict"] | |
| g_dict = {k.replace("generator.", ""): v for k, v in state_dict.items() if k.startswith("generator")} | |
| model.load_state_dict(g_dict) | |
| inpainted = model.forward(imgs, mask) | |
| return inpainted | |
| if __name__ == "__main__": | |
| import decord | |
| decord.bridge.set_bridge("torch") | |
| if len(sys.argv) < 2: | |
| print("Usage: python -m utils.lama <path/to/video>") | |
| sys.exit(1) | |
| video_path = sys.argv[1] | |
| out_path = video_path.replace(".mp4", " inpainted.mp4") | |
| vr = decord.VideoReader(video_path) | |
| fps = vr.get_avg_fps() | |
| video = rearrange(vr[:], "f h w c -> f c h w").div(255) | |
| inpainted = inpaint_watermark(video) | |
| inpainted = rearrange(inpainted, "f c h w -> f h w c").clamp(0, 1).mul(255).byte().cpu().numpy() | |
| export_to_video(inpainted, out_path, fps) | |