Spaces:
Sleeping
Sleeping
File size: 6,182 Bytes
73edc95 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
#!/usr/bin/env python3
"""
Concurrency test for benchmark environment using WebSockets.
Each WebSocket connection gets its own dedicated environment session,
enabling true concurrent execution across multiple sessions.
Run the server first:
cd benchmark && uvicorn server.app:app --port 8000
Then run this script:
python test_concurrency.py --requests 100 --wait 1.0
python test_concurrency.py -n 100 -w 1 --url wss://your-server.hf.space
"""
import argparse
import asyncio
import json
import time
from dataclasses import dataclass
try:
import websockets
except ImportError:
print("Install websockets: pip install websockets")
raise
@dataclass
class RequestResult:
"""Result from a single WebSocket request."""
request_id: int
wait_requested: float
waited_seconds: float
elapsed: float
pid: int
session_hash: str
host_url: str
def convert_to_ws_url(url: str) -> str:
"""Convert HTTP URL to WebSocket URL."""
url = url.rstrip("/")
if url.startswith("http://"):
return "ws://" + url[7:] + "/ws"
elif url.startswith("https://"):
return "wss://" + url[8:] + "/ws"
elif url.startswith("ws://") or url.startswith("wss://"):
return url if url.endswith("/ws") else url + "/ws"
return "ws://" + url + "/ws"
async def ws_session(
ws_url: str,
request_id: int,
wait_seconds: float,
timeout: float = 60.0,
) -> RequestResult:
"""
Run a complete WebSocket session: connect, reset, step, close.
Each session gets its own environment instance on the server.
"""
start = time.perf_counter()
async with websockets.connect(ws_url, open_timeout=timeout) as ws:
# Reset to initialize environment
await ws.send(json.dumps({"type": "reset", "data": {}}))
reset_response = json.loads(await asyncio.wait_for(ws.recv(), timeout))
if reset_response.get("type") == "error":
raise RuntimeError(f"Reset error: {reset_response}")
# Step with wait time
await ws.send(
json.dumps({
"type": "step",
"data": {"wait_seconds": wait_seconds},
})
)
step_response = json.loads(await asyncio.wait_for(ws.recv(), timeout))
if step_response.get("type") == "error":
raise RuntimeError(f"Step error: {step_response}")
# Close cleanly
await ws.send(json.dumps({"type": "close"}))
elapsed = time.perf_counter() - start
obs = step_response.get("data", {}).get("observation", {})
return RequestResult(
request_id=request_id,
wait_requested=wait_seconds,
waited_seconds=obs.get("waited_seconds", 0.0),
elapsed=elapsed,
pid=obs.get("pid", 0),
session_hash=obs.get("session_hash", ""),
host_url=obs.get("host_url", ""),
)
async def run_concurrent_test(
url: str,
num_requests: int,
wait_seconds: float,
timeout: float = 120.0,
) -> dict:
"""Run concurrent WebSocket sessions and collect results."""
ws_url = convert_to_ws_url(url)
print(f"WebSocket URL: {ws_url}")
print(f"Running {num_requests} concurrent sessions, each waiting {wait_seconds}s...")
print()
start = time.perf_counter()
# Launch all sessions concurrently
tasks = [
ws_session(ws_url, i, wait_seconds, timeout) for i in range(num_requests)
]
results = await asyncio.gather(*tasks, return_exceptions=True)
total_time = time.perf_counter() - start
# Process results
successful = [r for r in results if isinstance(r, RequestResult)]
failed = [r for r in results if isinstance(r, Exception)]
if not successful:
print("All requests failed!")
for i, err in enumerate(failed[:5]):
print(f" Error {i}: {err}")
return {"error": "All requests failed"}
avg_time = sum(r.elapsed for r in successful) / len(successful)
unique_pids = set(r.pid for r in successful)
unique_sessions = set(r.session_hash for r in successful)
unique_hosts = set(r.host_url for r in successful)
return {
"num_requests": num_requests,
"successful": len(successful),
"failed": len(failed),
"wait_seconds": wait_seconds,
"total_time": total_time,
"avg_time": avg_time,
"unique_pids": len(unique_pids),
"unique_sessions": len(unique_sessions),
"unique_hosts": len(unique_hosts),
"pids": list(unique_pids)[:10], # First 10 for display
}
async def main():
parser = argparse.ArgumentParser(
description="Test benchmark environment concurrency via WebSocket"
)
parser.add_argument(
"--requests", "-n", type=int, default=10,
help="Number of concurrent WebSocket sessions"
)
parser.add_argument(
"--wait", "-w", type=float, default=1.0,
help="Wait time per request (seconds)"
)
parser.add_argument(
"--url", "-u", type=str, default="http://localhost:8000",
help="Server URL (http/https/ws/wss)"
)
parser.add_argument(
"--timeout", "-t", type=float, default=120.0,
help="Timeout per request (seconds)"
)
args = parser.parse_args()
result = await run_concurrent_test(
args.url, args.requests, args.wait, args.timeout
)
if "error" in result:
return
print(f"Successful: {result['successful']}/{result['num_requests']}")
if result["failed"]:
print(f"Failed: {result['failed']}")
print(f"Total time: {result['total_time']:.3f}s")
print(f"Avg time: {result['avg_time']:.3f}s")
print(f"Unique PIDs: {result['unique_pids']}")
print(f"Unique sessions: {result['unique_sessions']}")
print(f"Unique hosts: {result['unique_hosts']}")
# Calculate concurrency metrics
ideal_time = args.wait
actual_concurrency = (args.requests * args.wait) / result["total_time"]
print()
print(f"Ideal time (full concurrency): {ideal_time:.3f}s")
print(f"Effective concurrency: {actual_concurrency:.1f}x")
if __name__ == "__main__":
asyncio.run(main())
|