Trynitzan commited on
Commit
33c162a
·
verified ·
1 Parent(s): 51ac236

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -37
app.py CHANGED
@@ -10,24 +10,24 @@ import torch
10
  from pathlib import Path
11
  from datetime import datetime
12
 
13
- # קריאת טוכן מ-Secrets
14
  HF_TOKEN = os.getenv("HF_TOKEN")
15
  if not HF_TOKEN:
16
- raise RuntimeError("HF_TOKEN environment variable is required")
17
 
18
- # טעינת המודל
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
- print(f"🚀 Loading model on {device}...")
21
 
22
  try:
23
  pipeline = Pipeline.from_pretrained(
24
  "ivrit-ai/pyannote-speaker-diarization-3.1",
25
  use_auth_token=HF_TOKEN,
26
  )
27
- pipeline.to(torch.device(device)) # ✅ תיקון כאן!
28
- print("Model loaded successfully!")
29
  except Exception as e:
30
- print(f"Failed to load model: {e}")
31
  raise
32
 
33
  app = FastAPI(
@@ -36,18 +36,18 @@ app = FastAPI(
36
  version="1.0.0"
37
  )
38
 
39
- # הגבלות
40
  MAX_FILE_SIZE_MB = 50
41
  MAX_DURATION_MINUTES = 15
42
  MAX_CONCURRENT_REQUESTS = 2
43
 
44
- # ניהול תור
45
  processing_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
46
  active_requests = 0
47
 
48
 
49
  def ensure_wav_16k_mono(in_path: str) -> str:
50
- """ממיר אודיו ל-WAV 16kHz mono"""
51
  out_path = str(Path(in_path).with_suffix(".wav"))
52
  cmd = [
53
  "ffmpeg", "-y", "-i", in_path,
@@ -62,7 +62,7 @@ def ensure_wav_16k_mono(in_path: str) -> str:
62
 
63
 
64
  def estimate_duration(file_path: str) -> float:
65
- """אומדן אורך קובץ בדקות"""
66
  try:
67
  file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
68
  return file_size_mb / 2.0
@@ -72,7 +72,7 @@ def estimate_duration(file_path: str) -> float:
72
 
73
  @app.get("/")
74
  def root():
75
- """מידע על ה-API"""
76
  global active_requests
77
  return {
78
  "service": "Hebrew Speaker Diarization API",
@@ -99,7 +99,7 @@ def root():
99
 
100
  @app.get("/health")
101
  def health():
102
- """בדיקת בריאות"""
103
  global active_requests
104
  return {
105
  "status": "healthy",
@@ -113,13 +113,13 @@ def health():
113
  @app.post("/diarize")
114
  async def diarize(file: UploadFile = File(...)):
115
  """
116
- זיהוי דוברים בקובץ אודיו
117
 
118
  Args:
119
- file: קובץ אודיו (MP3, WAV, M4A, וכו')
120
 
121
  Returns:
122
- JSON: רשימת מקטעים עם זיהוי דוברים
123
  """
124
  global active_requests
125
 
@@ -160,13 +160,13 @@ async def diarize(file: UploadFile = File(...)):
160
  detail=f"File too long: ~{duration:.1f} min (max: {MAX_DURATION_MINUTES} min)"
161
  )
162
 
163
- print(f"🎤 Processing: {file.filename} ({file_size_mb:.1f}MB)")
164
  start_time = datetime.now()
165
 
166
  annotation = pipeline(wav_path)
167
 
168
  processing_time = (datetime.now() - start_time).total_seconds()
169
- print(f"Done in {processing_time:.1f}s")
170
 
171
  segments = []
172
  last_segment = None
@@ -209,7 +209,7 @@ async def diarize(file: UploadFile = File(...)):
209
  except HTTPException:
210
  raise
211
  except Exception as e:
212
- print(f"Error: {str(e)}")
213
  raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
214
  finally:
215
  for path in [tmp_path, wav_path]:
@@ -222,21 +222,4 @@ async def diarize(file: UploadFile = File(...)):
222
 
223
  if __name__ == "__main__":
224
  import uvicorn
225
- uvicorn.run(app, host="0.0.0.0", port=7860)
226
- ```
227
-
228
- ---
229
-
230
- ## 🚀 עכשיו זה אמור לעבוד!
231
-
232
- 1. **עדכן את `app.py`** ב-Space
233
- 2. שמור
234
- 3. המתן לבנייה (~1-2 דקות - כי יש cache)
235
- 4. בדוק `/health`
236
-
237
- אחרי התיקון הזה, אמור לראות בלוגים:
238
- ```
239
- 🚀 Loading model on cpu...
240
- ✅ Model loaded successfully!
241
- INFO: Started server process
242
- INFO: Uvicorn running on http://0.0.0.0:7860
 
10
  from pathlib import Path
11
  from datetime import datetime
12
 
13
+ # Read token from environment
14
  HF_TOKEN = os.getenv("HF_TOKEN")
15
  if not HF_TOKEN:
16
+ raise RuntimeError("HF_TOKEN environment variable is required")
17
 
18
+ # Load model
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
20
+ print(f"Loading model on {device}...")
21
 
22
  try:
23
  pipeline = Pipeline.from_pretrained(
24
  "ivrit-ai/pyannote-speaker-diarization-3.1",
25
  use_auth_token=HF_TOKEN,
26
  )
27
+ pipeline.to(torch.device(device))
28
+ print("Model loaded successfully!")
29
  except Exception as e:
30
+ print(f"Failed to load model: {e}")
31
  raise
32
 
33
  app = FastAPI(
 
36
  version="1.0.0"
37
  )
38
 
39
+ # Limits
40
  MAX_FILE_SIZE_MB = 50
41
  MAX_DURATION_MINUTES = 15
42
  MAX_CONCURRENT_REQUESTS = 2
43
 
44
+ # Queue management
45
  processing_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
46
  active_requests = 0
47
 
48
 
49
  def ensure_wav_16k_mono(in_path: str) -> str:
50
+ """Convert audio to WAV 16kHz mono"""
51
  out_path = str(Path(in_path).with_suffix(".wav"))
52
  cmd = [
53
  "ffmpeg", "-y", "-i", in_path,
 
62
 
63
 
64
  def estimate_duration(file_path: str) -> float:
65
+ """Estimate file duration in minutes"""
66
  try:
67
  file_size_mb = os.path.getsize(file_path) / (1024 * 1024)
68
  return file_size_mb / 2.0
 
72
 
73
  @app.get("/")
74
  def root():
75
+ """API information"""
76
  global active_requests
77
  return {
78
  "service": "Hebrew Speaker Diarization API",
 
99
 
100
  @app.get("/health")
101
  def health():
102
+ """Health check"""
103
  global active_requests
104
  return {
105
  "status": "healthy",
 
113
  @app.post("/diarize")
114
  async def diarize(file: UploadFile = File(...)):
115
  """
116
+ Speaker diarization for audio file
117
 
118
  Args:
119
+ file: Audio file (MP3, WAV, M4A, etc.)
120
 
121
  Returns:
122
+ JSON: List of segments with speaker identification
123
  """
124
  global active_requests
125
 
 
160
  detail=f"File too long: ~{duration:.1f} min (max: {MAX_DURATION_MINUTES} min)"
161
  )
162
 
163
+ print(f"Processing: {file.filename} ({file_size_mb:.1f}MB)")
164
  start_time = datetime.now()
165
 
166
  annotation = pipeline(wav_path)
167
 
168
  processing_time = (datetime.now() - start_time).total_seconds()
169
+ print(f"Done in {processing_time:.1f}s")
170
 
171
  segments = []
172
  last_segment = None
 
209
  except HTTPException:
210
  raise
211
  except Exception as e:
212
+ print(f"Error: {str(e)}")
213
  raise HTTPException(status_code=500, detail=f"Processing error: {str(e)}")
214
  finally:
215
  for path in [tmp_path, wav_path]:
 
222
 
223
  if __name__ == "__main__":
224
  import uvicorn
225
+ uvicorn.run(app, host="0.0.0.0", port=7860)