diff --git a/src/marketdata/input_types/stocks.py b/src/marketdata/input_types/stocks.py index cf10534..f814cd2 100644 --- a/src/marketdata/input_types/stocks.py +++ b/src/marketdata/input_types/stocks.py @@ -4,6 +4,7 @@ from pydantic import Field, field_validator, model_validator from marketdata.input_types.base import BaseInputType, BaseModelConfig +from marketdata.utils import format_timestamp class StocksPricesInput(BaseInputType): @@ -48,11 +49,11 @@ class StocksCandlesInput(BaseInputType): resolution: str = Field(description="The resolution to use", default="D") - from_date: datetime.date | str | None = Field( + from_date: datetime.date | datetime.datetime | str | None = Field( description="The start date to fetch candles for", default=None, alias="from" ) - to_date: datetime.date | str | None = Field( + to_date: datetime.date | datetime.datetime | str | None = Field( description="The end date to fetch candles for", default=None, alias="to" ) @@ -71,6 +72,24 @@ class StocksCandlesInput(BaseInputType): @model_validator(mode="after") def validate_input(self) -> "StocksCandlesInput": self._validate_min_max_dates("from_date", "to_date") + + if self.is_intraday: + # Intraday resolution needs datetime objects to work with split_dates_by_timeframe + # But str is allowed in the input type for "yesterday" and others. + # So Pydantic will skip the validation for str, and we need to convert str to datetime objects here. + if isinstance(self.from_date, str): + self.from_date = format_timestamp(self.from_date) + if isinstance(self.to_date, str): + self.to_date = format_timestamp(self.to_date) + + if isinstance(self.from_date, datetime.date): + self.from_date = datetime.datetime.combine( + self.from_date, datetime.time.min + ) + if isinstance(self.to_date, datetime.date): + self.to_date = datetime.datetime.combine( + self.to_date, datetime.time.min + ) return self @field_validator("resolution") diff --git a/src/marketdata/resources/stocks/candles.py b/src/marketdata/resources/stocks/candles.py index a3cc450..107e35b 100644 --- a/src/marketdata/resources/stocks/candles.py +++ b/src/marketdata/resources/stocks/candles.py @@ -80,7 +80,7 @@ def _get_response( if input_params.is_intraday: year_ranges = split_dates_by_timeframe( input_params.from_date, - input_params.to_date or datetime.date.today(), + input_params.to_date or datetime.datetime.now(), datetime.timedelta(days=365), ) else: diff --git a/src/tests/test_stocks_candles.py b/src/tests/test_stocks_candles.py index bc53258..fe686d1 100644 --- a/src/tests/test_stocks_candles.py +++ b/src/tests/test_stocks_candles.py @@ -444,3 +444,21 @@ def test_get_stocks_candles_response_200_csv(respx_mock, client): filename="test.csv", ) assert pathlib.Path(output).read_text() is not "" + + +def test_stocks_candles_intraday_string_dates(load_json, respx_mock, client): + mock_data = load_json("stocks_candles_response_200") + + respx_mock.get("https://api.marketdata.app/v1/stocks/candles/4H/AAPL/").respond( + json=mock_data, + status_code=200, + ) + + candles = client.stocks.candles( + symbol="AAPL", + resolution="4H", + from_date="2023-01-01", + to_date="2023-01-05", + output_format=OutputFormat.INTERNAL, + ) + assert len(candles) == 253