Skip to content

Commit

Permalink
reverting to main
Browse files Browse the repository at this point in the history
  • Loading branch information
ilongin committed Jan 27, 2025
1 parent ccc79fa commit 3fd19b6
Show file tree
Hide file tree
Showing 6 changed files with 44 additions and 30 deletions.
2 changes: 1 addition & 1 deletion src/datachain/data_storage/db_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def has_table(self, name: str) -> bool:
return sa.inspect(self.engine).has_table(name)

@abstractmethod
def create_table(self, table: "Table", if_not_exists: bool = True) -> "Table": ...
def create_table(self, table: "Table", if_not_exists: bool = True) -> None: ...

@abstractmethod
def drop_table(self, table: "Table", if_exists: bool = False) -> None: ...
Expand Down
3 changes: 1 addition & 2 deletions src/datachain/data_storage/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,9 +276,8 @@ def has_table(self, name: str) -> bool:
)
return bool(next(self.execute(query))[0])

def create_table(self, table: "Table", if_not_exists: bool = True) -> "Table":
def create_table(self, table: "Table", if_not_exists: bool = True) -> None:
self.execute(CreateTable(table, if_not_exists=if_not_exists))
return table

def drop_table(self, table: "Table", if_exists: bool = False) -> None:
self.execute(DropTable(table, if_exists=if_exists))
Expand Down
1 change: 0 additions & 1 deletion src/datachain/diff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,6 @@ def _to_list(obj: Union[str, Sequence[str]]) -> list[str]:
)

def _default_val(chain: "DataChain", col: str):
print(chain._query.column_types)
col_type = chain._query.column_types[col] # type: ignore[index]
val = sa.literal(col_type.default_value(dialect)).label(col)
val.type = col_type()
Expand Down
35 changes: 21 additions & 14 deletions tests/func/test_datachain_merge.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from datachain.lib.dc import DataChain
from datachain.sql.types import Int


@pytest.mark.parametrize(
Expand All @@ -10,13 +11,16 @@
)
@pytest.mark.parametrize("inner", [True, False])
def test_merge_union(cloud_test_catalog, inner, cloud_type):
catalog = cloud_test_catalog.catalog
session = cloud_test_catalog.session

src = cloud_test_catalog.src_uri

dogs = DataChain.from_storage(f"{src}/dogs/*", session=session)
cats = DataChain.from_storage(f"{src}/cats/*", session=session)

signal_default_value = Int.default_value(catalog.warehouse.db.dialect)

