Skip to content

Commit

Permalink
Rehydrate RTU session on error and add backoff (#77)
Browse files Browse the repository at this point in the history
* Rehydrate RTU session on error and add backoff

* rename task variable

* remove ensure open calls -- eafp

* refactor

* Bump version: 0.0.16 → 0.0.17

* add changelog
  • Loading branch information
rohitsanj authored Jan 4, 2023
1 parent 546495a commit 11ecb40
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 0.0.16
current_version = 0.0.17
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)
serialize =
{major}.{minor}.{patch}
Expand Down
10 changes: 10 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

## [0.0.17] - 2023-01-04
### Added
- Handle websocket ConnectionClosedErrors in the process messages infinite loop:
- reconnect to the RTU websocket
- handle authentication
- resubscribe to all the previously subscribed channels
- Add backoff retry to `send_rtu_request` when a `ConnectionClosedError` is raised, and reconnect to RTU before retrying.
- Add backoff retry to `update_job_instance` on `ReadTimeout` error
- Add backoff retry to `get_or_launch_ready_kernel_session` on `TimeoutError`.

## [0.0.16] - 2022-12-02
### Fixed
- Fix API incompatibility when creating a parameterized notebook
Expand Down
2 changes: 1 addition & 1 deletion origami/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
version = "0.0.16"
version = "0.0.17"
62 changes: 54 additions & 8 deletions origami/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
import jwt
import structlog
import websockets
import websockets.exceptions
from httpx import ReadTimeout
from nbclient.util import run_sync
from pydantic import BaseModel, BaseSettings, ValidationError
from websockets.legacy.client import WebSocketClientProtocol

from origami.defs.deltas import NBMetadataProperties, V2CellMetadataProperties
from origami.defs.rtu import BulkCellStateMessage
Expand Down Expand Up @@ -162,7 +164,7 @@ def __init__(
self.token = api_token or os.getenv("NOTEABLE_TOKEN") or self.get_token()
if isinstance(self.token, str):
self.token = Token(access_token=self.token)
self.rtu_socket = None
self.rtu_socket: WebSocketClientProtocol = None
self.process_task_loop = None

headers = kwargs.pop('headers', {})
Expand All @@ -181,6 +183,8 @@ def __init__(
**kwargs,
)

self.reconnect_rtu_task = None

@property
def origin(self):
"""Formats the domain in an origin string for websocket headers."""
Expand Down Expand Up @@ -272,6 +276,9 @@ async def launch_kernel_session(
self.file_session_cache[file.id] = session
return session

@backoff.on_exception(
backoff.expo, asyncio.exceptions.TimeoutError, max_time=EXP_BACKOFF_MAX_TIME
)
async def get_or_launch_ready_kernel_session(
self,
file: NotebookFile,
Expand Down Expand Up @@ -397,6 +404,7 @@ async def create_job_instance(
return CustomerJobInstanceReference.parse_obj(resp.json())

@_default_timeout_arg
@backoff.on_exception(backoff.expo, ReadTimeout, max_time=EXP_BACKOFF_MAX_TIME)
async def update_job_instance(
self,
job_instance_attempt_id: uuid.UUID,
Expand Down Expand Up @@ -427,9 +435,8 @@ async def __aenter__(self):
validate and extract principal-user-id from the token.
"""
res = await httpx.AsyncClient.__aenter__(self)
# Origin is needed, else the server request crashes and rejects the connection
headers = {'Authorization': self.headers['authorization'], 'Origin': self.origin}
self.rtu_socket = await websockets.connect(self.ws_uri, extra_headers=headers)
# Create a new RTU socket for this context
await self._connect_rtu_socket()
# Loop indefinitely over the incoming websocket messages
self.process_task_loop = asyncio.create_task(self._process_messages())
# Authenticate for more advanced API calls
Expand Down Expand Up @@ -595,20 +602,59 @@ async def _process_messages(self):
logger.debug(
f"Callable for {channel}/{event} was a {'successful' if processed else 'failed'} match"
)

except websockets.exceptions.ConnectionClosed:
except websockets.exceptions.ConnectionClosedError:
await asyncio.sleep(0)
logger.exception("Websocket connection closed unexpectedly; reconnecting")
await self._reconnect_rtu()
continue
except websockets.exceptions.ConnectionClosedOK:
await asyncio.sleep(0)
break
except Exception:
logger.exception("Unexpected callback failure")
await asyncio.sleep(0)
break

async def _connect_rtu_socket(self):
"""Opens a websocket connection to Noteable RTU."""
# Origin is needed, else the server request crashes and rejects the connection
headers = {'Authorization': self.headers['authorization'], 'Origin': self.origin}
self.rtu_socket = await websockets.connect(self.ws_uri, extra_headers=headers)
logger.debug("Opened websocket connection")

async def _resubscribe_channels(self):
"""Rehydrates the session by re-authenticating and re-subscribing to all channels."""
await self.authenticate()
for channel in self.subscriptions:
# Only re-subscribe to the files channels
if channel.startswith("files/"):
await self.subscribe_file(channel.split('/')[1])

async def _reconnect_rtu(self):
"""Reconnects the RTU websocket connection."""
await self._connect_rtu_socket()
if not self.reconnect_rtu_task:
self.reconnect_rtu_task = asyncio.create_task(self._resubscribe_channels())

# set rehydrate_task to None so that the next connection failure will trigger a new rehydrate
self.reconnect_rtu_task = None

@_requires_ws_context
@backoff.on_exception(
backoff.expo, websockets.exceptions.ConnectionClosedError, max_time=EXP_BACKOFF_MAX_TIME
)
async def send_rtu_request(self, req: GenericRTURequestSchema):
"""Wraps converting a pydantic request model to be send down the websocket."""
"""Wraps converting a pydantic request model to be sent down the websocket."""
logger.debug(f"Sending websocket request: {req}")
return await self.rtu_socket.send(req.json())
try:
return await self.rtu_socket.send(req.json())
except websockets.exceptions.ConnectionClosedError:
logger.debug(
"Websocket connection closed unexpectedly while trying to send RTU request; reconnecting"
)
await self._reconnect_rtu()
# Raise the exception to trigger backoff
raise

@_requires_ws_context
@_default_timeout_arg
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

[tool.poetry]
name = "noteable-origami"
version = "0.0.16"
version = "0.0.17"
description = "The Noteable API interface"
authors = ["Matt Seal <[email protected]>"]
maintainers = ["Matt Seal <[email protected]>"]
Expand Down

0 comments on commit 11ecb40

Please sign in to comment.