Skip to content

Commit

Permalink
Add CSV parsing options (#813)
Browse files Browse the repository at this point in the history
* Update dc.py

Adding support for CSV files where values can span several lines, pyarrow parser already supports it

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update dc.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* adding csv parse options config

* naming of parse_options_config to parse_options

* typo

* fix tests, address PR review

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ivan Shcheklein <[email protected]>
  • Loading branch information
3 people authored Jan 21, 2025
1 parent 6cb6f20 commit 1b5a585
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
16 changes: 13 additions & 3 deletions src/datachain/lib/dc.py
Original file line number Diff line number Diff line change
Expand Up @@ -1942,7 +1942,7 @@ def parse_tabular(
def from_csv(
cls,
path,
delimiter: str = ",",
delimiter: Optional[str] = None,
header: bool = True,
output: OutputType = None,
object_name: str = "",
Expand All @@ -1952,14 +1952,16 @@ def from_csv(
session: Optional[Session] = None,
settings: Optional[dict] = None,
column_types: Optional[dict[str, "Union[str, ArrowDataType]"]] = None,
parse_options: Optional[dict[str, "Union[str, Union[bool, Callable]]"]] = None,
**kwargs,
) -> "DataChain":
"""Generate chain from csv files.
Parameters:
path : Storage URI with directory. URI must start with storage prefix such
as `s3://`, `gs://`, `az://` or "file:///".
delimiter : Character for delimiting columns.
delimiter : Character for delimiting columns. Takes precedence if also
specified in `parse_options`. Defaults to ",".
header : Whether the files include a header row.
output : Dictionary or feature class defining column names and their
corresponding types. List of column names is also accepted, in which
Expand All @@ -1973,6 +1975,8 @@ def from_csv(
column_types : Dictionary of column names and their corresponding types.
It is passed to CSV reader and for each column specified type auto
inference is disabled.
parse_options: Tells the parser how to process lines.
See https://arrow.apache.org/docs/python/generated/pyarrow.csv.ParseOptions.html
Example:
Reading a csv file:
Expand All @@ -1990,6 +1994,12 @@ def from_csv(
from pyarrow.dataset import CsvFileFormat
from pyarrow.lib import type_for_alias

parse_options = parse_options or {}
if "delimiter" not in parse_options:
parse_options["delimiter"] = ","
if delimiter:
parse_options["delimiter"] = delimiter

if column_types:
column_types = {
name: type_for_alias(typ) if isinstance(typ, str) else typ
Expand Down Expand Up @@ -2017,7 +2027,7 @@ def from_csv(
msg = f"error parsing csv - incompatible output type {type(output)}"
raise DatasetPrepareError(chain.name, msg)

parse_options = ParseOptions(delimiter=delimiter)
parse_options = ParseOptions(**parse_options)
read_options = ReadOptions(column_names=column_names)
convert_options = ConvertOptions(
strings_can_be_null=True,
Expand Down
28 changes: 28 additions & 0 deletions tests/unit/lib/test_datachain.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,6 +1332,34 @@ def test_from_csv_column_types(tmp_dir, test_session):
assert df1["age"].dtype == pd.StringDtype


def test_from_csv_parse_options(tmp_dir, test_session):
def skip_comment(row):
if row.text.startswith("# "):
return "skip"
return "error"

s = (
"animals;n_legs;entry\n"
"Flamingo;2;2022-03-01\n"
"# Comment here:\n"
"Horse;4;2022-03-02\n"
"Brittle stars;5;2022-03-03\n"
"Centipede;100;2022-03-04"
)

path = tmp_dir / "test.csv"
path.write_text(s)

dc = DataChain.from_csv(
path.as_uri(),
session=test_session,
parse_options={"invalid_row_handler": skip_comment, "delimiter": ";"},
)

df = dc.select("animals", "n_legs", "entry").to_pandas()
assert set(df["animals"]) == {"Horse", "Centipede", "Brittle stars", "Flamingo"}


def test_to_csv_features(tmp_dir, test_session):
dc_to = DataChain.from_values(
f1=features, num=range(len(features)), session=test_session
Expand Down

0 comments on commit 1b5a585

Please sign in to comment.