dogs1 = dogs.map(sig1=lambda: 1, output={"sig1": int})
dogs2 = dogs.map(sig2=lambda: 2, output={"sig2": int})
cats1 = cats.map(sig1=lambda: 1, output={"sig1": int})
Expand All @@ -33,8 +37,8 @@ def test_merge_union(cloud_test_catalog, inner, cloud_type):
]
else:
assert signals == [
("cats/cat1", 1, None),
("cats/cat2", 1, None),
("cats/cat1", 1, signal_default_value),
("cats/cat2", 1, signal_default_value),
("dogs/dog1", 1, 2),
("dogs/dog2", 1, 2),
("dogs/dog3", 1, 2),
Expand All @@ -51,13 +55,16 @@ def test_merge_union(cloud_test_catalog, inner, cloud_type):
@pytest.mark.parametrize("inner2", [True, False])
@pytest.mark.parametrize("inner3", [True, False])
def test_merge_multiple(cloud_test_catalog, inner1, inner2, inner3):
catalog = cloud_test_catalog.catalog
session = cloud_test_catalog.session

src = cloud_test_catalog.src_uri

dogs = DataChain.from_storage(f"{src}/dogs/*", session=session)
cats = DataChain.from_storage(f"{src}/cats/*", session=session)

signal_default_value = Int.default_value(catalog.warehouse.db.dialect)

dogs_and_cats = dogs | cats
dogs1 = dogs.map(sig1=lambda: 1, output={"sig1": int})
cats1 = cats.map(sig2=lambda: 2, output={"sig2": int})
Expand All @@ -73,22 +80,22 @@ def test_merge_multiple(cloud_test_catalog, inner1, inner2, inner3):
assert merged_signals == []
elif inner1:
assert merged_signals == [
("dogs/dog1", 1, None),
("dogs/dog2", 1, None),
("dogs/dog3", 1, None),
("dogs/others/dog4", 1, None),
("dogs/dog1", 1, signal_default_value),
("dogs/dog2", 1, signal_default_value),
("dogs/dog3", 1, signal_default_value),
("dogs/others/dog4", 1, signal_default_value),
]
elif inner2 and inner3:
assert merged_signals == [
("cats/cat1", None, 2),
("cats/cat2", None, 2),
("cats/cat1", signal_default_value, 2),
("cats/cat2", signal_default_value, 2),
]
else:
assert merged_signals == [
("cats/cat1", None, 2),
("cats/cat2", None, 2),
("dogs/dog1", 1, None),
("dogs/dog2", 1, None),
("dogs/dog3", 1, None),
("dogs/others/dog4", 1, None),
("cats/cat1", signal_default_value, 2),
("cats/cat2", signal_default_value, 2),
("dogs/dog1", 1, signal_default_value),
("dogs/dog2", 1, signal_default_value),
("dogs/dog3", 1, signal_default_value),
("dogs/others/dog4", 1, signal_default_value),
]
15 changes: 9 additions & 6 deletions tests/func/test_dataset_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -743,9 +743,10 @@ def test_join_with_binary_expression(
("dogs/others/dog4", "dogs/others/dog4"),
]
else:
string_default = String.default_value(catalog.warehouse.db.dialect)
expected = [
("cats/cat1", None),
("cats/cat2", None),
("cats/cat1", string_default),
("cats/cat2", string_default),
("dogs/dog1", "dogs/dog1"),
("dogs/dog2", "dogs/dog2"),
("dogs/dog3", "dogs/dog3"),
Expand Down Expand Up @@ -792,9 +793,10 @@ def test_join_with_combination_binary_expression_and_column_predicates(
("dogs/others/dog4", "dogs/others/dog4"),
]
else:
string_default = String.default_value(catalog.warehouse.db.dialect)
expected = [
("cats/cat1", None),
("cats/cat2", None),
("cats/cat1", string_default),
("cats/cat2", string_default),
("dogs/dog1", "dogs/dog1"),
("dogs/dog2", "dogs/dog2"),
("dogs/dog3", "dogs/dog3"),
Expand Down Expand Up @@ -916,9 +918,10 @@ def test_join_with_using_functions_in_expression(
("dogs/others/dog4", "dogs/others/dog4"),
]
else:
string_default = String.default_value(catalog.warehouse.db.dialect)
expected = [
("cats/cat1", None),
("cats/cat2", None),
("cats/cat1", string_default),
("cats/cat2", string_default),
("dogs/dog1", "dogs/dog1"),
("dogs/dog2", "dogs/dog2"),
("dogs/dog3", "dogs/dog3"),
Expand Down
18 changes: 12 additions & 6 deletions tests/unit/lib/test_datachain_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sqlalchemy import func

from datachain.lib.dc import C, DataChain, DatasetMergeError
from datachain.sql.types import Int, String
from tests.utils import skip_if_not_sqlite


Expand Down Expand Up @@ -51,6 +52,8 @@ def test_merge_objects(test_session):
ch2 = DataChain.from_values(team=team, session=test_session)
ch = ch1.merge(ch2, "emp.person.name", "team.player")

str_default = String.default_value(test_session.catalog.warehouse.db.dialect)

i = 0
j = 0
for items in ch.order_by("emp.person.name", "team.player").collect():
Expand All @@ -69,8 +72,8 @@ def test_merge_objects(test_session):
assert math.isclose(player.height, team[j].height, rel_tol=1e-7)
j += 1
else:
assert player.player is None
assert player.sport is None
assert player.player == str_default
assert player.sport == str_default
assert pd.isnull(player.weight)
assert pd.isnull(player.height)

Expand All @@ -92,6 +95,9 @@ def test_merge_objects_full_join(test_session, multiple_predicates):
else:
ch = ch1.merge(ch2, "emp.person.name", "team.player", full=True)

str_default = String.default_value(test_session.catalog.warehouse.db.dialect)
int_default = Int.default_value(test_session.catalog.warehouse.db.dialect)

i = 0
for items in ch.order_by("emp.person.name", "team.player").collect():
assert len(items) == 2
Expand All @@ -101,13 +107,13 @@ def test_merge_objects_full_join(test_session, multiple_predicates):
assert isinstance(player, TeamMember)

if player.player == "John":
assert empl.person.name is None
assert empl.person.age is None
assert empl.person.name == str_default
assert empl.person.age == int_default
continue

if empl.person.name == "Bob":
assert player.player is None
assert player.sport is None
assert player.player == str_default
assert player.sport == str_default
assert pd.isnull(player.weight)
assert pd.isnull(player.height)
continue
Expand Down

0 comments on commit 3fd19b6

Please sign in to comment.