Created
February 18, 2026 17:10
-
-
Save MahmoudAshraf97/504eb60dd19ea352728665ae74a51d05 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| import argparse | |
| import asyncio | |
| import json | |
| import uuid | |
| import wave | |
| from websockets import connect, WebSocketException | |
| # Configuration constants | |
| SAMPLE_RATE = 16000 | |
| CHANNELS = 1 | |
| PORT = 4001 # WebSocket server port | |
| EXIT_TIMEOUT = 2 # seconds of inactivity before ending stream | |
| # Event and mode enums | |
| class ClientEvents: | |
| START_STREAM = "START_STREAM" | |
| END_STREAM = "END_STREAM" | |
| class RecitationMode: | |
| DETECTION = "DETECTION" | |
| FOLLOW_ALONG = "FOLLOW_ALONG" | |
| async def send_audio(ws, audio_file, start_stream_data, shared): | |
| # Validate WAV properties | |
| with wave.open(audio_file, "rb") as wf: | |
| if wf.getframerate() != SAMPLE_RATE: | |
| raise ValueError(f"Invalid sample rate: {wf.getframerate()}, expected {SAMPLE_RATE}") | |
| if wf.getnchannels() != CHANNELS: | |
| raise ValueError(f"Invalid channel count: {wf.getnchannels()}, expected {CHANNELS}") | |
| # Send START_STREAM | |
| await ws.send(json.dumps(start_stream_data)) | |
| chunk_size = 0.05 | |
| chunk_duration_ms = chunk_size * 1000 | |
| # Stream binary audio in chunks | |
| while chunk := wf.readframes(int(chunk_size * SAMPLE_RATE)): | |
| await ws.send(chunk) | |
| await asyncio.sleep(chunk_size) | |
| shared["audio_sent_ms"] += chunk_duration_ms | |
| async def receive_messages(ws, shared): | |
| mistakes = {} | |
| first_identified_position = None | |
| state_latencies = [] | |
| mistake_latencies = {} | |
| try: | |
| async for message in ws: | |
| try: | |
| result = json.loads(message) | |
| except json.JSONDecodeError: | |
| print(f"[UNKNOWN] {message}") | |
| continue | |
| evt = result.get("event") | |
| data = result.get("data", {}) | |
| if evt == "STATES_UPDATE": | |
| if data.get("mistakeUpdates"): | |
| # print(data.get("mistakeUpdates").values()) | |
| for id, mistake in data.get("mistakeUpdates").items(): | |
| print(id, mistake) | |
| mistakes[id] = mistake | |
| if mistake is not None and id not in mistake_latencies: | |
| mistake_latencies[id] = shared["audio_sent_ms"] - mistake["endTimeMs"] | |
| elif mistake is None: | |
| mistake_latencies[id] = None | |
| if not first_identified_position: | |
| for state in data.get("newStates", []): | |
| pos = state.get("position") | |
| if pos: | |
| first_identified_position = ( | |
| f"{pos['surahNumber']}:{pos['ayahNumber']}:{pos['wordNumber']}" | |
| ) | |
| break | |
| for m in data.get("mistakeUpdates").values(): | |
| if m is not None: | |
| print( | |
| f"[{m['mistakeType']}] E: {m['expectedTranscript']} / R: {m['receivedTranscript']}" | |
| f" / {m['startTimeMs']}-{m['endTimeMs']}" | |
| ) | |
| for state in data.get("newStates", []): | |
| pos = state.get("position") | |
| position = ( | |
| f"{pos['surahNumber']}:{pos['ayahNumber']}:{pos['wordNumber']}" | |
| if pos | |
| else "0:0:0" | |
| ) | |
| print( | |
| f"[{state['type']}] {position} {state['word']} ({state['startTime']}-{state['endTime']})" | |
| ) | |
| if len(data.get("newStates", [])): | |
| print("audio sent ms", shared["audio_sent_ms"]) | |
| state_latencies.append( | |
| shared["audio_sent_ms"] - data["newStates"][-1]["endTime"] | |
| ) | |
| elif evt == "PARTIAL_TRANSCRIPT": | |
| print(f"[{evt}] {data.get('queryText')}") | |
| elif evt in ("GOT_LOST", "ERROR"): | |
| print(f"[{evt}] {data if evt=='ERROR' else ''}") | |
| else: | |
| print(f"[UNKNOWN_EVENT] {result}") | |
| print("=" * 20) | |
| except WebSocketException as e: | |
| print(f"WebSocket error: {e}") | |
| return ( | |
| {k: v for k, v in mistakes.items() if v is not None}, | |
| first_identified_position, | |
| state_latencies, | |
| {k: v for k, v in mistake_latencies.items() if v is not None}, | |
| ) | |
| async def inactivity_watch(ws, end_stream_data): | |
| # Wait then send END_STREAM | |
| await asyncio.sleep(EXIT_TIMEOUT) | |
| # print("Sending END_STREAM event due to inactivity") | |
| await ws.send(json.dumps(end_stream_data)) | |
| await ws.close() | |
| async def evaluate(audio_file: str, verbose: bool = True): | |
| client_config = { | |
| "appVersion": "dev", | |
| "audioConfig": { | |
| "fileFormat": "WAV", | |
| "channels": CHANNELS, | |
| "sampleRate": SAMPLE_RATE, | |
| "modelName": None, | |
| }, | |
| "authToken": "4b70b75dc4d77118cd63adb3acbbc5d7eeca65bb", | |
| "deviceId": "123", | |
| "devicePlatform": "WEB", | |
| "recitationMode": RecitationMode.DETECTION, | |
| "sessionId": str(uuid.uuid4()), | |
| "isDiacritized": True, | |
| "isMemorization": False, | |
| "shouldCollectAudio": False, | |
| "shouldLabelAudio": False, | |
| # "mistakeReportingTimeLag": 0, | |
| # "isNewSttServer": True, | |
| # "isDualModel": False, | |
| "mistakeReportingTimeLag": 800, | |
| "isNewSttServer": False, | |
| "isDualModel": True, | |
| } | |
| start_stream_data = {"event": ClientEvents.START_STREAM, "data": client_config} | |
| end_stream_data = {"event": ClientEvents.END_STREAM, "data": {}} | |
| # uri = f"ws://localhost:{PORT}" | |
| uri = "wss://voice-v2-dev.tarteel.io" | |
| shared = {"audio_sent_ms": 0} | |
| async with connect(uri, close_timeout=200, open_timeout=200) as ws: | |
| # Launch tasks concurrently | |
| send_task = asyncio.create_task(send_audio(ws, audio_file, start_stream_data, shared)) | |
| recv_task = asyncio.create_task(receive_messages(ws, shared)) | |
| # Start inactivity timer after sending completes | |
| await send_task | |
| await inactivity_watch(ws, end_stream_data) | |
| # Ensure we process remaining messages | |
| mistakes, first_identified_position, state_latencies, mistake_latencies = await recv_task | |
| if verbose: | |
| print(f"First identified position: {first_identified_position}") | |
| for mistake in mistakes.values(): | |
| # print(mistake) | |
| print( | |
| f"[{mistake['mistakeType']}] E: {mistake['expectedTranscript']} / R: {mistake['receivedTranscript']}" | |
| f" / {mistake['startTimeMs']}-{mistake['endTimeMs']}" | |
| ) | |
| if state_latencies: | |
| avg_state = sum(state_latencies) / len(state_latencies) | |
| print(f"Avg state latency: {avg_state:.1f}ms ({len(state_latencies)} states)") | |
| if mistake_latencies: | |
| avg_mistake = sum(mistake_latencies.values()) / len(mistake_latencies) | |
| print(f"Avg mistake latency: {avg_mistake:.1f}ms ({len(mistake_latencies)} mistakes)") | |
| return mistakes.values() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Async WAV-to-WebSocket streamer") | |
| parser.add_argument("file", help="Path to WAV audio file (16kHz, mono)") | |
| args = parser.parse_args() | |
| try: | |
| asyncio.run(evaluate(args.file)) | |
| except Exception as e: | |
| print(f"Error: {e}") | |
| exit(1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment