Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,5 +552,8 @@ def run_stubgen(self):
"geopandas",
"shapely",
"folium",
"pyspark",
"matplotlib",
"pandas"
Comment on lines +555 to +557
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyspark is added without a minimum version constraint, but the code relies on newer APIs (e.g., pyspark.sql.functions.try_divide). To prevent runtime incompatibilities, consider pinning/declaring a minimum supported PySpark version here.

Copilot uses AI. Check for mistakes.
],
)
1 change: 1 addition & 0 deletions src/dsf/tsm/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .tsm import TSM as TSM
from .tsm import _get_or_create_spark as _get_or_create_spark
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re-exporting a leading-underscore symbol makes it part of the public package surface even though the name signals “private”. Consider either not exporting it from __init__.py, or renaming it to a public name (and documenting it) if it’s intended for external use.

Suggested change
from .tsm import _get_or_create_spark as _get_or_create_spark

Copilot uses AI. Check for mistakes.
222 changes: 155 additions & 67 deletions src/dsf/tsm/tsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,35 @@
from pathlib import Path
from typing import Dict, Optional

import polars as pl
from pyspark.sql import DataFrame, SparkSession, Window
import pyspark.sql.functions as F
import pyspark.sql.types as T
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pyspark.sql.types as T is imported but not used in this module; please remove it to avoid lint/type-check noise.

Suggested change
import pyspark.sql.types as T

Copilot uses AI. Check for mistakes.


def _get_or_create_spark() -> SparkSession:
"""Return the active SparkSession or create a local one."""
return (
SparkSession.builder
.master("local[*]")
.appName("TSM")
.config("spark.driver.memory", "128g") \
.config("spark.executor.memory", "128g") \
.config("spark.sql.shuffle.partitions", "1600") \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
.getOrCreate()
)


Comment on lines +12 to 25
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This helper hard-codes local[*], very large driver/executor memory, and shuffle/adaptive settings. In a library context this can override user/cluster configuration and cause failures on machines without these resources. Prefer returning an existing active session (or a plain builder.getOrCreate()), and avoid setting resource configs here (or make them optional parameters).

Suggested change
"""Return the active SparkSession or create a local one."""
return (
SparkSession.builder
.master("local[*]")
.appName("TSM")
.config("spark.driver.memory", "128g") \
.config("spark.executor.memory", "128g") \
.config("spark.sql.shuffle.partitions", "1600") \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.sql.adaptive.coalescePartitions.enabled", "true") \
.getOrCreate()
)
"""Return the active SparkSession or create one with default settings."""
active = SparkSession.getActiveSession()
if active is not None:
return active
return SparkSession.builder.appName("TSM").getOrCreate()

Copilot uses AI. Check for mistakes.
class TSM:
"""Traffic State Monitoring data.

Builds density/flow clusters from per-vehicle detector data stored in a
Polars ``DataFrame``.
PySpark ``DataFrame``.

Parameters
----------
data : pl.DataFrame
data : pyspark.sql.DataFrame
Raw detector data. Must contain at least the columns
``detector``, ``timestamp``, and ``speed_kph`` (or names that are
mapped to them via *column_mapping*).
Expand All @@ -27,7 +44,7 @@ class TSM:

- ``detector``: unique ID of the traffic detector (e.g. loop sensor).
- ``timestamp``: timestamp of the vehicle passage (must be a
Polars datetime type).
PySpark ``TimestampType``).
- ``speed_kph``: speed of the vehicle in km/h.

Optional target columns:
Expand All @@ -40,14 +57,15 @@ class TSM:

def __init__(
self,
data: pl.DataFrame,
data: DataFrame,
column_mapping: Optional[Dict[str, str]] = None,
) -> None:
if column_mapping is not None:
rename = {
src: eng for src, eng in column_mapping.items() if src in data.columns
}
self._df: pl.DataFrame = data.rename(rename)
df = data
for src, eng in column_mapping.items():
if src in df.columns:
df = df.withColumnRenamed(src, eng)
self._df: DataFrame = df
else:
self._df = data

Expand All @@ -63,7 +81,8 @@ def __init__(
f"Available columns: {self._df.columns}"
)

self._result: Optional[pl.DataFrame] = None
self._result: Optional[DataFrame] = None
self._result_intratimes: Optional[DataFrame] = None

# ------------------------------------------------------------------
# helpers
Expand All @@ -83,6 +102,7 @@ def clusterize(
self,
min_vehicles: int = 5,
gap_factor: float = 3.0,
intermediates=False
) -> "TSM":
"""Run the clustering pipeline.

