diff --git a/ci/requirements/all-but-numba.yml b/ci/requirements/all-but-numba.yml index 61f64a176af..9bea3a17a85 100644 --- a/ci/requirements/all-but-numba.yml +++ b/ci/requirements/all-but-numba.yml @@ -22,6 +22,7 @@ dependencies: - hypothesis - iris - lxml # Optional dep of pydap + - lithops - matplotlib-base - nc-time-axis - netcdf4 diff --git a/ci/requirements/environment-3.13.yml b/ci/requirements/environment-3.13.yml index 937cb013711..16f3cbf950e 100644 --- a/ci/requirements/environment-3.13.yml +++ b/ci/requirements/environment-3.13.yml @@ -19,6 +19,7 @@ dependencies: - hdf5 - hypothesis - iris + - lithops - lxml # Optional dep of pydap - matplotlib-base - nc-time-axis diff --git a/ci/requirements/environment-windows-3.13.yml b/ci/requirements/environment-windows-3.13.yml index 448e3f70c0c..341ac182e43 100644 --- a/ci/requirements/environment-windows-3.13.yml +++ b/ci/requirements/environment-windows-3.13.yml @@ -18,6 +18,7 @@ dependencies: - hypothesis - iris - lxml # Optional dep of pydap + - lithops - matplotlib-base - nc-time-axis - netcdf4 diff --git a/ci/requirements/environment-windows.yml b/ci/requirements/environment-windows.yml index 3b2e6dc62e6..61e84debfa4 100644 --- a/ci/requirements/environment-windows.yml +++ b/ci/requirements/environment-windows.yml @@ -18,6 +18,7 @@ dependencies: - hypothesis - iris - lxml # Optional dep of pydap + - lithops - matplotlib-base - nc-time-axis - netcdf4 diff --git a/ci/requirements/environment.yml b/ci/requirements/environment.yml index 364ae03666f..156307dbbd6 100644 --- a/ci/requirements/environment.yml +++ b/ci/requirements/environment.yml @@ -19,6 +19,7 @@ dependencies: - hdf5 - hypothesis - iris + - lithops - lxml # Optional dep of pydap - matplotlib-base - nc-time-axis diff --git a/ci/requirements/min-all-deps.yml b/ci/requirements/min-all-deps.yml index f3dab2e5bbf..dd9b040d713 100644 --- a/ci/requirements/min-all-deps.yml +++ b/ci/requirements/min-all-deps.yml @@ -30,6 +30,7 @@ dependencies: - hypothesis - iris=3.7 - lxml=4.9 # Optional dep of pydap + - lithops=3.5.1 - matplotlib-base=3.8 - nc-time-axis=1.4 # netcdf follows a 1.major.minor[.patch] convention diff --git a/xarray/backends/api.py b/xarray/backends/api.py index 2adcc57c6b9..d45d4391dc5 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -1365,6 +1365,9 @@ def open_groups( return groups +import warnings + + def open_mfdataset( paths: str | os.PathLike @@ -1480,9 +1483,10 @@ def open_mfdataset( those corresponding to other dimensions. * list of str: The listed coordinate variables will be concatenated, in addition the "minimal" coordinates. - parallel : bool, default: False - If True, the open and preprocess steps of this function will be - performed in parallel using ``dask.delayed``. Default is False. + parallel : 'dask', 'lithops', or False + Specify whether the open and preprocess steps of this function will be + performed in parallel using ``dask.delayed``, in parallel using ``lithops.map``, or in serial. + Default is False. Passing True is now a deprecated alias for passing 'dask'. join : {"outer", "inner", "left", "right", "exact", "override"}, default: "outer" String indicating how to combine differing indexes (excluding concat_dim) in objects @@ -1596,7 +1600,15 @@ def open_mfdataset( open_kwargs = dict(engine=engine, chunks=chunks or {}, **kwargs) - if parallel: + if parallel is True: + warnings.warn( + "Passing ``parallel=True`` is deprecated, instead please pass ``parallel='dask'`` explicitly", + PendingDeprecationWarning, + stacklevel=2, + ) + parallel = "dask" + + if parallel == "dask": import dask # wrap the open_dataset, getattr, and preprocess with delayed @@ -1604,19 +1616,51 @@ def open_mfdataset( getattr_ = dask.delayed(getattr) if preprocess is not None: preprocess = dask.delayed(preprocess) - else: + elif parallel == "lithops": + import lithops + + # TODO use RetryingFunctionExecutor instead? + fn_exec = lithops.FunctionExecutor() + + # lithops doesn't have a delayed primitive + open_ = open_dataset + # TODO I don't know how best to chain this with the getattr + # getattr_ = getattr + elif parallel is False: open_ = open_dataset getattr_ = getattr + else: + raise ValueError( + f"{parallel} is an invalid option for the keyword argument ``parallel``" + ) - datasets = [open_(p, **open_kwargs) for p in paths1d] - closers = [getattr_(ds, "_close") for ds in datasets] - if preprocess is not None: - datasets = [preprocess(ds) for ds in datasets] + if parallel == "dask": + datasets = [open_(p, **open_kwargs) for p in paths1d] + closers = [getattr_(ds, "_close") for ds in datasets] + if preprocess is not None: + datasets = [preprocess(ds) for ds in datasets] - if parallel: # calling compute here will return the datasets/file_objs lists, # the underlying datasets will still be stored as dask arrays datasets, closers = dask.compute(datasets, closers) + elif parallel == "lithops": + + def generate_lazy_ds(path): + # allows passing the open_dataset function to lithops without evaluating it + ds = open_(path, **kwargs) + return ds + + futures = fn_exec.map(generate_lazy_ds, paths1d) + + # wait for all the serverless workers to finish, and send their resulting lazy datasets back to the client + # TODO do we need download_results? + completed_futures, _ = fn_exec.wait(futures, download_results=True) + datasets = completed_futures.get_result() + elif parallel is False: + virtual_datasets = [open_(p, **kwargs) for p in paths1d] + closers = [getattr_(ds, "_close") for ds in virtual_datasets] + if preprocess is not None: + virtual_datasets = [preprocess(ds) for ds in virtual_datasets] # Combine all datasets, closing them in case of a ValueError try: @@ -1654,7 +1698,9 @@ def open_mfdataset( ds.close() raise - combined.set_close(partial(_multi_file_closer, closers)) + # TODO remove if once closers added above + if parallel != "lithops": + combined.set_close(partial(_multi_file_closer, closers)) # read global attributes from the attrs_file or from the first dataset if attrs_file is not None: diff --git a/xarray/tests/__init__.py b/xarray/tests/__init__.py index 1f2eedcd8f0..5c9187341d5 100644 --- a/xarray/tests/__init__.py +++ b/xarray/tests/__init__.py @@ -118,6 +118,7 @@ def _importorskip( category=DeprecationWarning, ) has_dask_expr, requires_dask_expr = _importorskip("dask_expr") +has_lithops, requires_lithops = _importorskip("lithops") has_bottleneck, requires_bottleneck = _importorskip("bottleneck") has_rasterio, requires_rasterio = _importorskip("rasterio") has_zarr, requires_zarr = _importorskip("zarr") diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index cfca5e69048..11d5eacf344 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -80,6 +80,7 @@ requires_h5netcdf_1_4_0_or_above, requires_h5netcdf_ros3, requires_iris, + requires_lithops, requires_netcdf, requires_netCDF4, requires_netCDF4_1_6_2_or_above, @@ -4410,6 +4411,59 @@ def test_open_mfdataset_manyfiles( assert_identical(original, actual) +@requires_netCDF4 +class TestParallel: + def test_validate_parallel_kwarg(self) -> None: + original = Dataset({"foo": ("x", np.random.randn(10))}) + datasets = [original.isel(x=slice(5)), original.isel(x=slice(5, 10))] + with create_tmp_file() as tmp1: + with create_tmp_file() as tmp2: + save_mfdataset(datasets, [tmp1, tmp2]) + + with pytest.raises(ValueError, match="garbage is an invalid option"): + open_mfdataset( + [tmp1, tmp2], + concat_dim="x", + combine="nested", + parallel="garbage", + ) + + def test_deprecation_warning(self) -> None: + original = Dataset({"foo": ("x", np.random.randn(10))}) + datasets = [original.isel(x=slice(5)), original.isel(x=slice(5, 10))] + with create_tmp_file() as tmp1: + with create_tmp_file() as tmp2: + save_mfdataset(datasets, [tmp1, tmp2]) + + with pytest.warns( + PendingDeprecationWarning, + match="please pass ``parallel='dask'`` explicitly", + ): + open_mfdataset( + [tmp1, tmp2], + concat_dim="x", + combine="nested", + parallel=True, + ) + + @requires_lithops + def test_lithops_parallel(self) -> None: + # default configuration of lithops will use local executor + + original = Dataset({"foo": ("x", np.random.randn(10))}) + datasets = [original.isel(x=slice(5)), original.isel(x=slice(5, 10))] + with create_tmp_file() as tmp1: + with create_tmp_file() as tmp2: + save_mfdataset(datasets, [tmp1, tmp2]) + with open_mfdataset( + [tmp1, tmp2], + concat_dim="x", + combine="nested", + parallel="lithops", + ) as actual: + assert_identical(actual, original) + + @requires_netCDF4 @requires_dask def test_open_mfdataset_can_open_path_objects() -> None: