diff --git a/tests/unit/lib/test_datachain.py b/tests/unit/lib/test_datachain.py index 92f3dfec8..1d0496170 100644 --- a/tests/unit/lib/test_datachain.py +++ b/tests/unit/lib/test_datachain.py @@ -1052,13 +1052,17 @@ def test_parse_nested_json(tmp_dir, test_session): # E.g. nAmE -> name, l--as@t -> l_as_t, etc df1 = dc.select("na_me", "age", "city").to_pandas() + # In CH we replace None with '' for peforance reasons, + # have to handle it here + string_default = String.default_value(test_session.catalog.warehouse.db.dialect) + assert sorted(df1["na_me"]["first_select"].to_list()) == sorted( d["first-SELECT"] for d in df["nA-mE"].to_list() ) assert sorted( df1["na_me"]["l_as_t"].to_list(), key=lambda x: (x is None, x) ) == sorted( - [d.get("l--as@t", None) for d in df["nA-mE"].to_list()], + [d.get("l--as@t", string_default) for d in df["nA-mE"].to_list()], key=lambda x: (x is None, x), ) @@ -1300,7 +1304,6 @@ def test_from_csv_null_collect(tmp_dir, test_session): for i, row in enumerate(dc.collect()): # None value in numeric column will get converted to nan. if not height[i]: - print(row[1].height) assert math.isnan(row[1].height) else: assert row[1].height == height[i] @@ -1417,6 +1420,10 @@ def test_explode(tmp_dir, test_session, column_type, object_name, model_name): object_name = object_name or "content_expl" model_name = model_name or "ContentExplodedModel" + # In CH we have (atm at least) None converted to '' + # for performance reasons, so we need to handle this case + string_default = String.default_value(test_session.catalog.warehouse.db.dialect) + assert set( dc.collect( f"{object_name}.na_me.first_select", @@ -1426,7 +1433,7 @@ def test_explode(tmp_dir, test_session, column_type, object_name, model_name): ) == { ("Alice", 25, "New York"), ("Bob", 30, "Los Angeles"), - ("Charlie", 35, None), + ("Charlie", 35, string_default), ("David", 40, "Houston"), ("Eva", 45, "Phoenix"), ("Ivan", 41, "San Francisco"), @@ -2090,22 +2097,6 @@ def test_from_values_array_of_floats(test_session): assert list(chain.order_by("emd").collect("emd")) == embeddings -def test_from_values_array_of_ints_with_nones(test_session): - ids = [1, 2] - embeddings = [[1, None], [4, 5]] - chain = DataChain.from_values(emd=embeddings, ids=ids, session=test_session) - - assert list(chain.order_by("ids").collect("emd")) == embeddings - - -def test_from_values_with_nones(test_session): - ids = [1, 2, 3, 4] - embeddings = [100, None, 300, None] - chain = DataChain.from_values(emd=embeddings, ids=ids, session=test_session) - - assert list(chain.order_by("ids").collect("emd")) == [100, None, 300, None] - - def test_custom_model_with_nested_lists(test_session): class Trace(BaseModel): x: float diff --git a/tests/unit/lib/test_diff.py b/tests/unit/lib/test_diff.py index 4f604c7f8..298c6c5c9 100644 --- a/tests/unit/lib/test_diff.py +++ b/tests/unit/lib/test_diff.py @@ -4,6 +4,7 @@ from datachain.diff import CompareStatus, compare_and_split from datachain.lib.dc import DataChain from datachain.lib.file import File +from datachain.sql.types import Int64, String from tests.utils import sorted_dicts @@ -162,13 +163,15 @@ def test_compare_with_explicit_compare_fields(test_session, right_name): status_col="diff", ) + string_default = String.default_value(test_session.catalog.warehouse.db.dialect) + expected = [ (CompareStatus.MODIFIED, 1, "John1", "New York"), (CompareStatus.ADDED, 2, "Doe", "Boston"), ( CompareStatus.DELETED, 3, - None if right_name == "other_name" else "Mark", + string_default if right_name == "other_name" else "Mark", "Seattle", ), (CompareStatus.SAME, 4, "Andy", "San Francisco"), @@ -199,11 +202,13 @@ def test_compare_different_left_right_on_columns(test_session): status_col="diff", ) + int_default = Int64.default_value(test_session.catalog.warehouse.db.dialect) + expected = [ (CompareStatus.SAME, 4, "Andy"), (CompareStatus.ADDED, 2, "Doe"), (CompareStatus.MODIFIED, 1, "John1"), - (CompareStatus.DELETED, None, "Mark"), + (CompareStatus.DELETED, int_default, "Mark"), ] collect_fields = ["diff", "id", "name"] @@ -311,6 +316,8 @@ def test_compare_additional_column_on_left(test_session): session=test_session, ).save("ds2") + string_default = String.default_value(test_session.catalog.warehouse.db.dialect) + diff = ds1.compare(ds2, same=True, on=["id"], status_col="diff") assert sorted_dicts(diff.to_records(), "id") == sorted_dicts( @@ -321,7 +328,7 @@ def test_compare_additional_column_on_left(test_session): "diff": CompareStatus.DELETED, "id": 3, "name": "Mark", - "city": None, + "city": string_default, }, {"diff": CompareStatus.MODIFIED, "id": 4, "name": "Andy", "city": "Tokyo"}, ], @@ -356,8 +363,8 @@ def test_compare_additional_column_on_right(test_session): def test_compare_missing_on(test_session): - ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session) # .save("ds1") - ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session) # . save("ds2") + ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session) + ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session) with pytest.raises(ValueError) as exc_info: ds1.compare(ds2, on=None) @@ -366,8 +373,8 @@ def test_compare_missing_on(test_session): def test_compare_right_on_wrong_length(test_session): - ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session) # .save("ds1") - ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session) # .save("ds2") + ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session) + ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session) with pytest.raises(ValueError) as exc_info: ds1.compare(ds2, on=["id"], right_on=["id", "name"]) @@ -376,8 +383,8 @@ def test_compare_right_on_wrong_length(test_session): def test_compare_right_compare_defined_but_not_compare(test_session): - ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session) # .save("ds1") - ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session) # .save("ds2") + ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session) + ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session) with pytest.raises(ValueError) as exc_info: ds1.compare(ds2, on=["id"], right_compare=["name"]) @@ -388,8 +395,8 @@ def test_compare_right_compare_defined_but_not_compare(test_session): def test_compare_right_compare_wrong_length(test_session): - ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds1") - ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session).save("ds2") + ds1 = DataChain.from_values(id=[1, 2, 4], session=test_session) + ds2 = DataChain.from_values(id=[1, 2, 4], session=test_session) with pytest.raises(ValueError) as exc_info: ds1.compare(ds2, on=["id"], compare=["name"], right_compare=["name", "city"])