Clean up session tracking
This commit is contained in:
parent
eb6fd8ac1d
commit
025731f0c8
1 changed files with 97 additions and 71 deletions
168
bidi.py
168
bidi.py
|
|
@ -6,7 +6,6 @@ import logging
|
|||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import urllib.request
|
||||
from dataclasses import dataclass
|
||||
|
||||
import aiohttp
|
||||
|
|
@ -46,17 +45,28 @@ class LogMessage:
|
|||
return f"LogMessage: {self.type} {self.level} @{self.timestamp}: {self.text} {self.args}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Session:
|
||||
http: aiohttp.client.ClientSession
|
||||
ws: aiohttp.client.ClientWebSocketResponse
|
||||
session_url: str
|
||||
task_reader: asyncio.Task
|
||||
|
||||
|
||||
class WebdriverBidi:
|
||||
def __init__(self, browser, headless=False) -> None:
|
||||
self.headless = headless
|
||||
self.last_id = 0
|
||||
self.pending_commands: dict[int, asyncio.Future] = {}
|
||||
self.logs: list[LogMessage] = []
|
||||
self.session: Session | None = None
|
||||
|
||||
# TODO: make dynamic
|
||||
self.webdriver_port = 12345
|
||||
self.webdriver_url = f"http://localhost:{self.webdriver_port}"
|
||||
|
||||
chrome_binary = "/usr/lib64/chromium-browser/headless_shell" if self.headless else "/usr/bin/chromium-browser"
|
||||
|
||||
session_args = {"capabilities": {
|
||||
self.session_args = {"capabilities": {
|
||||
"alwaysMatch": {
|
||||
"webSocketUrl": True,
|
||||
"goog:chromeOptions": {"binary": chrome_binary},
|
||||
|
|
@ -69,32 +79,40 @@ class WebdriverBidi:
|
|||
except KeyError as e:
|
||||
raise ValueError(f"unknown browser {browser}") from e
|
||||
|
||||
req = urllib.request.Request(
|
||||
f"{self.webdriver_url}/session",
|
||||
json.dumps(session_args).encode(),
|
||||
headers={"Content-Type": "application/json"})
|
||||
async def ensure_session(self) -> None:
|
||||
if self.session is not None:
|
||||
return
|
||||
|
||||
aiohttp_session = aiohttp.ClientSession(raise_for_status=True)
|
||||
wd_url = f"http://localhost:{self.webdriver_port}"
|
||||
|
||||
# webdriver needs some time to launch
|
||||
for retry in range(1, 10):
|
||||
try:
|
||||
with urllib.request.urlopen(req) as f:
|
||||
resp = json.load(f)
|
||||
async with aiohttp_session.post(f"{wd_url}/session",
|
||||
data=json.dumps(self.session_args).encode()) as resp:
|
||||
session_info = json.loads(await resp.text())["value"]
|
||||
logger.debug("webdriver session request: %r %r", resp, session_info)
|
||||
break
|
||||
except urllib.error.URLError as e:
|
||||
except (IOError, aiohttp.client.ClientResponseError) as e:
|
||||
logger.debug("waiting for webdriver: %s", e)
|
||||
time.sleep(0.1 * retry)
|
||||
else:
|
||||
raise WebdriverError("could not connect to webdriver")
|
||||
|
||||
self.session_info = resp["value"]
|
||||
self.last_id = 0
|
||||
self.ws = None
|
||||
self.session = None
|
||||
self.pending_commands: dict[int, asyncio.Future] = {}
|
||||
self.logs: list[LogMessage] = []
|
||||
ws = await aiohttp_session.ws_connect(session_info["capabilities"]["webSocketUrl"])
|
||||
|
||||
async def ws_reader(self) -> None:
|
||||
assert self.ws
|
||||
async for msg in self.ws:
|
||||
self.session = Session(
|
||||
http=aiohttp_session,
|
||||
ws=ws,
|
||||
session_url=f"{wd_url}/session/{session_info['sessionId']}",
|
||||
task_reader=asyncio.create_task(self.ws_reader(ws), name="bidi_reader")
|
||||
)
|
||||
|
||||
logger.debug("Established session %r", self.session)
|
||||
|
||||
async def ws_reader(self, ws: aiohttp.client.ClientWebSocketResponse) -> None:
|
||||
async for msg in ws:
|
||||
if msg.type == aiohttp.WSMsgType.TEXT:
|
||||
data = json.loads(msg.data)
|
||||
logger.debug("ws TEXT → %r", data)
|
||||
|
|
@ -121,78 +139,83 @@ class WebdriverBidi:
|
|||
async def bidi(self, method, **params) -> asyncio.Future:
|
||||
"""Send a Webdriver BiDI command and return the JSON response"""
|
||||
|
||||
assert self.ws
|
||||
assert self.session
|
||||
|
||||
payload = json.dumps({"id": self.last_id, "method": method, "params": params})
|
||||
logger.debug("ws ← %r", payload)
|
||||
await self.ws.send_str(payload)
|
||||
await self.session.ws.send_str(payload)
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
self.pending_commands[self.last_id] = future
|
||||
self.last_id += 1
|
||||
return await future
|
||||
|
||||
async def webdriver(self, path: str, data: dict | None = None) -> dict:
|
||||
async def webdriver(self, path: str, data: dict | None = None, method: str | None = None) -> dict:
|
||||
"""Send a classic Webdriver request and return the JSON response"""
|
||||
|
||||
assert self.session
|
||||
# asyncio shares the connection
|
||||
url = f"{self.webdriver_url}/session/{self.session_info['sessionId']}/{path}"
|
||||
post_data = json.dumps(data).encode() if data is not None else None
|
||||
method = "POST" if post_data is not None else "GET"
|
||||
|
||||
async with self.session.request(method, url, data=post_data) as resp:
|
||||
# asyncio shares the connection
|
||||
post_data = json.dumps(data).encode() if data is not None else None
|
||||
method = method if method else ("POST" if post_data is not None else "GET")
|
||||
|
||||
async with self.session.http.request(method, f"{self.session.session_url}{path}", data=post_data) as resp:
|
||||
r = await resp.text()
|
||||
logger.debug("webdriver %s %s %r → %r", method, path, post_data, r)
|
||||
return json.loads(r)
|
||||
|
||||
async def run(self):
|
||||
# open bidi websocket for session
|
||||
async with aiohttp.ClientSession(raise_for_status=True) as session:
|
||||
self.session = session
|
||||
async with session.ws_connect(self.session_info["capabilities"]["webSocketUrl"]) as ws:
|
||||
self.ws = ws
|
||||
self.task_reader = asyncio.create_task(self.ws_reader(), name="bidi_reader")
|
||||
await self.ensure_session()
|
||||
|
||||
# wait for browser to initialize default context
|
||||
for _ in range(10):
|
||||
realms = (await self.bidi("script.getRealms"))["realms"]
|
||||
if len(realms) > 0:
|
||||
context = realms[0]["context"]
|
||||
break
|
||||
else:
|
||||
raise WebdriverError("timed out waiting for default realm")
|
||||
# wait for browser to initialize default context
|
||||
for _ in range(10):
|
||||
realms = (await self.bidi("script.getRealms"))["realms"]
|
||||
if len(realms) > 0:
|
||||
context = realms[0]["context"]
|
||||
break
|
||||
else:
|
||||
raise WebdriverError("timed out waiting for default realm")
|
||||
|
||||
await self.bidi("session.subscribe", events=["log.entryAdded"])
|
||||
await self.bidi("session.subscribe", events=["log.entryAdded"])
|
||||
|
||||
await self.bidi("script.evaluate", expression="console.log('Hello BiDi')",
|
||||
awaitPromise=False, target={"context": context})
|
||||
await self.bidi("browsingContext.navigate", context=context,
|
||||
url="https://piware.de", wait="complete")
|
||||
await self.bidi("script.evaluate", expression="console.log('Hello BiDi')",
|
||||
awaitPromise=False, target={"context": context})
|
||||
await self.bidi("browsingContext.navigate", context=context,
|
||||
url="https://piware.de", wait="complete")
|
||||
|
||||
r = (await self.bidi("browsingContext.locateNodes", context=context,
|
||||
locator={"type": "css", "value": "#menu-content"}))["nodes"]
|
||||
assert len(r) == 1
|
||||
menu_content_id = r[0]['sharedId']
|
||||
r = (await self.bidi("browsingContext.locateNodes", context=context,
|
||||
locator={"type": "css", "value": "#menu-content"}))["nodes"]
|
||||
assert len(r) == 1
|
||||
menu_content_id = r[0]['sharedId']
|
||||
|
||||
# this doensn't yet have a BiDi command
|
||||
r = await self.webdriver(f"element/{menu_content_id}/text")
|
||||
assert 'ADDICTED TO FREE SOFTWARE DEVELOPMENT' in r['value']
|
||||
# this doensn't yet have a BiDi command
|
||||
r = await self.webdriver(f"/element/{menu_content_id}/text")
|
||||
assert 'ADDICTED TO FREE SOFTWARE DEVELOPMENT' in r['value']
|
||||
|
||||
# locate first social link
|
||||
r = (await self.bidi("browsingContext.locateNodes", context=context,
|
||||
locator={"type": "css", "value": "a[rel='me']:first-child"}))["nodes"]
|
||||
assert len(r) == 1
|
||||
# click it (again, no BiDi command)
|
||||
await self.webdriver(f"element/{r[0]['sharedId']}/click", {})
|
||||
# locate first social link
|
||||
r = (await self.bidi("browsingContext.locateNodes", context=context,
|
||||
locator={"type": "css", "value": "a[rel='me']:first-child"}))["nodes"]
|
||||
assert len(r) == 1
|
||||
# click it (again, no BiDi command)
|
||||
await self.webdriver(f"/element/{r[0]['sharedId']}/click", {})
|
||||
|
||||
if not self.headless:
|
||||
await asyncio.sleep(3)
|
||||
self.task_reader.cancel()
|
||||
del self.task_reader
|
||||
if not self.headless:
|
||||
await asyncio.sleep(3)
|
||||
|
||||
def __del__(self):
|
||||
logger.debug("cleaning up webdriver")
|
||||
urllib.request.urlopen(urllib.request.Request(
|
||||
f"{self.webdriver_url}/session/{self.session_info['sessionId']}", method="DELETE"))
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *_excinfo):
|
||||
if self.session is not None:
|
||||
logger.debug("cleaning up webdriver")
|
||||
|
||||
self.session.task_reader.cancel()
|
||||
del self.session.task_reader
|
||||
|
||||
await self.webdriver("", method="DELETE")
|
||||
await self.session.ws.close()
|
||||
await self.session.http.close()
|
||||
|
||||
self.session = None
|
||||
|
||||
logger.info("Collected debug messages:")
|
||||
for log in self.logs:
|
||||
|
|
@ -202,7 +225,10 @@ class WebdriverBidi:
|
|||
self.driver.wait()
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
d = WebdriverBidi(sys.argv[1] if len(sys.argv) > 1 else 'chromium',
|
||||
headless=True if len(sys.argv) > 2 and sys.argv[2] == 'headless' else False)
|
||||
asyncio.run(d.run())
|
||||
async def main():
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
async with WebdriverBidi(sys.argv[1] if len(sys.argv) > 1 else 'chromium',
|
||||
headless=True if len(sys.argv) > 2 and sys.argv[2] == 'headless' else False) as d:
|
||||
await d.run()
|
||||
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue