diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 00000000..d92458d3 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,24 @@ +from pathlib import Path + +import pytest + +REPO_ROOT = Path(__file__).resolve().parents[1] +DATA_DIR = REPO_ROOT / "data" + + +@pytest.fixture() +def sales_data_csv(tmp_path: Path) -> Path: + """Return a temporary copy of ``sales_data.csv``.""" + src = DATA_DIR / "sales_data.csv" + dst = tmp_path / "sales_data.csv" + dst.write_text(src.read_text()) + return dst + + +@pytest.fixture() +def insurance_sales_csv(tmp_path: Path) -> Path: + """Return a temporary copy of ``insurance_sales.csv``.""" + src = DATA_DIR / "insurance_sales.csv" + dst = tmp_path / "insurance_sales.csv" + dst.write_text(src.read_text()) + return dst diff --git a/tests/test_business_tools.py b/tests/test_business_tools.py index 5a7e26ee..1ef4ad6b 100644 --- a/tests/test_business_tools.py +++ b/tests/test_business_tools.py @@ -1,10 +1,10 @@ -import os import sys +from pathlib import Path # Ensure the repository root is on sys.path so business_tools can be imported -REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) -sys.path.insert(0, os.path.join(REPO_ROOT, "src")) +REPO_ROOT = Path(__file__).resolve().parents[1] +sys.path.insert(0, str(REPO_ROOT / "src")) from business_tools import ( # noqa: E402 calculate_profit, @@ -18,57 +18,39 @@ ) -def test_calculate_profit(): +def test_calculate_profit() -> None: assert calculate_profit(1000, 600) == 400 -def test_get_sales_from_csv(tmp_path): - # copy sample sales_data.csv to a temp directory to avoid modifying repo file - src = os.path.join(REPO_ROOT, 'data', 'sales_data.csv') - dst = tmp_path / 'sales_data.csv' - with open(src, 'r') as fsrc, open(dst, 'w') as fdst: - fdst.write(fsrc.read()) - assert get_sales_from_csv(str(dst)) == 950 +def test_get_sales_from_csv(sales_data_csv: Path) -> None: + """Verify reading the sales CSV sums correctly.""" + assert get_sales_from_csv(str(sales_data_csv)) == 950 -def test_calculate_commission(): - premiums = [300, 700, 200] +def test_calculate_commission() -> None: + premiums = [300.0, 700.0, 200.0] assert calculate_commission(premiums, rate=0.1) == 120.0 -def test_load_insurance_sales_and_total_commission(tmp_path): - src = os.path.join(REPO_ROOT, "data", "insurance_sales.csv") - dst = tmp_path / "insurance_sales.csv" - with open(src, "r") as fsrc, open(dst, "w") as fdst: - fdst.write(fsrc.read()) - records = load_insurance_sales(str(dst)) +def test_load_insurance_sales_and_total_commission( + insurance_sales_csv: Path, +) -> None: + records = load_insurance_sales(str(insurance_sales_csv)) assert len(records) == 15 assert total_commission(records) == 2545.0 -def test_filter_by_state(tmp_path): - src = os.path.join(REPO_ROOT, "data", "insurance_sales.csv") - dst = tmp_path / "insurance_sales.csv" - with open(src, "r") as fsrc, open(dst, "w") as fdst: - fdst.write(fsrc.read()) - records = load_insurance_sales(str(dst)) +def test_filter_by_state(insurance_sales_csv: Path) -> None: + records = load_insurance_sales(str(insurance_sales_csv)) ca_records = filter_by_state(records, "CA") assert len(ca_records) == 4 -def test_calculate_total_premium(tmp_path): - src = os.path.join(REPO_ROOT, "data", "insurance_sales.csv") - dst = tmp_path / "insurance_sales.csv" - with open(src, "r") as fsrc, open(dst, "w") as fdst: - fdst.write(fsrc.read()) - records = load_insurance_sales(str(dst)) +def test_calculate_total_premium(insurance_sales_csv: Path) -> None: + records = load_insurance_sales(str(insurance_sales_csv)) assert calculate_total_premium(records) == 18480.0 -def test_filter_policies_by_state(tmp_path): - src = os.path.join(REPO_ROOT, "data", "insurance_sales.csv") - dst = tmp_path / "insurance_sales.csv" - with open(src, "r") as fsrc, open(dst, "w") as fdst: - fdst.write(fsrc.read()) - records = load_insurance_sales(str(dst)) +def test_filter_policies_by_state(insurance_sales_csv: Path) -> None: + records = load_insurance_sales(str(insurance_sales_csv)) ca_records = filter_policies_by_state(records, "CA") assert len(ca_records) == 4 diff --git a/tests/test_cli.py b/tests/test_cli.py index 0b4ef07e..8075c1d5 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -13,21 +13,15 @@ def run_cli(args: list[str]) -> str: return result.strip() -def test_profit_cli(): +def test_profit_cli() -> None: assert run_cli(["profit", "100", "40"]) == "60" -def test_commission_cli(tmp_path: Path): - src = REPO_ROOT / "data" / "insurance_sales.csv" - dst = tmp_path / "insurance_sales.csv" - dst.write_text(src.read_text()) - out = run_cli(["commission", str(dst)]) +def test_commission_cli(insurance_sales_csv: Path) -> None: + out = run_cli(["commission", str(insurance_sales_csv)]) assert out == "2545.0" -def test_premium_cli(tmp_path: Path): - src = REPO_ROOT / "data" / "insurance_sales.csv" - dst = tmp_path / "insurance_sales.csv" - dst.write_text(src.read_text()) - out = run_cli(["premium", str(dst)]) +def test_premium_cli(insurance_sales_csv: Path) -> None: + out = run_cli(["premium", str(insurance_sales_csv)]) assert out == "18480.0"