From c1496f15a30b7c1e504386e8682766a7a367a297 Mon Sep 17 00:00:00 2001 From: Martin Pitt Date: Tue, 23 Jul 2024 07:46:12 +0200 Subject: [PATCH] Refactor bidi session --- bidi.py | 97 +++++++++++++++++++++++++-------------------------------- 1 file changed, 42 insertions(+), 55 deletions(-) diff --git a/bidi.py b/bidi.py index 014ad46..e0326a3 100755 --- a/bidi.py +++ b/bidi.py @@ -38,11 +38,10 @@ class LogMessage: @dataclass -class Session: - http: aiohttp.client.ClientSession - ws: aiohttp.client.ClientWebSocketResponse +class BidiSession: + ws_url: str session_url: str - task_reader: asyncio.Task + process: asyncio.subprocess.Process # Return port numbers that were free at the time of checking @@ -69,32 +68,41 @@ class WebdriverBidi: self.last_id = 0 self.pending_commands: dict[int, asyncio.Future] = {} self.logs: list[LogMessage] = [] - self.session: Session | None = None + self.bidi_session: BidiSession | None = None self.future_wait_page_load = None - async def start_session(self) -> None: + async def start_bidi_session(self) -> None: raise NotImplementedError('must be implemented by concrete subclass') - async def close_session(self) -> None: + async def close_bidi_session(self) -> None: raise NotImplementedError('must be implemented by concrete subclass') async def close(self): - assert self.session is not None + assert self.bidi_session is not None logger.debug("cleaning up webdriver") - self.session.task_reader.cancel() - del self.session.task_reader - await self.session.ws.close() - await self.close_session() - await self.session.http.close() - self.session = None + self.task_reader.cancel() + del self.task_reader + await self.ws.close() + await self.close_bidi_session() + self.bidi_session.process.terminate() + await self.bidi_session.process.wait() + self.bidi_session = None + await self.http_session.close() + + async def start_session(self) -> None: + self.http_session = aiohttp.ClientSession(raise_for_status=True) + await self.start_bidi_session() + assert self.bidi_session + self.ws = await self.http_session.ws_connect(self.bidi_session.ws_url) + self.task_reader = asyncio.create_task(self.ws_reader(self.ws), name="bidi_reader") async def __aenter__(self): await self.start_session() return self async def __aexit__(self, *_excinfo): - if self.session is not None: + if self.bidi_session is not None: await self.close() async def ws_reader(self, ws: aiohttp.client.ClientWebSocketResponse) -> None: @@ -128,13 +136,11 @@ class WebdriverBidi: break async def bidi(self, method, **params) -> asyncio.Future: - """Send a Webdriver BiDI command and return the JSON response""" - - assert self.session + """Send a Webdriver BiDi command and return the JSON response""" payload = json.dumps({"id": self.last_id, "method": method, "params": params}) logger.debug("ws ← %r", payload) - await self.session.ws.send_str(payload) + await self.ws.send_str(payload) future = asyncio.get_event_loop().create_future() self.pending_commands[self.last_id] = future self.last_id += 1 @@ -208,8 +214,8 @@ class WebdriverBidi: class ChromiumBidi(WebdriverBidi): - async def start_session(self) -> None: - assert self.session is None + async def start_bidi_session(self) -> None: + assert self.bidi_session is None chrome_binary = "/usr/lib64/chromium-browser/headless_shell" if self.headless else "/usr/bin/chromium-browser" @@ -221,16 +227,15 @@ class ChromiumBidi(WebdriverBidi): }} [webdriver_port] = pick_ports(1) - self.driver = await asyncio.create_subprocess_exec("chromedriver", "--port=" + str(webdriver_port)) + driver = await asyncio.create_subprocess_exec("chromedriver", "--port=" + str(webdriver_port)) - aiohttp_session = aiohttp.ClientSession(raise_for_status=True) wd_url = f"http://localhost:{webdriver_port}" # webdriver needs some time to launch for retry in range(1, 10): try: - async with aiohttp_session.post(f"{wd_url}/session", - data=json.dumps(session_args).encode()) as resp: + async with self.http_session.post(f"{wd_url}/session", + data=json.dumps(session_args).encode()) as resp: session_info = json.loads(await resp.text())["value"] logger.debug("webdriver session request: %r %r", resp, session_info) break @@ -240,27 +245,20 @@ class ChromiumBidi(WebdriverBidi): else: raise WebdriverError("could not connect to chromedriver") - ws = await aiohttp_session.ws_connect(session_info["capabilities"]["webSocketUrl"]) - - self.session = Session( - http=aiohttp_session, - ws=ws, + self.bidi_session = BidiSession( session_url=f"{wd_url}/session/{session_info['sessionId']}", - task_reader=asyncio.create_task(self.ws_reader(ws), name="bidi_reader") - ) + ws_url=session_info["capabilities"]["webSocketUrl"], + process=driver) + logger.debug("Established chromium session %r", self.bidi_session) - logger.debug("Established chromium session %r", self.session) - - async def close_session(self): - await self.session.http.delete(self.session.session_url) - self.driver.terminate() - await self.driver.wait() + async def close_bidi_session(self): + await self.http_session.delete(self.bidi_session.session_url) # We could do this with https://github.com/mozilla/geckodriver/releases with a similar protocol as ChromeBidi # But let's use https://firefox-source-docs.mozilla.org/testing/marionette/Protocol.html directly, fewer moving parts class FirefoxBidi(WebdriverBidi): - async def start_session(self) -> None: + async def start_bidi_session(self) -> None: [marionette_port, bidi_port] = pick_ports(2) self.homedir = tempfile.TemporaryDirectory(prefix="firefox-home-") @@ -284,7 +282,7 @@ class FirefoxBidi(WebdriverBidi): user_pref('marionette.port', {marionette_port}); """) - self.driver = await asyncio.create_subprocess_exec( + driver = await asyncio.create_subprocess_exec( "firefox", "-profile", str(self.profiledir), "--marionette", "--no-remote", f"--remote-debugging-port={bidi_port}", *(["-headless"] if self.headless else []), "about:blank") @@ -314,24 +312,13 @@ class FirefoxBidi(WebdriverBidi): raise WebdriverError(f"unexpected marionette session request reply: {reply!r}") logger.debug("marionette session request reply: %s", reply) - aiohttp_session = aiohttp.ClientSession(raise_for_status=True) - ws_url = reply[3]["capabilities"]["webSocketUrl"] - ws = await aiohttp_session.ws_connect(ws_url) + url = reply[3]["capabilities"]["webSocketUrl"] + self.bidi_session = BidiSession(session_url=url, ws_url=url, process=driver) + logger.debug("Established firefox session %r", self.bidi_session) - self.session = Session( - http=aiohttp_session, - ws=ws, - session_url=ws_url, - task_reader=asyncio.create_task(self.ws_reader(ws), name="bidi_reader") - ) - - logger.debug("Established firefox session %r", self.session) - - async def close_session(self): + async def close_bidi_session(self): self.writer_marionette.close() await self.writer_marionette.wait_closed() - self.driver.terminate() - self.driver.wait() async def main():