Spaces:
Runtime error
Runtime error
Update
Browse files- .gitignore +0 -1
- .gitmodules +3 -0
- ELITE +1 -0
- README.md +3 -3
- model.py +4 -16
.gitignore
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
ELITE/
|
| 2 |
|
| 3 |
# Byte-compiled / optimized / DLL files
|
| 4 |
__pycache__/
|
|
|
|
|
|
|
| 1 |
|
| 2 |
# Byte-compiled / optimized / DLL files
|
| 3 |
__pycache__/
|
.gitmodules
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[submodule "ELITE"]
|
| 2 |
+
path = ELITE
|
| 3 |
+
url = https://huggingface.co/ELITE-library/ELITE
|
ELITE
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Subproject commit 9f563c699684b8b44358b0ab2f5dafd0a5af24b1
|
README.md
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
---
|
| 2 |
title: ELITE
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 3.20.1
|
| 8 |
app_file: app.py
|
|
|
|
| 1 |
---
|
| 2 |
title: ELITE
|
| 3 |
+
emoji: π
|
| 4 |
+
colorFrom: green
|
| 5 |
+
colorTo: green
|
| 6 |
sdk: gradio
|
| 7 |
sdk_version: 3.20.1
|
| 8 |
app_file: app.py
|
model.py
CHANGED
|
@@ -1,6 +1,5 @@
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
| 3 |
-
import os
|
| 4 |
import pathlib
|
| 5 |
import random
|
| 6 |
import sys
|
|
@@ -15,17 +14,11 @@ import torch.nn.functional as F
|
|
| 15 |
import torchvision.transforms as T
|
| 16 |
import tqdm.auto
|
| 17 |
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
|
| 18 |
-
from huggingface_hub import hf_hub_download
|
| 19 |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
| 20 |
|
| 21 |
-
HF_TOKEN = os.getenv('HF_TOKEN')
|
| 22 |
-
|
| 23 |
repo_dir = pathlib.Path(__file__).parent
|
| 24 |
submodule_dir = repo_dir / 'ELITE'
|
| 25 |
-
snapshot_download('ELITE-library/ELITE',
|
| 26 |
-
repo_type='model',
|
| 27 |
-
local_dir=submodule_dir.as_posix(),
|
| 28 |
-
token=HF_TOKEN)
|
| 29 |
sys.path.insert(0, submodule_dir.as_posix())
|
| 30 |
|
| 31 |
from train_local import (Mapper, MapperLocal, inj_forward_crossattention,
|
|
@@ -64,13 +57,11 @@ class Model:
|
|
| 64 |
global_mapper_path = hf_hub_download('ELITE-library/ELITE',
|
| 65 |
'global_mapper.pt',
|
| 66 |
subfolder='checkpoints',
|
| 67 |
-
repo_type='model'
|
| 68 |
-
token=HF_TOKEN)
|
| 69 |
local_mapper_path = hf_hub_download('ELITE-library/ELITE',
|
| 70 |
'local_mapper.pt',
|
| 71 |
subfolder='checkpoints',
|
| 72 |
-
repo_type='model'
|
| 73 |
-
token=HF_TOKEN)
|
| 74 |
return global_mapper_path, local_mapper_path
|
| 75 |
|
| 76 |
def load_model(
|
|
@@ -139,10 +130,7 @@ class Model:
|
|
| 139 |
mapper_local.add_module(f'{_name.replace(".", "_")}_to_k',
|
| 140 |
to_k_local)
|
| 141 |
|
| 142 |
-
|
| 143 |
-
global_mapper_path = submodule_dir / 'checkpoints/global_mapper.pt'
|
| 144 |
-
local_mapper_path = submodule_dir / 'checkpoints/local_mapper.pt'
|
| 145 |
-
|
| 146 |
mapper.load_state_dict(
|
| 147 |
torch.load(global_mapper_path, map_location='cpu'))
|
| 148 |
mapper.half()
|
|
|
|
| 1 |
from __future__ import annotations
|
| 2 |
|
|
|
|
| 3 |
import pathlib
|
| 4 |
import random
|
| 5 |
import sys
|
|
|
|
| 14 |
import torchvision.transforms as T
|
| 15 |
import tqdm.auto
|
| 16 |
from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
|
| 17 |
+
from huggingface_hub import hf_hub_download
|
| 18 |
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
| 19 |
|
|
|
|
|
|
|
| 20 |
repo_dir = pathlib.Path(__file__).parent
|
| 21 |
submodule_dir = repo_dir / 'ELITE'
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
sys.path.insert(0, submodule_dir.as_posix())
|
| 23 |
|
| 24 |
from train_local import (Mapper, MapperLocal, inj_forward_crossattention,
|
|
|
|
| 57 |
global_mapper_path = hf_hub_download('ELITE-library/ELITE',
|
| 58 |
'global_mapper.pt',
|
| 59 |
subfolder='checkpoints',
|
| 60 |
+
repo_type='model')
|
|
|
|
| 61 |
local_mapper_path = hf_hub_download('ELITE-library/ELITE',
|
| 62 |
'local_mapper.pt',
|
| 63 |
subfolder='checkpoints',
|
| 64 |
+
repo_type='model')
|
|
|
|
| 65 |
return global_mapper_path, local_mapper_path
|
| 66 |
|
| 67 |
def load_model(
|
|
|
|
| 130 |
mapper_local.add_module(f'{_name.replace(".", "_")}_to_k',
|
| 131 |
to_k_local)
|
| 132 |
|
| 133 |
+
global_mapper_path, local_mapper_path = self.download_mappers()
|
|
|
|
|
|
|
|
|
|
| 134 |
mapper.load_state_dict(
|
| 135 |
torch.load(global_mapper_path, map_location='cpu'))
|
| 136 |
mapper.half()
|