Skip to content
Snippets Groups Projects
Commit 27e0b877 authored by Tulir Asokan's avatar Tulir Asokan :cat2:
Browse files

Automatically reconnect if puppet script goes away

parent 66091b28
No related branches found
No related tags found
No related merge requests found
Pipeline #2860 passed
......@@ -33,6 +33,9 @@ class LoginComplete(Exception):
class Client(RPCClient):
async def start(self, session_data: Any = None) -> StartStatus:
await self.connect()
await self.send_start(session_data)
async def send_start(self, session_data: Any = None) -> StartStatus:
return StartStatus.deserialize(await self.request("start", session_data=session_data))
async def stop(self) -> None:
......
......@@ -39,6 +39,7 @@ class RPCClient:
_min_broadcast_id: int
_response_waiters: Dict[int, asyncio.Future]
_event_handlers: Dict[str, List[EventHandler]]
_connect_task: Optional[asyncio.Task]
def __init__(self, user_id: UserID) -> None:
self.log = self.log.getChild(user_id)
......@@ -50,6 +51,7 @@ class RPCClient:
self._response_waiters = {}
self._writer = None
self._reader = None
self._connect_task = None
async def _open_connection(self) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
if self.config["puppeteer.connection.type"] == "unix":
......@@ -60,17 +62,56 @@ class RPCClient:
else:
raise RuntimeError("invalid puppeteer connection type")
@property
def _pretty_connect_str(self) -> str:
if self.config["puppeteer.connection.type"] == "unix":
path = self.config["puppeteer.connection.path"]
return f"unix://{path}"
elif self.config["puppeteer.connection.type"] == "tcp":
host = self.config["puppeteer.connection.host"]
port = self.config["puppeteer.connection.port"]
return f"tcp://{host}:{port}"
else:
raise RuntimeError("invalid puppeteer connection type")
async def connect(self) -> None:
if self._writer is not None:
return
self._reader, self._writer = await self._open_connection()
self.loop.create_task(self._try_read_loop())
await self.request("register", user_id=self.user_id)
initial_connect = self.loop.create_future()
self._connect_task = asyncio.create_task(self._reconnect_forever(initial_connect))
await initial_connect
async def _reconnect_forever(self, initial_connect: asyncio.Future) -> None:
while True:
try:
self._reader, self._writer = await self._open_connection()
except OSError as e:
self.log.error(f"Connection to {self._pretty_connect_str} failed: {e}")
await asyncio.sleep(5)
continue
self._min_broadcast_id = 0
read_loop = asyncio.create_task(self._try_read_loop())
await self.request("register", user_id=self.user_id)
if initial_connect:
self.log.debug("RPC connected")
initial_connect.set_result(None)
initial_connect = None
else:
self.log.debug("RPC reconnected")
await self._run_internal_handler("reconnect")
await read_loop
self.log.debug("RPC disconnected")
await self._run_internal_handler("disconnect")
async def disconnect(self) -> None:
self._writer.write_eof()
await self._writer.drain()
if self._writer is not None:
self._writer.write_eof()
await self._writer.drain()
if self._connect_task:
self._connect_task.cancel()
self._connect_task = None
self._writer = None
self._reader = None
......@@ -85,6 +126,18 @@ class RPCClient:
def remove_event_handler(self, method: str, handler: EventHandler) -> None:
self._event_handlers.setdefault(method, []).remove(handler)
async def _run_internal_handler(self, event: str) -> None:
try:
handlers = self._event_handlers[event]
except KeyError:
self.log.warning("No handlers for %s", event)
else:
for handler in handlers:
try:
await handler({})
except Exception:
self.log.exception("Exception in event handler")
async def _run_event_handler(self, req_id: int, command: str, req: Dict[str, Any]) -> None:
if req_id > self._min_broadcast_id:
self.log.debug(f"Ignoring duplicate broadcast {req_id}")
......@@ -136,7 +189,11 @@ class RPCClient:
async def _read_loop(self) -> None:
while self._reader is not None and not self._reader.at_eof():
line = await self._reader.readline()
try:
line = await self._reader.readline()
except ConnectionResetError as e:
self.log.error(f"Server closed connection unexpectedly: {e}")
break
if not line:
continue
try:
......
......@@ -92,11 +92,20 @@ class User(DBUser, BaseUser):
self.client.on_message(self.handle_message)
self.client.on_chat_update(self.handle_chat_update)
self.client.add_event_handler("session_data", self.handle_session_data)
self.client.add_event_handler("disconnect", self.handle_disconnect)
self.client.add_event_handler("reconnect", self.handle_reconnect)
# if state.is_connected:
# self._track_metric(METRIC_CONNECTED, True)
# if state.is_logged_in:
# self.loop.create_task(self._try_sync())
async def handle_disconnect(self, _) -> None:
pass
async def handle_reconnect(self, _) -> None:
self.log.debug("Re-sending session data to puppet script after reconnect")
await self.client.send_start(session_data=self.session_data)
# async def _try_sync(self) -> None:
# try:
# await self.sync()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment