diff --git a/bidi.py b/bidi.py old mode 100644 new mode 100755 index 19ae842..37106b3 --- a/bidi.py +++ b/bidi.py @@ -1,3 +1,5 @@ +#!/usr/bin/python3 + import asyncio import json import logging @@ -6,7 +8,7 @@ import time import urllib.request from dataclasses import dataclass -import websockets +import aiohttp logger = logging.getLogger(__name__) @@ -80,34 +82,35 @@ class WebdriverBidi: async def ws_reader(self) -> None: assert self.ws - while True: - try: - msg = json.loads(await self.ws.recv()) - except websockets.exceptions.ConnectionClosedOK: - logger.debug("ws_reader connection closed") - break - logger.debug("ws → %r", msg) - if "id" in msg and msg["id"] in self.pending_commands: - logger.debug("ws_reader: resolving pending command %i", msg["id"]) - if msg["type"] == "success": - self.pending_commands[msg["id"]].set_result(msg["result"]) - else: - self.pending_commands[msg["id"]].set_exception(WebdriverError(f"{msg['type']}: {msg['message']}")) - del self.pending_commands[msg["id"]] - continue - - if msg["type"] == "event": - if msg["method"] == "log.entryAdded": - self.logs.append(LogMessage(msg["params"])) + async for msg in self.ws: + if msg.type == aiohttp.WSMsgType.TEXT: + data = json.loads(msg.data) + logger.debug("ws TEXT → %r", data) + if "id" in data and data["id"] in self.pending_commands: + logger.debug("ws_reader: resolving pending command %i", data["id"]) + if data["type"] == "success": + self.pending_commands[data["id"]].set_result(data["result"]) + else: + self.pending_commands[data["id"]].set_exception( + WebdriverError(f"{data['type']}: {data['message']}")) + del self.pending_commands[data["id"]] continue - logger.warning("ws_reader: unhandled message %r", msg) + if data["type"] == "event": + if data["method"] == "log.entryAdded": + self.logs.append(LogMessage(data["params"])) + continue + + logger.warning("ws_reader: unhandled message %r", data) + elif msg.type == aiohttp.WSMsgType.ERROR: + logger.error("BiDi failure: %s", msg) + break async def command(self, method, **params) -> asyncio.Future: assert self.ws payload = json.dumps({"id": self.last_id, "method": method, "params": params}) logger.debug("ws ← %r", payload) - await self.ws.send(payload) + await self.ws.send_str(payload) future = asyncio.get_event_loop().create_future() self.pending_commands[self.last_id] = future self.last_id += 1 @@ -115,19 +118,20 @@ class WebdriverBidi: async def run(self): # open bidi websocket for session - async with websockets.connect(self.session_info["capabilities"]["webSocketUrl"]) as ws: - self.ws = ws - self.task_reader = asyncio.create_task(self.ws_reader(), name="bidi_reader") + async with aiohttp.ClientSession() as session: + async with session.ws_connect(self.session_info["capabilities"]["webSocketUrl"]) as ws: + self.ws = ws + self.task_reader = asyncio.create_task(self.ws_reader(), name="bidi_reader") - await self.command("session.subscribe", events=["log.entryAdded"]) - context = (await self.command("browsingContext.create", type="tab"))["context"] - await self.command("script.evaluate", expression="console.log('Hello BiDi')", - awaitPromise=False, target={"context": context}) - await self.command("browsingContext.navigate", context=context, url="https://piware.de") + await self.command("session.subscribe", events=["log.entryAdded"]) + context = (await self.command("browsingContext.create", type="tab"))["context"] + await self.command("script.evaluate", expression="console.log('Hello BiDi')", + awaitPromise=False, target={"context": context}) + await self.command("browsingContext.navigate", context=context, url="https://piware.de") - await asyncio.sleep(5) - self.task_reader.cancel() - del self.task_reader + await asyncio.sleep(5) + self.task_reader.cancel() + del self.task_reader def __del__(self): logger.debug("cleaning up webdriver")