Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from pathlib import Path

import pytest

REPO_ROOT = Path(__file__).resolve().parents[1]
DATA_DIR = REPO_ROOT / "data"


@pytest.fixture()
def sales_data_csv(tmp_path: Path) -> Path:
"""Return a temporary copy of ``sales_data.csv``."""
src = DATA_DIR / "sales_data.csv"
dst = tmp_path / "sales_data.csv"
dst.write_text(src.read_text())
return dst


@pytest.fixture()
def insurance_sales_csv(tmp_path: Path) -> Path:
"""Return a temporary copy of ``insurance_sales.csv``."""
src = DATA_DIR / "insurance_sales.csv"
dst = tmp_path / "insurance_sales.csv"
dst.write_text(src.read_text())
return dst
56 changes: 19 additions & 37 deletions tests/test_business_tools.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@

import os
import sys
from pathlib import Path

# Ensure the repository root is on sys.path so business_tools can be imported
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))
sys.path.insert(0, os.path.join(REPO_ROOT, "src"))
REPO_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(REPO_ROOT / "src"))

from business_tools import ( # noqa: E402
calculate_profit,
Expand All @@ -18,57 +18,39 @@
)


def test_calculate_profit():
def test_calculate_profit() -> None:
assert calculate_profit(1000, 600) == 400


def test_get_sales_from_csv(tmp_path):
# copy sample sales_data.csv to a temp directory to avoid modifying repo file
src = os.path.join(REPO_ROOT, 'data', 'sales_data.csv')
dst = tmp_path / 'sales_data.csv'
with open(src, 'r') as fsrc, open(dst, 'w') as fdst:
fdst.write(fsrc.read())
assert get_sales_from_csv(str(dst)) == 950
def test_get_sales_from_csv(sales_data_csv: Path) -> None:
"""Verify reading the sales CSV sums correctly."""
assert get_sales_from_csv(str(sales_data_csv)) == 950


def test_calculate_commission():
premiums = [300, 700, 200]
def test_calculate_commission() -> None:
premiums = [300.0, 700.0, 200.0]
assert calculate_commission(premiums, rate=0.1) == 120.0


def test_load_insurance_sales_and_total_commission(tmp_path):
src = os.path.join(REPO_ROOT, "data", "insurance_sales.csv")
dst = tmp_path / "insurance_sales.csv"
with open(src, "r") as fsrc, open(dst, "w") as fdst:
fdst.write(fsrc.read())
records = load_insurance_sales(str(dst))
def test_load_insurance_sales_and_total_commission(
insurance_sales_csv: Path,
) -> None:
records = load_insurance_sales(str(insurance_sales_csv))
assert len(records) == 15
assert total_commission(records) == 2545.0


def test_filter_by_state(tmp_path):
src = os.path.join(REPO_ROOT, "data", "insurance_sales.csv")
dst = tmp_path / "insurance_sales.csv"
with open(src, "r") as fsrc, open(dst, "w") as fdst:
fdst.write(fsrc.read())
records = load_insurance_sales(str(dst))
def test_filter_by_state(insurance_sales_csv: Path) -> None:
records = load_insurance_sales(str(insurance_sales_csv))
ca_records = filter_by_state(records, "CA")
assert len(ca_records) == 4

def test_calculate_total_premium(tmp_path):
src = os.path.join(REPO_ROOT, "data", "insurance_sales.csv")
dst = tmp_path / "insurance_sales.csv"
with open(src, "r") as fsrc, open(dst, "w") as fdst:
fdst.write(fsrc.read())
records = load_insurance_sales(str(dst))
def test_calculate_total_premium(insurance_sales_csv: Path) -> None:
records = load_insurance_sales(str(insurance_sales_csv))
assert calculate_total_premium(records) == 18480.0


def test_filter_policies_by_state(tmp_path):
src = os.path.join(REPO_ROOT, "data", "insurance_sales.csv")
dst = tmp_path / "insurance_sales.csv"
with open(src, "r") as fsrc, open(dst, "w") as fdst:
fdst.write(fsrc.read())
records = load_insurance_sales(str(dst))
def test_filter_policies_by_state(insurance_sales_csv: Path) -> None:
records = load_insurance_sales(str(insurance_sales_csv))
ca_records = filter_policies_by_state(records, "CA")
assert len(ca_records) == 4
16 changes: 5 additions & 11 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,15 @@ def run_cli(args: list[str]) -> str:
return result.strip()


def test_profit_cli():
def test_profit_cli() -> None:
assert run_cli(["profit", "100", "40"]) == "60"


def test_commission_cli(tmp_path: Path):
src = REPO_ROOT / "data" / "insurance_sales.csv"
dst = tmp_path / "insurance_sales.csv"
dst.write_text(src.read_text())
out = run_cli(["commission", str(dst)])
def test_commission_cli(insurance_sales_csv: Path) -> None:
out = run_cli(["commission", str(insurance_sales_csv)])
assert out == "2545.0"


def test_premium_cli(tmp_path: Path):
src = REPO_ROOT / "data" / "insurance_sales.csv"
dst = tmp_path / "insurance_sales.csv"
dst.write_text(src.read_text())
out = run_cli(["premium", str(dst)])
def test_premium_cli(insurance_sales_csv: Path) -> None:
out = run_cli(["premium", str(insurance_sales_csv)])
assert out == "18480.0"