inoculatemedia commited on
Commit
4086d20
·
verified ·
1 Parent(s): 32ffbb9

Upload 11 files

Browse files
LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
kokoro/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = '0.9.4'
2
+
3
+ from loguru import logger
4
+ import sys
5
+
6
+ # Remove default handler
7
+ logger.remove()
8
+
9
+ # Add custom handler with clean format including module and line number
10
+ logger.add(
11
+ sys.stderr,
12
+ format="<green>{time:HH:mm:ss}</green> | <cyan>{module:>16}:{line}</cyan> | <level>{level: >8}</level> | <level>{message}</level>",
13
+ colorize=True,
14
+ level="INFO" # "DEBUG" to enable logger.debug("message") and up prints
15
+ # "ERROR" to enable only logger.error("message") prints
16
+ # etc
17
+ )
18
+
19
+ # Disable before release or as needed
20
+ logger.disable("kokoro")
21
+
22
+ from .model import KModel
23
+ from .pipeline import KPipeline
kokoro/__main__.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Kokoro TTS CLI
2
+ Example usage:
3
+ python3 -m kokoro --text "The sky above the port was the color of television, tuned to a dead channel." -o file.wav --debug
4
+
5
+ echo "Bom dia mundo, como vão vocês" > text.txt
6
+ python3 -m kokoro -i text.txt -l p --voice pm_alex > audio.wav
7
+
8
+ Common issues:
9
+ pip not installed: `uv pip install pip`
10
+ (Temporary workaround while https://github.com/explosion/spaCy/issues/13747 is not fixed)
11
+
12
+ espeak not installed: `apt-get install espeak-ng`
13
+ """
14
+
15
+ import argparse
16
+ import wave
17
+ from pathlib import Path
18
+ from typing import Generator, TYPE_CHECKING
19
+
20
+ import numpy as np
21
+ from loguru import logger
22
+
23
+ languages = [
24
+ "a", # American English
25
+ "b", # British English
26
+ "h", # Hindi
27
+ "e", # Spanish
28
+ "f", # French
29
+ "i", # Italian
30
+ "p", # Brazilian Portuguese
31
+ "j", # Japanese
32
+ "z", # Mandarin Chinese
33
+ ]
34
+
35
+ if TYPE_CHECKING:
36
+ from kokoro import KPipeline
37
+
38
+
39
+ def generate_audio(
40
+ text: str, kokoro_language: str, voice: str, speed=1
41
+ ) -> Generator["KPipeline.Result", None, None]:
42
+ from kokoro import KPipeline
43
+
44
+ if not voice.startswith(kokoro_language):
45
+ logger.warning(f"Voice {voice} is not made for language {kokoro_language}")
46
+ pipeline = KPipeline(lang_code=kokoro_language)
47
+ yield from pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+")
48
+
49
+
50
+ def generate_and_save_audio(
51
+ output_file: Path, text: str, kokoro_language: str, voice: str, speed=1
52
+ ) -> None:
53
+ with wave.open(str(output_file.resolve()), "wb") as wav_file:
54
+ wav_file.setnchannels(1) # Mono audio
55
+ wav_file.setsampwidth(2) # 2 bytes per sample (16-bit audio)
56
+ wav_file.setframerate(24000) # Sample rate
57
+
58
+ for result in generate_audio(
59
+ text, kokoro_language=kokoro_language, voice=voice, speed=speed
60
+ ):
61
+ logger.debug(result.phonemes)
62
+ if result.audio is None:
63
+ continue
64
+ audio_bytes = (result.audio.numpy() * 32767).astype(np.int16).tobytes()
65
+ wav_file.writeframes(audio_bytes)
66
+
67
+
68
+ def main() -> None:
69
+ parser = argparse.ArgumentParser()
70
+ parser.add_argument(
71
+ "-m",
72
+ "--voice",
73
+ default="af_heart",
74
+ help="Voice to use",
75
+ )
76
+ parser.add_argument(
77
+ "-l",
78
+ "--language",
79
+ help="Language to use (defaults to the one corresponding to the voice)",
80
+ choices=languages,
81
+ )
82
+ parser.add_argument(
83
+ "-o",
84
+ "--output-file",
85
+ "--output_file",
86
+ type=Path,
87
+ help="Path to output WAV file",
88
+ required=True,
89
+ )
90
+ parser.add_argument(
91
+ "-i",
92
+ "--input-file",
93
+ "--input_file",
94
+ type=Path,
95
+ help="Path to input text file (default: stdin)",
96
+ )
97
+ parser.add_argument(
98
+ "-t",
99
+ "--text",
100
+ help="Text to use instead of reading from stdin",
101
+ )
102
+ parser.add_argument(
103
+ "-s",
104
+ "--speed",
105
+ type=float,
106
+ default=1.0,
107
+ help="Speech speed",
108
+ )
109
+ parser.add_argument(
110
+ "--debug",
111
+ action="store_true",
112
+ help="Print DEBUG messages to console",
113
+ )
114
+ args = parser.parse_args()
115
+ if args.debug:
116
+ logger.level("DEBUG")
117
+ logger.debug(args)
118
+
119
+ lang = args.language or args.voice[0]
120
+
121
+ if args.text is not None and args.input_file is not None:
122
+ raise Exception("You cannot specify both 'text' and 'input_file'")
123
+ elif args.text:
124
+ text = args.text
125
+ elif args.input_file:
126
+ file: Path = args.input_file
127
+ text = file.read_text()
128
+ else:
129
+ import sys
130
+ print("Press Ctrl+D to stop reading input and start generating", flush=True)
131
+ text = '\n'.join(sys.stdin)
132
+
133
+ logger.debug(f"Input text: {text!r}")
134
+
135
+ out_file: Path = args.output_file
136
+ if not out_file.suffix == ".wav":
137
+ logger.warning("The output file name should end with .wav")
138
+ generate_and_save_audio(
139
+ output_file=out_file,
140
+ text=text,
141
+ kokoro_language=lang,
142
+ voice=args.voice,
143
+ speed=args.speed,
144
+ )
145
+
146
+
147
+ if __name__ == "__main__":
148
+ main()
kokoro/custom_stft.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from attr import attr
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ class CustomSTFT(nn.Module):
8
+ """
9
+ STFT/iSTFT without unfold/complex ops, using conv1d and conv_transpose1d.
10
+
11
+ - forward STFT => Real-part conv1d + Imag-part conv1d
12
+ - inverse STFT => Real-part conv_transpose1d + Imag-part conv_transpose1d + sum
13
+ - avoids F.unfold, so easier to export to ONNX
14
+ - uses replicate or constant padding for 'center=True' to approximate 'reflect'
15
+ (reflect is not supported for dynamic shapes in ONNX)
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ filter_length=800,
21
+ hop_length=200,
22
+ win_length=800,
23
+ window="hann",
24
+ center=True,
25
+ pad_mode="replicate", # or 'constant'
26
+ ):
27
+ super().__init__()
28
+ self.filter_length = filter_length
29
+ self.hop_length = hop_length
30
+ self.win_length = win_length
31
+ self.n_fft = filter_length
32
+ self.center = center
33
+ self.pad_mode = pad_mode
34
+
35
+ # Number of frequency bins for real-valued STFT with onesided=True
36
+ self.freq_bins = self.n_fft // 2 + 1
37
+
38
+ # Build window
39
+ assert window == 'hann', window
40
+ window_tensor = torch.hann_window(win_length, periodic=True, dtype=torch.float32)
41
+ if self.win_length < self.n_fft:
42
+ # Zero-pad up to n_fft
43
+ extra = self.n_fft - self.win_length
44
+ window_tensor = F.pad(window_tensor, (0, extra))
45
+ elif self.win_length > self.n_fft:
46
+ window_tensor = window_tensor[: self.n_fft]
47
+ self.register_buffer("window", window_tensor)
48
+
49
+ # Precompute forward DFT (real, imag)
50
+ # PyTorch stft uses e^{-j 2 pi k n / N} => real=cos(...), imag=-sin(...)
51
+ n = np.arange(self.n_fft)
52
+ k = np.arange(self.freq_bins)
53
+ angle = 2 * np.pi * np.outer(k, n) / self.n_fft # shape (freq_bins, n_fft)
54
+ dft_real = np.cos(angle)
55
+ dft_imag = -np.sin(angle) # note negative sign
56
+
57
+ # Combine window and dft => shape (freq_bins, filter_length)
58
+ # We'll make 2 conv weight tensors of shape (freq_bins, 1, filter_length).
59
+ forward_window = window_tensor.numpy() # shape (n_fft,)
60
+ forward_real = dft_real * forward_window # (freq_bins, n_fft)
61
+ forward_imag = dft_imag * forward_window
62
+
63
+ # Convert to PyTorch
64
+ forward_real_torch = torch.from_numpy(forward_real).float()
65
+ forward_imag_torch = torch.from_numpy(forward_imag).float()
66
+
67
+ # Register as Conv1d weight => (out_channels, in_channels, kernel_size)
68
+ # out_channels = freq_bins, in_channels=1, kernel_size=n_fft
69
+ self.register_buffer(
70
+ "weight_forward_real", forward_real_torch.unsqueeze(1)
71
+ )
72
+ self.register_buffer(
73
+ "weight_forward_imag", forward_imag_torch.unsqueeze(1)
74
+ )
75
+
76
+ # Precompute inverse DFT
77
+ # Real iFFT formula => scale = 1/n_fft, doubling for bins 1..freq_bins-2 if n_fft even, etc.
78
+ # For simplicity, we won't do the "DC/nyquist not doubled" approach here.
79
+ # If you want perfect real iSTFT, you can add that logic.
80
+ # This version just yields good approximate reconstruction with Hann + typical overlap.
81
+ inv_scale = 1.0 / self.n_fft
82
+ n = np.arange(self.n_fft)
83
+ angle_t = 2 * np.pi * np.outer(n, k) / self.n_fft # shape (n_fft, freq_bins)
84
+ idft_cos = np.cos(angle_t).T # => (freq_bins, n_fft)
85
+ idft_sin = np.sin(angle_t).T # => (freq_bins, n_fft)
86
+
87
+ # Multiply by window again for typical overlap-add
88
+ # We also incorporate the scale factor 1/n_fft
89
+ inv_window = window_tensor.numpy() * inv_scale
90
+ backward_real = idft_cos * inv_window # (freq_bins, n_fft)
91
+ backward_imag = idft_sin * inv_window
92
+
93
+ # We'll implement iSTFT as real+imag conv_transpose with stride=hop.
94
+ self.register_buffer(
95
+ "weight_backward_real", torch.from_numpy(backward_real).float().unsqueeze(1)
96
+ )
97
+ self.register_buffer(
98
+ "weight_backward_imag", torch.from_numpy(backward_imag).float().unsqueeze(1)
99
+ )
100
+
101
+
102
+
103
+ def transform(self, waveform: torch.Tensor):
104
+ """
105
+ Forward STFT => returns magnitude, phase
106
+ Output shape => (batch, freq_bins, frames)
107
+ """
108
+ # waveform shape => (B, T). conv1d expects (B, 1, T).
109
+ # Optional center pad
110
+ if self.center:
111
+ pad_len = self.n_fft // 2
112
+ waveform = F.pad(waveform, (pad_len, pad_len), mode=self.pad_mode)
113
+
114
+ x = waveform.unsqueeze(1) # => (B, 1, T)
115
+ # Convolution to get real part => shape (B, freq_bins, frames)
116
+ real_out = F.conv1d(
117
+ x,
118
+ self.weight_forward_real,
119
+ bias=None,
120
+ stride=self.hop_length,
121
+ padding=0,
122
+ )
123
+ # Imag part
124
+ imag_out = F.conv1d(
125
+ x,
126
+ self.weight_forward_imag,
127
+ bias=None,
128
+ stride=self.hop_length,
129
+ padding=0,
130
+ )
131
+
132
+ # magnitude, phase
133
+ magnitude = torch.sqrt(real_out**2 + imag_out**2 + 1e-14)
134
+ phase = torch.atan2(imag_out, real_out)
135
+ # Handle the case where imag_out is 0 and real_out is negative to correct ONNX atan2 to match PyTorch
136
+ # In this case, PyTorch returns pi, ONNX returns -pi
137
+ correction_mask = (imag_out == 0) & (real_out < 0)
138
+ phase[correction_mask] = torch.pi
139
+ return magnitude, phase
140
+
141
+
142
+ def inverse(self, magnitude: torch.Tensor, phase: torch.Tensor, length=None):
143
+ """
144
+ Inverse STFT => returns waveform shape (B, T).
145
+ """
146
+ # magnitude, phase => (B, freq_bins, frames)
147
+ # Re-create real/imag => shape (B, freq_bins, frames)
148
+ real_part = magnitude * torch.cos(phase)
149
+ imag_part = magnitude * torch.sin(phase)
150
+
151
+ # conv_transpose wants shape (B, freq_bins, frames). We'll treat "frames" as time dimension
152
+ # so we do (B, freq_bins, frames) => (B, freq_bins, frames)
153
+ # But PyTorch conv_transpose1d expects (B, in_channels, input_length)
154
+ real_part = real_part # (B, freq_bins, frames)
155
+ imag_part = imag_part
156
+
157
+ # real iSTFT => convolve with "backward_real", "backward_imag", and sum
158
+ # We'll do 2 conv_transpose calls, each giving (B, 1, time),
159
+ # then add them => (B, 1, time).
160
+ real_rec = F.conv_transpose1d(
161
+ real_part,
162
+ self.weight_backward_real, # shape (freq_bins, 1, filter_length)
163
+ bias=None,
164
+ stride=self.hop_length,
165
+ padding=0,
166
+ )
167
+ imag_rec = F.conv_transpose1d(
168
+ imag_part,
169
+ self.weight_backward_imag,
170
+ bias=None,
171
+ stride=self.hop_length,
172
+ padding=0,
173
+ )
174
+ # sum => (B, 1, time)
175
+ waveform = real_rec - imag_rec # typical real iFFT has minus for imaginary part
176
+
177
+ # If we used "center=True" in forward, we should remove pad
178
+ if self.center:
179
+ pad_len = self.n_fft // 2
180
+ # Because of transposed convolution, total length might have extra samples
181
+ # We remove `pad_len` from start & end if possible
182
+ waveform = waveform[..., pad_len:-pad_len]
183
+
184
+ # If a specific length is desired, clamp
185
+ if length is not None:
186
+ waveform = waveform[..., :length]
187
+
188
+ # shape => (B, T)
189
+ return waveform
190
+
191
+ def forward(self, x: torch.Tensor):
192
+ """
193
+ Full STFT -> iSTFT pass: returns time-domain reconstruction.
194
+ Same interface as your original code.
195
+ """
196
+ mag, phase = self.transform(x)
197
+ return self.inverse(mag, phase, length=x.shape[-1])
kokoro/istftnet.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ADAPTED from https://github.com/yl4579/StyleTTS2/blob/main/Modules/istftnet.py
2
+ from kokoro.custom_stft import CustomSTFT
3
+ from torch.nn.utils import weight_norm
4
+ import math
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ # https://github.com/yl4579/StyleTTS2/blob/main/Modules/utils.py
11
+ def init_weights(m, mean=0.0, std=0.01):
12
+ classname = m.__class__.__name__
13
+ if classname.find("Conv") != -1:
14
+ m.weight.data.normal_(mean, std)
15
+
16
+ def get_padding(kernel_size, dilation=1):
17
+ return int((kernel_size*dilation - dilation)/2)
18
+
19
+
20
+ class AdaIN1d(nn.Module):
21
+ def __init__(self, style_dim, num_features):
22
+ super().__init__()
23
+ # affine should be False, however there's a bug in the old torch.onnx.export (not newer dynamo) that causes the channel dimension to be lost if affine=False. When affine is true, there's additional learnably parameters. This shouldn't really matter setting it to True, since we're in inference mode
24
+ self.norm = nn.InstanceNorm1d(num_features, affine=True)
25
+ self.fc = nn.Linear(style_dim, num_features*2)
26
+
27
+ def forward(self, x, s):
28
+ h = self.fc(s)
29
+ h = h.view(h.size(0), h.size(1), 1)
30
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
31
+ return (1 + gamma) * self.norm(x) + beta
32
+
33
+
34
+ class AdaINResBlock1(nn.Module):
35
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), style_dim=64):
36
+ super(AdaINResBlock1, self).__init__()
37
+ self.convs1 = nn.ModuleList([
38
+ weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[0],
39
+ padding=get_padding(kernel_size, dilation[0]))),
40
+ weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[1],
41
+ padding=get_padding(kernel_size, dilation[1]))),
42
+ weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation[2],
43
+ padding=get_padding(kernel_size, dilation[2])))
44
+ ])
45
+ self.convs1.apply(init_weights)
46
+ self.convs2 = nn.ModuleList([
47
+ weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
48
+ padding=get_padding(kernel_size, 1))),
49
+ weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
50
+ padding=get_padding(kernel_size, 1))),
51
+ weight_norm(nn.Conv1d(channels, channels, kernel_size, 1, dilation=1,
52
+ padding=get_padding(kernel_size, 1)))
53
+ ])
54
+ self.convs2.apply(init_weights)
55
+ self.adain1 = nn.ModuleList([
56
+ AdaIN1d(style_dim, channels),
57
+ AdaIN1d(style_dim, channels),
58
+ AdaIN1d(style_dim, channels),
59
+ ])
60
+ self.adain2 = nn.ModuleList([
61
+ AdaIN1d(style_dim, channels),
62
+ AdaIN1d(style_dim, channels),
63
+ AdaIN1d(style_dim, channels),
64
+ ])
65
+ self.alpha1 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs1))])
66
+ self.alpha2 = nn.ParameterList([nn.Parameter(torch.ones(1, channels, 1)) for i in range(len(self.convs2))])
67
+
68
+ def forward(self, x, s):
69
+ for c1, c2, n1, n2, a1, a2 in zip(self.convs1, self.convs2, self.adain1, self.adain2, self.alpha1, self.alpha2):
70
+ xt = n1(x, s)
71
+ xt = xt + (1 / a1) * (torch.sin(a1 * xt) ** 2) # Snake1D
72
+ xt = c1(xt)
73
+ xt = n2(xt, s)
74
+ xt = xt + (1 / a2) * (torch.sin(a2 * xt) ** 2) # Snake1D
75
+ xt = c2(xt)
76
+ x = xt + x
77
+ return x
78
+
79
+
80
+ class TorchSTFT(nn.Module):
81
+ def __init__(self, filter_length=800, hop_length=200, win_length=800, window='hann'):
82
+ super().__init__()
83
+ self.filter_length = filter_length
84
+ self.hop_length = hop_length
85
+ self.win_length = win_length
86
+ assert window == 'hann', window
87
+ self.window = torch.hann_window(win_length, periodic=True, dtype=torch.float32)
88
+
89
+ def transform(self, input_data):
90
+ forward_transform = torch.stft(
91
+ input_data,
92
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(input_data.device),
93
+ return_complex=True)
94
+ return torch.abs(forward_transform), torch.angle(forward_transform)
95
+
96
+ def inverse(self, magnitude, phase):
97
+ inverse_transform = torch.istft(
98
+ magnitude * torch.exp(phase * 1j),
99
+ self.filter_length, self.hop_length, self.win_length, window=self.window.to(magnitude.device))
100
+ return inverse_transform.unsqueeze(-2) # unsqueeze to stay consistent with conv_transpose1d implementation
101
+
102
+ def forward(self, input_data):
103
+ self.magnitude, self.phase = self.transform(input_data)
104
+ reconstruction = self.inverse(self.magnitude, self.phase)
105
+ return reconstruction
106
+
107
+
108
+ class SineGen(nn.Module):
109
+ """ Definition of sine generator
110
+ SineGen(samp_rate, harmonic_num = 0,
111
+ sine_amp = 0.1, noise_std = 0.003,
112
+ voiced_threshold = 0,
113
+ flag_for_pulse=False)
114
+ samp_rate: sampling rate in Hz
115
+ harmonic_num: number of harmonic overtones (default 0)
116
+ sine_amp: amplitude of sine-wavefrom (default 0.1)
117
+ noise_std: std of Gaussian noise (default 0.003)
118
+ voiced_thoreshold: F0 threshold for U/V classification (default 0)
119
+ flag_for_pulse: this SinGen is used inside PulseGen (default False)
120
+ Note: when flag_for_pulse is True, the first time step of a voiced
121
+ segment is always sin(torch.pi) or cos(0)
122
+ """
123
+ def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
124
+ sine_amp=0.1, noise_std=0.003,
125
+ voiced_threshold=0,
126
+ flag_for_pulse=False):
127
+ super(SineGen, self).__init__()
128
+ self.sine_amp = sine_amp
129
+ self.noise_std = noise_std
130
+ self.harmonic_num = harmonic_num
131
+ self.dim = self.harmonic_num + 1
132
+ self.sampling_rate = samp_rate
133
+ self.voiced_threshold = voiced_threshold
134
+ self.flag_for_pulse = flag_for_pulse
135
+ self.upsample_scale = upsample_scale
136
+
137
+ def _f02uv(self, f0):
138
+ # generate uv signal
139
+ uv = (f0 > self.voiced_threshold).type(torch.float32)
140
+ return uv
141
+
142
+ def _f02sine(self, f0_values):
143
+ """ f0_values: (batchsize, length, dim)
144
+ where dim indicates fundamental tone and overtones
145
+ """
146
+ # convert to F0 in rad. The interger part n can be ignored
147
+ # because 2 * torch.pi * n doesn't affect phase
148
+ rad_values = (f0_values / self.sampling_rate) % 1
149
+ # initial phase noise (no noise for fundamental component)
150
+ rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
151
+ rand_ini[:, 0] = 0
152
+ rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
153
+ # instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
154
+ if not self.flag_for_pulse:
155
+ rad_values = F.interpolate(rad_values.transpose(1, 2), scale_factor=1/self.upsample_scale, mode="linear").transpose(1, 2)
156
+ phase = torch.cumsum(rad_values, dim=1) * 2 * torch.pi
157
+ phase = F.interpolate(phase.transpose(1, 2) * self.upsample_scale, scale_factor=self.upsample_scale, mode="linear").transpose(1, 2)
158
+ sines = torch.sin(phase)
159
+ else:
160
+ # If necessary, make sure that the first time step of every
161
+ # voiced segments is sin(pi) or cos(0)
162
+ # This is used for pulse-train generation
163
+ # identify the last time step in unvoiced segments
164
+ uv = self._f02uv(f0_values)
165
+ uv_1 = torch.roll(uv, shifts=-1, dims=1)
166
+ uv_1[:, -1, :] = 1
167
+ u_loc = (uv < 1) * (uv_1 > 0)
168
+ # get the instantanouse phase
169
+ tmp_cumsum = torch.cumsum(rad_values, dim=1)
170
+ # different batch needs to be processed differently
171
+ for idx in range(f0_values.shape[0]):
172
+ temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
173
+ temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
174
+ # stores the accumulation of i.phase within
175
+ # each voiced segments
176
+ tmp_cumsum[idx, :, :] = 0
177
+ tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
178
+ # rad_values - tmp_cumsum: remove the accumulation of i.phase
179
+ # within the previous voiced segment.
180
+ i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
181
+ # get the sines
182
+ sines = torch.cos(i_phase * 2 * torch.pi)
183
+ return sines
184
+
185
+ def forward(self, f0):
186
+ """ sine_tensor, uv = forward(f0)
187
+ input F0: tensor(batchsize=1, length, dim=1)
188
+ f0 for unvoiced steps should be 0
189
+ output sine_tensor: tensor(batchsize=1, length, dim)
190
+ output uv: tensor(batchsize=1, length, 1)
191
+ """
192
+ f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
193
+ # fundamental component
194
+ fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
195
+ # generate sine waveforms
196
+ sine_waves = self._f02sine(fn) * self.sine_amp
197
+ # generate uv signal
198
+ # uv = torch.ones(f0.shape)
199
+ # uv = uv * (f0 > self.voiced_threshold)
200
+ uv = self._f02uv(f0)
201
+ # noise: for unvoiced should be similar to sine_amp
202
+ # std = self.sine_amp/3 -> max value ~ self.sine_amp
203
+ # for voiced regions is self.noise_std
204
+ noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
205
+ noise = noise_amp * torch.randn_like(sine_waves)
206
+ # first: set the unvoiced part to 0 by uv
207
+ # then: additive noise
208
+ sine_waves = sine_waves * uv + noise
209
+ return sine_waves, uv, noise
210
+
211
+
212
+ class SourceModuleHnNSF(nn.Module):
213
+ """ SourceModule for hn-nsf
214
+ SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
215
+ add_noise_std=0.003, voiced_threshod=0)
216
+ sampling_rate: sampling_rate in Hz
217
+ harmonic_num: number of harmonic above F0 (default: 0)
218
+ sine_amp: amplitude of sine source signal (default: 0.1)
219
+ add_noise_std: std of additive Gaussian noise (default: 0.003)
220
+ note that amplitude of noise in unvoiced is decided
221
+ by sine_amp
222
+ voiced_threshold: threhold to set U/V given F0 (default: 0)
223
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
224
+ F0_sampled (batchsize, length, 1)
225
+ Sine_source (batchsize, length, 1)
226
+ noise_source (batchsize, length 1)
227
+ uv (batchsize, length, 1)
228
+ """
229
+ def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
230
+ add_noise_std=0.003, voiced_threshod=0):
231
+ super(SourceModuleHnNSF, self).__init__()
232
+ self.sine_amp = sine_amp
233
+ self.noise_std = add_noise_std
234
+ # to produce sine waveforms
235
+ self.l_sin_gen = SineGen(sampling_rate, upsample_scale, harmonic_num,
236
+ sine_amp, add_noise_std, voiced_threshod)
237
+ # to merge source harmonics into a single excitation
238
+ self.l_linear = nn.Linear(harmonic_num + 1, 1)
239
+ self.l_tanh = nn.Tanh()
240
+
241
+ def forward(self, x):
242
+ """
243
+ Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
244
+ F0_sampled (batchsize, length, 1)
245
+ Sine_source (batchsize, length, 1)
246
+ noise_source (batchsize, length 1)
247
+ """
248
+ # source for harmonic branch
249
+ with torch.no_grad():
250
+ sine_wavs, uv, _ = self.l_sin_gen(x)
251
+ sine_merge = self.l_tanh(self.l_linear(sine_wavs))
252
+ # source for noise branch, in the same shape as uv
253
+ noise = torch.randn_like(uv) * self.sine_amp / 3
254
+ return sine_merge, noise, uv
255
+
256
+
257
+ class Generator(nn.Module):
258
+ def __init__(self, style_dim, resblock_kernel_sizes, upsample_rates, upsample_initial_channel, resblock_dilation_sizes, upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, disable_complex=False):
259
+ super(Generator, self).__init__()
260
+ self.num_kernels = len(resblock_kernel_sizes)
261
+ self.num_upsamples = len(upsample_rates)
262
+ self.m_source = SourceModuleHnNSF(
263
+ sampling_rate=24000,
264
+ upsample_scale=math.prod(upsample_rates) * gen_istft_hop_size,
265
+ harmonic_num=8, voiced_threshod=10)
266
+ self.f0_upsamp = nn.Upsample(scale_factor=math.prod(upsample_rates) * gen_istft_hop_size)
267
+ self.noise_convs = nn.ModuleList()
268
+ self.noise_res = nn.ModuleList()
269
+ self.ups = nn.ModuleList()
270
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
271
+ self.ups.append(weight_norm(
272
+ nn.ConvTranspose1d(upsample_initial_channel//(2**i), upsample_initial_channel//(2**(i+1)),
273
+ k, u, padding=(k-u)//2)))
274
+ self.resblocks = nn.ModuleList()
275
+ for i in range(len(self.ups)):
276
+ ch = upsample_initial_channel//(2**(i+1))
277
+ for j, (k, d) in enumerate(zip(resblock_kernel_sizes,resblock_dilation_sizes)):
278
+ self.resblocks.append(AdaINResBlock1(ch, k, d, style_dim))
279
+ c_cur = upsample_initial_channel // (2 ** (i + 1))
280
+ if i + 1 < len(upsample_rates):
281
+ stride_f0 = math.prod(upsample_rates[i + 1:])
282
+ self.noise_convs.append(nn.Conv1d(
283
+ gen_istft_n_fft + 2, c_cur, kernel_size=stride_f0 * 2, stride=stride_f0, padding=(stride_f0+1) // 2))
284
+ self.noise_res.append(AdaINResBlock1(c_cur, 7, [1,3,5], style_dim))
285
+ else:
286
+ self.noise_convs.append(nn.Conv1d(gen_istft_n_fft + 2, c_cur, kernel_size=1))
287
+ self.noise_res.append(AdaINResBlock1(c_cur, 11, [1,3,5], style_dim))
288
+ self.post_n_fft = gen_istft_n_fft
289
+ self.conv_post = weight_norm(nn.Conv1d(ch, self.post_n_fft + 2, 7, 1, padding=3))
290
+ self.ups.apply(init_weights)
291
+ self.conv_post.apply(init_weights)
292
+ self.reflection_pad = nn.ReflectionPad1d((1, 0))
293
+ self.stft = (
294
+ CustomSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
295
+ if disable_complex
296
+ else TorchSTFT(filter_length=gen_istft_n_fft, hop_length=gen_istft_hop_size, win_length=gen_istft_n_fft)
297
+ )
298
+
299
+ def forward(self, x, s, f0):
300
+ with torch.no_grad():
301
+ f0 = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
302
+ har_source, noi_source, uv = self.m_source(f0)
303
+ har_source = har_source.transpose(1, 2).squeeze(1)
304
+ har_spec, har_phase = self.stft.transform(har_source)
305
+ har = torch.cat([har_spec, har_phase], dim=1)
306
+ for i in range(self.num_upsamples):
307
+ x = F.leaky_relu(x, negative_slope=0.1)
308
+ x_source = self.noise_convs[i](har)
309
+ x_source = self.noise_res[i](x_source, s)
310
+ x = self.ups[i](x)
311
+ if i == self.num_upsamples - 1:
312
+ x = self.reflection_pad(x)
313
+ x = x + x_source
314
+ xs = None
315
+ for j in range(self.num_kernels):
316
+ if xs is None:
317
+ xs = self.resblocks[i*self.num_kernels+j](x, s)
318
+ else:
319
+ xs += self.resblocks[i*self.num_kernels+j](x, s)
320
+ x = xs / self.num_kernels
321
+ x = F.leaky_relu(x)
322
+ x = self.conv_post(x)
323
+ spec = torch.exp(x[:,:self.post_n_fft // 2 + 1, :])
324
+ phase = torch.sin(x[:, self.post_n_fft // 2 + 1:, :])
325
+ return self.stft.inverse(spec, phase)
326
+
327
+
328
+ class UpSample1d(nn.Module):
329
+ def __init__(self, layer_type):
330
+ super().__init__()
331
+ self.layer_type = layer_type
332
+
333
+ def forward(self, x):
334
+ if self.layer_type == 'none':
335
+ return x
336
+ else:
337
+ return F.interpolate(x, scale_factor=2, mode='nearest')
338
+
339
+
340
+ class AdainResBlk1d(nn.Module):
341
+ def __init__(self, dim_in, dim_out, style_dim=64, actv=nn.LeakyReLU(0.2), upsample='none', dropout_p=0.0):
342
+ super().__init__()
343
+ self.actv = actv
344
+ self.upsample_type = upsample
345
+ self.upsample = UpSample1d(upsample)
346
+ self.learned_sc = dim_in != dim_out
347
+ self._build_weights(dim_in, dim_out, style_dim)
348
+ self.dropout = nn.Dropout(dropout_p)
349
+ if upsample == 'none':
350
+ self.pool = nn.Identity()
351
+ else:
352
+ self.pool = weight_norm(nn.ConvTranspose1d(dim_in, dim_in, kernel_size=3, stride=2, groups=dim_in, padding=1, output_padding=1))
353
+
354
+ def _build_weights(self, dim_in, dim_out, style_dim):
355
+ self.conv1 = weight_norm(nn.Conv1d(dim_in, dim_out, 3, 1, 1))
356
+ self.conv2 = weight_norm(nn.Conv1d(dim_out, dim_out, 3, 1, 1))
357
+ self.norm1 = AdaIN1d(style_dim, dim_in)
358
+ self.norm2 = AdaIN1d(style_dim, dim_out)
359
+ if self.learned_sc:
360
+ self.conv1x1 = weight_norm(nn.Conv1d(dim_in, dim_out, 1, 1, 0, bias=False))
361
+
362
+ def _shortcut(self, x):
363
+ x = self.upsample(x)
364
+ if self.learned_sc:
365
+ x = self.conv1x1(x)
366
+ return x
367
+
368
+ def _residual(self, x, s):
369
+ x = self.norm1(x, s)
370
+ x = self.actv(x)
371
+ x = self.pool(x)
372
+ x = self.conv1(self.dropout(x))
373
+ x = self.norm2(x, s)
374
+ x = self.actv(x)
375
+ x = self.conv2(self.dropout(x))
376
+ return x
377
+
378
+ def forward(self, x, s):
379
+ out = self._residual(x, s)
380
+ out = (out + self._shortcut(x)) * torch.rsqrt(torch.tensor(2))
381
+ return out
382
+
383
+
384
+ class Decoder(nn.Module):
385
+ def __init__(self, dim_in, style_dim, dim_out,
386
+ resblock_kernel_sizes,
387
+ upsample_rates,
388
+ upsample_initial_channel,
389
+ resblock_dilation_sizes,
390
+ upsample_kernel_sizes,
391
+ gen_istft_n_fft, gen_istft_hop_size,
392
+ disable_complex=False):
393
+ super().__init__()
394
+ self.encode = AdainResBlk1d(dim_in + 2, 1024, style_dim)
395
+ self.decode = nn.ModuleList()
396
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
397
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
398
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 1024, style_dim))
399
+ self.decode.append(AdainResBlk1d(1024 + 2 + 64, 512, style_dim, upsample=True))
400
+ self.F0_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
401
+ self.N_conv = weight_norm(nn.Conv1d(1, 1, kernel_size=3, stride=2, groups=1, padding=1))
402
+ self.asr_res = nn.Sequential(weight_norm(nn.Conv1d(512, 64, kernel_size=1)))
403
+ self.generator = Generator(style_dim, resblock_kernel_sizes, upsample_rates,
404
+ upsample_initial_channel, resblock_dilation_sizes,
405
+ upsample_kernel_sizes, gen_istft_n_fft, gen_istft_hop_size, disable_complex=disable_complex)
406
+
407
+ def forward(self, asr, F0_curve, N, s):
408
+ F0 = self.F0_conv(F0_curve.unsqueeze(1))
409
+ N = self.N_conv(N.unsqueeze(1))
410
+ x = torch.cat([asr, F0, N], axis=1)
411
+ x = self.encode(x, s)
412
+ asr_res = self.asr_res(asr)
413
+ res = True
414
+ for block in self.decode:
415
+ if res:
416
+ x = torch.cat([x, asr_res, F0, N], axis=1)
417
+ x = block(x, s)
418
+ if block.upsample_type != "none":
419
+ res = False
420
+ x = self.generator(x, s, F0_curve)
421
+ return x
kokoro/model.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .istftnet import Decoder
2
+ from .modules import CustomAlbert, ProsodyPredictor, TextEncoder
3
+ from dataclasses import dataclass
4
+ from huggingface_hub import hf_hub_download
5
+ from loguru import logger
6
+ from transformers import AlbertConfig
7
+ from typing import Dict, Optional, Union
8
+ import json
9
+ import torch
10
+ import os
11
+
12
+ class KModel(torch.nn.Module):
13
+ '''
14
+ KModel is a torch.nn.Module with 2 main responsibilities:
15
+ 1. Init weights, downloading config.json + model.pth from HF if needed
16
+ 2. forward(phonemes: str, ref_s: FloatTensor) -> (audio: FloatTensor)
17
+
18
+ You likely only need one KModel instance, and it can be reused across
19
+ multiple KPipelines to avoid redundant memory allocation.
20
+
21
+ Unlike KPipeline, KModel is language-blind.
22
+
23
+ KModel stores self.vocab and thus knows how to map phonemes -> input_ids,
24
+ so there is no need to repeatedly download config.json outside of KModel.
25
+ '''
26
+
27
+ MODEL_NAMES = {
28
+ 'hexgrad/Kokoro-82M': 'kokoro-v1_0.pth',
29
+ 'hexgrad/Kokoro-82M-v1.1-zh': 'kokoro-v1_1-zh.pth',
30
+ }
31
+
32
+ def __init__(
33
+ self,
34
+ repo_id: Optional[str] = None,
35
+ config: Union[Dict, str, None] = None,
36
+ model: Optional[str] = None,
37
+ disable_complex: bool = False
38
+ ):
39
+ super().__init__()
40
+ if repo_id is None:
41
+ repo_id = 'hexgrad/Kokoro-82M'
42
+ print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
43
+ self.repo_id = repo_id
44
+ if not isinstance(config, dict):
45
+ if not config:
46
+ logger.debug("No config provided, downloading from HF")
47
+ config = hf_hub_download(repo_id=repo_id, filename='config.json')
48
+ with open(config, 'r', encoding='utf-8') as r:
49
+ config = json.load(r)
50
+ logger.debug(f"Loaded config: {config}")
51
+ self.vocab = config['vocab']
52
+ self.bert = CustomAlbert(AlbertConfig(vocab_size=config['n_token'], **config['plbert']))
53
+ self.bert_encoder = torch.nn.Linear(self.bert.config.hidden_size, config['hidden_dim'])
54
+ self.context_length = self.bert.config.max_position_embeddings
55
+ self.predictor = ProsodyPredictor(
56
+ style_dim=config['style_dim'], d_hid=config['hidden_dim'],
57
+ nlayers=config['n_layer'], max_dur=config['max_dur'], dropout=config['dropout']
58
+ )
59
+ self.text_encoder = TextEncoder(
60
+ channels=config['hidden_dim'], kernel_size=config['text_encoder_kernel_size'],
61
+ depth=config['n_layer'], n_symbols=config['n_token']
62
+ )
63
+ self.decoder = Decoder(
64
+ dim_in=config['hidden_dim'], style_dim=config['style_dim'],
65
+ dim_out=config['n_mels'], disable_complex=disable_complex, **config['istftnet']
66
+ )
67
+ if not model:
68
+ try:
69
+ model = hf_hub_download(repo_id=repo_id, filename=KModel.MODEL_NAMES[repo_id])
70
+ except:
71
+ model = os.path.join(repo_id, 'kokoro-v1_0.pth')
72
+ for key, state_dict in torch.load(model, map_location='cpu', weights_only=True).items():
73
+ assert hasattr(self, key), key
74
+ try:
75
+ getattr(self, key).load_state_dict(state_dict)
76
+ except:
77
+ logger.debug(f"Did not load {key} from state_dict")
78
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
79
+ getattr(self, key).load_state_dict(state_dict, strict=False)
80
+
81
+ @property
82
+ def device(self):
83
+ return self.bert.device
84
+
85
+ @dataclass
86
+ class Output:
87
+ audio: torch.FloatTensor
88
+ pred_dur: Optional[torch.LongTensor] = None
89
+
90
+ @torch.no_grad()
91
+ def forward_with_tokens(
92
+ self,
93
+ input_ids: torch.LongTensor,
94
+ ref_s: torch.FloatTensor,
95
+ speed: float = 1
96
+ ) -> tuple[torch.FloatTensor, torch.LongTensor]:
97
+ input_lengths = torch.full(
98
+ (input_ids.shape[0],),
99
+ input_ids.shape[-1],
100
+ device=input_ids.device,
101
+ dtype=torch.long
102
+ )
103
+
104
+ text_mask = torch.arange(input_lengths.max()).unsqueeze(0).expand(input_lengths.shape[0], -1).type_as(input_lengths)
105
+ text_mask = torch.gt(text_mask+1, input_lengths.unsqueeze(1)).to(self.device)
106
+ bert_dur = self.bert(input_ids, attention_mask=(~text_mask).int())
107
+ d_en = self.bert_encoder(bert_dur).transpose(-1, -2)
108
+ s = ref_s[:, 128:]
109
+ d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask)
110
+ x, _ = self.predictor.lstm(d)
111
+ duration = self.predictor.duration_proj(x)
112
+ duration = torch.sigmoid(duration).sum(axis=-1) / speed
113
+ pred_dur = torch.round(duration).clamp(min=1).long().squeeze()
114
+ indices = torch.repeat_interleave(torch.arange(input_ids.shape[1], device=self.device), pred_dur)
115
+ pred_aln_trg = torch.zeros((input_ids.shape[1], indices.shape[0]), device=self.device)
116
+ pred_aln_trg[indices, torch.arange(indices.shape[0])] = 1
117
+ pred_aln_trg = pred_aln_trg.unsqueeze(0).to(self.device)
118
+ en = d.transpose(-1, -2) @ pred_aln_trg
119
+ F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
120
+ t_en = self.text_encoder(input_ids, input_lengths, text_mask)
121
+ asr = t_en @ pred_aln_trg
122
+ audio = self.decoder(asr, F0_pred, N_pred, ref_s[:, :128]).squeeze()
123
+ return audio, pred_dur
124
+
125
+ def forward(
126
+ self,
127
+ phonemes: str,
128
+ ref_s: torch.FloatTensor,
129
+ speed: float = 1,
130
+ return_output: bool = False
131
+ ) -> Union['KModel.Output', torch.FloatTensor]:
132
+ input_ids = list(filter(lambda i: i is not None, map(lambda p: self.vocab.get(p), phonemes)))
133
+ logger.debug(f"phonemes: {phonemes} -> input_ids: {input_ids}")
134
+ assert len(input_ids)+2 <= self.context_length, (len(input_ids)+2, self.context_length)
135
+ input_ids = torch.LongTensor([[0, *input_ids, 0]]).to(self.device)
136
+ ref_s = ref_s.to(self.device)
137
+ audio, pred_dur = self.forward_with_tokens(input_ids, ref_s, speed)
138
+ audio = audio.squeeze().cpu()
139
+ pred_dur = pred_dur.cpu() if pred_dur is not None else None
140
+ logger.debug(f"pred_dur: {pred_dur}")
141
+ return self.Output(audio=audio, pred_dur=pred_dur) if return_output else audio
142
+
143
+ class KModelForONNX(torch.nn.Module):
144
+ def __init__(self, kmodel: KModel):
145
+ super().__init__()
146
+ self.kmodel = kmodel
147
+
148
+ def forward(
149
+ self,
150
+ input_ids: torch.LongTensor,
151
+ ref_s: torch.FloatTensor,
152
+ speed: float = 1
153
+ ) -> tuple[torch.FloatTensor, torch.LongTensor]:
154
+ waveform, duration = self.kmodel.forward_with_tokens(input_ids, ref_s, speed)
155
+ return waveform, duration
kokoro/modules.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/yl4579/StyleTTS2/blob/main/models.py
2
+ from .istftnet import AdainResBlk1d
3
+ from torch.nn.utils import weight_norm
4
+ from transformers import AlbertModel
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class LinearNorm(nn.Module):
12
+ def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
13
+ super(LinearNorm, self).__init__()
14
+ self.linear_layer = nn.Linear(in_dim, out_dim, bias=bias)
15
+ nn.init.xavier_uniform_(self.linear_layer.weight, gain=nn.init.calculate_gain(w_init_gain))
16
+
17
+ def forward(self, x):
18
+ return self.linear_layer(x)
19
+
20
+
21
+ class LayerNorm(nn.Module):
22
+ def __init__(self, channels, eps=1e-5):
23
+ super().__init__()
24
+ self.channels = channels
25
+ self.eps = eps
26
+ self.gamma = nn.Parameter(torch.ones(channels))
27
+ self.beta = nn.Parameter(torch.zeros(channels))
28
+
29
+ def forward(self, x):
30
+ x = x.transpose(1, -1)
31
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
32
+ return x.transpose(1, -1)
33
+
34
+
35
+ class TextEncoder(nn.Module):
36
+ def __init__(self, channels, kernel_size, depth, n_symbols, actv=nn.LeakyReLU(0.2)):
37
+ super().__init__()
38
+ self.embedding = nn.Embedding(n_symbols, channels)
39
+ padding = (kernel_size - 1) // 2
40
+ self.cnn = nn.ModuleList()
41
+ for _ in range(depth):
42
+ self.cnn.append(nn.Sequential(
43
+ weight_norm(nn.Conv1d(channels, channels, kernel_size=kernel_size, padding=padding)),
44
+ LayerNorm(channels),
45
+ actv,
46
+ nn.Dropout(0.2),
47
+ ))
48
+ self.lstm = nn.LSTM(channels, channels//2, 1, batch_first=True, bidirectional=True)
49
+
50
+ def forward(self, x, input_lengths, m):
51
+ x = self.embedding(x) # [B, T, emb]
52
+ x = x.transpose(1, 2) # [B, emb, T]
53
+ m = m.unsqueeze(1)
54
+ x.masked_fill_(m, 0.0)
55
+ for c in self.cnn:
56
+ x = c(x)
57
+ x.masked_fill_(m, 0.0)
58
+ x = x.transpose(1, 2) # [B, T, chn]
59
+ lengths = input_lengths if input_lengths.device == torch.device('cpu') else input_lengths.to('cpu')
60
+ x = nn.utils.rnn.pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False)
61
+ self.lstm.flatten_parameters()
62
+ x, _ = self.lstm(x)
63
+ x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
64
+ x = x.transpose(-1, -2)
65
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device)
66
+ x_pad[:, :, :x.shape[-1]] = x
67
+ x = x_pad
68
+ x.masked_fill_(m, 0.0)
69
+ return x
70
+
71
+
72
+ class AdaLayerNorm(nn.Module):
73
+ def __init__(self, style_dim, channels, eps=1e-5):
74
+ super().__init__()
75
+ self.channels = channels
76
+ self.eps = eps
77
+ self.fc = nn.Linear(style_dim, channels*2)
78
+
79
+ def forward(self, x, s):
80
+ x = x.transpose(-1, -2)
81
+ x = x.transpose(1, -1)
82
+ h = self.fc(s)
83
+ h = h.view(h.size(0), h.size(1), 1)
84
+ gamma, beta = torch.chunk(h, chunks=2, dim=1)
85
+ gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
86
+ x = F.layer_norm(x, (self.channels,), eps=self.eps)
87
+ x = (1 + gamma) * x + beta
88
+ return x.transpose(1, -1).transpose(-1, -2)
89
+
90
+
91
+ class ProsodyPredictor(nn.Module):
92
+ def __init__(self, style_dim, d_hid, nlayers, max_dur=50, dropout=0.1):
93
+ super().__init__()
94
+ self.text_encoder = DurationEncoder(sty_dim=style_dim, d_model=d_hid,nlayers=nlayers, dropout=dropout)
95
+ self.lstm = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
96
+ self.duration_proj = LinearNorm(d_hid, max_dur)
97
+ self.shared = nn.LSTM(d_hid + style_dim, d_hid // 2, 1, batch_first=True, bidirectional=True)
98
+ self.F0 = nn.ModuleList()
99
+ self.F0.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
100
+ self.F0.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
101
+ self.F0.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
102
+ self.N = nn.ModuleList()
103
+ self.N.append(AdainResBlk1d(d_hid, d_hid, style_dim, dropout_p=dropout))
104
+ self.N.append(AdainResBlk1d(d_hid, d_hid // 2, style_dim, upsample=True, dropout_p=dropout))
105
+ self.N.append(AdainResBlk1d(d_hid // 2, d_hid // 2, style_dim, dropout_p=dropout))
106
+ self.F0_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
107
+ self.N_proj = nn.Conv1d(d_hid // 2, 1, 1, 1, 0)
108
+
109
+ def forward(self, texts, style, text_lengths, alignment, m):
110
+ d = self.text_encoder(texts, style, text_lengths, m)
111
+ m = m.unsqueeze(1)
112
+ lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu')
113
+ x = nn.utils.rnn.pack_padded_sequence(d, lengths, batch_first=True, enforce_sorted=False)
114
+ self.lstm.flatten_parameters()
115
+ x, _ = self.lstm(x)
116
+ x, _ = nn.utils.rnn.pad_packed_sequence(x, batch_first=True)
117
+ x_pad = torch.zeros([x.shape[0], m.shape[-1], x.shape[-1]], device=x.device)
118
+ x_pad[:, :x.shape[1], :] = x
119
+ x = x_pad
120
+ duration = self.duration_proj(nn.functional.dropout(x, 0.5, training=False))
121
+ en = (d.transpose(-1, -2) @ alignment)
122
+ return duration.squeeze(-1), en
123
+
124
+ def F0Ntrain(self, x, s):
125
+ x, _ = self.shared(x.transpose(-1, -2))
126
+ F0 = x.transpose(-1, -2)
127
+ for block in self.F0:
128
+ F0 = block(F0, s)
129
+ F0 = self.F0_proj(F0)
130
+ N = x.transpose(-1, -2)
131
+ for block in self.N:
132
+ N = block(N, s)
133
+ N = self.N_proj(N)
134
+ return F0.squeeze(1), N.squeeze(1)
135
+
136
+
137
+ class DurationEncoder(nn.Module):
138
+ def __init__(self, sty_dim, d_model, nlayers, dropout=0.1):
139
+ super().__init__()
140
+ self.lstms = nn.ModuleList()
141
+ for _ in range(nlayers):
142
+ self.lstms.append(nn.LSTM(d_model + sty_dim, d_model // 2, num_layers=1, batch_first=True, bidirectional=True, dropout=dropout))
143
+ self.lstms.append(AdaLayerNorm(sty_dim, d_model))
144
+ self.dropout = dropout
145
+ self.d_model = d_model
146
+ self.sty_dim = sty_dim
147
+
148
+ def forward(self, x, style, text_lengths, m):
149
+ masks = m
150
+ x = x.permute(2, 0, 1)
151
+ s = style.expand(x.shape[0], x.shape[1], -1)
152
+ x = torch.cat([x, s], axis=-1)
153
+ x.masked_fill_(masks.unsqueeze(-1).transpose(0, 1), 0.0)
154
+ x = x.transpose(0, 1)
155
+ x = x.transpose(-1, -2)
156
+ for block in self.lstms:
157
+ if isinstance(block, AdaLayerNorm):
158
+ x = block(x.transpose(-1, -2), style).transpose(-1, -2)
159
+ x = torch.cat([x, s.permute(1, 2, 0)], axis=1)
160
+ x.masked_fill_(masks.unsqueeze(-1).transpose(-1, -2), 0.0)
161
+ else:
162
+ lengths = text_lengths if text_lengths.device == torch.device('cpu') else text_lengths.to('cpu')
163
+ x = x.transpose(-1, -2)
164
+ x = nn.utils.rnn.pack_padded_sequence(
165
+ x, lengths, batch_first=True, enforce_sorted=False)
166
+ block.flatten_parameters()
167
+ x, _ = block(x)
168
+ x, _ = nn.utils.rnn.pad_packed_sequence(
169
+ x, batch_first=True)
170
+ x = F.dropout(x, p=self.dropout, training=False)
171
+ x = x.transpose(-1, -2)
172
+ x_pad = torch.zeros([x.shape[0], x.shape[1], m.shape[-1]], device=x.device)
173
+ x_pad[:, :, :x.shape[-1]] = x
174
+ x = x_pad
175
+
176
+ return x.transpose(-1, -2)
177
+
178
+
179
+ # https://github.com/yl4579/StyleTTS2/blob/main/Utils/PLBERT/util.py
180
+ class CustomAlbert(AlbertModel):
181
+ def forward(self, *args, **kwargs):
182
+ outputs = super().forward(*args, **kwargs)
183
+ return outputs.last_hidden_state
kokoro/pipeline.py ADDED
@@ -0,0 +1,445 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .model import KModel
2
+ from dataclasses import dataclass
3
+ from huggingface_hub import hf_hub_download
4
+ from loguru import logger
5
+ from misaki import en, espeak
6
+ from typing import Callable, Generator, List, Optional, Tuple, Union
7
+ import re
8
+ import torch
9
+ import os
10
+
11
+ ALIASES = {
12
+ 'en-us': 'a',
13
+ 'en-gb': 'b',
14
+ 'es': 'e',
15
+ 'fr-fr': 'f',
16
+ 'hi': 'h',
17
+ 'it': 'i',
18
+ 'pt-br': 'p',
19
+ 'ja': 'j',
20
+ 'zh': 'z',
21
+ }
22
+
23
+ LANG_CODES = dict(
24
+ # pip install misaki[en]
25
+ a='American English',
26
+ b='British English',
27
+
28
+ # espeak-ng
29
+ e='es',
30
+ f='fr-fr',
31
+ h='hi',
32
+ i='it',
33
+ p='pt-br',
34
+
35
+ # pip install misaki[ja]
36
+ j='Japanese',
37
+
38
+ # pip install misaki[zh]
39
+ z='Mandarin Chinese',
40
+ )
41
+
42
+ class KPipeline:
43
+ '''
44
+ KPipeline is a language-aware support class with 2 main responsibilities:
45
+ 1. Perform language-specific G2P, mapping (and chunking) text -> phonemes
46
+ 2. Manage and store voices, lazily downloaded from HF if needed
47
+
48
+ You are expected to have one KPipeline per language. If you have multiple
49
+ KPipelines, you should reuse one KModel instance across all of them.
50
+
51
+ KPipeline is designed to work with a KModel, but this is not required.
52
+ There are 2 ways to pass an existing model into a pipeline:
53
+ 1. On init: us_pipeline = KPipeline(lang_code='a', model=model)
54
+ 2. On call: us_pipeline(text, voice, model=model)
55
+
56
+ By default, KPipeline will automatically initialize its own KModel. To
57
+ suppress this, construct a "quiet" KPipeline with model=False.
58
+
59
+ A "quiet" KPipeline yields (graphemes, phonemes, None) without generating
60
+ any audio. You can use this to phonemize and chunk your text in advance.
61
+
62
+ A "loud" KPipeline _with_ a model yields (graphemes, phonemes, audio).
63
+ '''
64
+ def __init__(
65
+ self,
66
+ lang_code: str,
67
+ repo_id: Optional[str] = None,
68
+ model: Union[KModel, bool] = True,
69
+ trf: bool = False,
70
+ en_callable: Optional[Callable[[str], str]] = None,
71
+ device: Optional[str] = None
72
+ ):
73
+ """Initialize a KPipeline.
74
+
75
+ Args:
76
+ lang_code: Language code for G2P processing
77
+ model: KModel instance, True to create new model, False for no model
78
+ trf: Whether to use transformer-based G2P
79
+ device: Override default device selection ('cuda' or 'cpu', or None for auto)
80
+ If None, will auto-select cuda if available
81
+ If 'cuda' and not available, will explicitly raise an error
82
+ """
83
+ if repo_id is None:
84
+ repo_id = 'hexgrad/Kokoro-82M'
85
+ print(f"WARNING: Defaulting repo_id to {repo_id}. Pass repo_id='{repo_id}' to suppress this warning.")
86
+ config=None
87
+ else:
88
+ config = os.path.join(repo_id, 'config.json')
89
+ self.repo_id = repo_id
90
+ lang_code = lang_code.lower()
91
+ lang_code = ALIASES.get(lang_code, lang_code)
92
+ assert lang_code in LANG_CODES, (lang_code, LANG_CODES)
93
+ self.lang_code = lang_code
94
+ self.model = None
95
+ if isinstance(model, KModel):
96
+ self.model = model
97
+ elif model:
98
+ if device == 'cuda' and not torch.cuda.is_available():
99
+ raise RuntimeError("CUDA requested but not available")
100
+ if device == 'mps' and not torch.backends.mps.is_available():
101
+ raise RuntimeError("MPS requested but not available")
102
+ if device == 'mps' and os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') != '1':
103
+ raise RuntimeError("MPS requested but fallback not enabled")
104
+ if device is None:
105
+ if torch.cuda.is_available():
106
+ device = 'cuda'
107
+ elif os.environ.get('PYTORCH_ENABLE_MPS_FALLBACK') == '1' and torch.backends.mps.is_available():
108
+ device = 'mps'
109
+ else:
110
+ device = 'cpu'
111
+ try:
112
+ self.model = KModel(repo_id=repo_id, config=config).to(device).eval()
113
+ except RuntimeError as e:
114
+ if device == 'cuda':
115
+ raise RuntimeError(f"""Failed to initialize model on CUDA: {e}.
116
+ Try setting device='cpu' or check CUDA installation.""")
117
+ raise
118
+ self.voices = {}
119
+ if lang_code in 'ab':
120
+ try:
121
+ fallback = espeak.EspeakFallback(british=lang_code=='b')
122
+ except Exception as e:
123
+ logger.warning("EspeakFallback not Enabled: OOD words will be skipped")
124
+ logger.warning({str(e)})
125
+ fallback = None
126
+ self.g2p = en.G2P(trf=trf, british=lang_code=='b', fallback=fallback, unk='')
127
+ elif lang_code == 'j':
128
+ try:
129
+ from misaki import ja
130
+ self.g2p = ja.JAG2P()
131
+ except ImportError:
132
+ logger.error("You need to `pip install misaki[ja]` to use lang_code='j'")
133
+ raise
134
+ elif lang_code == 'z':
135
+ try:
136
+ from misaki import zh
137
+ self.g2p = zh.ZHG2P(
138
+ version=None if repo_id.endswith('/Kokoro-82M') else '1.1',
139
+ en_callable=en_callable
140
+ )
141
+ except ImportError:
142
+ logger.error("You need to `pip install misaki[zh]` to use lang_code='z'")
143
+ raise
144
+ else:
145
+ language = LANG_CODES[lang_code]
146
+ logger.warning(f"Using EspeakG2P(language='{language}'). Chunking logic not yet implemented, so long texts may be truncated unless you split them with '\\n'.")
147
+ self.g2p = espeak.EspeakG2P(language=language)
148
+
149
+ def load_single_voice(self, voice: str):
150
+ if voice in self.voices:
151
+ return self.voices[voice]
152
+ if voice.endswith('.pt'):
153
+ f = voice
154
+ else:
155
+ f = hf_hub_download(repo_id=self.repo_id, filename=f'voices/{voice}.pt')
156
+ if not voice.startswith(self.lang_code):
157
+ v = LANG_CODES.get(voice, voice)
158
+ p = LANG_CODES.get(self.lang_code, self.lang_code)
159
+ logger.warning(f'Language mismatch, loading {v} voice into {p} pipeline.')
160
+ pack = torch.load(f, weights_only=True)
161
+ self.voices[voice] = pack
162
+ return pack
163
+
164
+ """
165
+ load_voice is a helper function that lazily downloads and loads a voice:
166
+ Single voice can be requested (e.g. 'af_bella') or multiple voices (e.g. 'af_bella,af_jessica').
167
+ If multiple voices are requested, they are averaged.
168
+ Delimiter is optional and defaults to ','.
169
+ """
170
+ def load_voice(self, voice: Union[str, torch.FloatTensor], delimiter: str = ",") -> torch.FloatTensor:
171
+ if isinstance(voice, torch.FloatTensor):
172
+ return voice
173
+ if voice in self.voices:
174
+ return self.voices[voice]
175
+ logger.debug(f"Loading voice: {voice}")
176
+ packs = [self.load_single_voice(v) for v in voice.split(delimiter)]
177
+ if len(packs) == 1:
178
+ return packs[0]
179
+ self.voices[voice] = torch.mean(torch.stack(packs), dim=0)
180
+ return self.voices[voice]
181
+
182
+ @staticmethod
183
+ def tokens_to_ps(tokens: List[en.MToken]) -> str:
184
+ return ''.join(t.phonemes + (' ' if t.whitespace else '') for t in tokens).strip()
185
+
186
+ @staticmethod
187
+ def waterfall_last(
188
+ tokens: List[en.MToken],
189
+ next_count: int,
190
+ waterfall: List[str] = ['!.?…', ':;', ',—'],
191
+ bumps: List[str] = [')', '”']
192
+ ) -> int:
193
+ for w in waterfall:
194
+ z = next((i for i, t in reversed(list(enumerate(tokens))) if t.phonemes in set(w)), None)
195
+ if z is None:
196
+ continue
197
+ z += 1
198
+ if z < len(tokens) and tokens[z].phonemes in bumps:
199
+ z += 1
200
+ if next_count - len(KPipeline.tokens_to_ps(tokens[:z])) <= 510:
201
+ return z
202
+ return len(tokens)
203
+
204
+ @staticmethod
205
+ def tokens_to_text(tokens: List[en.MToken]) -> str:
206
+ return ''.join(t.text + t.whitespace for t in tokens).strip()
207
+
208
+ def en_tokenize(
209
+ self,
210
+ tokens: List[en.MToken]
211
+ ) -> Generator[Tuple[str, str, List[en.MToken]], None, None]:
212
+ tks = []
213
+ pcount = 0
214
+ for t in tokens:
215
+ # American English: ɾ => T
216
+ t.phonemes = '' if t.phonemes is None else t.phonemes#.replace('ɾ', 'T')
217
+ next_ps = t.phonemes + (' ' if t.whitespace else '')
218
+ next_pcount = pcount + len(next_ps.rstrip())
219
+ if next_pcount > 510:
220
+ z = KPipeline.waterfall_last(tks, next_pcount)
221
+ text = KPipeline.tokens_to_text(tks[:z])
222
+ logger.debug(f"Chunking text at {z}: '{text[:30]}{'...' if len(text) > 30 else ''}'")
223
+ ps = KPipeline.tokens_to_ps(tks[:z])
224
+ yield text, ps, tks[:z]
225
+ tks = tks[z:]
226
+ pcount = len(KPipeline.tokens_to_ps(tks))
227
+ if not tks:
228
+ next_ps = next_ps.lstrip()
229
+ tks.append(t)
230
+ pcount += len(next_ps)
231
+ if tks:
232
+ text = KPipeline.tokens_to_text(tks)
233
+ ps = KPipeline.tokens_to_ps(tks)
234
+ yield ''.join(text).strip(), ''.join(ps).strip(), tks
235
+
236
+ @staticmethod
237
+ def infer(
238
+ model: KModel,
239
+ ps: str,
240
+ pack: torch.FloatTensor,
241
+ speed: Union[float, Callable[[int], float]] = 1
242
+ ) -> KModel.Output:
243
+ if callable(speed):
244
+ speed = speed(len(ps))
245
+ return model(ps, pack[len(ps)-1], speed, return_output=True)
246
+
247
+ def generate_from_tokens(
248
+ self,
249
+ tokens: Union[str, List[en.MToken]],
250
+ voice: str,
251
+ speed: float = 1,
252
+ model: Optional[KModel] = None
253
+ ) -> Generator['KPipeline.Result', None, None]:
254
+ """Generate audio from either raw phonemes or pre-processed tokens.
255
+
256
+ Args:
257
+ tokens: Either a phoneme string or list of pre-processed MTokens
258
+ voice: The voice to use for synthesis
259
+ speed: Speech speed modifier (default: 1)
260
+ model: Optional KModel instance (uses pipeline's model if not provided)
261
+
262
+ Yields:
263
+ KPipeline.Result containing the input tokens and generated audio
264
+
265
+ Raises:
266
+ ValueError: If no voice is provided or token sequence exceeds model limits
267
+ """
268
+ model = model or self.model
269
+ if model and voice is None:
270
+ raise ValueError('Specify a voice: pipeline.generate_from_tokens(..., voice="af_heart")')
271
+
272
+ pack = self.load_voice(voice).to(model.device) if model else None
273
+
274
+ # Handle raw phoneme string
275
+ if isinstance(tokens, str):
276
+ logger.debug("Processing phonemes from raw string")
277
+ if len(tokens) > 510:
278
+ raise ValueError(f'Phoneme string too long: {len(tokens)} > 510')
279
+ output = KPipeline.infer(model, tokens, pack, speed) if model else None
280
+ yield self.Result(graphemes='', phonemes=tokens, output=output)
281
+ return
282
+
283
+ logger.debug("Processing MTokens")
284
+ # Handle pre-processed tokens
285
+ for gs, ps, tks in self.en_tokenize(tokens):
286
+ if not ps:
287
+ continue
288
+ elif len(ps) > 510:
289
+ logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
290
+ logger.warning("Truncating to 510 characters")
291
+ ps = ps[:510]
292
+ output = KPipeline.infer(model, ps, pack, speed) if model else None
293
+ if output is not None and output.pred_dur is not None:
294
+ KPipeline.join_timestamps(tks, output.pred_dur)
295
+ yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output)
296
+
297
+ @staticmethod
298
+ def join_timestamps(tokens: List[en.MToken], pred_dur: torch.LongTensor):
299
+ # Multiply by 600 to go from pred_dur frames to sample_rate 24000
300
+ # Equivalent to dividing pred_dur frames by 40 to get timestamp in seconds
301
+ # We will count nice round half-frames, so the divisor is 80
302
+ MAGIC_DIVISOR = 80
303
+ if not tokens or len(pred_dur) < 3:
304
+ # We expect at least 3: <bos>, token, <eos>
305
+ return
306
+ # We track 2 counts, measured in half-frames: (left, right)
307
+ # This way we can cut space characters in half
308
+ # TODO: Is -3 an appropriate offset?
309
+ left = right = 2 * max(0, pred_dur[0].item() - 3)
310
+ # Updates:
311
+ # left = right + (2 * token_dur) + space_dur
312
+ # right = left + space_dur
313
+ i = 1
314
+ for t in tokens:
315
+ if i >= len(pred_dur)-1:
316
+ break
317
+ if not t.phonemes:
318
+ if t.whitespace:
319
+ i += 1
320
+ left = right + pred_dur[i].item()
321
+ right = left + pred_dur[i].item()
322
+ i += 1
323
+ continue
324
+ j = i + len(t.phonemes)
325
+ if j >= len(pred_dur):
326
+ break
327
+ t.start_ts = left / MAGIC_DIVISOR
328
+ token_dur = pred_dur[i: j].sum().item()
329
+ space_dur = pred_dur[j].item() if t.whitespace else 0
330
+ left = right + (2 * token_dur) + space_dur
331
+ t.end_ts = left / MAGIC_DIVISOR
332
+ right = left + space_dur
333
+ i = j + (1 if t.whitespace else 0)
334
+
335
+ @dataclass
336
+ class Result:
337
+ graphemes: str
338
+ phonemes: str
339
+ tokens: Optional[List[en.MToken]] = None
340
+ output: Optional[KModel.Output] = None
341
+ text_index: Optional[int] = None
342
+
343
+ @property
344
+ def audio(self) -> Optional[torch.FloatTensor]:
345
+ return None if self.output is None else self.output.audio
346
+
347
+ @property
348
+ def pred_dur(self) -> Optional[torch.LongTensor]:
349
+ return None if self.output is None else self.output.pred_dur
350
+
351
+ ### MARK: BEGIN BACKWARD COMPAT ###
352
+ def __iter__(self):
353
+ yield self.graphemes
354
+ yield self.phonemes
355
+ yield self.audio
356
+
357
+ def __getitem__(self, index):
358
+ return [self.graphemes, self.phonemes, self.audio][index]
359
+
360
+ def __len__(self):
361
+ return 3
362
+ #### MARK: END BACKWARD COMPAT ####
363
+
364
+ def __call__(
365
+ self,
366
+ text: Union[str, List[str]],
367
+ voice: Optional[str] = None,
368
+ speed: Union[float, Callable[[int], float]] = 1,
369
+ split_pattern: Optional[str] = r'\n+',
370
+ model: Optional[KModel] = None
371
+ ) -> Generator['KPipeline.Result', None, None]:
372
+ model = model or self.model
373
+ if model and voice is None:
374
+ raise ValueError('Specify a voice: en_us_pipeline(text="Hello world!", voice="af_heart")')
375
+ pack = self.load_voice(voice).to(model.device) if model else None
376
+
377
+ # Convert input to list of segments
378
+ if isinstance(text, str):
379
+ text = re.split(split_pattern, text.strip()) if split_pattern else [text]
380
+
381
+ # Process each segment
382
+ for graphemes_index, graphemes in enumerate(text):
383
+ if not graphemes.strip(): # Skip empty segments
384
+ continue
385
+
386
+ # English processing (unchanged)
387
+ if self.lang_code in 'ab':
388
+ logger.debug(f"Processing English text: {graphemes[:50]}{'...' if len(graphemes) > 50 else ''}")
389
+ _, tokens = self.g2p(graphemes)
390
+ for gs, ps, tks in self.en_tokenize(tokens):
391
+ if not ps:
392
+ continue
393
+ elif len(ps) > 510:
394
+ logger.warning(f"Unexpected len(ps) == {len(ps)} > 510 and ps == '{ps}'")
395
+ ps = ps[:510]
396
+ output = KPipeline.infer(model, ps, pack, speed) if model else None
397
+ if output is not None and output.pred_dur is not None:
398
+ KPipeline.join_timestamps(tks, output.pred_dur)
399
+ yield self.Result(graphemes=gs, phonemes=ps, tokens=tks, output=output, text_index=graphemes_index)
400
+
401
+ # Non-English processing with chunking
402
+ else:
403
+ # Split long text into smaller chunks (roughly 400 characters each)
404
+ # Using sentence boundaries when possible
405
+ chunk_size = 400
406
+ chunks = []
407
+
408
+ # Try to split on sentence boundaries first
409
+ sentences = re.split(r'([.!?]+)', graphemes)
410
+ current_chunk = ""
411
+
412
+ for i in range(0, len(sentences), 2):
413
+ sentence = sentences[i]
414
+ # Add the punctuation back if it exists
415
+ if i + 1 < len(sentences):
416
+ sentence += sentences[i + 1]
417
+
418
+ if len(current_chunk) + len(sentence) <= chunk_size:
419
+ current_chunk += sentence
420
+ else:
421
+ if current_chunk:
422
+ chunks.append(current_chunk.strip())
423
+ current_chunk = sentence
424
+
425
+ if current_chunk:
426
+ chunks.append(current_chunk.strip())
427
+
428
+ # If no chunks were created (no sentence boundaries), fall back to character-based chunking
429
+ if not chunks:
430
+ chunks = [graphemes[i:i+chunk_size] for i in range(0, len(graphemes), chunk_size)]
431
+
432
+ # Process each chunk
433
+ for chunk in chunks:
434
+ if not chunk.strip():
435
+ continue
436
+
437
+ ps, _ = self.g2p(chunk)
438
+ if not ps:
439
+ continue
440
+ elif len(ps) > 510:
441
+ logger.warning(f'Truncating len(ps) == {len(ps)} > 510')
442
+ ps = ps[:510]
443
+
444
+ output = KPipeline.infer(model, ps, pack, speed) if model else None
445
+ yield self.Result(graphemes=chunk, phonemes=ps, output=output, text_index=graphemes_index)
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ opencv-python>=4.9.0.80
2
+ diffusers>=0.31.0
3
+ transformers>=4.49.0
4
+ tokenizers>=0.20.3
5
+ accelerate>=1.1.1
6
+ tqdm
7
+ imageio
8
+ easydict
9
+ ftfy
10
+ dashscope
11
+ imageio-ffmpeg
12
+ scikit-image
13
+ loguru
14
+ gradio>=5.0.0
15
+ numpy>=1.23.5,<2
16
+ xfuser>=0.4.1
17
+ pyloudnorm
18
+ optimum-quanto==0.2.6
19
+ scenedetect
20
+ moviepy==1.0.3
21
+ decord
tools/convert_img_to_video.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ import cv2
3
+ import numpy as np
4
+ from pathlib import Path
5
+
6
+ class ImageProcessor:
7
+ def __init__(self, yaml_path):
8
+ with open(yaml_path, 'r') as f:
9
+ self.config = yaml.safe_load(f)
10
+
11
+ self.images_info = []
12
+ self.reference_size = None
13
+ self._load_images()
14
+
15
+ def _load_images(self):
16
+ for img_config in self.config['images']:
17
+ img = cv2.imread(img_config['path'])
18
+ if img is None:
19
+ raise ValueError(f"Cannot load image: {img_config['path']}")
20
+
21
+ info = {
22
+ 'image': img,
23
+ 'duration': float(img_config.get('duration', 1.0)),
24
+ 'translation': img_config.get('translation', [0, 0]),
25
+ 'scale': float(img_config.get('scale', 1.0))
26
+ }
27
+ self.images_info.append(info)
28
+
29
+ if self.reference_size is None:
30
+ self.reference_size = (img.shape[1], img.shape[0])
31
+
32
+ def _translate_image(self, img, translation):
33
+ """Perform only translation"""
34
+ height, width = img.shape[:2]
35
+
36
+ # Calculate translation amount (pixels)
37
+ tx = int(width * translation[0] / 100)
38
+ ty = int(height * translation[1] / 100)
39
+
40
+ # Create translation matrix
41
+ M = np.float32([[1, 0, tx], [0, 1, ty]])
42
+
43
+ # Apply translation while maintaining original dimensions
44
+ translated = cv2.warpAffine(img, M, (width, height))
45
+
46
+ return translated
47
+
48
+ def _crop_black_borders(self, img):
49
+ """Crop out black borders from the image"""
50
+ # Convert to grayscale
51
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
52
+
53
+ # Threshold to identify non-black areas
54
+ _, thresh = cv2.threshold(gray, 1, 255, cv2.THRESH_BINARY)
55
+
56
+ # Find bounding box of non-black pixels
57
+ coords = cv2.findNonZero(thresh)
58
+ if coords is None:
59
+ return img
60
+
61
+ x, y, w, h = cv2.boundingRect(coords)
62
+
63
+ # Crop the image to the bounding box
64
+ return img[y:y+h, x:x+w]
65
+
66
+ def _scale_image(self, img, scale, target_size):
67
+ """Scale the image"""
68
+ if scale <= 1:
69
+ return cv2.resize(img, target_size)
70
+
71
+ # First scale up
72
+ height, width = img.shape[:2]
73
+ scaled_width = int(width * scale)
74
+ scaled_height = int(height * scale)
75
+ scaled = cv2.resize(img, (scaled_width, scaled_height))
76
+
77
+ # Center-crop to target dimensions
78
+ start_x = (scaled_width - target_size[0]) // 2
79
+ start_y = (scaled_height - target_size[1]) // 2
80
+ cropped = scaled[start_y:start_y+target_size[1],
81
+ start_x:start_x+target_size[0]]
82
+
83
+ return cropped
84
+
85
+ def _transform_image(self, img, translation, scale):
86
+ """Apply transformations in sequence: translation → cropping → scaling"""
87
+ original_size = (img.shape[1], img.shape[0])
88
+
89
+ # 1. Translation
90
+ translated = self._translate_image(img, translation)
91
+
92
+ # 2. Black border cropping
93
+ cropped = self._crop_black_borders(translated)
94
+
95
+ # 3. Scale back to original dimensions
96
+ transformed = self._scale_image(cropped, scale, original_size)
97
+
98
+ return transformed
99
+
100
+ def create_video(self, output_path, fps=25):
101
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
102
+ out = cv2.VideoWriter(output_path, fourcc, fps, self.reference_size)
103
+
104
+ try:
105
+ for info in self.images_info:
106
+ # Transform image
107
+ transformed = self._transform_image(
108
+ info['image'],
109
+ info['translation'],
110
+ info['scale']
111
+ )
112
+
113
+ # Resize to reference dimensions if needed
114
+ if transformed.shape[:2] != (self.reference_size[1], self.reference_size[0]):
115
+ transformed = cv2.resize(transformed, self.reference_size)
116
+
117
+ # Write video frames
118
+ n_frames = int(info['duration'] * fps)
119
+ for _ in range(n_frames):
120
+ out.write(transformed)
121
+
122
+ finally:
123
+ out.release()
124
+
125
+ # Enhance video quality
126
+ self._improve_video_quality(output_path)
127
+
128
+ def _improve_video_quality(self, video_path):
129
+ import subprocess
130
+ temp_path = video_path + '.temp.mp4'
131
+
132
+ cmd = [
133
+ 'ffmpeg', '-i', video_path,
134
+ '-c:v', 'libx264',
135
+ '-preset', 'slow',
136
+ '-crf', '18',
137
+ '-y',
138
+ temp_path
139
+ ]
140
+
141
+ subprocess.run(cmd)
142
+
143
+ import os
144
+ os.replace(temp_path, video_path)
145
+
146
+ def main():
147
+ processor = ImageProcessor('tools/i2v_config.yaml')
148
+ processor.create_video('convertd_video.mp4', fps=25)
149
+
150
+ if __name__ == '__main__':
151
+ main()
152
+
tools/i2v_config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ images:
2
+ # - path: "xxx.jpg" # Image path
3
+ # duration: 4.0 # Display duration (in seconds)
4
+ # # Translation: [x, y] percentage. Positive x is right, positive y is down.
5
+ # # This means the view pans 5% to the right and 2% up over 1 second.
6
+ # translation: [5, -2]
7
+ # # Scale: The final zoom factor. 1.0 is no zoom.
8
+ # # This means the view zooms from 1x to 1.2x over 1 second.
9
+ # scale: 1.2
10
+ - path: "examples/single/ref_image.png"
11
+ duration: 2 # seconds
12
+ translation: [0, 0] # [dx, dy] - pixels per second (approximately)
13
+ scale: 1.0 # Scale factor (1.0 = no change, >1.0 zoom in, <1.0 zoom out)
14
+ - path: "examples/single/ref_image.png"
15
+ duration: 3.0
16
+ translation: [-7, -7]
17
+ scale: 1.0