Clean up session tracking

This commit is contained in:
Martin Pitt 2024-07-22 10:30:21 +02:00
parent eb6fd8ac1d
commit 025731f0c8

112
bidi.py
View file

@ -6,7 +6,6 @@ import logging
import subprocess import subprocess
import sys import sys
import time import time
import urllib.request
from dataclasses import dataclass from dataclasses import dataclass
import aiohttp import aiohttp
@ -46,17 +45,28 @@ class LogMessage:
return f"LogMessage: {self.type} {self.level} @{self.timestamp}: {self.text} {self.args}" return f"LogMessage: {self.type} {self.level} @{self.timestamp}: {self.text} {self.args}"
@dataclass
class Session:
http: aiohttp.client.ClientSession
ws: aiohttp.client.ClientWebSocketResponse
session_url: str
task_reader: asyncio.Task
class WebdriverBidi: class WebdriverBidi:
def __init__(self, browser, headless=False) -> None: def __init__(self, browser, headless=False) -> None:
self.headless = headless self.headless = headless
self.last_id = 0
self.pending_commands: dict[int, asyncio.Future] = {}
self.logs: list[LogMessage] = []
self.session: Session | None = None
# TODO: make dynamic # TODO: make dynamic
self.webdriver_port = 12345 self.webdriver_port = 12345
self.webdriver_url = f"http://localhost:{self.webdriver_port}"
chrome_binary = "/usr/lib64/chromium-browser/headless_shell" if self.headless else "/usr/bin/chromium-browser" chrome_binary = "/usr/lib64/chromium-browser/headless_shell" if self.headless else "/usr/bin/chromium-browser"
session_args = {"capabilities": { self.session_args = {"capabilities": {
"alwaysMatch": { "alwaysMatch": {
"webSocketUrl": True, "webSocketUrl": True,
"goog:chromeOptions": {"binary": chrome_binary}, "goog:chromeOptions": {"binary": chrome_binary},
@ -69,32 +79,40 @@ class WebdriverBidi:
except KeyError as e: except KeyError as e:
raise ValueError(f"unknown browser {browser}") from e raise ValueError(f"unknown browser {browser}") from e
req = urllib.request.Request( async def ensure_session(self) -> None:
f"{self.webdriver_url}/session", if self.session is not None:
json.dumps(session_args).encode(), return
headers={"Content-Type": "application/json"})
aiohttp_session = aiohttp.ClientSession(raise_for_status=True)
wd_url = f"http://localhost:{self.webdriver_port}"
# webdriver needs some time to launch # webdriver needs some time to launch
for retry in range(1, 10): for retry in range(1, 10):
try: try:
with urllib.request.urlopen(req) as f: async with aiohttp_session.post(f"{wd_url}/session",
resp = json.load(f) data=json.dumps(self.session_args).encode()) as resp:
session_info = json.loads(await resp.text())["value"]
logger.debug("webdriver session request: %r %r", resp, session_info)
break break
except urllib.error.URLError as e: except (IOError, aiohttp.client.ClientResponseError) as e:
logger.debug("waiting for webdriver: %s", e) logger.debug("waiting for webdriver: %s", e)
time.sleep(0.1 * retry) time.sleep(0.1 * retry)
else: else:
raise WebdriverError("could not connect to webdriver") raise WebdriverError("could not connect to webdriver")
self.session_info = resp["value"] ws = await aiohttp_session.ws_connect(session_info["capabilities"]["webSocketUrl"])
self.last_id = 0
self.ws = None
self.session = None
self.pending_commands: dict[int, asyncio.Future] = {}
self.logs: list[LogMessage] = []
async def ws_reader(self) -> None: self.session = Session(
assert self.ws http=aiohttp_session,
async for msg in self.ws: ws=ws,
session_url=f"{wd_url}/session/{session_info['sessionId']}",
task_reader=asyncio.create_task(self.ws_reader(ws), name="bidi_reader")
)
logger.debug("Established session %r", self.session)
async def ws_reader(self, ws: aiohttp.client.ClientWebSocketResponse) -> None:
async for msg in ws:
if msg.type == aiohttp.WSMsgType.TEXT: if msg.type == aiohttp.WSMsgType.TEXT:
data = json.loads(msg.data) data = json.loads(msg.data)
logger.debug("ws TEXT → %r", data) logger.debug("ws TEXT → %r", data)
@ -121,36 +139,32 @@ class WebdriverBidi:
async def bidi(self, method, **params) -> asyncio.Future: async def bidi(self, method, **params) -> asyncio.Future:
"""Send a Webdriver BiDI command and return the JSON response""" """Send a Webdriver BiDI command and return the JSON response"""
assert self.ws assert self.session
payload = json.dumps({"id": self.last_id, "method": method, "params": params}) payload = json.dumps({"id": self.last_id, "method": method, "params": params})
logger.debug("ws ← %r", payload) logger.debug("ws ← %r", payload)
await self.ws.send_str(payload) await self.session.ws.send_str(payload)
future = asyncio.get_event_loop().create_future() future = asyncio.get_event_loop().create_future()
self.pending_commands[self.last_id] = future self.pending_commands[self.last_id] = future
self.last_id += 1 self.last_id += 1
return await future return await future
async def webdriver(self, path: str, data: dict | None = None) -> dict: async def webdriver(self, path: str, data: dict | None = None, method: str | None = None) -> dict:
"""Send a classic Webdriver request and return the JSON response""" """Send a classic Webdriver request and return the JSON response"""
assert self.session assert self.session
# asyncio shares the connection
url = f"{self.webdriver_url}/session/{self.session_info['sessionId']}/{path}"
post_data = json.dumps(data).encode() if data is not None else None
method = "POST" if post_data is not None else "GET"
async with self.session.request(method, url, data=post_data) as resp: # asyncio shares the connection
post_data = json.dumps(data).encode() if data is not None else None
method = method if method else ("POST" if post_data is not None else "GET")
async with self.session.http.request(method, f"{self.session.session_url}{path}", data=post_data) as resp:
r = await resp.text() r = await resp.text()
logger.debug("webdriver %s %s %r%r", method, path, post_data, r) logger.debug("webdriver %s %s %r%r", method, path, post_data, r)
return json.loads(r) return json.loads(r)
async def run(self): async def run(self):
# open bidi websocket for session await self.ensure_session()
async with aiohttp.ClientSession(raise_for_status=True) as session:
self.session = session
async with session.ws_connect(self.session_info["capabilities"]["webSocketUrl"]) as ws:
self.ws = ws
self.task_reader = asyncio.create_task(self.ws_reader(), name="bidi_reader")
# wait for browser to initialize default context # wait for browser to initialize default context
for _ in range(10): for _ in range(10):
@ -174,7 +188,7 @@ class WebdriverBidi:
menu_content_id = r[0]['sharedId'] menu_content_id = r[0]['sharedId']
# this doensn't yet have a BiDi command # this doensn't yet have a BiDi command
r = await self.webdriver(f"element/{menu_content_id}/text") r = await self.webdriver(f"/element/{menu_content_id}/text")
assert 'ADDICTED TO FREE SOFTWARE DEVELOPMENT' in r['value'] assert 'ADDICTED TO FREE SOFTWARE DEVELOPMENT' in r['value']
# locate first social link # locate first social link
@ -182,17 +196,26 @@ class WebdriverBidi:
locator={"type": "css", "value": "a[rel='me']:first-child"}))["nodes"] locator={"type": "css", "value": "a[rel='me']:first-child"}))["nodes"]
assert len(r) == 1 assert len(r) == 1
# click it (again, no BiDi command) # click it (again, no BiDi command)
await self.webdriver(f"element/{r[0]['sharedId']}/click", {}) await self.webdriver(f"/element/{r[0]['sharedId']}/click", {})
if not self.headless: if not self.headless:
await asyncio.sleep(3) await asyncio.sleep(3)
self.task_reader.cancel()
del self.task_reader
def __del__(self): async def __aenter__(self):
return self
async def __aexit__(self, *_excinfo):
if self.session is not None:
logger.debug("cleaning up webdriver") logger.debug("cleaning up webdriver")
urllib.request.urlopen(urllib.request.Request(
f"{self.webdriver_url}/session/{self.session_info['sessionId']}", method="DELETE")) self.session.task_reader.cancel()
del self.session.task_reader
await self.webdriver("", method="DELETE")
await self.session.ws.close()
await self.session.http.close()
self.session = None
logger.info("Collected debug messages:") logger.info("Collected debug messages:")
for log in self.logs: for log in self.logs:
@ -202,7 +225,10 @@ class WebdriverBidi:
self.driver.wait() self.driver.wait()
async def main():
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
d = WebdriverBidi(sys.argv[1] if len(sys.argv) > 1 else 'chromium', async with WebdriverBidi(sys.argv[1] if len(sys.argv) > 1 else 'chromium',
headless=True if len(sys.argv) > 2 and sys.argv[2] == 'headless' else False) headless=True if len(sys.argv) > 2 and sys.argv[2] == 'headless' else False) as d:
asyncio.run(d.run()) await d.run()
asyncio.run(main())