Skip to content

Commit

Permalink
returning old default values
Browse files Browse the repository at this point in the history
  • Loading branch information
ilongin committed Jan 27, 2025
1 parent 3fd19b6 commit 12641b9
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 30 deletions.
29 changes: 10 additions & 19 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,13 +1052,17 @@ def test_parse_nested_json(tmp_dir, test_session):
# E.g. nAmE -> name, l--as@t -> l_as_t, etc
df1 = dc.select("na_me", "age", "city").to_pandas()

# In CH we replace None with '' for peforance reasons,
# have to handle it here
string_default = String.default_value(test_session.catalog.warehouse.db.dialect)

assert sorted(df1["na_me"]["first_select"].to_list()) == sorted(
d["first-SELECT"] for d in df["nA-mE"].to_list()
)
assert sorted(
df1["na_me"]["l_as_t"].to_list(), key=lambda x: (x is None, x)
) == sorted(
[d.get("l--as@t", None) for d in df["nA-mE"].to_list()],
[d.get("l--as@t", string_default) for d in df["nA-mE"].to_list()],
key=lambda x: (x is None, x),
)

Expand Down Expand Up @@ -1300,7 +1304,6 @@ def test_from_csv_null_collect(tmp_dir, test_session):
for i, row in enumerate(dc.collect()):
# None value in numeric column will get converted to nan.
if not height[i]:
print(row[1].height)
assert math.isnan(row[1].height)
else:
assert row[1].height == height[i]
Expand Down Expand Up @@ -1417,6 +1420,10 @@ def test_explode(tmp_dir, test_session, column_type, object_name, model_name):
object_name = object_name or "content_expl"
model_name = model_name or "ContentExplodedModel"

# In CH we have (atm at least) None converted to ''
# for performance reasons, so we need to handle this case
string_default = String.default_value(test_session.catalog.warehouse.db.dialect)

assert set(
dc.collect(
f"{object_name}.na_me.first_select",
Expand All @@ -1426,7 +1433,7 @@ def test_explode(tmp_dir, test_session, column_type, object_name, model_name):
) == {
("Alice", 25, "New York"),
("Bob", 30, "Los Angeles"),
("Charlie", 35, None),
("Charlie", 35, string_default),
("David", 40, "Houston"),
("Eva", 45, "Phoenix"),
("Ivan", 41, "San Francisco"),
Expand Down Expand Up @@ -2090,22 +2097,6 @@ def test_from_values_array_of_floats(test_session):
assert list(chain.order_by("emd").collect("emd")) == embeddings


def test_from_values_array_of_ints_with_nones(test_session):
ids = [1, 2]
embeddings = [[1, None], [4, 5]]
chain = DataChain.from_values(emd=embeddings, ids=ids, session=test_session)

assert list(chain.order_by("ids").collect("emd")) == embeddings


def test_from_values_with_nones(test_session):
ids = [1, 2, 3, 4]
embeddings = [100, None, 300, None]
chain = DataChain.from_values(emd=embeddings, ids=ids, session=test_session)

assert list(chain.order_by("ids").collect("emd")) == [100, None, 300, None]


def test_custom_model_with_nested_lists(test_session):
class Trace(BaseModel):
x: float
Expand Down
29 changes: 18 additions & 11 deletions tests/unit/lib/test_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from datachain.diff import CompareStatus, compare_and_split
from datachain.lib.dc import DataChain
from datachain.lib.file import File
from datachain.sql.types import Int64, String
from tests.utils import sorted_dicts


Expand Down Expand Up @@ -162,13 +163,15 @@ def test_compare_with_explicit_compare_fields(test_session, right_name):
status_col="diff",
)

string_default = String.default_value(test_session.catalog.warehouse.db.dialect)

expected = [
(CompareStatus.MODIFIED, 1, "John1", "New York"),
(CompareStatus.ADDED, 2, "Doe", "Boston"),
(
CompareStatus.DELETED,
3,
None if right_name == "other_name" else "Mark",
string_default if right_name == "other_name" else "Mark",
"Seattle",
),
(CompareStatus.SAME, 4, "Andy", "San Francisco"),
Expand Down Expand Up @@ -199,11 +202,13 @@ def test_compare_different_left_right_on_columns(test_session):
status_col="diff",
)

int_default = Int64.default_value(test_session.catalog.warehouse.db.dialect)

expected = [
(CompareStatus.SAME, 4, "Andy"),
(CompareStatus.ADDED, 2, "Doe"),
(CompareStatus.MODIFIED, 1, "John1"),
(CompareStatus.DELETED, None, "Mark"),
(CompareStatus.DELETED, int_default, "Mark"),
]

collect_fields = ["diff", "id", "name"]
Expand Down Expand Up @@ -311,6 +316,8 @@ def test_compare_additional_column_on_left(test_session):
session=test_session,
).save("ds2")

string_default = String.default_value(test_session.catalog.warehouse.db.dialect)

diff = ds1.compare(ds2, same=True, on=["id"], status_col="diff")

assert sorted_dicts(diff.to_records(), "id") == sorted_dicts(
Expand All @@ -321,7 +328,7 @@ def test_compare_additional_column_on_left(test_session):
"diff": CompareStatus.DELETED,
"id": 3,
"name": "Mark",
"city": None,
"city": string_default,
},
{"diff": CompareStatus.MODIFIED, "id": 4, "name": "Andy", "city": "Tokyo"},
],
Expand Down Expand Up @@ -356,8 +363,8 @@ def test_compare_additional_column_on_right(test_session):


def test_compare_missing_on(test_session):
ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session) # .save("ds1")
ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session) # . save("ds2")
ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session)
ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session)

with pytest.raises(ValueError) as exc_info:
ds1.compare(ds2, on=None)
Expand All @@ -366,8 +373,8 @@ def test_compare_missing_on(test_session):


def test_compare_right_on_wrong_length(test_session):
ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session) # .save("ds1")
ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session) # .save("ds2")
ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session)
ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session)

with pytest.raises(ValueError) as exc_info:
ds1.compare(ds2, on=["id"], right_on=["id", "name"])
Expand All @@ -376,8 +383,8 @@ def test_compare_right_on_wrong_length(test_session):


def test_compare_right_compare_defined_but_not_compare(test_session):
ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session) # .save("ds1")
ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session) # .save("ds2")
ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session)
ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session)

with pytest.raises(ValueError) as exc_info:
ds1.compare(ds2, on=["id"], right_compare=["name"])
Expand All @@ -388,8 +395,8 @@ def test_compare_right_compare_defined_but_not_compare(test_session):


def test_compare_right_compare_wrong_length(test_session):
ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds1")
ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds2")
ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session)
ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session)

with pytest.raises(ValueError) as exc_info:
ds1.compare(ds2, on=["id"], compare=["name"], right_compare=["name", "city"])
Expand Down

0 comments on commit 12641b9

Please sign in to comment.