move to aiohttp

This commit is contained in:
Martin Pitt 2024-07-19 18:12:41 +02:00
parent e34f982375
commit 577ec26307

44
bidi.py Normal file → Executable file
View file

@ -1,3 +1,5 @@
#!/usr/bin/python3
import asyncio
import json
import logging
@ -6,7 +8,7 @@ import time
import urllib.request
from dataclasses import dataclass
import websockets
import aiohttp
logger = logging.getLogger(__name__)
@ -80,34 +82,35 @@ class WebdriverBidi:
async def ws_reader(self) -> None:
assert self.ws
while True:
try:
msg = json.loads(await self.ws.recv())
except websockets.exceptions.ConnectionClosedOK:
logger.debug("ws_reader connection closed")
break
logger.debug("ws → %r", msg)
if "id" in msg and msg["id"] in self.pending_commands:
logger.debug("ws_reader: resolving pending command %i", msg["id"])
if msg["type"] == "success":
self.pending_commands[msg["id"]].set_result(msg["result"])
async for msg in self.ws:
if msg.type == aiohttp.WSMsgType.TEXT:
data = json.loads(msg.data)
logger.debug("ws TEXT → %r", data)
if "id" in data and data["id"] in self.pending_commands:
logger.debug("ws_reader: resolving pending command %i", data["id"])
if data["type"] == "success":
self.pending_commands[data["id"]].set_result(data["result"])
else:
self.pending_commands[msg["id"]].set_exception(WebdriverError(f"{msg['type']}: {msg['message']}"))
del self.pending_commands[msg["id"]]
self.pending_commands[data["id"]].set_exception(
WebdriverError(f"{data['type']}: {data['message']}"))
del self.pending_commands[data["id"]]
continue
if msg["type"] == "event":
if msg["method"] == "log.entryAdded":
self.logs.append(LogMessage(msg["params"]))
if data["type"] == "event":
if data["method"] == "log.entryAdded":
self.logs.append(LogMessage(data["params"]))
continue
logger.warning("ws_reader: unhandled message %r", msg)
logger.warning("ws_reader: unhandled message %r", data)
elif msg.type == aiohttp.WSMsgType.ERROR:
logger.error("BiDi failure: %s", msg)
break
async def command(self, method, **params) -> asyncio.Future:
assert self.ws
payload = json.dumps({"id": self.last_id, "method": method, "params": params})
logger.debug("ws ← %r", payload)
await self.ws.send(payload)
await self.ws.send_str(payload)
future = asyncio.get_event_loop().create_future()
self.pending_commands[self.last_id] = future
self.last_id += 1
@ -115,7 +118,8 @@ class WebdriverBidi:
async def run(self):
# open bidi websocket for session
async with websockets.connect(self.session_info["capabilities"]["webSocketUrl"]) as ws:
async with aiohttp.ClientSession() as session:
async with session.ws_connect(self.session_info["capabilities"]["webSocketUrl"]) as ws:
self.ws = ws
self.task_reader = asyncio.create_task(self.ws_reader(), name="bidi_reader")