diff --git a/onedal/ensemble/forest.py b/onedal/ensemble/forest.py index 7001bf3fbe..002254126d 100644 --- a/onedal/ensemble/forest.py +++ b/onedal/ensemble/forest.py @@ -17,11 +17,7 @@ import numbers import warnings from abc import ABCMeta, abstractmethod -from math import ceil - -import numpy as np -from sklearn.ensemble import BaseEnsemble -from sklearn.utils import check_random_state +import math from daal4py.sklearn._utils import daal_check_version from sklearnex import get_hyperparameters @@ -30,16 +26,9 @@ from ..common._estimator_checks import _check_is_fitted from ..common._mixin import ClassifierMixin, RegressorMixin from ..datatypes import _convert_to_supported, from_table, to_table -from ..utils import ( - _check_array, - _check_n_features, - _check_X_y, - _column_or_1d, - _validate_targets, -) -class BaseForest(BaseEstimator, BaseEnsemble, metaclass=ABCMeta): +class BaseForest(BaseEstimator, metaclass=ABCMeta): @abstractmethod def __init__( self, @@ -100,8 +89,8 @@ def _to_absolute_max_features(self, n_features): if self.max_features is None: return n_features elif isinstance(self.max_features, str): - return max(1, int(getattr(np, self.max_features)(n_features))) - elif isinstance(self.max_features, (numbers.Integral, np.integer)): + return max(1, int(getattr(math, self.max_features)(n_features))) + elif isinstance(self.max_features, numbers.Integral): return self.max_features elif self.max_features > 0.0: return max(1, int(self.max_features * n_features)) @@ -132,31 +121,18 @@ def _get_onedal_params(self, data): self.observations_per_tree_fraction = ( self.observations_per_tree_fraction if bool(self.bootstrap) else 1.0 ) - - if not self.bootstrap and self.max_samples is not None: - raise ValueError( - "`max_sample` cannot be set if `bootstrap=False`. " - "Either switch to `bootstrap=True` or set " - "`max_sample=None`." - ) - if not self.bootstrap and self.oob_score: - raise ValueError("Out of bag estimation only available" " if bootstrap=True") - min_observations_in_leaf_node = ( self.min_samples_leaf if isinstance(self.min_samples_leaf, numbers.Integral) - else int(ceil(self.min_samples_leaf * n_samples)) + else int(math.ceil(self.min_samples_leaf * n_samples)) ) min_observations_in_split_node = ( self.min_samples_split if isinstance(self.min_samples_split, numbers.Integral) - else int(ceil(self.min_samples_split * n_samples)) + else int(math.ceil(self.min_samples_split * n_samples)) ) - rs = check_random_state(self.random_state) - seed = rs.randint(0, np.iinfo("i").max) - onedal_params = { "fptype": data.dtype, "method": self.algorithm, @@ -176,7 +152,7 @@ def _get_onedal_params(self, data): "max_leaf_nodes": (0 if self.max_leaf_nodes is None else self.max_leaf_nodes), "max_bins": self.max_bins, "min_bin_size": self.min_bin_size, - "seed": seed, + "seed": self.random_state, "memory_saving_mode": False, "bootstrap": bool(self.bootstrap), "error_metric_mode": self.error_metric_mode, @@ -190,125 +166,12 @@ def _get_onedal_params(self, data): onedal_params["splitter_mode"] = self.splitter_mode return onedal_params - def _check_parameters(self): - if isinstance(self.min_samples_leaf, numbers.Integral): - if not 1 <= self.min_samples_leaf: - raise ValueError( - "min_samples_leaf must be at least 1 " - "or in (0, 0.5], got %s" % self.min_samples_leaf - ) - else: # float - if not 0.0 < self.min_samples_leaf <= 0.5: - raise ValueError( - "min_samples_leaf must be at least 1 " - "or in (0, 0.5], got %s" % self.min_samples_leaf - ) - if isinstance(self.min_samples_split, numbers.Integral): - if not 2 <= self.min_samples_split: - raise ValueError( - "min_samples_split must be an integer " - "greater than 1 or a float in (0.0, 1.0]; " - "got the integer %s" % self.min_samples_split - ) - else: # float - if not 0.0 < self.min_samples_split <= 1.0: - raise ValueError( - "min_samples_split must be an integer " - "greater than 1 or a float in (0.0, 1.0]; " - "got the float %s" % self.min_samples_split - ) - if not 0 <= self.min_weight_fraction_leaf <= 0.5: - raise ValueError("min_weight_fraction_leaf must in [0, 0.5]") - if self.min_impurity_split is not None: - warnings.warn( - "The min_impurity_split parameter is deprecated. " - "Its default value has changed from 1e-7 to 0 in " - "version 0.23, and it will be removed in 0.25. " - "Use the min_impurity_decrease parameter instead.", - FutureWarning, - ) - - if self.min_impurity_split < 0.0: - raise ValueError( - "min_impurity_split must be greater than " "or equal to 0" - ) - if self.min_impurity_decrease < 0.0: - raise ValueError( - "min_impurity_decrease must be greater than " "or equal to 0" - ) - if self.max_leaf_nodes is not None: - if not isinstance(self.max_leaf_nodes, numbers.Integral): - raise ValueError( - "max_leaf_nodes must be integral number but was " - "%r" % self.max_leaf_nodes - ) - if self.max_leaf_nodes < 2: - raise ValueError( - ("max_leaf_nodes {0} must be either None " "or larger than 1").format( - self.max_leaf_nodes - ) - ) - if isinstance(self.max_bins, numbers.Integral): - if not 2 <= self.max_bins: - raise ValueError("max_bins must be at least 2, got %s" % self.max_bins) - else: - raise ValueError( - "max_bins must be integral number but was " "%r" % self.max_bins - ) - if isinstance(self.min_bin_size, numbers.Integral): - if not 1 <= self.min_bin_size: - raise ValueError( - "min_bin_size must be at least 1, got %s" % self.min_bin_size - ) - else: - raise ValueError( - "min_bin_size must be integral number but was " "%r" % self.min_bin_size - ) - - def _validate_targets(self, y, dtype): - self.class_weight_ = None - self.classes_ = None - return _column_or_1d(y, warn=True).astype(dtype, copy=False) - - def _get_sample_weight(self, sample_weight, X): - sample_weight = np.asarray(sample_weight, dtype=X.dtype).ravel() - - sample_weight = _check_array( - sample_weight, accept_sparse=False, ensure_2d=False, dtype=X.dtype, order="C" - ) - - if sample_weight.size != X.shape[0]: - raise ValueError( - "sample_weight and X have incompatible shapes: " - "%r vs %r\n" - "Note: Sparse matrices cannot be indexed w/" - "boolean masks (use `indices=True` in CV)." - % (sample_weight.shape, X.shape) - ) - - return sample_weight - def _fit(self, X, y, sample_weight, module, queue): - X, y = _check_X_y( - X, - y, - dtype=[np.float64, np.float32], - force_all_finite=True, - accept_sparse="csr", - ) - y = self._validate_targets(y, X.dtype) - - self.n_features_in_ = X.shape[1] - - if sample_weight is not None and len(sample_weight) > 0: - sample_weight = self._get_sample_weight(sample_weight, X) - data = (X, y, sample_weight) - else: - data = (X, y) + data = (X, y, sample_weight) if sample_weight else (X, y) policy = self._get_policy(queue, *data) - data = _convert_to_supported(policy, *data) + data = to_table(*_convert_to_supported(policy, *data)) params = self._get_onedal_params(data[0]) - train_result = module.train(policy, params, *to_table(*data)) + train_result = module.train(policy, params, *data) self._onedal_model = train_result.model @@ -318,25 +181,10 @@ def _fit(self, X, y, sample_weight, module, queue): self.oob_decision_function_ = from_table( train_result.oob_err_decision_function ) - if np.any(self.oob_decision_function_ == 0): - warnings.warn( - "Some inputs do not have OOB scores. This probably means " - "too few trees were used to compute any reliable OOB " - "estimates.", - UserWarning, - ) + else: self.oob_score_ = from_table(train_result.oob_err_r2).item() - self.oob_prediction_ = from_table( - train_result.oob_err_prediction - ).reshape(-1) - if np.any(self.oob_prediction_ == 0): - warnings.warn( - "Some inputs do not have OOB scores. This probably means " - "too few trees were used to compute any reliable OOB " - "estimates.", - UserWarning, - ) + self.oob_prediction_ = from_table(train_result.oob_err_prediction) return self @@ -347,42 +195,32 @@ def _create_model(self, module): def _predict(self, X, module, queue, hparams=None): _check_is_fitted(self) - X = _check_array( - X, dtype=[np.float64, np.float32], force_all_finite=True, accept_sparse=False - ) - _check_n_features(self, X, False) policy = self._get_policy(queue, X) model = self._onedal_model - X = _convert_to_supported(policy, X) + X = to_table(_convert_to_supported(policy, X)) params = self._get_onedal_params(X) if hparams is not None and not hparams.is_default: - result = module.infer(policy, params, hparams.backend, model, to_table(X)) + result = module.infer(policy, params, hparams.backend, model, X) else: - result = module.infer(policy, params, model, to_table(X)) + result = module.infer(policy, params, model, X) - y = from_table(result.responses) - return y + return from_table(result.responses) def _predict_proba(self, X, module, queue, hparams=None): _check_is_fitted(self) - X = _check_array( - X, dtype=[np.float64, np.float32], force_all_finite=True, accept_sparse=False - ) - _check_n_features(self, X, False) policy = self._get_policy(queue, X) - X = _convert_to_supported(policy, X) + X = to_table(_convert_to_supported(policy, X)) params = self._get_onedal_params(X) params["infer_mode"] = "class_probabilities" model = self._onedal_model if hparams is not None and not hparams.is_default: - result = module.infer(policy, params, hparams.backend, model, to_table(X)) + result = module.infer(policy, params, hparams.backend, model, X) else: - result = module.infer(policy, params, model, to_table(X)) + result = module.infer(policy, params, model, X) - y = from_table(result.probabilities) - return y + return from_table(result.probabilities) class RandomForestClassifier(ClassifierMixin, BaseForest, metaclass=ABCMeta): @@ -443,18 +281,6 @@ def __init__( algorithm=algorithm, ) - def _validate_targets(self, y, dtype): - y, self.class_weight_, self.classes_ = _validate_targets( - y, self.class_weight, dtype - ) - - # Decapsulate classes_ attributes - # TODO: - # align with `n_classes_` and `classes_` attr with daal4py implementations. - # if hasattr(self, "classes_"): - # self.n_classes_ = self.classes_ - return y - def fit(self, X, y, sample_weight=None, queue=None): return self._fit( X, @@ -466,15 +292,13 @@ def fit(self, X, y, sample_weight=None, queue=None): def predict(self, X, queue=None): hparams = get_hyperparameters("decision_forest", "infer") - pred = super()._predict( + return super()._predict( X, self._get_backend("decision_forest", "classification", None), queue, hparams, ) - return np.take(self.classes_, pred.ravel().astype(np.int64, casting="unsafe")) - def predict_proba(self, X, queue=None): hparams = get_hyperparameters("decision_forest", "infer") @@ -514,7 +338,6 @@ def __init__( error_metric_mode="none", variable_importance_mode="none", algorithm="hist", - **kwargs, ): super().__init__( n_estimators=n_estimators, @@ -558,10 +381,8 @@ def fit(self, X, y, sample_weight=None, queue=None): ) def predict(self, X, queue=None): - return ( - super() - ._predict(X, self._get_backend("decision_forest", "regression", None), queue) - .ravel() + return super()._predict( + X, self._get_backend("decision_forest", "regression", None), queue ) @@ -593,7 +414,6 @@ def __init__( error_metric_mode="none", variable_importance_mode="none", algorithm="hist", - **kwargs, ): super().__init__( n_estimators=n_estimators, @@ -623,18 +443,6 @@ def __init__( algorithm=algorithm, ) - def _validate_targets(self, y, dtype): - y, self.class_weight_, self.classes_ = _validate_targets( - y, self.class_weight, dtype - ) - - # Decapsulate classes_ attributes - # TODO: - # align with `n_classes_` and `classes_` attr with daal4py implementations. - # if hasattr(self, "classes_"): - # self.n_classes_ = self.classes_ - return y - def fit(self, X, y, sample_weight=None, queue=None): return self._fit( X, @@ -645,12 +453,10 @@ def fit(self, X, y, sample_weight=None, queue=None): ) def predict(self, X, queue=None): - pred = super()._predict( + return super()._predict( X, self._get_backend("decision_forest", "classification", None), queue ) - return np.take(self.classes_, pred.ravel().astype(np.int64, casting="unsafe")) - def predict_proba(self, X, queue=None): return super()._predict_proba( X, self._get_backend("decision_forest", "classification", None), queue @@ -685,7 +491,6 @@ def __init__( error_metric_mode="none", variable_importance_mode="none", algorithm="hist", - **kwargs, ): super().__init__( n_estimators=n_estimators, @@ -729,8 +534,6 @@ def fit(self, X, y, sample_weight=None, queue=None): ) def predict(self, X, queue=None): - return ( - super() - ._predict(X, self._get_backend("decision_forest", "regression", None), queue) - .ravel() + return super()._predict( + X, self._get_backend("decision_forest", "regression", None), queue ) diff --git a/sklearnex/ensemble/_forest.py b/sklearnex/ensemble/_forest.py index 2a04962645..e0c088cbfc 100644 --- a/sklearnex/ensemble/_forest.py +++ b/sklearnex/ensemble/_forest.py @@ -58,35 +58,33 @@ from onedal.primitives import get_tree_state_cls, get_tree_state_reg from onedal.utils import _num_features, _num_samples from sklearnex import get_hyperparameters -from sklearnex._utils import register_hyperparameters from .._device_offload import dispatch, wrap_output_data -from .._utils import PatchingConditionsChain +from .._utils import register_hyperparameters, PatchingConditionsChain from ..utils._array_api import get_namespace +from ..utils.validation import assert_all_finite, validate_data, _check_sample_weight if sklearn_check_version("1.2"): from sklearn.utils._param_validation import Interval -if sklearn_check_version("1.4"): - from daal4py.sklearn.utils import _assert_all_finite - -if sklearn_check_version("1.6"): - from sklearn.utils.validation import validate_data -else: - validate_data = BaseEstimator._validate_data class BaseForest(ABC): _onedal_factory = None def _onedal_fit(self, X, y, sample_weight=None, queue=None): + if sp.issparse(y): + raise ValueError("sparse multilabel-indicator for y is not supported.") + + xp, _ = get_namespace(X) + X, y = validate_data( self, X, y, multi_output=True, accept_sparse=False, - dtype=[np.float64, np.float32], - force_all_finite=False, + dtype=[xp.float64, xp.float32], + ensure_all_finite=False, ensure_2d=True, ) @@ -103,9 +101,7 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): ) if y.ndim == 1: - # reshape is necessary to preserve the data contiguity against vs - # [:, np.newaxis] that does not. - y = np.reshape(y, (-1, 1)) + y = xp.reshape(y, (-1, 1)) self._n_samples, self.n_outputs_ = y.shape @@ -119,6 +115,29 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): if sample_weight is not None: sample_weight = [sample_weight] + if not self.bootstrap and self.max_samples is not None: + raise ValueError( + "`max_sample` cannot be set if `bootstrap=False`. " + "Either switch to `bootstrap=True` or set " + "`max_sample=None`." + ) + elif self.bootstrap: + n_samples_bootstrap = _get_n_samples_bootstrap( + n_samples=X.shape[0], max_samples=self.max_samples + ) + else: + n_samples_bootstrap = None + + self._n_samples_bootstrap = n_samples_bootstrap + + self._validate_estimator() + + if not self.bootstrap and self.oob_score: + raise ValueError("Out of bag estimation only available if bootstrap=True") + + rs = check_random_state(self.random_state) + seed = rs.randint(0, xp.iinfo("i").max) + onedal_params = { "n_estimators": self.n_estimators, "criterion": self.criterion, @@ -127,14 +146,14 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): "min_samples_leaf": self.min_samples_leaf, "min_weight_fraction_leaf": self.min_weight_fraction_leaf, "max_features": self._to_absolute_max_features( - self.max_features, self.n_features_in_ + self.max_features, self.n_features_in_, xp ), "max_leaf_nodes": self.max_leaf_nodes, "min_impurity_decrease": self.min_impurity_decrease, "bootstrap": self.bootstrap, "oob_score": self.oob_score, "n_jobs": self.n_jobs, - "random_state": self.random_state, + "random_state": seed, "verbose": self.verbose, "warm_start": self.warm_start, "error_metric_mode": self._err if self.oob_score else "none", @@ -155,9 +174,9 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): # Compute self._onedal_estimator = self._onedal_factory(**onedal_params) - self._onedal_estimator.fit(X, np.ravel(y), sample_weight, queue=queue) + self._onedal_estimator.fit(X, y, sample_weight, queue=queue) - self._save_attributes() + self._save_attributes(xp) # Decapsulate classes_ attributes if hasattr(self, "classes_") and self.n_outputs_ == 1: @@ -166,15 +185,31 @@ def _onedal_fit(self, X, y, sample_weight=None, queue=None): return self - def _save_attributes(self): + def _save_attributes(self, xp): if self.oob_score: self.oob_score_ = self._onedal_estimator.oob_score_ if hasattr(self._onedal_estimator, "oob_prediction_"): self.oob_prediction_ = self._onedal_estimator.oob_prediction_ + if xp.any(self.oob_prediction_ == 0): + warnings.warn( + "Some inputs do not have OOB scores. This probably means " + "too few trees were used to compute any reliable OOB " + "estimates.", + UserWarning, + ) + if hasattr(self._onedal_estimator, "oob_decision_function_"): self.oob_decision_function_ = ( self._onedal_estimator.oob_decision_function_ ) + if xp.any(self.oob_decision_function_ == 0): + warnings.warn( + "Some inputs do not have OOB scores. This probably means " + "too few trees were used to compute any reliable OOB " + "estimates.", + UserWarning, + ) + if self.bootstrap: self._n_samples_bootstrap = max( round( @@ -188,7 +223,7 @@ def _save_attributes(self): self._validate_estimator() return self - def _to_absolute_max_features(self, max_features, n_features): + def _to_absolute_max_features(self, max_features, n_features, xp=None): if max_features is None: return n_features if isinstance(max_features, str): @@ -204,14 +239,14 @@ def _to_absolute_max_features(self, max_features, n_features): FutureWarning, ) return ( - max(1, int(np.sqrt(n_features))) + max(1, int(xp.sqrt(n_features))) if isinstance(self, ForestClassifier) else n_features ) if max_features == "sqrt": - return max(1, int(np.sqrt(n_features))) + return max(1, int(xp.sqrt(n_features))) if max_features == "log2": - return max(1, int(np.log2(n_features))) + return max(1, int(xp.log2(n_features))) allowed_string_values = ( '"sqrt" or "log2"' if sklearn_check_version("1.3") @@ -221,6 +256,7 @@ def _to_absolute_max_features(self, max_features, n_features): "Invalid value for max_features. Allowed string " f"values are {allowed_string_values}." ) + if isinstance(max_features, (numbers.Integral, np.integer)): return max_features if max_features > 0.0: @@ -319,6 +355,10 @@ def estimators_(self, estimators): self._cached_estimators_ = estimators def _estimators_(self): + """This attribute provides lazy creation of scikit-learn conformant + Decision Trees used for analysis in such as 'apply'. This will stay + array_api non-conformant as this is inherently creating sklearn + objects which are not array_api conformant""" # _estimators_ should only be called if _onedal_estimator exists check_is_fitted(self, "_onedal_estimator") if hasattr(self, "n_classes_"): @@ -475,17 +515,12 @@ def fit(self, X, y, sample_weight=None): return self def _onedal_fit_ready(self, patching_status, X, y, sample_weight): - if sp.issparse(y): - raise ValueError("sparse multilabel-indicator for y is not supported.") - + xp, _ = get_namespace(X) if sklearn_check_version("1.2"): self._validate_params() else: self._check_parameters() - if not self.bootstrap and self.oob_score: - raise ValueError("Out of bag estimation only available" " if bootstrap=True") - patching_status.and_conditions( [ ( @@ -524,7 +559,7 @@ def _onedal_fit_ready(self, patching_status, X, y, sample_weight): if patching_status.get_status() and sklearn_check_version("1.4"): try: - _assert_all_finite(X) + assert_all_finite(X) input_is_finite = True except ValueError: input_is_finite = False @@ -539,52 +574,21 @@ def _onedal_fit_ready(self, patching_status, X, y, sample_weight): ) if patching_status.get_status(): - X, y = check_X_y( - X, - y, - multi_output=True, - accept_sparse=True, - dtype=[np.float64, np.float32], - force_all_finite=False, - ) - - if y.ndim == 2 and y.shape[1] == 1: - warnings.warn( - "A column-vector y was passed when a 1d array was" - " expected. Please change the shape of y to " - "(n_samples,), for example using ravel().", - DataConversionWarning, - stacklevel=2, - ) - - if y.ndim == 1: - y = np.reshape(y, (-1, 1)) - - self.n_outputs_ = y.shape[1] patching_status.and_conditions( [ ( - self.n_outputs_ == 1, - f"Number of outputs ({self.n_outputs_}) is not 1.", + _num_features(y, fallback_1d=True) == 1, + f"Number of outputs is not 1.", ), ( - y.dtype in [np.float32, np.float64, np.int32, np.int64], + y.dtype in [xp.float32, xp.float64, xp.int32, xp.int64], f"Datatype ({y.dtype}) for y is not supported.", ), ] ) # TODO: Fix to support integers as input - _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples) - - if not self.bootstrap and self.max_samples is not None: - raise ValueError( - "`max_sample` cannot be set if `bootstrap=False`. " - "Either switch to `bootstrap=True` or set " - "`max_sample=None`." - ) - if ( patching_status.get_status() and (self.random_state is not None) @@ -596,7 +600,7 @@ def _onedal_fit_ready(self, patching_status, X, y, sample_weight): RuntimeWarning, ) - return patching_status, X, y, sample_weight + return patching_status @wrap_output_data def predict(self, X): @@ -613,9 +617,6 @@ def predict(self, X): @wrap_output_data def predict_proba(self, X): - # TODO: - # _check_proba() - # self._check_proba() check_is_fitted(self) return dispatch( self, @@ -668,9 +669,7 @@ def _onedal_cpu_supported(self, method_name, *data): ) if method_name == "fit": - patching_status, X, y, sample_weight = self._onedal_fit_ready( - patching_status, *data - ) + patching_status = self._onedal_fit_ready(patching_status, *data) patching_status.and_conditions( [ @@ -680,7 +679,7 @@ def _onedal_cpu_supported(self, method_name, *data): "ExtraTrees only supported starting from oneDAL version 2023.2", ), ( - not sp.issparse(sample_weight), + not sp.issparse(data[2]), "sample_weight is sparse. " "Sparse input is not supported.", ), ] @@ -736,9 +735,7 @@ def _onedal_gpu_supported(self, method_name, *data): ) if method_name == "fit": - patching_status, X, y, sample_weight = self._onedal_fit_ready( - patching_status, *data - ) + patching_status = self._onedal_fit_ready(patching_status, *data) patching_status.and_conditions( [ @@ -751,7 +748,7 @@ def _onedal_gpu_supported(self, method_name, *data): not self.oob_score, "oob_scores using r2 or accuracy not implemented.", ), - (sample_weight is None, "sample_weight is not supported."), + (data[2] is None, "sample_weight is not supported."), ] ) @@ -790,21 +787,19 @@ def _onedal_gpu_supported(self, method_name, *data): return patching_status def _onedal_predict(self, X, queue=None): - + xp, _ = get_namespace(X) if sklearn_check_version("1.0"): X = validate_data( self, X, - dtype=[np.float64, np.float32], - force_all_finite=False, + dtype=[xp.float64, xp.float32], reset=False, ensure_2d=True, ) else: X = check_array( X, - dtype=[np.float64, np.float32], - force_all_finite=False, + dtype=[xp.float64, xp.float32], ) # Warning, order of dtype matters if hasattr(self, "n_features_in_"): try: @@ -822,24 +817,22 @@ def _onedal_predict(self, X, queue=None): self._check_n_features(X, reset=False) res = self._onedal_estimator.predict(X, queue=queue) - return np.take(self.classes_, res.ravel().astype(np.int64, casting="unsafe")) + return xp.take(self.classes_, xp.reshape(res, (-1,)).astype(xp.int64)) def _onedal_predict_proba(self, X, queue=None): - + xp, _ = get_namespace(X) if sklearn_check_version("1.0"): X = validate_data( self, X, - dtype=[np.float64, np.float32], - force_all_finite=False, + dtype=[xp.float64, xp.float32], reset=False, ensure_2d=True, ) else: X = check_array( X, - dtype=[np.float64, np.float32], - force_all_finite=False, + dtype=[xp.float64, xp.float32], ) # Warning, order of dtype matters self._check_n_features(X, reset=False) @@ -897,25 +890,11 @@ def __init__( raise TypeError(f" oneDAL estimator has not been set.") def _onedal_fit_ready(self, patching_status, X, y, sample_weight): - if sp.issparse(y): - raise ValueError("sparse multilabel-indicator for y is not supported.") - if sklearn_check_version("1.2"): self._validate_params() else: self._check_parameters() - if not self.bootstrap and self.oob_score: - raise ValueError("Out of bag estimation only available" " if bootstrap=True") - - if sklearn_check_version("1.0") and self.criterion == "mse": - warnings.warn( - "Criterion 'mse' was deprecated in v1.0 and will be " - "removed in version 1.2. Use `criterion='squared_error'` " - "which is equivalent.", - FutureWarning, - ) - patching_status.and_conditions( [ ( @@ -944,7 +923,7 @@ def _onedal_fit_ready(self, patching_status, X, y, sample_weight): if patching_status.get_status() and sklearn_check_version("1.4"): try: - _assert_all_finite(X) + assert_all_finite(X) input_is_finite = True except ValueError: input_is_finite = False @@ -959,50 +938,15 @@ def _onedal_fit_ready(self, patching_status, X, y, sample_weight): ) if patching_status.get_status(): - X, y = check_X_y( - X, - y, - multi_output=True, - accept_sparse=True, - dtype=[np.float64, np.float32], - force_all_finite=False, - ) - - if y.ndim == 2 and y.shape[1] == 1: - warnings.warn( - "A column-vector y was passed when a 1d array was" - " expected. Please change the shape of y to " - "(n_samples,), for example using ravel().", - DataConversionWarning, - stacklevel=2, - ) - - if y.ndim == 1: - # reshape is necessary to preserve the data contiguity against vs - # [:, np.newaxis] that does not. - y = np.reshape(y, (-1, 1)) - - self.n_outputs_ = y.shape[1] - patching_status.and_conditions( [ ( - self.n_outputs_ == 1, - f"Number of outputs ({self.n_outputs_}) is not 1.", + _num_features(y, fallback_1d=True) == 1, + f"Number of outputs is not 1.", ) ] ) - # Sklearn function used for doing checks on max_samples attribute - _get_n_samples_bootstrap(n_samples=X.shape[0], max_samples=self.max_samples) - - if not self.bootstrap and self.max_samples is not None: - raise ValueError( - "`max_sample` cannot be set if `bootstrap=False`. " - "Either switch to `bootstrap=True` or set " - "`max_sample=None`." - ) - if ( patching_status.get_status() and (self.random_state is not None) @@ -1014,7 +958,7 @@ def _onedal_fit_ready(self, patching_status, X, y, sample_weight): RuntimeWarning, ) - return patching_status, X, y, sample_weight + return patching_status def _onedal_cpu_supported(self, method_name, *data): class_name = self.__class__.__name__ @@ -1023,9 +967,7 @@ def _onedal_cpu_supported(self, method_name, *data): ) if method_name == "fit": - patching_status, X, y, sample_weight = self._onedal_fit_ready( - patching_status, *data - ) + patching_status = self._onedal_fit_ready(patching_status, *data) patching_status.and_conditions( [ @@ -1035,7 +977,7 @@ def _onedal_cpu_supported(self, method_name, *data): "ExtraTrees only supported starting from oneDAL version 2023.2", ), ( - not sp.issparse(sample_weight), + not sp.issparse(data[2]), "sample_weight is sparse. " "Sparse input is not supported.", ), ] @@ -1080,9 +1022,7 @@ def _onedal_gpu_supported(self, method_name, *data): ) if method_name == "fit": - patching_status, X, y, sample_weight = self._onedal_fit_ready( - patching_status, *data - ) + patching_status = self._onedal_fit_ready(patching_status, *data) patching_status.and_conditions( [ @@ -1092,7 +1032,7 @@ def _onedal_gpu_supported(self, method_name, *data): "ExtraTrees only supported starting from oneDAL version 2023.1", ), (not self.oob_score, "oob_score value is not sklearn conformant."), - (sample_weight is None, "sample_weight is not supported."), + (data[2] is None, "sample_weight is not supported."), ] ) @@ -1130,19 +1070,18 @@ def _onedal_gpu_supported(self, method_name, *data): def _onedal_predict(self, X, queue=None): check_is_fitted(self, "_onedal_estimator") - + xp, _ = get_namespace(X) if sklearn_check_version("1.0"): X = validate_data( self, X, - dtype=[np.float64, np.float32], - force_all_finite=False, + dtype=[xp.float64, xp.float32], reset=False, ensure_2d=True, ) # Warning, order of dtype matters else: X = check_array( - X, dtype=[np.float64, np.float32], force_all_finite=False + X, dtype=[xp.float64, xp.float32] ) # Warning, order of dtype matters return self._onedal_estimator.predict(X, queue=queue) diff --git a/sklearnex/tests/test_memory_usage.py b/sklearnex/tests/test_memory_usage.py index aa92df1d6a..2d52a545cf 100644 --- a/sklearnex/tests/test_memory_usage.py +++ b/sklearnex/tests/test_memory_usage.py @@ -35,10 +35,14 @@ get_dataframes_and_queues, ) from onedal.tests.utils._device_selection import get_queues, is_dpctl_device_available -from onedal.utils._array_api import _get_sycl_namespace from onedal.utils._dpep_helpers import dpctl_available, dpnp_available from sklearnex import config_context -from sklearnex.tests.utils import PATCHED_FUNCTIONS, PATCHED_MODELS, SPECIAL_INSTANCES +from sklearnex.tests.utils import ( + PATCHED_FUNCTIONS, + PATCHED_MODELS, + SPECIAL_INSTANCES, + DummyEstimator, +) from sklearnex.utils._array_api import get_namespace if dpctl_available: @@ -131,41 +135,6 @@ def gen_functions(functions): ORDER_DICT = {"F": np.asfortranarray, "C": np.ascontiguousarray} -if _is_dpc_backend: - - from sklearn.utils.validation import check_is_fitted - - from onedal.datatypes import from_table, to_table - - class DummyEstimatorWithTableConversions(BaseEstimator): - - def fit(self, X, y=None): - sua_iface, xp, _ = _get_sycl_namespace(X) - X_table = to_table(X) - y_table = to_table(y) - # The presence of the fitted attributes (ending with a trailing - # underscore) is required for the correct check. The cleanup of - # the memory will occur at the estimator instance deletion. - self.x_attr_ = from_table( - X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp - ) - self.y_attr_ = from_table( - y_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp - ) - return self - - def predict(self, X): - # Checks if the estimator is fitted by verifying the presence of - # fitted attributes (ending with a trailing underscore). - check_is_fitted(self) - sua_iface, xp, _ = _get_sycl_namespace(X) - X_table = to_table(X) - returned_X = from_table( - X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp - ) - return returned_X - - def gen_clsf_data(n_samples, n_features, dtype=None): data, label = make_classification( n_classes=2, n_samples=n_samples, n_features=n_features, random_state=777 @@ -369,7 +338,7 @@ def test_table_conversions_memory_leaks(dataframe, queue, order, data_shape, dty pytest.skip("SYCL device memory leak check requires the level zero sysman") _kfold_function_template( - DummyEstimatorWithTableConversions, + DummyEstimator, dataframe, data_shape, queue, diff --git a/sklearnex/tests/utils/__init__.py b/sklearnex/tests/utils/__init__.py index 60ca67fa37..db728fe913 100644 --- a/sklearnex/tests/utils/__init__.py +++ b/sklearnex/tests/utils/__init__.py @@ -21,6 +21,7 @@ SPECIAL_INSTANCES, UNPATCHED_FUNCTIONS, UNPATCHED_MODELS, + DummyEstimator, _get_processor_info, call_method, gen_dataset, @@ -39,6 +40,7 @@ "gen_models_info", "gen_dataset", "sklearn_clone_dict", + "DummyEstimator", ] _IS_INTEL = "GenuineIntel" in _get_processor_info() diff --git a/sklearnex/tests/utils/base.py b/sklearnex/tests/utils/base.py index 1949519585..706de39a91 100755 --- a/sklearnex/tests/utils/base.py +++ b/sklearnex/tests/utils/base.py @@ -32,8 +32,11 @@ ) from sklearn.datasets import load_diabetes, load_iris from sklearn.neighbors._base import KNeighborsMixin +from sklearn.utils.validation import check_is_fitted +from onedal.datatypes import from_table, to_table from onedal.tests.utils._dataframes_support import _convert_to_dataframe +from onedal.utils._array_api import _get_sycl_namespace from sklearnex import get_patch_map, patch_sklearn, sklearn_is_patched, unpatch_sklearn from sklearnex.basic_statistics import BasicStatistics, IncrementalBasicStatistics from sklearnex.linear_model import LogisticRegression @@ -369,3 +372,41 @@ def _get_processor_info(): ) return proc + + +class DummyEstimator(BaseEstimator): + + def fit(self, X, y=None): + sua_iface, xp, _ = _get_sycl_namespace(X) + X_table = to_table(X) + y_table = to_table(y) + # The presence of the fitted attributes (ending with a trailing + # underscore) is required for the correct check. The cleanup of + # the memory will occur at the estimator instance deletion. + if sua_iface: + self.x_attr_ = from_table( + X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp + ) + self.y_attr_ = from_table( + y_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp + ) + else: + self.x_attr = from_table(X_table) + self.y_attr = from_table(y_table) + + return self + + def predict(self, X): + # Checks if the estimator is fitted by verifying the presence of + # fitted attributes (ending with a trailing underscore). + check_is_fitted(self) + sua_iface, xp, _ = _get_sycl_namespace(X) + X_table = to_table(X) + if sua_iface: + returned_X = from_table( + X_table, sua_iface=sua_iface, sycl_queue=X.sycl_queue, xp=xp + ) + else: + returned_X = from_table(X_table) + + return returned_X diff --git a/sklearnex/utils/__init__.py b/sklearnex/utils/__init__.py index 4c3fe21154..686e089adf 100755 --- a/sklearnex/utils/__init__.py +++ b/sklearnex/utils/__init__.py @@ -14,6 +14,6 @@ # limitations under the License. # =============================================================================== -from .validation import _assert_all_finite +from .validation import assert_all_finite -__all__ = ["_assert_all_finite"] +__all__ = ["assert_all_finite"] diff --git a/sklearnex/utils/tests/test_finite.py b/sklearnex/utils/tests/test_finite.py deleted file mode 100644 index 7d83667699..0000000000 --- a/sklearnex/utils/tests/test_finite.py +++ /dev/null @@ -1,89 +0,0 @@ -# ============================================================================== -# Copyright 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -import time - -import numpy as np -import numpy.random as rand -import pytest -from numpy.testing import assert_raises - -from sklearnex.utils import _assert_all_finite - - -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -@pytest.mark.parametrize( - "shape", - [ - [16, 2048], - [ - 2**16 + 3, - ], - [1000, 1000], - ], -) -@pytest.mark.parametrize("allow_nan", [False, True]) -def test_sum_infinite_actually_finite(dtype, shape, allow_nan): - X = np.empty(shape, dtype=dtype) - X.fill(np.finfo(dtype).max) - _assert_all_finite(X, allow_nan=allow_nan) - - -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -@pytest.mark.parametrize( - "shape", - [ - [16, 2048], - [ - 65539, # 2**16 + 3, - ], - [1000, 1000], - ], -) -@pytest.mark.parametrize("allow_nan", [False, True]) -@pytest.mark.parametrize("check", ["inf", "NaN", None]) -@pytest.mark.parametrize("seed", [0, int(time.time())]) -def test_assert_finite_random_location(dtype, shape, allow_nan, check, seed): - rand.seed(seed) - X = rand.uniform(high=np.finfo(dtype).max, size=shape).astype(dtype) - - if check: - loc = rand.randint(0, X.size - 1) - X.reshape((-1,))[loc] = float(check) - - if check is None or (allow_nan and check == "NaN"): - _assert_all_finite(X, allow_nan=allow_nan) - else: - assert_raises(ValueError, _assert_all_finite, X, allow_nan=allow_nan) - - -@pytest.mark.parametrize("dtype", [np.float32, np.float64]) -@pytest.mark.parametrize("allow_nan", [False, True]) -@pytest.mark.parametrize("check", ["inf", "NaN", None]) -@pytest.mark.parametrize("seed", [0, int(time.time())]) -def test_assert_finite_random_shape_and_location(dtype, allow_nan, check, seed): - lb, ub = 32768, 1048576 # lb is a patching condition, ub 2^20 - rand.seed(seed) - X = rand.uniform(high=np.finfo(dtype).max, size=rand.randint(lb, ub)).astype(dtype) - - if check: - loc = rand.randint(0, X.size - 1) - X[loc] = float(check) - - if check is None or (allow_nan and check == "NaN"): - _assert_all_finite(X, allow_nan=allow_nan) - else: - assert_raises(ValueError, _assert_all_finite, X, allow_nan=allow_nan) diff --git a/sklearnex/utils/tests/test_validation.py b/sklearnex/utils/tests/test_validation.py new file mode 100644 index 0000000000..70da28dbce --- /dev/null +++ b/sklearnex/utils/tests/test_validation.py @@ -0,0 +1,236 @@ +# ============================================================================== +# Copyright 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import time + +import numpy as np +import numpy.random as rand +import pytest + +from daal4py.sklearn._utils import sklearn_check_version +from onedal.tests.utils._dataframes_support import ( + _convert_to_dataframe, + get_dataframes_and_queues, +) +from sklearnex import config_context +from sklearnex.tests.utils import DummyEstimator, gen_dataset +from sklearnex.utils.validation import _check_sample_weight, validate_data + +# array_api support starts in sklearn 1.2, and array_api_strict conformance starts in sklearn 1.3 +_dataframes_supported = ( + "numpy,pandas" + + (",dpctl" if sklearn_check_version("1.2") else "") + + (",array_api" if sklearn_check_version("1.3") else "") +) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize( + "shape", + [ + [16, 2048], + [2**16 + 3], + [1000, 1000], + ], +) +@pytest.mark.parametrize("ensure_all_finite", ["allow-nan", True]) +def test_sum_infinite_actually_finite(dtype, shape, ensure_all_finite): + est = DummyEstimator() + X = np.empty(shape, dtype=dtype) + X.fill(np.finfo(dtype).max) + X = np.atleast_2d(X) + X_array = validate_data(est, X, ensure_all_finite=ensure_all_finite) + assert type(X_array) == type(X) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize( + "shape", + [ + [16, 2048], + [2**16 + 3], + [1000, 1000], + ], +) +@pytest.mark.parametrize("ensure_all_finite", ["allow-nan", True]) +@pytest.mark.parametrize("check", ["inf", "NaN", None]) +@pytest.mark.parametrize("seed", [0, int(time.time())]) +@pytest.mark.parametrize( + "dataframe, queue", + get_dataframes_and_queues(_dataframes_supported), +) +def test_validate_data_random_location( + dataframe, queue, dtype, shape, ensure_all_finite, check, seed +): + est = DummyEstimator() + rand.seed(seed) + X = rand.uniform(high=np.finfo(dtype).max, size=shape).astype(dtype) + + if check: + loc = rand.randint(0, X.size - 1) + X.reshape((-1,))[loc] = float(check) + + # column heavy pandas inputs are very slow in sklearn's check_array even without + # the finite check, just transpose inputs to guarantee fast processing in tests + X = _convert_to_dataframe( + np.atleast_2d(X).T, + target_df=dataframe, + sycl_queue=queue, + ) + + dispatch = {} + if sklearn_check_version("1.2") and dataframe != "pandas": + dispatch["array_api_dispatch"] = True + + with config_context(**dispatch): + + allow_nan = ensure_all_finite == "allow-nan" + if check is None or (allow_nan and check == "NaN"): + validate_data(est, X, ensure_all_finite=ensure_all_finite) + else: + type_err = "infinity" if allow_nan else "[NaN|infinity]" + msg_err = f"Input X contains {type_err}" + with pytest.raises(ValueError, match=msg_err): + validate_data(est, X, ensure_all_finite=ensure_all_finite) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("ensure_all_finite", ["allow-nan", True]) +@pytest.mark.parametrize("check", ["inf", "NaN", None]) +@pytest.mark.parametrize("seed", [0, int(time.time())]) +@pytest.mark.parametrize( + "dataframe, queue", + get_dataframes_and_queues(_dataframes_supported), +) +def test_validate_data_random_shape_and_location( + dataframe, queue, dtype, ensure_all_finite, check, seed +): + est = DummyEstimator() + lb, ub = 32768, 1048576 # lb is a patching condition, ub 2^20 + rand.seed(seed) + X = rand.uniform(high=np.finfo(dtype).max, size=rand.randint(lb, ub)).astype(dtype) + + if check: + loc = rand.randint(0, X.size - 1) + X[loc] = float(check) + + X = _convert_to_dataframe( + np.atleast_2d(X).T, + target_df=dataframe, + sycl_queue=queue, + ) + + dispatch = {} + if sklearn_check_version("1.2") and dataframe != "pandas": + dispatch["array_api_dispatch"] = True + + with config_context(**dispatch): + + allow_nan = ensure_all_finite == "allow-nan" + if check is None or (allow_nan and check == "NaN"): + validate_data(est, X, ensure_all_finite=ensure_all_finite) + else: + type_err = "infinity" if allow_nan else "[NaN|infinity]" + msg_err = f"Input X contains {type_err}." + with pytest.raises(ValueError, match=msg_err): + validate_data(est, X, ensure_all_finite=ensure_all_finite) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize("check", ["inf", "NaN", None]) +@pytest.mark.parametrize("seed", [0, int(time.time())]) +@pytest.mark.parametrize( + "dataframe, queue", + get_dataframes_and_queues(_dataframes_supported), +) +def test__check_sample_weight_random_shape_and_location( + dataframe, queue, dtype, check, seed +): + # This testing assumes that array api inputs to validate_data will only occur + # with sklearn array_api support which began in sklearn 1.2. This would assume + # that somewhere upstream of the validate_data call, a data conversion of dpnp, + # dpctl, or array_api inputs to numpy inputs would have occurred. + + lb, ub = 32768, 1048576 # lb is a patching condition, ub 2^20 + rand.seed(seed) + shape = (rand.randint(lb, ub), 2) + X = rand.uniform(high=np.finfo(dtype).max, size=shape).astype(dtype) + sample_weight = rand.uniform(high=np.finfo(dtype).max, size=shape[0]).astype(dtype) + + if check: + loc = rand.randint(0, shape[0] - 1) + sample_weight[loc] = float(check) + + X = _convert_to_dataframe( + X, + target_df=dataframe, + sycl_queue=queue, + ) + sample_weight = _convert_to_dataframe( + sample_weight, + target_df=dataframe, + sycl_queue=queue, + ) + + dispatch = {} + if sklearn_check_version("1.2") and dataframe != "pandas": + dispatch["array_api_dispatch"] = True + + with config_context(**dispatch): + + if check is None: + X_out = _check_sample_weight(sample_weight, X) + if dispatch: + assert type(X_out) == type(X) + else: + assert isinstance(X_out, np.ndarray) + else: + msg_err = "Input sample_weight contains [NaN|infinity]" + with pytest.raises(ValueError, match=msg_err): + X_out = _check_sample_weight(sample_weight, X) + + +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +@pytest.mark.parametrize( + "dataframe, queue", + get_dataframes_and_queues(_dataframes_supported), +) +def test_validate_data_output(dtype, dataframe, queue): + # This testing assumes that array api inputs to validate_data will only occur + # with sklearn array_api support which began in sklearn 1.2. This would assume + # that somewhere upstream of the validate_data call, a data conversion of dpnp, + # dpctl, or array_api inputs to numpy inputs would have occurred. + est = DummyEstimator() + X, y = gen_dataset(est, queue=queue, target_df=dataframe, dtype=dtype)[0] + + dispatch = {} + if sklearn_check_version("1.2") and dataframe != "pandas": + dispatch["array_api_dispatch"] = True + + with config_context(**dispatch): + X_out, y_out = validate_data(est, X, y) + # check sklearn validate_data operations work underneath + X_array = validate_data(est, X, reset=False) + + if dispatch: + assert type(X) == type( + X_array + ), f"validate_data converted {type(X)} to {type(X_array)}" + assert type(X) == type(X_out), f"from_array converted {type(X)} to {type(X_out)}" + else: + # array_api_strict from sklearn < 1.2 and pandas will convert to numpy arrays + assert isinstance(X_array, np.ndarray) + assert isinstance(X_out, np.ndarray) diff --git a/sklearnex/utils/validation.py b/sklearnex/utils/validation.py index b2d1898643..c2ba2c1dc5 100755 --- a/sklearnex/utils/validation.py +++ b/sklearnex/utils/validation.py @@ -14,4 +14,162 @@ # limitations under the License. # =============================================================================== -from daal4py.sklearn.utils.validation import _assert_all_finite +import numbers + +import scipy.sparse as sp +from sklearn.utils.validation import _assert_all_finite as _sklearn_assert_all_finite +from sklearn.utils.validation import _num_samples, check_array, check_non_negative + +from daal4py.sklearn._utils import daal_check_version, sklearn_check_version + +from ._array_api import get_namespace + +if sklearn_check_version("1.6"): + from sklearn.utils.validation import validate_data as _sklearn_validate_data + + _finite_keyword = "ensure_all_finite" + +else: + from sklearn.base import BaseEstimator + + _sklearn_validate_data = BaseEstimator._validate_data + _finite_keyword = "force_all_finite" + + +if daal_check_version((2024, "P", 700)): + from onedal.utils.validation import _assert_all_finite as _onedal_assert_all_finite + + def _onedal_supported_format(X, xp=None): + # array_api does not have a `strides` or `flags` attribute for testing memory + # order. When dlpack support is brought in for oneDAL, general support for + # array_api can be enabled and the hasattr check can be removed. + # _onedal_supported_format is therefore conservative in verifying attributes and + # does not support array_api. This will block onedal_assert_all_finite from being + # used for array_api inputs but will allow dpnp ndarrays and dpctl tensors. + return X.dtype in [xp.float32, xp.float64] and hasattr(X, "flags") + +else: + from daal4py.utils.validation import _assert_all_finite as _onedal_assert_all_finite + from onedal.utils._array_api import _is_numpy_namespace + + def _onedal_supported_format(X, xp=None): + # daal4py _assert_all_finite only supports numpy namespaces, use internally- + # defined check to validate inputs, otherwise offload to sklearn + return X.dtype in [xp.float32, xp.float64] and _is_numpy_namespace(xp) + + +def _sklearnex_assert_all_finite( + X, + *, + allow_nan=False, + input_name="", +): + # size check is an initial match to daal4py for performance reasons, can be + # optimized later + xp, _ = get_namespace(X) + if X.size < 32768 or not _onedal_supported_format(X, xp): + if sklearn_check_version("1.1"): + _sklearn_assert_all_finite(X, allow_nan=allow_nan, input_name=input_name) + else: + _sklearn_assert_all_finite(X, allow_nan=allow_nan) + else: + _onedal_assert_all_finite(X, allow_nan=allow_nan, input_name=input_name) + + +def assert_all_finite( + X, + *, + allow_nan=False, + input_name="", +): + _sklearnex_assert_all_finite( + X.data if sp.issparse(X) else X, + allow_nan=allow_nan, + input_name=input_name, + ) + + +def validate_data( + _estimator, + /, + X="no_validation", + y="no_validation", + **kwargs, +): + # force finite check to not occur in sklearn, default is True + # `ensure_all_finite` is the most up-to-date keyword name in sklearn + # _finite_keyword provides backward compatability for `force_all_finite` + ensure_all_finite = kwargs.pop("ensure_all_finite", True) + kwargs[_finite_keyword] = False + + out = _sklearn_validate_data( + _estimator, + X=X, + y=y, + **kwargs, + ) + if ensure_all_finite: + # run local finite check + allow_nan = ensure_all_finite == "allow-nan" + arg = iter(out if isinstance(out, tuple) else (out,)) + if not isinstance(X, str) or X != "no_validation": + assert_all_finite(next(arg), allow_nan=allow_nan, input_name="X") + if not (y is None or isinstance(y, str) and y == "no_validation"): + assert_all_finite(next(arg), allow_nan=allow_nan, input_name="y") + return out + + +def _check_sample_weight( + sample_weight, X, dtype=None, copy=False, only_non_negative=False +): + + n_samples = _num_samples(X) + xp, _ = get_namespace(X) + + if dtype is not None and dtype not in [xp.float32, xp.float64]: + dtype = xp.float64 + + if sample_weight is None: + if hasattr(X, "device"): + sample_weight = xp.ones(n_samples, dtype=dtype, device=X.device) + else: + sample_weight = xp.ones(n_samples, dtype=dtype) + elif isinstance(sample_weight, numbers.Number): + if hasattr(X, "device"): + sample_weight = xp.full( + n_samples, sample_weight, dtype=dtype, device=X.device + ) + else: + sample_weight = xp.full(n_samples, sample_weight, dtype=dtype) + else: + if dtype is None: + dtype = [xp.float64, xp.float32] + + params = { + "accept_sparse": False, + "ensure_2d": False, + "dtype": dtype, + "order": "C", + "copy": copy, + _finite_keyword: False, + } + if sklearn_check_version("1.1"): + params["input_name"] = "sample_weight" + + sample_weight = check_array(sample_weight, **params) + assert_all_finite(sample_weight, input_name="sample_weight") + + if sample_weight.ndim != 1: + raise ValueError("Sample weights must be 1D array or scalar") + + if sample_weight.shape != (n_samples,): + raise ValueError( + "sample_weight.shape == {}, expected {}!".format( + sample_weight.shape, (n_samples,) + ) + ) + + if only_non_negative: + check_non_negative(sample_weight, "`sample_weight`") + + return sample_weight