Expand All @@ -102,78 +122,101 @@ def clusterize(
group = self._group_cols

# --- lanes sub-table (only when lane info is available) -----------
lanes_df: Optional[pl.DataFrame] = None
lanes_df: Optional[DataFrame] = None
if self._has_lane:
lanes_df = self._df.group_by(group).agg(
pl.col("lane").n_unique().alias("n_lanes")
lanes_df = self._df.groupBy(group).agg(
F.count_distinct("lane").alias("n_lanes")
)

# --- window for per-detector and direction in ordered operations ----------------------
w = Window.partitionBy(group).orderBy("timestamp")
Comment on lines +127 to +132
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

groupBy() takes varargs; passing group (a Python list) as a single argument will fail at runtime. Use argument unpacking (groupBy(*group)) so the grouping columns are applied correctly.

Suggested change
lanes_df = self._df.groupBy(group).agg(
F.count_distinct("lane").alias("n_lanes")
)
# --- window for per-detector and direction in ordered operations ----------------------
w = Window.partitionBy(group).orderBy("timestamp")
lanes_df = self._df.groupBy(*group).agg(
F.count_distinct("lane").alias("n_lanes")
)
# --- window for per-detector and direction in ordered operations ----------------------
w = Window.partitionBy(*group).orderBy("timestamp")

Copilot uses AI. Check for mistakes.
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Window.partitionBy() takes varargs; passing group (a list) will be treated as a single invalid column and will error. Use Window.partitionBy(*group).

Suggested change
w = Window.partitionBy(group).orderBy("timestamp")
w = Window.partitionBy(*group).orderBy("timestamp")

Copilot uses AI. Check for mistakes.

# --- main pipeline ------------------------------------------------
df = self._df

# delta_t_s: seconds since previous row for the same detector and direction
df = df.withColumn(
"prev_timestamp", F.lag("timestamp").over(w)
).withColumn(
"delta_t_s",
F.when(
F.col("prev_timestamp").isNotNull(),
F.unix_timestamp(F.col("timestamp")) - F.unix_timestamp(F.col("prev_timestamp")),
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unix_timestamp() truncates to whole seconds, so delta_t_s will lose sub-second precision compared to the previous implementation (and can change cluster boundaries). Consider computing the difference in seconds using a higher-precision approach (e.g., casting timestamps to double/decimal or using an interval and extracting seconds) so you preserve fractional seconds.

Suggested change
F.unix_timestamp(F.col("timestamp")) - F.unix_timestamp(F.col("prev_timestamp")),
F.col("timestamp").cast("double") - F.col("prev_timestamp").cast("double"),

Copilot uses AI. Check for mistakes.
),
).drop("prev_timestamp")

# distance_m
df = df.withColumn(
"distance_m",
F.col("speed_kph") * F.col("delta_t_s") / 3.6,
)

# new_cluster flag
df = df.withColumn(
"new_cluster",
(
(F.col("distance_m") > F.lit(gap_factor) * (F.col("speed_kph") / 3.6))
| F.col("delta_t_s").isNull()
).cast("int"),
)

# cluster_local_id: cumulative sum of new_cluster within each detector and direction
w_unbounded = (
Window.partitionBy(group)
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Window.partitionBy(group) has the same issue here: partitionBy expects varargs, not a list. Use Window.partitionBy(*group) to avoid runtime errors.

Suggested change
Window.partitionBy(group)
Window.partitionBy(*group)

Copilot uses AI. Check for mistakes.
.orderBy("timestamp")
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
)
df = df.withColumn(
"cluster_local_id",
F.sum("new_cluster").over(w_unbounded),
)
df = df.withColumn(
"intra_cluster_intermediate_time_s",
F.when(F.lag("cluster_local_id").over(w) == F.col("cluster_local_id"), F.col("delta_t_s")).otherwise(None),
)
if intermediates:
self._result_intratimes = df.drop("new_cluster").drop("prev_timestamp")

# aggregate per cluster
result = (
self._df.sort(group + ["timestamp"])
.with_columns(
(pl.col("timestamp") - pl.col("timestamp").shift(1))
.dt.total_seconds()
.over(group)
.alias("delta_t_s")
)
.with_columns(
(pl.col("speed_kph") * pl.col("delta_t_s") / 3.6).alias("distance_m")
)
.with_row_index("row_idx")
.with_columns(
(
(pl.col("distance_m") > gap_factor * (pl.col("speed_kph") / 3.6))
| pl.col("delta_t_s").is_null()
).alias("new_cluster")
)
.with_columns(
pl.col("new_cluster")
.cast(pl.Int32)
.cum_sum()
.over(group)
.alias("cluster_local_id")
)
.group_by(group + ["cluster_local_id"])
df.groupBy(group + ["cluster_local_id"])
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above: groupBy() expects varargs of column names; group + ["cluster_local_id"] is a list being passed as a single argument. Use groupBy(*(group + ["cluster_local_id"])) (or groupBy(*group, "cluster_local_id")).

Suggested change
df.groupBy(group + ["cluster_local_id"])
df.groupBy(*(group + ["cluster_local_id"]))

Copilot uses AI. Check for mistakes.
.agg(
pl.col("speed_kph").mean().alias("mean_speed_kph"),
pl.len().alias("num_vehicles"),
(pl.col("distance_m") * 1e-3).sum().alias("cluster_len_km"),
pl.col("delta_t_s").sum().alias("cluster_dt_s"),
F.mean("speed_kph").alias("mean_speed_kph"),
F.count("*").alias("num_vehicles"),
F.sum(F.col("distance_m") * 1e-3).alias("cluster_len_km"),
F.sum("delta_t_s").alias("cluster_dt_s"),
F.mean("intra_cluster_intermediate_time_s").alias("mean_intra_cluster_dt_s"),
)
.filter(pl.col("num_vehicles") > min_vehicles)
.filter(F.col("num_vehicles") > min_vehicles)
)

# --- join lane count & compute density / flow ---------------------
if lanes_df is not None:
result = result.join(lanes_df, on=group, how="left").with_columns(
(
pl.col("num_vehicles")
/ pl.col("cluster_len_km")
/ pl.col("n_lanes")
).alias("density"),
(
pl.col("num_vehicles")
* 3.6e3
/ pl.col("cluster_dt_s")
/ pl.col("n_lanes")
).alias("flow"),
result = result.join(lanes_df, on=group, how="left").withColumn(
"density",
F.try_divide(F.col("num_vehicles"), F.col("cluster_len_km") * F.col("n_lanes")),
).withColumn(
"flow",
F.try_divide(F.col("num_vehicles") * 3.6e3, F.col("cluster_dt_s") * F.col("n_lanes")),
Comment on lines +195 to +200
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

F.try_divide is not available in all supported PySpark versions. Since install_requires doesn’t constrain a minimum PySpark version, this can break at runtime on older environments. Either add a minimum pyspark version constraint that guarantees try_divide, or implement the safe division with when/otherwise so it works across versions.

Copilot uses AI. Check for mistakes.
)
else:
# Without lane info assume 1 lane
result = result.with_columns(
(pl.col("num_vehicles") / pl.col("cluster_len_km")).alias("density"),
(pl.col("num_vehicles") * 3.6e3 / pl.col("cluster_dt_s")).alias("flow"),
result = result.withColumn(
"density",
F.try_divide(F.col("num_vehicles"), F.col("cluster_len_km")),
).withColumn(
"flow",
F.try_divide(F.col("num_vehicles") * 3.6e3, F.col("cluster_dt_s")),
)

self._result = result.sort(group + ["cluster_local_id"])
self._result = result.orderBy(*group, "cluster_local_id")
return self

# ------------------------------------------------------------------
# accessors
# ------------------------------------------------------------------
@property
def result(self) -> pl.DataFrame:
def result(self) -> DataFrame:
"""Return the clustered result DataFrame.

Raises
Expand All @@ -186,19 +229,64 @@ def result(self) -> pl.DataFrame:
return self._result

@property
def df(self) -> pl.DataFrame:
def result_intratimes(self) -> DataFrame:
"""Return the intermediate times DataFrame.

Raises
------
RuntimeError
If :meth:`clusterize` has not been called yet.
"""
if self._result_intratimes is None:
raise RuntimeError("Call .clusterize() with intermediates=True before accessing .result_intratimes")
return self._result_intratimes

@property
def df(self) -> DataFrame:
"""Alias for :attr:`result`."""
return self.result

def to_csv(self, path: str | Path, **kwargs) -> None:
"""Write the result to a CSV file."""
self.result.write_csv(path, **kwargs)
"""Write the result to a CSV directory (Spark partitioned output).

Parameters
----------
path : str or Path
Destination directory. Spark writes one or more part-* files.
**kwargs
Extra options forwarded to ``DataFrameWriter.csv()``.
"""
self.result.write.option("header", "true").csv(str(path), **kwargs)

def to_csv_intratimes(self, path: str | Path, **kwargs) -> None:
"""Write the intermediate times result to a CSV directory (Spark partitioned output).

Parameters
----------
path : str or Path
Destination directory. Spark writes one or more part-* files.
**kwargs
Extra options forwarded to ``DataFrameWriter.csv()``.
"""
self.result_intratimes.write.option("header", "true").csv(str(path), **kwargs)


def to_parquet(self, path: str | Path, **kwargs) -> None:
"""Write the result to a Parquet file."""
self.result.write_parquet(path, **kwargs)
"""Write the result to a Parquet directory (Spark partitioned output).

Parameters
----------
path : str or Path
Destination directory.
**kwargs
Extra options forwarded to ``DataFrameWriter.parquet()``.
"""
self.result.write.parquet(str(path), **kwargs)

def __repr__(self) -> str:
status = "clusterized" if self._result is not None else "raw"
rows = len(self._result) if self._result is not None else len(self._df)
return f"TSM(status={status}, rows={rows})"
if self._result is not None:
rows = self._result.count()
else:
rows = self._df.count()
return f"TSM(status={status}, rows={rows})"
Comment on lines +288 to +292
Copy link

Copilot AI Feb 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling DataFrame.count() inside __repr__ triggers a Spark job and can be extremely expensive/unexpected (and may even fail if the session is stopped). Consider avoiding actions in __repr__ (e.g., omit row count or show a cached/known value only).

Suggested change
if self._result is not None:
rows = self._result.count()
else:
rows = self._df.count()
return f"TSM(status={status}, rows={rows})"
return f"TSM(status={status})"

Copilot uses AI. Check for mistakes.