Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port to pydantic 2 #192

Merged
merged 24 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ For pre-1.0 releases, see [0.0.35 Changelog](https://github.com/noteable-io/orig
### Added
- Environ variable `NOTEABLE_RTU_URL` to override RTU websocket, primarily for when apps are running in-cluster with Gate and need to use the http vs websocket service DNS

### Changed
- Upgraded pydantic to 2.1.4.

### [1.1.2] - 2023-10-12
### Added
- Environ variable `NOTEABLE_RTU_URL` to override RTU websocket, primarily for when apps are running in-cluster with Gate and need to use the http vs websocket service DNS
Expand Down Expand Up @@ -70,4 +73,4 @@ For pre-1.0 releases, see [0.0.35 Changelog](https://github.com/noteable-io/orig
### Added
- `APIClient` and `RTUClient` for HTTP and Websocket connections to Noteables API respectively
- Discriminated-union Pydantic modeling for RTU and Delta payloads
- End-to-end tests to run against a Noteable deployment
- End-to-end tests to run against a Noteable deployment
30 changes: 15 additions & 15 deletions origami/clients/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async def user_info(self) -> User:
endpoint = "/users/me"
resp = await self.client.get(endpoint)
resp.raise_for_status()
user = User.parse_obj(resp.json())
user = User.model_validate(resp.json())
self.add_tags_and_contextvars(user_id=str(user.id))
return user

Expand All @@ -72,7 +72,7 @@ async def create_space(self, name: str, description: Optional[str] = None) -> Sp
endpoint = "/spaces"
resp = await self.client.post(endpoint, json={"name": name, "description": description})
resp.raise_for_status()
space = Space.parse_obj(resp.json())
space = Space.model_validate(resp.json())
self.add_tags_and_contextvars(space_id=str(space.id))
return space

Expand All @@ -81,7 +81,7 @@ async def get_space(self, space_id: uuid.UUID) -> Space:
endpoint = f"/spaces/{space_id}"
resp = await self.client.get(endpoint)
resp.raise_for_status()
space = Space.parse_obj(resp.json())
space = Space.model_validate(resp.json())
return space

async def delete_space(self, space_id: uuid.UUID) -> None:
Expand All @@ -97,7 +97,7 @@ async def list_space_projects(self, space_id: uuid.UUID) -> List[Project]:
endpoint = f"/spaces/{space_id}/projects"
resp = await self.client.get(endpoint)
resp.raise_for_status()
projects = [Project.parse_obj(project) for project in resp.json()]
projects = [Project.model_validate(project) for project in resp.json()]
return projects

# Projects are collections of Files, including Notebooks. When a Kernel is launched for a
Expand All @@ -118,7 +118,7 @@ async def create_project(
},
)
resp.raise_for_status()
project = Project.parse_obj(resp.json())
project = Project.model_validate(resp.json())
self.add_tags_and_contextvars(project_id=str(project.id))
return project

Expand All @@ -127,15 +127,15 @@ async def get_project(self, project_id: uuid.UUID) -> Project:
endpoint = f"/projects/{project_id}"
resp = await self.client.get(endpoint)
resp.raise_for_status()
project = Project.parse_obj(resp.json())
project = Project.model_validate(resp.json())
return project

async def delete_project(self, project_id: uuid.UUID) -> Project:
self.add_tags_and_contextvars(project_id=str(project_id))
endpoint = f"/projects/{project_id}"
resp = await self.client.delete(endpoint)
resp.raise_for_status()
project = Project.parse_obj(resp.json())
project = Project.model_validate(resp.json())
return project

async def list_project_files(self, project_id: uuid.UUID) -> List[File]:
Expand All @@ -144,7 +144,7 @@ async def list_project_files(self, project_id: uuid.UUID) -> List[File]:
endpoint = f"/projects/{project_id}/files"
resp = await self.client.get(endpoint)
resp.raise_for_status()
files = [File.parse_obj(file) for file in resp.json()]
files = [File.model_validate(file) for file in resp.json()]
return files

# Files are flat files (like text, csv, etc) or Notebooks.
Expand Down Expand Up @@ -176,7 +176,7 @@ async def _multi_step_file_create(
upload_url = js["presigned_upload_url_info"]["parts"][0]["upload_url"]
upload_id = js["presigned_upload_url_info"]["upload_id"]
upload_key = js["presigned_upload_url_info"]["key"]
file = File.parse_obj(js)
file = File.model_validate(js)

# (2) Upload to pre-signed url
# TODO: remove this hack if/when we get containers in Skaffold to be able to translate
Expand Down Expand Up @@ -214,7 +214,7 @@ async def create_notebook(
self.add_tags_and_contextvars(project_id=str(project_id))
if notebook is None:
notebook = Notebook()
content = notebook.json().encode()
content = notebook.model_dump_json().encode()
file = await self._multi_step_file_create(project_id, path, "notebook", content)
self.add_tags_and_contextvars(file_id=str(file.id))
logger.info("Created new notebook", extra={"file_id": str(file.id)})
Expand All @@ -226,7 +226,7 @@ async def get_file(self, file_id: uuid.UUID) -> File:
endpoint = f"/v1/files/{file_id}"
resp = await self.client.get(endpoint)
resp.raise_for_status()
file = File.parse_obj(resp.json())
file = File.model_validate(resp.json())
return file

async def get_file_content(self, file_id: uuid.UUID) -> bytes:
Expand Down Expand Up @@ -254,15 +254,15 @@ async def get_file_versions(self, file_id: uuid.UUID) -> List[FileVersion]:
endpoint = f"/files/{file_id}/versions"
resp = await self.client.get(endpoint)
resp.raise_for_status()
versions = [FileVersion.parse_obj(version) for version in resp.json()]
versions = [FileVersion.model_validate(version) for version in resp.json()]
return versions

async def delete_file(self, file_id: uuid.UUID) -> File:
self.add_tags_and_contextvars(file_id=str(file_id))
endpoint = f"/v1/files/{file_id}"
resp = await self.client.delete(endpoint)
resp.raise_for_status()
file = File.parse_obj(resp.json())
file = File.model_validate(resp.json())
return file

async def get_datasources_for_notebook(self, file_id: uuid.UUID) -> List[DataSource]:
Expand All @@ -288,7 +288,7 @@ async def launch_kernel(
}
resp = await self.client.post(endpoint, json=data)
resp.raise_for_status()
kernel_session = KernelSession.parse_obj(resp.json())
kernel_session = KernelSession.model_validate(resp.json())
self.add_tags_and_contextvars(kernel_session_id=str(kernel_session.id))
logger.info(
"Launched new kernel",
Expand All @@ -308,7 +308,7 @@ async def get_output_collection(
endpoint = f"/outputs/collection/{output_collection_id}"
resp = await self.client.get(endpoint)
resp.raise_for_status()
return KernelOutputCollection.parse_obj(resp.json())
return KernelOutputCollection.model_validate(resp.json())

async def connect_realtime(self, file: Union[File, uuid.UUID, str]) -> "RTUClient": # noqa
"""
Expand Down
30 changes: 22 additions & 8 deletions origami/clients/rtu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import httpx
import orjson
from pydantic import BaseModel, parse_obj_as
from sending.backends.websocket import WebsocketManager
from websockets.client import WebSocketClientProtocol

Expand Down Expand Up @@ -51,7 +50,7 @@
KernelStatusUpdateResponse,
)
from origami.models.rtu.channels.system import AuthenticateReply, AuthenticateRequest
from origami.models.rtu.discriminators import RTURequest, RTUResponse
from origami.models.rtu.discriminators import RTURequest, RTUResponse, RTUResponseParser
from origami.models.rtu.errors import InconsistentStateEvent
from origami.notebook.builder import CellNotFound, NotebookBuilder

Expand Down Expand Up @@ -87,7 +86,8 @@ async def inbound_message_hook(self, contents: str) -> RTUResponse:
# to error or BaseRTUResponse)
data: dict = orjson.loads(contents)
data["channel_prefix"] = data.get("channel", "").split("/")[0]
rtu_event = parse_obj_as(RTUResponse, data)

rtu_event = RTUResponseParser.validate_python(data)

# Debug Logging
extra_dict = {
Expand All @@ -98,15 +98,18 @@ async def inbound_message_hook(self, contents: str) -> RTUResponse:
if isinstance(rtu_event, NewDeltaEvent):
extra_dict["delta_type"] = rtu_event.data.delta_type
extra_dict["delta_action"] = rtu_event.data.delta_action
logger.debug(f"Received: {data}\nParsed: {rtu_event.dict()}", extra=extra_dict)

if logging.DEBUG >= logging.root.level:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't make the logging call that includes dumping the model if our current logging level is lower than DEBUG.

logger.debug(f"Received: {data}\nParsed: {rtu_event.model_dump()}", extra=extra_dict)

return rtu_event

async def outbound_message_hook(self, contents: RTURequest) -> str:
"""
Hook applied to every message we send out over the websocket.
- Anything calling .send() should pass in an RTU Request pydantic model
"""
return contents.json()
return contents.model_dump_json()

def send(self, message: RTURequest) -> None:
"""Override WebsocketManager-defined method for type hinting and logging."""
Expand All @@ -118,7 +121,9 @@ def send(self, message: RTURequest) -> None:
if message.event == "new_delta_request":
extra_dict["delta_type"] = message.data.delta.delta_type
extra_dict["delta_action"] = message.data.delta.delta_action

logger.debug("Sending: RTU request", extra=extra_dict)

super().send(message) # the .outbound_message_hook handles serializing this to json

async def on_exception(self, exc: Exception):
Expand All @@ -143,12 +148,21 @@ class DeltaRejected(Exception):


# Used in registering callback functions that get called right after squashing a Delta
class DeltaCallback(BaseModel):
class DeltaCallback:
# callback function should be async and expect one argument: a FileDelta
# Doesn't matter what it returns. Pydantic doesn't validate Callable args/return.
delta_class: Type[FileDelta]
fn: Callable[[FileDelta], Awaitable[None]]

def __init__(self, delta_class: Type[FileDelta], fn: Callable[[FileDelta], Awaitable[None]]):
# With pydantic2, raises: "TypeError: Subscripted generics cannot be used with class and instance checks"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to comment out this well meaning check, 'cause pydantic 2 base implementation using parameterized types is incompatible with issubclass.

Would happily be open to suggestion on respelling it to pivot off of, say, a constant in the class or something to get the same effect?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to pair with you on this to better understand what's hpapening

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, whenevs. If left in, the isinstance() check now raises an exception 'cause isinstance() can't be used with parameterized types.

# Sigh.
# if not issubclass(delta_class, FileDelta):
# raise ValueError(f"delta_class must be a FileDelta subclass, got {delta_class}")

self.delta_class = delta_class
self.fn = fn


class DeltaRequestCallbackManager:
"""
Expand Down Expand Up @@ -453,7 +467,7 @@ async def load_seed_notebook(self):
resp = await plain_http_client.get(file.presigned_download_url)
resp.raise_for_status()

seed_notebook = Notebook.parse_obj(resp.json())
seed_notebook = Notebook.model_validate(resp.json())
self.builder = NotebookBuilder(seed_notebook=seed_notebook)

# See Sending backends.websocket for details but a quick refresher on hook timing:
Expand Down Expand Up @@ -492,7 +506,7 @@ async def auth_hook(self, *args, **kwargs):
# we observe the auth reply. Instead use the unauth_ws directly and manually serialize
ws: WebSocketClientProtocol = await self.manager.unauth_ws
logger.info(f"Sending auth request with jwt {jwt[:5]}...{jwt[-5:]}")
await ws.send(auth_request.json())
await ws.send(auth_request.model_dump_json())

async def on_auth(self, msg: AuthenticateReply):
# hook for Application code to override, consider catastrophic failure on auth failure
Expand Down
2 changes: 1 addition & 1 deletion origami/models/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ class ResourceBase(BaseModel):
id: uuid.UUID
created_at: datetime
updated_at: datetime
deleted_at: Optional[datetime]
deleted_at: Optional[datetime] = None
6 changes: 3 additions & 3 deletions origami/models/api/datasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ class DataSource(BaseModel):
type_id: str # e.g. duckdb, postgresql
sql_cell_handle: str # this goes in cell metadata for SQL cells
# One of these three will be not None, and that tells you the scope of the datasource
space_id: Optional[uuid.UUID]
project_id: Optional[uuid.UUID]
user_id: Optional[uuid.UUID]
space_id: Optional[uuid.UUID] = None
project_id: Optional[uuid.UUID] = None
user_id: Optional[uuid.UUID] = None
created_by_id: uuid.UUID
created_at: datetime
updated_at: datetime
Expand Down
10 changes: 6 additions & 4 deletions origami/models/api/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import uuid
from typing import Literal, Optional

from pydantic import validator
from pydantic import model_validator

from origami.models.api.base import ResourceBase

Expand All @@ -22,10 +22,12 @@ class File(ResourceBase):
presigned_download_url: Optional[str] = None
url: Optional[str] = None

@validator("url", always=True)
def construct_url(cls, v, values):
@model_validator(mode="after")
def construct_url(self):
noteable_url = os.environ.get("PUBLIC_NOTEABLE_URL", "https://app.noteable.io")
return f"{noteable_url}/f/{values['id']}/{values['path']}"
self.url = f"{noteable_url}/f/{self.id}/{self.path}"

return self


class FileVersion(ResourceBase):
Expand Down
6 changes: 3 additions & 3 deletions origami/models/api/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ class KernelOutputContent(BaseModel):

class KernelOutput(ResourceBase):
type: str
display_id: Optional[str]
display_id: Optional[str] = None
available_mimetypes: List[str]
content_metadata: KernelOutputContent
content: Optional[KernelOutputContent]
content_for_llm: Optional[KernelOutputContent]
content: Optional[KernelOutputContent] = None
content_for_llm: Optional[KernelOutputContent] = None
parent_collection_id: uuid.UUID


Expand Down
12 changes: 7 additions & 5 deletions origami/models/api/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@
import uuid
from typing import Optional

from pydantic import validator
from pydantic import model_validator

from origami.models.api.base import ResourceBase


class Project(ResourceBase):
name: str
description: Optional[str]
description: Optional[str] = None
space_id: uuid.UUID
url: Optional[str] = None

@validator("url", always=True)
def construct_url(cls, v, values):
@model_validator(mode="after")
def construct_url(self):
noteable_url = os.environ.get("PUBLIC_NOTEABLE_URL", "https://app.noteable.io")
return f"{noteable_url}/p/{values['id']}"
self.url = f"{noteable_url}/p/{self.id}"

return self
12 changes: 7 additions & 5 deletions origami/models/api/spaces.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import os
from typing import Optional

from pydantic import validator
from pydantic import model_validator

from origami.models.api.base import ResourceBase


class Space(ResourceBase):
name: str
description: Optional[str]
description: Optional[str] = None
url: Optional[str] = None

@validator("url", always=True)
def construct_url(cls, v, values):
@model_validator(mode="after")
def construct_url(self):
noteable_url = os.environ.get("PUBLIC_NOTEABLE_URL", "https://app.noteable.io")
return f"{noteable_url}/s/{values['id']}"
self.url = f"{noteable_url}/s/{self.id}"

return self
20 changes: 11 additions & 9 deletions origami/models/api/users.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import uuid
from typing import Optional

from pydantic import validator
from pydantic import model_validator

from origami.models.api.base import ResourceBase

Expand All @@ -10,14 +10,16 @@ class User(ResourceBase):
"""The user fields sent to/from the server"""

handle: str
email: Optional[str] # not returned if looking up user other than yourself
email: Optional[str] = None # not returned if looking up user other than yourself
first_name: str
last_name: str
origamist_default_project_id: Optional[uuid.UUID]
principal_sub: Optional[str] # from /users/me only, represents auth type
auth_type: Optional[str]
origamist_default_project_id: Optional[uuid.UUID] = None
principal_sub: Optional[str] = None # from /users/me only, represents auth type
auth_type: Optional[str] = None

@validator("auth_type", always=True)
def construct_auth_type(cls, v, values):
if values.get("principal_sub"):
return values["principal_sub"].split("|")[0]
@model_validator(mode="after")
def construct_auth_type(self):
if self.principal_sub:
self.auth_type = self.principal_sub.split("|")[0]

return self
Loading