Skip to content

Commit

Permalink
Write tests
Browse files Browse the repository at this point in the history
  • Loading branch information
agendazhang committed Jan 14, 2025
1 parent 9a4c80d commit 3c25022
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 8 deletions.
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ readme = "README.md"
[tool.poetry.dependencies]
python = "^3.11"

[tool.poetry.group.dev.dependencies]
pytest = "^8.3.4"

[tool.semantic_release]
version_toml = [
"pyproject.toml:tool.poetry.version",
Expand Down
16 changes: 8 additions & 8 deletions src/linreg_ally/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,34 +52,34 @@ def run_linear_regression(dataframe, target_column, numeric_feats, categorical_f
>>> numeric_feats = ['feature_1', 'feature_2']
>>> categorical_feats = ['category']
>>> drop_feats = []
>>> best_model, X_train, X_test, y_train, y_test, scoring_metrics = run_linear_regression(
>>> best_model, X_train, X_test, y_train, y_test, scores = run_linear_regression(
... df, target_column, numeric_feats, categorical_feats, drop_feats, scoring_metrics=['r2', 'neg_mean_squared_error']
... )
>>> scores
{'r2': 0.52, 'neg_mean_squared_error': 1.23}
"""

if not isinstance(dataframe, pd.DataFrame):
raise Exception("dataframe must be a pandas DataFrame.")
raise TypeError("dataframe must be a pandas DataFrame.")

if dataframe.shape[1] <= 1:
raise Exception("dataframe must contain more than one column.")
raise ValueError("dataframe must contain more than one column.")

if target_column not in dataframe.columns:
raise Exception(f"target_column '{target_column}' is not in the dataframe.")
raise ValueError(f"target_column '{target_column}' is not in the dataframe.")

if not (0.0 < test_size < 1.0):
raise Exception("test_size must be between 0.0 and 1.0.")
raise ValueError("test_size must be between 0.0 and 1.0.")

if random_state is not None and not isinstance(random_state, int):
raise Exception("random_state must be an integer.")
raise TypeError("random_state must be an integer.")

if not isinstance(scoring_metrics, list) or not all(isinstance(metric, str) for metric in scoring_metrics):
raise Exception("scoring_metrics must be a list of strings.")
raise TypeError("scoring_metrics must be a list of strings.")

if not all(metric in get_scorer_names() for metric in scoring_metrics):
invalid_metrics = [metric for metric in scoring_metrics if metric not in get_scorer_names()]
raise Exception(f"The following scoring metrics are not valid: {', '.join(invalid_metrics)}")
raise ValueError(f"The following scoring metrics are not valid: {', '.join(invalid_metrics)}")

drop_feats = drop_feats if drop_feats is not None else []

Expand Down
90 changes: 90 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from linreg_ally.models import run_linear_regression
from sklearn.pipeline import Pipeline
import numpy as np
import pandas as pd

df = pd.DataFrame({
"feature_1": [1, 2, 3, 4, 5],
"feature_2": [0.5, 0.1, 0.4, 0.9, 0.6],
"category": ["a", "b", "a", "b", "c"],
"target": [1.0, 2.5, 3.4, 4.3, 5.1]
})

target_column = 'target'
numeric_feats = ['feature_1', 'feature_2']
categorical_feats = ['category']
drop_feats = []

# Function to create sample DataFrame for testing
def create_sample_dataframe():
return pd.DataFrame({
"feature_1": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
"feature_2": [0.5, 0.1, 0.4, 0.9, 0.6, 0.7, 0.3, 0.2, 0.8, 1.0],
"category": ["a", "b", "a", "b", "c", "a", "b", "c", "a", "b"],
"target": [1.0, 2.5, 3.4, 4.3, 5.1, 6.2, 7.1, 8.3, 9.5, 10.0]
})

df = create_sample_dataframe()
target_column = 'target'
numeric_feats = ['feature_1', 'feature_2']
categorical_feats = ['category']
drop_feats = []

# Test 1: Check if the function result has correct return type and values
def test_return_type_and_length():
result = run_linear_regression(df, target_column, numeric_feats, categorical_feats, drop_feats)
assert isinstance(result, tuple), "The return type should be a tuple."
assert len(result) == 6, "The tuple should have 6 elements."
assert isinstance(result[0], Pipeline), "The first element should be a Pipeline."
assert isinstance(result[1], pd.DataFrame), "The second element should be a DataFrame (X_train)."
assert isinstance(result[2], pd.DataFrame), "The third element should be a DataFrame (X_test)."
assert isinstance(result[3], pd.Series), "The fourth element should be a Series (y_train)."
assert isinstance(result[4], pd.Series), "The fifth element should be a Series (y_test)."
assert isinstance(result[5], dict), "The sixth element should be a dictionary (scores)."
assert "r2" in result[5], "r2 score should be in the scores dictionary."
assert "neg_mean_squared_error" in result[5], "neg_mean_squared_error score should be in the scores dictionary."

# Test 2: Check if TypeError is raised when a non-DataFrame is provided
def test_invalid_dataframe():
try:
run_linear_regression("not_a_dataframe", target_column, numeric_feats, categorical_feats, drop_feats)
except TypeError as e:
assert str(e) == "dataframe must be a pandas DataFrame."
else:
assert False, "TypeError not raised for non-DataFrame input."

# Test 3: Check if ValueError is raised when the target column is not in the DataFrame
def test_target_column_present():
try:
run_linear_regression(df.drop(columns=[target_column]), target_column, numeric_feats, categorical_feats, drop_feats)
except ValueError as e:
assert f"target_column '{target_column}' is not in the dataframe." in str(e)
else:
assert False, "ValueError not raised for missing target column."

# Test 4: Check if an invalid test_size raises a ValueError
def test_invalid_test_size():
try:
run_linear_regression(df, target_column, numeric_feats, categorical_feats, drop_feats, test_size=1.5)
except ValueError as e:
assert str(e) == "test_size must be between 0.0 and 1.0."
else:
assert False, "ValueError not raised for invalid test_size."

# Test 5: Check if an invalid random_state raises a ValueError
def test_invalid_random_state():
try:
run_linear_regression(df, target_column, numeric_feats, categorical_feats, drop_feats, random_state="not_an_int")
except TypeError as e:
assert str(e) == "random_state must be an integer."
else:
assert False, "TypeError not raised for non-integer random_state."

# Test 6: Check if invalid scoring_metrics raises a ValueError
def test_invalid_scoring_metric():
try:
run_linear_regression(df, target_column, numeric_feats, categorical_feats, drop_feats, scoring_metrics=['invalid_metric'])
except ValueError as e:
assert "are not valid" in str(e)
else:
assert False, "ValueError not raised for invalid scoring_metric."

0 comments on commit 3c25022

Please sign in to comment.