PyTorch Implementation of Nested Logit Mode Choice Model with Heterogenous Features + Dynamic Pricing Tools
This is a nested logit discrete choice model implemented in PyTorch and deployed as a Streamlit app (containerized with Docker). The public ModeCanada dataset is used to model travel mode choice across train, car, bus, air using a two-level nest structure (Land vs. Air) where Land = {train, car, bus} and Air = {air}. Heterogeneous effects introduced via income and urban features. The project is framed as a hypothetical travel-agency e-ticketing case study, where train recall is prioritized to reduce missed public transport demand signals.
The dataset is based on the ModeCanada example commonly used in discrete choice modeling:
- Source documentation: ModeCanada dataset (mlogit)
https://rdrr.io/rforge/mlogit/man/ModeCanada.html
Each observation represents a choice occasion with:
- Alternative-specific attributes (vary by mode):
cost,ivt(in-vehicle time),ovt(out-of-vehicle time),freq(frequency) - Individual-specific attributes (vary by user):
income,urban(indicator)
A nested logit model is used to relax the IIA assumption by allowing correlated unobserved utility within a nest (substitution patterns are captured more realistically than in plain multinomial logit).
Utilities are computed as:
U_ni = V_ni + eps_niP(i|n) = P(i|m,n) * P(m|n)(nested logit factorization)V_niis formed from base alternative features and engineered features, with heterogeneity introduced, including via alternative-specific covariates (e.g.,income_train) and interactions (e.g.,urban_x_ovt).
Choice probabilities are decomposed into:
- probability of choosing a nest
- probability of choosing an alternative conditional on the nest (via inclusive value / log-sum)
PyTorch is used to compute utilities and nested probabilities, and parameters are optimized by minimizing negative log-likelihood (NLL) with model selection guided by validation performance.
A base feature set is used:
cost,ivt,ovt,freq
Individual attributes are also included via alternative-specific coding (mode-specific covariates):
income_train,income_car,income_bus,income_airurban_train,urban_car,urban_bus,urban_air
Alternative-specific coding allows person-level attributes to affect different modes differently by interacting the attribute with mode indicators, so income/urban can shift the relative utilities and therefore the predicted choice.
Candidate engineered features are evaluated via forward selection:
cost_log: Log of monetary cost (as to make the model treat cost sensitivity as diminishing, e.g., an extra $10 hurts more when the trip is cheap than when it's already expensive).wait_time: Expected waiting time from service frequency calculated using the half headway rule.gen_time:ivt+w_ovt*ovt+wait_timerel_cost: How much more expensive the alternative is compared to the cheapest option in the same case.rel_gen_time: How much slower (in generalized time) the alternative is compared to the fastest option in the same case.income_x_rel_cost: Interaction term betweenincomeandrel_cost.urban_x_ovt: Interaction term betweenurbanandovt.
Feature selection is performed with a two-objective routine where validation NLL is improved while train recall is treated as a priority metric for the case study.
- Stratified sampling is used to preserve class proportions across train/validations/test splits.
- Optimization is performed in PyTorch.
- Model selection is guided by validation NLL.
For this particular dataset, the chosen item features (aside from the alternative-specific coding) include cost, ivt, ovt, freq, urban_x_ovt, gen_time.
Table below shows the overall metrics of train/val/test:
| Split | Per-case NLL | Accuracy |
|---|---|---|
| Train | 0.6193 | 0.7601 |
| Val | 0.6295 | 0.7612 |
| Test | 0.5967 | 0.7673 |
Mode shares (pred vs. actual) are as below:
| Split | Type | Train | Car | Bus | Air |
|---|---|---|---|---|---|
| Train | Pred | 14.37% | 51.24% | 0.39% | 34.00% |
| Train | Actual | 14.41% | 51.19% | 0.36% | 34.04% |
| Train | Delta | -0.04% | +0.05% | +0.02% | -0.03% |
| Val | Pred | 15.09% | 49.35% | 0.47% | 35.09% |
| Val | Actual | 14.33% | 51.16% | 0.46% | 34.05% |
| Val | Delta | +0.76% | -1.81% | +0.01% | +1.04% |
| Test | Pred | 14.10% | 50.20% | 0.38% | 35.32% |
| Test | Actual | 14.48% | 51.16% | 0.31% | 34.05% |
| Test | Delta | -0.39% | -0.95% | +0.07% | +1.27% |
Across train/validation/test:
- Accuracy is stable around 0.76~0.77.
- Predicted shares are close to observed shares (aggregate demand is reproduced reasonably well).
On the test set:
-
ROC AUC: 0.8591 (strong) | ROC AUC (weighted): 0.8908 (strong)
ROC AUC measures how well the model can discriminate a given class from "not that class" across a range of decision thresholds. A strong ROC AUC means the model's scoring is generally well-ordered, which means true positives tend to receive higher predicted scores than true negatives. Equivalently, if one positive case and one negative case are sampled at random, the model will assign a higher score to the positive case most of the time. In practice, this suggests the model has learned useful signal and can separate modes reasonably well at the ranking/score level, even before committing to a single hard label via argmax or a fixed threshold.
-
PR AUC: 0.5439 (moderate)
PR AUC evaluates performance in terms of the precision–recall trade-off, and is typically more informative under class imbalance since it focuses on how reliable positive predictions remain as we attempt to capture more positives. Unlike ROC AUC (which can remain high when negatives dominate), PR AUC drops quickly if improving recall comes at the cost of many false positives. A moderate PR AUC here indicates that while the model can separate classes reasonably well overall, its positive predictions are less robust for rarer modes, i.e., as we try to recover more minority-class cases, precision degrades more than we would ideally want.
-
Accuracy: 0.7673 (moderate strong)
Accuracy reports the fraction of samples where the predicted mode matches the true mode. An accuracy of 0.7673 means the model is correct on ~76.7% of cases overall. Because this metric counts every sample equally, it is most influenced by the majority classes where strong performance on common modes can yield a high accuracy even when performance on rarer modes is weaker.
-
Macro F1: 0.4530 (moderate-weak) | Weighted Average F1: 0.7267 (moderate-strong)
A macro F1 of 0.4530 indicates uneven performance, typically driven by weaker minority-class detection. Weighted F1 averages per-class F1 weighted by class frequency, so it emphasizes performance on the most common modes. A weighted F1 of 0.7267 suggests the model is comparatively strong on majority classes even if some minority classes remain challenging.
Table below shows the confusion matrix:
| True \ Pred | pred_train | pred_car | pred_bus | pred_air |
|---|---|---|---|---|
| true_train | 9 | 60 | 0 | 25 |
| true_car | 5 | 294 | 0 | 33 |
| true_bus | 0 | 2 | 0 | 0 |
| true_air | 3 | 23 | 0 | 195 |
Reading note: most true train cases are predicted as car (60) or air (25), which explains the low train recall.
While table below shows the classification report:
| Class | Precision | Recall | F1 | Support |
|---|---|---|---|---|
| train | 0.5294 | 0.0957 | 0.1622 | 94 |
| car | 0.7757 | 0.8855 | 0.8270 | 332 |
| bus | 0.0000 | 0.0000 | 0.0000 | 2 |
| air | 0.7708 | 0.8824 | 0.8228 | 221 |
| accuracy | 0.7673 | 649 | ||
| macro avg | 0.5190 | 0.4659 | 0.4530 | 649 |
| weighted avg | 0.7360 | 0.7673 | 0.7267 | 649 |
Note: bus has support = 2, so bus metrics are not statistically meaningful.
Class-wise performance is still uneven:
- car and air show strong recall (around 0.88) and decent precision (around 0.77~0.78).
- While train's precision is not good either, its recall is especially low (less than 0.10) with many true train cases predicted as car or air.
- bus has extremely low support (~2 samples) so class-specific bus metrics are not reliable.
This is consistent with Nested Logit model that matches aggregate shares but struggles to separate train at the individual decision level under the current signal and objective.
Price response is analyzed via counterfactuals. The cost feature for a selected mode is multiplied by a factor (e.g., +1%), choice probabilities are recomputed, and expected aggregate demand is obtained by summing probabilities across observations. Price elasticity of each specific mode is estimated using a finite-difference approximation, and substitution curves are generated by sweeping a multiplier grid and plotting relative demand by mode.
Pricing is optimized with gradient-based expected revenue maximization (PyTorch autograd + Adam) using multiplicative price factors. In the scenario-based optimizer, demand can be softly capped with a differentiable capacity constraint to approximate inventory limits.
- Substitution simulation: Explore how predicted demand reallocates across modes when the price of a selected mode changes (counterfactual demand + substitution curves).
- Population pricing optimization: Solve for revenue-maximizing price multipliers under different scenarios (e.g., capacity constraints or policy assumptions), using differentiable expected revenue.
- User-context sandbox: Run single user simulations to understand sensitivity and implied willingness-to-pay in a controlled setting.
The user-level sandbox is included strictly for learning and sensitivity analysis. In real deployments, personalized pricing can raise serious ethical and legal concerns (e.g., discrimination or unfair treatment), so it should not be used without careful governance, transparency, and jurisdiction review.
A Streamlit interface is provided for interactive inference and policy evaluation. The app is intended to demonstrate how a choice model can support dynamic pricing optimization where candidate price changes are simulated and compared using predicted demand and substitution effects before deployment.
-
Create a virtual environment:
python -m venv venv
-
Activate the virtual environment:
- Windows:
.\venv\Scripts\activate
- Unix/MacOS:
source venv/bin/activate
- Windows:
-
Install dependencies as specified in
pyproject.toml
pip install .- Run Streamlit
streamlit run app/Home.py- Create the backend network
docker network create backend- Create
.envfile and copy the content from.env.example. Adjust accordingly. - Build the docker image
docker compose build- Run the docker container
docker compose up -dTo stop the container, run:
docker compose downBelow is the NGINX configuration to forward request to this application.
map $http_host $proxy_host {
'' $host;
default $http_host;
}
map $http_x_forwarded_proto $proxy_x_forwarded_proto {
'' $scheme;
default $http_x_forwarded_proto;
}
map $http_x_forwarded_scheme $proxy_x_forwarded_scheme {
'' $scheme;
default $http_x_forwarded_scheme;
}
map $http_x_forwarded_for $proxy_x_forwarded_for {
'' $proxy_add_x_forwarded_for;
default $http_x_forwarded_for;
}
map $http_x_real_ip $proxy_x_real_ip {
'' $remote_addr;
default $http_x_real_ip;
}
map $http_upgrade $connection_upgrade {
'' $http_connection;
default "upgrade";
}
server {
listen 443 ssl;
listen [::]:443 ssl;
server_tokens off;
server_name <server_name>;
ssl_certificate <path_to_cert>;
ssl_certificate_key <path_to_key>;
location / {
# Additional response headers
add_header Strict-Transport-Security "max-age=31536000; includeSubDomains; preload";
# Request headers passed to the origin.
proxy_set_header Host $proxy_host;
proxy_set_header X-Forwarded-Scheme $proxy_x_forwarded_scheme;
proxy_set_header X-Forwarded-Proto $proxy_x_forwarded_proto;
proxy_set_header X-Forwarded-For $proxy_x_forwarded_for;
proxy_set_header X-Real-IP $proxy_x_real_ip;
# WebSocket support
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection $connection_upgrade;
proxy_http_version 1.1;
proxy_pass http://mode-choice:8501;
}
}Note
- Replace
<server_name>with the host name you want to use. - Replace
<path_to_cert>with the actual path to the SSL certificate file. - Replace
<path_to_key>with the actual path to the SSL private key file.
Several changes are likely to improve the model's performance (including train recall):
- Add mode-specific slopes for key attributes (e.g.,
cost_train,gen_time_air) on top of the existing user-level alternative-specific coding (e.g.,income_train,urban_air). - Cost-sensitive training (class-weighted NLL) to penalize train false negatives more heavily.
- Revisit the nest structure, e.g., replace Land vs. Air with a structure closer to behavior (e.g., public (train/bus) vs. private (car), optionally treating air as its own nest).
- With extremely low bus support, estimates may be noisy, so consider merging bus into an "other" bucket or collecting more bus samples.
- Add nonlinear time sensitivity (e.g., log transform or piecewise/threshold effects for time components).
- If stronger heterogeneity is needed, a mixed logit model (random coefficients) can be an upgrade.