diff --git a/bidi.py b/bidi.py index b0e994c..2937ac7 100755 --- a/bidi.py +++ b/bidi.py @@ -6,7 +6,6 @@ import logging import subprocess import sys import time -import urllib.request from dataclasses import dataclass import aiohttp @@ -46,17 +45,28 @@ class LogMessage: return f"LogMessage: {self.type} {self.level} @{self.timestamp}: {self.text} {self.args}" +@dataclass +class Session: + http: aiohttp.client.ClientSession + ws: aiohttp.client.ClientWebSocketResponse + session_url: str + task_reader: asyncio.Task + + class WebdriverBidi: def __init__(self, browser, headless=False) -> None: self.headless = headless + self.last_id = 0 + self.pending_commands: dict[int, asyncio.Future] = {} + self.logs: list[LogMessage] = [] + self.session: Session | None = None # TODO: make dynamic self.webdriver_port = 12345 - self.webdriver_url = f"http://localhost:{self.webdriver_port}" chrome_binary = "/usr/lib64/chromium-browser/headless_shell" if self.headless else "/usr/bin/chromium-browser" - session_args = {"capabilities": { + self.session_args = {"capabilities": { "alwaysMatch": { "webSocketUrl": True, "goog:chromeOptions": {"binary": chrome_binary}, @@ -69,32 +79,40 @@ class WebdriverBidi: except KeyError as e: raise ValueError(f"unknown browser {browser}") from e - req = urllib.request.Request( - f"{self.webdriver_url}/session", - json.dumps(session_args).encode(), - headers={"Content-Type": "application/json"}) + async def ensure_session(self) -> None: + if self.session is not None: + return + + aiohttp_session = aiohttp.ClientSession(raise_for_status=True) + wd_url = f"http://localhost:{self.webdriver_port}" + # webdriver needs some time to launch for retry in range(1, 10): try: - with urllib.request.urlopen(req) as f: - resp = json.load(f) + async with aiohttp_session.post(f"{wd_url}/session", + data=json.dumps(self.session_args).encode()) as resp: + session_info = json.loads(await resp.text())["value"] + logger.debug("webdriver session request: %r %r", resp, session_info) break - except urllib.error.URLError as e: + except (IOError, aiohttp.client.ClientResponseError) as e: logger.debug("waiting for webdriver: %s", e) time.sleep(0.1 * retry) else: raise WebdriverError("could not connect to webdriver") - self.session_info = resp["value"] - self.last_id = 0 - self.ws = None - self.session = None - self.pending_commands: dict[int, asyncio.Future] = {} - self.logs: list[LogMessage] = [] + ws = await aiohttp_session.ws_connect(session_info["capabilities"]["webSocketUrl"]) - async def ws_reader(self) -> None: - assert self.ws - async for msg in self.ws: + self.session = Session( + http=aiohttp_session, + ws=ws, + session_url=f"{wd_url}/session/{session_info['sessionId']}", + task_reader=asyncio.create_task(self.ws_reader(ws), name="bidi_reader") + ) + + logger.debug("Established session %r", self.session) + + async def ws_reader(self, ws: aiohttp.client.ClientWebSocketResponse) -> None: + async for msg in ws: if msg.type == aiohttp.WSMsgType.TEXT: data = json.loads(msg.data) logger.debug("ws TEXT → %r", data) @@ -121,78 +139,83 @@ class WebdriverBidi: async def bidi(self, method, **params) -> asyncio.Future: """Send a Webdriver BiDI command and return the JSON response""" - assert self.ws + assert self.session + payload = json.dumps({"id": self.last_id, "method": method, "params": params}) logger.debug("ws ← %r", payload) - await self.ws.send_str(payload) + await self.session.ws.send_str(payload) future = asyncio.get_event_loop().create_future() self.pending_commands[self.last_id] = future self.last_id += 1 return await future - async def webdriver(self, path: str, data: dict | None = None) -> dict: + async def webdriver(self, path: str, data: dict | None = None, method: str | None = None) -> dict: """Send a classic Webdriver request and return the JSON response""" assert self.session - # asyncio shares the connection - url = f"{self.webdriver_url}/session/{self.session_info['sessionId']}/{path}" - post_data = json.dumps(data).encode() if data is not None else None - method = "POST" if post_data is not None else "GET" - async with self.session.request(method, url, data=post_data) as resp: + # asyncio shares the connection + post_data = json.dumps(data).encode() if data is not None else None + method = method if method else ("POST" if post_data is not None else "GET") + + async with self.session.http.request(method, f"{self.session.session_url}{path}", data=post_data) as resp: r = await resp.text() logger.debug("webdriver %s %s %r → %r", method, path, post_data, r) return json.loads(r) async def run(self): - # open bidi websocket for session - async with aiohttp.ClientSession(raise_for_status=True) as session: - self.session = session - async with session.ws_connect(self.session_info["capabilities"]["webSocketUrl"]) as ws: - self.ws = ws - self.task_reader = asyncio.create_task(self.ws_reader(), name="bidi_reader") + await self.ensure_session() - # wait for browser to initialize default context - for _ in range(10): - realms = (await self.bidi("script.getRealms"))["realms"] - if len(realms) > 0: - context = realms[0]["context"] - break - else: - raise WebdriverError("timed out waiting for default realm") + # wait for browser to initialize default context + for _ in range(10): + realms = (await self.bidi("script.getRealms"))["realms"] + if len(realms) > 0: + context = realms[0]["context"] + break + else: + raise WebdriverError("timed out waiting for default realm") - await self.bidi("session.subscribe", events=["log.entryAdded"]) + await self.bidi("session.subscribe", events=["log.entryAdded"]) - await self.bidi("script.evaluate", expression="console.log('Hello BiDi')", - awaitPromise=False, target={"context": context}) - await self.bidi("browsingContext.navigate", context=context, - url="https://piware.de", wait="complete") + await self.bidi("script.evaluate", expression="console.log('Hello BiDi')", + awaitPromise=False, target={"context": context}) + await self.bidi("browsingContext.navigate", context=context, + url="https://piware.de", wait="complete") - r = (await self.bidi("browsingContext.locateNodes", context=context, - locator={"type": "css", "value": "#menu-content"}))["nodes"] - assert len(r) == 1 - menu_content_id = r[0]['sharedId'] + r = (await self.bidi("browsingContext.locateNodes", context=context, + locator={"type": "css", "value": "#menu-content"}))["nodes"] + assert len(r) == 1 + menu_content_id = r[0]['sharedId'] - # this doensn't yet have a BiDi command - r = await self.webdriver(f"element/{menu_content_id}/text") - assert 'ADDICTED TO FREE SOFTWARE DEVELOPMENT' in r['value'] + # this doensn't yet have a BiDi command + r = await self.webdriver(f"/element/{menu_content_id}/text") + assert 'ADDICTED TO FREE SOFTWARE DEVELOPMENT' in r['value'] - # locate first social link - r = (await self.bidi("browsingContext.locateNodes", context=context, - locator={"type": "css", "value": "a[rel='me']:first-child"}))["nodes"] - assert len(r) == 1 - # click it (again, no BiDi command) - await self.webdriver(f"element/{r[0]['sharedId']}/click", {}) + # locate first social link + r = (await self.bidi("browsingContext.locateNodes", context=context, + locator={"type": "css", "value": "a[rel='me']:first-child"}))["nodes"] + assert len(r) == 1 + # click it (again, no BiDi command) + await self.webdriver(f"/element/{r[0]['sharedId']}/click", {}) - if not self.headless: - await asyncio.sleep(3) - self.task_reader.cancel() - del self.task_reader + if not self.headless: + await asyncio.sleep(3) - def __del__(self): - logger.debug("cleaning up webdriver") - urllib.request.urlopen(urllib.request.Request( - f"{self.webdriver_url}/session/{self.session_info['sessionId']}", method="DELETE")) + async def __aenter__(self): + return self + + async def __aexit__(self, *_excinfo): + if self.session is not None: + logger.debug("cleaning up webdriver") + + self.session.task_reader.cancel() + del self.session.task_reader + + await self.webdriver("", method="DELETE") + await self.session.ws.close() + await self.session.http.close() + + self.session = None logger.info("Collected debug messages:") for log in self.logs: @@ -202,7 +225,10 @@ class WebdriverBidi: self.driver.wait() -logging.basicConfig(level=logging.DEBUG) -d = WebdriverBidi(sys.argv[1] if len(sys.argv) > 1 else 'chromium', - headless=True if len(sys.argv) > 2 and sys.argv[2] == 'headless' else False) -asyncio.run(d.run()) +async def main(): + logging.basicConfig(level=logging.DEBUG) + async with WebdriverBidi(sys.argv[1] if len(sys.argv) > 1 else 'chromium', + headless=True if len(sys.argv) > 2 and sys.argv[2] == 'headless' else False) as d: + await d.run() + +asyncio.run(main())