Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port to pydantic 2 #192

Merged
merged 24 commits into from
Nov 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
For pre-1.0 releases, see [0.0.35 Changelog](https://github.com/noteable-io/origami/blob/0.0.35/CHANGELOG.md)

## [Unreleased]
### Changed
- Upgraded pydantic to 2.4.2 up from 1.X.

### [1.1.5] - 2023-11-06
### Fixed
Expand Down
30 changes: 15 additions & 15 deletions origami/clients/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ async def user_info(self) -> User:
endpoint = "/users/me"
resp = await self.client.get(endpoint)
resp.raise_for_status()
user = User.parse_obj(resp.json())
user = User.model_validate(resp.json())
self.add_tags_and_contextvars(user_id=str(user.id))
return user

Expand Down Expand Up @@ -191,7 +191,7 @@ async def create_space(self, name: str, description: Optional[str] = None) -> Sp
endpoint = "/spaces"
resp = await self.client.post(endpoint, json={"name": name, "description": description})
resp.raise_for_status()
space = Space.parse_obj(resp.json())
space = Space.model_validate(resp.json())
self.add_tags_and_contextvars(space_id=str(space.id))
return space

Expand All @@ -200,7 +200,7 @@ async def get_space(self, space_id: uuid.UUID) -> Space:
endpoint = f"/spaces/{space_id}"
resp = await self.client.get(endpoint)
resp.raise_for_status()
space = Space.parse_obj(resp.json())
space = Space.model_validate(resp.json())
return space

async def delete_space(self, space_id: uuid.UUID) -> None:
Expand All @@ -216,7 +216,7 @@ async def list_space_projects(self, space_id: uuid.UUID) -> List[Project]:
endpoint = f"/spaces/{space_id}/projects"
resp = await self.client.get(endpoint)
resp.raise_for_status()
projects = [Project.parse_obj(project) for project in resp.json()]
projects = [Project.model_validate(project) for project in resp.json()]
return projects

async def share_space(
Expand Down Expand Up @@ -267,7 +267,7 @@ async def create_project(
},
)
resp.raise_for_status()
project = Project.parse_obj(resp.json())
project = Project.model_validate(resp.json())
self.add_tags_and_contextvars(project_id=str(project.id))
return project

Expand All @@ -276,15 +276,15 @@ async def get_project(self, project_id: uuid.UUID) -> Project:
endpoint = f"/projects/{project_id}"
resp = await self.client.get(endpoint)
resp.raise_for_status()
project = Project.parse_obj(resp.json())
project = Project.model_validate(resp.json())
return project

async def delete_project(self, project_id: uuid.UUID) -> Project:
self.add_tags_and_contextvars(project_id=str(project_id))
endpoint = f"/projects/{project_id}"
resp = await self.client.delete(endpoint)
resp.raise_for_status()
project = Project.parse_obj(resp.json())
project = Project.model_validate(resp.json())
return project

async def share_project(
Expand Down Expand Up @@ -323,7 +323,7 @@ async def list_project_files(self, project_id: uuid.UUID) -> List[File]:
endpoint = f"/projects/{project_id}/files"
resp = await self.client.get(endpoint)
resp.raise_for_status()
files = [File.parse_obj(file) for file in resp.json()]
files = [File.model_validate(file) for file in resp.json()]
return files

# Files are flat files (like text, csv, etc) or Notebooks.
Expand Down Expand Up @@ -355,7 +355,7 @@ async def _multi_step_file_create(
upload_url = js["presigned_upload_url_info"]["parts"][0]["upload_url"]
upload_id = js["presigned_upload_url_info"]["upload_id"]
upload_key = js["presigned_upload_url_info"]["key"]
file = File.parse_obj(js)
file = File.model_validate(js)

# (2) Upload to pre-signed url
# TODO: remove this hack if/when we get containers in Skaffold to be able to translate
Expand Down Expand Up @@ -393,7 +393,7 @@ async def create_notebook(
self.add_tags_and_contextvars(project_id=str(project_id))
if notebook is None:
notebook = Notebook()
content = notebook.json().encode()
content = notebook.model_dump_json().encode()
file = await self._multi_step_file_create(project_id, path, "notebook", content)
self.add_tags_and_contextvars(file_id=str(file.id))
logger.info("Created new notebook", extra={"file_id": str(file.id)})
Expand All @@ -405,7 +405,7 @@ async def get_file(self, file_id: uuid.UUID) -> File:
endpoint = f"/v1/files/{file_id}"
resp = await self.client.get(endpoint)
resp.raise_for_status()
file = File.parse_obj(resp.json())
file = File.model_validate(resp.json())
return file

async def get_file_content(self, file_id: uuid.UUID) -> bytes:
Expand Down Expand Up @@ -433,15 +433,15 @@ async def get_file_versions(self, file_id: uuid.UUID) -> List[FileVersion]:
endpoint = f"/files/{file_id}/versions"
resp = await self.client.get(endpoint)
resp.raise_for_status()
versions = [FileVersion.parse_obj(version) for version in resp.json()]
versions = [FileVersion.model_validate(version) for version in resp.json()]
return versions

async def delete_file(self, file_id: uuid.UUID) -> File:
self.add_tags_and_contextvars(file_id=str(file_id))
endpoint = f"/v1/files/{file_id}"
resp = await self.client.delete(endpoint)
resp.raise_for_status()
file = File.parse_obj(resp.json())
file = File.model_validate(resp.json())
return file

async def share_file(
Expand Down Expand Up @@ -497,7 +497,7 @@ async def launch_kernel(
}
resp = await self.client.post(endpoint, json=data)
resp.raise_for_status()
kernel_session = KernelSession.parse_obj(resp.json())
kernel_session = KernelSession.model_validate(resp.json())
self.add_tags_and_contextvars(kernel_session_id=str(kernel_session.id))
logger.info(
"Launched new kernel",
Expand All @@ -517,7 +517,7 @@ async def get_output_collection(
endpoint = f"/outputs/collection/{output_collection_id}"
resp = await self.client.get(endpoint)
resp.raise_for_status()
return KernelOutputCollection.parse_obj(resp.json())
return KernelOutputCollection.model_validate(resp.json())

async def connect_realtime(self, file: Union[File, uuid.UUID, str]) -> "RTUClient": # noqa
"""
Expand Down
28 changes: 16 additions & 12 deletions origami/clients/rtu.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import httpx
import orjson
from pydantic import BaseModel, parse_obj_as
from sending.backends.websocket import WebsocketManager
from websockets.client import WebSocketClientProtocol

Expand Down Expand Up @@ -51,7 +50,7 @@
KernelStatusUpdateResponse,
)
from origami.models.rtu.channels.system import AuthenticateReply, AuthenticateRequest
from origami.models.rtu.discriminators import RTURequest, RTUResponse
from origami.models.rtu.discriminators import RTURequest, RTUResponse, RTUResponseParser
from origami.models.rtu.errors import InconsistentStateEvent
from origami.notebook.builder import CellNotFound, NotebookBuilder

Expand Down Expand Up @@ -87,7 +86,8 @@ async def inbound_message_hook(self, contents: str) -> RTUResponse:
# to error or BaseRTUResponse)
data: dict = orjson.loads(contents)
data["channel_prefix"] = data.get("channel", "").split("/")[0]
rtu_event = parse_obj_as(RTUResponse, data)

rtu_event = RTUResponseParser.validate_python(data)

# Debug Logging
extra_dict = {
Expand All @@ -98,15 +98,18 @@ async def inbound_message_hook(self, contents: str) -> RTUResponse:
if isinstance(rtu_event, NewDeltaEvent):
extra_dict["delta_type"] = rtu_event.data.delta_type
extra_dict["delta_action"] = rtu_event.data.delta_action
logger.debug(f"Received: {data}\nParsed: {rtu_event.dict()}", extra=extra_dict)

if logging.DEBUG >= logging.root.level:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't make the logging call that includes dumping the model if our current logging level is lower than DEBUG.

logger.debug(f"Received: {data}\nParsed: {rtu_event.model_dump()}", extra=extra_dict)

return rtu_event

async def outbound_message_hook(self, contents: RTURequest) -> str:
"""
Hook applied to every message we send out over the websocket.
- Anything calling .send() should pass in an RTU Request pydantic model
"""
return contents.json()
return contents.model_dump_json()

def send(self, message: RTURequest) -> None:
"""Override WebsocketManager-defined method for type hinting and logging."""
Expand All @@ -118,7 +121,9 @@ def send(self, message: RTURequest) -> None:
if message.event == "new_delta_request":
extra_dict["delta_type"] = message.data.delta.delta_type
extra_dict["delta_action"] = message.data.delta.delta_action

logger.debug("Sending: RTU request", extra=extra_dict)

super().send(message) # the .outbound_message_hook handles serializing this to json

async def on_exception(self, exc: Exception):
Expand All @@ -143,11 +148,10 @@ class DeltaRejected(Exception):


# Used in registering callback functions that get called right after squashing a Delta
class DeltaCallback(BaseModel):
# callback function should be async and expect one argument: a FileDelta
# Doesn't matter what it returns. Pydantic doesn't validate Callable args/return.
delta_class: Type[FileDelta]
fn: Callable[[FileDelta], Awaitable[None]]
class DeltaCallback:
def __init__(self, delta_class: Type[FileDelta], fn: Callable[[FileDelta], Awaitable[None]]):
self.delta_class = delta_class
self.fn = fn


class DeltaRequestCallbackManager:
Expand Down Expand Up @@ -455,7 +459,7 @@ async def load_seed_notebook(self):
resp = await plain_http_client.get(file.presigned_download_url)
resp.raise_for_status()

seed_notebook = Notebook.parse_obj(resp.json())
seed_notebook = Notebook.model_validate(resp.json())
self.builder = NotebookBuilder(seed_notebook=seed_notebook)

# See Sending backends.websocket for details but a quick refresher on hook timing:
Expand Down Expand Up @@ -494,7 +498,7 @@ async def auth_hook(self, *args, **kwargs):
# we observe the auth reply. Instead use the unauth_ws directly and manually serialize
ws: WebSocketClientProtocol = await self.manager.unauth_ws
logger.info(f"Sending auth request with jwt {jwt[:5]}...{jwt[-5:]}")
await ws.send(auth_request.json())
await ws.send(auth_request.model_dump_json())

async def on_auth(self, msg: AuthenticateReply):
# hook for Application code to override, consider catastrophic failure on auth failure
Expand Down
2 changes: 1 addition & 1 deletion origami/models/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@ class ResourceBase(BaseModel):
id: uuid.UUID
created_at: datetime
updated_at: datetime
deleted_at: Optional[datetime]
deleted_at: Optional[datetime] = None
6 changes: 3 additions & 3 deletions origami/models/api/datasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ class DataSource(BaseModel):
type_id: str # e.g. duckdb, postgresql
sql_cell_handle: str # this goes in cell metadata for SQL cells
# One of these three will be not None, and that tells you the scope of the datasource
space_id: Optional[uuid.UUID]
project_id: Optional[uuid.UUID]
user_id: Optional[uuid.UUID]
space_id: Optional[uuid.UUID] = None
project_id: Optional[uuid.UUID] = None
user_id: Optional[uuid.UUID] = None
created_by_id: uuid.UUID
created_at: datetime
updated_at: datetime
Expand Down
10 changes: 6 additions & 4 deletions origami/models/api/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import uuid
from typing import Literal, Optional

from pydantic import validator
from pydantic import model_validator

from origami.models.api.base import ResourceBase

Expand All @@ -22,10 +22,12 @@ class File(ResourceBase):
presigned_download_url: Optional[str] = None
url: Optional[str] = None

@validator("url", always=True)
def construct_url(cls, v, values):
@model_validator(mode="after")
def construct_url(self):
noteable_url = os.environ.get("PUBLIC_NOTEABLE_URL", "https://app.noteable.io")
return f"{noteable_url}/f/{values['id']}/{values['path']}"
self.url = f"{noteable_url}/f/{self.id}/{self.path}"

return self


class FileVersion(ResourceBase):
Expand Down
6 changes: 3 additions & 3 deletions origami/models/api/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@ class KernelOutputContent(BaseModel):

class KernelOutput(ResourceBase):
type: str
display_id: Optional[str]
display_id: Optional[str] = None
available_mimetypes: List[str]
content_metadata: KernelOutputContent
content: Optional[KernelOutputContent]
content_for_llm: Optional[KernelOutputContent]
content: Optional[KernelOutputContent] = None
content_for_llm: Optional[KernelOutputContent] = None
parent_collection_id: uuid.UUID


Expand Down
12 changes: 7 additions & 5 deletions origami/models/api/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,20 @@
import uuid
from typing import Optional

from pydantic import validator
from pydantic import model_validator

from origami.models.api.base import ResourceBase


class Project(ResourceBase):
name: str
description: Optional[str]
description: Optional[str] = None
space_id: uuid.UUID
url: Optional[str] = None

@validator("url", always=True)
def construct_url(cls, v, values):
@model_validator(mode="after")
def construct_url(self):
noteable_url = os.environ.get("PUBLIC_NOTEABLE_URL", "https://app.noteable.io")
return f"{noteable_url}/p/{values['id']}"
self.url = f"{noteable_url}/p/{self.id}"

return self
12 changes: 7 additions & 5 deletions origami/models/api/spaces.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import os
from typing import Optional

from pydantic import validator
from pydantic import model_validator

from origami.models.api.base import ResourceBase


class Space(ResourceBase):
name: str
description: Optional[str]
description: Optional[str] = None
url: Optional[str] = None

@validator("url", always=True)
def construct_url(cls, v, values):
@model_validator(mode="after")
def construct_url(self):
noteable_url = os.environ.get("PUBLIC_NOTEABLE_URL", "https://app.noteable.io")
return f"{noteable_url}/s/{values['id']}"
self.url = f"{noteable_url}/s/{self.id}"

return self
20 changes: 11 additions & 9 deletions origami/models/api/users.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import uuid
from typing import Optional

from pydantic import validator
from pydantic import model_validator

from origami.models.api.base import ResourceBase

Expand All @@ -10,14 +10,16 @@ class User(ResourceBase):
"""The user fields sent to/from the server"""

handle: str
email: Optional[str] # not returned if looking up user other than yourself
email: Optional[str] = None # not returned if looking up user other than yourself
first_name: str
last_name: str
origamist_default_project_id: Optional[uuid.UUID]
principal_sub: Optional[str] # from /users/me only, represents auth type
auth_type: Optional[str]
origamist_default_project_id: Optional[uuid.UUID] = None
principal_sub: Optional[str] = None # from /users/me only, represents auth type
auth_type: Optional[str] = None

@validator("auth_type", always=True)
def construct_auth_type(cls, v, values):
if values.get("principal_sub"):
return values["principal_sub"].split("|")[0]
@model_validator(mode="after")
def construct_auth_type(self):
if self.principal_sub:
self.auth_type = self.principal_sub.split("|")[0]

return self
6 changes: 3 additions & 3 deletions origami/models/deltas/delta_types/cell_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class CellMetadataDelta(FileDeltaBase):
# A lot of state is stored in cell metadata, including DEX and execute time
class CellMetadataUpdateProperties(BaseModel):
path: list
value: Any
value: Any = None
prior_value: Any = NULL_PRIOR_VALUE_SENTINEL


Expand All @@ -26,8 +26,8 @@ class CellMetadataUpdate(CellMetadataDelta):

# Cell metadata replace is used for changing cell type and language (Python/R/etc)
class CellMetadataReplaceProperties(BaseModel):
type: Optional[str]
language: Optional[str]
type: Optional[str] = None
language: Optional[str] = None


class CellMetadataReplace(CellMetadataDelta):
Expand Down
Loading