Spaces:
Running
on
Zero
Running
on
Zero
| import yt_dlp | |
| import re | |
| import subprocess | |
| import os | |
| import shutil | |
| from pydub import AudioSegment, silence | |
| import gradio as gr | |
| import traceback | |
| import logging | |
| from inference import proc_folder_direct | |
| from pathlib import Path | |
| import spaces | |
| from pydub.exceptions import CouldntEncodeError | |
| from transformers import pipeline | |
| import requests | |
| # Initialize text generation model | |
| model = pipeline('text-generation', model='EleutherAI/gpt-neo-125M') | |
| # Define constants | |
| OUTPUT_FOLDER = "separation_results/" | |
| INPUT_FOLDER = "input" | |
| download_path = "" | |
| # URL for the cookies.txt file in the Hugging Face repository | |
| cookies_url = "https://huggingface.co/spaces/Awell00/music_drums_separation/raw/main/cookies.txt" | |
| def download_cookies(): | |
| try: | |
| response = requests.get(cookies_url) | |
| response.raise_for_status() # Check for HTTP errors | |
| # Write content to cookies.txt file in the Docker container | |
| with open("cookies.txt", "w") as file: | |
| file.write(response.text) | |
| print("cookies.txt downloaded successfully.") | |
| except requests.exceptions.RequestException as e: | |
| print(f"Error downloading cookies.txt: {e}") | |
| class MyLogger: | |
| def debug(self, msg): | |
| # For compatibility with youtube-dl, both debug and info are passed into debug | |
| if msg.startswith('[debug] '): | |
| pass | |
| else: | |
| self.info(msg) | |
| def info(self, msg): | |
| pass | |
| def warning(self, msg): | |
| pass | |
| def error(self, msg): | |
| print(msg) | |
| def my_hook(d): | |
| if d['status'] == 'finished': | |
| print('Done downloading, now post-processing ...') | |
| def sanitize_filename(filename): | |
| """ | |
| Remove special characters from filename to ensure it's valid across different file systems. | |
| Args: | |
| filename (str): The original filename | |
| Returns: | |
| str: Sanitized filename | |
| """ | |
| return re.sub(r'[\\/*?:"<>|]', '_', filename) | |
| def delete_input_files(input_dir): | |
| """ | |
| Delete all WAV files in the input directory. | |
| Args: | |
| input_dir (str): Path to the input directory | |
| """ | |
| wav_dir = Path(input_dir) / "wav" | |
| for wav_file in wav_dir.glob("*.wav"): | |
| wav_file.unlink() | |
| print(f"Deleted {wav_file}") | |
| def standardize_title(input_title): | |
| """ | |
| Standardize the title format by removing unnecessary words and rearranging artist and title. | |
| Args: | |
| input_title (str): The original title | |
| Returns: | |
| str: Standardized title in "Artist - Title" format | |
| """ | |
| # Remove content within parentheses or brackets | |
| title_cleaned = re.sub(r"[\(\[].*?[\)\]]", "", input_title) | |
| # Remove unnecessary words | |
| unnecessary_words = ["official", "video", "hd", "4k", "lyrics", "music", "audio", "visualizer", "remix", ""] | |
| title_cleaned = re.sub(r"\b(?:{})\b".format("|".join(unnecessary_words)), "", title_cleaned, flags=re.IGNORECASE) | |
| # Split title into parts | |
| parts = re.split(r"\s*-\s*|\s*,\s*", title_cleaned) | |
| # Determine artist and title parts | |
| if len(parts) >= 2: | |
| title_part = parts[-1].strip() | |
| artist_part = ', '.join(parts[:-1]).strip() | |
| else: | |
| artist_part = "Unknown Artist" | |
| title_part = title_cleaned.strip() | |
| # Handle "with" or "feat" in the title | |
| if "with" in input_title.lower() or "feat" in input_title.lower(): | |
| match = re.search(r"\((with|feat\.?) (.*?)\)", input_title, re.IGNORECASE) | |
| if match: | |
| additional_artist = match.group(2).strip() | |
| artist_part = f"{artist_part}, {additional_artist}" if artist_part != "Unknown Artist" else additional_artist | |
| # Clean up and capitalize | |
| artist_part = re.sub(r'\s+', ' ', artist_part).title() | |
| title_part = re.sub(r'\s+', ' ', title_part).title() | |
| # Combine artist and title | |
| standardized_output = f"{artist_part} - {title_part}" | |
| return standardized_output.strip() | |
| def get_video_title(video_url): | |
| ydl_opts = { | |
| 'logger': MyLogger(), | |
| 'progress_hooks': [my_hook], | |
| 'cookiefile': 'cookies.txt', | |
| 'quiet': True, | |
| 'ratelimit': 500000, | |
| 'retries': 3, | |
| } | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| # Extract video info using the provided URL | |
| video_info = ydl.extract_info(video_url, download=False) | |
| # Get the video title | |
| video_title = video_info['title'] # Get the video title | |
| return video_title | |
| def download_youtube_audio(youtube_url: str, output_dir: str = './download', delete_existing: bool = True, simulate: bool = False) -> str: | |
| """ | |
| Downloads audio from a YouTube URL and saves it as an MP3 file with specified yt-dlp options. | |
| Args: | |
| youtube_url (str): URL of the YouTube video. | |
| output_dir (str): Directory to save the downloaded audio file. | |
| delete_existing (bool): If True, deletes any existing file with the same name. | |
| simulate (bool): If True, simulates the download without actually downloading. | |
| Returns: | |
| str: Path to the downloaded audio file. | |
| """ | |
| if not os.path.exists(output_dir): | |
| os.makedirs(output_dir) | |
| download_cookies() | |
| title = get_video_title(youtube_url) | |
| audio_file = os.path.join(output_dir, title) | |
| # Remove existing file if requested | |
| if delete_existing and os.path.exists(audio_file + '.mp3'): | |
| os.remove(audio_file + '.mp3') | |
| # Prepare yt-dlp options | |
| ydl_opts = { | |
| 'logger': MyLogger(), | |
| 'progress_hooks': [my_hook], | |
| 'format': 'bestaudio', | |
| 'outtmpl': audio_file, | |
| 'postprocessors': [{ | |
| 'key': 'FFmpegExtractAudio', | |
| 'preferredcodec': 'wav', | |
| }], | |
| 'extractor_retries': 10, | |
| 'force_overwrites': True, | |
| 'cookiefile': 'cookies.txt', | |
| 'verbose': True, | |
| 'ratelimit': 500000, | |
| 'retries': 3, | |
| 'sleep_interval': 10, | |
| 'max_sleep_interval': 30 | |
| } | |
| if simulate: | |
| ydl_opts['simulate'] = True | |
| # Download the audio using yt-dlp | |
| with yt_dlp.YoutubeDL(ydl_opts) as ydl: | |
| ydl.download([youtube_url]) | |
| return audio_file + '.wav' | |
| def handle_file_upload(file): | |
| """ | |
| Handle file upload, standardize the filename, change extension to .wav, and copy it to the input folder. | |
| Args: | |
| file: Uploaded file object or file path string | |
| Returns: | |
| tuple: (input_path, formatted_title) or (None, error_message) | |
| """ | |
| if file is None: | |
| return None, "No file uploaded" | |
| # Check if 'file' is an instance of a file object or a string | |
| if isinstance(file, str): | |
| filename = os.path.basename(file) # If it's a string, use it directly | |
| file_path = file # The string itself is the file path | |
| else: | |
| filename = os.path.basename(file.name) # If it's a file object | |
| file_path = file.name | |
| formatted_title = standardize_title(os.path.splitext(filename)[0]) # Removing extension | |
| formatted_title = sanitize_filename(formatted_title.strip()) | |
| # Change the extension to .wav | |
| input_path = os.path.join(INPUT_FOLDER, "wav", f"{formatted_title}.wav") | |
| os.makedirs(os.path.dirname(input_path), exist_ok=True) | |
| # Convert the input file to .wav if it's not already | |
| audio = AudioSegment.from_file(file_path) | |
| audio.export(input_path, format="wav") | |
| return input_path, formatted_title | |
| def run_inference(model_type, config_path, start_check_point, input_dir, output_dir, device_ids="0"): | |
| """ | |
| Run inference using the specified model and parameters. | |
| Args: | |
| model_type (str): Type of the model | |
| config_path (str): Path to the model configuration | |
| start_check_point (str): Path to the model checkpoint | |
| input_dir (str): Input directory | |
| output_dir (str): Output directory | |
| device_ids (str): GPU device IDs to use | |
| Returns: | |
| subprocess.CompletedProcess: Result of the subprocess run | |
| """ | |
| command = [ | |
| "python", "inference.py", | |
| "--model_type", model_type, | |
| "--config_path", config_path, | |
| "--start_check_point", start_check_point, | |
| "--INPUT_FOLDER", input_dir, | |
| "--store_dir", output_dir, | |
| "--device_ids", device_ids | |
| ] | |
| return subprocess.run(command, check=True, capture_output=True, text=True) | |
| def move_stems_to_parent(input_dir): | |
| """ | |
| Move generated stem files to their parent directories. | |
| Args: | |
| input_dir (str): Input directory containing stem folders | |
| """ | |
| for subdir, dirs, files in os.walk(input_dir): | |
| if subdir == input_dir: | |
| continue | |
| parent_dir = os.path.dirname(subdir) | |
| song_name = os.path.basename(parent_dir) | |
| # Move bass stem | |
| if 'htdemucs' in subdir: | |
| bass_path = os.path.join(subdir, f"{song_name}_bass.wav") | |
| if os.path.exists(bass_path): | |
| new_bass_path = os.path.join(parent_dir, "bass.wav") | |
| shutil.move(bass_path, new_bass_path) | |
| else: | |
| print(f"Bass file not found: {bass_path}") | |
| # Move vocals stem | |
| elif 'mel_band_roformer' in subdir: | |
| vocals_path = os.path.join(subdir, f"{song_name}_vocals.wav") | |
| if os.path.exists(vocals_path): | |
| new_vocals_path = os.path.join(parent_dir, "vocals.wav") | |
| shutil.move(vocals_path, new_vocals_path) | |
| else: | |
| print(f"Vocals file not found: {vocals_path}") | |
| # Move other stem | |
| elif 'scnet' in subdir: | |
| other_path = os.path.join(subdir, f"{song_name}_other.wav") | |
| if os.path.exists(other_path): | |
| new_other_path = os.path.join(parent_dir, "other.wav") | |
| shutil.move(other_path, new_other_path) | |
| else: | |
| print(f"Other file not found: {other_path}") | |
| # Move instrumental stem | |
| elif 'bs_roformer' in subdir: | |
| instrumental_path = os.path.join(subdir, f"{song_name}_other.wav") | |
| if os.path.exists(instrumental_path): | |
| new_instrumental_path = os.path.join(parent_dir, "instrumental.wav") | |
| shutil.move(instrumental_path, new_instrumental_path) | |
| def combine_stems_for_all(input_dir, output_format="mp3"): | |
| """ | |
| Combine all stems for each song in the input directory and export as MP3. | |
| Args: | |
| input_dir (str): Input directory containing song folders | |
| output_format (str): Output audio format (default is 'mp3') | |
| Returns: | |
| str: Path to the combined audio file | |
| """ | |
| for subdir, _, _ in os.walk(input_dir): | |
| if subdir == input_dir: | |
| continue | |
| song_name = os.path.basename(subdir).strip() # Remove any trailing spaces | |
| print(f"Processing {subdir}") | |
| stem_paths = { | |
| "vocals": os.path.join(subdir, "vocals.wav"), | |
| "bass": os.path.join(subdir, "bass.wav"), | |
| "others": os.path.join(subdir, "other.wav"), | |
| "instrumental": os.path.join(subdir, "instrumental.wav") | |
| } | |
| # Skip if not all stems are present | |
| if not all(os.path.exists(path) for path in stem_paths.values()): | |
| print(f"Skipping {subdir}, not all stems are present.") | |
| continue | |
| # Load and combine stems | |
| stems = {name: AudioSegment.from_file(path) for name, path in stem_paths.items()} | |
| stems["instrumental"] = stems["instrumental"].apply_gain(-20) | |
| combined = stems["vocals"].overlay(stems["bass"]).overlay(stems["others"]).overlay(stems["instrumental"]) | |
| # Trim silence at the end | |
| trimmed_combined = trim_silence_at_end(combined) | |
| # Format the output file name correctly | |
| output_file = os.path.join(subdir, f"{song_name}.{output_format.lower()}") | |
| # Export combined audio | |
| try: | |
| trimmed_combined.export(output_file, format=output_format.lower(), codec="libmp3lame", bitrate="320k") | |
| print(f"Exported combined stems to {output_format.upper()} format: {output_file}") | |
| except CouldntEncodeError as e: | |
| print(f"{output_format.upper()} Encoding failed: {e}") | |
| return None | |
| return output_file | |
| def trim_silence_at_end(audio_segment, silence_thresh=-50, chunk_size=10): | |
| """ | |
| Trim silence at the end of an audio segment. | |
| Args: | |
| audio_segment (AudioSegment): Input audio segment | |
| silence_thresh (int): Silence threshold in dB | |
| chunk_size (int): Size of chunks to analyze in ms | |
| Returns: | |
| AudioSegment: Trimmed audio segment | |
| """ | |
| silence_end = silence.detect_silence(audio_segment, min_silence_len=chunk_size, silence_thresh=silence_thresh) | |
| if silence_end: | |
| last_silence_start = silence_end[-1][0] | |
| return audio_segment[:last_silence_start] | |
| else: | |
| return audio_segment | |
| def delete_folders_and_files(input_dir): | |
| """ | |
| Delete temporary folders and files after processing. | |
| Args: | |
| input_dir (str): Input directory to clean up | |
| """ | |
| folders_to_delete = ['htdemucs', 'mel_band_roformer', 'scnet', 'bs_roformer'] | |
| files_to_delete = ['bass.wav', 'vocals.wav', 'other.wav', 'instrumental.wav'] | |
| for root, dirs, files in os.walk(input_dir, topdown=False): | |
| if root == input_dir: | |
| continue | |
| # Delete specified folders | |
| for folder in folders_to_delete: | |
| folder_path = os.path.join(root, folder) | |
| if os.path.isdir(folder_path): | |
| print(f"Deleting folder: {folder_path}") | |
| shutil.rmtree(folder_path) | |
| # Delete specified files | |
| for file in files_to_delete: | |
| file_path = os.path.join(root, file) | |
| if os.path.isfile(file_path): | |
| print(f"Deleting file: {file_path}") | |
| os.remove(file_path) | |
| # Delete vocals folders | |
| for root, dirs, files in os.walk(OUTPUT_FOLDER): | |
| for dir_name in dirs: | |
| if dir_name.endswith('_vocals'): | |
| dir_path = os.path.join(root, dir_name) | |
| print(f"Deleting folder: {dir_path}") | |
| shutil.rmtree(dir_path) | |
| print("Cleanup completed.") | |
| def process_audio(uploaded_file, link): | |
| """ | |
| Main function to process the uploaded audio file. | |
| Args: | |
| uploaded_file: Uploaded file object | |
| Yields: | |
| tuple: (status_message, output_file_path) | |
| """ | |
| try: | |
| yield "Processing audio...", None | |
| if uploaded_file: | |
| input_path, formatted_title = handle_file_upload(uploaded_file) | |
| if input_path is None: | |
| raise ValueError("File upload failed.") | |
| elif link: | |
| new_file = download_youtube_audio(link) | |
| input_path, formatted_title = handle_file_upload(new_file) | |
| else: | |
| raise ValueError("Please upload a WAV file.") | |
| # Run inference for different models | |
| yield "Starting SCNet inference...", None | |
| proc_folder_direct("scnet", "configs/config_scnet_other.yaml", "results/model_scnet_other.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER) | |
| yield "Starting Mel Band Roformer inference...", None | |
| proc_folder_direct("mel_band_roformer", "configs/config_mel_band_roformer_vocals.yaml", "results/model_mel_band_roformer_vocals.ckpt", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER, extract_instrumental=True) | |
| yield "Starting HTDemucs inference...", None | |
| proc_folder_direct("htdemucs", "configs/config_htdemucs_bass.yaml", "results/model_htdemucs_bass.th", f"{INPUT_FOLDER}/wav", OUTPUT_FOLDER) | |
| # Rename instrumental file | |
| source_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}_instrumental.wav' | |
| destination_path = f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer/{formatted_title}.wav' | |
| os.rename(source_path, destination_path) | |
| yield "Starting BS Roformer inference...", None | |
| proc_folder_direct("bs_roformer", "configs/config_bs_roformer_instrumental.yaml", "results/model_bs_roformer_instrumental.ckpt", f'{OUTPUT_FOLDER}{formatted_title}/mel_band_roformer', OUTPUT_FOLDER) | |
| # Clean up and organize files | |
| yield "Moving input files...", None | |
| delete_input_files(INPUT_FOLDER) | |
| yield "Moving stems to parent...", None | |
| move_stems_to_parent(OUTPUT_FOLDER) | |
| yield "Combining stems...", None | |
| output_file = combine_stems_for_all(OUTPUT_FOLDER, "mp3") | |
| yield "Cleaning up...", None | |
| delete_folders_and_files(OUTPUT_FOLDER) | |
| yield f"Audio processing completed successfully.", output_file | |
| except Exception as e: | |
| error_msg = f"An error occurred: {str(e)}\n{traceback.format_exc()}" | |
| logging.error(error_msg) | |
| yield error_msg, None | |
| # Set up Gradio interface | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Music Player and Processor") | |
| youtube_url = gr.Textbox( | |
| label="YouTube Song URL", | |
| placeholder="This feature is currently disabled. You cannot input a URL.", | |
| interactive=False | |
| ) | |
| file_upload = gr.File(label="Upload MP3 file", file_types=[".mp3"]) | |
| process_button = gr.Button("Process Audio") | |
| log_output = gr.Textbox(label="Processing Log", interactive=False) | |
| processed_audio_output = gr.File(label="Processed Audio") | |
| process_button.click( | |
| fn=process_audio, | |
| inputs=[file_upload, youtube_url], | |
| outputs=[log_output, processed_audio_output], | |
| show_progress=True | |
| ) | |
| # Launch the Gradio app | |
| demo.launch() |