diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml
index 2f06667..803d71d 100644
--- a/.github/workflows/ci-build.yml
+++ b/.github/workflows/ci-build.yml
@@ -43,18 +43,18 @@ jobs:
- name: Install wheel
run: pip install wheel
- - name: Build encrypted filesystem artifacta, contract ledger client & depa-training container
+ - name: Build encrypted filesystem artifacts, contract ledger client & depa-training container
run: ci/build.sh
- name: Build container images
run: cd ${{ github.workspace }}/scenarios/covid && ./ci/build.sh
- name: Run pre-processing
- run: cd ./scenarios/covid/deployment/docker && ./preprocess.sh
+ run: cd ./scenarios/covid/deployment/local && ./preprocess.sh
- name: Run model saving
- run: cd ./scenarios/covid/deployment/docker && ./save-model.sh
+ run: cd ./scenarios/covid/deployment/local && ./save-model.sh
- name: Run training
- run: cd ./scenarios/covid/deployment/docker && ./train.sh
+ run: cd ./scenarios/covid/deployment/local && ./train.sh
diff --git a/.github/workflows/ci-local.yml b/.github/workflows/ci-local.yml
index 5595509..7e9f613 100644
--- a/.github/workflows/ci-local.yml
+++ b/.github/workflows/ci-local.yml
@@ -42,10 +42,10 @@ jobs:
run: cd ${{ github.workspace }}/ci && ./pull-containers.sh
- name: Run pre-processing
- run: cd ./scenarios/covid/deployment/docker && ./preprocess.sh
+ run: cd ./scenarios/covid/deployment/local && ./preprocess.sh
- name: Run model saving
- run: cd ./scenarios/covid/deployment/docker && ./save-model.sh
+ run: cd ./scenarios/covid/deployment/local && ./save-model.sh
- name: Run training
- run: cd ./scenarios/covid/deployment/docker && ./train.sh
+ run: cd ./scenarios/covid/deployment/local && ./train.sh
diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 7273ad0..8205868 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -79,31 +79,34 @@ jobs:
run: sudo usermod -aG docker $USER
- name: Run pre-processing
- run: cd ${{ github.workspace }}/scenarios/covid/deployment/docker && ./preprocess.sh
+ run: cd ${{ github.workspace }}/scenarios/covid/deployment/local && ./preprocess.sh
- name: Run model saving
- run: cd ${{ github.workspace }}/scenarios/covid/deployment/docker && ./save-model.sh
+ run: cd ${{ github.workspace }}/scenarios/covid/deployment/local && ./save-model.sh
- name: Pull container images for generating policy
run: cd ${{ github.workspace }}/ci && ./pull-containers.sh
+ - name: Consolidate pipeline configuration
+ run: cd ${{ github.workspace }}/scenarios/covid/ && ./config/consolidate_pipeline.sh
+
- name: create storage and containers
- run: cd ${{ github.workspace }}/scenarios/covid/data && ./1-create-storage-containers.sh
+ run: cd ${{ github.workspace }}/scenarios/covid/deployment/azure && ./1-create-storage-containers.sh
- name: create azure key vault
- run: cd ${{ github.workspace }}/scenarios/covid/data && ./2-create-akv.sh
+ run: cd ${{ github.workspace }}/scenarios/covid/deployment/azure && ./2-create-akv.sh
- name: Import data and model encryption keys with key release policies
- run: cd ${{ github.workspace }}/scenarios/covid/data && ./3-import-keys.sh
+ run: cd ${{ github.workspace }}/scenarios/covid/deployment/azure && ./3-import-keys.sh
- name: Encrypt data and models
- run: cd ${{ github.workspace }}/scenarios/covid/data && ./4-encrypt-data.sh
+ run: cd ${{ github.workspace }}/scenarios/covid/deployment/azure && ./4-encrypt-data.sh
- name: Upload data and model
- run: cd ${{ github.workspace }}/scenarios/covid/data && ./5-upload-encrypted-data.sh
+ run: cd ${{ github.workspace }}/scenarios/covid/deployment/azure && ./5-upload-encrypted-data.sh
- name: Run training
- run: cd ${{ github.workspace }}/scenarios/covid/deployment/aci && ./deploy.sh -c ${{ github.event.inputs.contract }} -p ../../config/pipeline_config.json
+ run: cd ${{ github.workspace }}/scenarios/covid/deployment/azure && ./deploy.sh -c ${{ github.event.inputs.contract }} -p ../../config/pipeline_config.json
- name: Dump training container logs
run: sleep 200 && az container logs --name depa-training-covid --resource-group $AZURE_RESOURCE_GROUP --container-name depa-training
@@ -112,7 +115,7 @@ jobs:
run: az container logs --name depa-training-covid --resource-group $AZURE_RESOURCE_GROUP --container-name encrypted-storage-sidecar
- name: Download and decrypt model
- run: cd ${{ github.workspace }}/scenarios/covid/data && ./6-download-decrypt-model.sh
+ run: cd ${{ github.workspace }}/scenarios/covid/deployment/azure && ./6-download-decrypt-model.sh
- name: Clean up resource group and all resources
run: az group delete --yes --name $AZURE_RESOURCE_GROUP
diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml
index 73e7116..c5b6500 100644
--- a/.github/workflows/release.yml
+++ b/.github/workflows/release.yml
@@ -34,7 +34,7 @@ jobs:
context: ./scenarios/covid/src
buildargs: |
- dockerfile: ./scenarios/covid/ci/Dockerfile.modelsave
- name: ccr-model-save
+ name: covid-model-save
context: ./scenarios/covid/src
buildargs: |
- dockerfile: ./ci/Dockerfile.encfs
diff --git a/.gitignore b/.gitignore
index ed20c36..e8db1be 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1 +1,10 @@
**/*.onnx
+**/*.pth
+**/*.pt
+**/*.img
+**/*.bin
+**/*.pem
+
+venv/
+
+**/__pycache__/
\ No newline at end of file
diff --git a/README.md b/README.md
index 7c088db..552bd4e 100644
--- a/README.md
+++ b/README.md
@@ -6,7 +6,7 @@
## GitHub Codespaces
-The simplest way to setup a development environment is using [GitHub Codespaces](https://github.com/codespaces). The repository includes a [devcontainer.json](../../.devcontainer/devcontainer.json), which customizes your codespace to install all required dependencies. Please ensure you allocate at least 64GB disk space in your codespace. Also, run the following command in the codespace to update submodules.
+The simplest way to setup a development environment is using [GitHub Codespaces](https://github.com/codespaces). The repository includes a [devcontainer.json](.devcontainer/devcontainer.json), which customizes your codespace to install all required dependencies. Please ensure you allocate at least 8 vCPUs and 64GB disk space in your codespace. Also, run the following command in the codespace to update submodules.
```bash
git submodule update --init --recursive
@@ -14,20 +14,23 @@ git submodule update --init --recursive
## Local Development Environment
-Alternatively, you can build and develop locally in a Linux environment (we have tested with Ubuntu 20.04 and 22.04), or Windows with WSL 2. Install the following dependencies.
+Alternatively, you can build and develop locally in a Linux environment (we have tested with Ubuntu 20.04 and 22.04), or Windows with WSL 2.
-- [docker](https://docs.docker.com/engine/install/ubuntu/) and docker-compose. After installing docker, add your user to the docker group using `sudo usermod -aG docker $USER`, and log back in to a shell.
-- make (install using ```sudo apt-get install make```)
-- Python 3.6.9 and pip
-- [Go](https://go.dev/doc/install). Follow the instructions to install Go. After installing, ensure that the PATH environment variable is set to include ```go``` runtime.
-- Python wheel package (install using ```pip install wheel```)
-
-Clone this repo as follows.
+Clone this repo to your local machine / virtual machine as follows.
```bash
git clone --recursive http://github.com/iSPIRT/depa-training
+cd depa-training
```
+Install the below listed dependencies by running the [install-prerequisites.sh](./install-prerequisites.sh) script.
+
+```bash
+./install-prerequisites.sh
+```
+
+Note: You may need to restart your machine to ensure that the changes take effect.
+
## Build CCR containers
To build your own CCR container images, use the following command from the root of the repository.
@@ -44,16 +47,39 @@ This scripts build the following containers.
Alternatively, you can use pre-built container images from the ispirt repository by setting the following environment variable. Docker hub has started throttling which may effect the upload/download time, especially when images are bigger size. So, It is advisable to use other container registries, we are using azure container registry as shown below
```bash
export CONTAINER_REGISTRY=ispirt.azurecr.io
+./ci/pull-containers.sh
```
# Scenarios
This repository contains two samples that illustrate the kinds of scenarios DEPA for Training can support.
-- [Training a differentially private COVID prediction model on private datasets](./scenarios/covid/README.md)
-- [Convolutional Neural Network training on MNIST dataset](./scenarios/mnist/README.md)
+Follow the links to build and deploy these scenarios.
+
+| Scenario name | Scenario type | Task type | Privacy | No. of TDPs* | Data type (format) | Model type (format) | Join type (No. of datasets) |
+|--------------|---------------|-----------------|--------------|-----------|------------|------------|------------|
+| [COVID-19](./scenarios/covid/README.md) | Training - Deep Learning | Binary Classification | Differentially Private | 3 | PII tabular data (CSV) | MLP (ONNX) | Horizontal (3)|
+| [BraTS](./scenarios/brats/README.md) | Training - Deep Learning | Image Segmentation | Differentially Private | 4 | MRI scans data (NIfTI/PNG) | UNet (Safetensors) | Vertical (4)|
+| [Credit Risk](./scenarios/credit-risk/README.md) | Training - Classical ML | Binary Classification | Differentially Private | 4 | PII tabular data (Parquet) | XGBoost (JSON) | Horizontal (4)|
+| [CIFAR-10](./scenarios/cifar10/README.md) | Training - Deep Learning | Multi-class Image Classification | NA | 1 | Non-PII image data (SafeTensors) | CNN (Safetensors) | NA (1)|
+| [MNIST](./scenarios/mnist/README.md) | Training - Deep Learning | Multi-class Image Classification | NA | 1 | Non-PII image data (HDF5) | CNN (ONNX) | NA (1)|
+
+_NA: Not Applicable_
+_DL: Deep Learning, ML: Classical Machine Learning_
+_*Training Data Providers (TDPs) involved in the scenario._
+
+## Build your own Scenarios
+
+A guide to build your own scenarios is coming soon. Stay tuned!
+
+Currently, DEPA for Training supports the following training frameworks, libraries and file formats (more will be included soon):
+
+- Training frameworks: PyTorch, Scikit-learn, XGBoost
+- Libraries: Opacus, PySpark, Pandas
+- File formats (for models and datasets): ONNX, Safetensors, Parquet, CSV, HDF5, PNG
+
+Note: Due to security reasons, we do not support Pickle based file formats such as .pkl, .pt/.pth, .npy/.npz, .joblib, etc.
-Follow these links to build and deploy these scenarios.
# Contributing
diff --git a/ci/Dockerfile.train b/ci/Dockerfile.train
index e7a3a80..dd1111f 100644
--- a/ci/Dockerfile.train
+++ b/ci/Dockerfile.train
@@ -1,23 +1,33 @@
-FROM ubuntu:20.04
+FROM ubuntu:22.04
ENV DEBIAN_FRONTEND="noninteractive"
RUN apt-get update && apt-get -y upgrade \
&& apt-get install -y curl \
- && apt-get install -y python3.9 python3.9-dev python3.9-distutils \
- && apt-get install -y openjdk-8-jdk
+ && apt-get install -y python3 python3-dev python3-distutils \
+ && apt-get install -y openjdk-17-jdk
## Install pip
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
-RUN python3.9 get-pip.py
+RUN python3 get-pip.py
## Install dependencies
RUN pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
-RUN pip3 --default-timeout=1000 install pyspark pandas opacus onnx onnx2pytorch scikit-learn scipy matplotlib
+RUN pip3 --default-timeout=1000 install pyspark pandas opacus==1.5.3 onnx onnx2pytorch scikit-learn scipy matplotlib
+RUN pip3 install safetensors h5py pyarrow xgboost
+
+# For computer vision tasks
+RUN pip3 install --default-timeout=100 opencv-python pillow monai==1.4.0
+
+# # For natural language processing tasks
+# RUN pip3 install transformers datasets peft
RUN apt-get install -y jq
-# Install contract ledger client
+ENV JAVA_HOME /usr/lib/jvm/java-17-openjdk-amd64/
+RUN export JAVA_HOME
+
+# Install pytrain package for training
COPY train/dist/pytrain-0.0.1-py3-none-any.whl .
RUN pip3 install pytrain-0.0.1-py3-none-any.whl
diff --git a/ci/pull-containers.sh b/ci/pull-containers.sh
index d12f5fc..ea72ef5 100755
--- a/ci/pull-containers.sh
+++ b/ci/pull-containers.sh
@@ -1,6 +1,6 @@
#!/bin/bash
-containers=("ccr-model-save:latest" "depa-training:latest" "depa-training-encfs:latest")
+containers=("depa-training:latest" "depa-training-encfs:latest")
for container in "${containers[@]}"
do
docker pull $CONTAINER_REGISTRY"/"$container
diff --git a/install-prerequisites.sh b/install-prerequisites.sh
new file mode 100755
index 0000000..91cdea0
--- /dev/null
+++ b/install-prerequisites.sh
@@ -0,0 +1,159 @@
+#!/bin/bash
+#
+# Script Name: install-prerequisites.sh
+# Description:
+# Checks for the presence of required development tools and installs
+# any that are missing. Supports Ubuntu/Debian systems.
+# At the end, prints a summary table with each tool's status and version.
+#
+# Prerequisites Checked:
+# - Python3
+# - pip
+# - Go
+# - make
+# - jq
+# - wheel (Python package)
+# - Docker
+# - docker-compose
+# - Azure CLI
+# - Azure CLI extension: confcom
+#
+# Notes:
+# - Requires sudo privileges for package installation.
+# - After adding the user to the 'docker' group, a re-login or `newgrp docker`
+# may be needed for changes to take effect.
+
+set -e
+
+echo "=== Checking and installing prerequisites ==="
+
+# Check if apt is installed
+if ! command -v apt >/dev/null 2>&1; then
+ echo "apt is not installed. Installing it using apt-get."
+ sudo apt-get update
+ sudo apt-get install -y apt
+fi
+
+# Update apt package index
+sudo apt update -y
+
+# Arrays to track results
+pkg_names=()
+pkg_status=()
+pkg_version=()
+
+# Record result
+record_result() {
+ pkg_names+=("$1")
+ pkg_status+=("$2")
+ pkg_version+=("$3")
+}
+
+# Check & install function
+check_and_install() {
+ local cmd="$1"
+ local pkg="$2"
+ local install_cmd="$3"
+ local ver_cmd="$4"
+
+ if command -v "$cmd" >/dev/null 2>&1; then
+ local ver
+ ver=$($ver_cmd 2>/dev/null | head -n 1)
+ echo "[OK] $pkg is already installed: $ver"
+ record_result "$pkg" "Already Installed" "$ver"
+ else
+ echo "[Installing] $pkg..."
+ eval "$install_cmd"
+ local ver
+ ver=$($ver_cmd 2>/dev/null | head -n 1)
+ record_result "$pkg" "Installed Now" "$ver"
+ fi
+}
+
+# --- Python3 ---
+check_and_install python3 python3 "sudo apt install -y python3" "python3 --version"
+
+# --- pip ---
+check_and_install pip pip "sudo apt install -y python3-pip" "pip --version"
+
+# --- Go ---
+check_and_install go golang-go "sudo apt install -y golang-go" "go version"
+
+# --- make ---
+check_and_install make make "sudo apt install -y make" "make --version"
+
+# --- jq ---
+check_and_install jq jq "sudo apt install -y jq" "jq --version"
+
+# --- wheel ---
+if python3 -m pip show wheel >/dev/null 2>&1; then
+ ver=$(python3 -m pip show wheel | grep Version: | awk '{print $2}')
+ echo "[OK] wheel is already installed: version $ver"
+ record_result "wheel (pip)" "Already Installed" "$ver"
+else
+ echo "[Installing] wheel..."
+ python3 -m pip install wheel
+ ver=$(python3 -m pip show wheel | grep Version: | awk '{print $2}')
+ record_result "wheel (pip)" "Installed Now" "$ver"
+fi
+
+# --- Docker ---
+if command -v docker >/dev/null 2>&1; then
+ ver=$(docker -v)
+ echo "[OK] docker is already installed: $ver"
+ record_result "docker" "Already Installed" "$ver"
+else
+ echo "[Installing] docker..."
+ curl -fsSL https://get.docker.com -o get-docker.sh
+ sudo sh get-docker.sh
+ ver=$(docker -v)
+ record_result "docker" "Installed Now" "$ver"
+fi
+
+# --- docker-compose ---
+check_and_install docker-compose docker-compose "sudo apt install -y docker-compose" "docker-compose --version"
+
+# --- Azure CLI ---
+if command -v az >/dev/null 2>&1; then
+ ver=$(az version --query '[].azure-cli' --output tsv 2>/dev/null || az version | grep azure-cli | head -n1)
+ echo "[OK] Azure CLI is already installed: $ver"
+ record_result "Azure CLI" "Already Installed" "$ver"
+else
+ echo "[Installing] Azure CLI..."
+ curl -sL https://aka.ms/InstallAzureCLIDeb | sudo bash
+ ver=$(az version --query '[].azure-cli' --output tsv 2>/dev/null || az version | grep azure-cli | head -n1)
+ record_result "Azure CLI" "Installed Now" "$ver"
+fi
+
+# --- Azure CLI extension: confcom ---
+if az extension show --name confcom >/dev/null 2>&1; then
+ ver=$(az extension show --name confcom --query version -o tsv)
+ echo "[OK] Azure CLI extension 'confcom' is already installed: $ver"
+ record_result "Azure CLI ext: confcom" "Already Installed" "$ver"
+else
+ echo "[Installing] Azure CLI extension 'confcom'..."
+ az extension add --name confcom -y
+ ver=$(az extension show --name confcom --query version -o tsv)
+ record_result "Azure CLI ext: confcom" "Installed Now" "$ver"
+fi
+
+# --- Docker group setup ---
+if groups "$USER" | grep &>/dev/null '\bdocker\b'; then
+ echo "[OK] User '$USER' is already in docker group."
+else
+ echo "[Adding] User '$USER' to docker group..."
+ sudo usermod -aG docker "$USER"
+ echo "You may need to log out and log back in for docker group changes to take effect."
+fi
+
+# --- Summary Table ---
+echo
+echo "=== Installation Summary ==="
+printf "%-30s | %-17s | %-30s\n" "Package" "Status" "Version"
+printf "%-30s | %-17s | %-30s\n" "------------------------------" "-----------------" "------------------------------"
+
+for i in "${!pkg_names[@]}"; do
+ printf "%-30s | %-17s | %-30s\n" "${pkg_names[$i]}" "${pkg_status[$i]}" "${pkg_version[$i]}"
+done
+
+echo "=== All prerequisites are installed. ==="
diff --git a/scenarios/brats/.gitignore b/scenarios/brats/.gitignore
new file mode 100644
index 0000000..ed87890
--- /dev/null
+++ b/scenarios/brats/.gitignore
@@ -0,0 +1,22 @@
+# Ignore model and binary files
+*.img
+*.bin
+*.pth
+*.pt
+*.onnx
+*.npy
+*.png
+*.nii.gz
+*.nii
+*.safetensors
+*.h5
+*.hdf5
+
+# Ignore modeller output folder (relative to repo root)
+modeller/output/
+
+# Ignore any folder named preprocessed (anywhere)
+preprocessed/
+
+# Ignore pycache
+__pycache__/
\ No newline at end of file
diff --git a/scenarios/brats/README.md b/scenarios/brats/README.md
new file mode 100644
index 0000000..df2a8ff
--- /dev/null
+++ b/scenarios/brats/README.md
@@ -0,0 +1,402 @@
+# Brain Tumor Segmentation
+
+## Scenario Type
+
+| Scenario name | Scenario type | Task type | Privacy | No. of TDPs* | Data type (format) | Model type (format) | Join type (No. of datasets) |
+|--------------|---------------|-----------------|--------------|-----------|------------|------------|------------|
+| BraTS | Training - Deep Learning | Image Segmentation | Differentially Private | 4 | MRI scans data (NIfTI/PNG) | UNet (Safetensors) | Vertical (4)|
+
+---
+
+## Scenario Description
+
+This scenario demonstrates how a deep learning model can be trained for Brain MRI Tumor Segmentation using the join of multiple medical imaging datasets (potentially PII sensitive, due to combination of quasi-identifiers such as biodata or possiblity of volumetric facial reconstruction). The Training Data Consumer (TDC) building the model gets into a contractual agreement with multiple Training Data Providers (TDPs) having annotated MRI data, and the model is trained on the joined datasets in a data-blind manner within the CCR, maintaining privacy guarantees (as per need, keeping in mind the privacy-utility trade-off) using differential privacy. For demonstration purpose, this scenario uses annotated Brain MRI data made available through the BraTS 2020 challenge [[1, 2, 3]](README.md#references), and a custom UNet architecture model for segmentation.
+
+The end-to-end training pipeline consists of the following phases:
+
+1. Data pre-processing
+2. Data packaging, encryption and upload
+3. Model packaging, encryption and upload
+4. Encryption key import with key release policies
+5. Deployment and execution of CCR
+6. Trained model decryption
+
+## Build container images
+
+Build container images required for this sample as follows:
+
+```bash
+export SCENARIO=brats
+export REPO_ROOT="$(git rev-parse --show-toplevel)"
+cd $REPO_ROOT/scenarios/$SCENARIO
+./ci/build.sh
+```
+
+This script builds the following container images:
+
+- `preprocess-brats-a, preprocess-brats-b, preprocess-brats-c`: Containers that pre-process the individual MRI datasets
+- `brats-model-save`: Container that saves the base model to be trained.
+
+Alternatively, you can pull and use pre-built container images from the ispirt container registry by setting the following environment variable. Docker hub has started throttling which may effect the upload/download time, especially when images are bigger size. So, It is advisable to use other container registries. We are using Azure container registry (ACR) as shown below:
+
+```bash
+export CONTAINER_REGISTRY=ispirt.azurecr.io
+cd $REPO_ROOT/scenarios/$SCENARIO
+./ci/pull-containers.sh
+```
+
+## Data pre-processing
+
+The folder ```scenarios/brats/src``` contains scripts for extracting and pre-processing the BraTS MRI datasets. Acting as a Training Data Provider (TDP), prepare your datasets.
+
+For ease of execution, the individual preprocessed BraTS MRI datasets are already made available in the repo under `scenarios/brats/data` as `tar.gz` files. Run the following scripts to extract them:
+
+```bash
+cd $REPO_ROOT/scenarios/$SCENARIO/deployment/local
+./preprocess.sh
+```
+
+The datasets are saved to the [data](./data/) directory.
+
+> Note: If you wish to pre-process the datasets yourself (in this case, extract 2D slices from the original 3D MRI NIfTI volumes and perform preprocessing and augmentation steps), uncomment and modify the preprocess scripts located in [src](./src/preprocess_brats_A.py).
+
+## Prepare model for training
+
+Next, acting as a Training Data Consumer (TDC), define and save your base model for training using the following script. This calls the [save_base_model.py](./src/save_base_model.py) script, which is a custom script that saves the model to the [models](./modeller/models) directory, as a PyTorch file.
+
+```bash
+./save-model.sh
+```
+
+## Deploy locally
+
+Assuming you have cleartext access to all the datasets, you can train the model _locally_ as follows:
+
+```bash
+./train.sh
+```
+
+The script joins the datasets and trains the model using a pipeline configuration. To modify the various components of the training pipeline, you can edit the training config files in the [config](./config/) directory. The training config files are used to create the pipeline configuration ([pipeline_config.json](./config/pipeline_config.json)) created by consolidating all the TDC's training config files, namely the [model config](./config/model_config.json), [dataset config](./config/dataset_config.json), [loss function config](./config/loss_config.json), [training config](./config/train_config_template.json), [evaluation config](./config/eval_config.json), and if multiple datasets are used, the [data join config](./config/join_config.json). These enable the TDC to design highly customized training pipelines without requiring review and approval of new custom code for each use case—reducing risks from potentially malicious or non-compliant code. The consolidated pipeline configuration is then attested against the signed contract using the TDP’s policy-as-code. If approved, it is executed in the CCR to train the model, which we will deploy in the next section.
+
+```mermaid
+flowchart TD
+
+ subgraph Config Files
+ C1[model_config.json]
+ C2[dataset_config.json]
+ C3[loss_config.json]
+ C4[train_config_template.json]
+ C5[eval_config.json]
+ C6[join_config.json]
+ end
+
+ B[Consolidated into
pipeline_config.json]
+
+ C1 --> B
+ C2 --> B
+ C3 --> B
+ C4 --> B
+ C5 --> B
+ C6 --> B
+
+ B --> D[Attested against contract
using policy-as-code]
+ D --> E{Approved?}
+ E -- Yes --> F[CCR training begins]
+ E -- No --> H[Rejected: fix config]
+```
+
+If all goes well, you should see output similar to the following output, and the trained model and evaluation metrics will be saved under the folder [output](./modeller/output).
+
+```bash
+train-1 | Merged dataset 'brats_A' into '/tmp/brats_joined'
+train-1 | Merged dataset 'brat_B' into '/tmp/brats_joined'
+train-1 | Merged dataset 'brats_C' into '/tmp/brats_joined'
+train-1 | Merged dataset 'brats_D' into '/tmp/brats_joined'
+train-1 |
+train-1 | All datasets joined in: /tmp/brats_joined
+train-1 | Training samples: 228
+train-1 | Validation samples: 66
+train-1 | Test samples: 33
+train-1 | Dataset constructed from config
+train-1 | Custom model loaded from PyTorch config
+train-1 | Created non-private baseline model for comparison
+train-1 | Optimizer Adam loaded from config
+train-1 | Scheduler CyclicLR loaded from config
+train-1 | Custom loss function loaded from config
+train-1 | Epoch 1/1 completed | Training Loss: 1.8462 | Epsilon: 1.4971
+train-1 | Epoch 1/1 completed | Validation Loss: 1.8034
+train-1 |
+train-1 | Training non-private replica model for comparison...
+train-1 | Non-private baseline model - Epoch 1/1 completed | Training Loss: 1.7808
+train-1 | Non-private baseline model - Epoch 1/1 completed | Validation Loss: 1.6784
+train-1 | Saving trained model to /mnt/remote/output/trained_model.pth
+train-1 | Evaluation Metrics: {'test_loss': 1.8256315231323241, 'dice_score': 0.00021866214829874793, 'jaccard_index': 0.00010934302874015687, 'hausdorff_distance': 99.54396013822235}
+train-1 | CCR Training complete!
+train-1 |
+train-1 exited with code 0
+```
+
+The trained model along with sample predictions on the validation set will be saved under the [output](./modeller/output/) directory.
+
+Now that training has run successfully locally, let's move on to the actual execution using a Confidential Clean Room (CCR) equipped with confidential computing, key release policies, and contract-based access control.
+
+## Deploy on CCR
+
+In a more realistic scenario, these datasets will not be available in the clear to the TDC, and the TDC will be required to use a CCR for training her model. The following steps describe the process of sharing encrypted datasets with TDCs and setting up a CCR in Azure for training models. Please stay tuned for CCR on other cloud platforms.
+
+To deploy in Azure, you will need the following.
+
+- Docker Hub account to store container images. Alternatively, you can use pre-built images from the ```ispirt``` container registry.
+- [Azure Key Vault](https://azure.microsoft.com/en-us/products/key-vault/) to store encryption keys and implement secure key release to CCR. You can either you Azure Key Vault Premium (lower cost), or [Azure Key Vault managed HSM](https://learn.microsoft.com/en-us/azure/key-vault/managed-hsm/overview) for enhanced security. Please see instructions below on how to create and setup your AKV instance.
+- Valid Azure subscription with sufficient access to create key vault, storage accounts, storage containers, and Azure Container Instances (ACI).
+
+If you are using your own development environment instead of a dev container or codespaces, you will to install the following dependencies.
+
+- [Azure CLI](https://learn.microsoft.com/en-us/cli/azure/install-azure-cli-linux).
+- [Azure CLI Confidential containers extension](https://learn.microsoft.com/en-us/cli/azure/confcom?view=azure-cli-latest). After installing Azure CLI, you can install this extension using ```az extension add --name confcom -y```
+- [Go](https://go.dev/doc/install). Follow the instructions to install Go. After installing, ensure that the PATH environment variable is set to include ```go``` runtime.
+- ```jq```. You can install jq using ```sudo apt-get install -y jq```
+
+We will be creating the following resources as part of the deployment.
+
+- Azure Key Vault
+- Azure Storage account
+- Storage containers to host encrypted datasets
+- Azure Container Instances (ACI) to deploy the CCR and train the model
+
+### 1\. Push Container Images
+
+Pre-built container images are available in iSPIRT's container registry, which can be pulled by setting the following environment variable.
+
+```bash
+export CONTAINER_REGISTRY=ispirt.azurecr.io
+```
+
+If you wish to use your own container images, login to docker hub (or your container registry of choice) and then build and push the container images to it, so that they can be pulled by the CCR. This is a one-time operation, and you can skip this step if you have already pushed the images to your container registry.
+
+```bash
+export CONTAINER_REGISTRY=
+docker login -u -p ${CONTAINER_REGISTRY}
+cd $REPO_ROOT
+./ci/push-containers.sh
+cd $REPO_ROOT/scenarios/$SCENARIO
+./ci/push-containers.sh
+```
+
+> **Note:** Replace ``, `` and `` with your container registry name, docker hub username and password respectively. Preferably use registry services other than Docker Hub as throttling restrictions will cause delays (or) image push/pull failures.
+
+---
+
+### 2\. Create Resources
+
+First, set up the necessary environment variables for your deployment.
+
+```bash
+az login
+
+export SCENARIO=brats
+export CONTAINER_REGISTRY=ispirt.azurecr.io
+export AZURE_LOCATION=northeurope
+export AZURE_SUBSCRIPTION_ID=
+export AZURE_RESOURCE_GROUP=
+export AZURE_KEYVAULT_ENDPOINT=.vault.azure.net
+export AZURE_STORAGE_ACCOUNT_NAME=
+
+export AZURE_BRATS_A_CONTAINER_NAME=bratsacontainer
+export AZURE_BRATS_B_CONTAINER_NAME=bratsbcontainer
+export AZURE_BRATS_C_CONTAINER_NAME=bratsccontainer
+export AZURE_BRATS_D_CONTAINER_NAME=bratsdcontainer
+export AZURE_MODEL_CONTAINER_NAME=modelcontainer
+export AZURE_OUTPUT_CONTAINER_NAME=outputcontainer
+```
+
+Alternatively, you can edit the values in the [export-variables.sh](./export-variables.sh) script and run it to set the environment variables.
+
+```bash
+./export-variables.sh
+source export-variables.sh
+```
+
+Azure Naming Rules:
+- Resource Group:
+ - 1–90 characters
+ - Letters, numbers, underscores, parentheses, hyphens, periods allowed
+ - Cannot end with a period (.)
+ - Case-insensitive, unique within subscription\
+- Key Vault:
+ - 3-24 characters
+ - Globally unique name
+ - Lowercase letters, numbers, hyphens only
+ - Must start and end with letter or number
+- Storage Account:
+ - 3-24 characters
+ - Globally unique name
+ - Lowercase letters and numbers only
+- Storage Container:
+ - 3-63 characters
+ - Lowercase letters, numbers, hyphens only
+ - Must start and end with a letter or number
+ - No consecutive hyphens
+ - Unique within storage account
+
+---
+
+**Important:**
+
+The values for the environment variables listed below must precisely match the namesake environment variables used during contract signing (next step). Any mismatch will lead to execution failure.
+
+- `SCENARIO`
+- `AZURE_KEYVAULT_ENDPOINT`
+- `CONTRACT_SERVICE_URL`
+- `AZURE_STORAGE_ACCOUNT_NAME`
+- `AZURE_BRATS_A_CONTAINER_NAME`
+- `AZURE_BRATS_B_CONTAINER_NAME`
+- `AZURE_BRATS_C_CONTAINER_NAME`
+- `AZURE_BRATS_D_CONTAINER_NAME`
+
+---
+With the environment variables set, we are ready to create the resources -- Azure Key Vault and Azure Storage containers.
+
+```bash
+cd $REPO_ROOT/scenarios/$SCENARIO/deployment/azure
+./1-create-storage-containers.sh
+./2-create-akv.sh
+```
+---
+
+### 3\. Contract Signing
+
+Navigate to the [contract-ledger](https://github.com/kapilvgit/contract-ledger/blob/main/README.md) repository and follow the instructions for contract signing.
+
+Once the contract is signed, export the contract sequence number as an environment variable in the same terminal where you set the environment variables for the deployment.
+
+```bash
+export CONTRACT_SEQ_NO=
+```
+
+---
+
+### 4\. Data Encryption and Upload
+
+Using their respective keys, the TDPs and TDC encrypt their datasets and model (respectively) and upload them to the Storage containers created in the previous step.
+
+Navigate to the [Azure deployment](./deployment/azure/) directory and execute the scripts for key import, data encryption and upload to Azure Blob Storage, in preparation of the CCR deployment.
+
+The import-keys script generates and imports encryption keys into Azure Key Vault with a policy based on [policy-in-template.json](./policy/policy-in-template.json). The policy requires that the CCRs run specific containers with a specific configuration which includes the public identity of the contract service. Only CCRs that satisfy this policy will be granted access to the encryption keys. The generated keys are available as files with the extension `.bin`.
+
+```bash
+export CONTRACT_SERVICE_URL=https://depa-training-contract-service.centralindia.cloudapp.azure.com:8000
+export TOOLS_HOME=$REPO_ROOT/external/confidential-sidecar-containers/tools
+
+./3-import-keys.sh
+```
+
+The data and model are then packaged as encrypted filesystems by the TDPs and TDC using their respective keys, which are saved as `.img` files.
+
+```bash
+./4-encrypt-data.sh
+```
+
+The encrypted data and model are then uploaded to the Storage containers created in the previous step. The `.img` files are uploaded to the Storage containers as blobs.
+
+```bash
+./5-upload-encrypted-data.sh
+```
+
+---
+
+### 5\. CCR Deployment
+
+With the resources ready, we are ready to deploy the Confidential Clean Room (CCR) for executing the privacy-preserving model training.
+
+```bash
+export CONTRACT_SEQ_NO=
+./deploy.sh -c $CONTRACT_SEQ_NO -p ../../config/pipeline_config.json
+```
+
+Set the `$CONTRACT_SEQ_NO` variable to the exact value of the contract sequence number (of format 2.XX). For example, if the number was 2.15, export as:
+
+```bash
+export CONTRACT_SEQ_NO=15
+```
+
+This script will deploy the container images from your container registry, including the encrypted filesystem sidecar. The sidecar will generate an SEV-SNP attestation report, generate an attestation token using the Microsoft Azure Attestation (MAA) service, retrieve dataset, model and output encryption keys from the TDP and TDC's Azure Key Vault, train the model, and save the resulting model into TDC's output filesystem image, which the TDC can later decrypt.
+
+
+
+**Note:** The completion of this script's execution simply creates a CCR instance, and doesn't indicate whether training has completed or not. The training process might still be ongoing. Poll the container logs (see below) to track progress until training is complete.
+
+### 6\. Monitor Container Logs
+
+Use the following commands to monitor the logs of the deployed containers. You might have to repeatedly poll this command to monitor the training progress:
+
+```bash
+az container logs \
+ --name "depa-training-$SCENARIO" \
+ --resource-group "$AZURE_RESOURCE_GROUP" \
+ --container-name depa-training
+```
+
+You will know training has completed when the logs print "CCR Training complete!".
+
+#### Troubleshooting
+
+In case training fails, you might want to monitor the logs of the encrypted storage sidecar container to see if the encryption process completed successfully:
+
+```bash
+az container logs --name depa-training-$SCENARIO --resource-group $AZURE_RESOURCE_GROUP --container-name encrypted-storage-sidecar
+```
+
+And to further debug, inspect the logs of the encrypted filesystem sidecar container:
+
+```bash
+az container exec \
+ --resource-group $AZURE_RESOURCE_GROUP \
+ --name depa-training-$SCENARIO \
+ --container-name encrypted-storage-sidecar \
+ --exec-command "/bin/sh"
+```
+
+Once inside the sidecar container shell, view the logs:
+
+```bash
+cat log.txt
+```
+Or inspect the individual mounted directories in `mnt/remote/`:
+
+```bash
+cd mnt/remote && ls
+```
+
+### 6\. Download and Decrypt Model
+
+Once training has completed succesfully (The training container logs will mention it explicitly), download and decrypt the trained model and other training outputs.
+
+```bash
+./6-download-decrypt-model.sh
+```
+
+The outputs will be saved to the [output](./modeller/output/) directory.
+
+To check if the trained model is fresh, you can run the following command:
+
+```bash
+stat $REPO_ROOT/scenarios/$SCENARIO/modeller/output/trained_model.pth
+```
+
+---
+### Clean-up
+
+You can use the following command to delete the resource group and clean-up all resources used in the demo. Alternatively, you can navigate to the Azure portal and delete the resource group created for this demo.
+
+```bash
+az group delete --yes --name $AZURE_RESOURCE_GROUP
+```
+
+### References
+
+[1] B. H. Menze, A. Jakab, S. Bauer, J. Kalpathy-Cramer, K. Farahani, J. Kirby, et al. "The Multimodal Brain Tumor Image Segmentation Benchmark (BRATS)", IEEE Transactions on Medical Imaging 34(10), 1993-2024 (2015) DOI: 10.1109/TMI.2014.2377694 (opens in a new window)
+
+[2] S. Bakas, H. Akbari, A. Sotiras, M. Bilello, M. Rozycki, J.S. Kirby, et al., "Advancing The Cancer Genome Atlas glioma MRI collections with expert segmentation labels and radiomic features", Nature Scientific Data, 4:170117 (2017) DOI: 10.1038/sdata.2017.117(opens in a new window)
+
+[3] S. Bakas, M. Reyes, A. Jakab, S. Bauer, M. Rempfler, A. Crimi, et al., "Identifying the Best Machine Learning Algorithms for Brain Tumor Segmentation, Progression Assessment, and Overall Survival Prediction in the BRATS Challenge", arXiv preprint arXiv:1811.02629 (2018)
\ No newline at end of file
diff --git a/scenarios/brats/ci/Dockerfile.bratsA b/scenarios/brats/ci/Dockerfile.bratsA
new file mode 100644
index 0000000..7489ce9
--- /dev/null
+++ b/scenarios/brats/ci/Dockerfile.bratsA
@@ -0,0 +1,19 @@
+FROM ubuntu:22.04
+
+ENV DEBIAN_FRONTEND="noninteractive"
+
+RUN apt-get update && apt-get -y upgrade \
+ && apt-get install -y gcc g++ curl \
+ && apt-get install -y python3 python3-dev python3-distutils
+
+## Install pip
+RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
+RUN python3 get-pip.py
+
+RUN pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
+RUN pip3 --default-timeout=1000 install numpy opencv-python pillow nibabel matplotlib
+
+# ENV JAVA_HOME /usr/lib/jvm/java-17-openjdk-amd64/
+# RUN export JAVA_HOME
+
+COPY preprocess_brats_A.py preprocess_brats_A.py
diff --git a/scenarios/brats/ci/Dockerfile.bratsB b/scenarios/brats/ci/Dockerfile.bratsB
new file mode 100644
index 0000000..495a930
--- /dev/null
+++ b/scenarios/brats/ci/Dockerfile.bratsB
@@ -0,0 +1,19 @@
+FROM ubuntu:22.04
+
+ENV DEBIAN_FRONTEND="noninteractive"
+
+RUN apt-get update && apt-get -y upgrade \
+ && apt-get install -y gcc g++ curl \
+ && apt-get install -y python3 python3-dev python3-distutils
+
+## Install pip
+RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
+RUN python3 get-pip.py
+
+RUN pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
+RUN pip3 --default-timeout=1000 install numpy opencv-python pillow nibabel matplotlib
+
+# ENV JAVA_HOME /usr/lib/jvm/java-17-openjdk-amd64/
+# RUN export JAVA_HOME
+
+COPY preprocess_brats_B.py preprocess_brats_B.py
diff --git a/scenarios/brats/ci/Dockerfile.bratsC b/scenarios/brats/ci/Dockerfile.bratsC
new file mode 100644
index 0000000..ba2f284
--- /dev/null
+++ b/scenarios/brats/ci/Dockerfile.bratsC
@@ -0,0 +1,19 @@
+FROM ubuntu:22.04
+
+ENV DEBIAN_FRONTEND="noninteractive"
+
+RUN apt-get update && apt-get -y upgrade \
+ && apt-get install -y gcc g++ curl \
+ && apt-get install -y python3 python3-dev python3-distutils
+
+## Install pip
+RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
+RUN python3 get-pip.py
+
+RUN pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
+RUN pip3 --default-timeout=1000 install numpy opencv-python pillow nibabel matplotlib
+
+# ENV JAVA_HOME /usr/lib/jvm/java-17-openjdk-amd64/
+# RUN export JAVA_HOME
+
+COPY preprocess_brats_C.py preprocess_brats_C.py
diff --git a/scenarios/brats/ci/Dockerfile.bratsD b/scenarios/brats/ci/Dockerfile.bratsD
new file mode 100644
index 0000000..c91140a
--- /dev/null
+++ b/scenarios/brats/ci/Dockerfile.bratsD
@@ -0,0 +1,19 @@
+FROM ubuntu:22.04
+
+ENV DEBIAN_FRONTEND="noninteractive"
+
+RUN apt-get update && apt-get -y upgrade \
+ && apt-get install -y gcc g++ curl \
+ && apt-get install -y python3 python3-dev python3-distutils
+
+## Install pip
+RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
+RUN python3 get-pip.py
+
+RUN pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
+RUN pip3 --default-timeout=1000 install numpy opencv-python pillow nibabel matplotlib
+
+# ENV JAVA_HOME /usr/lib/jvm/java-17-openjdk-amd64/
+# RUN export JAVA_HOME
+
+COPY preprocess_brats_D.py preprocess_brats_D.py
diff --git a/scenarios/brats/ci/Dockerfile.modelsave b/scenarios/brats/ci/Dockerfile.modelsave
new file mode 100644
index 0000000..f21c6c9
--- /dev/null
+++ b/scenarios/brats/ci/Dockerfile.modelsave
@@ -0,0 +1,17 @@
+FROM ubuntu:22.04
+
+ENV DEBIAN_FRONTEND="noninteractive"
+
+RUN apt-get update && apt-get -y upgrade \
+ && apt-get install -y gcc g++ curl \
+ && apt-get install -y python3 python3-dev python3-distutils
+
+## Install pip
+RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
+RUN python3 get-pip.py
+
+## Install dependencies
+RUN pip3 install torch --index-url https://download.pytorch.org/whl/cpu
+RUN pip3 install safetensors packaging numpy
+COPY save_base_model.py save_base_model.py
+COPY model_constructor.py model_constructor.py
diff --git a/scenarios/brats/ci/build.sh b/scenarios/brats/ci/build.sh
new file mode 100755
index 0000000..ad4e226
--- /dev/null
+++ b/scenarios/brats/ci/build.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+
+docker build -f ci/Dockerfile.bratsA src -t preprocess-brats-a:latest
+docker build -f ci/Dockerfile.bratsB src -t preprocess-brats-b:latest
+docker build -f ci/Dockerfile.bratsC src -t preprocess-brats-c:latest
+docker build -f ci/Dockerfile.bratsD src -t preprocess-brats-d:latest
+docker build -f ci/Dockerfile.modelsave src -t brats-model-save:latest
diff --git a/scenarios/brats/ci/pull-containers.sh b/scenarios/brats/ci/pull-containers.sh
new file mode 100755
index 0000000..9cd9161
--- /dev/null
+++ b/scenarios/brats/ci/pull-containers.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+
+containers=("preprocess-brats-a:latest" "preprocess-brats-b:latest" "preprocess-brats-c:latest" "preprocess-brats-d:latest" "brats-model-save:latest")
+for container in "${containers[@]}"
+do
+ docker pull $CONTAINER_REGISTRY"/"$container
+done
\ No newline at end of file
diff --git a/scenarios/brats/ci/push-containers.sh b/scenarios/brats/ci/push-containers.sh
new file mode 100755
index 0000000..cfcf80c
--- /dev/null
+++ b/scenarios/brats/ci/push-containers.sh
@@ -0,0 +1,6 @@
+containers=("brats-model-save:latest" "preprocess-brats-a:latest" "preprocess-brats-b:latest" "preprocess-brats-c:latest" "preprocess-brats-d:latest")
+for container in "${containers[@]}"
+do
+ docker tag $container $CONTAINER_REGISTRY"/"$container
+ docker push $CONTAINER_REGISTRY"/"$container
+done
diff --git a/scenarios/brats/config/consolidate_pipeline.sh b/scenarios/brats/config/consolidate_pipeline.sh
new file mode 100755
index 0000000..082cb4d
--- /dev/null
+++ b/scenarios/brats/config/consolidate_pipeline.sh
@@ -0,0 +1,58 @@
+#! /bin/bash
+
+REPO_ROOT="$(git rev-parse --show-toplevel)"
+SCENARIO=brats
+
+template_path="$REPO_ROOT/scenarios/$SCENARIO/config/templates"
+model_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/model_config.json"
+data_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/dataset_config.json"
+loss_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/loss_config.json"
+train_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/train_config.json"
+eval_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/eval_config.json"
+join_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/join_config.json"
+pipeline_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/pipeline_config.json"
+
+# populate "model_config", "data_config", and "loss_config" keys in train config
+train_config=$(cat $template_path/train_config_template.json)
+
+# Only merge if the file exists
+if [[ -f "$model_config_path" ]]; then
+ model_config=$(cat $model_config_path)
+ train_config=$(echo "$train_config" | jq --argjson model "$model_config" '.config.model_config = $model')
+fi
+
+if [[ -f "$data_config_path" ]]; then
+ data_config=$(cat $data_config_path)
+ train_config=$(echo "$train_config" | jq --argjson data "$data_config" '.config.dataset_config = $data')
+fi
+
+if [[ -f "$loss_config_path" ]]; then
+ loss_config=$(cat $loss_config_path)
+ train_config=$(echo "$train_config" | jq --argjson loss "$loss_config" '.config.loss_config = $loss')
+fi
+
+if [[ -f "$eval_config_path" ]]; then
+ eval_config=$(cat $eval_config_path)
+ # Get all keys from eval_config and copy them to train_config
+ for key in $(echo "$eval_config" | jq -r 'keys[]'); do
+ train_config=$(echo "$train_config" | jq --argjson eval "$eval_config" --arg key "$key" '.config[$key] = $eval[$key]')
+ done
+fi
+
+# save train_config
+echo "$train_config" > $train_config_path
+
+# prepare pipeline config from join_config.json (first dict "config") and train_config.json (second dict "config")
+pipeline_config=$(cat $template_path/pipeline_config_template.json)
+
+# Only merge join_config if the file exists
+if [[ -f "$join_config_path" ]]; then
+ join_config=$(cat $join_config_path)
+ pipeline_config=$(echo "$pipeline_config" | jq --argjson join "$join_config" '.pipeline += [$join]')
+fi
+
+# Always merge train_config as it's required
+pipeline_config=$(echo "$pipeline_config" | jq --argjson train "$train_config" '.pipeline += [$train]')
+
+# save pipeline_config to pipeline_config.json
+echo "$pipeline_config" > $pipeline_config_path
\ No newline at end of file
diff --git a/scenarios/brats/config/dataset_config.json b/scenarios/brats/config/dataset_config.json
new file mode 100644
index 0000000..05a36e2
--- /dev/null
+++ b/scenarios/brats/config/dataset_config.json
@@ -0,0 +1,27 @@
+{
+ "type": "directory",
+ "structure_type": "paired",
+ "data_type": "image",
+ "pairing": {
+ "folder_pattern": "BraTS20_Training_*",
+ "input_pattern": "*_flair.png",
+ "target_pattern": "*_seg.png"
+ },
+ "image_config": {
+ "use_cv2": true,
+ "convert_to_pil": true,
+ "grayscale": true,
+ "to_tensor": true,
+ "binarize": true,
+ "binarize_threshold": 0
+ },
+ "filtering": {
+ "filter_empty_targets": true
+ },
+ "splits": {
+ "train": 0.7,
+ "val": 0.2,
+ "test": 0.1,
+ "random_state": 42
+ }
+}
\ No newline at end of file
diff --git a/scenarios/brats/config/eval_config.json b/scenarios/brats/config/eval_config.json
new file mode 100644
index 0000000..f6e614e
--- /dev/null
+++ b/scenarios/brats/config/eval_config.json
@@ -0,0 +1,24 @@
+{
+ "task_type": "segmentation",
+ "threshold": 0.3,
+ "metrics": [
+ {
+ "name": "dice_score",
+ "params": {
+ "threshold": 0.3
+ }
+ },
+ {
+ "name": "jaccard_index",
+ "params": {
+ "threshold": 0.3
+ }
+ },
+ {
+ "name": "hausdorff_distance",
+ "params": {
+ "threshold": 0.3
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/scenarios/brats/config/join_config.json b/scenarios/brats/config/join_config.json
new file mode 100644
index 0000000..3d163a4
--- /dev/null
+++ b/scenarios/brats/config/join_config.json
@@ -0,0 +1,32 @@
+{
+ "name": "DirectoryJoin",
+ "config": {
+ "datasets": [
+ {
+ "id": "19517ba8-bab8-11ed-afa1-0242ac120002",
+ "provider": "brats_A",
+ "name": "BraTS_Brain_MRI_set_A",
+ "mount_path": "/mnt/remote/brats_A"
+ },
+ {
+ "id": "216d5cc6-bab8-11ed-afa1-0242ac120002",
+ "provider": "brats_B",
+ "name": "BraTS_Brain_MRI_set_B",
+ "mount_path": "/mnt/remote/brats_B"
+ },
+ {
+ "id": "2830a144-bab8-11ed-afa1-0242ac120002",
+ "provider": "brats_C",
+ "name": "BraTS_Brain_MRI_set_C",
+ "mount_path": "/mnt/remote/brats_C"
+ },
+ {
+ "id": "2a3b0c4e-bab8-11ed-afa1-0242ac120002",
+ "provider": "brats_D",
+ "name": "BraTS_Brain_MRI_set_D",
+ "mount_path": "/mnt/remote/brats_D"
+ }
+ ],
+ "joined_dataset": "/tmp/brats_joined"
+ }
+}
\ No newline at end of file
diff --git a/scenarios/brats/config/loss_config.json b/scenarios/brats/config/loss_config.json
new file mode 100644
index 0000000..8516f4c
--- /dev/null
+++ b/scenarios/brats/config/loss_config.json
@@ -0,0 +1,27 @@
+{
+ "expression": "dice_loss + 2 * l1_loss",
+ "components": {
+ "dice_loss": {
+ "class": "monai.losses.DiceLoss",
+ "params": {
+ "sigmoid": true,
+ "squared_pred": true,
+ "reduction": "mean"
+ }
+ },
+ "l1_loss": {
+ "class": "torch.nn.L1Loss",
+ "params": {
+ "reduction": "mean"
+ }
+ },
+ "bce_loss": {
+ "class": "torch.nn.functional.binary_cross_entropy_with_logits",
+ "params": {
+ "input": "outputs",
+ "target": "targets",
+ "reduction": "mean"
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/scenarios/brats/config/model_config.json b/scenarios/brats/config/model_config.json
new file mode 100644
index 0000000..5b05712
--- /dev/null
+++ b/scenarios/brats/config/model_config.json
@@ -0,0 +1,560 @@
+{
+ "submodules": {
+ "ConvBlock2d": {
+ "layers": {
+ "conv1": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": "$in_ch",
+ "out_channels": "$mid_ch",
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ }
+ },
+ "norm1": {
+ "class": "nn.GroupNorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": "$mid_ch"
+ }
+ },
+ "conv2": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": "$mid_ch",
+ "out_channels": "$out_ch",
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ }
+ },
+ "norm2": {
+ "class": "nn.GroupNorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": "$out_ch"
+ }
+ }
+ },
+ "input": [
+ "x"
+ ],
+ "forward": [
+ {
+ "ops": [
+ "conv1",
+ "norm1"
+ ],
+ "input": [
+ "x"
+ ],
+ "output": "x1"
+ },
+ {
+ "ops": [
+ [
+ "F.leaky_relu",
+ {
+ "negative_slope": 0.1
+ }
+ ]
+ ],
+ "input": [
+ "x1"
+ ],
+ "output": "x2"
+ },
+ {
+ "ops": [
+ "conv2",
+ "norm2"
+ ],
+ "input": [
+ "x2"
+ ],
+ "output": "x3"
+ },
+ {
+ "ops": [
+ [
+ "F.leaky_relu",
+ {
+ "negative_slope": 0.1
+ }
+ ]
+ ],
+ "input": [
+ "x3"
+ ],
+ "output": "x4"
+ }
+ ],
+ "output": [
+ "x4"
+ ]
+ },
+ "Upsample": {
+ "layers": {
+ "conv1": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": "$in_ch",
+ "out_channels": "$out_ch",
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ }
+ },
+ "norm1": {
+ "class": "nn.GroupNorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": "$out_ch"
+ }
+ }
+ },
+ "input": [
+ "x",
+ "encoded_feature"
+ ],
+ "forward": [
+ {
+ "ops": [
+ [
+ "F.interpolate",
+ {
+ "scale_factor": 2.0,
+ "mode": "bilinear",
+ "align_corners": false
+ }
+ ]
+ ],
+ "input": [
+ "x"
+ ],
+ "output": "x1"
+ },
+ {
+ "ops": [
+ "conv1",
+ "norm1"
+ ],
+ "input": [
+ "x1"
+ ],
+ "output": "x2"
+ },
+ {
+ "ops": [
+ [
+ "F.leaky_relu",
+ {
+ "negative_slope": 0.1
+ }
+ ]
+ ],
+ "input": [
+ "x2"
+ ],
+ "output": "x3"
+ },
+ {
+ "ops": [
+ [
+ "torch.cat",
+ {
+ "dim": 1
+ }
+ ]
+ ],
+ "input": [
+ [
+ "x3",
+ "encoded_feature"
+ ]
+ ],
+ "output": "x4"
+ }
+ ],
+ "output": [
+ "x4"
+ ]
+ }
+ },
+ "layers": {
+ "in_conv": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": 1,
+ "out_channels": 8,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ }
+ },
+ "down_conv0": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 8,
+ "mid_ch": 16,
+ "out_ch": 16
+ }
+ },
+ "down_pool0": {
+ "class": "nn.MaxPool2d",
+ "params": {
+ "kernel_size": 2,
+ "stride": 2
+ }
+ },
+ "down_conv1": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 16,
+ "mid_ch": 32,
+ "out_ch": 32
+ }
+ },
+ "down_pool1": {
+ "class": "nn.MaxPool2d",
+ "params": {
+ "kernel_size": 2,
+ "stride": 2
+ }
+ },
+ "down_conv2": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 32,
+ "mid_ch": 64,
+ "out_ch": 64
+ }
+ },
+ "down_pool2": {
+ "class": "nn.MaxPool2d",
+ "params": {
+ "kernel_size": 2,
+ "stride": 2
+ }
+ },
+ "down_conv3": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 64,
+ "mid_ch": 128,
+ "out_ch": 128
+ }
+ },
+ "down_pool3": {
+ "class": "nn.MaxPool2d",
+ "params": {
+ "kernel_size": 2,
+ "stride": 2
+ }
+ },
+ "bottleneck": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 128,
+ "mid_ch": 256,
+ "out_ch": 256
+ }
+ },
+ "up_samp3": {
+ "submodule": "Upsample",
+ "params": {
+ "in_ch": 256,
+ "out_ch": 128
+ }
+ },
+ "up_conv3": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 256,
+ "mid_ch": 128,
+ "out_ch": 128
+ }
+ },
+ "up_samp2": {
+ "submodule": "Upsample",
+ "params": {
+ "in_ch": 128,
+ "out_ch": 64
+ }
+ },
+ "up_conv2": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 128,
+ "mid_ch": 64,
+ "out_ch": 64
+ }
+ },
+ "up_samp1": {
+ "submodule": "Upsample",
+ "params": {
+ "in_ch": 64,
+ "out_ch": 32
+ }
+ },
+ "up_conv1": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 64,
+ "mid_ch": 32,
+ "out_ch": 32
+ }
+ },
+ "up_samp0": {
+ "submodule": "Upsample",
+ "params": {
+ "in_ch": 32,
+ "out_ch": 16
+ }
+ },
+ "up_conv0": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 32,
+ "mid_ch": 16,
+ "out_ch": 16
+ }
+ },
+ "out_conv1": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": 16,
+ "out_channels": 8,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ }
+ },
+ "out_conv2": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": 8,
+ "out_channels": 1,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ }
+ }
+ },
+ "input": [
+ "x"
+ ],
+ "forward": [
+ {
+ "ops": [
+ "in_conv"
+ ],
+ "input": [
+ "x"
+ ],
+ "output": "x1"
+ },
+ {
+ "ops": [
+ "down_conv0"
+ ],
+ "input": [
+ "x1"
+ ],
+ "output": "x2"
+ },
+ {
+ "ops": [
+ "down_pool0"
+ ],
+ "input": [
+ "x2"
+ ],
+ "output": "x3"
+ },
+ {
+ "ops": [
+ "down_conv1"
+ ],
+ "input": [
+ "x3"
+ ],
+ "output": "x4"
+ },
+ {
+ "ops": [
+ "down_pool1"
+ ],
+ "input": [
+ "x4"
+ ],
+ "output": "x5"
+ },
+ {
+ "ops": [
+ "down_conv2"
+ ],
+ "input": [
+ "x5"
+ ],
+ "output": "x6"
+ },
+ {
+ "ops": [
+ "down_pool2"
+ ],
+ "input": [
+ "x6"
+ ],
+ "output": "x7"
+ },
+ {
+ "ops": [
+ "down_conv3"
+ ],
+ "input": [
+ "x7"
+ ],
+ "output": "x8"
+ },
+ {
+ "ops": [
+ "down_pool3"
+ ],
+ "input": [
+ "x8"
+ ],
+ "output": "x9"
+ },
+ {
+ "ops": [
+ "bottleneck"
+ ],
+ "input": [
+ "x9"
+ ],
+ "output": "x10"
+ },
+ {
+ "ops": [
+ "up_samp3"
+ ],
+ "input": [
+ "x10",
+ "x8"
+ ],
+ "output": "x11"
+ },
+ {
+ "ops": [
+ "up_conv3"
+ ],
+ "input": [
+ "x11"
+ ],
+ "output": "x12"
+ },
+ {
+ "ops": [
+ "up_samp2"
+ ],
+ "input": [
+ "x12",
+ "x6"
+ ],
+ "output": "x13"
+ },
+ {
+ "ops": [
+ "up_conv2"
+ ],
+ "input": [
+ "x13"
+ ],
+ "output": "x14"
+ },
+ {
+ "ops": [
+ "up_samp1"
+ ],
+ "input": [
+ "x14",
+ "x4"
+ ],
+ "output": "x15"
+ },
+ {
+ "ops": [
+ "up_conv1"
+ ],
+ "input": [
+ "x15"
+ ],
+ "output": "x16"
+ },
+ {
+ "ops": [
+ "up_samp0"
+ ],
+ "input": [
+ "x16",
+ "x2"
+ ],
+ "output": "x17"
+ },
+ {
+ "ops": [
+ "up_conv0"
+ ],
+ "input": [
+ "x17"
+ ],
+ "output": "x18"
+ },
+ {
+ "ops": [
+ "out_conv1"
+ ],
+ "input": [
+ "x18"
+ ],
+ "output": "x19"
+ },
+ {
+ "ops": [
+ [
+ "F.leaky_relu",
+ {
+ "negative_slope": 0.1
+ }
+ ]
+ ],
+ "input": [
+ "x19"
+ ],
+ "output": "x20"
+ },
+ {
+ "ops": [
+ "out_conv2"
+ ],
+ "input": [
+ "x20"
+ ],
+ "output": "x21"
+ },
+ {
+ "ops": [
+ "torch.sigmoid"
+ ],
+ "input": [
+ "x21"
+ ],
+ "output": "x22"
+ }
+ ],
+ "output": [
+ "x22"
+ ]
+}
\ No newline at end of file
diff --git a/scenarios/brats/config/model_config_old.json b/scenarios/brats/config/model_config_old.json
new file mode 100644
index 0000000..b6d558e
--- /dev/null
+++ b/scenarios/brats/config/model_config_old.json
@@ -0,0 +1,889 @@
+{
+ "layers": {
+ "input_conv": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 1,
+ "out_channels": 8,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "input"
+ ]
+ },
+ "down_conv1_0": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 8,
+ "out_channels": 16,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "input_conv"
+ ]
+ },
+ "down_norm1_0": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 16
+ },
+ "inputs": [
+ "down_conv1_0"
+ ]
+ },
+ "down_relu1_0": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "down_norm1_0"
+ ]
+ },
+ "down_conv2_0": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 16,
+ "out_channels": 16,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "down_relu1_0"
+ ]
+ },
+ "down_norm2_0": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 16
+ },
+ "inputs": [
+ "down_conv2_0"
+ ]
+ },
+ "down_relu2_0": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "down_norm2_0"
+ ]
+ },
+ "down_pool_0": {
+ "type": "maxpool2d",
+ "params": {
+ "kernel_size": 2,
+ "stride": 2
+ },
+ "inputs": [
+ "down_relu2_0"
+ ]
+ },
+ "down_conv1_1": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 16,
+ "out_channels": 32,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "down_pool_0"
+ ]
+ },
+ "down_norm1_1": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 32
+ },
+ "inputs": [
+ "down_conv1_1"
+ ]
+ },
+ "down_relu1_1": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "down_norm1_1"
+ ]
+ },
+ "down_conv2_1": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 32,
+ "out_channels": 32,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "down_relu1_1"
+ ]
+ },
+ "down_norm2_1": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 32
+ },
+ "inputs": [
+ "down_conv2_1"
+ ]
+ },
+ "down_relu2_1": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "down_norm2_1"
+ ]
+ },
+ "down_pool_1": {
+ "type": "maxpool2d",
+ "params": {
+ "kernel_size": 2,
+ "stride": 2
+ },
+ "inputs": [
+ "down_relu2_1"
+ ]
+ },
+ "down_conv1_2": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 32,
+ "out_channels": 64,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "down_pool_1"
+ ]
+ },
+ "down_norm1_2": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 64
+ },
+ "inputs": [
+ "down_conv1_2"
+ ]
+ },
+ "down_relu1_2": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "down_norm1_2"
+ ]
+ },
+ "down_conv2_2": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 64,
+ "out_channels": 64,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "down_relu1_2"
+ ]
+ },
+ "down_norm2_2": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 64
+ },
+ "inputs": [
+ "down_conv2_2"
+ ]
+ },
+ "down_relu2_2": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "down_norm2_2"
+ ]
+ },
+ "down_pool_2": {
+ "type": "maxpool2d",
+ "params": {
+ "kernel_size": 2,
+ "stride": 2
+ },
+ "inputs": [
+ "down_relu2_2"
+ ]
+ },
+ "down_conv1_3": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 64,
+ "out_channels": 128,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "down_pool_2"
+ ]
+ },
+ "down_norm1_3": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 128
+ },
+ "inputs": [
+ "down_conv1_3"
+ ]
+ },
+ "down_relu1_3": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "down_norm1_3"
+ ]
+ },
+ "down_conv2_3": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 128,
+ "out_channels": 128,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "down_relu1_3"
+ ]
+ },
+ "down_norm2_3": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 128
+ },
+ "inputs": [
+ "down_conv2_3"
+ ]
+ },
+ "down_relu2_3": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "down_norm2_3"
+ ]
+ },
+ "down_pool_3": {
+ "type": "maxpool2d",
+ "params": {
+ "kernel_size": 2,
+ "stride": 2
+ },
+ "inputs": [
+ "down_relu2_3"
+ ]
+ },
+ "bottleneck_conv1": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 128,
+ "out_channels": 256,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "down_pool_3"
+ ]
+ },
+ "bottleneck_norm1": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 256
+ },
+ "inputs": [
+ "bottleneck_conv1"
+ ]
+ },
+ "bottleneck_relu1": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "bottleneck_norm1"
+ ]
+ },
+ "bottleneck_conv2": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 256,
+ "out_channels": 256,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "bottleneck_relu1"
+ ]
+ },
+ "bottleneck_norm2": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 256
+ },
+ "inputs": [
+ "bottleneck_conv2"
+ ]
+ },
+ "bottleneck_relu2": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "bottleneck_norm2"
+ ]
+ },
+ "up_interpolate_3": {
+ "type": "interpolate",
+ "params": {
+ "scale_factor": 2.0,
+ "mode": "bilinear",
+ "align_corners": false
+ },
+ "inputs": [
+ "bottleneck_relu2"
+ ]
+ },
+ "up_conv_3": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 256,
+ "out_channels": 128,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "up_interpolate_3"
+ ]
+ },
+ "up_norm_3": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 128
+ },
+ "inputs": [
+ "up_conv_3"
+ ]
+ },
+ "up_relu_3": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "up_norm_3"
+ ]
+ },
+ "skip_concat_3": {
+ "type": "concat",
+ "params": {
+ "dim": 1
+ },
+ "inputs": [
+ "down_relu2_3",
+ "up_relu_3"
+ ]
+ },
+ "up_conv_block1_3": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 256,
+ "out_channels": 128,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "skip_concat_3"
+ ]
+ },
+ "up_conv_norm1_3": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 128
+ },
+ "inputs": [
+ "up_conv_block1_3"
+ ]
+ },
+ "up_conv_relu1_3": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "up_conv_norm1_3"
+ ]
+ },
+ "up_conv_block2_3": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 128,
+ "out_channels": 128,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "up_conv_relu1_3"
+ ]
+ },
+ "up_conv_norm2_3": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 128
+ },
+ "inputs": [
+ "up_conv_block2_3"
+ ]
+ },
+ "up_conv_relu2_3": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "up_conv_norm2_3"
+ ]
+ },
+ "up_interpolate_2": {
+ "type": "interpolate",
+ "params": {
+ "scale_factor": 2.0,
+ "mode": "bilinear",
+ "align_corners": false
+ },
+ "inputs": [
+ "up_conv_relu2_3"
+ ]
+ },
+ "up_conv_2": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 128,
+ "out_channels": 64,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "up_interpolate_2"
+ ]
+ },
+ "up_norm_2": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 64
+ },
+ "inputs": [
+ "up_conv_2"
+ ]
+ },
+ "up_relu_2": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "up_norm_2"
+ ]
+ },
+ "skip_concat_2": {
+ "type": "concat",
+ "params": {
+ "dim": 1
+ },
+ "inputs": [
+ "down_relu2_2",
+ "up_relu_2"
+ ]
+ },
+ "up_conv_block1_2": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 128,
+ "out_channels": 64,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "skip_concat_2"
+ ]
+ },
+ "up_conv_norm1_2": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 64
+ },
+ "inputs": [
+ "up_conv_block1_2"
+ ]
+ },
+ "up_conv_relu1_2": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "up_conv_norm1_2"
+ ]
+ },
+ "up_conv_block2_2": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 64,
+ "out_channels": 64,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "up_conv_relu1_2"
+ ]
+ },
+ "up_conv_norm2_2": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 64
+ },
+ "inputs": [
+ "up_conv_block2_2"
+ ]
+ },
+ "up_conv_relu2_2": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "up_conv_norm2_2"
+ ]
+ },
+ "up_interpolate_1": {
+ "type": "interpolate",
+ "params": {
+ "scale_factor": 2.0,
+ "mode": "bilinear",
+ "align_corners": false
+ },
+ "inputs": [
+ "up_conv_relu2_2"
+ ]
+ },
+ "up_conv_1": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 64,
+ "out_channels": 32,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "up_interpolate_1"
+ ]
+ },
+ "up_norm_1": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 32
+ },
+ "inputs": [
+ "up_conv_1"
+ ]
+ },
+ "up_relu_1": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "up_norm_1"
+ ]
+ },
+ "skip_concat_1": {
+ "type": "concat",
+ "params": {
+ "dim": 1
+ },
+ "inputs": [
+ "down_relu2_1",
+ "up_relu_1"
+ ]
+ },
+ "up_conv_block1_1": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 64,
+ "out_channels": 32,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "skip_concat_1"
+ ]
+ },
+ "up_conv_norm1_1": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 32
+ },
+ "inputs": [
+ "up_conv_block1_1"
+ ]
+ },
+ "up_conv_relu1_1": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "up_conv_norm1_1"
+ ]
+ },
+ "up_conv_block2_1": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 32,
+ "out_channels": 32,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "up_conv_relu1_1"
+ ]
+ },
+ "up_conv_norm2_1": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 32
+ },
+ "inputs": [
+ "up_conv_block2_1"
+ ]
+ },
+ "up_conv_relu2_1": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "up_conv_norm2_1"
+ ]
+ },
+ "up_interpolate_0": {
+ "type": "interpolate",
+ "params": {
+ "scale_factor": 2.0,
+ "mode": "bilinear",
+ "align_corners": false
+ },
+ "inputs": [
+ "up_conv_relu2_1"
+ ]
+ },
+ "up_conv_0": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 32,
+ "out_channels": 16,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "up_interpolate_0"
+ ]
+ },
+ "up_norm_0": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 16
+ },
+ "inputs": [
+ "up_conv_0"
+ ]
+ },
+ "up_relu_0": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "up_norm_0"
+ ]
+ },
+ "skip_concat_0": {
+ "type": "concat",
+ "params": {
+ "dim": 1
+ },
+ "inputs": [
+ "down_relu2_0",
+ "up_relu_0"
+ ]
+ },
+ "up_conv_block1_0": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 32,
+ "out_channels": 16,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "skip_concat_0"
+ ]
+ },
+ "up_conv_norm1_0": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 16
+ },
+ "inputs": [
+ "up_conv_block1_0"
+ ]
+ },
+ "up_conv_relu1_0": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "up_conv_norm1_0"
+ ]
+ },
+ "up_conv_block2_0": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 16,
+ "out_channels": 16,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "up_conv_relu1_0"
+ ]
+ },
+ "up_conv_norm2_0": {
+ "type": "groupnorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": 16
+ },
+ "inputs": [
+ "up_conv_block2_0"
+ ]
+ },
+ "up_conv_relu2_0": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "up_conv_norm2_0"
+ ]
+ },
+ "out_conv1": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 16,
+ "out_channels": 8,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "up_conv_relu2_0"
+ ]
+ },
+ "out_relu": {
+ "type": "leaky_relu",
+ "params": {
+ "negative_slope": 0.1
+ },
+ "inputs": [
+ "out_conv1"
+ ]
+ },
+ "out_conv2": {
+ "type": "conv2d",
+ "params": {
+ "in_channels": 8,
+ "out_channels": 1,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ },
+ "inputs": [
+ "out_relu"
+ ]
+ },
+ "final_output": {
+ "type": "sigmoid",
+ "inputs": [
+ "out_conv2"
+ ]
+ }
+ },
+ "outputs": [
+ "final_output"
+ ]
+}
\ No newline at end of file
diff --git a/scenarios/brats/config/pipeline_config.json b/scenarios/brats/config/pipeline_config.json
new file mode 100644
index 0000000..d678435
--- /dev/null
+++ b/scenarios/brats/config/pipeline_config.json
@@ -0,0 +1,706 @@
+{
+ "pipeline": [
+ {
+ "name": "DirectoryJoin",
+ "config": {
+ "datasets": [
+ {
+ "id": "19517ba8-bab8-11ed-afa1-0242ac120002",
+ "provider": "brats_A",
+ "name": "BraTS_Brain_MRI_set_A",
+ "mount_path": "/mnt/remote/brats_A"
+ },
+ {
+ "id": "216d5cc6-bab8-11ed-afa1-0242ac120002",
+ "provider": "brats_B",
+ "name": "BraTS_Brain_MRI_set_B",
+ "mount_path": "/mnt/remote/brats_B"
+ },
+ {
+ "id": "2830a144-bab8-11ed-afa1-0242ac120002",
+ "provider": "brats_C",
+ "name": "BraTS_Brain_MRI_set_C",
+ "mount_path": "/mnt/remote/brats_C"
+ },
+ {
+ "id": "2a3b0c4e-bab8-11ed-afa1-0242ac120002",
+ "provider": "brats_D",
+ "name": "BraTS_Brain_MRI_set_D",
+ "mount_path": "/mnt/remote/brats_D"
+ }
+ ],
+ "joined_dataset": "/tmp/brats_joined"
+ }
+ },
+ {
+ "name": "Train_DL",
+ "config": {
+ "paths": {
+ "input_dataset_path": "/tmp/brats_joined",
+ "saved_weights_path": "/mnt/remote/model/model.safetensors",
+ "trained_model_output_path": "/mnt/remote/output"
+ },
+ "model_type": "safetensors",
+ "is_private": true,
+ "privacy_params": {
+ "max_grad_norm": 0.1,
+ "epsilon": 1.5,
+ "delta": 0.005
+ },
+ "device": "cpu",
+ "batch_size": 8,
+ "optimizer": {
+ "name": "Adam",
+ "params": {
+ "lr": 0.0001
+ }
+ },
+ "scheduler": {
+ "name": "CyclicLR",
+ "params": {
+ "base_lr": 0.0001,
+ "max_lr": 0.01,
+ "cycle_momentum": false
+ }
+ },
+ "total_epochs": 1,
+ "model_config": {
+ "submodules": {
+ "ConvBlock2d": {
+ "layers": {
+ "conv1": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": "$in_ch",
+ "out_channels": "$mid_ch",
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ }
+ },
+ "norm1": {
+ "class": "nn.GroupNorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": "$mid_ch"
+ }
+ },
+ "conv2": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": "$mid_ch",
+ "out_channels": "$out_ch",
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ }
+ },
+ "norm2": {
+ "class": "nn.GroupNorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": "$out_ch"
+ }
+ }
+ },
+ "input": [
+ "x"
+ ],
+ "forward": [
+ {
+ "ops": [
+ "conv1",
+ "norm1"
+ ],
+ "input": [
+ "x"
+ ],
+ "output": "x1"
+ },
+ {
+ "ops": [
+ [
+ "F.leaky_relu",
+ {
+ "negative_slope": 0.1
+ }
+ ]
+ ],
+ "input": [
+ "x1"
+ ],
+ "output": "x2"
+ },
+ {
+ "ops": [
+ "conv2",
+ "norm2"
+ ],
+ "input": [
+ "x2"
+ ],
+ "output": "x3"
+ },
+ {
+ "ops": [
+ [
+ "F.leaky_relu",
+ {
+ "negative_slope": 0.1
+ }
+ ]
+ ],
+ "input": [
+ "x3"
+ ],
+ "output": "x4"
+ }
+ ],
+ "output": [
+ "x4"
+ ]
+ },
+ "Upsample": {
+ "layers": {
+ "conv1": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": "$in_ch",
+ "out_channels": "$out_ch",
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ }
+ },
+ "norm1": {
+ "class": "nn.GroupNorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": "$out_ch"
+ }
+ }
+ },
+ "input": [
+ "x",
+ "encoded_feature"
+ ],
+ "forward": [
+ {
+ "ops": [
+ [
+ "F.interpolate",
+ {
+ "scale_factor": 2.0,
+ "mode": "bilinear",
+ "align_corners": false
+ }
+ ]
+ ],
+ "input": [
+ "x"
+ ],
+ "output": "x1"
+ },
+ {
+ "ops": [
+ "conv1",
+ "norm1"
+ ],
+ "input": [
+ "x1"
+ ],
+ "output": "x2"
+ },
+ {
+ "ops": [
+ [
+ "F.leaky_relu",
+ {
+ "negative_slope": 0.1
+ }
+ ]
+ ],
+ "input": [
+ "x2"
+ ],
+ "output": "x3"
+ },
+ {
+ "ops": [
+ [
+ "torch.cat",
+ {
+ "dim": 1
+ }
+ ]
+ ],
+ "input": [
+ [
+ "x3",
+ "encoded_feature"
+ ]
+ ],
+ "output": "x4"
+ }
+ ],
+ "output": [
+ "x4"
+ ]
+ }
+ },
+ "layers": {
+ "in_conv": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": 1,
+ "out_channels": 8,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ }
+ },
+ "down_conv0": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 8,
+ "mid_ch": 16,
+ "out_ch": 16
+ }
+ },
+ "down_pool0": {
+ "class": "nn.MaxPool2d",
+ "params": {
+ "kernel_size": 2,
+ "stride": 2
+ }
+ },
+ "down_conv1": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 16,
+ "mid_ch": 32,
+ "out_ch": 32
+ }
+ },
+ "down_pool1": {
+ "class": "nn.MaxPool2d",
+ "params": {
+ "kernel_size": 2,
+ "stride": 2
+ }
+ },
+ "down_conv2": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 32,
+ "mid_ch": 64,
+ "out_ch": 64
+ }
+ },
+ "down_pool2": {
+ "class": "nn.MaxPool2d",
+ "params": {
+ "kernel_size": 2,
+ "stride": 2
+ }
+ },
+ "down_conv3": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 64,
+ "mid_ch": 128,
+ "out_ch": 128
+ }
+ },
+ "down_pool3": {
+ "class": "nn.MaxPool2d",
+ "params": {
+ "kernel_size": 2,
+ "stride": 2
+ }
+ },
+ "bottleneck": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 128,
+ "mid_ch": 256,
+ "out_ch": 256
+ }
+ },
+ "up_samp3": {
+ "submodule": "Upsample",
+ "params": {
+ "in_ch": 256,
+ "out_ch": 128
+ }
+ },
+ "up_conv3": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 256,
+ "mid_ch": 128,
+ "out_ch": 128
+ }
+ },
+ "up_samp2": {
+ "submodule": "Upsample",
+ "params": {
+ "in_ch": 128,
+ "out_ch": 64
+ }
+ },
+ "up_conv2": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 128,
+ "mid_ch": 64,
+ "out_ch": 64
+ }
+ },
+ "up_samp1": {
+ "submodule": "Upsample",
+ "params": {
+ "in_ch": 64,
+ "out_ch": 32
+ }
+ },
+ "up_conv1": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 64,
+ "mid_ch": 32,
+ "out_ch": 32
+ }
+ },
+ "up_samp0": {
+ "submodule": "Upsample",
+ "params": {
+ "in_ch": 32,
+ "out_ch": 16
+ }
+ },
+ "up_conv0": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 32,
+ "mid_ch": 16,
+ "out_ch": 16
+ }
+ },
+ "out_conv1": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": 16,
+ "out_channels": 8,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ }
+ },
+ "out_conv2": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": 8,
+ "out_channels": 1,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ }
+ }
+ },
+ "input": [
+ "x"
+ ],
+ "forward": [
+ {
+ "ops": [
+ "in_conv"
+ ],
+ "input": [
+ "x"
+ ],
+ "output": "x1"
+ },
+ {
+ "ops": [
+ "down_conv0"
+ ],
+ "input": [
+ "x1"
+ ],
+ "output": "x2"
+ },
+ {
+ "ops": [
+ "down_pool0"
+ ],
+ "input": [
+ "x2"
+ ],
+ "output": "x3"
+ },
+ {
+ "ops": [
+ "down_conv1"
+ ],
+ "input": [
+ "x3"
+ ],
+ "output": "x4"
+ },
+ {
+ "ops": [
+ "down_pool1"
+ ],
+ "input": [
+ "x4"
+ ],
+ "output": "x5"
+ },
+ {
+ "ops": [
+ "down_conv2"
+ ],
+ "input": [
+ "x5"
+ ],
+ "output": "x6"
+ },
+ {
+ "ops": [
+ "down_pool2"
+ ],
+ "input": [
+ "x6"
+ ],
+ "output": "x7"
+ },
+ {
+ "ops": [
+ "down_conv3"
+ ],
+ "input": [
+ "x7"
+ ],
+ "output": "x8"
+ },
+ {
+ "ops": [
+ "down_pool3"
+ ],
+ "input": [
+ "x8"
+ ],
+ "output": "x9"
+ },
+ {
+ "ops": [
+ "bottleneck"
+ ],
+ "input": [
+ "x9"
+ ],
+ "output": "x10"
+ },
+ {
+ "ops": [
+ "up_samp3"
+ ],
+ "input": [
+ "x10",
+ "x8"
+ ],
+ "output": "x11"
+ },
+ {
+ "ops": [
+ "up_conv3"
+ ],
+ "input": [
+ "x11"
+ ],
+ "output": "x12"
+ },
+ {
+ "ops": [
+ "up_samp2"
+ ],
+ "input": [
+ "x12",
+ "x6"
+ ],
+ "output": "x13"
+ },
+ {
+ "ops": [
+ "up_conv2"
+ ],
+ "input": [
+ "x13"
+ ],
+ "output": "x14"
+ },
+ {
+ "ops": [
+ "up_samp1"
+ ],
+ "input": [
+ "x14",
+ "x4"
+ ],
+ "output": "x15"
+ },
+ {
+ "ops": [
+ "up_conv1"
+ ],
+ "input": [
+ "x15"
+ ],
+ "output": "x16"
+ },
+ {
+ "ops": [
+ "up_samp0"
+ ],
+ "input": [
+ "x16",
+ "x2"
+ ],
+ "output": "x17"
+ },
+ {
+ "ops": [
+ "up_conv0"
+ ],
+ "input": [
+ "x17"
+ ],
+ "output": "x18"
+ },
+ {
+ "ops": [
+ "out_conv1"
+ ],
+ "input": [
+ "x18"
+ ],
+ "output": "x19"
+ },
+ {
+ "ops": [
+ [
+ "F.leaky_relu",
+ {
+ "negative_slope": 0.1
+ }
+ ]
+ ],
+ "input": [
+ "x19"
+ ],
+ "output": "x20"
+ },
+ {
+ "ops": [
+ "out_conv2"
+ ],
+ "input": [
+ "x20"
+ ],
+ "output": "x21"
+ },
+ {
+ "ops": [
+ "torch.sigmoid"
+ ],
+ "input": [
+ "x21"
+ ],
+ "output": "x22"
+ }
+ ],
+ "output": [
+ "x22"
+ ]
+ },
+ "dataset_config": {
+ "type": "directory",
+ "structure_type": "paired",
+ "data_type": "image",
+ "pairing": {
+ "folder_pattern": "BraTS20_Training_*",
+ "input_pattern": "*_flair.png",
+ "target_pattern": "*_seg.png"
+ },
+ "image_config": {
+ "use_cv2": true,
+ "convert_to_pil": true,
+ "grayscale": true,
+ "to_tensor": true,
+ "binarize": true,
+ "binarize_threshold": 0
+ },
+ "filtering": {
+ "filter_empty_targets": true
+ },
+ "splits": {
+ "train": 0.7,
+ "val": 0.2,
+ "test": 0.1,
+ "random_state": 42
+ }
+ },
+ "loss_config": {
+ "expression": "dice_loss + 2 * l1_loss",
+ "components": {
+ "dice_loss": {
+ "class": "monai.losses.DiceLoss",
+ "params": {
+ "sigmoid": true,
+ "squared_pred": true,
+ "reduction": "mean"
+ }
+ },
+ "l1_loss": {
+ "class": "torch.nn.L1Loss",
+ "params": {
+ "reduction": "mean"
+ }
+ },
+ "bce_loss": {
+ "class": "torch.nn.functional.binary_cross_entropy_with_logits",
+ "params": {
+ "input": "outputs",
+ "target": "targets",
+ "reduction": "mean"
+ }
+ }
+ }
+ },
+ "metrics": [
+ {
+ "name": "dice_score",
+ "params": {
+ "threshold": 0.3
+ }
+ },
+ {
+ "name": "jaccard_index",
+ "params": {
+ "threshold": 0.3
+ }
+ },
+ {
+ "name": "hausdorff_distance",
+ "params": {
+ "threshold": 0.3
+ }
+ }
+ ],
+ "task_type": "segmentation",
+ "threshold": 0.3
+ }
+ }
+ ]
+}
diff --git a/scenarios/brats/config/templates/pipeline_config_template.json b/scenarios/brats/config/templates/pipeline_config_template.json
new file mode 100644
index 0000000..43e9e84
--- /dev/null
+++ b/scenarios/brats/config/templates/pipeline_config_template.json
@@ -0,0 +1,3 @@
+{
+ "pipeline": []
+}
\ No newline at end of file
diff --git a/scenarios/brats/config/templates/train_config_template.json b/scenarios/brats/config/templates/train_config_template.json
new file mode 100644
index 0000000..a20bdcd
--- /dev/null
+++ b/scenarios/brats/config/templates/train_config_template.json
@@ -0,0 +1,34 @@
+{
+ "name": "Train_DL",
+ "config": {
+ "paths": {
+ "input_dataset_path": "/tmp/brats_joined",
+ "saved_weights_path": "/mnt/remote/model/model.safetensors",
+ "trained_model_output_path": "/mnt/remote/output"
+ },
+ "model_type": "safetensors",
+ "is_private": true,
+ "privacy_params": {
+ "max_grad_norm": 0.1,
+ "epsilon": 1.5,
+ "delta": 0.005
+ },
+ "device": "cpu",
+ "batch_size": 8,
+ "optimizer": {
+ "name": "Adam",
+ "params": {
+ "lr": 1e-4
+ }
+ },
+ "scheduler": {
+ "name": "CyclicLR",
+ "params": {
+ "base_lr": 1e-4,
+ "max_lr": 1e-2,
+ "cycle_momentum": false
+ }
+ },
+ "total_epochs": 1
+ }
+}
\ No newline at end of file
diff --git a/scenarios/brats/config/train_config.json b/scenarios/brats/config/train_config.json
new file mode 100644
index 0000000..40c3e06
--- /dev/null
+++ b/scenarios/brats/config/train_config.json
@@ -0,0 +1,670 @@
+{
+ "name": "Train_DL",
+ "config": {
+ "paths": {
+ "input_dataset_path": "/tmp/brats_joined",
+ "saved_weights_path": "/mnt/remote/model/model.safetensors",
+ "trained_model_output_path": "/mnt/remote/output"
+ },
+ "model_type": "safetensors",
+ "is_private": true,
+ "privacy_params": {
+ "max_grad_norm": 0.1,
+ "epsilon": 1.5,
+ "delta": 0.005
+ },
+ "device": "cpu",
+ "batch_size": 8,
+ "optimizer": {
+ "name": "Adam",
+ "params": {
+ "lr": 0.0001
+ }
+ },
+ "scheduler": {
+ "name": "CyclicLR",
+ "params": {
+ "base_lr": 0.0001,
+ "max_lr": 0.01,
+ "cycle_momentum": false
+ }
+ },
+ "total_epochs": 1,
+ "model_config": {
+ "submodules": {
+ "ConvBlock2d": {
+ "layers": {
+ "conv1": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": "$in_ch",
+ "out_channels": "$mid_ch",
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ }
+ },
+ "norm1": {
+ "class": "nn.GroupNorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": "$mid_ch"
+ }
+ },
+ "conv2": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": "$mid_ch",
+ "out_channels": "$out_ch",
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ }
+ },
+ "norm2": {
+ "class": "nn.GroupNorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": "$out_ch"
+ }
+ }
+ },
+ "input": [
+ "x"
+ ],
+ "forward": [
+ {
+ "ops": [
+ "conv1",
+ "norm1"
+ ],
+ "input": [
+ "x"
+ ],
+ "output": "x1"
+ },
+ {
+ "ops": [
+ [
+ "F.leaky_relu",
+ {
+ "negative_slope": 0.1
+ }
+ ]
+ ],
+ "input": [
+ "x1"
+ ],
+ "output": "x2"
+ },
+ {
+ "ops": [
+ "conv2",
+ "norm2"
+ ],
+ "input": [
+ "x2"
+ ],
+ "output": "x3"
+ },
+ {
+ "ops": [
+ [
+ "F.leaky_relu",
+ {
+ "negative_slope": 0.1
+ }
+ ]
+ ],
+ "input": [
+ "x3"
+ ],
+ "output": "x4"
+ }
+ ],
+ "output": [
+ "x4"
+ ]
+ },
+ "Upsample": {
+ "layers": {
+ "conv1": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": "$in_ch",
+ "out_channels": "$out_ch",
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ }
+ },
+ "norm1": {
+ "class": "nn.GroupNorm",
+ "params": {
+ "num_groups": 1,
+ "num_channels": "$out_ch"
+ }
+ }
+ },
+ "input": [
+ "x",
+ "encoded_feature"
+ ],
+ "forward": [
+ {
+ "ops": [
+ [
+ "F.interpolate",
+ {
+ "scale_factor": 2.0,
+ "mode": "bilinear",
+ "align_corners": false
+ }
+ ]
+ ],
+ "input": [
+ "x"
+ ],
+ "output": "x1"
+ },
+ {
+ "ops": [
+ "conv1",
+ "norm1"
+ ],
+ "input": [
+ "x1"
+ ],
+ "output": "x2"
+ },
+ {
+ "ops": [
+ [
+ "F.leaky_relu",
+ {
+ "negative_slope": 0.1
+ }
+ ]
+ ],
+ "input": [
+ "x2"
+ ],
+ "output": "x3"
+ },
+ {
+ "ops": [
+ [
+ "torch.cat",
+ {
+ "dim": 1
+ }
+ ]
+ ],
+ "input": [
+ [
+ "x3",
+ "encoded_feature"
+ ]
+ ],
+ "output": "x4"
+ }
+ ],
+ "output": [
+ "x4"
+ ]
+ }
+ },
+ "layers": {
+ "in_conv": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": 1,
+ "out_channels": 8,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ }
+ },
+ "down_conv0": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 8,
+ "mid_ch": 16,
+ "out_ch": 16
+ }
+ },
+ "down_pool0": {
+ "class": "nn.MaxPool2d",
+ "params": {
+ "kernel_size": 2,
+ "stride": 2
+ }
+ },
+ "down_conv1": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 16,
+ "mid_ch": 32,
+ "out_ch": 32
+ }
+ },
+ "down_pool1": {
+ "class": "nn.MaxPool2d",
+ "params": {
+ "kernel_size": 2,
+ "stride": 2
+ }
+ },
+ "down_conv2": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 32,
+ "mid_ch": 64,
+ "out_ch": 64
+ }
+ },
+ "down_pool2": {
+ "class": "nn.MaxPool2d",
+ "params": {
+ "kernel_size": 2,
+ "stride": 2
+ }
+ },
+ "down_conv3": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 64,
+ "mid_ch": 128,
+ "out_ch": 128
+ }
+ },
+ "down_pool3": {
+ "class": "nn.MaxPool2d",
+ "params": {
+ "kernel_size": 2,
+ "stride": 2
+ }
+ },
+ "bottleneck": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 128,
+ "mid_ch": 256,
+ "out_ch": 256
+ }
+ },
+ "up_samp3": {
+ "submodule": "Upsample",
+ "params": {
+ "in_ch": 256,
+ "out_ch": 128
+ }
+ },
+ "up_conv3": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 256,
+ "mid_ch": 128,
+ "out_ch": 128
+ }
+ },
+ "up_samp2": {
+ "submodule": "Upsample",
+ "params": {
+ "in_ch": 128,
+ "out_ch": 64
+ }
+ },
+ "up_conv2": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 128,
+ "mid_ch": 64,
+ "out_ch": 64
+ }
+ },
+ "up_samp1": {
+ "submodule": "Upsample",
+ "params": {
+ "in_ch": 64,
+ "out_ch": 32
+ }
+ },
+ "up_conv1": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 64,
+ "mid_ch": 32,
+ "out_ch": 32
+ }
+ },
+ "up_samp0": {
+ "submodule": "Upsample",
+ "params": {
+ "in_ch": 32,
+ "out_ch": 16
+ }
+ },
+ "up_conv0": {
+ "submodule": "ConvBlock2d",
+ "params": {
+ "in_ch": 32,
+ "mid_ch": 16,
+ "out_ch": 16
+ }
+ },
+ "out_conv1": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": 16,
+ "out_channels": 8,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ }
+ },
+ "out_conv2": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": 8,
+ "out_channels": 1,
+ "kernel_size": 3,
+ "stride": 1,
+ "padding": 1
+ }
+ }
+ },
+ "input": [
+ "x"
+ ],
+ "forward": [
+ {
+ "ops": [
+ "in_conv"
+ ],
+ "input": [
+ "x"
+ ],
+ "output": "x1"
+ },
+ {
+ "ops": [
+ "down_conv0"
+ ],
+ "input": [
+ "x1"
+ ],
+ "output": "x2"
+ },
+ {
+ "ops": [
+ "down_pool0"
+ ],
+ "input": [
+ "x2"
+ ],
+ "output": "x3"
+ },
+ {
+ "ops": [
+ "down_conv1"
+ ],
+ "input": [
+ "x3"
+ ],
+ "output": "x4"
+ },
+ {
+ "ops": [
+ "down_pool1"
+ ],
+ "input": [
+ "x4"
+ ],
+ "output": "x5"
+ },
+ {
+ "ops": [
+ "down_conv2"
+ ],
+ "input": [
+ "x5"
+ ],
+ "output": "x6"
+ },
+ {
+ "ops": [
+ "down_pool2"
+ ],
+ "input": [
+ "x6"
+ ],
+ "output": "x7"
+ },
+ {
+ "ops": [
+ "down_conv3"
+ ],
+ "input": [
+ "x7"
+ ],
+ "output": "x8"
+ },
+ {
+ "ops": [
+ "down_pool3"
+ ],
+ "input": [
+ "x8"
+ ],
+ "output": "x9"
+ },
+ {
+ "ops": [
+ "bottleneck"
+ ],
+ "input": [
+ "x9"
+ ],
+ "output": "x10"
+ },
+ {
+ "ops": [
+ "up_samp3"
+ ],
+ "input": [
+ "x10",
+ "x8"
+ ],
+ "output": "x11"
+ },
+ {
+ "ops": [
+ "up_conv3"
+ ],
+ "input": [
+ "x11"
+ ],
+ "output": "x12"
+ },
+ {
+ "ops": [
+ "up_samp2"
+ ],
+ "input": [
+ "x12",
+ "x6"
+ ],
+ "output": "x13"
+ },
+ {
+ "ops": [
+ "up_conv2"
+ ],
+ "input": [
+ "x13"
+ ],
+ "output": "x14"
+ },
+ {
+ "ops": [
+ "up_samp1"
+ ],
+ "input": [
+ "x14",
+ "x4"
+ ],
+ "output": "x15"
+ },
+ {
+ "ops": [
+ "up_conv1"
+ ],
+ "input": [
+ "x15"
+ ],
+ "output": "x16"
+ },
+ {
+ "ops": [
+ "up_samp0"
+ ],
+ "input": [
+ "x16",
+ "x2"
+ ],
+ "output": "x17"
+ },
+ {
+ "ops": [
+ "up_conv0"
+ ],
+ "input": [
+ "x17"
+ ],
+ "output": "x18"
+ },
+ {
+ "ops": [
+ "out_conv1"
+ ],
+ "input": [
+ "x18"
+ ],
+ "output": "x19"
+ },
+ {
+ "ops": [
+ [
+ "F.leaky_relu",
+ {
+ "negative_slope": 0.1
+ }
+ ]
+ ],
+ "input": [
+ "x19"
+ ],
+ "output": "x20"
+ },
+ {
+ "ops": [
+ "out_conv2"
+ ],
+ "input": [
+ "x20"
+ ],
+ "output": "x21"
+ },
+ {
+ "ops": [
+ "torch.sigmoid"
+ ],
+ "input": [
+ "x21"
+ ],
+ "output": "x22"
+ }
+ ],
+ "output": [
+ "x22"
+ ]
+ },
+ "dataset_config": {
+ "type": "directory",
+ "structure_type": "paired",
+ "data_type": "image",
+ "pairing": {
+ "folder_pattern": "BraTS20_Training_*",
+ "input_pattern": "*_flair.png",
+ "target_pattern": "*_seg.png"
+ },
+ "image_config": {
+ "use_cv2": true,
+ "convert_to_pil": true,
+ "grayscale": true,
+ "to_tensor": true,
+ "binarize": true,
+ "binarize_threshold": 0
+ },
+ "filtering": {
+ "filter_empty_targets": true
+ },
+ "splits": {
+ "train": 0.7,
+ "val": 0.2,
+ "test": 0.1,
+ "random_state": 42
+ }
+ },
+ "loss_config": {
+ "expression": "dice_loss + 2 * l1_loss",
+ "components": {
+ "dice_loss": {
+ "class": "monai.losses.DiceLoss",
+ "params": {
+ "sigmoid": true,
+ "squared_pred": true,
+ "reduction": "mean"
+ }
+ },
+ "l1_loss": {
+ "class": "torch.nn.L1Loss",
+ "params": {
+ "reduction": "mean"
+ }
+ },
+ "bce_loss": {
+ "class": "torch.nn.functional.binary_cross_entropy_with_logits",
+ "params": {
+ "input": "outputs",
+ "target": "targets",
+ "reduction": "mean"
+ }
+ }
+ }
+ },
+ "metrics": [
+ {
+ "name": "dice_score",
+ "params": {
+ "threshold": 0.3
+ }
+ },
+ {
+ "name": "jaccard_index",
+ "params": {
+ "threshold": 0.3
+ }
+ },
+ {
+ "name": "hausdorff_distance",
+ "params": {
+ "threshold": 0.3
+ }
+ }
+ ],
+ "task_type": "segmentation",
+ "threshold": 0.3
+ }
+}
diff --git a/scenarios/brats/contract/contract.json b/scenarios/brats/contract/contract.json
new file mode 100644
index 0000000..0204000
--- /dev/null
+++ b/scenarios/brats/contract/contract.json
@@ -0,0 +1,114 @@
+{
+ "id": "f4f72a88-bab1-11ed-afa1-0242ac120002",
+ "schemaVersion": "0.1",
+ "startTime": "2023-03-14T00:00:00.000Z",
+ "expiryTime": "2024-03-14T00:00:00.000Z",
+ "tdc": "",
+ "tdps": [],
+ "ccrp": "did:web:$CCRP_USERNAME.github.io",
+ "datasets": [
+ {
+ "id": "19517ba8-bab8-11ed-afa1-0242ac120002",
+ "name": "brats_A",
+ "url": "https://$AZURE_STORAGE_ACCOUNT_NAME.blob.core.windows.net/$AZURE_BRATS_A_CONTAINER_NAME/data.img",
+ "provider": "",
+ "key": {
+ "type": "azure",
+ "properties": {
+ "kid": "BRATSAFilesystemEncryptionKey",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "endpoint": ""
+ }
+ }
+ },
+ {
+ "id": "216d5cc6-bab8-11ed-afa1-0242ac120002",
+ "name": "brats_B",
+ "url": "https://$AZURE_STORAGE_ACCOUNT_NAME.blob.core.windows.net/$AZURE_BRATS_B_CONTAINER_NAME/data.img",
+ "provider": "",
+ "key": {
+ "type": "azure",
+ "properties": {
+ "kid": "BRATSBFilesystemEncryptionKey",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "endpoint": ""
+ }
+ }
+ },
+ {
+ "id": "2830a144-bab8-11ed-afa1-0242ac120002",
+ "name": "brats_C",
+ "url": "https://$AZURE_STORAGE_ACCOUNT_NAME.blob.core.windows.net/$AZURE_BRATS_C_CONTAINER_NAME/data.img",
+ "provider": "",
+ "key": {
+ "type": "azure",
+ "properties": {
+ "kid": "BRATSCFilesystemEncryptionKey",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "endpoint": ""
+ }
+ }
+ },
+ {
+ "id": "2a3b0c4e-bab8-11ed-afa1-0242ac120002",
+ "name": "brats_D",
+ "url": "https://$AZURE_STORAGE_ACCOUNT_NAME.blob.core.windows.net/$AZURE_BRATS_D_CONTAINER_NAME/data.img",
+ "provider": "",
+ "key": {
+ "type": "azure",
+ "properties": {
+ "kid": "BRATSDFilesystemEncryptionKey",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "endpoint": ""
+ }
+ }
+ }
+ ],
+ "purpose": "TRAINING",
+ "constraints": [
+ {
+ "privacy": [
+ {
+ "dataset": "19517ba8-bab8-11ed-afa1-0242ac120002",
+ "epsilon_threshold": "1.5",
+ "noise_multiplier": "2.0",
+ "delta": "0.01",
+ "epochs_per_report": "2"
+ },
+ {
+ "dataset": "216d5cc6-bab8-11ed-afa1-0242ac120002",
+ "epsilon_threshold": "1.5",
+ "noise_multiplier": "2.0",
+ "delta": "0.01",
+ "epochs_per_report": "2"
+ },
+ {
+ "dataset": "2830a144-bab8-11ed-afa1-0242ac120002",
+ "epsilon_threshold": "1.5",
+ "noise_multiplier": "2.0",
+ "delta": "0.01",
+ "epochs_per_report": "2"
+ },
+ {
+ "dataset": "2a3b0c4e-bab8-11ed-afa1-0242ac120002",
+ "epsilon_threshold": "1.5",
+ "noise_multiplier": "2.0",
+ "delta": "0.01",
+ "epochs_per_report": "2"
+ }
+ ]
+ }
+ ],
+ "terms": {
+ "payment": {},
+ "revocation": {}
+ }
+}
\ No newline at end of file
diff --git a/scenarios/brats/data/brats_A.tar.gz b/scenarios/brats/data/brats_A.tar.gz
new file mode 100644
index 0000000..1066587
Binary files /dev/null and b/scenarios/brats/data/brats_A.tar.gz differ
diff --git a/scenarios/brats/data/brats_B.tar.gz b/scenarios/brats/data/brats_B.tar.gz
new file mode 100644
index 0000000..53f3117
Binary files /dev/null and b/scenarios/brats/data/brats_B.tar.gz differ
diff --git a/scenarios/brats/data/brats_C.tar.gz b/scenarios/brats/data/brats_C.tar.gz
new file mode 100644
index 0000000..9d7584b
Binary files /dev/null and b/scenarios/brats/data/brats_C.tar.gz differ
diff --git a/scenarios/brats/data/brats_D.tar.gz b/scenarios/brats/data/brats_D.tar.gz
new file mode 100644
index 0000000..2f8026c
Binary files /dev/null and b/scenarios/brats/data/brats_D.tar.gz differ
diff --git a/scenarios/brats/deployment/azure/0-create-acr.sh b/scenarios/brats/deployment/azure/0-create-acr.sh
new file mode 100755
index 0000000..71372dc
--- /dev/null
+++ b/scenarios/brats/deployment/azure/0-create-acr.sh
@@ -0,0 +1,39 @@
+#!/bin/bash
+
+# Only to be run when creating a new ACR
+
+# Ensure required env vars are set
+if [[ -z "$CONTAINER_REGISTRY" || -z "$AZURE_RESOURCE_GROUP" || -z "$AZURE_LOCATION" ]]; then
+ echo "ERROR: CONTAINER_REGISTRY, AZURE_RESOURCE_GROUP, and AZURE_LOCATION environment variables must be set."
+ exit 1
+fi
+
+echo "Checking if ACR '$CONTAINER_REGISTRY' exists in resource group '$AZURE_RESOURCE_GROUP'..."
+
+# Check if ACR exists
+ACR_EXISTS=$(az acr show --name "$CONTAINER_REGISTRY" --resource-group "$AZURE_RESOURCE_GROUP" --query "name" -o tsv 2>/dev/null)
+
+if [[ -n "$ACR_EXISTS" ]]; then
+ echo "✅ ACR '$CONTAINER_REGISTRY' already exists."
+else
+ echo "⏳ ACR '$CONTAINER_REGISTRY' does not exist. Creating..."
+
+ az acr create \
+ --name "$CONTAINER_REGISTRY" \
+ --resource-group "$AZURE_RESOURCE_GROUP" \
+ --location "$AZURE_LOCATION" \
+ --sku Premium \
+ --admin-enabled true \
+ --anonymous-pull-enabled true \
+ --output table
+
+ if [[ $? -eq 0 ]]; then
+ echo "✅ ACR '$CONTAINER_REGISTRY' created successfully."
+ else
+ echo "❌ Failed to create ACR."
+ exit 1
+ fi
+fi
+
+# Login to the ACR
+az acr login --name "$CONTAINER_REGISTRY"
\ No newline at end of file
diff --git a/scenarios/brats/deployment/azure/1-create-storage-containers.sh b/scenarios/brats/deployment/azure/1-create-storage-containers.sh
new file mode 100755
index 0000000..4a30d09
--- /dev/null
+++ b/scenarios/brats/deployment/azure/1-create-storage-containers.sh
@@ -0,0 +1,73 @@
+#!/bin/bash
+#
+echo "Checking if resource group $AZURE_RESOURCE_GROUP exists..."
+RG_EXISTS=$(az group exists --name $AZURE_RESOURCE_GROUP)
+
+if [ "$RG_EXISTS" == "false" ]; then
+ echo "Resource group $AZURE_RESOURCE_GROUP does not exist. Creating it now..."
+ # Create the resource group
+ az group create --name $AZURE_RESOURCE_GROUP --location $AZURE_LOCATION
+else
+ echo "Resource group $AZURE_RESOURCE_GROUP already exists. Skipping creation."
+fi
+
+echo "Check if storage account $STORAGE_ACCOUNT_NAME exists..."
+STORAGE_ACCOUNT_EXISTS=$(az storage account check-name --name $AZURE_STORAGE_ACCOUNT_NAME --query "nameAvailable" --output tsv)
+
+if [ "$STORAGE_ACCOUNT_EXISTS" == "true" ]; then
+ echo "Storage account $AZURE_STORAGE_ACCOUNT_NAME does not exist. Creating it now..."
+ az storage account create --resource-group $AZURE_RESOURCE_GROUP --name $AZURE_STORAGE_ACCOUNT_NAME
+else
+ echo "Storage account $AZURE_STORAGE_ACCOUNT_NAME already exists. Skipping creation."
+fi
+
+# Get the storage account key
+ACCOUNT_KEY=$(az storage account keys list --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --query "[0].value" --output tsv)
+
+# Check if the BRATS-A container exists
+CONTAINER_EXISTS=$(az storage container exists --name $AZURE_BRATS_A_CONTAINER_NAME --account-name $AZURE_STORAGE_ACCOUNT_NAME --account-key $ACCOUNT_KEY --query "exists" --output tsv)
+
+if [ "$CONTAINER_EXISTS" == "false" ]; then
+ echo "Container $AZURE_BRATS_A_CONTAINER_NAME does not exist. Creating it now..."
+ az storage container create --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --name $AZURE_BRATS_A_CONTAINER_NAME --account-key $ACCOUNT_KEY
+fi
+
+# Check if the BRATS-B container exists
+CONTAINER_EXISTS=$(az storage container exists --name $AZURE_BRATS_B_CONTAINER_NAME --account-name $AZURE_STORAGE_ACCOUNT_NAME --account-key $ACCOUNT_KEY --query "exists" --output tsv)
+
+if [ "$CONTAINER_EXISTS" == "false" ]; then
+ echo "Container $AZURE_BRATS_B_CONTAINER_NAME does not exist. Creating it now..."
+ az storage container create --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --name $AZURE_BRATS_B_CONTAINER_NAME --account-key $ACCOUNT_KEY
+fi
+
+# Check if the BRATS-C container exists
+CONTAINER_EXISTS=$(az storage container exists --name $AZURE_BRATS_C_CONTAINER_NAME --account-name $AZURE_STORAGE_ACCOUNT_NAME --account-key $ACCOUNT_KEY --query "exists" --output tsv)
+
+if [ "$CONTAINER_EXISTS" == "false" ]; then
+ echo "Container $AZURE_BRATS_C_CONTAINER_NAME does not exist. Creating it now..."
+ az storage container create --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --name $AZURE_BRATS_C_CONTAINER_NAME --account-key $ACCOUNT_KEY
+fi
+
+# Check if the BRATS-D container exists
+CONTAINER_EXISTS=$(az storage container exists --name $AZURE_BRATS_D_CONTAINER_NAME --account-name $AZURE_STORAGE_ACCOUNT_NAME --account-key $ACCOUNT_KEY --query "exists" --output tsv)
+
+if [ "$CONTAINER_EXISTS" == "false" ]; then
+ echo "Container $AZURE_BRATS_D_CONTAINER_NAME does not exist. Creating it now..."
+ az storage container create --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --name $AZURE_BRATS_D_CONTAINER_NAME --account-key $ACCOUNT_KEY
+fi
+
+# Check if the MODEL container exists
+CONTAINER_EXISTS=$(az storage container exists --name $AZURE_MODEL_CONTAINER_NAME --account-name $AZURE_STORAGE_ACCOUNT_NAME --account-key $ACCOUNT_KEY --query "exists" --output tsv)
+
+if [ "$CONTAINER_EXISTS" == "false" ]; then
+ echo "Container $AZURE_MODEL_CONTAINER_NAME does not exist. Creating it now..."
+ az storage container create --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --name $AZURE_MODEL_CONTAINER_NAME --account-key $ACCOUNT_KEY
+fi
+
+# Check if the OUTPUT container exists
+CONTAINER_EXISTS=$(az storage container exists --name $AZURE_OUTPUT_CONTAINER_NAME --account-name $AZURE_STORAGE_ACCOUNT_NAME --account-key $ACCOUNT_KEY --query "exists" --output tsv)
+
+if [ "$CONTAINER_EXISTS" == "false" ]; then
+ echo "Container $AZURE_OUTPUT_CONTAINER_NAME does not exist. Creating it now..."
+ az storage container create --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --name $AZURE_OUTPUT_CONTAINER_NAME --account-key $ACCOUNT_KEY
+fi
diff --git a/scenarios/brats/deployment/azure/2-create-akv.sh b/scenarios/brats/deployment/azure/2-create-akv.sh
new file mode 100755
index 0000000..8a30dd5
--- /dev/null
+++ b/scenarios/brats/deployment/azure/2-create-akv.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+set -e
+
+if [[ "$AZURE_KEYVAULT_ENDPOINT" == *".vault.azure.net" ]]; then
+ AZURE_AKV_RESOURCE_NAME=`echo $AZURE_KEYVAULT_ENDPOINT | awk '{split($0,a,"."); print a[1]}'`
+ # Check if the Key Vault already exists
+ echo "Checking if Key Vault $KEY_VAULT_NAME exists..."
+ NAME_AVAILABLE=$(az rest --method post \
+ --uri "https://management.azure.com/subscriptions/$AZURE_SUBSCRIPTION_ID/providers/Microsoft.KeyVault/checkNameAvailability?api-version=2019-09-01" \
+ --headers "Content-Type=application/json" \
+ --body "{\"name\": \"$AZURE_AKV_RESOURCE_NAME\", \"type\": \"Microsoft.KeyVault/vaults\"}" | jq -r '.nameAvailable')
+ if [ "$NAME_AVAILABLE" == true ]; then
+ echo "Key Vault $KEY_VAULT_NAME does not exist. Creating it now..."
+ echo CREATING $AZURE_KEYVAULT_ENDPOINT in resouce group $AZURE_RESOURCE_GROUP
+ # Create Azure key vault with RBAC authorization
+ az keyvault create --name $AZURE_AKV_RESOURCE_NAME --resource-group $AZURE_RESOURCE_GROUP --sku "Premium" --enable-rbac-authorization
+ # Assign RBAC roles to the resource owner so they can import keys
+ AKV_SCOPE=`az keyvault show --name $AZURE_AKV_RESOURCE_NAME --query id --output tsv`
+ az role assignment create --role "Key Vault Crypto Officer" --assignee `az account show --query user.name --output tsv` --scope $AKV_SCOPE
+ az role assignment create --role "Key Vault Crypto User" --assignee `az account show --query user.name --output tsv` --scope $AKV_SCOPE
+ else
+ echo "Key Vault $AZURE_AKV_RESOURCE_NAME already exists. Skipping creation."
+ fi
+else
+ echo "Automated creation of key vaults is supported only for vaults"
+fi
diff --git a/scenarios/brats/deployment/azure/3-import-keys.sh b/scenarios/brats/deployment/azure/3-import-keys.sh
new file mode 100755
index 0000000..951271e
--- /dev/null
+++ b/scenarios/brats/deployment/azure/3-import-keys.sh
@@ -0,0 +1,59 @@
+#!/bin/bash
+
+# Function to import a key with a given key ID and key material into AKV
+# The key is bound to a key release policy with host data defined in the environment variable CCE_POLICY_HASH
+function import_key() {
+ export KEYID=$1
+ export KEYFILE=$2
+
+ # For RSA-HSM keys, we need to set a salt and label which will be used in the symmetric key derivation
+ if [ "$AZURE_AKV_KEY_TYPE" = "RSA-HSM" ]; then
+ export AZURE_AKV_KEY_DERIVATION_LABEL=$KEYID
+ fi
+
+ CONFIG=$(jq '.claims[0][0].equals = env.CCE_POLICY_HASH' importkey-config-template.json)
+ CONFIG=$(echo $CONFIG | jq '.key.kid = env.KEYID')
+ CONFIG=$(echo $CONFIG | jq '.key.kty = env.AZURE_AKV_KEY_TYPE')
+ CONFIG=$(echo $CONFIG | jq '.key_derivation.salt = "9b53cddbe5b78a0b912a8f05f341bcd4dd839ea85d26a08efaef13e696d999f4"')
+ CONFIG=$(echo $CONFIG | jq '.key_derivation.label = env.AZURE_AKV_KEY_DERIVATION_LABEL')
+ CONFIG=$(echo $CONFIG | jq '.key.akv.endpoint = env.AZURE_KEYVAULT_ENDPOINT')
+ CONFIG=$(echo $CONFIG | jq '.key.akv.bearer_token = env.BEARER_TOKEN')
+ echo $CONFIG > /tmp/importkey-config.json
+ echo "Importing $KEYID key with key release policy"
+ jq '.key.akv.bearer_token = "REDACTED"' /tmp/importkey-config.json
+ pushd . && cd $TOOLS_HOME/importkey && go run main.go -c /tmp/importkey-config.json -out && popd
+ mv $TOOLS_HOME/importkey/keyfile.bin $KEYFILE
+}
+
+echo Obtaining contract service parameters...
+CONTRACT_SERVICE_URL=${CONTRACT_SERVICE_URL:-"http://localhost:8000"}
+export CONTRACT_SERVICE_PARAMETERS=$(curl -k -f $CONTRACT_SERVICE_URL/parameters | base64 --wrap=0)
+
+envsubst < ../../policy/policy-in-template.json > /tmp/policy-in.json
+export CCE_POLICY=$(az confcom acipolicygen -i /tmp/policy-in.json --debug-mode)
+export CCE_POLICY_HASH=$(go run $TOOLS_HOME/securitypolicydigest/main.go -p $CCE_POLICY)
+echo "Training container policy hash $CCE_POLICY_HASH"
+
+# Obtain the token based on the AKV resource endpoint subdomain
+if [[ "$AZURE_KEYVAULT_ENDPOINT" == *".vault.azure.net" ]]; then
+ export BEARER_TOKEN=$(az account get-access-token --resource https://vault.azure.net | jq -r .accessToken)
+ echo "Importing keys to AKV key vaults can be only of type RSA-HSM"
+ export AZURE_AKV_KEY_TYPE="RSA-HSM"
+elif [[ "$AZURE_KEYVAULT_ENDPOINT" == *".managedhsm.azure.net" ]]; then
+ export BEARER_TOKEN=$(az account get-access-token --resource https://managedhsm.azure.net | jq -r .accessToken)
+ export AZURE_AKV_KEY_TYPE="oct-HSM"
+fi
+
+DATADIR=$REPO_ROOT/scenarios/$SCENARIO/data
+MODELDIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+import_key "BRATSAFilesystemEncryptionKey" $DATADIR/brats_A_key.bin
+import_key "BRATSBFilesystemEncryptionKey" $DATADIR/brats_B_key.bin
+import_key "BRATSCFilesystemEncryptionKey" $DATADIR/brats_C_key.bin
+import_key "BRATSDFilesystemEncryptionKey" $DATADIR/brats_D_key.bin
+import_key "ModelFilesystemEncryptionKey" $MODELDIR/model_key.bin
+import_key "OutputFilesystemEncryptionKey" $MODELDIR/output_key.bin
+
+## Cleanup
+rm /tmp/importkey-config.json
+rm /tmp/policy-in.json
diff --git a/scenarios/brats/deployment/azure/4-encrypt-data.sh b/scenarios/brats/deployment/azure/4-encrypt-data.sh
new file mode 100755
index 0000000..a01a01e
--- /dev/null
+++ b/scenarios/brats/deployment/azure/4-encrypt-data.sh
@@ -0,0 +1,14 @@
+#!/bin/bash
+
+DATADIR=$REPO_ROOT/scenarios/$SCENARIO/data
+MODELDIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+./generatefs.sh -d $DATADIR/brats_A/preprocessed -k $DATADIR/brats_A_key.bin -i $DATADIR/brats_A.img
+./generatefs.sh -d $DATADIR/brats_B/preprocessed -k $DATADIR/brats_B_key.bin -i $DATADIR/brats_B.img
+./generatefs.sh -d $DATADIR/brats_C/preprocessed -k $DATADIR/brats_C_key.bin -i $DATADIR/brats_C.img
+./generatefs.sh -d $DATADIR/brats_D/preprocessed -k $DATADIR/brats_D_key.bin -i $DATADIR/brats_D.img
+./generatefs.sh -d $MODELDIR/models -k $MODELDIR/model_key.bin -i $MODELDIR/model.img
+
+sudo rm -rf $MODELDIR/output
+mkdir -p $MODELDIR/output
+./generatefs.sh -d $MODELDIR/output -k $MODELDIR/output_key.bin -i $MODELDIR/output.img
\ No newline at end of file
diff --git a/scenarios/brats/deployment/azure/5-upload-encrypted-data.sh b/scenarios/brats/deployment/azure/5-upload-encrypted-data.sh
new file mode 100755
index 0000000..7b56305
--- /dev/null
+++ b/scenarios/brats/deployment/azure/5-upload-encrypted-data.sh
@@ -0,0 +1,60 @@
+#!/bin/bash
+
+DATADIR=$REPO_ROOT/scenarios/$SCENARIO/data
+MODELDIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+ACCOUNT_KEY=$(az storage account keys list --account-name $AZURE_STORAGE_ACCOUNT_NAME --only-show-errors | jq -r .[0].value)
+
+az storage blob upload \
+ --account-name $AZURE_STORAGE_ACCOUNT_NAME \
+ --container $AZURE_BRATS_A_CONTAINER_NAME \
+ --file $DATADIR/brats_A.img \
+ --name data.img \
+ --type page \
+ --overwrite \
+ --account-key $ACCOUNT_KEY
+
+az storage blob upload \
+ --account-name $AZURE_STORAGE_ACCOUNT_NAME \
+ --container $AZURE_BRATS_B_CONTAINER_NAME \
+ --file $DATADIR/brats_B.img \
+ --name data.img \
+ --type page \
+ --overwrite \
+ --account-key $ACCOUNT_KEY
+
+az storage blob upload \
+ --account-name $AZURE_STORAGE_ACCOUNT_NAME \
+ --container $AZURE_BRATS_C_CONTAINER_NAME \
+ --file $DATADIR/brats_C.img \
+ --name data.img \
+ --type page \
+ --overwrite \
+ --account-key $ACCOUNT_KEY
+
+az storage blob upload \
+ --account-name $AZURE_STORAGE_ACCOUNT_NAME \
+ --container $AZURE_BRATS_D_CONTAINER_NAME \
+ --file $DATADIR/brats_D.img \
+ --name data.img \
+ --type page \
+ --overwrite \
+ --account-key $ACCOUNT_KEY
+
+az storage blob upload \
+ --account-name $AZURE_STORAGE_ACCOUNT_NAME \
+ --container $AZURE_MODEL_CONTAINER_NAME \
+ --file $MODELDIR/model.img \
+ --name data.img \
+ --type page \
+ --overwrite \
+ --account-key $ACCOUNT_KEY
+
+az storage blob upload \
+ --account-name $AZURE_STORAGE_ACCOUNT_NAME \
+ --container $AZURE_OUTPUT_CONTAINER_NAME \
+ --file $MODELDIR/output.img \
+ --name data.img \
+ --type page \
+ --overwrite \
+ --account-key $ACCOUNT_KEY
diff --git a/scenarios/covid/data/6-download-decrypt-model.sh b/scenarios/brats/deployment/azure/6-download-decrypt-model.sh
similarity index 75%
rename from scenarios/covid/data/6-download-decrypt-model.sh
rename to scenarios/brats/deployment/azure/6-download-decrypt-model.sh
index 0975334..b6d043a 100755
--- a/scenarios/covid/data/6-download-decrypt-model.sh
+++ b/scenarios/brats/deployment/azure/6-download-decrypt-model.sh
@@ -1,16 +1,21 @@
#!/bin/bash
+MODELDIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+sudo rm -rf $MODELDIR/output
+mkdir -p $MODELDIR/output
+
ACCOUNT_KEY=$(az storage account keys list --account-name $AZURE_STORAGE_ACCOUNT_NAME --only-show-errors | jq -r .[0].value)
az storage blob download \
--account-name $AZURE_STORAGE_ACCOUNT_NAME \
--container $AZURE_OUTPUT_CONTAINER_NAME \
- --file output.img \
+ --file $MODELDIR/output.img \
--name data.img \
--account-key $ACCOUNT_KEY
-encryptedImage=output.img
-keyFilePath=outputkey.bin
+encryptedImage=$MODELDIR/output.img
+keyFilePath=$MODELDIR/output_key.bin
echo Decrypting $encryptedImage with key $keyFilePath
deviceName=cryptdevice1
@@ -23,7 +28,7 @@ sudo cryptsetup luksOpen "$encryptedImage" "$deviceName" \
mountPoint=`mktemp -d`
sudo mount -t ext4 "$deviceNamePath" "$mountPoint" -o loop
-cp -r $mountPoint/* ./output/
+cp -r $mountPoint/* $MODELDIR/output/
echo "[!] Closing device..."
diff --git a/scenarios/covid/deployment/aci/aci-parameters-template.json b/scenarios/brats/deployment/azure/aci-parameters-template.json
similarity index 100%
rename from scenarios/covid/deployment/aci/aci-parameters-template.json
rename to scenarios/brats/deployment/azure/aci-parameters-template.json
diff --git a/scenarios/brats/deployment/azure/arm-template.json b/scenarios/brats/deployment/azure/arm-template.json
new file mode 100644
index 0000000..47f4e68
--- /dev/null
+++ b/scenarios/brats/deployment/azure/arm-template.json
@@ -0,0 +1,181 @@
+{
+ "$schema": "https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#",
+ "contentVersion": "1.0.0.0",
+ "parameters": {
+ "name": {
+ "defaultValue": "depa-training-brats",
+ "type": "string",
+ "metadata": {
+ "description": "Name for the container group"
+ }
+ },
+ "location": {
+ "defaultValue": "northeurope",
+ "type": "string",
+ "metadata": {
+ "description": "Location for all resources."
+ }
+ },
+ "port": {
+ "defaultValue": 8080,
+ "type": "int",
+ "metadata": {
+ "description": "Port to open on the container and the public IP address."
+ }
+ },
+ "containerRegistry": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "The container registry login server."
+ }
+ },
+ "restartPolicy": {
+ "defaultValue": "Never",
+ "allowedValues": [
+ "Always",
+ "Never",
+ "OnFailure"
+ ],
+ "type": "string",
+ "metadata": {
+ "description": "The behavior of Azure runtime if container has stopped."
+ }
+ },
+ "ccePolicy": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "cce policy"
+ }
+ },
+ "EncfsSideCarArgs": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "Remote file system information for storage sidecar."
+ }
+ },
+ "ContractService": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "URL of contract service"
+ }
+ },
+ "Contracts": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "List of contracts"
+ }
+ },
+ "ContractServiceParameters": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "Contract service parameters"
+ }
+ },
+ "PipelineConfiguration": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "Pipeline configuration"
+ }
+ }
+ },
+ "resources": [
+ {
+ "type": "Microsoft.ContainerInstance/containerGroups",
+ "apiVersion": "2023-05-01",
+ "name": "[parameters('name')]",
+ "location": "[parameters('location')]",
+ "properties": {
+ "confidentialComputeProperties": {
+ "ccePolicy": "[parameters('ccePolicy')]"
+ },
+ "containers": [
+ {
+ "name": "depa-training",
+ "properties": {
+ "image": "[concat(parameters('containerRegistry'), '/depa-training:latest')]",
+ "command": [
+ "/bin/bash",
+ "run.sh"
+ ],
+ "environmentVariables": [],
+ "volumeMounts": [
+ {
+ "name": "remotemounts",
+ "mountPath": "/mnt/remote"
+ }
+ ],
+ "resources": {
+ "requests": {
+ "cpu": 3,
+ "memoryInGB": 12
+ }
+ }
+ }
+ },
+ {
+ "name": "encrypted-storage-sidecar",
+ "properties": {
+ "image": "[concat(parameters('containerRegistry'), '/depa-training-encfs:latest')]",
+ "command": [
+ "/encfs.sh"
+ ],
+ "environmentVariables": [
+ {
+ "name": "EncfsSideCarArgs",
+ "value": "[parameters('EncfsSideCarArgs')]"
+ },
+ {
+ "name": "ContractService",
+ "value": "[parameters('ContractService')]"
+ },
+ {
+ "name": "Contracts",
+ "value": "[parameters('Contracts')]"
+ },
+ {
+ "name": "ContractServiceParameters",
+ "value": "[parameters('ContractServiceParameters')]"
+ },
+ {
+ "name": "PipelineConfiguration",
+ "value": "[parameters('PipelineConfiguration')]"
+ }
+ ],
+ "volumeMounts": [
+ {
+ "name": "remotemounts",
+ "mountPath": "/mnt/remote"
+ }
+ ],
+ "securityContext": {
+ "privileged": "true"
+ },
+ "resources": {
+ "requests": {
+ "cpu": 0.5,
+ "memoryInGB": 2
+ }
+ }
+ }
+ }
+ ],
+ "sku": "Confidential",
+ "osType": "Linux",
+ "restartPolicy": "[parameters('restartPolicy')]",
+ "volumes": [
+ {
+ "name": "remotemounts",
+ "emptydir": {}
+ }
+ ]
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/scenarios/brats/deployment/azure/deploy.sh b/scenarios/brats/deployment/azure/deploy.sh
new file mode 100755
index 0000000..63b3087
--- /dev/null
+++ b/scenarios/brats/deployment/azure/deploy.sh
@@ -0,0 +1,173 @@
+#!/bin/bash
+
+set -e
+
+while getopts ":c:p:" options; do
+ case $options in
+ c)contract=$OPTARG;;
+ p)pipelineConfiguration=$OPTARG;;
+ esac
+done
+
+if [[ -z "${contract}" ]]; then
+ echo "No contract specified"
+ exit 1
+fi
+
+if [[ -z "${pipelineConfiguration}" ]]; then
+ echo "No pipeline configuration specified"
+ exit 1
+fi
+
+if [[ -z "${AZURE_KEYVAULT_ENDPOINT}" ]]; then
+ echo "Environment variable AZURE_KEYVAULT_ENDPOINT not defined"
+fi
+
+echo Obtaining contract service parameters...
+
+CONTRACT_SERVICE_URL=${CONTRACT_SERVICE_URL:-"https://localhost:8000"}
+export CONTRACT_SERVICE_PARAMETERS=$(curl -k -f $CONTRACT_SERVICE_URL/parameters | base64 --wrap=0)
+
+echo Computing CCE policy...
+envsubst < ../../policy/policy-in-template.json > /tmp/policy-in.json
+export CCE_POLICY=$(az confcom acipolicygen -i /tmp/policy-in.json --debug-mode)
+export CCE_POLICY_HASH=$(go run $TOOLS_HOME/securitypolicydigest/main.go -p $CCE_POLICY)
+echo "Training container policy hash $CCE_POLICY_HASH"
+
+export CONTRACTS=$contract
+export PIPELINE_CONFIGURATION=`cat $pipelineConfiguration | base64 --wrap=0`
+
+function generate_encrypted_filesystem_information() {
+ end=`date -u -d "60 minutes" '+%Y-%m-%dT%H:%MZ'`
+ BRATS_A_SAS_TOKEN=$(az storage blob generate-sas --account-name $AZURE_STORAGE_ACCOUNT_NAME --container-name $AZURE_BRATS_A_CONTAINER_NAME --permissions r --name data.img --expiry $end --only-show-errors)
+ export BRATS_A_SAS_TOKEN="$(echo -n $BRATS_A_SAS_TOKEN | tr -d \")"
+ export BRATS_A_SAS_TOKEN="?$BRATS_A_SAS_TOKEN"
+
+ BRATS_B_SAS_TOKEN=$(az storage blob generate-sas --account-name $AZURE_STORAGE_ACCOUNT_NAME --container-name $AZURE_BRATS_B_CONTAINER_NAME --permissions r --name data.img --expiry $end --only-show-errors)
+ export BRATS_B_SAS_TOKEN=$(echo $BRATS_B_SAS_TOKEN | tr -d \")
+ export BRATS_B_SAS_TOKEN="?$BRATS_B_SAS_TOKEN"
+
+ BRATS_C_SAS_TOKEN=$(az storage blob generate-sas --account-name $AZURE_STORAGE_ACCOUNT_NAME --container-name $AZURE_BRATS_C_CONTAINER_NAME --permissions r --name data.img --expiry $end --only-show-errors)
+ export BRATS_C_SAS_TOKEN=$(echo $BRATS_C_SAS_TOKEN | tr -d \")
+ export BRATS_C_SAS_TOKEN="?$BRATS_C_SAS_TOKEN"
+
+ BRATS_D_SAS_TOKEN=$(az storage blob generate-sas --account-name $AZURE_STORAGE_ACCOUNT_NAME --container-name $AZURE_BRATS_D_CONTAINER_NAME --permissions r --name data.img --expiry $end --only-show-errors)
+ export BRATS_D_SAS_TOKEN=$(echo $BRATS_D_SAS_TOKEN | tr -d \")
+ export BRATS_D_SAS_TOKEN="?$BRATS_D_SAS_TOKEN"
+
+ MODEL_SAS_TOKEN=$(az storage blob generate-sas --account-name $AZURE_STORAGE_ACCOUNT_NAME --container-name $AZURE_MODEL_CONTAINER_NAME --permissions r --name data.img --expiry $end --only-show-errors)
+ export MODEL_SAS_TOKEN=$(echo $MODEL_SAS_TOKEN | tr -d \")
+ export MODEL_SAS_TOKEN="?$MODEL_SAS_TOKEN"
+
+ OUTPUT_SAS_TOKEN=$(az storage blob generate-sas --account-name $AZURE_STORAGE_ACCOUNT_NAME --container-name $AZURE_OUTPUT_CONTAINER_NAME --permissions rw --name data.img --expiry $end --only-show-errors)
+ export OUTPUT_SAS_TOKEN=$(echo $OUTPUT_SAS_TOKEN | tr -d \")
+ export OUTPUT_SAS_TOKEN="?$OUTPUT_SAS_TOKEN"
+
+ # Obtain the token based on the AKV resource endpoint subdomain
+ if [[ "$AZURE_KEYVAULT_ENDPOINT" == *".vault.azure.net" ]]; then
+ export BEARER_TOKEN=$(az account get-access-token --resource https://vault.azure.net | jq -r .accessToken)
+ echo "Importing keys to AKV key vaults can be only of type RSA-HSM"
+ export AZURE_AKV_KEY_TYPE="RSA-HSM"
+ elif [[ "$AZURE_KEYVAULT_ENDPOINT" == *".managedhsm.azure.net" ]]; then
+ export BEARER_TOKEN=$(az account get-access-token --resource https://managedhsm.azure.net | jq -r .accessToken)
+ export AZURE_AKV_KEY_TYPE="oct-HSM"
+ fi
+
+ TMP=$(jq . encrypted-filesystem-config-template.json)
+ TMP=`echo $TMP | \
+ jq '.azure_filesystems[0].azure_url = "https://" + env.AZURE_STORAGE_ACCOUNT_NAME + ".blob.core.windows.net/" + env.AZURE_BRATS_A_CONTAINER_NAME + "/data.img" + env.BRATS_A_SAS_TOKEN' | \
+ jq '.azure_filesystems[0].mount_point = "/mnt/remote/brats_A"' | \
+ jq '.azure_filesystems[0].key.kid = "BRATSAFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[0].key.kty = env.AZURE_AKV_KEY_TYPE' | \
+ jq '.azure_filesystems[0].key.akv.endpoint = env.AZURE_KEYVAULT_ENDPOINT' | \
+ jq '.azure_filesystems[0].key.akv.bearer_token = env.BEARER_TOKEN' | \
+ jq '.azure_filesystems[0].key_derivation.label = "BRATSAFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[0].key_derivation.salt = "9b53cddbe5b78a0b912a8f05f341bcd4dd839ea85d26a08efaef13e696d999f4"'`
+
+ TMP=`echo $TMP | \
+ jq '.azure_filesystems[1].azure_url = "https://" + env.AZURE_STORAGE_ACCOUNT_NAME + ".blob.core.windows.net/" + env.AZURE_BRATS_B_CONTAINER_NAME + "/data.img" + env.BRATS_B_SAS_TOKEN' | \
+ jq '.azure_filesystems[1].mount_point = "/mnt/remote/brats_B"' | \
+ jq '.azure_filesystems[1].key.kid = "BRATSBFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[1].key.kty = env.AZURE_AKV_KEY_TYPE' | \
+ jq '.azure_filesystems[1].key.akv.endpoint = env.AZURE_KEYVAULT_ENDPOINT' | \
+ jq '.azure_filesystems[1].key.akv.bearer_token = env.BEARER_TOKEN' | \
+ jq '.azure_filesystems[1].key_derivation.label = "BRATSBFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[1].key_derivation.salt = "9b53cddbe5b78a0b912a8f05f341bcd4dd839ea85d26a08efaef13e696d999f4"'`
+
+ TMP=`echo $TMP | \
+ jq '.azure_filesystems[2].azure_url = "https://" + env.AZURE_STORAGE_ACCOUNT_NAME + ".blob.core.windows.net/" + env.AZURE_BRATS_C_CONTAINER_NAME + "/data.img" + env.BRATS_C_SAS_TOKEN' | \
+ jq '.azure_filesystems[2].mount_point = "/mnt/remote/brats_C"' | \
+ jq '.azure_filesystems[2].key.kid = "BRATSCFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[2].key.kty = env.AZURE_AKV_KEY_TYPE' | \
+ jq '.azure_filesystems[2].key.akv.endpoint = env.AZURE_KEYVAULT_ENDPOINT' | \
+ jq '.azure_filesystems[2].key.akv.bearer_token = env.BEARER_TOKEN' | \
+ jq '.azure_filesystems[2].key_derivation.label = "BRATSCFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[2].key_derivation.salt = "9b53cddbe5b78a0b912a8f05f341bcd4dd839ea85d26a08efaef13e696d999f4"'`
+
+ TMP=`echo $TMP | \
+ jq '.azure_filesystems[3].azure_url = "https://" + env.AZURE_STORAGE_ACCOUNT_NAME + ".blob.core.windows.net/" + env.AZURE_BRATS_D_CONTAINER_NAME + "/data.img" + env.BRATS_D_SAS_TOKEN' | \
+ jq '.azure_filesystems[3].mount_point = "/mnt/remote/brats_D"' | \
+ jq '.azure_filesystems[3].key.kid = "BRATSDFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[3].key.kty = env.AZURE_AKV_KEY_TYPE' | \
+ jq '.azure_filesystems[3].key.akv.endpoint = env.AZURE_KEYVAULT_ENDPOINT' | \
+ jq '.azure_filesystems[3].key.akv.bearer_token = env.BEARER_TOKEN' | \
+ jq '.azure_filesystems[3].key_derivation.label = "BRATSDFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[3].key_derivation.salt = "9b53cddbe5b78a0b912a8f05f341bcd4dd839ea85d26a08efaef13e696d999f4"'`
+
+ TMP=`echo $TMP | \
+ jq '.azure_filesystems[4].azure_url = "https://" + env.AZURE_STORAGE_ACCOUNT_NAME + ".blob.core.windows.net/" + env.AZURE_MODEL_CONTAINER_NAME + "/data.img" + env.MODEL_SAS_TOKEN' | \
+ jq '.azure_filesystems[4].mount_point = "/mnt/remote/model"' | \
+ jq '.azure_filesystems[4].key.kid = "ModelFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[4].key.kty = env.AZURE_AKV_KEY_TYPE' | \
+ jq '.azure_filesystems[4].key.akv.endpoint = env.AZURE_KEYVAULT_ENDPOINT' | \
+ jq '.azure_filesystems[4].key.akv.bearer_token = env.BEARER_TOKEN' | \
+ jq '.azure_filesystems[4].key_derivation.label = "ModelFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[4].key_derivation.salt = "9b53cddbe5b78a0b912a8f05f341bcd4dd839ea85d26a08efaef13e696d999f4"'`
+
+ TMP=`echo $TMP | \
+ jq '.azure_filesystems[5].azure_url = "https://" + env.AZURE_STORAGE_ACCOUNT_NAME + ".blob.core.windows.net/" + env.AZURE_OUTPUT_CONTAINER_NAME + "/data.img" + env.OUTPUT_SAS_TOKEN' | \
+ jq '.azure_filesystems[5].mount_point = "/mnt/remote/output"' | \
+ jq '.azure_filesystems[5].key.kid = "OutputFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[5].key.kty = env.AZURE_AKV_KEY_TYPE' | \
+ jq '.azure_filesystems[5].key.akv.endpoint = env.AZURE_KEYVAULT_ENDPOINT' | \
+ jq '.azure_filesystems[5].key.akv.bearer_token = env.BEARER_TOKEN' | \
+ jq '.azure_filesystems[5].key_derivation.label = "OutputFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[5].key_derivation.salt = "9b53cddbe5b78a0b912a8f05f341bcd4dd839ea85d26a08efaef13e696d999f4"'`
+
+ ENCRYPTED_FILESYSTEM_INFORMATION=`echo $TMP | base64 --wrap=0`
+}
+
+echo Generating encrypted file system information...
+generate_encrypted_filesystem_information
+echo $ENCRYPTED_FILESYSTEM_INFORMATION > /tmp/encrypted-filesystem-config.json
+export ENCRYPTED_FILESYSTEM_INFORMATION
+
+echo Generating parameters for ACI deployment...
+TMP=$(jq '.containerRegistry.value = env.CONTAINER_REGISTRY' aci-parameters-template.json)
+TMP=`echo $TMP | jq '.ccePolicy.value = env.CCE_POLICY'`
+TMP=`echo $TMP | jq '.EncfsSideCarArgs.value = env.ENCRYPTED_FILESYSTEM_INFORMATION'`
+TMP=`echo $TMP | jq '.ContractService.value = env.CONTRACT_SERVICE_URL'`
+TMP=`echo $TMP | jq '.ContractServiceParameters.value = env.CONTRACT_SERVICE_PARAMETERS'`
+TMP=`echo $TMP | jq '.Contracts.value = env.CONTRACTS'`
+TMP=`echo $TMP | jq '.PipelineConfiguration.value = env.PIPELINE_CONFIGURATION'`
+echo $TMP > /tmp/aci-parameters.json
+
+echo Deploying training clean room...
+
+echo "Checking if resource group $AZURE_RESOURCE_GROUP exists..."
+RG_EXISTS=$(az group exists --name $AZURE_RESOURCE_GROUP)
+
+if [ "$RG_EXISTS" == "false" ]; then
+ echo "Resource group $AZURE_RESOURCE_GROUP does not exist. Creating it now..."
+ # Create the resource group
+ az group create --name $AZURE_RESOURCE_GROUP --location $AZURE_LOCATION
+else
+ echo "Resource group $AZURE_RESOURCE_GROUP already exists. Skipping creation."
+fi
+
+az deployment group create \
+ --resource-group $AZURE_RESOURCE_GROUP \
+ --template-file arm-template.json \
+ --parameters @/tmp/aci-parameters.json
+
+echo Deployment complete.
diff --git a/scenarios/brats/deployment/azure/encrypted-filesystem-config-template.json b/scenarios/brats/deployment/azure/encrypted-filesystem-config-template.json
new file mode 100644
index 0000000..d000326
--- /dev/null
+++ b/scenarios/brats/deployment/azure/encrypted-filesystem-config-template.json
@@ -0,0 +1,142 @@
+{
+ "azure_filesystems": [
+ {
+ "azure_url": "",
+ "azure_url_private": false,
+ "read_write": false,
+ "mount_point": "",
+ "key":{
+ "kid": "",
+ "kty": "",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "akv": {
+ "endpoint": "",
+ "api_version": "api-version=7.3-preview",
+ "bearer_token": ""
+ }
+ },
+ "key_derivation":
+ {
+ "salt": "",
+ "label": ""
+ }
+ },
+ {
+ "azure_url": "",
+ "azure_url_private": false,
+ "read_write": false,
+ "mount_point": "",
+ "key":{
+ "kid": "",
+ "kty": "",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "akv": {
+ "endpoint": "",
+ "api_version": "api-version=7.3-preview",
+ "bearer_token": ""
+ }
+ },
+ "key_derivation":
+ {
+ "salt": "",
+ "label": ""
+ }
+ },
+ {
+ "azure_url": "",
+ "azure_url_private": false,
+ "read_write": false,
+ "mount_point": "",
+ "key":{
+ "kid": "",
+ "kty": "",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "akv": {
+ "endpoint": "",
+ "api_version": "api-version=7.3-preview",
+ "bearer_token": ""
+ }
+ },
+ "key_derivation":
+ {
+ "salt": "",
+ "label": ""
+ }
+ },
+ {
+ "azure_url": "",
+ "azure_url_private": false,
+ "read_write": false,
+ "mount_point": "",
+ "key":{
+ "kid": "",
+ "kty": "",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "akv": {
+ "endpoint": "",
+ "api_version": "api-version=7.3-preview",
+ "bearer_token": ""
+ }
+ },
+ "key_derivation":
+ {
+ "salt": "",
+ "label": ""
+ }
+ },
+ {
+ "azure_url": "",
+ "azure_url_private": false,
+ "read_write": false,
+ "mount_point": "",
+ "key":{
+ "kid": "",
+ "kty": "",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "akv": {
+ "endpoint": "",
+ "api_version": "api-version=7.3-preview",
+ "bearer_token": ""
+ }
+ },
+ "key_derivation":
+ {
+ "salt": "",
+ "label": ""
+ }
+ },
+ {
+ "azure_url": "",
+ "azure_url_private": false,
+ "read_write": true,
+ "mount_point": "",
+ "key":{
+ "kid": "",
+ "kty": "",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "akv": {
+ "endpoint": "",
+ "api_version": "api-version=7.3-preview",
+ "bearer_token": ""
+ }
+ },
+ "key_derivation":
+ {
+ "salt": "",
+ "label": ""
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/scenarios/covid/data/generatefs.sh b/scenarios/brats/deployment/azure/generatefs.sh
similarity index 100%
rename from scenarios/covid/data/generatefs.sh
rename to scenarios/brats/deployment/azure/generatefs.sh
diff --git a/scenarios/covid/data/importkey-config-template.json b/scenarios/brats/deployment/azure/importkey-config-template.json
similarity index 100%
rename from scenarios/covid/data/importkey-config-template.json
rename to scenarios/brats/deployment/azure/importkey-config-template.json
diff --git a/scenarios/brats/deployment/local/docker-compose-modelsave.yml b/scenarios/brats/deployment/local/docker-compose-modelsave.yml
new file mode 100644
index 0000000..60a3ed7
--- /dev/null
+++ b/scenarios/brats/deployment/local/docker-compose-modelsave.yml
@@ -0,0 +1,7 @@
+services:
+ model_save:
+ image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}brats-model-save:latest
+ volumes:
+ - $MODEL_OUTPUT_PATH:/mnt/model
+ - $MODEL_CONFIG_PATH:/mnt/config/model_config.json
+ command: ["python3", "save_base_model.py"]
diff --git a/scenarios/brats/deployment/local/docker-compose-preprocess.yml b/scenarios/brats/deployment/local/docker-compose-preprocess.yml
new file mode 100644
index 0000000..1bc5fe5
--- /dev/null
+++ b/scenarios/brats/deployment/local/docker-compose-preprocess.yml
@@ -0,0 +1,25 @@
+services:
+ brats_A:
+ image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}preprocess-brats-a:latest
+ volumes:
+ - $BRATS_A_INPUT_PATH:/mnt/input/data
+ - $BRATS_A_OUTPUT_PATH:/mnt/output/preprocessed
+ command: ["python3", "preprocess_brats_A.py"]
+ brats_B:
+ image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}preprocess-brats-b:latest
+ volumes:
+ - $BRATS_B_INPUT_PATH:/mnt/input/data
+ - $BRATS_B_OUTPUT_PATH:/mnt/output/preprocessed
+ command: ["python3", "preprocess_brats_B.py"]
+ brats_C:
+ image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}preprocess-brats-c:latest
+ volumes:
+ - $BRATS_C_INPUT_PATH:/mnt/input/data
+ - $BRATS_C_OUTPUT_PATH:/mnt/output/preprocessed
+ command: ["python3", "preprocess_brats_C.py"]
+ brats_D:
+ image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}preprocess-brats-d:latest
+ volumes:
+ - $BRATS_D_INPUT_PATH:/mnt/input/data
+ - $BRATS_D_OUTPUT_PATH:/mnt/output/preprocessed
+ command: ["python3", "preprocess_brats_D.py"]
diff --git a/scenarios/brats/deployment/local/docker-compose-train.yml b/scenarios/brats/deployment/local/docker-compose-train.yml
new file mode 100644
index 0000000..58429d5
--- /dev/null
+++ b/scenarios/brats/deployment/local/docker-compose-train.yml
@@ -0,0 +1,13 @@
+services:
+ train:
+ image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}depa-training:latest
+ volumes:
+ - $BRATS_A_INPUT_PATH:/mnt/remote/brats_A
+ - $BRATS_B_INPUT_PATH:/mnt/remote/brats_B
+ - $BRATS_C_INPUT_PATH:/mnt/remote/brats_C
+ - $BRATS_D_INPUT_PATH:/mnt/remote/brats_D
+ - $MODEL_INPUT_PATH:/mnt/remote/model
+ - $MODEL_OUTPUT_PATH:/mnt/remote/output
+ - $CONFIGURATION_PATH:/mnt/remote/config
+ command: ["/bin/bash", "run.sh"]
+
\ No newline at end of file
diff --git a/scenarios/brats/deployment/local/preprocess.sh b/scenarios/brats/deployment/local/preprocess.sh
new file mode 100755
index 0000000..8074e3e
--- /dev/null
+++ b/scenarios/brats/deployment/local/preprocess.sh
@@ -0,0 +1,21 @@
+#!/bin/bash
+
+export REPO_ROOT="$(git rev-parse --show-toplevel)"
+export SCENARIO=brats
+export DATA_DIR=$REPO_ROOT/scenarios/$SCENARIO/data
+
+tar -xzf $DATA_DIR/brats_A.tar.gz -C $DATA_DIR/
+tar -xzf $DATA_DIR/brats_B.tar.gz -C $DATA_DIR/
+tar -xzf $DATA_DIR/brats_C.tar.gz -C $DATA_DIR/
+tar -xzf $DATA_DIR/brats_D.tar.gz -C $DATA_DIR/
+
+export BRATS_A_INPUT_PATH=$DATA_DIR/brats_A/
+export BRATS_A_OUTPUT_PATH=$DATA_DIR/brats_A/preprocessed
+export BRATS_B_INPUT_PATH=$DATA_DIR/brats_B/
+export BRATS_B_OUTPUT_PATH=$DATA_DIR/brats_B/preprocessed
+export BRATS_C_INPUT_PATH=$DATA_DIR/brats_C/
+export BRATS_C_OUTPUT_PATH=$DATA_DIR/brats_C/preprocessed
+export BRATS_D_INPUT_PATH=$DATA_DIR/brats_D/
+export BRATS_D_OUTPUT_PATH=$DATA_DIR/brats_D/preprocessed
+
+docker compose -f docker-compose-preprocess.yml up --remove-orphans
diff --git a/scenarios/brats/deployment/local/save-model.sh b/scenarios/brats/deployment/local/save-model.sh
new file mode 100755
index 0000000..090e90a
--- /dev/null
+++ b/scenarios/brats/deployment/local/save-model.sh
@@ -0,0 +1,8 @@
+#!/bin/bash
+
+export REPO_ROOT="$(git rev-parse --show-toplevel)"
+export SCENARIO=brats
+export MODEL_OUTPUT_PATH=$REPO_ROOT/scenarios/$SCENARIO/modeller/models
+mkdir -p $MODEL_OUTPUT_PATH
+export MODEL_CONFIG_PATH=$REPO_ROOT/scenarios/$SCENARIO/config/model_config.json
+docker compose -f docker-compose-modelsave.yml up --remove-orphans
\ No newline at end of file
diff --git a/scenarios/brats/deployment/local/train.sh b/scenarios/brats/deployment/local/train.sh
new file mode 100755
index 0000000..97a2990
--- /dev/null
+++ b/scenarios/brats/deployment/local/train.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+export REPO_ROOT="$(git rev-parse --show-toplevel)"
+export SCENARIO=brats
+export DATA_DIR=$REPO_ROOT/scenarios/$SCENARIO/data
+export MODEL_DIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+export BRATS_A_INPUT_PATH=$DATA_DIR/brats_A/preprocessed
+export BRATS_B_INPUT_PATH=$DATA_DIR/brats_B/preprocessed
+export BRATS_C_INPUT_PATH=$DATA_DIR/brats_C/preprocessed
+export BRATS_D_INPUT_PATH=$DATA_DIR/brats_D/preprocessed
+
+export MODEL_INPUT_PATH=$MODEL_DIR/models
+
+# export MODEL_OUTPUT_PATH=/tmp/output
+export MODEL_OUTPUT_PATH=$MODEL_DIR/output
+sudo rm -rf $MODEL_OUTPUT_PATH
+mkdir -p $MODEL_OUTPUT_PATH
+
+# export CONFIGURATION_PATH=/tmp
+export CONFIGURATION_PATH=$REPO_ROOT/scenarios/$SCENARIO/config
+# cp $PWD/../../config/pipeline_config.json /tmp/pipeline_config.json
+
+# Run consolidate_pipeline.sh to create pipeline_config.json
+$REPO_ROOT/scenarios/$SCENARIO/config/consolidate_pipeline.sh
+
+docker compose -f docker-compose-train.yml up --remove-orphans
diff --git a/scenarios/brats/export-variables.sh b/scenarios/brats/export-variables.sh
new file mode 100755
index 0000000..7eeeff7
--- /dev/null
+++ b/scenarios/brats/export-variables.sh
@@ -0,0 +1,66 @@
+#!/bin/bash
+
+# Azure Naming Rules:
+#
+# Resource Group:
+# - 1–90 characters
+# - Letters, numbers, underscores, parentheses, hyphens, periods allowed
+# - Cannot end with a period (.)
+# - Case-insensitive, unique within subscription
+#
+# Key Vault:
+# - 3-24 characters
+# - Globally unique name
+# - Lowercase letters, numbers, hyphens only
+# - Must start and end with letter or number
+#
+# Storage Account:
+# - 3-24 characters
+# - Globally unique name
+# - Lowercase letters and numbers only
+#
+# Storage Container:
+# - 3-63 characters
+# - Lowercase letters, numbers, hyphens only
+# - Must start and end with a letter or number
+# - No consecutive hyphens
+# - Unique within storage account
+
+# For cloud resource creation:
+declare -x SCENARIO=brats
+declare -x REPO_ROOT="$(git rev-parse --show-toplevel)"
+declare -x CONTAINER_REGISTRY=ispirt.azurecr.io
+declare -x AZURE_LOCATION=centralindia
+declare -x AZURE_SUBSCRIPTION_ID=
+declare -x AZURE_RESOURCE_GROUP=
+declare -x AZURE_KEYVAULT_ENDPOINT=
+declare -x AZURE_STORAGE_ACCOUNT_NAME=
+
+declare -x AZURE_BRATS_A_CONTAINER_NAME=bratsacontainer
+declare -x AZURE_BRATS_B_CONTAINER_NAME=bratsbcontainer
+declare -x AZURE_BRATS_C_CONTAINER_NAME=bratsccontainer
+declare -x AZURE_BRATS_D_CONTAINER_NAME=bratsdcontainer
+declare -x AZURE_MODEL_CONTAINER_NAME=modelcontainer
+declare -x AZURE_OUTPUT_CONTAINER_NAME=outputcontainer
+
+# For key import:
+declare -x CONTRACT_SERVICE_URL=https://depa-training-contract-service.centralindia.cloudapp.azure.com:8000
+declare -x TOOLS_HOME=$REPO_ROOT/external/confidential-sidecar-containers/tools
+
+# Export all variables to make them available to other scripts
+export SCENARIO
+export REPO_ROOT
+export CONTAINER_REGISTRY
+export AZURE_LOCATION
+export AZURE_SUBSCRIPTION_ID
+export AZURE_RESOURCE_GROUP
+export AZURE_KEYVAULT_ENDPOINT
+export AZURE_STORAGE_ACCOUNT_NAME
+export AZURE_BRATS_A_CONTAINER_NAME
+export AZURE_BRATS_B_CONTAINER_NAME
+export AZURE_BRATS_C_CONTAINER_NAME
+export AZURE_BRATS_D_CONTAINER_NAME
+export AZURE_MODEL_CONTAINER_NAME
+export AZURE_OUTPUT_CONTAINER_NAME
+export CONTRACT_SERVICE_URL
+export TOOLS_HOME
\ No newline at end of file
diff --git a/scenarios/brats/policy/policy-in-template.json b/scenarios/brats/policy/policy-in-template.json
new file mode 100644
index 0000000..3f8567f
--- /dev/null
+++ b/scenarios/brats/policy/policy-in-template.json
@@ -0,0 +1,58 @@
+{
+ "version": "1.0",
+ "containers": [
+ {
+ "containerImage": "$CONTAINER_REGISTRY/depa-training:latest",
+ "command": ["/bin/bash", "run.sh"],
+ "environmentVariables": [],
+ "mounts": [
+ {
+ "mountType": "emptyDir",
+ "mountPath": "/mnt/remote",
+ "readonly": false
+ }
+ ]
+ },
+ {
+ "containerImage": "$CONTAINER_REGISTRY/depa-training-encfs:latest",
+ "environmentVariables": [
+ {
+ "name": "EncfsSideCarArgs",
+ "value": ".+",
+ "strategy": "re2"
+ },
+ {
+ "name": "ContractService",
+ "value": ".+",
+ "strategy": "re2"
+ },
+ {
+ "name": "ContractServiceParameters",
+ "value": "$CONTRACT_SERVICE_PARAMETERS",
+ "strategy": "string"
+ },
+ {
+ "name": "Contracts",
+ "value": ".+",
+ "strategy": "re2"
+ },
+ {
+ "name": "PipelineConfiguration",
+ "value": ".+",
+ "strategy": "re2"
+ }
+ ],
+ "command": ["/encfs.sh"],
+ "securityContext": {
+ "privileged": "true"
+ },
+ "mounts": [
+ {
+ "mountType": "emptyDir",
+ "mountPath": "/mnt/remote",
+ "readonly": false
+ }
+ ]
+ }
+ ]
+}
diff --git a/scenarios/brats/src/model_constructor.py b/scenarios/brats/src/model_constructor.py
new file mode 100644
index 0000000..1d01585
--- /dev/null
+++ b/scenarios/brats/src/model_constructor.py
@@ -0,0 +1,362 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
+
+from typing import Any, Dict, List, Tuple, Callable
+import types
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+APPROVED_NAMESPACES = {
+ "torch": torch,
+ "nn": nn,
+ "F": F,
+}
+
+# Security controls
+ALLOWED_OP_PREFIXES = {"F.", "torch.nn.functional."} # Allow only torch.nn.functional.* by default
+ALLOWED_OPS = {
+ "torch.cat", "torch.stack", "torch.concat", "torch.flatten", "torch.reshape", "torch.permute", "torch.transpose",
+ "torch.unsqueeze", "torch.squeeze", "torch.chunk", "torch.split", "torch.gather", "torch.index_select", "torch.narrow",
+ "torch.sum", "torch.mean", "torch.std", "torch.var", "torch.max", "torch.min", "torch.argmax", "torch.argmin", "torch.norm",
+ "torch.exp", "torch.log", "torch.log1p", "torch.sigmoid", "torch.tanh", "torch.softmax", "torch.log_softmax", "torch.relu", "torch.gelu",
+ "torch.matmul", "torch.mm", "torch.bmm", "torch.addmm", "torch.einsum",
+ "torch.roll", "torch.flip", "torch.rot90", "torch.rot180", "torch.rot270", "torch.rot360",
+}
+
+# Denylist of potentially dangerous kwarg names (case-insensitive)
+DENYLIST_ARG_NAMES = {
+ "out", # in-place writes to user-provided buffers
+ "file", "filename", "path", "dir", "directory", # filesystem
+ "map_location", # avoid device remap surprises
+}
+
+# DoS safeguards
+MAX_FORWARD_STEPS = 200
+MAX_OPS_PER_STEP = 10
+
+
+def _resolve_submodule(path: str) -> Any:
+ """Resolve dotted path like 'nn.Conv2d' or 'torch.sigmoid' to an object.
+ Raises AttributeError if resolution fails.
+ """
+ try:
+ if not isinstance(path, str):
+ raise TypeError("path must be a string")
+ parts = path.split(".")
+ if parts[0] in APPROVED_NAMESPACES:
+ obj = APPROVED_NAMESPACES[parts[0]]
+ else:
+ # allow direct module names like 'math' if needed
+ raise AttributeError(f"Unknown root namespace '{parts[0]}' in path '{path}'")
+ for p in parts[1:]:
+ try:
+ obj = getattr(obj, p)
+ except AttributeError:
+ raise AttributeError(f"Could not resolve attribute '{p}' in path '{path}'")
+ return obj
+ except Exception as e:
+ raise RuntimeError(f"Error resolving dotted path '{path}': {str(e)}") from e
+
+
+def _replace_placeholders(obj: Any, params: Dict[str, Any]) -> Any:
+ """Recursively replace strings of the form '$name' using params mapping."""
+ try:
+ if isinstance(obj, str) and obj.startswith("$"):
+ key = obj[1:]
+ if key not in params:
+ raise KeyError(f"Placeholder '{obj}' not found in params {params}")
+ return params[key]
+ elif isinstance(obj, dict):
+ try:
+ return {k: _replace_placeholders(v, params) for k, v in obj.items()}
+ except Exception as e:
+ raise RuntimeError(f"Error replacing placeholders in dict: {str(e)}") from e
+ elif isinstance(obj, (list, tuple)):
+ try:
+ seq_type = list if isinstance(obj, list) else tuple
+ return seq_type(_replace_placeholders(x, params) for x in obj)
+ except Exception as e:
+ raise RuntimeError(f"Error replacing placeholders in sequence: {str(e)}") from e
+ else:
+ return obj
+ except Exception as e:
+ raise RuntimeError(f"Error in placeholder replacement: {str(e)}") from e
+
+
+class ModelFactory:
+ """Factory for building PyTorch nn.Module instances from config dicts.
+
+ Public API:
+ ModelFactory.load_from_dict(config: dict) -> nn.Module
+ """
+
+ @classmethod
+ def load_from_dict(cls, config: Dict[str, Any]) -> nn.Module:
+ """Create an nn.Module instance from a top-level config.
+
+ The config may define 'submodules' (a dict of reusable component templates) and
+ a top-level 'layers' and 'forward' graph. Submodules are used by layers that have
+ a 'submodule' key and are instantiated with their provided params.
+ """
+ try:
+ if not isinstance(config, dict):
+ raise TypeError("Config must be a dictionary")
+
+ submodules_defs = config.get("submodules", {})
+
+ def create_instance_from_def(def_cfg: Dict[str, Any], provided_params: Dict[str, Any]):
+ try:
+ # Replace placeholders in the def_cfg copy
+ # Deep copy not strictly necessary since we replace on the fly
+ replaced_cfg = {
+ k: (_replace_placeholders(v, provided_params) if k in ("layers",) or isinstance(v, dict) else v)
+ for k, v in def_cfg.items()
+ }
+ # Build module from replaced config (submodule templates should not themselves contain further 'submodules')
+ return cls._build_module_from_config(replaced_cfg, submodules_defs)
+ except Exception as e:
+ raise RuntimeError(f"Error creating instance from definition: {str(e)}") from e
+
+ # When a layer entry references a 'submodule', we instantiate it using template from submodules_defs
+ return cls._build_module_from_config(config, submodules_defs)
+ except Exception as e:
+ raise RuntimeError(f"Error loading model from config: {str(e)}") from e
+
+ @classmethod
+ def _build_module_from_config(cls, config: Dict[str, Any], submodules_defs: Dict[str, Any]) -> nn.Module:
+ try:
+ layers_cfg = config.get("layers", {})
+ forward_cfg = config.get("forward", [])
+ input_names = config.get("input", [])
+ output_names = config.get("output", [])
+
+ # Create dynamic module class
+ class DynamicModule(nn.Module):
+ def __init__(self):
+ try:
+ super().__init__()
+ # ModuleDict to register submodules / layers
+ self._layers = nn.ModuleDict()
+ # Save forward graph and io names
+ self._forward_cfg = forward_cfg
+ self._input_names = input_names
+ self._output_names = output_names
+
+ # Build each layer / submodule
+ for name, entry in layers_cfg.items():
+ try:
+ if "class" in entry:
+ cls_obj = _resolve_submodule(entry["class"]) # e.g. nn.Conv2d
+ if not (isinstance(cls_obj, type) and issubclass(cls_obj, nn.Module)):
+ raise TypeError(f"Layer '{name}' class must be an nn.Module subclass, got {cls_obj}")
+ params = entry.get("params", {})
+ inst_params = _replace_placeholders(params, {}) # top-level layers likely have no placeholders
+ module = cls_obj(**inst_params)
+ self._layers[name] = module
+ elif "submodule" in entry:
+ sub_name = entry["submodule"]
+ if sub_name not in submodules_defs:
+ raise KeyError(f"Submodule '{sub_name}' not found in submodules definitions")
+ sub_def = submodules_defs[sub_name]
+ provided_params = entry.get("params", {})
+ # Replace placeholders inside sub_def using provided_params
+ # We create a fresh instance of submodule by calling helper
+ sub_inst = cls._instantiate_submodule(sub_def, provided_params, submodules_defs)
+ self._layers[name] = sub_inst
+ else:
+ raise KeyError(f"Layer '{name}' must contain either 'class' or 'submodule' key")
+ except Exception as e:
+ raise RuntimeError(f"Error building layer '{name}': {str(e)}") from e
+ except Exception as e:
+ raise RuntimeError(f"Error initializing DynamicModule: {str(e)}") from e
+
+ def forward(self, *args, **kwargs):
+ try:
+ # Map inputs
+ env: Dict[str, Any] = {}
+ # assign by position
+ for i, in_name in enumerate(self._input_names):
+ if i < len(args):
+ env[in_name] = args[i]
+ elif in_name in kwargs:
+ env[in_name] = kwargs[in_name]
+ else:
+ raise ValueError(f"Missing input '{in_name}' for forward; provided args={len(args)}, kwargs keys={list(kwargs.keys())}")
+
+ # Execute forward graph
+ if len(self._forward_cfg) > MAX_FORWARD_STEPS:
+ raise RuntimeError(f"Too many forward steps: {len(self._forward_cfg)} > {MAX_FORWARD_STEPS}. This is a security feature to prevent infinite loops.")
+
+ for idx, step in enumerate(self._forward_cfg):
+ try:
+ ops = step.get("ops", [])
+ if isinstance(ops, (list, tuple)) and len(ops) > MAX_OPS_PER_STEP:
+ raise RuntimeError(f"Too many ops in step {idx}: {len(ops)} > {MAX_OPS_PER_STEP}")
+ inputs_spec = step.get("input", [])
+ out_name = step.get("output", None)
+
+ # Resolve input tensors for this step
+ # inputs_spec might be: ['x'] or ['x1','x2'] or [['x3','encoded_feature']]
+ if len(inputs_spec) == 1 and isinstance(inputs_spec[0], (list, tuple)):
+ args_list = [env[n] for n in inputs_spec[0]]
+ else:
+ args_list = [env[n] for n in inputs_spec]
+
+ # Apply ops sequentially
+ current = args_list
+ for op in ops:
+ try:
+ # op can be string like 'conv1' or dotted 'F.relu'
+ # or can be a list like ['torch.flatten', {'start_dim':1}]
+ op_callable, op_kwargs = self._resolve_op(op)
+ # Validate kwargs denylist
+ for k in op_kwargs.keys():
+ if isinstance(k, str) and k.lower() in DENYLIST_ARG_NAMES:
+ raise PermissionError(f"Denied kwarg '{k}' for op '{op}'")
+
+ # If op_callable is a module in self._layers, call with module semantics
+ if isinstance(op_callable, str) and op_callable in self._layers:
+ module = self._layers[op_callable]
+ # if current is list of multiple args, pass them all
+ if isinstance(current, (list, tuple)) and len(current) > 1:
+ result = module(*current)
+ else:
+ result = module(current[0])
+ else:
+ # op_callable is a real callable object
+
+ if op_callable in {torch.cat, torch.stack}: # Ops that require a sequence input (instead of varargs)
+ # Wrap current into a list
+ result = op_callable(list(current), **op_kwargs)
+ elif isinstance(current, (list, tuple)):
+ result = op_callable(*current, **op_kwargs)
+ else:
+ result = op_callable(current, **op_kwargs)
+
+ # prepare current for next op
+ current = [result]
+ except Exception as e:
+ raise RuntimeError(f"Error applying operation '{op}': {str(e)}") from e
+
+ # write outputs back into env
+ if out_name is None:
+ continue
+ if isinstance(out_name, (list, tuple)):
+ # if step produces multiple outputs (rare), try unpacking
+ if len(out_name) == 1:
+ env[out_name[0]] = current[0]
+ else:
+ # try to unpack
+ try:
+ for k, v in zip(out_name, current[0]):
+ env[k] = v
+ except Exception as e:
+ raise RuntimeError(f"Could not assign multiple outputs for step {step}: {e}")
+ else:
+ env[out_name] = current[0]
+ except Exception as e:
+ raise RuntimeError(f"Error executing forward step: {str(e)}") from e
+
+ # Build function return
+ if len(self._output_names) == 0:
+ return None
+ if len(self._output_names) == 1:
+ return env[self._output_names[0]]
+ return tuple(env[n] for n in self._output_names)
+ except Exception as e:
+ raise RuntimeError(f"Error in forward pass: {str(e)}") from e
+
+ def _resolve_op(self, op_spec):
+ """Return (callable_or_module_name, kwargs)
+
+ If op_spec is a string and matches a layer name -> returns (layer_name_str, {}).
+ If op_spec is a string dotted path -> resolve dotted and return (callable, {}).
+ If op_spec is a list like ["torch.flatten", {"start_dim":1}] -> resolve and return (callable, kwargs)
+ """
+ try:
+ # module reference by name
+ if isinstance(op_spec, str):
+ if op_spec in self._layers:
+ return (op_spec, {})
+ # dotted function (F.relu, torch.sigmoid)
+ if not _is_allowed_op_path(op_spec):
+ raise PermissionError(f"Operation '{op_spec}' is not allowed")
+ callable_obj = _resolve_submodule(op_spec)
+ if not callable(callable_obj):
+ raise TypeError(f"Resolved object for '{op_spec}' is not callable")
+ return (callable_obj, {})
+ elif isinstance(op_spec, (list, tuple)):
+ if len(op_spec) == 0:
+ raise ValueError("Empty op_spec list")
+ path = op_spec[0]
+ kwargs = op_spec[1] if len(op_spec) > 1 else {}
+ if not _is_allowed_op_path(path):
+ raise PermissionError(f"Operation '{path}' is not allowed")
+ callable_obj = _resolve_submodule(path)
+ if not callable(callable_obj):
+ raise TypeError(f"Resolved object for '{path}' is not callable")
+ return (callable_obj, kwargs)
+ else:
+ raise TypeError(f"Unsupported op spec type: {type(op_spec)}")
+ except Exception as e:
+ raise RuntimeError(f"Error resolving operation '{op_spec}': {str(e)}") from e
+
+ # Instantiate dynamic module and return
+ dyn = DynamicModule()
+ return dyn
+ except Exception as e:
+ raise RuntimeError(f"Error building module from config: {str(e)}") from e
+
+ @classmethod
+ def _instantiate_submodule(cls, sub_def: Dict[str, Any], provided_params: Dict[str, Any], submodules_defs: Dict[str, Any]) -> nn.Module:
+ """Instantiate a submodule defined in 'submodules' using provided_params to replace placeholders.
+
+ provided_params are used to replace occurrences of strings like '$in_ch' inside the sub_def's 'layers' params.
+ """
+ try:
+ # Deep replace placeholders within sub_def copy
+ # We'll construct a new config where the "layers"->"params" are substituted
+ replaced = {}
+ for k, v in sub_def.items():
+ try:
+ if k == "layers":
+ new_layers = {}
+ for lname, lentry in v.items():
+ new_entry = dict(lentry)
+ if "params" in lentry:
+ new_entry["params"] = _replace_placeholders(lentry["params"], provided_params)
+ new_layers[lname] = new_entry
+ replaced[k] = new_layers
+ else:
+ # copy other keys directly (input/forward/output)
+ replaced[k] = v
+ except Exception as e:
+ raise RuntimeError(f"Error processing key '{k}': {str(e)}") from e
+
+ # Now build a module from this replaced config. This call may in turn instantiate nested submodules.
+ return cls._build_module_from_config(replaced, submodules_defs)
+ except Exception as e:
+ raise RuntimeError(f"Error instantiating submodule: {str(e)}") from e
+
+
+def _is_allowed_op_path(path: str) -> bool:
+ if any(path.startswith(p) for p in ALLOWED_OP_PREFIXES):
+ return True
+ return path in ALLOWED_OPS
\ No newline at end of file
diff --git a/scenarios/brats/src/preprocess_brats_A.py b/scenarios/brats/src/preprocess_brats_A.py
new file mode 100644
index 0000000..c2fe7e4
--- /dev/null
+++ b/scenarios/brats/src/preprocess_brats_A.py
@@ -0,0 +1,80 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
+import os
+import nibabel as nib
+import numpy as np
+from PIL import Image
+import matplotlib.pyplot as plt
+
+# Note: This script is designed to preprocess the BraTS dataset by extracting the middle axial slice from each NIfTI file and saving as PNG.
+
+# Commented out since NiFTI files are bulky and time consuming for this demo. The preprocessed datasets are already present in the repository.
+# To run the preprocessing step if you have the NIfTI files available in the specified input directories (scenario/mri-tumor-segmentation/data), uncomment the code below.
+
+'''
+def get_middle_axial_slice(nifti_path):
+ """Load NIfTI file and return center axial slice"""
+ img = nib.load(nifti_path)
+ data = img.get_fdata()
+
+ # Get center axial slice
+ axial_slices = data.shape[2]
+ center_slice = data[:, :, axial_slices // 2]
+
+ return center_slice
+
+
+def normalize_slice(slice_data):
+ """Normalize slice to 0-255 range"""
+ slice_data = slice_data.astype(np.float32)
+ if np.max(slice_data) > 0: # avoid division by zero
+ slice_data = (slice_data - np.min(slice_data)) / (np.max(slice_data) - np.min(slice_data)) * 255
+ return slice_data.astype(np.uint8)
+
+
+## Process all NIfTI files in directory structure and save as PNGs (middle axial slice)
+
+input_root = "/mnt/input/data"
+output_root = "/mnt/output/preprocessed/"
+
+for root, dirs, files in os.walk(input_root):
+ for file in files:
+ if file.endswith('.nii.gz'):
+ # Create output path maintaining structure
+ rel_path = os.path.relpath(root, input_root)
+ output_dir = os.path.join(output_root, rel_path)
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Process NIfTI file
+ input_path = os.path.join(root, file)
+ try:
+ center_slice = get_middle_axial_slice(input_path)
+ normalized_slice = normalize_slice(center_slice)
+
+ # Create PNG filename (replace .nii.gz with .png)
+ png_filename = file.replace('.nii.gz', '.png')
+ output_path = os.path.join(output_dir, png_filename)
+
+ # Save as PNG
+ Image.fromarray(normalized_slice).save(output_path)
+ print(f"Processed: {input_path} -> {output_path}")
+ except Exception as e:
+ print(f"Error processing {input_path}: {str(e)}")
+'''
+
+print("Preprocessed BraTS_A dataset saved to data/brats_A/preprocessed/")
\ No newline at end of file
diff --git a/scenarios/brats/src/preprocess_brats_B.py b/scenarios/brats/src/preprocess_brats_B.py
new file mode 100644
index 0000000..64427ca
--- /dev/null
+++ b/scenarios/brats/src/preprocess_brats_B.py
@@ -0,0 +1,80 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
+import os
+import nibabel as nib
+import numpy as np
+from PIL import Image
+import matplotlib.pyplot as plt
+
+# Note: This script is designed to preprocess the BraTS dataset by extracting the middle axial slice from each NIfTI file and saving as PNG.
+
+# Commented out since NiFTI files are bulky and time consuming for this demo. The preprocessed datasets are already present in the repository.
+# To run the preprocessing step if you have the NIfTI files available in the specified input directories (scenario/mri-tumor-segmentation/data), uncomment the code below.
+
+'''
+def get_middle_axial_slice(nifti_path):
+ """Load NIfTI file and return center axial slice"""
+ img = nib.load(nifti_path)
+ data = img.get_fdata()
+
+ # Get center axial slice
+ axial_slices = data.shape[2]
+ center_slice = data[:, :, axial_slices // 2]
+
+ return center_slice
+
+
+def normalize_slice(slice_data):
+ """Normalize slice to 0-255 range"""
+ slice_data = slice_data.astype(np.float32)
+ if np.max(slice_data) > 0: # avoid division by zero
+ slice_data = (slice_data - np.min(slice_data)) / (np.max(slice_data) - np.min(slice_data)) * 255
+ return slice_data.astype(np.uint8)
+
+
+## Process all NIfTI files in directory structure and save as PNGs (middle axial slice)
+
+input_root = "/mnt/input/data"
+output_root = "/mnt/output/preprocessed/"
+
+for root, dirs, files in os.walk(input_root):
+ for file in files:
+ if file.endswith('.nii.gz'):
+ # Create output path maintaining structure
+ rel_path = os.path.relpath(root, input_root)
+ output_dir = os.path.join(output_root, rel_path)
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Process NIfTI file
+ input_path = os.path.join(root, file)
+ try:
+ center_slice = get_middle_axial_slice(input_path)
+ normalized_slice = normalize_slice(center_slice)
+
+ # Create PNG filename (replace .nii.gz with .png)
+ png_filename = file.replace('.nii.gz', '.png')
+ output_path = os.path.join(output_dir, png_filename)
+
+ # Save as PNG
+ Image.fromarray(normalized_slice).save(output_path)
+ print(f"Processed: {input_path} -> {output_path}")
+ except Exception as e:
+ print(f"Error processing {input_path}: {str(e)}")
+'''
+
+print("Preprocessed BraTS_B dataset saved to data/brats_B/preprocessed/")
\ No newline at end of file
diff --git a/scenarios/brats/src/preprocess_brats_C.py b/scenarios/brats/src/preprocess_brats_C.py
new file mode 100644
index 0000000..4d0ccbf
--- /dev/null
+++ b/scenarios/brats/src/preprocess_brats_C.py
@@ -0,0 +1,80 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
+import os
+import nibabel as nib
+import numpy as np
+from PIL import Image
+import matplotlib.pyplot as plt
+
+# Note: This script is designed to preprocess the BraTS dataset by extracting the middle axial slice from each NIfTI file and saving as PNG.
+
+# Commented out since NiFTI files are bulky and time consuming for this demo. The preprocessed datasets are already present in the repository.
+# To run the preprocessing step if you have the NIfTI files available in the specified input directories (scenario/mri-tumor-segmentation/data), uncomment the code below.
+
+'''
+def get_middle_axial_slice(nifti_path):
+ """Load NIfTI file and return center axial slice"""
+ img = nib.load(nifti_path)
+ data = img.get_fdata()
+
+ # Get center axial slice
+ axial_slices = data.shape[2]
+ center_slice = data[:, :, axial_slices // 2]
+
+ return center_slice
+
+
+def normalize_slice(slice_data):
+ """Normalize slice to 0-255 range"""
+ slice_data = slice_data.astype(np.float32)
+ if np.max(slice_data) > 0: # avoid division by zero
+ slice_data = (slice_data - np.min(slice_data)) / (np.max(slice_data) - np.min(slice_data)) * 255
+ return slice_data.astype(np.uint8)
+
+
+## Process all NIfTI files in directory structure and save as PNGs (middle axial slice)
+
+input_root = "/mnt/input/data"
+output_root = "/mnt/output/preprocessed/"
+
+for root, dirs, files in os.walk(input_root):
+ for file in files:
+ if file.endswith('.nii.gz'):
+ # Create output path maintaining structure
+ rel_path = os.path.relpath(root, input_root)
+ output_dir = os.path.join(output_root, rel_path)
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Process NIfTI file
+ input_path = os.path.join(root, file)
+ try:
+ center_slice = get_middle_axial_slice(input_path)
+ normalized_slice = normalize_slice(center_slice)
+
+ # Create PNG filename (replace .nii.gz with .png)
+ png_filename = file.replace('.nii.gz', '.png')
+ output_path = os.path.join(output_dir, png_filename)
+
+ # Save as PNG
+ Image.fromarray(normalized_slice).save(output_path)
+ print(f"Processed: {input_path} -> {output_path}")
+ except Exception as e:
+ print(f"Error processing {input_path}: {str(e)}")
+'''
+
+print("Preprocessed BraTS_C dataset saved to data/brats_C/preprocessed/")
\ No newline at end of file
diff --git a/scenarios/brats/src/preprocess_brats_D.py b/scenarios/brats/src/preprocess_brats_D.py
new file mode 100644
index 0000000..e978b4a
--- /dev/null
+++ b/scenarios/brats/src/preprocess_brats_D.py
@@ -0,0 +1,80 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
+import os
+import nibabel as nib
+import numpy as np
+from PIL import Image
+import matplotlib.pyplot as plt
+
+# Note: This script is designed to preprocess the BraTS dataset by extracting the middle axial slice from each NIfTI file and saving as PNG.
+
+# Commented out since NiFTI files are bulky and time consuming for this demo. The preprocessed datasets are already present in the repository.
+# To run the preprocessing step if you have the NIfTI files available in the specified input directories (scenario/mri-tumor-segmentation/data), uncomment the code below.
+
+'''
+def get_middle_axial_slice(nifti_path):
+ """Load NIfTI file and return center axial slice"""
+ img = nib.load(nifti_path)
+ data = img.get_fdata()
+
+ # Get center axial slice
+ axial_slices = data.shape[2]
+ center_slice = data[:, :, axial_slices // 2]
+
+ return center_slice
+
+
+def normalize_slice(slice_data):
+ """Normalize slice to 0-255 range"""
+ slice_data = slice_data.astype(np.float32)
+ if np.max(slice_data) > 0: # avoid division by zero
+ slice_data = (slice_data - np.min(slice_data)) / (np.max(slice_data) - np.min(slice_data)) * 255
+ return slice_data.astype(np.uint8)
+
+
+## Process all NIfTI files in directory structure and save as PNGs (middle axial slice)
+
+input_root = "/mnt/input/data"
+output_root = "/mnt/output/preprocessed/"
+
+for root, dirs, files in os.walk(input_root):
+ for file in files:
+ if file.endswith('.nii.gz'):
+ # Create output path maintaining structure
+ rel_path = os.path.relpath(root, input_root)
+ output_dir = os.path.join(output_root, rel_path)
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Process NIfTI file
+ input_path = os.path.join(root, file)
+ try:
+ center_slice = get_middle_axial_slice(input_path)
+ normalized_slice = normalize_slice(center_slice)
+
+ # Create PNG filename (replace .nii.gz with .png)
+ png_filename = file.replace('.nii.gz', '.png')
+ output_path = os.path.join(output_dir, png_filename)
+
+ # Save as PNG
+ Image.fromarray(normalized_slice).save(output_path)
+ print(f"Processed: {input_path} -> {output_path}")
+ except Exception as e:
+ print(f"Error processing {input_path}: {str(e)}")
+'''
+
+print("Preprocessed BraTS_D dataset saved to data/brats_D/preprocessed/")
\ No newline at end of file
diff --git a/scenarios/brats/src/save_base_model.py b/scenarios/brats/src/save_base_model.py
new file mode 100644
index 0000000..87bd101
--- /dev/null
+++ b/scenarios/brats/src/save_base_model.py
@@ -0,0 +1,128 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import json
+from safetensors.torch import save_file as st_save
+
+from model_constructor import *
+
+model_config_path = "/mnt/config/model_config.json"
+
+with open(model_config_path, 'r') as f:
+ config = json.load(f)
+
+model = ModelFactory.load_from_dict(config)
+
+model.eval()
+st_save(model.state_dict(), "/mnt/model/model.safetensors")
+print("Model saved as safetensors to /mnt/model/model.safetensors")
+
+
+'''
+
+# Reference model architecture and components for Anatomy UNet
+
+class ConvBlock2d(nn.Module):
+ def __init__(self, in_ch, mid_ch, out_ch):
+ super().__init__()
+ self.conv = nn.Sequential(
+ nn.Conv2d(in_ch, mid_ch, 3, 1, 1),
+ # nn.InstanceNorm2d(mid_ch),
+ nn.GroupNorm(1, mid_ch),
+ nn.LeakyReLU(0.1),
+ nn.Conv2d(mid_ch, out_ch, 3, 1, 1),
+ # nn.InstanceNorm2d(out_ch),
+ nn.GroupNorm(1, out_ch),
+ nn.LeakyReLU(0.1)
+ )
+
+ def forward(self, in_tensor):
+ return self.conv(in_tensor)
+
+
+class Upsample(nn.Module):
+ def __init__(self, in_ch):
+ super().__init__()
+ out_ch = in_ch // 2
+ self.conv = nn.Sequential(
+ nn.Conv2d(in_ch, out_ch, 3, 1, 1),
+ # nn.InstanceNorm2d(out_ch),
+ nn.GroupNorm(1, out_ch),
+ nn.LeakyReLU(0.1)
+ )
+
+ def forward(self, in_tensor, encoded_feature):
+ up_sampled_tensor = F.interpolate(in_tensor, size=None, scale_factor=2.0, mode='bilinear', align_corners=False)
+ up_sampled_tensor = self.conv(up_sampled_tensor)
+ return torch.cat([encoded_feature, up_sampled_tensor], dim=1)
+
+class Base_Model(nn.Module):
+ def __init__(self, in_ch, out_ch, num_lvs=4, base_ch=16, final_act='noact'):
+ super().__init__()
+ self.final_act = final_act
+ self.in_conv = nn.Conv2d(in_ch, base_ch, 3, 1, 1)
+
+ self.down_convs = nn.ModuleList()
+ self.down_samples = nn.ModuleList()
+ self.up_samples = nn.ModuleList()
+ self.up_convs = nn.ModuleList()
+ for lv in range(num_lvs):
+ ch = base_ch * (2 ** lv)
+ self.down_convs.append(ConvBlock2d(ch, ch * 2, ch * 2))
+ self.down_samples.append(nn.MaxPool2d(kernel_size=2, stride=2))
+ self.up_samples.append(Upsample(ch * 4))
+ self.up_convs.append(ConvBlock2d(ch * 4, ch * 2, ch * 2))
+ bottleneck_ch = base_ch * (2 ** num_lvs)
+ self.bottleneck_conv = ConvBlock2d(bottleneck_ch, bottleneck_ch * 2, bottleneck_ch * 2)
+ self.out_conv = nn.Sequential(nn.Conv2d(base_ch * 2, base_ch, 3, 1, 1),
+ nn.LeakyReLU(0.1),
+ nn.Conv2d(base_ch, out_ch, 3, 1, 1))
+
+ def forward(self, in_tensor):
+ encoded_features = []
+ x = self.in_conv(in_tensor)
+ for down_conv, down_sample in zip(self.down_convs, self.down_samples):
+ down_conv_out = down_conv(x)
+ x = down_sample(down_conv_out)
+ encoded_features.append(down_conv_out)
+ x = self.bottleneck_conv(x)
+ for encoded_feature, up_conv, up_sample in zip(reversed(encoded_features),
+ reversed(self.up_convs),
+ reversed(self.up_samples)):
+ x = up_sample(x, encoded_feature)
+ x = up_conv(x)
+ x = self.out_conv(x)
+ if self.final_act == 'sigmoid':
+ x = torch.sigmoid(x)
+ elif self.final_act == "relu":
+ x = torch.relu(x)
+ elif self.final_act == 'tanh':
+ x = torch.tanh(x)
+ else:
+ x = x
+ return x
+
+
+model = Base_Model(in_ch=1, out_ch=1, base_ch=8, final_act='sigmoid').to('cpu')
+model.eval()
+
+torch.save(model.state_dict(), "/mnt/model/model.pth")
+print("Model saved as pth to /mnt/model/model.pth")
+'''
\ No newline at end of file
diff --git a/scenarios/cifar10/.gitignore b/scenarios/cifar10/.gitignore
new file mode 100644
index 0000000..9d6d918
--- /dev/null
+++ b/scenarios/cifar10/.gitignore
@@ -0,0 +1,17 @@
+**/preprocessed/*.csv
+*.bin
+*.img
+*.pth
+*.pt
+*.onnx
+*.npy
+*.safetensors
+*.h5
+*.hdf5
+
+# Ignore modeller output folder (relative to repo root)
+modeller/output/
+
+data/cifar*
+
+**/__pycache__/
\ No newline at end of file
diff --git a/scenarios/cifar10/README.md b/scenarios/cifar10/README.md
new file mode 100644
index 0000000..e92b08f
--- /dev/null
+++ b/scenarios/cifar10/README.md
@@ -0,0 +1,368 @@
+# CIFAR-10 Image Classification
+
+## Scenario Type
+
+| Scenario name | Scenario type | Task type | Privacy | No. of TDPs* | Data type (format) | Model type (format) | Join type (No. of datasets) |
+|--------------|---------------|-----------------|--------------|-----------|------------|------------|------------|
+| CIFAR-10 | Training - Deep Learning | Multi-class Image Classification | NA | 1 | Non-PII image data (SafeTensors) | CNN (Safetensors) | NA (1)|
+
+---
+
+## Scenario Description
+
+This scenario involves training a CNN for image classification using the CIFAR-10 image dataset. It involves one Training Data Provider (TDP), and a Training Data Consumer (TDC) who wishes to train a model on the dataset. The CIFAR-10 dataset is a collection of 60,000 32x32 color images in 10 classes (viz. airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks), with 6,000 images per class.
+
+The end-to-end training pipeline consists of the following phases.
+
+1. Data pre-processing
+2. Data packaging, encryption and upload
+3. Model packaging, encryption and upload
+4. Encryption key import with key release policies
+5. Deployment and execution of CCR
+6. Trained model decryption
+
+## Build container images
+
+Build container images required for this sample as follows:
+
+```bash
+export SCENARIO=cifar10
+export REPO_ROOT="$(git rev-parse --show-toplevel)"
+cd $REPO_ROOT/scenarios/$SCENARIO
+./ci/build.sh
+```
+
+This script builds the following container images.
+
+- ```preprocess-cifar10```: Container for pre-processing CIFAR-10 dataset.
+- ```cifar10-model-save```: Container that saves the model to be trained in PyTorch format.
+
+Alternatively, you can pull and use pre-built container images from the ispirt container registry by setting the following environment variable. Docker hub has started throttling which may effect the upload/download time, especially when images are bigger size. So, It is advisable to use other container registries. We are using Azure container registry (ACR) as shown below:
+
+```bash
+export CONTAINER_REGISTRY=ispirt.azurecr.io
+cd $REPO_ROOT/scenarios/$SCENARIO
+./ci/pull-containers.sh
+```
+
+## Data pre-processing
+
+The folder ```scenarios/cifar10/src``` contains scripts for downloading and pre-processing the CIFAR-10 dataset. Acting as a Training Data Provider (TDP), prepare your datasets.
+
+```bash
+cd $REPO_ROOT/scenarios/$SCENARIO/deployment/local
+./preprocess.sh
+```
+
+The datasets are saved to the [data](./data/) directory.
+
+## Prepare model for training
+
+Next, acting as a Training Data Consumer (TDC), define and save your base model for training using the following script. This calls the [save_base_model.py](./src/save_base_model.py) script, which is a custom script that saves the model to the [models](./modeller/models) directory, as a PyTorch file. This may first require you to edit the [model config](./config/model_config.json) to specify the model architecture and parameters.
+
+```bash
+./save-model.sh
+```
+
+## Deploy locally
+
+Assuming you have cleartext access to all the datasets, you can train the model _locally_ as follows:
+
+```bash
+./train.sh
+```
+
+The script joins the datasets and trains the model using a pipeline configuration. To modify the various components of the training pipeline, you can edit the training config files in the [config](./config/) directory. The training config files are used to create the pipeline configuration ([pipeline_config.json](./config/pipeline_config.json)) created by consolidating all the TDC's training config files, namely the [model config](./config/model_config.json), [dataset config](./config/dataset_config.json), [loss function config](./config/loss_config.json), [training config](./config/train_config_template.json), [evaluation config](./config/eval_config.json), and if multiple datasets are used, the [data join config](./config/join_config.json). These enable the TDC to design highly customized training pipelines without requiring review and approval of new custom code for each use case—reducing risks from potentially malicious or non-compliant code. The consolidated pipeline configuration is then attested against the signed contract using the TDP’s policy-as-code. If approved, it is executed in the CCR to train the model, which we will deploy in the next section.
+
+```mermaid
+flowchart TD
+
+ subgraph Config Files
+ C1[model_config.json]
+ C2[dataset_config.json]
+ C3[loss_config.json]
+ C4[train_config_template.json]
+ C5[eval_config.json]
+ C6[join_config.json]
+ end
+
+ B[Consolidated into
pipeline_config.json]
+
+ C1 --> B
+ C2 --> B
+ C3 --> B
+ C4 --> B
+ C5 --> B
+ C6 --> B
+
+ B --> D[Attested against contract
using policy-as-code]
+ D --> E{Approved?}
+ E -- Yes --> F[CCR training begins]
+ E -- No --> H[Rejected: fix config]
+```
+
+If all goes well, you should see output similar to the following output, and the trained model and evaluation metrics will be saved under the folder [output](./modeller/output).
+
+```
+train-1 | Training samples: 36363
+train-1 | Validation samples: 9091
+train-1 | Test samples: 4546
+train-1 | Dataset constructed from config
+train-1 | Custom model loaded from PyTorch config
+train-1 | Optimizer Adam loaded from config
+train-1 | Custom loss function loaded from config
+train-1 | Epoch 1/2 completed | Training Loss: 1.9266
+train-1 | Epoch 1/2 completed | Validation Loss: 1.7594
+train-1 | Epoch 2/2 completed | Training Loss: 1.6811
+train-1 | Epoch 2/2 completed | Validation Loss: 1.6222
+train-1 | Saving trained model to /mnt/remote/output/trained_model.pth
+train-1 | Evaluation Metrics: {'test_loss': 1.6260074046620152, 'accuracy': 0.4104707435107787, 'f1_score': 0.40647825713476216}
+train-1 | CCR Training complete!
+train-1 |
+train-1 exited with code 0
+```
+
+## Deploy on CCR
+
+In a more realistic scenario, this datasets will not be available in the clear to the TDC, and the TDC will be required to use a CCR for training. The following steps describe the process of sharing an encrypted dataset with TDCs and setting up a CCR in Azure for training. Please stay tuned for CCR on other cloud platforms.
+
+To deploy in Azure, you will need the following.
+
+- Docker Hub account to store container images. Alternatively, you can use pre-built images from the ```ispirt``` container registry.
+- [Azure Key Vault](https://azure.microsoft.com/en-us/products/key-vault/) to store encryption keys and implement secure key release to CCR. You can either you Azure Key Vault Premium (lower cost), or [Azure Key Vault managed HSM](https://learn.microsoft.com/en-us/azure/key-vault/managed-hsm/overview) for enhanced security. Please see instructions below on how to create and setup your AKV instance.
+- Valid Azure subscription with sufficient access to create key vault, storage accounts, storage containers, and Azure Container Instances (ACI).
+
+If you are using your own development environment instead of a dev container or codespaces, you will to install the following dependencies.
+
+- [Azure CLI](https://learn.microsoft.com/en-us/cli/azure/install-azure-cli-linux).
+- [Azure CLI Confidential containers extension](https://learn.microsoft.com/en-us/cli/azure/confcom?view=azure-cli-latest). After installing Azure CLI, you can install this extension using ```az extension add --name confcom -y```
+- [Go](https://go.dev/doc/install). Follow the instructions to install Go. After installing, ensure that the PATH environment variable is set to include ```go``` runtime.
+- ```jq```. You can install jq using ```sudo apt-get install -y jq```
+
+We will be creating the following resources as part of the deployment.
+
+- Azure Key Vault
+- Azure Storage account
+- Storage containers to host encrypted datasets
+- Azure Container Instances (ACI) to deploy the CCR and train the model
+
+### 1. Push Container Images
+
+Pre-built container images are available in iSPIRT's container registry, which can be pulled by setting the following environment variable.
+
+```bash
+export CONTAINER_REGISTRY=ispirt.azurecr.io
+```
+
+If you wish to use your own container images, login to docker hub (or your container registry of choice) and then build and push the container images to it, so that they can be pulled by the CCR. This is a one-time operation, and you can skip this step if you have already pushed the images to your container registry.
+
+```bash
+export CONTAINER_REGISTRY=
+docker login -u -p ${CONTAINER_REGISTRY}
+cd $REPO_ROOT
+./ci/push-containers.sh
+cd $REPO_ROOT/scenarios/$SCENARIO
+./ci/push-containers.sh
+```
+
+> **Note:** Replace ``, `` and `` with your container registry name, docker hub username and password respectively. Preferably use registry services other than Docker Hub as throttling restrictions will cause delays (or) image push/pull failures.
+
+### 2. Create Resources
+
+First, set up the necessary environment variables for your deployment.
+
+```bash
+az login
+
+export SCENARIO=cifar10
+export CONTAINER_REGISTRY=ispirt.azurecr.io
+export AZURE_LOCATION=northeurope
+export AZURE_SUBSCRIPTION_ID=
+export AZURE_RESOURCE_GROUP=
+export AZURE_KEYVAULT_ENDPOINT=.vault.azure.net
+export AZURE_STORAGE_ACCOUNT_NAME=
+
+export AZURE_CIFAR10_CONTAINER_NAME=cifar10container
+export AZURE_MODEL_CONTAINER_NAME=modelcontainer
+export AZURE_OUTPUT_CONTAINER_NAME=outputcontainer
+```
+
+Alternatively, you can edit the values in the [export-variables.sh](./export-variables.sh) script and run it to set the environment variables.
+
+```bash
+./export-variables.sh
+source export-variables.sh
+```
+
+Azure Naming Rules:
+- Resource Group:
+ - 1–90 characters
+ - Letters, numbers, underscores, parentheses, hyphens, periods allowed
+ - Cannot end with a period (.)
+ - Case-insensitive, unique within subscription\
+- Key Vault:
+ - 3-24 characters
+ - Globally unique name
+ - Lowercase letters, numbers, hyphens only
+ - Must start and end with letter or number
+- Storage Account:
+ - 3-24 characters
+ - Globally unique name
+ - Lowercase letters and numbers only
+- Storage Container:
+ - 3-63 characters
+ - Lowercase letters, numbers, hyphens only
+ - Must start and end with a letter or number
+ - No consecutive hyphens
+ - Unique within storage account
+
+---
+
+**Important:**
+
+The values for the environment variables listed below must precisely match the namesake environment variables used during contract signing (next step). Any mismatch will lead to execution failure.
+
+- `SCENARIO`
+- `AZURE_KEYVAULT_ENDPOINT`
+- `CONTRACT_SERVICE_URL`
+- `AZURE_STORAGE_ACCOUNT_NAME`
+- `AZURE_CIFAR10_CONTAINER_NAME`
+
+---
+With the environment variables set, we are ready to create the resources -- Azure Key Vault and Azure Storage containers.
+
+```bash
+cd $REPO_ROOT/scenarios/$SCENARIO/deployment/azure
+./1-create-storage-containers.sh
+./2-create-akv.sh
+```
+---
+
+### 3\. Contract Signing
+
+Navigate to the [contract-ledger](https://github.com/kapilvgit/contract-ledger/blob/main/README.md) repository and follow the instructions for contract signing.
+
+Once the contract is signed, export the contract sequence number as an environment variable in the same terminal where you set the environment variables for the deployment.
+
+```bash
+export CONTRACT_SEQ_NO=
+```
+
+---
+
+### 4\. Data Encryption and Upload
+
+Using their respective keys, the TDPs and TDC encrypt their datasets and model (respectively) and upload them to the Storage containers created in the previous step.
+
+Navigate to the [Azure deployment](./deployment/azure/) directory and execute the scripts for key import, data encryption and upload to Azure Blob Storage, in preparation of the CCR deployment.
+
+The import-keys script generates and imports encryption keys into Azure Key Vault with a policy based on [policy-in-template.json](./policy/policy-in-template.json). The policy requires that the CCRs run specific containers with a specific configuration which includes the public identity of the contract service. Only CCRs that satisfy this policy will be granted access to the encryption keys. The generated keys are available as files with the extension `.bin`.
+
+```bash
+export CONTRACT_SERVICE_URL=https://depa-training-contract-service.centralindia.cloudapp.azure.com:8000
+export TOOLS_HOME=$REPO_ROOT/external/confidential-sidecar-containers/tools
+
+./3-import-keys.sh
+```
+
+The data and model are then packaged as encrypted filesystems by the TDPs and TDC using their respective keys, which are saved as `.img` files.
+
+```bash
+./4-encrypt-data.sh
+```
+
+The encrypted data and model are then uploaded to the Storage containers created in the previous step. The `.img` files are uploaded to the Storage containers as blobs.
+
+```bash
+./5-upload-encrypted-data.sh
+```
+
+---
+
+### 5\. CCR Deployment
+
+With the resources ready, we are ready to deploy the Confidential Clean Room (CCR) for executing the privacy-preserving model training.
+
+```bash
+export CONTRACT_SEQ_NO=
+./deploy.sh -c $CONTRACT_SEQ_NO -p ../../config/pipeline_config.json
+```
+
+Set the `$CONTRACT_SEQ_NO` variable to the exact value of the contract sequence number (of format 2.XX). For example, if the number was 2.15, export as:
+
+```bash
+export CONTRACT_SEQ_NO=15
+```
+
+This script will deploy the container images from your container registry, including the encrypted filesystem sidecar. The sidecar will generate an SEV-SNP attestation report, generate an attestation token using the Microsoft Azure Attestation (MAA) service, retrieve dataset, model and output encryption keys from the TDP and TDC's Azure Key Vault, train the model, and save the resulting model into TDC's output filesystem image, which the TDC can later decrypt.
+
+
+
+**Note:** The completion of this script's execution simply creates a CCR instance, and doesn't indicate whether training has completed or not. The training process might still be ongoing. Poll the container logs (see below) to track progress until training is complete.
+
+### 6\. Monitor Container Logs
+
+Use the following commands to monitor the logs of the deployed containers. You might have to repeatedly poll this command to monitor the training progress:
+
+```bash
+az container logs \
+ --name "depa-training-$SCENARIO" \
+ --resource-group "$AZURE_RESOURCE_GROUP" \
+ --container-name depa-training
+```
+
+You will know training has completed when the logs print "CCR Training complete!".
+
+#### Troubleshooting
+
+In case training fails, you might want to monitor the logs of the encrypted storage sidecar container to see if the encryption process completed successfully:
+
+```bash
+az container logs --name depa-training-$SCENARIO --resource-group $AZURE_RESOURCE_GROUP --container-name encrypted-storage-sidecar
+```
+
+And to further debug, inspect the logs of the encrypted filesystem sidecar container:
+
+```bash
+az container exec \
+ --resource-group $AZURE_RESOURCE_GROUP \
+ --name depa-training-$SCENARIO \
+ --container-name encrypted-storage-sidecar \
+ --exec-command "/bin/sh"
+```
+
+Once inside the sidecar container shell, view the logs:
+
+```bash
+cat log.txt
+```
+Or inspect the individual mounted directories in `mnt/remote/`:
+
+```bash
+cd mnt/remote && ls
+```
+
+### 6\. Download and Decrypt Model
+
+Once training has completed succesfully (The training container logs will mention it explicitly), download and decrypt the trained model and other training outputs.
+
+```bash
+./6-download-decrypt-model.sh
+```
+
+The outputs will be saved to the [output](./modeller/output/) directory.
+
+To check if the trained model is fresh, you can run the following command:
+
+```bash
+stat $REPO_ROOT/scenarios/$SCENARIO/modeller/output/trained_model.pth
+```
+
+---
+### Clean-up
+
+You can use the following command to delete the resource group and clean-up all resources used in the demo. Alternatively, you can navigate to the Azure portal and delete the resource group created for this demo.
+
+```bash
+az group delete --yes --name $AZURE_RESOURCE_GROUP
+```
\ No newline at end of file
diff --git a/scenarios/mnist/ci/Dockerfile.preprocess b/scenarios/cifar10/ci/Dockerfile.cifar10
similarity index 61%
rename from scenarios/mnist/ci/Dockerfile.preprocess
rename to scenarios/cifar10/ci/Dockerfile.cifar10
index 7f7d0fa..365b476 100644
--- a/scenarios/mnist/ci/Dockerfile.preprocess
+++ b/scenarios/cifar10/ci/Dockerfile.cifar10
@@ -1,19 +1,19 @@
-FROM ubuntu:20.04
+FROM ubuntu:22.04
ENV DEBIAN_FRONTEND="noninteractive"
RUN apt-get update && apt-get -y upgrade \
&& apt-get install -y curl \
- && apt-get install -y python3.9 python3.9-dev python3.9-distutils
+ && apt-get install -y python3 python3-dev python3-distutils
## Install pip
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
-RUN python3.9 get-pip.py
+RUN python3 get-pip.py
## Install dependencies
RUN pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
-RUN pip3 --default-timeout=1000 install onnx onnx2pytorch
+RUN pip3 install safetensors packaging
RUN apt-get install -y jq
-COPY preprocess.py preprocess.py
\ No newline at end of file
+COPY preprocess_cifar10.py preprocess_cifar10.py
\ No newline at end of file
diff --git a/scenarios/cifar10/ci/Dockerfile.modelsave b/scenarios/cifar10/ci/Dockerfile.modelsave
new file mode 100644
index 0000000..2bd8829
--- /dev/null
+++ b/scenarios/cifar10/ci/Dockerfile.modelsave
@@ -0,0 +1,18 @@
+FROM ubuntu:22.04
+
+ENV DEBIAN_FRONTEND="noninteractive"
+
+RUN apt-get update && apt-get -y upgrade \
+ && apt-get install -y gcc g++ curl \
+ && apt-get install -y python3 python3-dev python3-distutils
+
+## Install pip
+RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
+RUN python3 get-pip.py
+
+## Install dependencies
+RUN pip3 install torch --index-url https://download.pytorch.org/whl/cpu
+RUN pip3 --default-timeout=1000 install safetensors packaging numpy
+
+COPY save_base_model.py save_base_model.py
+COPY model_constructor.py model_constructor.py
diff --git a/scenarios/cifar10/ci/build.sh b/scenarios/cifar10/ci/build.sh
new file mode 100755
index 0000000..ee9e233
--- /dev/null
+++ b/scenarios/cifar10/ci/build.sh
@@ -0,0 +1,4 @@
+#!/bin/bash
+
+docker build -f ci/Dockerfile.cifar10 src -t preprocess-cifar10:latest
+docker build -f ci/Dockerfile.modelsave src -t cifar10-model-save:latest
\ No newline at end of file
diff --git a/scenarios/cifar10/ci/pull-containers.sh b/scenarios/cifar10/ci/pull-containers.sh
new file mode 100755
index 0000000..d7904f0
--- /dev/null
+++ b/scenarios/cifar10/ci/pull-containers.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+
+containers=("preprocess-cifar10:latest" "cifar10-model-save:latest")
+for container in "${containers[@]}"
+do
+ docker pull $CONTAINER_REGISTRY"/"$container
+done
\ No newline at end of file
diff --git a/scenarios/cifar10/ci/push-containers.sh b/scenarios/cifar10/ci/push-containers.sh
new file mode 100755
index 0000000..f84bc18
--- /dev/null
+++ b/scenarios/cifar10/ci/push-containers.sh
@@ -0,0 +1,6 @@
+containers=("cifar10-model-save:latest" "preprocess-cifar10:latest")
+for container in "${containers[@]}"
+do
+ docker tag $container $CONTAINER_REGISTRY"/"$container
+ docker push $CONTAINER_REGISTRY"/"$container
+done
diff --git a/scenarios/cifar10/config/consolidate_pipeline.sh b/scenarios/cifar10/config/consolidate_pipeline.sh
new file mode 100755
index 0000000..7fefbdf
--- /dev/null
+++ b/scenarios/cifar10/config/consolidate_pipeline.sh
@@ -0,0 +1,58 @@
+#! /bin/bash
+
+REPO_ROOT="$(git rev-parse --show-toplevel)"
+SCENARIO=cifar10
+
+template_path="$REPO_ROOT/scenarios/$SCENARIO/config/templates"
+model_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/model_config.json"
+data_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/dataset_config.json"
+loss_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/loss_config.json"
+train_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/train_config.json"
+eval_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/eval_config.json"
+join_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/join_config.json"
+pipeline_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/pipeline_config.json"
+
+# populate "model_config", "data_config", and "loss_config" keys in train config
+train_config=$(cat $template_path/train_config_template.json)
+
+# Only merge if the file exists
+if [[ -f "$model_config_path" ]]; then
+ model_config=$(cat $model_config_path)
+ train_config=$(echo "$train_config" | jq --argjson model "$model_config" '.config.model_config = $model')
+fi
+
+if [[ -f "$data_config_path" ]]; then
+ data_config=$(cat $data_config_path)
+ train_config=$(echo "$train_config" | jq --argjson data "$data_config" '.config.dataset_config = $data')
+fi
+
+if [[ -f "$loss_config_path" ]]; then
+ loss_config=$(cat $loss_config_path)
+ train_config=$(echo "$train_config" | jq --argjson loss "$loss_config" '.config.loss_config = $loss')
+fi
+
+if [[ -f "$eval_config_path" ]]; then
+ eval_config=$(cat $eval_config_path)
+ # Get all keys from eval_config and copy them to train_config
+ for key in $(echo "$eval_config" | jq -r 'keys[]'); do
+ train_config=$(echo "$train_config" | jq --argjson eval "$eval_config" --arg key "$key" '.config[$key] = $eval[$key]')
+ done
+fi
+
+# save train_config
+echo "$train_config" > $train_config_path
+
+# prepare pipeline config from join_config.json (first dict "config") and train_config.json (second dict "config")
+pipeline_config=$(cat $template_path/pipeline_config_template.json)
+
+# Only merge join_config if the file exists
+if [[ -f "$join_config_path" ]]; then
+ join_config=$(cat $join_config_path)
+ pipeline_config=$(echo "$pipeline_config" | jq --argjson join "$join_config" '.pipeline += [$join]')
+fi
+
+# Always merge train_config as it's required
+pipeline_config=$(echo "$pipeline_config" | jq --argjson train "$train_config" '.pipeline += [$train]')
+
+# save pipeline_config to pipeline_config.json
+echo "$pipeline_config" > $pipeline_config_path
\ No newline at end of file
diff --git a/scenarios/cifar10/config/dataset_config.json b/scenarios/cifar10/config/dataset_config.json
new file mode 100644
index 0000000..bc14fca
--- /dev/null
+++ b/scenarios/cifar10/config/dataset_config.json
@@ -0,0 +1,17 @@
+{
+ "type": "serialized",
+ "format": "safetensors",
+ "structure": "list_of_tuples",
+ "features_key": "features",
+ "targets_key": "targets",
+ "transforms": {
+ "normalize": true,
+ "augment": false
+ },
+ "splits": {
+ "train": 0.85,
+ "val": 0.15,
+ "test": 0.05,
+ "random_state": 42
+ }
+}
\ No newline at end of file
diff --git a/scenarios/cifar10/config/eval_config.json b/scenarios/cifar10/config/eval_config.json
new file mode 100644
index 0000000..2b4b375
--- /dev/null
+++ b/scenarios/cifar10/config/eval_config.json
@@ -0,0 +1,20 @@
+{
+ "task_type": "classification",
+ "metrics": [
+ "accuracy",
+ {
+ "name": "confusion_matrix",
+ "params": {}
+ },
+ {
+ "name": "f1_score",
+ "params": {
+ "average": "macro"
+ }
+ },
+ {
+ "name": "classification_report",
+ "params": {}
+ }
+ ]
+}
\ No newline at end of file
diff --git a/scenarios/cifar10/config/loss_config.json b/scenarios/cifar10/config/loss_config.json
new file mode 100644
index 0000000..99f914b
--- /dev/null
+++ b/scenarios/cifar10/config/loss_config.json
@@ -0,0 +1,6 @@
+{
+ "class": "nn.CrossEntropyLoss",
+ "params": {
+ "reduction": "mean"
+ }
+}
\ No newline at end of file
diff --git a/scenarios/cifar10/config/model_config.json b/scenarios/cifar10/config/model_config.json
new file mode 100644
index 0000000..caa9878
--- /dev/null
+++ b/scenarios/cifar10/config/model_config.json
@@ -0,0 +1,121 @@
+{
+ "layers": {
+ "conv1": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": 3,
+ "out_channels": 6,
+ "kernel_size": 5
+ }
+ },
+ "pool": {
+ "class": "nn.MaxPool2d",
+ "params": {
+ "kernel_size": 2,
+ "stride": 2
+ }
+ },
+ "conv2": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": 6,
+ "out_channels": 16,
+ "kernel_size": 5
+ }
+ },
+ "fc1": {
+ "class": "nn.Linear",
+ "params": {
+ "in_features": 400,
+ "out_features": 120
+ }
+ },
+ "fc2": {
+ "class": "nn.Linear",
+ "params": {
+ "in_features": 120,
+ "out_features": 84
+ }
+ },
+ "fc3": {
+ "class": "nn.Linear",
+ "params": {
+ "in_features": 84,
+ "out_features": 10
+ }
+ }
+ },
+ "input": [
+ "x"
+ ],
+ "forward": [
+ {
+ "ops": [
+ "conv1",
+ "F.relu",
+ "pool"
+ ],
+ "input": [
+ "x"
+ ],
+ "output": "x1"
+ },
+ {
+ "ops": [
+ "conv2",
+ "F.relu",
+ "pool"
+ ],
+ "input": [
+ "x1"
+ ],
+ "output": "x2"
+ },
+ {
+ "ops": [
+ [
+ "torch.flatten",
+ {
+ "start_dim": 1
+ }
+ ]
+ ],
+ "input": [
+ "x2"
+ ],
+ "output": "x3"
+ },
+ {
+ "ops": [
+ "fc1",
+ "F.relu"
+ ],
+ "input": [
+ "x3"
+ ],
+ "output": "x4"
+ },
+ {
+ "ops": [
+ "fc2",
+ "F.relu"
+ ],
+ "input": [
+ "x4"
+ ],
+ "output": "x5"
+ },
+ {
+ "ops": [
+ "fc3"
+ ],
+ "input": [
+ "x5"
+ ],
+ "output": "x6"
+ }
+ ],
+ "output": [
+ "x6"
+ ]
+}
\ No newline at end of file
diff --git a/scenarios/cifar10/config/pipeline_config.json b/scenarios/cifar10/config/pipeline_config.json
new file mode 100644
index 0000000..74f5f8b
--- /dev/null
+++ b/scenarios/cifar10/config/pipeline_config.json
@@ -0,0 +1,187 @@
+{
+ "pipeline": [
+ {
+ "name": "Train_DL",
+ "config": {
+ "paths": {
+ "input_dataset_path": "/mnt/remote/cifar10/cifar10-dataset.safetensors",
+ "base_model_path": "/mnt/remote/model/model.safetensors",
+ "trained_model_output_path": "/mnt/remote/output"
+ },
+ "model_type": "safetensors",
+ "is_private": false,
+ "device": "cpu",
+ "batch_size": 16,
+ "optimizer": {
+ "name": "Adam",
+ "params": {
+ "lr": 0.0001
+ }
+ },
+ "scheduler": null,
+ "total_epochs": 2,
+ "model_config": {
+ "layers": {
+ "conv1": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": 3,
+ "out_channels": 6,
+ "kernel_size": 5
+ }
+ },
+ "pool": {
+ "class": "nn.MaxPool2d",
+ "params": {
+ "kernel_size": 2,
+ "stride": 2
+ }
+ },
+ "conv2": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": 6,
+ "out_channels": 16,
+ "kernel_size": 5
+ }
+ },
+ "fc1": {
+ "class": "nn.Linear",
+ "params": {
+ "in_features": 400,
+ "out_features": 120
+ }
+ },
+ "fc2": {
+ "class": "nn.Linear",
+ "params": {
+ "in_features": 120,
+ "out_features": 84
+ }
+ },
+ "fc3": {
+ "class": "nn.Linear",
+ "params": {
+ "in_features": 84,
+ "out_features": 10
+ }
+ }
+ },
+ "input": [
+ "x"
+ ],
+ "forward": [
+ {
+ "ops": [
+ "conv1",
+ "F.relu",
+ "pool"
+ ],
+ "input": [
+ "x"
+ ],
+ "output": "x1"
+ },
+ {
+ "ops": [
+ "conv2",
+ "F.relu",
+ "pool"
+ ],
+ "input": [
+ "x1"
+ ],
+ "output": "x2"
+ },
+ {
+ "ops": [
+ [
+ "torch.flatten",
+ {
+ "start_dim": 1
+ }
+ ]
+ ],
+ "input": [
+ "x2"
+ ],
+ "output": "x3"
+ },
+ {
+ "ops": [
+ "fc1",
+ "F.relu"
+ ],
+ "input": [
+ "x3"
+ ],
+ "output": "x4"
+ },
+ {
+ "ops": [
+ "fc2",
+ "F.relu"
+ ],
+ "input": [
+ "x4"
+ ],
+ "output": "x5"
+ },
+ {
+ "ops": [
+ "fc3"
+ ],
+ "input": [
+ "x5"
+ ],
+ "output": "x6"
+ }
+ ],
+ "output": [
+ "x6"
+ ]
+ },
+ "dataset_config": {
+ "type": "serialized",
+ "format": "safetensors",
+ "structure": "list_of_tuples",
+ "features_key": "features",
+ "targets_key": "targets",
+ "transforms": {
+ "normalize": true,
+ "augment": false
+ },
+ "splits": {
+ "train": 0.8,
+ "val": 0.2,
+ "random_state": 42
+ }
+ },
+ "loss_config": {
+ "class": "nn.CrossEntropyLoss",
+ "params": {
+ "reduction": "mean"
+ }
+ },
+ "metrics": [
+ "accuracy",
+ {
+ "name": "confusion_matrix",
+ "params": {}
+ },
+ {
+ "name": "f1_score",
+ "params": {
+ "average": "macro"
+ }
+ },
+ {
+ "name": "classification_report",
+ "params": {}
+ }
+ ],
+ "task_type": "classification"
+ }
+ }
+ ]
+}
diff --git a/scenarios/cifar10/config/templates/pipeline_config_template.json b/scenarios/cifar10/config/templates/pipeline_config_template.json
new file mode 100644
index 0000000..43e9e84
--- /dev/null
+++ b/scenarios/cifar10/config/templates/pipeline_config_template.json
@@ -0,0 +1,3 @@
+{
+ "pipeline": []
+}
\ No newline at end of file
diff --git a/scenarios/cifar10/config/templates/train_config_template.json b/scenarios/cifar10/config/templates/train_config_template.json
new file mode 100644
index 0000000..ecaaf6a
--- /dev/null
+++ b/scenarios/cifar10/config/templates/train_config_template.json
@@ -0,0 +1,22 @@
+{
+ "name": "Train_DL",
+ "config": {
+ "paths": {
+ "input_dataset_path": "/mnt/remote/cifar10/cifar10-dataset.safetensors",
+ "base_model_path": "/mnt/remote/model/model.safetensors",
+ "trained_model_output_path": "/mnt/remote/output"
+ },
+ "model_type": "safetensors",
+ "is_private": false,
+ "device": "cpu",
+ "batch_size": 16,
+ "optimizer": {
+ "name": "Adam",
+ "params": {
+ "lr": 1e-4
+ }
+ },
+ "scheduler": null,
+ "total_epochs": 2
+ }
+}
\ No newline at end of file
diff --git a/scenarios/cifar10/config/train_config.json b/scenarios/cifar10/config/train_config.json
new file mode 100644
index 0000000..f247681
--- /dev/null
+++ b/scenarios/cifar10/config/train_config.json
@@ -0,0 +1,183 @@
+{
+ "name": "Train_DL",
+ "config": {
+ "paths": {
+ "input_dataset_path": "/mnt/remote/cifar10/cifar10-dataset.safetensors",
+ "base_model_path": "/mnt/remote/model/model.safetensors",
+ "trained_model_output_path": "/mnt/remote/output"
+ },
+ "model_type": "safetensors",
+ "is_private": false,
+ "device": "cpu",
+ "batch_size": 16,
+ "optimizer": {
+ "name": "Adam",
+ "params": {
+ "lr": 0.0001
+ }
+ },
+ "scheduler": null,
+ "total_epochs": 2,
+ "model_config": {
+ "layers": {
+ "conv1": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": 3,
+ "out_channels": 6,
+ "kernel_size": 5
+ }
+ },
+ "pool": {
+ "class": "nn.MaxPool2d",
+ "params": {
+ "kernel_size": 2,
+ "stride": 2
+ }
+ },
+ "conv2": {
+ "class": "nn.Conv2d",
+ "params": {
+ "in_channels": 6,
+ "out_channels": 16,
+ "kernel_size": 5
+ }
+ },
+ "fc1": {
+ "class": "nn.Linear",
+ "params": {
+ "in_features": 400,
+ "out_features": 120
+ }
+ },
+ "fc2": {
+ "class": "nn.Linear",
+ "params": {
+ "in_features": 120,
+ "out_features": 84
+ }
+ },
+ "fc3": {
+ "class": "nn.Linear",
+ "params": {
+ "in_features": 84,
+ "out_features": 10
+ }
+ }
+ },
+ "input": [
+ "x"
+ ],
+ "forward": [
+ {
+ "ops": [
+ "conv1",
+ "F.relu",
+ "pool"
+ ],
+ "input": [
+ "x"
+ ],
+ "output": "x1"
+ },
+ {
+ "ops": [
+ "conv2",
+ "F.relu",
+ "pool"
+ ],
+ "input": [
+ "x1"
+ ],
+ "output": "x2"
+ },
+ {
+ "ops": [
+ [
+ "torch.flatten",
+ {
+ "start_dim": 1
+ }
+ ]
+ ],
+ "input": [
+ "x2"
+ ],
+ "output": "x3"
+ },
+ {
+ "ops": [
+ "fc1",
+ "F.relu"
+ ],
+ "input": [
+ "x3"
+ ],
+ "output": "x4"
+ },
+ {
+ "ops": [
+ "fc2",
+ "F.relu"
+ ],
+ "input": [
+ "x4"
+ ],
+ "output": "x5"
+ },
+ {
+ "ops": [
+ "fc3"
+ ],
+ "input": [
+ "x5"
+ ],
+ "output": "x6"
+ }
+ ],
+ "output": [
+ "x6"
+ ]
+ },
+ "dataset_config": {
+ "type": "serialized",
+ "format": "safetensors",
+ "structure": "list_of_tuples",
+ "features_key": "features",
+ "targets_key": "targets",
+ "transforms": {
+ "normalize": true,
+ "augment": false
+ },
+ "splits": {
+ "train": 0.8,
+ "val": 0.2,
+ "random_state": 42
+ }
+ },
+ "loss_config": {
+ "class": "nn.CrossEntropyLoss",
+ "params": {
+ "reduction": "mean"
+ }
+ },
+ "metrics": [
+ "accuracy",
+ {
+ "name": "confusion_matrix",
+ "params": {}
+ },
+ {
+ "name": "f1_score",
+ "params": {
+ "average": "macro"
+ }
+ },
+ {
+ "name": "classification_report",
+ "params": {}
+ }
+ ],
+ "task_type": "classification"
+ }
+}
diff --git a/scenarios/cifar10/contract/contract.json b/scenarios/cifar10/contract/contract.json
new file mode 100644
index 0000000..c59d85f
--- /dev/null
+++ b/scenarios/cifar10/contract/contract.json
@@ -0,0 +1,32 @@
+{
+ "id": "f4f72a88-bab1-11ed-afa1-0242ac120002",
+ "schemaVersion": "0.1",
+ "startTime": "2023-03-14T00:00:00.000Z",
+ "expiryTime": "2024-03-14T00:00:00.000Z",
+ "tdc": "",
+ "tdps": [],
+ "ccrp": "did:web:$CCRP_USERNAME.github.io",
+ "datasets": [
+ {
+ "id": "19517ba8-bab8-11ed-afa1-0242ac120002",
+ "name": "cifar10",
+ "url": "https://$AZURE_STORAGE_ACCOUNT_NAME.blob.core.windows.net/$AZURE_CIFAR10_CONTAINER_NAME/data.img",
+ "provider": "",
+ "key": {
+ "type": "azure",
+ "properties": {
+ "kid": "CIFAR10FilesystemEncryptionKey",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "endpoint": ""
+ }
+ }
+ }
+ ],
+ "purpose": "TRAINING",
+ "terms": {
+ "payment": {},
+ "revocation": {}
+ }
+}
\ No newline at end of file
diff --git a/scenarios/cifar10/deployment/azure/0-create-acr.sh b/scenarios/cifar10/deployment/azure/0-create-acr.sh
new file mode 100755
index 0000000..fdd8103
--- /dev/null
+++ b/scenarios/cifar10/deployment/azure/0-create-acr.sh
@@ -0,0 +1,42 @@
+#!/bin/bash
+
+# Only to be run when creating a new Azure Container Registry (ACR).
+
+# Ensure required env vars are set
+if [[ -z "$CONTAINER_REGISTRY" || -z "$AZURE_RESOURCE_GROUP" || -z "$AZURE_LOCATION" ]]; then
+ echo "ERROR: CONTAINER_REGISTRY, AZURE_RESOURCE_GROUP, and AZURE_LOCATION environment variables must be set."
+ exit 1
+fi
+
+echo "Checking if ACR '$CONTAINER_REGISTRY' exists in resource group '$AZURE_RESOURCE_GROUP'..."
+
+# Check if ACR exists
+ACR_EXISTS=$(az acr show --name "$CONTAINER_REGISTRY" --resource-group "$AZURE_RESOURCE_GROUP" --query "name" -o tsv 2>/dev/null)
+
+if [[ -n "$ACR_EXISTS" ]]; then
+ echo "✅ ACR '$CONTAINER_REGISTRY' already exists."
+else
+ echo "⏳ ACR '$CONTAINER_REGISTRY' does not exist. Creating..."
+
+ # Create ACR with premium SKU and admin enabled
+ az acr create \
+ --name "$CONTAINER_REGISTRY" \
+ --resource-group "$AZURE_RESOURCE_GROUP" \
+ --location "$AZURE_LOCATION" \
+ --sku Premium \
+ --admin-enabled true \
+ --output table
+
+ # Enable anonymous pull
+ az acr update --name "$CONTAINER_REGISTRY" --anonymous-pull-enabled true
+
+ if [[ $? -eq 0 ]]; then
+ echo "✅ ACR '$CONTAINER_REGISTRY' created successfully."
+ else
+ echo "❌ Failed to create ACR."
+ exit 1
+ fi
+fi
+
+# Login to the ACR
+az acr login --name "$CONTAINER_REGISTRY"
\ No newline at end of file
diff --git a/scenarios/cifar10/deployment/azure/1-create-storage-containers.sh b/scenarios/cifar10/deployment/azure/1-create-storage-containers.sh
new file mode 100755
index 0000000..da98b8a
--- /dev/null
+++ b/scenarios/cifar10/deployment/azure/1-create-storage-containers.sh
@@ -0,0 +1,50 @@
+#!/bin/bash
+#
+echo "Checking if resource group $AZURE_RESOURCE_GROUP exists..."
+RG_EXISTS=$(az group exists --name $AZURE_RESOURCE_GROUP)
+
+if [ "$RG_EXISTS" == "false" ]; then
+ echo "Resource group $AZURE_RESOURCE_GROUP does not exist. Creating it now..."
+ # Create the resource group
+ az group create --name $AZURE_RESOURCE_GROUP --location $AZURE_LOCATION
+else
+ echo "Resource group $AZURE_RESOURCE_GROUP already exists. Skipping creation."
+fi
+
+echo "Check if storage account $AZURE_STORAGE_ACCOUNT_NAME exists..."
+STORAGE_ACCOUNT_EXISTS=$(az storage account check-name --name $AZURE_STORAGE_ACCOUNT_NAME --query "nameAvailable" --output tsv)
+
+if [ "$STORAGE_ACCOUNT_EXISTS" == "true" ]; then
+ echo "Storage account $AZURE_STORAGE_ACCOUNT_NAME does not exist. Creating it now..."
+ az storage account create --resource-group $AZURE_RESOURCE_GROUP --name $AZURE_STORAGE_ACCOUNT_NAME
+else
+ echo "Storage account $AZURE_STORAGE_ACCOUNT_NAME already exists. Skipping creation."
+fi
+
+# Get the storage account key
+ACCOUNT_KEY=$(az storage account keys list --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --query "[0].value" --output tsv)
+
+
+# Check if the CIFAR-10 container exists
+CONTAINER_EXISTS=$(az storage container exists --name $AZURE_CIFAR10_CONTAINER_NAME --account-name $AZURE_STORAGE_ACCOUNT_NAME --account-key $ACCOUNT_KEY --query "exists" --output tsv)
+
+if [ "$CONTAINER_EXISTS" == "false" ]; then
+ echo "Container $AZURE_CIFAR10_CONTAINER_NAME does not exist. Creating it now..."
+ az storage container create --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --name $AZURE_CIFAR10_CONTAINER_NAME --account-key $ACCOUNT_KEY
+fi
+
+# Check if the MODEL container exists
+CONTAINER_EXISTS=$(az storage container exists --name $AZURE_MODEL_CONTAINER_NAME --account-name $AZURE_STORAGE_ACCOUNT_NAME --account-key $ACCOUNT_KEY --query "exists" --output tsv)
+
+if [ "$CONTAINER_EXISTS" == "false" ]; then
+ echo "Container $AZURE_MODEL_CONTAINER_NAME does not exist. Creating it now..."
+ az storage container create --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --name $AZURE_MODEL_CONTAINER_NAME --account-key $ACCOUNT_KEY
+fi
+
+# Check if the OUTPUT container exists
+CONTAINER_EXISTS=$(az storage container exists --name $AZURE_OUTPUT_CONTAINER_NAME --account-name $AZURE_STORAGE_ACCOUNT_NAME --account-key $ACCOUNT_KEY --query "exists" --output tsv)
+
+if [ "$CONTAINER_EXISTS" == "false" ]; then
+ echo "Container $AZURE_OUTPUT_CONTAINER_NAME does not exist. Creating it now..."
+ az storage container create --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --name $AZURE_OUTPUT_CONTAINER_NAME --account-key $ACCOUNT_KEY
+fi
\ No newline at end of file
diff --git a/scenarios/cifar10/deployment/azure/2-create-akv.sh b/scenarios/cifar10/deployment/azure/2-create-akv.sh
new file mode 100755
index 0000000..c20a75e
--- /dev/null
+++ b/scenarios/cifar10/deployment/azure/2-create-akv.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+set -e
+
+if [[ "$AZURE_KEYVAULT_ENDPOINT" == *".vault.azure.net" ]]; then
+ AZURE_AKV_RESOURCE_NAME=`echo $AZURE_KEYVAULT_ENDPOINT | awk '{split($0,a,"."); print a[1]}'`
+ # Check if the Key Vault already exists
+ echo "Checking if Key Vault $AZURE_AKV_RESOURCE_NAME exists..."
+ NAME_AVAILABLE=$(az rest --method post \
+ --uri "https://management.azure.com/subscriptions/$AZURE_SUBSCRIPTION_ID/providers/Microsoft.KeyVault/checkNameAvailability?api-version=2019-09-01" \
+ --headers "Content-Type=application/json" \
+ --body "{\"name\": \"$AZURE_AKV_RESOURCE_NAME\", \"type\": \"Microsoft.KeyVault/vaults\"}" | jq -r '.nameAvailable')
+ if [ "$NAME_AVAILABLE" == true ]; then
+ echo "Key Vault $AZURE_AKV_RESOURCE_NAME does not exist. Creating it now..."
+ echo CREATING $AZURE_KEYVAULT_ENDPOINT in resouce group $AZURE_RESOURCE_GROUP
+ # Create Azure key vault with RBAC authorization
+ az keyvault create --name $AZURE_AKV_RESOURCE_NAME --resource-group $AZURE_RESOURCE_GROUP --sku "Premium" --enable-rbac-authorization
+ # Assign RBAC roles to the resource owner so they can import keys
+ AKV_SCOPE=`az keyvault show --name $AZURE_AKV_RESOURCE_NAME --query id --output tsv`
+ az role assignment create --role "Key Vault Crypto Officer" --assignee `az account show --query user.name --output tsv` --scope $AKV_SCOPE
+ az role assignment create --role "Key Vault Crypto User" --assignee `az account show --query user.name --output tsv` --scope $AKV_SCOPE
+ else
+ echo "Key Vault $AZURE_AKV_RESOURCE_NAME already exists. Skipping creation."
+ fi
+else
+ echo "Automated creation of key vaults is supported only for vaults"
+fi
diff --git a/scenarios/covid/data/3-import-keys.sh b/scenarios/cifar10/deployment/azure/3-import-keys.sh
similarity index 86%
rename from scenarios/covid/data/3-import-keys.sh
rename to scenarios/cifar10/deployment/azure/3-import-keys.sh
index 3202001..a010412 100755
--- a/scenarios/covid/data/3-import-keys.sh
+++ b/scenarios/cifar10/deployment/azure/3-import-keys.sh
@@ -29,7 +29,7 @@ echo Obtaining contract service parameters...
CONTRACT_SERVICE_URL=${CONTRACT_SERVICE_URL:-"http://localhost:8000"}
export CONTRACT_SERVICE_PARAMETERS=$(curl -k -f $CONTRACT_SERVICE_URL/parameters | base64 --wrap=0)
-envsubst < ../policy/policy-in-template.json > /tmp/policy-in.json
+envsubst < ../../policy/policy-in-template.json > /tmp/policy-in.json
export CCE_POLICY=$(az confcom acipolicygen -i /tmp/policy-in.json --debug-mode)
export CCE_POLICY_HASH=$(go run $TOOLS_HOME/securitypolicydigest/main.go -p $CCE_POLICY)
echo "Training container policy hash $CCE_POLICY_HASH"
@@ -44,12 +44,12 @@ elif [[ "$AZURE_KEYVAULT_ENDPOINT" == *".managedhsm.azure.net" ]]; then
export AZURE_AKV_KEY_TYPE="oct-HSM"
fi
-DATADIR=`pwd`
-import_key "ICMRFilesystemEncryptionKey" $DATADIR/icmrkey.bin
-import_key "COWINFilesystemEncryptionKey" $DATADIR/cowinkey.bin
-import_key "IndexFilesystemEncryptionKey" $DATADIR/indexkey.bin
-import_key "ModelFilesystemEncryptionKey" $DATADIR/modelkey.bin
-import_key "OutputFilesystemEncryptionKey" $DATADIR/outputkey.bin
+DATADIR=$REPO_ROOT/scenarios/$SCENARIO/data
+MODELDIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+import_key "CIFAR10FilesystemEncryptionKey" $DATADIR/cifar10_key.bin
+import_key "ModelFilesystemEncryptionKey" $MODELDIR/model_key.bin
+import_key "OutputFilesystemEncryptionKey" $MODELDIR/output_key.bin
## Cleanup
rm /tmp/importkey-config.json
diff --git a/scenarios/cifar10/deployment/azure/4-encrypt-data.sh b/scenarios/cifar10/deployment/azure/4-encrypt-data.sh
new file mode 100755
index 0000000..82d4967
--- /dev/null
+++ b/scenarios/cifar10/deployment/azure/4-encrypt-data.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+DATADIR=$REPO_ROOT/scenarios/$SCENARIO/data
+MODELDIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+./generatefs.sh -d $DATADIR/preprocessed -k $DATADIR/cifar10_key.bin -i $DATADIR/cifar10.img
+./generatefs.sh -d $MODELDIR/models -k $MODELDIR/model_key.bin -i $MODELDIR/model.img
+
+sudo rm -rf $MODELDIR/output
+mkdir -p $MODELDIR/output
+./generatefs.sh -d $MODELDIR/output -k $MODELDIR/output_key.bin -i $MODELDIR/output.img
\ No newline at end of file
diff --git a/scenarios/cifar10/deployment/azure/5-upload-encrypted-data.sh b/scenarios/cifar10/deployment/azure/5-upload-encrypted-data.sh
new file mode 100755
index 0000000..440455a
--- /dev/null
+++ b/scenarios/cifar10/deployment/azure/5-upload-encrypted-data.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+
+export DATA_DIR=$REPO_ROOT/scenarios/$SCENARIO/data
+export MODEL_DIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+ACCOUNT_KEY=$(az storage account keys list --account-name $AZURE_STORAGE_ACCOUNT_NAME --only-show-errors | jq -r .[0].value)
+
+az storage blob upload \
+ --account-name $AZURE_STORAGE_ACCOUNT_NAME \
+ --container $AZURE_CIFAR10_CONTAINER_NAME \
+ --file $DATA_DIR/cifar10.img \
+ --name data.img \
+ --type page \
+ --overwrite \
+ --account-key $ACCOUNT_KEY
+
+az storage blob upload \
+ --account-name $AZURE_STORAGE_ACCOUNT_NAME \
+ --container $AZURE_MODEL_CONTAINER_NAME \
+ --file $MODEL_DIR/model.img \
+ --name data.img \
+ --type page \
+ --overwrite \
+ --account-key $ACCOUNT_KEY
+
+az storage blob upload \
+ --account-name $AZURE_STORAGE_ACCOUNT_NAME \
+ --container $AZURE_OUTPUT_CONTAINER_NAME \
+ --file $MODEL_DIR/output.img \
+ --name data.img \
+ --type page \
+ --overwrite \
+ --account-key $ACCOUNT_KEY
diff --git a/scenarios/mnist/data/6-download-decrypt-model.sh b/scenarios/cifar10/deployment/azure/6-download-decrypt-model.sh
similarity index 75%
rename from scenarios/mnist/data/6-download-decrypt-model.sh
rename to scenarios/cifar10/deployment/azure/6-download-decrypt-model.sh
index 0975334..b6d043a 100755
--- a/scenarios/mnist/data/6-download-decrypt-model.sh
+++ b/scenarios/cifar10/deployment/azure/6-download-decrypt-model.sh
@@ -1,16 +1,21 @@
#!/bin/bash
+MODELDIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+sudo rm -rf $MODELDIR/output
+mkdir -p $MODELDIR/output
+
ACCOUNT_KEY=$(az storage account keys list --account-name $AZURE_STORAGE_ACCOUNT_NAME --only-show-errors | jq -r .[0].value)
az storage blob download \
--account-name $AZURE_STORAGE_ACCOUNT_NAME \
--container $AZURE_OUTPUT_CONTAINER_NAME \
- --file output.img \
+ --file $MODELDIR/output.img \
--name data.img \
--account-key $ACCOUNT_KEY
-encryptedImage=output.img
-keyFilePath=outputkey.bin
+encryptedImage=$MODELDIR/output.img
+keyFilePath=$MODELDIR/output_key.bin
echo Decrypting $encryptedImage with key $keyFilePath
deviceName=cryptdevice1
@@ -23,7 +28,7 @@ sudo cryptsetup luksOpen "$encryptedImage" "$deviceName" \
mountPoint=`mktemp -d`
sudo mount -t ext4 "$deviceNamePath" "$mountPoint" -o loop
-cp -r $mountPoint/* ./output/
+cp -r $mountPoint/* $MODELDIR/output/
echo "[!] Closing device..."
diff --git a/scenarios/mnist/deployment/aci/aci-parameters-template.json b/scenarios/cifar10/deployment/azure/aci-parameters-template.json
similarity index 100%
rename from scenarios/mnist/deployment/aci/aci-parameters-template.json
rename to scenarios/cifar10/deployment/azure/aci-parameters-template.json
diff --git a/scenarios/cifar10/deployment/azure/arm-template.json b/scenarios/cifar10/deployment/azure/arm-template.json
new file mode 100644
index 0000000..50deadd
--- /dev/null
+++ b/scenarios/cifar10/deployment/azure/arm-template.json
@@ -0,0 +1,181 @@
+{
+ "$schema": "https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#",
+ "contentVersion": "1.0.0.0",
+ "parameters": {
+ "name": {
+ "defaultValue": "depa-training-cifar10",
+ "type": "string",
+ "metadata": {
+ "description": "Name for the container group"
+ }
+ },
+ "location": {
+ "defaultValue": "northeurope",
+ "type": "string",
+ "metadata": {
+ "description": "Location for all resources."
+ }
+ },
+ "port": {
+ "defaultValue": 8080,
+ "type": "int",
+ "metadata": {
+ "description": "Port to open on the container and the public IP address."
+ }
+ },
+ "containerRegistry": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "The container registry login server."
+ }
+ },
+ "restartPolicy": {
+ "defaultValue": "Never",
+ "allowedValues": [
+ "Always",
+ "Never",
+ "OnFailure"
+ ],
+ "type": "string",
+ "metadata": {
+ "description": "The behavior of Azure runtime if container has stopped."
+ }
+ },
+ "ccePolicy": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "cce policy"
+ }
+ },
+ "EncfsSideCarArgs": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "Remote file system information for storage sidecar."
+ }
+ },
+ "ContractService": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "URL of contract service"
+ }
+ },
+ "Contracts": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "List of contracts"
+ }
+ },
+ "ContractServiceParameters": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "Contract service parameters"
+ }
+ },
+ "PipelineConfiguration": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "Pipeline configuration"
+ }
+ }
+ },
+ "resources": [
+ {
+ "type": "Microsoft.ContainerInstance/containerGroups",
+ "apiVersion": "2023-05-01",
+ "name": "[parameters('name')]",
+ "location": "[parameters('location')]",
+ "properties": {
+ "confidentialComputeProperties": {
+ "ccePolicy": "[parameters('ccePolicy')]"
+ },
+ "containers": [
+ {
+ "name": "depa-training",
+ "properties": {
+ "image": "[concat(parameters('containerRegistry'), '/depa-training:latest')]",
+ "command": [
+ "/bin/bash",
+ "run.sh"
+ ],
+ "environmentVariables": [],
+ "volumeMounts": [
+ {
+ "name": "remotemounts",
+ "mountPath": "/mnt/remote"
+ }
+ ],
+ "resources": {
+ "requests": {
+ "cpu": 3,
+ "memoryInGB": 12
+ }
+ }
+ }
+ },
+ {
+ "name": "encrypted-storage-sidecar",
+ "properties": {
+ "image": "[concat(parameters('containerRegistry'), '/depa-training-encfs:latest')]",
+ "command": [
+ "/encfs.sh"
+ ],
+ "environmentVariables": [
+ {
+ "name": "EncfsSideCarArgs",
+ "value": "[parameters('EncfsSideCarArgs')]"
+ },
+ {
+ "name": "ContractService",
+ "value": "[parameters('ContractService')]"
+ },
+ {
+ "name": "Contracts",
+ "value": "[parameters('Contracts')]"
+ },
+ {
+ "name": "ContractServiceParameters",
+ "value": "[parameters('ContractServiceParameters')]"
+ },
+ {
+ "name": "PipelineConfiguration",
+ "value": "[parameters('PipelineConfiguration')]"
+ }
+ ],
+ "volumeMounts": [
+ {
+ "name": "remotemounts",
+ "mountPath": "/mnt/remote"
+ }
+ ],
+ "securityContext": {
+ "privileged": "true"
+ },
+ "resources": {
+ "requests": {
+ "cpu": 0.5,
+ "memoryInGB": 2
+ }
+ }
+ }
+ }
+ ],
+ "sku": "Confidential",
+ "osType": "Linux",
+ "restartPolicy": "[parameters('restartPolicy')]",
+ "volumes": [
+ {
+ "name": "remotemounts",
+ "emptydir": {}
+ }
+ ]
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/scenarios/cifar10/deployment/azure/deploy.sh b/scenarios/cifar10/deployment/azure/deploy.sh
new file mode 100755
index 0000000..f1c5b02
--- /dev/null
+++ b/scenarios/cifar10/deployment/azure/deploy.sh
@@ -0,0 +1,131 @@
+#!/bin/bash
+
+set -e
+
+while getopts ":c:p:" options; do
+ case $options in
+ c)contract=$OPTARG;;
+ p)pipelineConfiguration=$OPTARG;;
+ esac
+done
+
+if [[ -z "${contract}" ]]; then
+ echo "No contract specified"
+ exit 1
+fi
+
+if [[ -z "${pipelineConfiguration}" ]]; then
+ echo "No pipeline configuration specified"
+ exit 1
+fi
+
+if [[ -z "${AZURE_KEYVAULT_ENDPOINT}" ]]; then
+ echo "Environment variable AZURE_KEYVAULT_ENDPOINT not defined"
+fi
+
+echo Obtaining contract service parameters...
+
+CONTRACT_SERVICE_URL=${CONTRACT_SERVICE_URL:-"https://localhost:8000"}
+export CONTRACT_SERVICE_PARAMETERS=$(curl -k -f $CONTRACT_SERVICE_URL/parameters | base64 --wrap=0)
+
+echo Computing CCE policy...
+envsubst < ../../policy/policy-in-template.json > /tmp/policy-in.json
+export CCE_POLICY=$(az confcom acipolicygen -i /tmp/policy-in.json --debug-mode)
+export CCE_POLICY_HASH=$(go run $TOOLS_HOME/securitypolicydigest/main.go -p $CCE_POLICY)
+echo "Training container policy hash $CCE_POLICY_HASH"
+
+export CONTRACTS=$contract
+export PIPELINE_CONFIGURATION=`cat $pipelineConfiguration | base64 --wrap=0`
+
+function generate_encrypted_filesystem_information() {
+ end=`date -u -d "60 minutes" '+%Y-%m-%dT%H:%MZ'`
+ CIFAR10_SAS_TOKEN=$(az storage blob generate-sas --account-name $AZURE_STORAGE_ACCOUNT_NAME --container-name $AZURE_CIFAR10_CONTAINER_NAME --permissions r --name data.img --expiry $end --only-show-errors)
+ export CIFAR10_SAS_TOKEN="$(echo -n $CIFAR10_SAS_TOKEN | tr -d \")"
+ export CIFAR10_SAS_TOKEN="?$CIFAR10_SAS_TOKEN"
+
+ MODEL_SAS_TOKEN=$(az storage blob generate-sas --account-name $AZURE_STORAGE_ACCOUNT_NAME --container-name $AZURE_MODEL_CONTAINER_NAME --permissions r --name data.img --expiry $end --only-show-errors)
+ export MODEL_SAS_TOKEN=$(echo $MODEL_SAS_TOKEN | tr -d \")
+ export MODEL_SAS_TOKEN="?$MODEL_SAS_TOKEN"
+
+ OUTPUT_SAS_TOKEN=$(az storage blob generate-sas --account-name $AZURE_STORAGE_ACCOUNT_NAME --container-name $AZURE_OUTPUT_CONTAINER_NAME --permissions rw --name data.img --expiry $end --only-show-errors)
+ export OUTPUT_SAS_TOKEN=$(echo $OUTPUT_SAS_TOKEN | tr -d \")
+ export OUTPUT_SAS_TOKEN="?$OUTPUT_SAS_TOKEN"
+
+ # Obtain the token based on the AKV resource endpoint subdomain
+ if [[ "$AZURE_KEYVAULT_ENDPOINT" == *".vault.azure.net" ]]; then
+ export BEARER_TOKEN=$(az account get-access-token --resource https://vault.azure.net | jq -r .accessToken)
+ echo "Importing keys to AKV key vaults can be only of type RSA-HSM"
+ export AZURE_AKV_KEY_TYPE="RSA-HSM"
+ elif [[ "$AZURE_KEYVAULT_ENDPOINT" == *".managedhsm.azure.net" ]]; then
+ export BEARER_TOKEN=$(az account get-access-token --resource https://managedhsm.azure.net | jq -r .accessToken)
+ export AZURE_AKV_KEY_TYPE="oct-HSM"
+ fi
+
+ TMP=$(jq . encrypted-filesystem-config-template.json)
+ TMP=`echo $TMP | \
+ jq '.azure_filesystems[0].azure_url = "https://" + env.AZURE_STORAGE_ACCOUNT_NAME + ".blob.core.windows.net/" + env.AZURE_CIFAR10_CONTAINER_NAME + "/data.img" + env.CIFAR10_SAS_TOKEN' | \
+ jq '.azure_filesystems[0].mount_point = "/mnt/remote/cifar10"' | \
+ jq '.azure_filesystems[0].key.kid = "CIFAR10FilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[0].key.kty = env.AZURE_AKV_KEY_TYPE' | \
+ jq '.azure_filesystems[0].key.akv.endpoint = env.AZURE_KEYVAULT_ENDPOINT' | \
+ jq '.azure_filesystems[0].key.akv.bearer_token = env.BEARER_TOKEN' | \
+ jq '.azure_filesystems[0].key_derivation.label = "CIFAR10FilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[0].key_derivation.salt = "9b53cddbe5b78a0b912a8f05f341bcd4dd839ea85d26a08efaef13e696d999f4"'`
+
+ TMP=`echo $TMP | \
+ jq '.azure_filesystems[1].azure_url = "https://" + env.AZURE_STORAGE_ACCOUNT_NAME + ".blob.core.windows.net/" + env.AZURE_MODEL_CONTAINER_NAME + "/data.img" + env.MODEL_SAS_TOKEN' | \
+ jq '.azure_filesystems[1].mount_point = "/mnt/remote/model"' | \
+ jq '.azure_filesystems[1].key.kid = "ModelFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[1].key.kty = env.AZURE_AKV_KEY_TYPE' | \
+ jq '.azure_filesystems[1].key.akv.endpoint = env.AZURE_KEYVAULT_ENDPOINT' | \
+ jq '.azure_filesystems[1].key.akv.bearer_token = env.BEARER_TOKEN' | \
+ jq '.azure_filesystems[1].key_derivation.label = "ModelFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[1].key_derivation.salt = "9b53cddbe5b78a0b912a8f05f341bcd4dd839ea85d26a08efaef13e696d999f4"'`
+
+ TMP=`echo $TMP | \
+ jq '.azure_filesystems[2].azure_url = "https://" + env.AZURE_STORAGE_ACCOUNT_NAME + ".blob.core.windows.net/" + env.AZURE_OUTPUT_CONTAINER_NAME + "/data.img" + env.OUTPUT_SAS_TOKEN' | \
+ jq '.azure_filesystems[2].mount_point = "/mnt/remote/output"' | \
+ jq '.azure_filesystems[2].key.kid = "OutputFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[2].key.kty = env.AZURE_AKV_KEY_TYPE' | \
+ jq '.azure_filesystems[2].key.akv.endpoint = env.AZURE_KEYVAULT_ENDPOINT' | \
+ jq '.azure_filesystems[2].key.akv.bearer_token = env.BEARER_TOKEN' | \
+ jq '.azure_filesystems[2].key_derivation.label = "OutputFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[2].key_derivation.salt = "9b53cddbe5b78a0b912a8f05f341bcd4dd839ea85d26a08efaef13e696d999f4"'`
+
+ ENCRYPTED_FILESYSTEM_INFORMATION=`echo $TMP | base64 --wrap=0`
+}
+
+echo Generating encrypted file system information...
+generate_encrypted_filesystem_information
+echo $ENCRYPTED_FILESYSTEM_INFORMATION > /tmp/encrypted-filesystem-config.json
+export ENCRYPTED_FILESYSTEM_INFORMATION
+
+echo Generating parameters for ACI deployment...
+TMP=$(jq '.containerRegistry.value = env.CONTAINER_REGISTRY' aci-parameters-template.json)
+TMP=`echo $TMP | jq '.ccePolicy.value = env.CCE_POLICY'`
+TMP=`echo $TMP | jq '.EncfsSideCarArgs.value = env.ENCRYPTED_FILESYSTEM_INFORMATION'`
+TMP=`echo $TMP | jq '.ContractService.value = env.CONTRACT_SERVICE_URL'`
+TMP=`echo $TMP | jq '.ContractServiceParameters.value = env.CONTRACT_SERVICE_PARAMETERS'`
+TMP=`echo $TMP | jq '.Contracts.value = env.CONTRACTS'`
+TMP=`echo $TMP | jq '.PipelineConfiguration.value = env.PIPELINE_CONFIGURATION'`
+echo $TMP > /tmp/aci-parameters.json
+
+echo Deploying training clean room...
+
+echo "Checking if resource group $AZURE_RESOURCE_GROUP exists..."
+RG_EXISTS=$(az group exists --name $AZURE_RESOURCE_GROUP)
+
+if [ "$RG_EXISTS" == "false" ]; then
+ echo "Resource group $AZURE_RESOURCE_GROUP does not exist. Creating it now..."
+ # Create the resource group
+ az group create --name $AZURE_RESOURCE_GROUP --location $AZURE_LOCATION
+else
+ echo "Resource group $AZURE_RESOURCE_GROUP already exists. Skipping creation."
+fi
+
+az deployment group create \
+ --resource-group $AZURE_RESOURCE_GROUP \
+ --template-file arm-template.json \
+ --parameters @/tmp/aci-parameters.json
+
+echo Deployment complete.
\ No newline at end of file
diff --git a/scenarios/mnist/deployment/aci/encrypted-filesystem-config-template.json b/scenarios/cifar10/deployment/azure/encrypted-filesystem-config-template.json
similarity index 52%
rename from scenarios/mnist/deployment/aci/encrypted-filesystem-config-template.json
rename to scenarios/cifar10/deployment/azure/encrypted-filesystem-config-template.json
index b710182..2af9e95 100644
--- a/scenarios/mnist/deployment/aci/encrypted-filesystem-config-template.json
+++ b/scenarios/cifar10/deployment/azure/encrypted-filesystem-config-template.json
@@ -5,69 +5,66 @@
"azure_url_private": false,
"read_write": false,
"mount_point": "",
- "key":{
+ "key": {
"kid": "",
"kty": "",
"authority": {
- "endpoint": "sharedneu.neu.attest.azure.net"
+ "endpoint": "sharedneu.neu.attest.azure.net"
},
"akv": {
- "endpoint": "",
- "api_version": "api-version=7.3-preview",
- "bearer_token": ""
+ "endpoint": "",
+ "api_version": "api-version=7.3-preview",
+ "bearer_token": ""
}
},
- "key_derivation":
- {
+ "key_derivation": {
"salt": "",
"label": ""
- }
+ }
},
{
"azure_url": "",
"azure_url_private": false,
"read_write": false,
"mount_point": "",
- "key":{
+ "key": {
"kid": "",
"kty": "",
"authority": {
- "endpoint": "sharedneu.neu.attest.azure.net"
+ "endpoint": "sharedneu.neu.attest.azure.net"
},
"akv": {
- "endpoint": "",
- "api_version": "api-version=7.3-preview",
- "bearer_token": ""
+ "endpoint": "",
+ "api_version": "api-version=7.3-preview",
+ "bearer_token": ""
}
},
- "key_derivation":
- {
+ "key_derivation": {
"salt": "",
"label": ""
- }
+ }
},
{
"azure_url": "",
"azure_url_private": false,
- "read_write": false,
+ "read_write": true,
"mount_point": "",
- "key":{
+ "key": {
"kid": "",
"kty": "",
"authority": {
- "endpoint": "sharedneu.neu.attest.azure.net"
+ "endpoint": "sharedneu.neu.attest.azure.net"
},
"akv": {
- "endpoint": "",
- "api_version": "api-version=7.3-preview",
- "bearer_token": ""
+ "endpoint": "",
+ "api_version": "api-version=7.3-preview",
+ "bearer_token": ""
}
},
- "key_derivation":
- {
+ "key_derivation": {
"salt": "",
"label": ""
- }
+ }
}
- ]
+ ]
}
\ No newline at end of file
diff --git a/scenarios/mnist/data/generatefs.sh b/scenarios/cifar10/deployment/azure/generatefs.sh
similarity index 100%
rename from scenarios/mnist/data/generatefs.sh
rename to scenarios/cifar10/deployment/azure/generatefs.sh
diff --git a/scenarios/mnist/data/importkey-config-template.json b/scenarios/cifar10/deployment/azure/importkey-config-template.json
similarity index 100%
rename from scenarios/mnist/data/importkey-config-template.json
rename to scenarios/cifar10/deployment/azure/importkey-config-template.json
diff --git a/scenarios/cifar10/deployment/local/docker-compose-modelsave.yml b/scenarios/cifar10/deployment/local/docker-compose-modelsave.yml
new file mode 100644
index 0000000..2957150
--- /dev/null
+++ b/scenarios/cifar10/deployment/local/docker-compose-modelsave.yml
@@ -0,0 +1,7 @@
+services:
+ model_save:
+ image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}cifar10-model-save:latest
+ volumes:
+ - $MODEL_OUTPUT_PATH:/mnt/model
+ - $MODEL_CONFIG_PATH:/mnt/config/model_config.json
+ command: ["python3", "save_base_model.py"]
diff --git a/scenarios/cifar10/deployment/local/docker-compose-preprocess.yml b/scenarios/cifar10/deployment/local/docker-compose-preprocess.yml
new file mode 100644
index 0000000..bc171d0
--- /dev/null
+++ b/scenarios/cifar10/deployment/local/docker-compose-preprocess.yml
@@ -0,0 +1,7 @@
+services:
+ cifar10:
+ image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}preprocess-cifar10:latest
+ volumes:
+ - $CIFAR10_INPUT_PATH:/mnt/input/data
+ - $CIFAR10_OUTPUT_PATH:/mnt/output/preprocessed
+ command: ["python3", "preprocess_cifar10.py"]
diff --git a/scenarios/cifar10/deployment/local/docker-compose-train.yml b/scenarios/cifar10/deployment/local/docker-compose-train.yml
new file mode 100644
index 0000000..336ce34
--- /dev/null
+++ b/scenarios/cifar10/deployment/local/docker-compose-train.yml
@@ -0,0 +1,10 @@
+services:
+ train:
+ image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}depa-training:latest
+ volumes:
+ - $CIFAR10_INPUT_PATH:/mnt/remote/cifar10
+ - $MODEL_INPUT_PATH:/mnt/remote/model
+ - $MODEL_OUTPUT_PATH:/mnt/remote/output
+ - $CONFIGURATION_PATH:/mnt/remote/config
+ command: ["/bin/bash", "run.sh"]
+
\ No newline at end of file
diff --git a/scenarios/cifar10/deployment/local/preprocess.sh b/scenarios/cifar10/deployment/local/preprocess.sh
new file mode 100755
index 0000000..53510a2
--- /dev/null
+++ b/scenarios/cifar10/deployment/local/preprocess.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+
+export REPO_ROOT="$(git rev-parse --show-toplevel)"
+export SCENARIO="cifar10"
+export DATA_DIR=$REPO_ROOT/scenarios/$SCENARIO/data
+export CIFAR10_INPUT_PATH=$DATA_DIR
+export CIFAR10_OUTPUT_PATH=$DATA_DIR/preprocessed
+mkdir -p $CIFAR10_OUTPUT_PATH
+docker compose -f docker-compose-preprocess.yml up --remove-orphans
diff --git a/scenarios/cifar10/deployment/local/save-model.sh b/scenarios/cifar10/deployment/local/save-model.sh
new file mode 100755
index 0000000..b154c7b
--- /dev/null
+++ b/scenarios/cifar10/deployment/local/save-model.sh
@@ -0,0 +1,9 @@
+#!/bin/bash
+
+export REPO_ROOT="$(git rev-parse --show-toplevel)"
+export SCENARIO="cifar10"
+export MODEL_OUTPUT_PATH=$REPO_ROOT/scenarios/$SCENARIO/modeller/models
+mkdir -p $MODEL_OUTPUT_PATH
+export MODEL_CONFIG_PATH=$REPO_ROOT/scenarios/$SCENARIO/config/model_config.json
+
+docker compose -f docker-compose-modelsave.yml up --remove-orphans
\ No newline at end of file
diff --git a/scenarios/cifar10/deployment/local/train.sh b/scenarios/cifar10/deployment/local/train.sh
new file mode 100755
index 0000000..7aca6d3
--- /dev/null
+++ b/scenarios/cifar10/deployment/local/train.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+
+export REPO_ROOT="$(git rev-parse --show-toplevel)"
+export SCENARIO="cifar10"
+
+export DATA_DIR=$REPO_ROOT/scenarios/$SCENARIO/data
+export MODEL_DIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+export CIFAR10_INPUT_PATH=$DATA_DIR/preprocessed
+
+export MODEL_INPUT_PATH=$MODEL_DIR/models
+
+# export MODEL_OUTPUT_PATH=/tmp/output
+export MODEL_OUTPUT_PATH=$MODEL_DIR/output
+sudo rm -rf $MODEL_OUTPUT_PATH
+mkdir -p $MODEL_OUTPUT_PATH
+
+export CONFIGURATION_PATH=$REPO_ROOT/scenarios/$SCENARIO/config
+# export CONFIGURATION_PATH=/tmp
+# cp $PWD/../../config/pipeline_config.json /tmp/pipeline_config.json
+
+# Run consolidate_pipeline.sh to create pipeline_config.json
+$REPO_ROOT/scenarios/$SCENARIO/config/consolidate_pipeline.sh
+
+docker compose -f docker-compose-train.yml up --remove-orphans
diff --git a/scenarios/cifar10/export-variables.sh b/scenarios/cifar10/export-variables.sh
new file mode 100755
index 0000000..54b0936
--- /dev/null
+++ b/scenarios/cifar10/export-variables.sh
@@ -0,0 +1,60 @@
+#!/bin/bash
+
+# Azure Naming Rules:
+#
+# Resource Group:
+# - 1–90 characters
+# - Letters, numbers, underscores, parentheses, hyphens, periods allowed
+# - Cannot end with a period (.)
+# - Case-insensitive, unique within subscription
+#
+# Key Vault:
+# - 3-24 characters
+# - Globally unique name
+# - Lowercase letters, numbers, hyphens only
+# - Must start and end with letter or number
+#
+# Storage Account:
+# - 3-24 characters
+# - Globally unique name
+# - Lowercase letters and numbers only
+#
+# Storage Container:
+# - 3-63 characters
+# - Lowercase letters, numbers, hyphens only
+# - Must start and end with a letter or number
+# - No consecutive hyphens
+# - Unique within storage account
+
+# For cloud resource creation:
+declare -x SCENARIO=cifar10
+declare -x REPO_ROOT="$(git rev-parse --show-toplevel)"
+declare -x CONTAINER_REGISTRY=ispirt.azurecr.io
+declare -x AZURE_LOCATION=centralindia
+declare -x AZURE_SUBSCRIPTION_ID=
+declare -x AZURE_RESOURCE_GROUP=
+declare -x AZURE_KEYVAULT_ENDPOINT=
+declare -x AZURE_STORAGE_ACCOUNT_NAME=
+
+declare -x AZURE_CIFAR10_CONTAINER_NAME=cifar10container
+declare -x AZURE_MODEL_CONTAINER_NAME=modelcontainer
+declare -x AZURE_OUTPUT_CONTAINER_NAME=outputcontainer
+
+# For key import:
+declare -x CONTRACT_SERVICE_URL=https://depa-training-contract-service.centralindia.cloudapp.azure.com:8000
+declare -x TOOLS_HOME=$REPO_ROOT/external/confidential-sidecar-containers/tools
+
+# Export all variables to make them available to other scripts
+export SCENARIO
+export REPO_ROOT
+export CONTAINER_REGISTRY
+export AZURE_LOCATION
+export AZURE_SUBSCRIPTION_ID
+export AZURE_RESOURCE_GROUP
+export AZURE_KEYVAULT_ENDPOINT
+export AZURE_STORAGE_ACCOUNT_NAME
+export AZURE_CIFAR10_CONTAINER_NAME
+export AZURE_MODEL_CONTAINER_NAME
+export AZURE_OUTPUT_CONTAINER_NAME
+export CONTRACT_SERVICE_URL
+export TOOLS_HOME
\ No newline at end of file
diff --git a/scenarios/cifar10/policy/policy-in-template.json b/scenarios/cifar10/policy/policy-in-template.json
new file mode 100644
index 0000000..c093bdd
--- /dev/null
+++ b/scenarios/cifar10/policy/policy-in-template.json
@@ -0,0 +1,63 @@
+{
+ "version": "1.0",
+ "containers": [
+ {
+ "containerImage": "$CONTAINER_REGISTRY/depa-training:latest",
+ "command": [
+ "/bin/bash",
+ "run.sh"
+ ],
+ "environmentVariables": [],
+ "mounts": [
+ {
+ "mountType": "emptyDir",
+ "mountPath": "/mnt/remote",
+ "readonly": false
+ }
+ ]
+ },
+ {
+ "containerImage": "$CONTAINER_REGISTRY/depa-training-encfs:latest",
+ "environmentVariables": [
+ {
+ "name" : "EncfsSideCarArgs",
+ "value" : ".+",
+ "strategy" : "re2"
+ },
+ {
+ "name": "ContractService",
+ "value": ".+",
+ "strategy": "re2"
+ },
+ {
+ "name": "ContractServiceParameters",
+ "value": "$CONTRACT_SERVICE_PARAMETERS",
+ "strategy": "string"
+ },
+ {
+ "name": "Contracts",
+ "value": ".+",
+ "strategy": "re2"
+ },
+ {
+ "name": "PipelineConfiguration",
+ "value": ".+",
+ "strategy": "re2"
+ }
+ ],
+ "command": [
+ "/encfs.sh"
+ ],
+ "securityContext": {
+ "privileged": "true"
+ },
+ "mounts": [
+ {
+ "mountType": "emptyDir",
+ "mountPath": "/mnt/remote",
+ "readonly": false
+ }
+ ]
+ }
+ ]
+}
diff --git a/scenarios/cifar10/src/model_constructor.py b/scenarios/cifar10/src/model_constructor.py
new file mode 100644
index 0000000..1d01585
--- /dev/null
+++ b/scenarios/cifar10/src/model_constructor.py
@@ -0,0 +1,362 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
+
+from typing import Any, Dict, List, Tuple, Callable
+import types
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+APPROVED_NAMESPACES = {
+ "torch": torch,
+ "nn": nn,
+ "F": F,
+}
+
+# Security controls
+ALLOWED_OP_PREFIXES = {"F.", "torch.nn.functional."} # Allow only torch.nn.functional.* by default
+ALLOWED_OPS = {
+ "torch.cat", "torch.stack", "torch.concat", "torch.flatten", "torch.reshape", "torch.permute", "torch.transpose",
+ "torch.unsqueeze", "torch.squeeze", "torch.chunk", "torch.split", "torch.gather", "torch.index_select", "torch.narrow",
+ "torch.sum", "torch.mean", "torch.std", "torch.var", "torch.max", "torch.min", "torch.argmax", "torch.argmin", "torch.norm",
+ "torch.exp", "torch.log", "torch.log1p", "torch.sigmoid", "torch.tanh", "torch.softmax", "torch.log_softmax", "torch.relu", "torch.gelu",
+ "torch.matmul", "torch.mm", "torch.bmm", "torch.addmm", "torch.einsum",
+ "torch.roll", "torch.flip", "torch.rot90", "torch.rot180", "torch.rot270", "torch.rot360",
+}
+
+# Denylist of potentially dangerous kwarg names (case-insensitive)
+DENYLIST_ARG_NAMES = {
+ "out", # in-place writes to user-provided buffers
+ "file", "filename", "path", "dir", "directory", # filesystem
+ "map_location", # avoid device remap surprises
+}
+
+# DoS safeguards
+MAX_FORWARD_STEPS = 200
+MAX_OPS_PER_STEP = 10
+
+
+def _resolve_submodule(path: str) -> Any:
+ """Resolve dotted path like 'nn.Conv2d' or 'torch.sigmoid' to an object.
+ Raises AttributeError if resolution fails.
+ """
+ try:
+ if not isinstance(path, str):
+ raise TypeError("path must be a string")
+ parts = path.split(".")
+ if parts[0] in APPROVED_NAMESPACES:
+ obj = APPROVED_NAMESPACES[parts[0]]
+ else:
+ # allow direct module names like 'math' if needed
+ raise AttributeError(f"Unknown root namespace '{parts[0]}' in path '{path}'")
+ for p in parts[1:]:
+ try:
+ obj = getattr(obj, p)
+ except AttributeError:
+ raise AttributeError(f"Could not resolve attribute '{p}' in path '{path}'")
+ return obj
+ except Exception as e:
+ raise RuntimeError(f"Error resolving dotted path '{path}': {str(e)}") from e
+
+
+def _replace_placeholders(obj: Any, params: Dict[str, Any]) -> Any:
+ """Recursively replace strings of the form '$name' using params mapping."""
+ try:
+ if isinstance(obj, str) and obj.startswith("$"):
+ key = obj[1:]
+ if key not in params:
+ raise KeyError(f"Placeholder '{obj}' not found in params {params}")
+ return params[key]
+ elif isinstance(obj, dict):
+ try:
+ return {k: _replace_placeholders(v, params) for k, v in obj.items()}
+ except Exception as e:
+ raise RuntimeError(f"Error replacing placeholders in dict: {str(e)}") from e
+ elif isinstance(obj, (list, tuple)):
+ try:
+ seq_type = list if isinstance(obj, list) else tuple
+ return seq_type(_replace_placeholders(x, params) for x in obj)
+ except Exception as e:
+ raise RuntimeError(f"Error replacing placeholders in sequence: {str(e)}") from e
+ else:
+ return obj
+ except Exception as e:
+ raise RuntimeError(f"Error in placeholder replacement: {str(e)}") from e
+
+
+class ModelFactory:
+ """Factory for building PyTorch nn.Module instances from config dicts.
+
+ Public API:
+ ModelFactory.load_from_dict(config: dict) -> nn.Module
+ """
+
+ @classmethod
+ def load_from_dict(cls, config: Dict[str, Any]) -> nn.Module:
+ """Create an nn.Module instance from a top-level config.
+
+ The config may define 'submodules' (a dict of reusable component templates) and
+ a top-level 'layers' and 'forward' graph. Submodules are used by layers that have
+ a 'submodule' key and are instantiated with their provided params.
+ """
+ try:
+ if not isinstance(config, dict):
+ raise TypeError("Config must be a dictionary")
+
+ submodules_defs = config.get("submodules", {})
+
+ def create_instance_from_def(def_cfg: Dict[str, Any], provided_params: Dict[str, Any]):
+ try:
+ # Replace placeholders in the def_cfg copy
+ # Deep copy not strictly necessary since we replace on the fly
+ replaced_cfg = {
+ k: (_replace_placeholders(v, provided_params) if k in ("layers",) or isinstance(v, dict) else v)
+ for k, v in def_cfg.items()
+ }
+ # Build module from replaced config (submodule templates should not themselves contain further 'submodules')
+ return cls._build_module_from_config(replaced_cfg, submodules_defs)
+ except Exception as e:
+ raise RuntimeError(f"Error creating instance from definition: {str(e)}") from e
+
+ # When a layer entry references a 'submodule', we instantiate it using template from submodules_defs
+ return cls._build_module_from_config(config, submodules_defs)
+ except Exception as e:
+ raise RuntimeError(f"Error loading model from config: {str(e)}") from e
+
+ @classmethod
+ def _build_module_from_config(cls, config: Dict[str, Any], submodules_defs: Dict[str, Any]) -> nn.Module:
+ try:
+ layers_cfg = config.get("layers", {})
+ forward_cfg = config.get("forward", [])
+ input_names = config.get("input", [])
+ output_names = config.get("output", [])
+
+ # Create dynamic module class
+ class DynamicModule(nn.Module):
+ def __init__(self):
+ try:
+ super().__init__()
+ # ModuleDict to register submodules / layers
+ self._layers = nn.ModuleDict()
+ # Save forward graph and io names
+ self._forward_cfg = forward_cfg
+ self._input_names = input_names
+ self._output_names = output_names
+
+ # Build each layer / submodule
+ for name, entry in layers_cfg.items():
+ try:
+ if "class" in entry:
+ cls_obj = _resolve_submodule(entry["class"]) # e.g. nn.Conv2d
+ if not (isinstance(cls_obj, type) and issubclass(cls_obj, nn.Module)):
+ raise TypeError(f"Layer '{name}' class must be an nn.Module subclass, got {cls_obj}")
+ params = entry.get("params", {})
+ inst_params = _replace_placeholders(params, {}) # top-level layers likely have no placeholders
+ module = cls_obj(**inst_params)
+ self._layers[name] = module
+ elif "submodule" in entry:
+ sub_name = entry["submodule"]
+ if sub_name not in submodules_defs:
+ raise KeyError(f"Submodule '{sub_name}' not found in submodules definitions")
+ sub_def = submodules_defs[sub_name]
+ provided_params = entry.get("params", {})
+ # Replace placeholders inside sub_def using provided_params
+ # We create a fresh instance of submodule by calling helper
+ sub_inst = cls._instantiate_submodule(sub_def, provided_params, submodules_defs)
+ self._layers[name] = sub_inst
+ else:
+ raise KeyError(f"Layer '{name}' must contain either 'class' or 'submodule' key")
+ except Exception as e:
+ raise RuntimeError(f"Error building layer '{name}': {str(e)}") from e
+ except Exception as e:
+ raise RuntimeError(f"Error initializing DynamicModule: {str(e)}") from e
+
+ def forward(self, *args, **kwargs):
+ try:
+ # Map inputs
+ env: Dict[str, Any] = {}
+ # assign by position
+ for i, in_name in enumerate(self._input_names):
+ if i < len(args):
+ env[in_name] = args[i]
+ elif in_name in kwargs:
+ env[in_name] = kwargs[in_name]
+ else:
+ raise ValueError(f"Missing input '{in_name}' for forward; provided args={len(args)}, kwargs keys={list(kwargs.keys())}")
+
+ # Execute forward graph
+ if len(self._forward_cfg) > MAX_FORWARD_STEPS:
+ raise RuntimeError(f"Too many forward steps: {len(self._forward_cfg)} > {MAX_FORWARD_STEPS}. This is a security feature to prevent infinite loops.")
+
+ for idx, step in enumerate(self._forward_cfg):
+ try:
+ ops = step.get("ops", [])
+ if isinstance(ops, (list, tuple)) and len(ops) > MAX_OPS_PER_STEP:
+ raise RuntimeError(f"Too many ops in step {idx}: {len(ops)} > {MAX_OPS_PER_STEP}")
+ inputs_spec = step.get("input", [])
+ out_name = step.get("output", None)
+
+ # Resolve input tensors for this step
+ # inputs_spec might be: ['x'] or ['x1','x2'] or [['x3','encoded_feature']]
+ if len(inputs_spec) == 1 and isinstance(inputs_spec[0], (list, tuple)):
+ args_list = [env[n] for n in inputs_spec[0]]
+ else:
+ args_list = [env[n] for n in inputs_spec]
+
+ # Apply ops sequentially
+ current = args_list
+ for op in ops:
+ try:
+ # op can be string like 'conv1' or dotted 'F.relu'
+ # or can be a list like ['torch.flatten', {'start_dim':1}]
+ op_callable, op_kwargs = self._resolve_op(op)
+ # Validate kwargs denylist
+ for k in op_kwargs.keys():
+ if isinstance(k, str) and k.lower() in DENYLIST_ARG_NAMES:
+ raise PermissionError(f"Denied kwarg '{k}' for op '{op}'")
+
+ # If op_callable is a module in self._layers, call with module semantics
+ if isinstance(op_callable, str) and op_callable in self._layers:
+ module = self._layers[op_callable]
+ # if current is list of multiple args, pass them all
+ if isinstance(current, (list, tuple)) and len(current) > 1:
+ result = module(*current)
+ else:
+ result = module(current[0])
+ else:
+ # op_callable is a real callable object
+
+ if op_callable in {torch.cat, torch.stack}: # Ops that require a sequence input (instead of varargs)
+ # Wrap current into a list
+ result = op_callable(list(current), **op_kwargs)
+ elif isinstance(current, (list, tuple)):
+ result = op_callable(*current, **op_kwargs)
+ else:
+ result = op_callable(current, **op_kwargs)
+
+ # prepare current for next op
+ current = [result]
+ except Exception as e:
+ raise RuntimeError(f"Error applying operation '{op}': {str(e)}") from e
+
+ # write outputs back into env
+ if out_name is None:
+ continue
+ if isinstance(out_name, (list, tuple)):
+ # if step produces multiple outputs (rare), try unpacking
+ if len(out_name) == 1:
+ env[out_name[0]] = current[0]
+ else:
+ # try to unpack
+ try:
+ for k, v in zip(out_name, current[0]):
+ env[k] = v
+ except Exception as e:
+ raise RuntimeError(f"Could not assign multiple outputs for step {step}: {e}")
+ else:
+ env[out_name] = current[0]
+ except Exception as e:
+ raise RuntimeError(f"Error executing forward step: {str(e)}") from e
+
+ # Build function return
+ if len(self._output_names) == 0:
+ return None
+ if len(self._output_names) == 1:
+ return env[self._output_names[0]]
+ return tuple(env[n] for n in self._output_names)
+ except Exception as e:
+ raise RuntimeError(f"Error in forward pass: {str(e)}") from e
+
+ def _resolve_op(self, op_spec):
+ """Return (callable_or_module_name, kwargs)
+
+ If op_spec is a string and matches a layer name -> returns (layer_name_str, {}).
+ If op_spec is a string dotted path -> resolve dotted and return (callable, {}).
+ If op_spec is a list like ["torch.flatten", {"start_dim":1}] -> resolve and return (callable, kwargs)
+ """
+ try:
+ # module reference by name
+ if isinstance(op_spec, str):
+ if op_spec in self._layers:
+ return (op_spec, {})
+ # dotted function (F.relu, torch.sigmoid)
+ if not _is_allowed_op_path(op_spec):
+ raise PermissionError(f"Operation '{op_spec}' is not allowed")
+ callable_obj = _resolve_submodule(op_spec)
+ if not callable(callable_obj):
+ raise TypeError(f"Resolved object for '{op_spec}' is not callable")
+ return (callable_obj, {})
+ elif isinstance(op_spec, (list, tuple)):
+ if len(op_spec) == 0:
+ raise ValueError("Empty op_spec list")
+ path = op_spec[0]
+ kwargs = op_spec[1] if len(op_spec) > 1 else {}
+ if not _is_allowed_op_path(path):
+ raise PermissionError(f"Operation '{path}' is not allowed")
+ callable_obj = _resolve_submodule(path)
+ if not callable(callable_obj):
+ raise TypeError(f"Resolved object for '{path}' is not callable")
+ return (callable_obj, kwargs)
+ else:
+ raise TypeError(f"Unsupported op spec type: {type(op_spec)}")
+ except Exception as e:
+ raise RuntimeError(f"Error resolving operation '{op_spec}': {str(e)}") from e
+
+ # Instantiate dynamic module and return
+ dyn = DynamicModule()
+ return dyn
+ except Exception as e:
+ raise RuntimeError(f"Error building module from config: {str(e)}") from e
+
+ @classmethod
+ def _instantiate_submodule(cls, sub_def: Dict[str, Any], provided_params: Dict[str, Any], submodules_defs: Dict[str, Any]) -> nn.Module:
+ """Instantiate a submodule defined in 'submodules' using provided_params to replace placeholders.
+
+ provided_params are used to replace occurrences of strings like '$in_ch' inside the sub_def's 'layers' params.
+ """
+ try:
+ # Deep replace placeholders within sub_def copy
+ # We'll construct a new config where the "layers"->"params" are substituted
+ replaced = {}
+ for k, v in sub_def.items():
+ try:
+ if k == "layers":
+ new_layers = {}
+ for lname, lentry in v.items():
+ new_entry = dict(lentry)
+ if "params" in lentry:
+ new_entry["params"] = _replace_placeholders(lentry["params"], provided_params)
+ new_layers[lname] = new_entry
+ replaced[k] = new_layers
+ else:
+ # copy other keys directly (input/forward/output)
+ replaced[k] = v
+ except Exception as e:
+ raise RuntimeError(f"Error processing key '{k}': {str(e)}") from e
+
+ # Now build a module from this replaced config. This call may in turn instantiate nested submodules.
+ return cls._build_module_from_config(replaced, submodules_defs)
+ except Exception as e:
+ raise RuntimeError(f"Error instantiating submodule: {str(e)}") from e
+
+
+def _is_allowed_op_path(path: str) -> bool:
+ if any(path.startswith(p) for p in ALLOWED_OP_PREFIXES):
+ return True
+ return path in ALLOWED_OPS
\ No newline at end of file
diff --git a/scenarios/cifar10/src/preprocess_cifar10.py b/scenarios/cifar10/src/preprocess_cifar10.py
new file mode 100644
index 0000000..214ad9e
--- /dev/null
+++ b/scenarios/cifar10/src/preprocess_cifar10.py
@@ -0,0 +1,52 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
+import os
+import torch
+import torchvision
+import torchvision.transforms as transforms
+from safetensors.torch import save_file as st_save
+
+cifar10_input_folder='/mnt/input/data/'
+
+# Location of preprocessed CIFAR-10 dataset
+cifar10_output_folder='/mnt/output/preprocessed/'
+
+transform = transforms.Compose(
+ [transforms.ToTensor(),
+ transforms.Normalize((0.1307,), (0.3081,))]) # CIFAR-10 mean and std
+
+trainset = torchvision.datasets.CIFAR10(root=cifar10_input_folder, train=True, download=True, transform=transform)
+
+# Build tensors (N, C, H, W) and labels (N,)
+features = []
+targets = []
+for img, label in trainset:
+ features.append(img)
+ targets.append(label)
+
+features = torch.stack(features).to(torch.float32)
+targets = torch.tensor(targets, dtype=torch.int64)
+
+# Ensure output directory exists
+os.makedirs(cifar10_output_folder, exist_ok=True)
+
+# Save as SafeTensors with keys 'features' and 'targets'
+out_path = os.path.join(cifar10_output_folder, 'cifar10-dataset.safetensors')
+st_save({'features': features, 'targets': targets}, out_path)
+
+print(f"Saved CIFAR-10 dataset to {out_path} as SafeTensors with keys 'features' and 'targets'.")
diff --git a/scenarios/cifar10/src/save_base_model.py b/scenarios/cifar10/src/save_base_model.py
new file mode 100644
index 0000000..dbfb7d4
--- /dev/null
+++ b/scenarios/cifar10/src/save_base_model.py
@@ -0,0 +1,60 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import json
+from safetensors.torch import save_file as st_save
+
+from model_constructor import *
+
+model_config_path = "/mnt/config/model_config.json"
+
+with open(model_config_path, 'r') as f:
+ config = json.load(f)
+
+model = ModelFactory.load_from_dict(config)
+
+model.eval()
+st_save(model.state_dict(), "/mnt/model/model.safetensors")
+print("Model saved as safetensors to /mnt/model/model.safetensors")
+
+
+'''
+# Reference model for CIFAR-10
+
+class Net(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv1 = nn.Conv2d(3, 6, 5)
+ self.pool = nn.MaxPool2d(2, 2)
+ self.conv2 = nn.Conv2d(6, 16, 5)
+ self.fc1 = nn.Linear(16 * 5 * 5, 120)
+ self.fc2 = nn.Linear(120, 84)
+ self.fc3 = nn.Linear(84, 10)
+
+ def forward(self, x):
+ x = self.pool(F.relu(self.conv1(x)))
+ x = self.pool(F.relu(self.conv2(x)))
+ x = torch.flatten(x, 1) # flatten all dimensions except batch
+ x = F.relu(self.fc1(x))
+ x = F.relu(self.fc2(x))
+ x = self.fc3(x)
+ return x
+
+'''
\ No newline at end of file
diff --git a/scenarios/covid/.gitignore b/scenarios/covid/.gitignore
index 11ead02..1f98b27 100644
--- a/scenarios/covid/.gitignore
+++ b/scenarios/covid/.gitignore
@@ -1,3 +1,12 @@
**/preprocessed/*.csv
*.bin
-*.img
\ No newline at end of file
+*.img
+*.pth
+*.pt
+*.onnx
+*.npy
+
+# Ignore modeller output folder (relative to repo root)
+modeller/output/
+
+**/__pycache__/
\ No newline at end of file
diff --git a/scenarios/covid/README.md b/scenarios/covid/README.md
index 319bdd2..913d0fb 100644
--- a/scenarios/covid/README.md
+++ b/scenarios/covid/README.md
@@ -1,71 +1,152 @@
-# COVID predictive modelling
+# COVID-19 Predictive Modelling
-This hypothetical scenario involves three training data providers (TDPs), ICMR, COWIN and a state war room, and a TDC who wishes the train a model using datasets from these TDPs. The repository contains sample datasets and a model. The model and datasets are for illustrative purposes only; none of these organizations have been involved in contributing to the code or datasets.
+## Scenario Type
-The end-to-end training pipeline consists of the following phases.
+| Scenario name | Scenario type | Task type | Privacy | No. of TDPs* | Data type (format) | Model type (format) | Join type (No. of datasets) |
+|--------------|---------------|-----------------|--------------|-----------|------------|------------|------------|
+| COVID-19 | Training - Deep Learning | Binary Classification | Differentially Private | 3 | PII tabular data (CSV) | MLP (ONNX) | Horizontal (3)|
+
+---
+
+## Scenario Description
+
+This hypothetical scenario involves three Training Data Providers (TDPs), ICMR (providing Covid test results), COWIN (providing vaccine data) and a State War Room ("Index") (providing patient records), and a Training Data Consumer (TDC) who wishes the train a model for predicting Covid infection using these datasets. The model and datasets are for illustrative purposes only; none of these organizations have been involved in contributing to the code or datasets. The data used in this scenario is not real or representative of any actual data, and has been synthetically generated.
+
+The end-to-end training pipeline consists of the following phases:
1. Data pre-processing and de-identification
2. Data packaging, encryption and upload
3. Model packaging, encryption and upload
4. Encryption key import with key release policies
5. Deployment and execution of CCR
-6. Model decryption
+6. Trained model decryption
## Build container images
-Build container images required for this sample as follows.
+Build container images required for this sample as follows:
```bash
-cd scenarios/covid
+export SCENARIO=covid
+export REPO_ROOT="$(git rev-parse --show-toplevel)"
+cd $REPO_ROOT/scenarios/$SCENARIO
./ci/build.sh
-
```
-This script builds the following container images.
+This script builds the following container images:
- ```preprocess-icmr, preprocess-cowin, preprocess-index```: Containers that pre-process and de-identify datasets.
- ```ccr-model-save```: Container that saves the model to be trained in ONNX format.
+Alternatively, you can pull and use pre-built container images from the ispirt container registry by setting the following environment variable. Docker hub has started throttling which may effect the upload/download time, especially when images are bigger size. So, It is advisable to use other container registries. We are using Azure container registry (ACR) as shown below:
+
+```bash
+export CONTAINER_REGISTRY=ispirt.azurecr.io
+cd $REPO_ROOT/scenarios/$SCENARIO
+./ci/build.sh
+```
+
## Data pre-processing and de-identification
-The folders ```scenarios/covid/data``` contains three sample training datasets. Acting as TDPs for these datasets, run the following scripts to de-identify the datasets.
+The folder ```scenarios/covid/src``` contains scripts for pre-processing and de-identifying sample COVID-19 datasets. Acting as a Training Data Provider (TDP), prepare your datasets:
```bash
-cd scenarios/covid/deployment/docker
+cd $REPO_ROOT/scenarios/$SCENARIO/deployment/local
./preprocess.sh
```
-This script performs pre-processing and de-identification of these datasets before sharing with the TDC.
+The datasets are saved to the [data](./data/) directory.
## Prepare model for training
-Next, acting as a TDC, save a sample model using the following script.
+Next, acting as a Training Data Consumer (TDC), define and save your base model for training using the following script. This calls the [save_base_model.py](./src/save_base_model.py) script, which is a custom script that saves the model to the [models](./modeller/models) directory, as an ONNX file:
```bash
./save-model.sh
```
-This script will save the model as ```scenarios/covid/data/modeller/model/model.onnx.```
-
## Deploy locally
-Assuming you have cleartext access to all the de-identified datasets, you can train the model as follows.
+Assuming you have cleartext access to all the de-identified datasets, you can train the model as follows:
```bash
./train.sh
```
-The script joins the datasets and trains the model using a pipeline configuration defined in [pipeline_config.json](./config/pipeline_config.json). If all goes well, you should see output similar to the following output, and the trained model will be saved under the folder `/tmp/output`.
+The script joins the datasets and trains the model using a pipeline configuration. To modify the various components of the training pipeline, you can edit the training config files in the [config](./config/) directory. The training config files are used to create the pipeline configuration ([pipeline_config.json](./config/pipeline_config.json)) created by consolidating all the TDC's training config files, namely the [model config](./config/model_config.json), [dataset config](./config/dataset_config.json), [loss function config](./config/loss_config.json), [training config](./config/train_config_template.json), [evaluation config](./config/eval_config.json), and if multiple datasets are used, the [data join config](./config/join_config.json). These enable the TDC to design highly customized training pipelines without requiring review and approval of new custom code for each use case—reducing risks from potentially malicious or non-compliant code. The consolidated pipeline configuration is then attested against the signed contract using the TDP’s policy-as-code. If approved, it is executed in the CCR to train the model, which we will deploy in the next section.
+
+```mermaid
+flowchart TD
+
+ subgraph Config Files
+ C1[model_config.json]
+ C2[dataset_config.json]
+ C3[loss_config.json]
+ C4[train_config_template.json]
+ C5[eval_config.json]
+ C6[join_config.json]
+ end
+
+ B[Consolidated into
pipeline_config.json]
+
+ C1 --> B
+ C2 --> B
+ C3 --> B
+ C4 --> B
+ C5 --> B
+ C6 --> B
+
+ B --> D[Attested against contract
using policy-as-code]
+ D --> E{Approved?}
+ E -- Yes --> F[CCR training begins]
+ E -- No --> H[Rejected: fix config]
```
-docker-train-1 | {'input_dataset_path': '/tmp/sandbox_icmr_cowin_index_without_key_identifiers.csv', 'saved_model_path': '/mnt/remote/model/model.onnx', 'saved_model_optimizer': '/mnt/remote/model/dpsgd_model_opimizer.pth', 'saved_weights_path': '', 'batch_size': 2, 'total_epochs': 5, 'max_grad_norm': 0.1, 'epsilon_threshold': 1.0, 'delta': 0.01, 'sample_size': 60000, 'target_variable': 'icmr_a_icmr_test_result', 'test_train_split': 0.2, 'metrics': ['accuracy', 'precision', 'recall']}
-docker-train-1 | Epoch [1/5], Loss: 0.0084
-docker-train-1 | Epoch [2/5], Loss: 0.4231
-docker-train-1 | Epoch [3/5], Loss: 0.0008
-docker-train-1 | Epoch [4/5], Loss: 0.0138
-docker-train-1 | Epoch [5/5], Loss: 0.0489
+
+Note: A few model config variants for training can be found [here](./config/sample_variants).
+
+If all goes well, you should see output similar to the following output, and the trained model and evaluation metrics will be saved under the folder [output](./modeller/output).
+
```
+train-1 | Generating aggregated data in /tmp/covid_joined.csv
+train-1 | Training samples: 1483
+train-1 | Validation samples: 424
+train-1 | Test samples: 212
+train-1 | Dataset constructed from config
+train-1 | Model loaded from ONNX file
+train-1 | Created non-private baseline model for comparison
+train-1 | Optimizer SGD loaded from config
+train-1 | Custom loss function loaded from config
+train-1 | Epoch 1/5 completed | Training Loss: 0.6092 | Epsilon: 0.3496
+train-1 | Epoch 1/5 completed | Validation Loss: 0.7853
+train-1 | Epoch 2/5 completed | Training Loss: 0.8024 | Epsilon: 0.6328
+train-1 | Epoch 2/5 completed | Validation Loss: 0.7784
+train-1 | Epoch 3/5 completed | Training Loss: 0.6584 | Epsilon: 0.8921
+train-1 | Epoch 3/5 completed | Validation Loss: 0.6838
+train-1 | Epoch 4/5 completed | Training Loss: 0.6794 | Epsilon: 1.2342
+train-1 | Epoch 4/5 completed | Validation Loss: 0.6140
+train-1 | Epoch 5/5 completed | Training Loss: 0.5554 | Epsilon: 1.4947
+train-1 | Epoch 5/5 completed | Validation Loss: 0.6091
+train-1 | Non-private baseline model - Epoch 1/5 completed | Training Loss: 0.4564
+train-1 | Non-private baseline model - Epoch 1/5 completed | Validation Loss: 0.6719
+train-1 | Non-private baseline model - Epoch 2/5 completed | Training Loss: 0.5467
+train-1 | Non-private baseline model - Epoch 2/5 completed | Validation Loss: 0.6719
+train-1 | Non-private baseline model - Epoch 3/5 completed | Training Loss: 0.4356
+train-1 | Non-private baseline model - Epoch 3/5 completed | Validation Loss: 0.6719
+train-1 | Non-private baseline model - Epoch 4/5 completed | Training Loss: 0.3814
+train-1 | Non-private baseline model - Epoch 4/5 completed | Validation Loss: 0.6719
+train-1 | Non-private baseline model - Epoch 5/5 completed | Training Loss: 0.3352
+train-1 | Non-private baseline model - Epoch 5/5 completed | Validation Loss: 0.4864
+train-1 | Saving trained model to /mnt/remote/output/trained_model.onnx
+train-1 | Evaluation Metrics: {'test_loss': 0.6747791530951014, 'accuracy': 0.6415094339622641, 'f1_score': 0.5365853658536586, 'roc_auc': 0.6637159032424087}
+train-1 | CCR Training complete!
+train-1 |
+train-1 exited with code 0
+```
+
+The trained model along with sample predictions on the validation set will be saved under the [output](./modeller/output/) directory.
-## Deploy to Azure
+Now that training has run successfully locally, let's move on to the actual execution using a Confidential Clean Room (CCR) equipped with confidential computing, key release policies, and contract-based access control.
+
+## Deploy on CCR
In a more realistic scenario, these datasets will not be available in the clear to the TDC, and the TDC will be required to use a CCR for training her model. The following steps describe the process of sharing encrypted datasets with TDCs and setting up a CCR in Azure for training models. Please stay tuned for CCR on other cloud platforms.
@@ -73,7 +154,7 @@ To deploy in Azure, you will need the following.
- Docker Hub account to store container images. Alternatively, you can use pre-built images from the ```ispirt``` container registry.
- [Azure Key Vault](https://azure.microsoft.com/en-us/products/key-vault/) to store encryption keys and implement secure key release to CCR. You can either you Azure Key Vault Premium (lower cost), or [Azure Key Vault managed HSM](https://learn.microsoft.com/en-us/azure/key-vault/managed-hsm/overview) for enhanced security. Please see instructions below on how to create and setup your AKV instance.
-- Valid Azure subscription with sufficient access to create key vault, storage accounts, storage containers, and Azure Container Instances.
+- Valid Azure subscription with sufficient access to create key vault, storage accounts, storage containers, and Azure Container Instances (ACI).
If you are using your own development environment instead of a dev container or codespaces, you will to install the following dependencies.
@@ -87,122 +168,237 @@ We will be creating the following resources as part of the deployment.
- Azure Key Vault
- Azure Storage account
- Storage containers to host encrypted datasets
-- Azure Container Instances to deploy the CCR and train the model
+- Azure Container Instances (ACI) to deploy the CCR and train the model
-### Push Container Images
+### 1\. Push Container Images
-If you wish to use your own container images, login to docker hub and push containers to your container registry.
+Pre-built container images are available in iSPIRT's container registry, which can be pulled by setting the following environment variable.
-> **Note:** Replace `` the name of your container registry name, preferably use registry services other than docker hub as throttling restrictions will cause delays (or) image push/pull failures
+```bash
+export CONTAINER_REGISTRY=ispirt.azurecr.io
+```
+
+If you wish to use your own container images, login to docker hub (or your container registry of choice) and then build and push the container images to it, so that they can be pulled by the CCR. This is a one-time operation, and you can skip this step if you have already pushed the images to your container registry.
```bash
export CONTAINER_REGISTRY=
-docker login -u ${USERNAME} -p ${PASSWORD} ${CONTAINER_REGISTRY}
+docker login -u -p ${CONTAINER_REGISTRY}
+cd $REPO_ROOT
./ci/push-containers.sh
-cd scenarios/covid
+cd $REPO_ROOT/scenarios/$SCENARIO
./ci/push-containers.sh
```
-### Create Resources
+> **Note:** Replace ``, `` and `` with your container registry name, docker hub username and password respectively. Preferably use registry services other than Docker Hub as throttling restrictions will cause delays (or) image push/pull failures.
+
+---
-Acting as the TDP, we will create a resource group, a key vault instance and storage containers to host encrypted training datasets and encryption keys. In a real deployments, TDPs and TDCs will use their own key vault instance. However, for this sample, we will use one key vault instance to store keys for all datasets and models.
+### 2\. Create Resources
-> **Note:** At this point, automated creation of AKV managed HSMs is not supported.
+First, set up the necessary environment variables for your deployment.
-> **Note:** Replace `` and `` with names of your choice. Storage account names must not container any special characters. Key vault endpoints are of the form `.vault.azure.net` (for Azure Key Vault Premium) and `.managedhsm.azure.net` for AKV managed HSM, **with no leading https**. This endpoint must be the same endpoint you used while creating the contract.
+Option 1: Manually set the environment variables.
```bash
az login
+export SCENARIO=covid
+export CONTAINER_REGISTRY=ispirt.azurecr.io
+export AZURE_LOCATION=northeurope
+export AZURE_SUBSCRIPTION_ID=
export AZURE_RESOURCE_GROUP=
-export AZURE_KEYVAULT_ENDPOINT=
-export AZURE_STORAGE_ACCOUNT_NAME=
+export AZURE_KEYVAULT_ENDPOINT=.vault.azure.net
+export AZURE_STORAGE_ACCOUNT_NAME=
+
export AZURE_ICMR_CONTAINER_NAME=icmrcontainer
export AZURE_COWIN_CONTAINER_NAME=cowincontainer
export AZURE_INDEX_CONTAINER_NAME=indexcontainer
export AZURE_MODEL_CONTAINER_NAME=modelcontainer
export AZURE_OUTPUT_CONTAINER_NAME=outputcontainer
+```
-cd scenarios/covid/data
-./1-create-storage-containers.sh
-./2-create-akv.sh
+Option 2: Configurable script to set the environment variables.
+
+Alternatively, you can edit the values in the [export-variables.sh](./export-variables.sh) script and run it to set the environment variables.
+
+```bash
+./export-variables.sh
+source export-variables.sh
```
-### Sign and Register Contract
+Azure Naming Rules:
+- Resource Group:
+ - 1–90 characters
+ - Letters, numbers, underscores, parentheses, hyphens, periods allowed
+ - Cannot end with a period (.)
+ - Case-insensitive, unique within subscription\
+- Key Vault:
+ - 3-24 characters
+ - Globally unique name
+ - Lowercase letters, numbers, hyphens only
+ - Must start and end with letter or number
+- Storage Account:
+ - 3-24 characters
+ - Globally unique name
+ - Lowercase letters and numbers only
+- Storage Container:
+ - 3-63 characters
+ - Lowercase letters, numbers, hyphens only
+ - Must start and end with a letter or number
+ - No consecutive hyphens
+ - Unique within storage account
+
+---
+
+**Important:**
+
+The values for the environment variables listed below must precisely match the namesake environment variables used during contract signing (next step). Any mismatch will lead to execution failure.
+
+- `SCENARIO`
+- `AZURE_KEYVAULT_ENDPOINT`
+- `CONTRACT_SERVICE_URL`
+- `AZURE_STORAGE_ACCOUNT_NAME`
+- `AZURE_ICMR_CONTAINER_NAME`
+- `AZURE_COWIN_CONTAINER_NAME`
+- `AZURE_INDEX_CONTAINER_NAME`
+
+---
+With the environment variables set, we are ready to create the resources -- Azure Key Vault and Azure Storage containers.
-Next, follow instructions [here](https://github.com/kapilvgit/contract-ledger/blob/675003b83211e6d3d2c15864523bf875e0172cba/demo/contract/README.md) to sign and register a contract with the contract service. You can either deploy your own contract service or use a test contract service hosted at ```https://contract-service.eastus.cloudapp.azure.com:8000/```. The registered contract must contain references to the datasets with matching names, keyIDs and Azure Key Vault endpoints used in this sample. A sample contract template for this scenario is provided [here](./contract/contract.json). After updating, signing and registering the contract, retain the contract service URL and sequence number of the contract for the rest of this sample.
+```bash
+cd $REPO_ROOT/scenarios/$SCENARIO/deployment/azure
+./1-create-storage-containers.sh
+./2-create-akv.sh
+```
+---
-### Import encryption keys
+### 3\. Contract Signing
-Next, use the following script to generate and import encryption keys into Azure Key Vault with a policy based on [policy-in-template.json](./policy/policy-in-template.json). The policy requires that the CCRs run specific containers with a specific configuration which includes the public identity of the contract service. Only CCRs that satisfy this policy will be granted access to the encryption keys.
+Navigate to the [contract-ledger](https://github.com/kapilvgit/contract-ledger/blob/main/README.md) repository and follow the instructions for contract signing.
-> **Note:** Replace `` with the path to and including the `depa-training` folder where the repository was cloned.
+Once the contract is signed, export the contract sequence number as an environment variable in the same terminal where you set the environment variables for the deployment.
```bash
-export CONTRACT_SERVICE_URL=
-export TOOLS_HOME=/external/confidential-sidecar-containers/tools
-./3-import-keys.sh
+export CONTRACT_SEQ_NO=
```
-The generated keys are available as files with the extension `.bin`.
+---
-### Encrypt Datasets and Model
+### 4\. Data Encryption and Upload
-Next, encrypt the datasets and models using keys generated in the previous step.
+Using their respective keys, the TDPs and TDC encrypt their datasets and model (respectively) and upload them to the Storage containers created in the previous step.
+
+Navigate to the [Azure deployment](./deployment/azure/) directory and execute the scripts for key import, data encryption and upload to Azure Blob Storage, in preparation of the CCR deployment.
+
+The import-keys script generates and imports encryption keys into Azure Key Vault with a policy based on [policy-in-template.json](./policy/policy-in-template.json). The policy requires that the CCRs run specific containers with a specific configuration which includes the public identity of the contract service. Only CCRs that satisfy this policy will be granted access to the encryption keys. The generated keys are available as files with the extension `.bin`.
```bash
-cd scenarios/covid/data
-./4-encrypt-data.sh
+export CONTRACT_SERVICE_URL=https://depa-training-contract-service.centralindia.cloudapp.azure.com:8000
+export TOOLS_HOME=$REPO_ROOT/external/confidential-sidecar-containers/tools
+
+./3-import-keys.sh
```
-This step will generate five encrypted file system images (with extension `.img`), three for the datasets, one encrypted file system image containing the model, and one image where the trained model will be stored.
+The data and model are then packaged as encrypted filesystems by the TDPs and TDC using their respective keys, which are saved as `.img` files.
-### Upload Datasets
+```bash
+./4-encrypt-data.sh
+```
-Now upload encrypted datasets to Azure storage containers.
+The encrypted data and model are then uploaded to the Storage containers created in the previous step. The `.img` files are uploaded to the Storage containers as blobs.
```bash
./5-upload-encrypted-data.sh
```
-### Deploy CCR in Azure
+---
-Acting as a TDC, use the following script to deploy the CCR using Confidential Containers on Azure Container Instances.
+### 5\. CCR Deployment
-> **Note:** Replace `` with the sequence number of the contract registered with the contract service.
+With the resources ready, we are ready to deploy the Confidential Clean Room (CCR) for executing the privacy-preserving model training.
```bash
-cd scenarios/covid/deployment/aci
-./deploy.sh -c -p ../../config/pipeline_config.json
+export CONTRACT_SEQ_NO=
+./deploy.sh -c $CONTRACT_SEQ_NO -p ../../config/pipeline_config.json
+```
+
+Set the `$CONTRACT_SEQ_NO` variable to the exact value of the contract sequence number (of format 2.XX). For example, if the number was 2.15, export as:
+
+```bash
+export CONTRACT_SEQ_NO=15
```
This script will deploy the container images from your container registry, including the encrypted filesystem sidecar. The sidecar will generate an SEV-SNP attestation report, generate an attestation token using the Microsoft Azure Attestation (MAA) service, retrieve dataset, model and output encryption keys from the TDP and TDC's Azure Key Vault, train the model, and save the resulting model into TDC's output filesystem image, which the TDC can later decrypt.
-Once the deployment is complete, you can obtain logs from the CCR using the following commands. Note there may be some delay in getting the logs are deployment is complete.
+
+
+**Note:** The completion of this script's execution simply creates a CCR instance, and doesn't indicate whether training has completed or not. The training process might still be ongoing. Poll the container logs (see below) to track progress until training is complete.
+
+### 6\. Monitor Container Logs
+
+Use the following commands to monitor the logs of the deployed containers. You might have to repeatedly poll this command to monitor the training progress:
+
+```bash
+az container logs \
+ --name "depa-training-$SCENARIO" \
+ --resource-group "$AZURE_RESOURCE_GROUP" \
+ --container-name depa-training
+```
+
+You will know training has completed when the logs print "CCR Training complete!".
+
+#### Troubleshooting
+
+In case training fails, you might want to monitor the logs of the encrypted storage sidecar container to see if the encryption process completed successfully:
```bash
-# Obtain logs from the training container
-az container logs --name depa-training-covid --resource-group $AZURE_RESOURCE_GROUP --container-name depa-training
+az container logs --name depa-training-$SCENARIO --resource-group $AZURE_RESOURCE_GROUP --container-name encrypted-storage-sidecar
+```
+
+And to further debug, inspect the logs of the encrypted filesystem sidecar container:
-# Obtain logs from the encrypted filesystem sidecar
-az container logs --name depa-training-covid --resource-group $AZURE_RESOURCE_GROUP --container-name encrypted-storage-sidecar
+```bash
+az container exec \
+ --resource-group $AZURE_RESOURCE_GROUP \
+ --name depa-training-$SCENARIO \
+ --container-name encrypted-storage-sidecar \
+ --exec-command "/bin/sh"
```
-### Download and decrypt trained model
+Once inside the sidecar container shell, view the logs:
-You can download and decrypt the trained model using the following script.
+```bash
+cat log.txt
+```
+Or inspect the individual mounted directories in `mnt/remote/`:
+
+```bash
+cd mnt/remote && ls
+```
+
+### 6\. Download and Decrypt Model
+
+Once training has completed succesfully (The training container logs will mention it explicitly), download and decrypt the trained model and other training outputs.
```bash
-cd scenarios/covid/data
./6-download-decrypt-model.sh
```
-The trained model is available in `output` folder.
+The outputs will be saved to the [output](./modeller/output/) directory.
+
+To check if the trained model is fresh, you can run the following command:
+
+```bash
+stat $REPO_ROOT/scenarios/$SCENARIO/modeller/output/trained_model.onnx
+```
+---
### Clean-up
-You can use the following command to delete the resource group and clean-up all resources used in the demo.
+
+
+You can use the following command to delete the resource group and clean-up all resources used in the demo. Alternatively, you can navigate to the Azure portal and delete the resource group created for this demo.
```bash
az group delete --yes --name $AZURE_RESOURCE_GROUP
-```
+```
\ No newline at end of file
diff --git a/scenarios/covid/ci/Dockerfile.cowin b/scenarios/covid/ci/Dockerfile.cowin
index b4a8c05..7ee8fc0 100644
--- a/scenarios/covid/ci/Dockerfile.cowin
+++ b/scenarios/covid/ci/Dockerfile.cowin
@@ -1,14 +1,14 @@
-FROM ubuntu:20.04
+FROM ubuntu:22.04
ENV DEBIAN_FRONTEND="noninteractive"
RUN apt-get upgrade && apt-get update \
&& apt-get install -y python3 python3-pip \
- && apt-get install -y openjdk-8-jdk
+ && apt-get install -y openjdk-17-jdk
RUN pip3 install pyspark pandas
-ENV JAVA_HOME /usr/lib/jvm/java-8-openjdk-amd64/
+ENV JAVA_HOME /usr/lib/jvm/java-17-openjdk-amd64/
RUN export JAVA_HOME
-COPY ccr_depa_covid_poc_dp_data_prep_cowin.py ccr_depa_covid_poc_dp_data_prep_cowin.py
+COPY preprocess_cowin.py preprocess_cowin.py
diff --git a/scenarios/covid/ci/Dockerfile.icmr b/scenarios/covid/ci/Dockerfile.icmr
index ea7f72b..b87e8ea 100644
--- a/scenarios/covid/ci/Dockerfile.icmr
+++ b/scenarios/covid/ci/Dockerfile.icmr
@@ -1,14 +1,14 @@
-FROM ubuntu:20.04
+FROM ubuntu:22.04
ENV DEBIAN_FRONTEND="noninteractive"
RUN apt-get upgrade && apt-get update \
&& apt-get install -y python3 python3-pip \
- && apt-get install -y openjdk-8-jdk
+ && apt-get install -y openjdk-17-jdk
RUN pip3 install pyspark pandas
-ENV JAVA_HOME /usr/lib/jvm/java-8-openjdk-amd64/
+ENV JAVA_HOME /usr/lib/jvm/java-17-openjdk-amd64/
RUN export JAVA_HOME
-COPY ccr_depa_covid_poc_dp_data_prep_icmr.py ccr_depa_covid_poc_dp_data_prep_icmr.py
+COPY preprocess_icmr.py preprocess_icmr.py
diff --git a/scenarios/covid/ci/Dockerfile.index b/scenarios/covid/ci/Dockerfile.index
index d99958b..9238136 100644
--- a/scenarios/covid/ci/Dockerfile.index
+++ b/scenarios/covid/ci/Dockerfile.index
@@ -1,14 +1,14 @@
-FROM ubuntu:20.04
+FROM ubuntu:22.04
ENV DEBIAN_FRONTEND="noninteractive"
RUN apt-get upgrade && apt-get update \
&& apt-get install -y python3 python3-pip \
- && apt-get install -y openjdk-8-jdk
+ && apt-get install -y openjdk-17-jdk
RUN pip3 install pyspark pandas
-ENV JAVA_HOME /usr/lib/jvm/java-8-openjdk-amd64/
+ENV JAVA_HOME /usr/lib/jvm/java-17-openjdk-amd64/
RUN export JAVA_HOME
-COPY ccr_depa_covid_poc_dp_data_prep_index.py ccr_depa_covid_poc_dp_data_prep_index.py
+COPY preprocess_index.py preprocess_index.py
diff --git a/scenarios/covid/ci/Dockerfile.modelsave b/scenarios/covid/ci/Dockerfile.modelsave
index dfc3be2..de8ba5c 100644
--- a/scenarios/covid/ci/Dockerfile.modelsave
+++ b/scenarios/covid/ci/Dockerfile.modelsave
@@ -1,16 +1,17 @@
-FROM ubuntu:20.04
+FROM ubuntu:22.04
ENV DEBIAN_FRONTEND="noninteractive"
RUN apt-get update && apt-get -y upgrade \
&& apt-get install -y gcc g++ curl \
- && apt-get install -y python3.9 python3.9-dev python3.9-distutils
+ && apt-get install -y python3 python3-dev python3-distutils
## Install pip
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
-RUN python3.9 get-pip.py
+RUN python3 get-pip.py
## Install dependencies
-RUN pip3 --default-timeout=1000 install pandas torch onnx onnx2pytorch scikit-learn scipy matplotlib
+RUN pip3 install torch --index-url https://download.pytorch.org/whl/cpu
+RUN pip3 --default-timeout=1000 install pandas onnx scikit-learn scipy matplotlib
-COPY ccr_dpsgd_model_saving_template_v2.py ccr_dpsgd_model_saving_template_v2.py
+COPY save_base_model.py save_base_model.py
diff --git a/scenarios/covid/ci/build.sh b/scenarios/covid/ci/build.sh
index eba2fb0..5290862 100755
--- a/scenarios/covid/ci/build.sh
+++ b/scenarios/covid/ci/build.sh
@@ -3,4 +3,4 @@
docker build -f ci/Dockerfile.icmr src -t preprocess-icmr:latest
docker build -f ci/Dockerfile.index src -t preprocess-index:latest
docker build -f ci/Dockerfile.cowin src -t preprocess-cowin:latest
-docker build -f ci/Dockerfile.modelsave src -t ccr-model-save:latest
+docker build -f ci/Dockerfile.modelsave src -t covid-model-save:latest
diff --git a/scenarios/covid/ci/pull-containers.sh b/scenarios/covid/ci/pull-containers.sh
new file mode 100755
index 0000000..ee90457
--- /dev/null
+++ b/scenarios/covid/ci/pull-containers.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+
+containers=("preprocess-icmr:latest" "preprocess-cowin:latest" "preprocess-index:latest" "covid-model-save:latest")
+for container in "${containers[@]}"
+do
+ docker pull $CONTAINER_REGISTRY"/"$container
+done
\ No newline at end of file
diff --git a/scenarios/covid/ci/push-containers.sh b/scenarios/covid/ci/push-containers.sh
index e0bf0d5..18631a8 100755
--- a/scenarios/covid/ci/push-containers.sh
+++ b/scenarios/covid/ci/push-containers.sh
@@ -1,4 +1,4 @@
-containers=("preprocess-icmr:latest" "preprocess-cowin:latest" "preprocess-index:latest" "ccr-model-save:latest")
+containers=("preprocess-icmr:latest" "preprocess-cowin:latest" "preprocess-index:latest" "covid-model-save:latest")
for container in "${containers[@]}"
do
docker tag $container $CONTAINER_REGISTRY"/"$container
diff --git a/scenarios/covid/config/consolidate_pipeline.sh b/scenarios/covid/config/consolidate_pipeline.sh
new file mode 100755
index 0000000..50b5a1d
--- /dev/null
+++ b/scenarios/covid/config/consolidate_pipeline.sh
@@ -0,0 +1,58 @@
+#! /bin/bash
+
+REPO_ROOT="$(git rev-parse --show-toplevel)"
+SCENARIO=covid
+
+template_path="$REPO_ROOT/scenarios/$SCENARIO/config/templates"
+model_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/model_config.json"
+data_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/dataset_config.json"
+loss_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/loss_config.json"
+train_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/train_config.json"
+eval_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/eval_config.json"
+join_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/join_config.json"
+pipeline_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/pipeline_config.json"
+
+# populate "model_config", "data_config", and "loss_config" keys in train config
+train_config=$(cat $template_path/train_config_template.json)
+
+# Only merge if the file exists
+if [[ -f "$model_config_path" ]]; then
+ model_config=$(cat $model_config_path)
+ train_config=$(echo "$train_config" | jq --argjson model "$model_config" '.config.model_config = $model')
+fi
+
+if [[ -f "$data_config_path" ]]; then
+ data_config=$(cat $data_config_path)
+ train_config=$(echo "$train_config" | jq --argjson data "$data_config" '.config.dataset_config = $data')
+fi
+
+if [[ -f "$loss_config_path" ]]; then
+ loss_config=$(cat $loss_config_path)
+ train_config=$(echo "$train_config" | jq --argjson loss "$loss_config" '.config.loss_config = $loss')
+fi
+
+if [[ -f "$eval_config_path" ]]; then
+ eval_config=$(cat $eval_config_path)
+ # Get all keys from eval_config and copy them to train_config
+ for key in $(echo "$eval_config" | jq -r 'keys[]'); do
+ train_config=$(echo "$train_config" | jq --argjson eval "$eval_config" --arg key "$key" '.config[$key] = $eval[$key]')
+ done
+fi
+
+# save train_config
+echo "$train_config" > $train_config_path
+
+# prepare pipeline config from join_config.json (first dict "config") and train_config.json (second dict "config")
+pipeline_config=$(cat $template_path/pipeline_config_template.json)
+
+# Only merge join_config if the file exists
+if [[ -f "$join_config_path" ]]; then
+ join_config=$(cat $join_config_path)
+ pipeline_config=$(echo "$pipeline_config" | jq --argjson join "$join_config" '.pipeline += [$join]')
+fi
+
+# Always merge train_config as it's required
+pipeline_config=$(echo "$pipeline_config" | jq --argjson train "$train_config" '.pipeline += [$train]')
+
+# save pipeline_config to pipeline_config.json
+echo "$pipeline_config" > $pipeline_config_path
\ No newline at end of file
diff --git a/scenarios/covid/config/dataset_config.json b/scenarios/covid/config/dataset_config.json
new file mode 100644
index 0000000..23c7ada
--- /dev/null
+++ b/scenarios/covid/config/dataset_config.json
@@ -0,0 +1,17 @@
+{
+ "type": "tabular",
+ "target_variable": "icmr_a_icmr_test_result",
+ "preprocessing": {
+ "scaler": "standard"
+ },
+ "missing_strategy": "fill",
+ "fill_value": 0,
+ "splits": {
+ "train": 0.7,
+ "val": 0.2,
+ "test": 0.1,
+ "random_state": 42,
+ "stratify": true
+ },
+ "data_type": "tensor"
+}
\ No newline at end of file
diff --git a/scenarios/covid/config/eval_config.json b/scenarios/covid/config/eval_config.json
new file mode 100644
index 0000000..c0c99cb
--- /dev/null
+++ b/scenarios/covid/config/eval_config.json
@@ -0,0 +1,24 @@
+{
+ "task_type": "classification",
+ "metrics": [
+ "accuracy",
+ {
+ "name": "confusion_matrix",
+ "params": {}
+ },
+ {
+ "name": "f1_score",
+ "params": {
+ "average": "binary"
+ }
+ },
+ {
+ "name": "precision_recall_curve",
+ "params": {}
+ },
+ {
+ "name": "roc_auc",
+ "params": {}
+ }
+ ]
+}
\ No newline at end of file
diff --git a/scenarios/covid/config/join_config.json b/scenarios/covid/config/join_config.json
new file mode 100644
index 0000000..37ecc79
--- /dev/null
+++ b/scenarios/covid/config/join_config.json
@@ -0,0 +1,166 @@
+{
+ "name": "SparkJoin",
+ "config": {
+ "datasets": [
+ {
+ "id": "19517ba8-bab8-11ed-afa1-0242ac120002",
+ "provider": "icmr",
+ "name": "icmr_test_results",
+ "file": "dp_icmr_standardised_anon.csv",
+ "select_variables": [
+ "icmr_test_result",
+ "fk_genetic_strain",
+ "test_ct_value",
+ "sample_genetic_sequenced"
+ ],
+ "mount_path": "/mnt/remote/icmr/"
+ },
+ {
+ "id": "216d5cc6-bab8-11ed-afa1-0242ac120002",
+ "provider": "cowin",
+ "name": "cowin_vaccine_data",
+ "file": "dp_cowin_standardised_anon.csv",
+ "select_variables": [
+ "age",
+ "vaccine_name"
+ ],
+ "mount_path": "/mnt/remote/cowin/"
+ },
+ {
+ "id": "2830a144-bab8-11ed-afa1-0242ac120002",
+ "provider": "index",
+ "name": "index_patient_records",
+ "file": "dp_index_standardised_anon.csv",
+ "select_variables": [
+ "pasymp",
+ "dsinfection",
+ "potravel",
+ "page"
+ ],
+ "mount_path": "/mnt/remote/index/"
+ }
+ ],
+ "joined_dataset": {
+ "joined_dataset": "/tmp/covid_joined.csv",
+ "joining_query": "SELECT * FROM icmr_test_results AS icmr JOIN index_patient_records AS idx ON idx.pk_mobno_hashed = icmr.pk_mobno_hashed JOIN cowin_vaccine_data AS cowin ON idx.pk_mobno_hashed = cowin.pk_mobno_hashed",
+ "joining_key": "pk_mobno_hashed",
+ "drop_columns": [
+ "pk_icmrno",
+ "pk_mobno",
+ "ref_srfno",
+ "ref_index_id",
+ "ref_tr_index_id",
+ "fk_pname",
+ "ref_paddress",
+ "cowin_beneficiary_name"
+ ],
+ "identifiers": [
+ "pk_mobno_hashed",
+ "ref_srfno_hashed",
+ "pk_icmrno_hashed",
+ "index_idcpatient",
+ "index_pname",
+ "index_paddress",
+ "index_phcname",
+ "pk_icmrno_hashed",
+ "pk_mobno_hashed",
+ "ref_bucode_hashed",
+ "cowin_beneficiary_name",
+ "cowin_d1_vaccinated_at",
+ "cowin_d2_vaccinated_at",
+ "pk_beneficiary_id_hashed",
+ "ref_uhid_hashed",
+ "pk_mobno_hashed",
+ "ref_id_verified_hashed",
+ "ref_index_id_hashed",
+ "fk_pname_hashed",
+ "fk_cpatient_hashed",
+ "ref_tr_index_id_hashed",
+ "ref_pahospital_code_hashed",
+ "ref_tpahospital_code_hashed",
+ "ref_paddress_hashed",
+ "fk_icmr_labid_hashed",
+ "index_labcode",
+ "ref_labid",
+ "index_pgender",
+ "index_pstate",
+ "index_pdistrict",
+ "index_plocation",
+ "index_pzone",
+ "index_pward",
+ "index_ptaluka",
+ "index_astatus",
+ "index_anumber",
+ "index_adname",
+ "index_adnumber",
+ "index_stambulance",
+ "index_stdate",
+ "index_audate",
+ "index_cdate",
+ "index_pcdate",
+ "index_adddate",
+ "index_bmdate",
+ "index_moddate",
+ "index_admdate",
+ "index_movdate",
+ "index_disdate",
+ "index_pudate",
+ "index_cudate",
+ "index_distcode",
+ "index_distrefno",
+ "index_ptype",
+ "ref_labid_hashed",
+ "index_fdate",
+ "index_todate",
+ "index_apcnumber",
+ "index_pbtype",
+ "index_pbquota",
+ "icmr_pupdate",
+ "index_bucode",
+ "index_trbucode",
+ "index_commstatus",
+ "index_commby",
+ "index_dristatus",
+ "index_driby",
+ "index_vehnumber",
+ "index_hosstatus",
+ "index_hosby",
+ "index_padone",
+ "index_padate",
+ "index_pstatus",
+ "index_dsummary",
+ "index_statusuby",
+ "index_ureason",
+ "index_usummary",
+ "index_padmitted",
+ "index_hcode",
+ "index_pahospital",
+ "index_tpahospital",
+ "index_htype",
+ "index_pbedcode",
+ "index_labname",
+ "index_pid",
+ "cowin_dose_1_date",
+ "cowin_d1_vaccinated_by",
+ "cowin_dose_2_date",
+ "cowin_d2_vaccinated_by",
+ "cowin_pupdate",
+ "cowin_gender",
+ "index_remarks",
+ "icmr_a_icmr_test_type"
+ ],
+ "joined_result_columns": [
+ "icmr_a_icmr_test_result",
+ "icmr_a_test_ct_value",
+ "icmr_a_sample_genetic_sequenced",
+ "fk_genetic_strain",
+ "index_pasymp",
+ "index_dsinfection",
+ "index_potravel",
+ "index_page",
+ "cowin_age",
+ "cowin_vaccine_name"
+ ]
+ }
+ }
+}
\ No newline at end of file
diff --git a/scenarios/covid/config/loss_config.json b/scenarios/covid/config/loss_config.json
new file mode 100644
index 0000000..ca60f35
--- /dev/null
+++ b/scenarios/covid/config/loss_config.json
@@ -0,0 +1,6 @@
+{
+ "class": "torch.nn.BCELoss",
+ "params": {
+ "reduction": "mean"
+ }
+}
\ No newline at end of file
diff --git a/scenarios/covid/config/pipeline_config.json b/scenarios/covid/config/pipeline_config.json
index 6f798b4..ca17150 100644
--- a/scenarios/covid/config/pipeline_config.json
+++ b/scenarios/covid/config/pipeline_config.json
@@ -1,191 +1,242 @@
{
- "pipeline": [
- {
- "name": "Join",
- "config": {
- "datasets": [
- {
- "id": "19517ba8-bab8-11ed-afa1-0242ac120002",
- "name": "icmr",
- "file": "dp_icmr_standardised_anon.csv",
- "select_variables": [
- "icmr_test_result",
- "fk_genetic_strain",
- "test_ct_value",
- "sample_genetic_sequenced"
- ],
- "mount_path": "/mnt/remote/icmr/"
- },
- {
- "id": "216d5cc6-bab8-11ed-afa1-0242ac120002",
- "name": "cowin",
- "file": "dp_cowin_standardised_anon.csv",
- "select_variables": [
- "age",
- "vaccine_name"
- ],
- "mount_path": "/mnt/remote/cowin/"
- },
- {
- "id": "2830a144-bab8-11ed-afa1-0242ac120002",
- "name": "index",
- "file": "dp_index_standardised_anon.csv",
- "select_variables": [
- "pasymp",
- "dsinfection",
- "potravel",
- "page"
- ],
- "mount_path": "/mnt/remote/index/"
- }
- ],
- "joined_dataset": {
- "joined_dataset": "sandbox_icmr_cowin_index_without_key_identifiers.csv",
- "joining_query": "select * from icmr, index, cowin where index.pk_mobno_hashed == icmr.pk_mobno_hashed and index.pk_mobno_hashed == cowin.pk_mobno_hashed",
- "joining_key": "pk_mobno_hashed",
- "model_output_folder": "/tmp/",
- "drop_columns": [
- "pk_icmrno",
- "pk_mobno",
- "ref_srfno",
- "ref_index_id",
- "ref_tr_index_id",
- "fk_pname",
- "ref_paddress",
- "cowin_beneficiary_name"
- ],
- "identifiers": [
- "pk_mobno_hashed",
- "ref_srfno_hashed",
- "pk_icmrno_hashed",
- "index_idcpatient",
- "index_pname",
- "index_paddress",
- "index_phcname",
- "pk_icmrno_hashed",
- "pk_mobno_hashed",
- "ref_bucode_hashed",
- "cowin_beneficiary_name",
- "cowin_d1_vaccinated_at",
- "cowin_d2_vaccinated_at",
- "pk_beneficiary_id_hashed",
- "ref_uhid_hashed",
- "pk_mobno_hashed",
- "ref_id_verified_hashed",
- "ref_index_id_hashed",
- "fk_pname_hashed",
- "fk_cpatient_hashed",
- "ref_tr_index_id_hashed",
- "ref_pahospital_code_hashed",
- "ref_tpahospital_code_hashed",
- "ref_paddress_hashed",
- "fk_icmr_labid_hashed",
- "index_labcode",
- "ref_labid",
- "index_pgender",
- "index_pstate",
- "index_pdistrict",
- "index_plocation",
- "index_pzone",
- "index_pward",
- "index_ptaluka",
- "index_astatus",
- "index_anumber",
- "index_adname",
- "index_adnumber",
- "index_stambulance",
- "index_stdate",
- "index_audate",
- "index_cdate",
- "index_pcdate",
- "index_adddate",
- "index_bmdate",
- "index_moddate",
- "index_admdate",
- "index_movdate",
- "index_disdate",
- "index_pudate",
- "index_cudate",
- "index_distcode",
- "index_distrefno",
- "index_ptype",
- "ref_labid_hashed",
- "index_fdate",
- "index_todate",
- "index_apcnumber",
- "index_pbtype",
- "index_pbquota",
- "icmr_pupdate",
- "index_bucode",
- "index_trbucode",
- "index_commstatus",
- "index_commby",
- "index_dristatus",
- "index_driby",
- "index_vehnumber",
- "index_hosstatus",
- "index_hosby",
- "index_padone",
- "index_padate",
- "index_pstatus",
- "index_dsummary",
- "index_statusuby",
- "index_ureason",
- "index_usummary",
- "index_padmitted",
- "index_hcode",
- "index_pahospital",
- "index_tpahospital",
- "index_htype",
- "index_pbedcode",
- "index_labname",
- "index_pid",
- "cowin_dose_1_date",
- "cowin_d1_vaccinated_by",
- "cowin_dose_2_date",
- "cowin_d2_vaccinated_by",
- "cowin_pupdate",
- "cowin_gender",
- "index_remarks",
- "icmr_a_icmr_test_type"
- ],
- "joined_result_columns": [
- "icmr_a_icmr_test_result",
- "icmr_a_test_ct_value",
- "icmr_a_sample_genetic_sequenced",
- "fk_genetic_strain",
- "index_pasymp",
- "index_dsinfection",
- "index_potravel",
- "index_page",
- "cowin_age",
- "cowin_vaccine_name"
- ]
- }
- }
+ "pipeline": [
+ {
+ "name": "SparkJoin",
+ "config": {
+ "datasets": [
+ {
+ "id": "19517ba8-bab8-11ed-afa1-0242ac120002",
+ "provider": "icmr",
+ "name": "icmr_test_results",
+ "file": "dp_icmr_standardised_anon.csv",
+ "select_variables": [
+ "icmr_test_result",
+ "fk_genetic_strain",
+ "test_ct_value",
+ "sample_genetic_sequenced"
+ ],
+ "mount_path": "/mnt/remote/icmr/"
+ },
+ {
+ "id": "216d5cc6-bab8-11ed-afa1-0242ac120002",
+ "provider": "cowin",
+ "name": "cowin_vaccine_data",
+ "file": "dp_cowin_standardised_anon.csv",
+ "select_variables": [
+ "age",
+ "vaccine_name"
+ ],
+ "mount_path": "/mnt/remote/cowin/"
+ },
+ {
+ "id": "2830a144-bab8-11ed-afa1-0242ac120002",
+ "provider": "index",
+ "name": "index_patient_records",
+ "file": "dp_index_standardised_anon.csv",
+ "select_variables": [
+ "pasymp",
+ "dsinfection",
+ "potravel",
+ "page"
+ ],
+ "mount_path": "/mnt/remote/index/"
+ }
+ ],
+ "joined_dataset": {
+ "joined_dataset": "/tmp/covid_joined.csv",
+ "joining_query": "SELECT * FROM icmr_test_results AS icmr JOIN index_patient_records AS idx ON idx.pk_mobno_hashed = icmr.pk_mobno_hashed JOIN cowin_vaccine_data AS cowin ON idx.pk_mobno_hashed = cowin.pk_mobno_hashed",
+ "joining_key": "pk_mobno_hashed",
+ "drop_columns": [
+ "pk_icmrno",
+ "pk_mobno",
+ "ref_srfno",
+ "ref_index_id",
+ "ref_tr_index_id",
+ "fk_pname",
+ "ref_paddress",
+ "cowin_beneficiary_name"
+ ],
+ "identifiers": [
+ "pk_mobno_hashed",
+ "ref_srfno_hashed",
+ "pk_icmrno_hashed",
+ "index_idcpatient",
+ "index_pname",
+ "index_paddress",
+ "index_phcname",
+ "pk_icmrno_hashed",
+ "pk_mobno_hashed",
+ "ref_bucode_hashed",
+ "cowin_beneficiary_name",
+ "cowin_d1_vaccinated_at",
+ "cowin_d2_vaccinated_at",
+ "pk_beneficiary_id_hashed",
+ "ref_uhid_hashed",
+ "pk_mobno_hashed",
+ "ref_id_verified_hashed",
+ "ref_index_id_hashed",
+ "fk_pname_hashed",
+ "fk_cpatient_hashed",
+ "ref_tr_index_id_hashed",
+ "ref_pahospital_code_hashed",
+ "ref_tpahospital_code_hashed",
+ "ref_paddress_hashed",
+ "fk_icmr_labid_hashed",
+ "index_labcode",
+ "ref_labid",
+ "index_pgender",
+ "index_pstate",
+ "index_pdistrict",
+ "index_plocation",
+ "index_pzone",
+ "index_pward",
+ "index_ptaluka",
+ "index_astatus",
+ "index_anumber",
+ "index_adname",
+ "index_adnumber",
+ "index_stambulance",
+ "index_stdate",
+ "index_audate",
+ "index_cdate",
+ "index_pcdate",
+ "index_adddate",
+ "index_bmdate",
+ "index_moddate",
+ "index_admdate",
+ "index_movdate",
+ "index_disdate",
+ "index_pudate",
+ "index_cudate",
+ "index_distcode",
+ "index_distrefno",
+ "index_ptype",
+ "ref_labid_hashed",
+ "index_fdate",
+ "index_todate",
+ "index_apcnumber",
+ "index_pbtype",
+ "index_pbquota",
+ "icmr_pupdate",
+ "index_bucode",
+ "index_trbucode",
+ "index_commstatus",
+ "index_commby",
+ "index_dristatus",
+ "index_driby",
+ "index_vehnumber",
+ "index_hosstatus",
+ "index_hosby",
+ "index_padone",
+ "index_padate",
+ "index_pstatus",
+ "index_dsummary",
+ "index_statusuby",
+ "index_ureason",
+ "index_usummary",
+ "index_padmitted",
+ "index_hcode",
+ "index_pahospital",
+ "index_tpahospital",
+ "index_htype",
+ "index_pbedcode",
+ "index_labname",
+ "index_pid",
+ "cowin_dose_1_date",
+ "cowin_d1_vaccinated_by",
+ "cowin_dose_2_date",
+ "cowin_d2_vaccinated_by",
+ "cowin_pupdate",
+ "cowin_gender",
+ "index_remarks",
+ "icmr_a_icmr_test_type"
+ ],
+ "joined_result_columns": [
+ "icmr_a_icmr_test_result",
+ "icmr_a_test_ct_value",
+ "icmr_a_sample_genetic_sequenced",
+ "fk_genetic_strain",
+ "index_pasymp",
+ "index_dsinfection",
+ "index_potravel",
+ "index_page",
+ "cowin_age",
+ "cowin_vaccine_name"
+ ]
+ }
+ }
+ },
+ {
+ "name": "Train_DL",
+ "config": {
+ "paths": {
+ "input_dataset_path": "/tmp/covid_joined.csv",
+ "base_model_path": "/mnt/remote/model/model.onnx",
+ "trained_model_output_path": "/mnt/remote/output"
+ },
+ "model_type": "onnx",
+ "is_private": true,
+ "privacy_params": {
+ "max_grad_norm": 0.1,
+ "epsilon": 1.0,
+ "delta": 0.01
+ },
+ "device": "cpu",
+ "batch_size": 8,
+ "optimizer": {
+ "name": "SGD",
+ "params": {
+ "lr": 0.0001
+ }
},
- {
- "name": "PrivateTrain",
- "config": {
- "input_dataset_path": "/tmp/sandbox_icmr_cowin_index_without_key_identifiers.csv",
- "saved_model_path": "/mnt/remote/model/model.onnx",
- "saved_model_optimizer": "/mnt/remote/model/dpsgd_model_opimizer.pth",
- "trained_model_output_path": "/mnt/remote/output/model.onnx",
- "saved_weights_path": "",
- "batch_size": 2,
- "total_epochs": 5,
- "max_grad_norm": 0.1,
- "epsilon_threshold": 1.0,
- "delta": 0.01,
- "sample_size": 60000,
- "target_variable": "icmr_a_icmr_test_result",
- "test_train_split": 0.2,
- "metrics": [
- "accuracy",
- "precision",
- "recall"
- ]
+ "scheduler": null,
+ "total_epochs": 1,
+ "dataset_config": {
+ "type": "tabular",
+ "target_variable": "icmr_a_icmr_test_result",
+ "preprocessing": {
+ "scaler": "standard"
+ },
+ "missing_strategy": "fill",
+ "fill_value": 0,
+ "splits": {
+ "train": 0.7,
+ "val": 0.2,
+ "test": 0.1,
+ "random_state": 42,
+ "stratify": true
+ },
+ "data_type": "tensor"
+ },
+ "loss_config": {
+ "class": "torch.nn.BCELoss",
+ "params": {
+ "reduction": "mean"
+ }
+ },
+ "metrics": [
+ "accuracy",
+ {
+ "name": "confusion_matrix",
+ "params": {}
+ },
+ {
+ "name": "f1_score",
+ "params": {
+ "average": "binary"
}
- }
- ]
-}
\ No newline at end of file
+ },
+ {
+ "name": "precision_recall_curve",
+ "params": {}
+ },
+ {
+ "name": "roc_auc",
+ "params": {}
+ }
+ ],
+ "task_type": "classification"
+ }
+ }
+ ]
+}
diff --git a/scenarios/covid/config/sample_variants/model_config_safetensors.json b/scenarios/covid/config/sample_variants/model_config_safetensors.json
new file mode 100644
index 0000000..763d175
--- /dev/null
+++ b/scenarios/covid/config/sample_variants/model_config_safetensors.json
@@ -0,0 +1,71 @@
+{
+ "layers": {
+ "fc1": {
+ "class": "nn.Linear",
+ "params": {
+ "in_features": 10,
+ "out_features": 128
+ }
+ },
+ "fc2": {
+ "class": "nn.Linear",
+ "params": {
+ "in_features": 128,
+ "out_features": 64
+ }
+ },
+ "fc3": {
+ "class": "nn.Linear",
+ "params": {
+ "in_features": 64,
+ "out_features": 1
+ }
+ },
+ "dropout": {
+ "class": "nn.Dropout",
+ "params": {
+ "p": 0.3
+ }
+ }
+ },
+ "input": [
+ "x"
+ ],
+ "forward": [
+ {
+ "ops": [
+ "fc1",
+ "torch.relu",
+ "dropout"
+ ],
+ "input": [
+ "x"
+ ],
+ "output": "x1"
+ },
+ {
+ "ops": [
+ "fc2",
+ "torch.relu",
+ "dropout"
+ ],
+ "input": [
+ "x1"
+ ],
+ "output": "x2"
+ },
+ {
+ "ops": [
+ "fc3",
+ "torch.sigmoid"
+ ],
+ "input": [
+ "x2"
+ ],
+ "output": "x3"
+ }
+ ],
+ "output": [
+ "x3"
+ ]
+}
\ No newline at end of file
diff --git a/scenarios/covid/config/sample_variants/model_config_xgboost.json b/scenarios/covid/config/sample_variants/model_config_xgboost.json
new file mode 100644
index 0000000..6ac75b1
--- /dev/null
+++ b/scenarios/covid/config/sample_variants/model_config_xgboost.json
@@ -0,0 +1,8 @@
+{
+ "num_boost_round": 100,
+ "booster_params": {
+ "max_depth": 6,
+ "learning_rate": 0.1,
+ "objective": "binary:logistic"
+ }
+}
\ No newline at end of file
diff --git a/scenarios/covid/config/templates/pipeline_config_template.json b/scenarios/covid/config/templates/pipeline_config_template.json
new file mode 100644
index 0000000..43e9e84
--- /dev/null
+++ b/scenarios/covid/config/templates/pipeline_config_template.json
@@ -0,0 +1,3 @@
+{
+ "pipeline": []
+}
\ No newline at end of file
diff --git a/scenarios/covid/config/templates/train_config_template.json b/scenarios/covid/config/templates/train_config_template.json
new file mode 100644
index 0000000..08499c4
--- /dev/null
+++ b/scenarios/covid/config/templates/train_config_template.json
@@ -0,0 +1,27 @@
+{
+ "name": "Train_DL",
+ "config": {
+ "paths": {
+ "input_dataset_path": "/tmp/covid_joined.csv",
+ "base_model_path": "/mnt/remote/model/model.onnx",
+ "trained_model_output_path": "/mnt/remote/output"
+ },
+ "model_type": "onnx",
+ "is_private": true,
+ "privacy_params": {
+ "max_grad_norm": 0.1,
+ "epsilon": 1.0,
+ "delta": 0.01
+ },
+ "device": "cpu",
+ "batch_size": 8,
+ "optimizer": {
+ "name": "SGD",
+ "params": {
+ "lr": 1e-4
+ }
+ },
+ "scheduler": null,
+ "total_epochs": 1
+ }
+}
\ No newline at end of file
diff --git a/scenarios/covid/config/train_config.json b/scenarios/covid/config/train_config.json
new file mode 100644
index 0000000..06ac125
--- /dev/null
+++ b/scenarios/covid/config/train_config.json
@@ -0,0 +1,72 @@
+{
+ "name": "Train_DL",
+ "config": {
+ "paths": {
+ "input_dataset_path": "/tmp/covid_joined.csv",
+ "base_model_path": "/mnt/remote/model/model.onnx",
+ "trained_model_output_path": "/mnt/remote/output"
+ },
+ "model_type": "onnx",
+ "is_private": true,
+ "privacy_params": {
+ "max_grad_norm": 0.1,
+ "epsilon": 1.0,
+ "delta": 0.01
+ },
+ "device": "cpu",
+ "batch_size": 8,
+ "optimizer": {
+ "name": "SGD",
+ "params": {
+ "lr": 0.0001
+ }
+ },
+ "scheduler": null,
+ "total_epochs": 1,
+ "dataset_config": {
+ "type": "tabular",
+ "target_variable": "icmr_a_icmr_test_result",
+ "preprocessing": {
+ "scaler": "standard"
+ },
+ "missing_strategy": "fill",
+ "fill_value": 0,
+ "splits": {
+ "train": 0.7,
+ "val": 0.2,
+ "test": 0.1,
+ "random_state": 42,
+ "stratify": true
+ },
+ "data_type": "tensor"
+ },
+ "loss_config": {
+ "class": "torch.nn.BCELoss",
+ "params": {
+ "reduction": "mean"
+ }
+ },
+ "metrics": [
+ "accuracy",
+ {
+ "name": "confusion_matrix",
+ "params": {}
+ },
+ {
+ "name": "f1_score",
+ "params": {
+ "average": "binary"
+ }
+ },
+ {
+ "name": "precision_recall_curve",
+ "params": {}
+ },
+ {
+ "name": "roc_auc",
+ "params": {}
+ }
+ ],
+ "task_type": "classification"
+ }
+}
diff --git a/scenarios/covid/contract/contract.json b/scenarios/covid/contract/contract.json
index c0c5e67..85820c0 100644
--- a/scenarios/covid/contract/contract.json
+++ b/scenarios/covid/contract/contract.json
@@ -3,53 +3,53 @@
"schemaVersion": "0.1",
"startTime": "2023-03-14T00:00:00.000Z",
"expiryTime": "2024-03-14T00:00:00.000Z",
- "tdc" : "",
- "tdps" : [],
- "ccrp": "did:web:ccrprovider.github.io",
+ "tdc": "",
+ "tdps": [],
+ "ccrp": "did:web:$CCRP_USERNAME.github.io",
"datasets": [
{
- "id" : "19517ba8-bab8-11ed-afa1-0242ac120002",
+ "id": "19517ba8-bab8-11ed-afa1-0242ac120002",
"name": "icmr",
- "url" : "https://ccrcontainer.blob.core.windows.net/icmr/data.img",
+ "url": "https://$AZURE_STORAGE_ACCOUNT_NAME.blob.core.windows.net/$AZURE_ICMR_CONTAINER_NAME/data.img",
"provider": "",
- "key" : {
+ "key": {
"type": "azure",
"properties": {
"kid": "ICMRFilesystemEncryptionKey",
"authority": {
- "endpoint": "sharedneu.neu.attest.azure.net"
+ "endpoint": "sharedneu.neu.attest.azure.net"
},
"endpoint": ""
}
}
},
{
- "id" : "216d5cc6-bab8-11ed-afa1-0242ac120002",
+ "id": "216d5cc6-bab8-11ed-afa1-0242ac120002",
"name": "cowin",
- "url" : "https://ccrcontainer.blob.core.windows.net/cowin/data.img",
+ "url": "https://$AZURE_STORAGE_ACCOUNT_NAME.blob.core.windows.net/$AZURE_COWIN_CONTAINER_NAME/data.img",
"provider": "",
- "key" : {
+ "key": {
"type": "azure",
"properties": {
"kid": "COWINFilesystemEncryptionKey",
"authority": {
- "endpoint": "sharedneu.neu.attest.azure.net"
+ "endpoint": "sharedneu.neu.attest.azure.net"
},
"endpoint": ""
}
}
},
{
- "id" : "2830a144-bab8-11ed-afa1-0242ac120002",
+ "id": "2830a144-bab8-11ed-afa1-0242ac120002",
"name": "index",
- "url" : "https://ccrcontainer.blob.core.windows.net/swr/data.img",
+ "url": "https://$AZURE_STORAGE_ACCOUNT_NAME.blob.core.windows.net/$AZURE_INDEX_CONTAINER_NAME/data.img",
"provider": "",
- "key" : {
+ "key": {
"type": "azure",
"properties": {
"kid": "IndexFilesystemEncryptionKey",
"authority": {
- "endpoint": "sharedneu.neu.attest.azure.net"
+ "endpoint": "sharedneu.neu.attest.azure.net"
},
"endpoint": ""
}
@@ -61,21 +61,21 @@
{
"privacy": [
{
- "dataset": "19517ba8-bab8-11ed-afa1-0242ac120002",
+ "dataset": "19517ba8-bab8-11ed-afa1-0242ac120002",
"epsilon_threshold": "1.5",
"noise_multiplier": "2.0",
"delta": "0.01",
"epochs_per_report": "2"
},
{
- "dataset": "216d5cc6-bab8-11ed-afa1-0242ac120002",
+ "dataset": "216d5cc6-bab8-11ed-afa1-0242ac120002",
"epsilon_threshold": "1.5",
"noise_multiplier": "2.0",
"delta": "0.01",
"epochs_per_report": "2"
},
{
- "dataset": "2830a144-bab8-11ed-afa1-0242ac120002",
+ "dataset": "2830a144-bab8-11ed-afa1-0242ac120002",
"epsilon_threshold": "1.5",
"noise_multiplier": "2.0",
"delta": "0.01",
@@ -85,9 +85,7 @@
}
],
"terms": {
- "payment" : {
- },
- "revocation": {
- }
+ "payment": {},
+ "revocation": {}
}
-}
+}
\ No newline at end of file
diff --git a/scenarios/covid/data/4-encrypt-data.sh b/scenarios/covid/data/4-encrypt-data.sh
deleted file mode 100755
index 6294a3f..0000000
--- a/scenarios/covid/data/4-encrypt-data.sh
+++ /dev/null
@@ -1,8 +0,0 @@
-#!/bin/bash
-
-./generatefs.sh -d icmr/preprocessed -k icmrkey.bin -i icmr.img
-./generatefs.sh -d cowin/preprocessed -k cowinkey.bin -i cowin.img
-./generatefs.sh -d index/preprocessed -k indexkey.bin -i index.img
-./generatefs.sh -d modeller/model -k modelkey.bin -i model.img
-mkdir -p output
-./generatefs.sh -d output -k outputkey.bin -i output.img
\ No newline at end of file
diff --git a/scenarios/covid/data/output/ccr_depa_trg_model_logger.json b/scenarios/covid/data/output/ccr_depa_trg_model_logger.json
deleted file mode 100644
index 55e6e19..0000000
--- a/scenarios/covid/data/output/ccr_depa_trg_model_logger.json
+++ /dev/null
@@ -1,8 +0,0 @@
-Model Architecture
-GradSampleModule(ConvertModel(
- (Gemm_/fc1/Gemm_output_0): Linear(in_features=10, out_features=128, bias=True)
- (Relu_/Relu_output_0): ReLU(inplace=True)
- (Gemm_/fc2/Gemm_output_0): Linear(in_features=128, out_features=64, bias=True)
- (Relu_/Relu_1_output_0): ReLU(inplace=True)
- (Gemm_11): Linear(in_features=64, out_features=1, bias=True)
-))Epoch [{epoch+1}/{self.model_config["total_epochs"]}], Loss: {loss.item():.4f}Epoch [{epoch+1}/{self.model_config["total_epochs"]}], Loss: {loss.item():.4f}Epoch [{epoch+1}/{self.model_config["total_epochs"]}], Loss: {loss.item():.4f}Epoch [{epoch+1}/{self.model_config["total_epochs"]}], Loss: {loss.item():.4f}Epoch [{epoch+1}/{self.model_config["total_epochs"]}], Loss: {loss.item():.4f}
\ No newline at end of file
diff --git a/scenarios/covid/deployment/azure/0-create-acr.sh b/scenarios/covid/deployment/azure/0-create-acr.sh
new file mode 100755
index 0000000..4719bad
--- /dev/null
+++ b/scenarios/covid/deployment/azure/0-create-acr.sh
@@ -0,0 +1,42 @@
+#!/bin/bash
+
+# Only to be run when creating a new ACR
+
+# Ensure required env vars are set
+if [[ -z "$CONTAINER_REGISTRY" || -z "$AZURE_RESOURCE_GROUP" || -z "$AZURE_LOCATION" ]]; then
+ echo "ERROR: CONTAINER_REGISTRY, AZURE_RESOURCE_GROUP, and AZURE_LOCATION environment variables must be set."
+ exit 1
+fi
+
+echo "Checking if ACR '$CONTAINER_REGISTRY' exists in resource group '$AZURE_RESOURCE_GROUP'..."
+
+# Check if ACR exists
+ACR_EXISTS=$(az acr show --name "$CONTAINER_REGISTRY" --resource-group "$AZURE_RESOURCE_GROUP" --query "name" -o tsv 2>/dev/null)
+
+if [[ -n "$ACR_EXISTS" ]]; then
+ echo "✅ ACR '$CONTAINER_REGISTRY' already exists."
+else
+ echo "⏳ ACR '$CONTAINER_REGISTRY' does not exist. Creating..."
+
+ # Create ACR with premium SKU and admin enabled
+ az acr create \
+ --name "$CONTAINER_REGISTRY" \
+ --resource-group "$AZURE_RESOURCE_GROUP" \
+ --location "$AZURE_LOCATION" \
+ --sku Premium \
+ --admin-enabled true \
+ --output table
+
+ # Enable anonymous pull
+ az acr update --name "$CONTAINER_REGISTRY" --anonymous-pull-enabled true
+
+ if [[ $? -eq 0 ]]; then
+ echo "✅ ACR '$CONTAINER_REGISTRY' created successfully."
+ else
+ echo "❌ Failed to create ACR."
+ exit 1
+ fi
+fi
+
+# Login to the ACR
+az acr login --name "$CONTAINER_REGISTRY"
\ No newline at end of file
diff --git a/scenarios/covid/data/1-create-storage-containers.sh b/scenarios/covid/deployment/azure/1-create-storage-containers.sh
similarity index 96%
rename from scenarios/covid/data/1-create-storage-containers.sh
rename to scenarios/covid/deployment/azure/1-create-storage-containers.sh
index e6e2a43..dbab27f 100755
--- a/scenarios/covid/data/1-create-storage-containers.sh
+++ b/scenarios/covid/deployment/azure/1-create-storage-containers.sh
@@ -11,14 +11,14 @@ else
echo "Resource group $AZURE_RESOURCE_GROUP already exists. Skipping creation."
fi
-#echo "Check if storage account $STORAGE_ACCOUNT_NAME exists..."
+echo "Check if storage account $STORAGE_ACCOUNT_NAME exists..."
STORAGE_ACCOUNT_EXISTS=$(az storage account check-name --name $AZURE_STORAGE_ACCOUNT_NAME --query "nameAvailable" --output tsv)
if [ "$STORAGE_ACCOUNT_EXISTS" == "true" ]; then
echo "Storage account $AZURE_STORAGE_ACCOUNT_NAME does not exist. Creating it now..."
az storage account create --resource-group $AZURE_RESOURCE_GROUP --name $AZURE_STORAGE_ACCOUNT_NAME
else
- echo "Storage account $AZURE_STORAGE_ACCOUNT_NAME exists"
+ echo "Storage account $AZURE_STORAGE_ACCOUNT_NAME already exists. Skipping creation."
fi
# Get the storage account key
diff --git a/scenarios/covid/data/2-create-akv.sh b/scenarios/covid/deployment/azure/2-create-akv.sh
similarity index 94%
rename from scenarios/covid/data/2-create-akv.sh
rename to scenarios/covid/deployment/azure/2-create-akv.sh
index 07cf2ce..545b08b 100755
--- a/scenarios/covid/data/2-create-akv.sh
+++ b/scenarios/covid/deployment/azure/2-create-akv.sh
@@ -1,8 +1,6 @@
#!/bin/bash
set -e
-
- echo CREATING $AZURE_KEYVAULT_ENDPOINT in resouce group $AZURE_RESOURCE_GROUP
if [[ "$AZURE_KEYVAULT_ENDPOINT" == *".vault.azure.net" ]]; then
AZURE_AKV_RESOURCE_NAME=`echo $AZURE_KEYVAULT_ENDPOINT | awk '{split($0,a,"."); print a[1]}'`
@@ -14,6 +12,7 @@ if [[ "$AZURE_KEYVAULT_ENDPOINT" == *".vault.azure.net" ]]; then
--body "{\"name\": \"$AZURE_AKV_RESOURCE_NAME\", \"type\": \"Microsoft.KeyVault/vaults\"}" | jq -r '.nameAvailable')
if [ "$NAME_AVAILABLE" == true ]; then
echo "Key Vault $KEY_VAULT_NAME does not exist. Creating it now..."
+ echo CREATING $AZURE_KEYVAULT_ENDPOINT in resouce group $AZURE_RESOURCE_GROUP
# Create Azure key vault with RBAC authorization
az keyvault create --name $AZURE_AKV_RESOURCE_NAME --resource-group $AZURE_RESOURCE_GROUP --sku "Premium" --enable-rbac-authorization
# Assign RBAC roles to the resource owner so they can import keys
diff --git a/scenarios/covid/deployment/azure/3-import-keys.sh b/scenarios/covid/deployment/azure/3-import-keys.sh
new file mode 100755
index 0000000..c1bcae9
--- /dev/null
+++ b/scenarios/covid/deployment/azure/3-import-keys.sh
@@ -0,0 +1,58 @@
+#!/bin/bash
+
+# Function to import a key with a given key ID and key material into AKV
+# The key is bound to a key release policy with host data defined in the environment variable CCE_POLICY_HASH
+function import_key() {
+ export KEYID=$1
+ export KEYFILE=$2
+
+ # For RSA-HSM keys, we need to set a salt and label which will be used in the symmetric key derivation
+ if [ "$AZURE_AKV_KEY_TYPE" = "RSA-HSM" ]; then
+ export AZURE_AKV_KEY_DERIVATION_LABEL=$KEYID
+ fi
+
+ CONFIG=$(jq '.claims[0][0].equals = env.CCE_POLICY_HASH' importkey-config-template.json)
+ CONFIG=$(echo $CONFIG | jq '.key.kid = env.KEYID')
+ CONFIG=$(echo $CONFIG | jq '.key.kty = env.AZURE_AKV_KEY_TYPE')
+ CONFIG=$(echo $CONFIG | jq '.key_derivation.salt = "9b53cddbe5b78a0b912a8f05f341bcd4dd839ea85d26a08efaef13e696d999f4"')
+ CONFIG=$(echo $CONFIG | jq '.key_derivation.label = env.AZURE_AKV_KEY_DERIVATION_LABEL')
+ CONFIG=$(echo $CONFIG | jq '.key.akv.endpoint = env.AZURE_KEYVAULT_ENDPOINT')
+ CONFIG=$(echo $CONFIG | jq '.key.akv.bearer_token = env.BEARER_TOKEN')
+ echo $CONFIG > /tmp/importkey-config.json
+ echo "Importing $KEYID key with key release policy"
+ jq '.key.akv.bearer_token = "REDACTED"' /tmp/importkey-config.json
+ pushd . && cd $TOOLS_HOME/importkey && go run main.go -c /tmp/importkey-config.json -out && popd
+ mv $TOOLS_HOME/importkey/keyfile.bin $KEYFILE
+}
+
+echo Obtaining contract service parameters...
+CONTRACT_SERVICE_URL=${CONTRACT_SERVICE_URL:-"http://localhost:8000"}
+export CONTRACT_SERVICE_PARAMETERS=$(curl -k -f $CONTRACT_SERVICE_URL/parameters | base64 --wrap=0)
+
+envsubst < ../../policy/policy-in-template.json > /tmp/policy-in.json
+export CCE_POLICY=$(az confcom acipolicygen -i /tmp/policy-in.json --debug-mode)
+export CCE_POLICY_HASH=$(go run $TOOLS_HOME/securitypolicydigest/main.go -p $CCE_POLICY)
+echo "Training container policy hash $CCE_POLICY_HASH"
+
+# Obtain the token based on the AKV resource endpoint subdomain
+if [[ "$AZURE_KEYVAULT_ENDPOINT" == *".vault.azure.net" ]]; then
+ export BEARER_TOKEN=$(az account get-access-token --resource https://vault.azure.net | jq -r .accessToken)
+ echo "Importing keys to AKV key vaults can be only of type RSA-HSM"
+ export AZURE_AKV_KEY_TYPE="RSA-HSM"
+elif [[ "$AZURE_KEYVAULT_ENDPOINT" == *".managedhsm.azure.net" ]]; then
+ export BEARER_TOKEN=$(az account get-access-token --resource https://managedhsm.azure.net | jq -r .accessToken)
+ export AZURE_AKV_KEY_TYPE="oct-HSM"
+fi
+
+DATADIR=$REPO_ROOT/scenarios/$SCENARIO/data
+MODELDIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+import_key "ICMRFilesystemEncryptionKey" $DATADIR/icmr_key.bin
+import_key "COWINFilesystemEncryptionKey" $DATADIR/cowin_key.bin
+import_key "IndexFilesystemEncryptionKey" $DATADIR/index_key.bin
+import_key "ModelFilesystemEncryptionKey" $MODELDIR/model_key.bin
+import_key "OutputFilesystemEncryptionKey" $MODELDIR/output_key.bin
+
+## Cleanup
+rm /tmp/importkey-config.json
+rm /tmp/policy-in.json
diff --git a/scenarios/covid/deployment/azure/4-encrypt-data.sh b/scenarios/covid/deployment/azure/4-encrypt-data.sh
new file mode 100755
index 0000000..9e6da82
--- /dev/null
+++ b/scenarios/covid/deployment/azure/4-encrypt-data.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+
+DATADIR=$REPO_ROOT/scenarios/$SCENARIO/data
+MODELDIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+./generatefs.sh -d $DATADIR/icmr/preprocessed -k $DATADIR/icmr_key.bin -i $DATADIR/icmr.img
+./generatefs.sh -d $DATADIR/cowin/preprocessed -k $DATADIR/cowin_key.bin -i $DATADIR/cowin.img
+./generatefs.sh -d $DATADIR/index/preprocessed -k $DATADIR/index_key.bin -i $DATADIR/index.img
+./generatefs.sh -d $MODELDIR/models -k $MODELDIR/model_key.bin -i $MODELDIR/model.img
+
+sudo rm -rf $MODELDIR/output
+mkdir -p $MODELDIR/output
+./generatefs.sh -d $MODELDIR/output -k $MODELDIR/output_key.bin -i $MODELDIR/output.img
\ No newline at end of file
diff --git a/scenarios/covid/data/5-upload-encrypted-data.sh b/scenarios/covid/deployment/azure/5-upload-encrypted-data.sh
similarity index 82%
rename from scenarios/covid/data/5-upload-encrypted-data.sh
rename to scenarios/covid/deployment/azure/5-upload-encrypted-data.sh
index 94abf2f..8f5bf67 100755
--- a/scenarios/covid/data/5-upload-encrypted-data.sh
+++ b/scenarios/covid/deployment/azure/5-upload-encrypted-data.sh
@@ -1,11 +1,14 @@
#!/bin/bash
+DATADIR=$REPO_ROOT/scenarios/$SCENARIO/data
+MODELDIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
ACCOUNT_KEY=$(az storage account keys list --account-name $AZURE_STORAGE_ACCOUNT_NAME --only-show-errors | jq -r .[0].value)
az storage blob upload \
--account-name $AZURE_STORAGE_ACCOUNT_NAME \
--container $AZURE_ICMR_CONTAINER_NAME \
- --file icmr.img \
+ --file $DATADIR/icmr.img \
--name data.img \
--type page \
--overwrite \
@@ -14,7 +17,7 @@ az storage blob upload \
az storage blob upload \
--account-name $AZURE_STORAGE_ACCOUNT_NAME \
--container $AZURE_COWIN_CONTAINER_NAME \
- --file cowin.img \
+ --file $DATADIR/cowin.img \
--name data.img \
--type page \
--overwrite \
@@ -23,7 +26,7 @@ az storage blob upload \
az storage blob upload \
--account-name $AZURE_STORAGE_ACCOUNT_NAME \
--container $AZURE_INDEX_CONTAINER_NAME \
- --file index.img \
+ --file $DATADIR/index.img \
--name data.img \
--type page \
--overwrite \
@@ -32,7 +35,7 @@ az storage blob upload \
az storage blob upload \
--account-name $AZURE_STORAGE_ACCOUNT_NAME \
--container $AZURE_MODEL_CONTAINER_NAME \
- --file model.img \
+ --file $MODELDIR/model.img \
--name data.img \
--type page \
--overwrite \
@@ -41,7 +44,7 @@ az storage blob upload \
az storage blob upload \
--account-name $AZURE_STORAGE_ACCOUNT_NAME \
--container $AZURE_OUTPUT_CONTAINER_NAME \
- --file output.img \
+ --file $MODELDIR/output.img \
--name data.img \
--type page \
--overwrite \
diff --git a/scenarios/covid/deployment/azure/6-download-decrypt-model.sh b/scenarios/covid/deployment/azure/6-download-decrypt-model.sh
new file mode 100755
index 0000000..b6d043a
--- /dev/null
+++ b/scenarios/covid/deployment/azure/6-download-decrypt-model.sh
@@ -0,0 +1,37 @@
+#!/bin/bash
+
+MODELDIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+sudo rm -rf $MODELDIR/output
+mkdir -p $MODELDIR/output
+
+ACCOUNT_KEY=$(az storage account keys list --account-name $AZURE_STORAGE_ACCOUNT_NAME --only-show-errors | jq -r .[0].value)
+
+az storage blob download \
+ --account-name $AZURE_STORAGE_ACCOUNT_NAME \
+ --container $AZURE_OUTPUT_CONTAINER_NAME \
+ --file $MODELDIR/output.img \
+ --name data.img \
+ --account-key $ACCOUNT_KEY
+
+encryptedImage=$MODELDIR/output.img
+keyFilePath=$MODELDIR/output_key.bin
+
+echo Decrypting $encryptedImage with key $keyFilePath
+deviceName=cryptdevice1
+deviceNamePath="/dev/mapper/$deviceName"
+
+sudo cryptsetup luksOpen "$encryptedImage" "$deviceName" \
+ --key-file "$keyFilePath" \
+ --integrity-no-journal --persistent
+
+mountPoint=`mktemp -d`
+sudo mount -t ext4 "$deviceNamePath" "$mountPoint" -o loop
+
+cp -r $mountPoint/* $MODELDIR/output/
+
+echo "[!] Closing device..."
+
+sudo umount "$mountPoint"
+sleep 2
+sudo cryptsetup luksClose "$deviceName"
\ No newline at end of file
diff --git a/scenarios/covid/deployment/azure/aci-parameters-template.json b/scenarios/covid/deployment/azure/aci-parameters-template.json
new file mode 100644
index 0000000..8eb11fc
--- /dev/null
+++ b/scenarios/covid/deployment/azure/aci-parameters-template.json
@@ -0,0 +1,23 @@
+{
+ "containerRegistry": {
+ "value": ""
+ },
+ "ccePolicy": {
+ "value": ""
+ },
+ "EncfsSideCarArgs": {
+ "value": ""
+ },
+ "ContractService": {
+ "value": ""
+ },
+ "ContractServiceParameters": {
+ "value": ""
+ },
+ "Contracts": {
+ "value": ""
+ },
+ "PipelineConfiguration": {
+ "value": ""
+ }
+}
\ No newline at end of file
diff --git a/scenarios/covid/deployment/aci/arm-template.json b/scenarios/covid/deployment/azure/arm-template.json
similarity index 97%
rename from scenarios/covid/deployment/aci/arm-template.json
rename to scenarios/covid/deployment/azure/arm-template.json
index caa3a95..c87c88c 100644
--- a/scenarios/covid/deployment/aci/arm-template.json
+++ b/scenarios/covid/deployment/azure/arm-template.json
@@ -10,7 +10,7 @@
}
},
"location": {
- "defaultValue": "[resourceGroup().location]",
+ "defaultValue": "northeurope",
"type": "string",
"metadata": {
"description": "Location for all resources."
@@ -154,9 +154,9 @@
"mountPath": "/mnt/remote"
}
],
- "securityContext": {
- "privileged": "true"
- },
+ "securityContext": {
+ "privileged": "true"
+ },
"resources": {
"requests": {
"cpu": 0.5,
@@ -178,4 +178,4 @@
}
}
]
-}
+}
\ No newline at end of file
diff --git a/scenarios/covid/deployment/aci/deploy.sh b/scenarios/covid/deployment/azure/deploy.sh
similarity index 100%
rename from scenarios/covid/deployment/aci/deploy.sh
rename to scenarios/covid/deployment/azure/deploy.sh
diff --git a/scenarios/covid/deployment/aci/encrypted-filesystem-config-template.json b/scenarios/covid/deployment/azure/encrypted-filesystem-config-template.json
similarity index 100%
rename from scenarios/covid/deployment/aci/encrypted-filesystem-config-template.json
rename to scenarios/covid/deployment/azure/encrypted-filesystem-config-template.json
diff --git a/scenarios/covid/deployment/azure/generatefs.sh b/scenarios/covid/deployment/azure/generatefs.sh
new file mode 100755
index 0000000..df8833e
--- /dev/null
+++ b/scenarios/covid/deployment/azure/generatefs.sh
@@ -0,0 +1,75 @@
+#!/bin/bash
+
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+while getopts ":d:k:i:" options; do
+ case $options in
+ d)dataPath=$OPTARG;;
+ k)keyFilePath=$OPTARG;;
+ i)encryptedImage=$OPTARG;;
+ esac
+done
+
+echo Encrypting $dataPath with key $keyFilePath and generating $encryptedImage
+deviceName=cryptdevice1
+deviceNamePath="/dev/mapper/$deviceName"
+
+if [ -f "$keyFilePath" ]; then
+ echo "[!] Encrypting dataset using $keyFilePath"
+else
+ echo "[!] Generating keyfile..."
+ dd if=/dev/random of="$keyFilePath" count=1 bs=32
+ truncate -s 32 "$keyFilePath"
+fi
+
+echo "[!] Creating encrypted image..."
+
+response=`du -s $dataPath`
+read -ra arr <<< "$response"
+size=`echo "x=l($arr)/l(2); scale=0; 2^((x+0.5)/1)*2" | bc -l;`
+
+# cryptsetup requires 16M or more
+
+if (($((size)) < 65536)); then
+ size="65536"
+fi
+size=$size"K"
+
+echo "Data size: $size"
+
+rm -f "$encryptedImage"
+touch "$encryptedImage"
+truncate --size $size "$encryptedImage"
+
+sudo cryptsetup luksFormat --type luks2 "$encryptedImage" \
+ --key-file "$keyFilePath" -v --batch-mode --sector-size 4096 \
+ --cipher aes-xts-plain64 \
+ --pbkdf pbkdf2 --pbkdf-force-iterations 1000
+
+sudo cryptsetup luksOpen "$encryptedImage" "$deviceName" \
+ --key-file "$keyFilePath" \
+ --integrity-no-journal --persistent
+
+echo "[!] Formatting as ext4..."
+
+sudo mkfs.ext4 "$deviceNamePath"
+
+echo "[!] Mounting..."
+
+mountPoint=`mktemp -d`
+echo "Mounting to $mountPoint"
+sudo mount -t ext4 "$deviceNamePath" "$mountPoint" -o loop
+
+echo "[!] Copying contents to encrypted device..."
+
+# The /* is needed to copy folder contents instead of the folder + contents
+sudo cp -r $dataPath/* "$mountPoint"
+sudo rm -rf "$mountPoint/lost+found"
+ls "$mountPoint"
+
+echo "[!] Closing device..."
+
+sudo umount "$mountPoint"
+sleep 2
+sudo cryptsetup luksClose "$deviceName"
diff --git a/scenarios/covid/deployment/azure/importkey-config-template.json b/scenarios/covid/deployment/azure/importkey-config-template.json
new file mode 100644
index 0000000..42ed7ee
--- /dev/null
+++ b/scenarios/covid/deployment/azure/importkey-config-template.json
@@ -0,0 +1,29 @@
+{
+ "key":{
+ "kid": "",
+ "kty": "oct-HSM",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "akv": {
+ "endpoint": "",
+ "api_version": "api-version=7.3-preview",
+ "bearer_token": ""
+ }
+ },
+ "key_derivation":
+ {
+ "salt": "",
+ "label": ""
+ },
+ "claims": [
+ [{
+ "claim": "x-ms-sevsnpvm-hostdata",
+ "equals": ""
+ },
+ {
+ "claim": "x-ms-compliance-status",
+ "equals": "azure-compliant-uvm"
+ }]
+ ]
+}
\ No newline at end of file
diff --git a/scenarios/covid/deployment/docker/docker-compose-modelsave.yml b/scenarios/covid/deployment/docker/docker-compose-modelsave.yml
deleted file mode 100644
index 872a640..0000000
--- a/scenarios/covid/deployment/docker/docker-compose-modelsave.yml
+++ /dev/null
@@ -1,6 +0,0 @@
-services:
- model_save:
- image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}ccr-model-save:latest
- volumes:
- - $MODEL_OUTPUT_PATH:/mnt/model
- command: ["python3.9", "ccr_dpsgd_model_saving_template_v2.py"]
diff --git a/scenarios/covid/deployment/docker/docker-compose-preprocess.yml b/scenarios/covid/deployment/docker/docker-compose-preprocess.yml
deleted file mode 100644
index af7a100..0000000
--- a/scenarios/covid/deployment/docker/docker-compose-preprocess.yml
+++ /dev/null
@@ -1,19 +0,0 @@
-services:
- icmr:
- image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}preprocess-icmr:latest
- volumes:
- - $ICMR_INPUT_PATH:/mnt/depa_ccr_poc/data
- - $ICMR_OUTPUT_PATH:/mnt/output/icmr
- command: ["python3", "ccr_depa_covid_poc_dp_data_prep_icmr.py"]
- index:
- image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}preprocess-index:latest
- volumes:
- - $INDEX_INPUT_PATH:/mnt/depa_ccr_poc/data
- - $INDEX_OUTPUT_PATH:/mnt/output/index
- command: ["python3", "ccr_depa_covid_poc_dp_data_prep_index.py"]
- cowin:
- image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}preprocess-cowin:latest
- volumes:
- - $COWIN_INPUT_PATH:/mnt/depa_ccr_poc/data
- - $COWIN_OUTPUT_PATH:/mnt/output/cowin
- command: ["python3", "ccr_depa_covid_poc_dp_data_prep_cowin.py"]
diff --git a/scenarios/covid/deployment/docker/save-model.sh b/scenarios/covid/deployment/docker/save-model.sh
deleted file mode 100755
index 4f0d788..0000000
--- a/scenarios/covid/deployment/docker/save-model.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-export MODEL_OUTPUT_PATH=$PWD/../../data/modeller/model
-mkdir -p $MODEL_OUTPUT_PATH
-docker compose -f docker-compose-modelsave.yml up --remove-orphans
diff --git a/scenarios/covid/deployment/docker/train.sh b/scenarios/covid/deployment/docker/train.sh
deleted file mode 100755
index 37ee8eb..0000000
--- a/scenarios/covid/deployment/docker/train.sh
+++ /dev/null
@@ -1,10 +0,0 @@
-export DATA_DIR=$PWD/../../data
-export ICMR_INPUT_PATH=$DATA_DIR/icmr/preprocessed
-export INDEX_INPUT_PATH=$DATA_DIR/index/preprocessed
-export COWIN_INPUT_PATH=$DATA_DIR/cowin/preprocessed
-export MODEL_INPUT_PATH=$DATA_DIR/modeller/model
-export MODEL_OUTPUT_PATH=/tmp/output
-mkdir -p $MODEL_OUTPUT_PATH
-export CONFIGURATION_PATH=/tmp
-cp $PWD/../../config/pipeline_config.json /tmp/pipeline_config.json
-docker compose -f docker-compose-train.yml up --remove-orphans
diff --git a/scenarios/covid/deployment/local/docker-compose-modelsave.yml b/scenarios/covid/deployment/local/docker-compose-modelsave.yml
new file mode 100644
index 0000000..64f2061
--- /dev/null
+++ b/scenarios/covid/deployment/local/docker-compose-modelsave.yml
@@ -0,0 +1,6 @@
+services:
+ model_save:
+ image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}covid-model-save:latest
+ volumes:
+ - $MODEL_OUTPUT_PATH:/mnt/model
+ command: ["python3", "save_base_model.py"]
diff --git a/scenarios/covid/deployment/local/docker-compose-preprocess.yml b/scenarios/covid/deployment/local/docker-compose-preprocess.yml
new file mode 100644
index 0000000..4ea049a
--- /dev/null
+++ b/scenarios/covid/deployment/local/docker-compose-preprocess.yml
@@ -0,0 +1,19 @@
+services:
+ icmr:
+ image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}preprocess-icmr:latest
+ volumes:
+ - $ICMR_INPUT_PATH:/mnt/input/data
+ - $ICMR_OUTPUT_PATH:/mnt/output/preprocessed
+ command: ["python3", "preprocess_icmr.py"]
+ index:
+ image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}preprocess-index:latest
+ volumes:
+ - $INDEX_INPUT_PATH:/mnt/input/data
+ - $INDEX_OUTPUT_PATH:/mnt/output/preprocessed
+ command: ["python3", "preprocess_index.py"]
+ cowin:
+ image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}preprocess-cowin:latest
+ volumes:
+ - $COWIN_INPUT_PATH:/mnt/input/data
+ - $COWIN_OUTPUT_PATH:/mnt/output/preprocessed
+ command: ["python3", "preprocess_cowin.py"]
diff --git a/scenarios/covid/deployment/docker/docker-compose-train.yml b/scenarios/covid/deployment/local/docker-compose-train.yml
similarity index 90%
rename from scenarios/covid/deployment/docker/docker-compose-train.yml
rename to scenarios/covid/deployment/local/docker-compose-train.yml
index 6b7288c..91b3f72 100644
--- a/scenarios/covid/deployment/docker/docker-compose-train.yml
+++ b/scenarios/covid/deployment/local/docker-compose-train.yml
@@ -1,6 +1,8 @@
services:
train:
image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}depa-training:latest
+ environment:
+ - PYTHONUNBUFFERED=1
volumes:
- $ICMR_INPUT_PATH:/mnt/remote/icmr
- $COWIN_INPUT_PATH:/mnt/remote/cowin
diff --git a/scenarios/covid/deployment/docker/preprocess.sh b/scenarios/covid/deployment/local/preprocess.sh
similarity index 71%
rename from scenarios/covid/deployment/docker/preprocess.sh
rename to scenarios/covid/deployment/local/preprocess.sh
index 46c4f4e..b4401e0 100755
--- a/scenarios/covid/deployment/docker/preprocess.sh
+++ b/scenarios/covid/deployment/local/preprocess.sh
@@ -1,4 +1,8 @@
-export DATA_DIR=$PWD/../../data
+#!/bin/bash
+
+export REPO_ROOT="$(git rev-parse --show-toplevel)"
+export SCENARIO="covid"
+export DATA_DIR=$REPO_ROOT/scenarios/$SCENARIO/data
export ICMR_INPUT_PATH=$DATA_DIR/icmr
export ICMR_OUTPUT_PATH=$DATA_DIR/icmr/preprocessed
export INDEX_INPUT_PATH=$DATA_DIR/index
diff --git a/scenarios/covid/deployment/local/save-model.sh b/scenarios/covid/deployment/local/save-model.sh
new file mode 100755
index 0000000..49edf0a
--- /dev/null
+++ b/scenarios/covid/deployment/local/save-model.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+
+export REPO_ROOT="$(git rev-parse --show-toplevel)"
+export SCENARIO="covid"
+export MODEL_OUTPUT_PATH=$REPO_ROOT/scenarios/$SCENARIO/modeller/models
+mkdir -p $MODEL_OUTPUT_PATH
+docker compose -f docker-compose-modelsave.yml up --remove-orphans
\ No newline at end of file
diff --git a/scenarios/covid/deployment/local/train.sh b/scenarios/covid/deployment/local/train.sh
new file mode 100755
index 0000000..e5c9992
--- /dev/null
+++ b/scenarios/covid/deployment/local/train.sh
@@ -0,0 +1,26 @@
+#!/bin/bash
+
+export REPO_ROOT="$(git rev-parse --show-toplevel)"
+export SCENARIO="covid"
+export DATA_DIR=$REPO_ROOT/scenarios/$SCENARIO/data
+export MODEL_DIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+export ICMR_INPUT_PATH=$DATA_DIR/icmr/preprocessed
+export INDEX_INPUT_PATH=$DATA_DIR/index/preprocessed
+export COWIN_INPUT_PATH=$DATA_DIR/cowin/preprocessed
+
+export MODEL_INPUT_PATH=$MODEL_DIR/models
+
+# export MODEL_OUTPUT_PATH=/tmp/output
+export MODEL_OUTPUT_PATH=$MODEL_DIR/output
+sudo rm -rf $MODEL_OUTPUT_PATH
+mkdir -p $MODEL_OUTPUT_PATH
+
+export CONFIGURATION_PATH=$REPO_ROOT/scenarios/$SCENARIO/config
+# export CONFIGURATION_PATH=/tmp
+# cp $PWD/../../config/pipeline_config.json /tmp/pipeline_config.json
+
+# Run consolidate_pipeline.sh to create pipeline_config.json
+$REPO_ROOT/scenarios/$SCENARIO/config/consolidate_pipeline.sh
+
+docker compose -f docker-compose-train.yml up --remove-orphans
diff --git a/scenarios/covid/export-variables.sh b/scenarios/covid/export-variables.sh
new file mode 100755
index 0000000..da94633
--- /dev/null
+++ b/scenarios/covid/export-variables.sh
@@ -0,0 +1,64 @@
+#!/bin/bash
+
+# Azure Naming Rules:
+#
+# Resource Group:
+# - 1–90 characters
+# - Letters, numbers, underscores, parentheses, hyphens, periods allowed
+# - Cannot end with a period (.)
+# - Case-insensitive, unique within subscription
+#
+# Key Vault:
+# - 3-24 characters
+# - Globally unique name
+# - Lowercase letters, numbers, hyphens only
+# - Must start and end with letter or number
+#
+# Storage Account:
+# - 3-24 characters
+# - Globally unique name
+# - Lowercase letters and numbers only
+#
+# Storage Container:
+# - 3-63 characters
+# - Lowercase letters, numbers, hyphens only
+# - Must start and end with a letter or number
+# - No consecutive hyphens
+# - Unique within storage account
+
+# For cloud resource creation:
+declare -x SCENARIO=covid
+declare -x REPO_ROOT="$(git rev-parse --show-toplevel)"
+declare -x CONTAINER_REGISTRY=ispirt.azurecr.io
+declare -x AZURE_LOCATION=centralindia
+declare -x AZURE_SUBSCRIPTION_ID=
+declare -x AZURE_RESOURCE_GROUP=
+declare -x AZURE_KEYVAULT_ENDPOINT=
+declare -x AZURE_STORAGE_ACCOUNT_NAME=
+
+declare -x AZURE_ICMR_CONTAINER_NAME=icmrcontainer
+declare -x AZURE_COWIN_CONTAINER_NAME=cowincontainer
+declare -x AZURE_INDEX_CONTAINER_NAME=indexcontainer
+declare -x AZURE_MODEL_CONTAINER_NAME=modelcontainer
+declare -x AZURE_OUTPUT_CONTAINER_NAME=outputcontainer
+
+# For key import:
+declare -x CONTRACT_SERVICE_URL=https://depa-training-contract-service.centralindia.cloudapp.azure.com:8000
+declare -x TOOLS_HOME=$REPO_ROOT/external/confidential-sidecar-containers/tools
+
+# Export all variables to make them available to other scripts
+export SCENARIO
+export REPO_ROOT
+export CONTAINER_REGISTRY
+export AZURE_LOCATION
+export AZURE_SUBSCRIPTION_ID
+export AZURE_RESOURCE_GROUP
+export AZURE_KEYVAULT_ENDPOINT
+export AZURE_STORAGE_ACCOUNT_NAME
+export AZURE_ICMR_CONTAINER_NAME
+export AZURE_COWIN_CONTAINER_NAME
+export AZURE_INDEX_CONTAINER_NAME
+export AZURE_MODEL_CONTAINER_NAME
+export AZURE_OUTPUT_CONTAINER_NAME
+export CONTRACT_SERVICE_URL
+export TOOLS_HOME
\ No newline at end of file
diff --git a/scenarios/covid/src/ccr_dpsgd_model_saving_template_v2.py b/scenarios/covid/src/ccr_dpsgd_model_saving_template_v2.py
deleted file mode 100644
index 58ebc74..0000000
--- a/scenarios/covid/src/ccr_dpsgd_model_saving_template_v2.py
+++ /dev/null
@@ -1,100 +0,0 @@
-import onnx
-import numpy as np
-import pandas as pd
-import torch
-import torch.nn as nn
-import torch.optim as optim
-from torch.utils.data import Dataset, DataLoader
-from sklearn.model_selection import train_test_split
-from sklearn.preprocessing import StandardScaler
-
-# Step 1: Load the CSV data
-#data = pd.read_csv('/tmp/sandbox_icmr_cowin_index_without_key_identifiers.csv')
-data = pd.DataFrame(np.random.randint(0,100,size=(2119, 11)), columns=['A','B','C','D','E','F','G','H','I','J','K'])
-
-features = data.drop(columns=["K"])
-target = data["K"]
-
-# Step 2: Preprocess the data
-# Assuming your target column is named 'target'
-#features = data.drop(columns=['icmr_a_icmr_test_result'])
-#target = data['icmr_a_icmr_test_result']
-
-# Split the data into training and validation sets
-train_features, val_features, train_target, val_target = train_test_split(features, target, test_size=0.2, random_state=42)
-
-# Standardize the features
-scaler = StandardScaler()
-train_features = scaler.fit_transform(train_features)
-val_features = scaler.transform(val_features)
-
-# Step 3: Create a PyTorch dataset and data loader
-class CustomDataset(Dataset):
- def __init__(self, features, target):
- self.features = torch.tensor(features, dtype=torch.float32)
- self.target = torch.tensor(target.values, dtype=torch.float32)
-
- def __len__(self):
- return len(self.features)
-
- def __getitem__(self, idx):
- return self.features[idx], self.target[idx]
-
-train_dataset = CustomDataset(train_features, train_target)
-val_dataset = CustomDataset(val_features, val_target)
-
-batch_size = 32
-train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
-val_loader = DataLoader(val_dataset, batch_size=batch_size)
-
-# Step 4: Define a simple neural network model
-class SimpleModel(nn.Module):
- def __init__(self, input_dim):
- super(SimpleModel, self).__init__()
- self.fc1 = nn.Linear(input_dim, 128)
- self.fc2 = nn.Linear(128, 64)
- self.fc3 = nn.Linear(64, 1)
-
- def forward(self, x):
- x = torch.relu(self.fc1(x))
- x = torch.relu(self.fc2(x))
- x = self.fc3(x)
- return x
-
-# Step 5: Choose a loss function and optimizer
-model = SimpleModel(input_dim=train_features.shape[1])
-criterion = nn.MSELoss()
-optimizer = optim.Adam(model.parameters(), lr=0.001)
-
-# Step 6: Train the model
-num_epochs = 10
-
-for epoch in range(num_epochs):
- model.train()
- for inputs, targets in train_loader:
- optimizer.zero_grad()
- outputs = model(inputs)
- loss = criterion(outputs, targets)
- loss.backward()
- optimizer.step()
-
- model.eval()
- val_loss = 0.0
- with torch.no_grad():
- for inputs, targets in val_loader:
- outputs = model(inputs)
- val_loss += criterion(outputs, targets).item()
-
- #val_loss /= len(val_loader
- #print(f'Epoch [{epoch+1}/{num_epochs}] - Validation Loss: {val_loss:.4f}')
-
-print('Training finished.')
-
-torch.onnx.export(model, torch.randn(1, train_features.shape[1]), "/mnt/model/model.onnx", verbose=True)
-print('Model saved as ONNX.')
-
-#model.save('/mnt/model/dpsgd_model')
-
-#model.save_weights('/mnt/model/model_weights')
-
-#model.save('/mnt/model/dpsgd_model.h5')
diff --git a/scenarios/covid/src/ccr_depa_covid_poc_dp_data_prep_cowin.py b/scenarios/covid/src/preprocess_cowin.py
similarity index 66%
rename from scenarios/covid/src/ccr_depa_covid_poc_dp_data_prep_cowin.py
rename to scenarios/covid/src/preprocess_cowin.py
index cba6b80..dc0370b 100644
--- a/scenarios/covid/src/ccr_depa_covid_poc_dp_data_prep_cowin.py
+++ b/scenarios/covid/src/preprocess_cowin.py
@@ -1,32 +1,31 @@
-# 2023, The DEPA CCR DP Training Reference Implementation
-# authors shyam@ispirt.in, sridhar.avs@ispirt.in
+# 2025 DEPA Foundation
#
-# Licensed TBD
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Key references / Attributions: https://depa.world/training/contracts
-# Key frameworks used : pyspark
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
-#Crirical Library Imports
from pyspark.sql import SparkSession
-from pyspark.sql.functions import *
-from pyspark.sql.types import *
-from pyspark.sql.functions import sha2, concat_ws # Hashing Related functions
-from pyspark.sql.functions import col , column
+from pyspark.sql.functions import sha2, concat_ws
"""##**Key Configuration Variables**"""
# Debug Enabled
-debug_poc=True
+debug_poc=False
# Model Input and output folders
-model_input_folder='/mnt/depa_ccr_poc/data/'
+model_input_folder='/mnt/input/data/'
cowin_file='poc_data_cowin_data.csv'
@@ -34,7 +33,7 @@
load_process_dp_data=True
# Data Provider Level - Preprocess Locations
-dp_cowin_output_folder='/mnt/output/cowin/'
+dp_cowin_output_folder='/mnt/output/preprocessed/'
# DP Standardisation Non Anon Files
dp_cowin_std_nonanon_file ='dp_cowin_standardised_nonanon.csv'
@@ -42,12 +41,11 @@
# DP Standardisation Anon | Tokenised Files
dp_cowin_std_anon_file ='dp_cowin_standardised_anon.csv'
-# In the CCR
-model_output_folder='/mnt/depa_ccr_poc/output/'
"""# Setting up spark session"""
spark = SparkSession.builder.appName('CCR_DEPA_COVID_POC_Code').getOrCreate()
+spark.sparkContext.setLogLevel("ERROR")
"""# Common Utility Functions
@@ -83,15 +81,16 @@ def dp_load_data(input_folder,data_file,load=True,debug=True):
"""
if load:
- input_file=input_folder+data_file
- if debug:
- print("Debug | input_file",input_file)
- data_loaded= spark.read.csv(
- input_file,
- header=True,
- inferSchema=True,
- mode="DROPMALFORMED"
- )
+ input_file=input_folder+data_file
+ if debug:
+ print("Debug | input_file",input_file)
+ # Prefer schema-on-read via CSV or Parquet; keep CSV for compatibility
+ data_loaded = spark.read.csv(
+ input_file,
+ header=True,
+ inferSchema=True,
+ mode="DROPMALFORMED"
+ )
if debug:
print("Debug | input_file",data_loaded.count())
data_loaded.show()
@@ -103,21 +102,21 @@ def dp_process_cowin_full(input_folder,data_file,load=True,debug=True):
Suggested to have some standardisations defined by the DEP/ Data Consumer (DC) for easier joining process
"""
if load:
- input_file=input_folder+data_file
+ input_file=input_folder+data_file
if debug:
- print("Debug | input_file",input_file)
+ print("Debug | input_file",input_file)
- data_loaded= spark.read.csv(
- input_file,
- header=True,
- inferSchema=True,
- mode="DROPMALFORMED"
- )
+ data_loaded = spark.read.csv(
+ input_file,
+ header=True,
+ inferSchema=True,
+ mode="DROPMALFORMED"
+ )
if debug:
- print("Debug | input_file",data_loaded.count())
- data_loaded.show()
+ print("Debug | input_file",data_loaded.count())
+ data_loaded.show()
# Standardisations
@@ -134,7 +133,7 @@ def dp_process_cowin_full(input_folder,data_file,load=True,debug=True):
if i not in do_not_change:
sandbox_dp_cowin = sandbox_dp_cowin.withColumnRenamed(i,'cowin_'+i)
- sandbox_dp_cowin.toPandas().to_csv(dp_cowin_output_folder+dp_cowin_std_nonanon_file)
+ sandbox_dp_cowin.toPandas().to_csv(dp_cowin_output_folder + dp_cowin_std_nonanon_file, index=False)
# Anonymisation of key identifiers
sandbox_dp_cowin_anon = sandbox_dp_cowin.withColumn('pk_beneficiary_id_hashed', sha2(concat_ws("", sandbox_dp_cowin.pk_beneficiary_id),256)) \
@@ -143,7 +142,7 @@ def dp_process_cowin_full(input_folder,data_file,load=True,debug=True):
.withColumn("ref_id_verified_hashed", sha2(concat_ws("", sandbox_dp_cowin.ref_id_verified),256)) \
.drop("pk_mobno").drop("pk_beneficiary_id").drop("ref_uhid").drop("ref_id_verified").cache()
- sandbox_dp_cowin_anon.toPandas().to_csv(dp_cowin_output_folder + dp_cowin_std_anon_file)
+ sandbox_dp_cowin_anon.toPandas().to_csv(dp_cowin_output_folder + dp_cowin_std_anon_file, index=False)
if debug_poc:
print("Debug | Dataset Created ", "sandbox_cowin_processed_anon")
diff --git a/scenarios/covid/src/ccr_depa_covid_poc_dp_data_prep_icmr.py b/scenarios/covid/src/preprocess_icmr.py
similarity index 76%
rename from scenarios/covid/src/ccr_depa_covid_poc_dp_data_prep_icmr.py
rename to scenarios/covid/src/preprocess_icmr.py
index 88f4e37..4598e74 100644
--- a/scenarios/covid/src/ccr_depa_covid_poc_dp_data_prep_icmr.py
+++ b/scenarios/covid/src/preprocess_icmr.py
@@ -1,35 +1,36 @@
-# 2023, The DEPA CCR DP Training Reference Implementation
-# authors shyam@ispirt.in, sridhar.avs@ispirt.in
+# 2025 DEPA Foundation
#
-# Licensed TBD
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Key references / Attributions: https://depa.world/training/contracts
-# Key frameworks used : pyspark
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
from pyspark.sql import SparkSession
-from pyspark.sql.functions import *
-from pyspark.sql.types import *
-from pyspark.sql.functions import sha2, concat_ws # Hashing Related functions
-from pyspark.sql.functions import col , column
+from pyspark.sql.functions import sha2, concat_ws
##**Key Configuration Variables**"""
# Debug Enabled
-debug_poc=True
+debug_poc=False
# Model Input folder
-icmr_input_folder='/mnt/depa_ccr_poc/data/'
+icmr_input_folder='/mnt/input/data/'
# POC State-wise dummy data
icmr_file='poc_data_icmr_data.csv'
# Data Provider Level - Preprocess Locations
-dp_icmr_output_folder='/mnt/output/icmr/'
+dp_icmr_output_folder='/mnt/output/preprocessed/'
# DP Standardisation Non Anon Files
dp_icmr_std_nonanon_file ='dp_icmr_standardised_nonanon.csv'
@@ -40,6 +41,7 @@
"""# Setting up spark session"""
spark = SparkSession.builder.appName('CCR_DEPA_COVID_POC_Code').getOrCreate()
+spark.sparkContext.setLogLevel("ERROR")
"""# Common Utility Functions
@@ -78,11 +80,11 @@ def dp_load_data(input_folder,data_file,load=True,debug=True):
if debug:
print("Debug | input_file",input_file)
data_loaded= spark.read.csv(
- input_file,
- header=True,
- inferSchema=True,
- mode="DROPMALFORMED"
- )
+ input_file,
+ header=True,
+ inferSchema=True,
+ mode="DROPMALFORMED"
+ )
if debug:
print("Debug | input_file",data_loaded.count())
data_loaded.show()
@@ -107,7 +109,7 @@ def dp_process_icmr_full(input_folder,data_file,load=True,debug=True):
header=True,
inferSchema=True,
mode="DROPMALFORMED"
- )
+ )
if debug:
print("Debug | input_file",data_loaded.count())
@@ -133,7 +135,7 @@ def dp_process_icmr_full(input_folder,data_file,load=True,debug=True):
sandbox_dp_icmr = sandbox_dp_icmr.withColumnRenamed(i,'icmr_'+i)
# Create the Output
- sandbox_dp_icmr.toPandas().to_csv(dp_icmr_output_folder+ dp_icmr_std_nonanon_file)
+ sandbox_dp_icmr.toPandas().to_csv(dp_icmr_output_folder+ dp_icmr_std_nonanon_file, index=False)
# Anonymisation of key identifiers
sandbox_dp_icmr_anon = sandbox_dp_icmr.withColumn('pk_mobno_hashed', sha2(concat_ws("", sandbox_dp_icmr.pk_mobno),256)) \
@@ -142,7 +144,7 @@ def dp_process_icmr_full(input_folder,data_file,load=True,debug=True):
.withColumn('fk_icmr_labid_hashed', sha2(concat_ws("", sandbox_dp_icmr.fk_icmr_labid),256)) \
.drop("pk_mobno").drop("ref_srfno").drop("pk_icmrno").drop("fk_icmr_labid").cache()
- sandbox_dp_icmr_anon.toPandas().to_csv(dp_icmr_output_folder + dp_icmr_std_anon_file)
+ sandbox_dp_icmr_anon.toPandas().to_csv(dp_icmr_output_folder + dp_icmr_std_anon_file, index=False)
if debug_poc:
print("Debug | sandbox_icmr_anon count =", sandbox_dp_icmr_anon.count())
print("Debug | Dataset Created ", "sandbox_icmr_processed_nonanon")
diff --git a/scenarios/covid/src/ccr_depa_covid_poc_dp_data_prep_index.py b/scenarios/covid/src/preprocess_index.py
similarity index 78%
rename from scenarios/covid/src/ccr_depa_covid_poc_dp_data_prep_index.py
rename to scenarios/covid/src/preprocess_index.py
index 83dfe0c..94e90ee 100644
--- a/scenarios/covid/src/ccr_depa_covid_poc_dp_data_prep_index.py
+++ b/scenarios/covid/src/preprocess_index.py
@@ -1,30 +1,30 @@
-# 2023, The DEPA CCR DP Training Reference Implementation
-# authors shyam@ispirt.in, sridhar.avs@ispirt.in
+# 2025 DEPA Foundation
#
-# Licensed TBD
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Key references / Attributions: https://depa.world/training/contracts
-# Key frameworks used : pyspark
-#Crirical Library Imports
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
from pyspark.sql import SparkSession
-from pyspark.sql.functions import *
-from pyspark.sql.types import *
-from pyspark.sql.functions import sha2, concat_ws # Hashing Related functions
-from pyspark.sql.functions import col , column
+from pyspark.sql.functions import sha2, concat_ws
##**Key Configuration Variables**"""
# Debug Enabled
-debug_poc=True
+debug_poc=False
# Model Input and output folders
-model_input_folder='/mnt/depa_ccr_poc/data/'
+model_input_folder='/mnt/input/data/'
# POC State-wise dummy data
index_file ='poc_data_statewarroom_data.csv'
@@ -33,7 +33,7 @@
load_process_dp_data=True
# Data Provider Level - Preprocess Locations
-dp_index_output_folder='/mnt/output/index/'
+dp_index_output_folder='/mnt/output/preprocessed/'
# DP Standardisation Non Anon Files
dp_index_std_nonanon_file ='dp_index_standardised_nonanon.csv'
@@ -41,9 +41,6 @@
# DP Standardisation Anon | Tokenised Files
dp_index_std_anon_file ='dp_index_standardised_anon.csv'
-# In the CCR
-model_output_folder='/mnt/depa_ccr_poc/output/'
-
dp_joined_dataset_identifiers_file='sandbox_icmr_cowin_index_linked_anon.csv'
dp_joined_dataset_wo_identifiers_file='sandbox_icmr_cowin_index_without_key_identifiers.csv'
@@ -53,6 +50,7 @@
"""# Setting up spark session"""
spark = SparkSession.builder.appName('CCR_DEPA_COVID_POC_Code').getOrCreate()
+spark.sparkContext.setLogLevel("ERROR")
"""# Common Utility Functions
@@ -91,11 +89,11 @@ def dp_load_data(input_folder,data_file,load=True,debug=True):
if debug:
print("Debug | input_file",input_file)
data_loaded= spark.read.csv(
- input_file,
- header=True,
- inferSchema=True,
- mode="DROPMALFORMED"
- )
+ input_file,
+ header=True,
+ inferSchema=True,
+ mode="DROPMALFORMED"
+ )
if debug:
print("Debug | input_file",data_loaded.count())
data_loaded.show()
@@ -119,7 +117,7 @@ def dp_process_index_full(input_folder,data_file,load=True,debug=True):
header=True,
inferSchema=True,
mode="DROPMALFORMED"
- )
+ )
if debug:
print("Debug | input_file",data_loaded.count())
@@ -146,7 +144,7 @@ def dp_process_index_full(input_folder,data_file,load=True,debug=True):
sandbox_dp_index = sandbox_dp_index.withColumnRenamed(i,'index_'+i)
# Create the Output
- sandbox_dp_index.toPandas().to_csv(dp_index_output_folder+ dp_index_std_nonanon_file)
+ sandbox_dp_index.toPandas().to_csv(dp_index_output_folder+ dp_index_std_nonanon_file, index=False)
# Anonymisation of key identifiers
sandbox_dp_index_anon = sandbox_dp_index.withColumn('pk_icmrno_hashed', sha2(concat_ws("", sandbox_dp_index.pk_icmrno),256)) \
@@ -155,7 +153,7 @@ def dp_process_index_full(input_folder,data_file,load=True,debug=True):
.withColumn("ref_labid_hashed", sha2(concat_ws("", sandbox_dp_index.ref_labid),256)) \
.drop("pk_mobno").drop("ref_srfno").drop("pk_icmrno").drop("ref_labid").cache()
- sandbox_dp_index_anon.toPandas().to_csv(dp_index_output_folder + dp_index_std_anon_file)
+ sandbox_dp_index_anon.toPandas().to_csv(dp_index_output_folder + dp_index_std_anon_file, index=False)
if debug_poc:
print("Debug | Dataset Created ", "sandbox_index_processed_anon")
diff --git a/scenarios/covid/src/save_base_model.py b/scenarios/covid/src/save_base_model.py
new file mode 100644
index 0000000..b5f96db
--- /dev/null
+++ b/scenarios/covid/src/save_base_model.py
@@ -0,0 +1,153 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
+import onnx
+import numpy as np
+import pandas as pd
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import Dataset, DataLoader
+from sklearn.model_selection import train_test_split
+from sklearn.preprocessing import StandardScaler
+
+
+class CustomDataset(Dataset):
+ """PyTorch dataset for feature-target pairs."""
+
+ def __init__(self, features, target):
+ self.features = torch.tensor(features, dtype=torch.float32)
+ self.target = torch.tensor(target.values, dtype=torch.float32)
+
+ def __len__(self):
+ return len(self.features)
+
+ def __getitem__(self, idx):
+ return self.features[idx], self.target[idx]
+
+
+class BaseModel(nn.Module):
+ """Binary classification neural network model."""
+
+ def __init__(self, input_dim):
+ super(BaseModel, self).__init__()
+ self.fc1 = nn.Linear(input_dim, 128)
+ self.fc2 = nn.Linear(128, 64)
+ self.fc3 = nn.Linear(64, 1)
+ self.dropout = nn.Dropout(0.3)
+
+ def forward(self, x):
+ x = torch.relu(self.fc1(x))
+ x = self.dropout(x)
+ x = torch.relu(self.fc2(x))
+ x = self.dropout(x)
+ x = self.fc3(x)
+ x = torch.sigmoid(x)
+ return x
+
+
+def main():
+ # Load and preprocess data
+ data = pd.DataFrame(np.random.randint(0, 100, size=(2119, 10)),
+ columns=['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J'])
+
+ features = data.drop(columns=["J"])
+ target = (data["J"] > 50).astype(float) # Binary classification
+
+ # Split data
+ train_features, val_features, train_target, val_target = train_test_split(
+ features, target, test_size=0.2, random_state=42
+ )
+
+ # Standardize features
+ scaler = StandardScaler()
+ train_features = scaler.fit_transform(train_features)
+ val_features = scaler.transform(val_features)
+
+ # Create datasets and dataloaders
+ train_dataset = CustomDataset(train_features, train_target)
+ val_dataset = CustomDataset(val_features, val_target)
+
+ batch_size = 32
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
+ val_loader = DataLoader(val_dataset, batch_size=batch_size)
+
+ # Initialize model, loss, and optimizer
+ model = BaseModel(input_dim=train_features.shape[1])
+ criterion = nn.BCELoss()
+ optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
+
+ # Training loop
+ num_epochs = 10
+
+ for epoch in range(num_epochs):
+ # Training phase
+ model.train()
+ train_loss = 0.0
+ correct = 0
+ total = 0
+
+ for inputs, targets in train_loader:
+ optimizer.zero_grad()
+ outputs = model(inputs)
+ targets = targets.unsqueeze(1)
+
+ loss = criterion(outputs, targets)
+ loss.backward()
+ optimizer.step()
+
+ train_loss += loss.item()
+ predicted = (outputs >= 0.5).float()
+ total += targets.size(0)
+ correct += (predicted == targets).sum().item()
+
+ # Validation phase
+ model.eval()
+ val_loss = 0.0
+ val_correct = 0
+ val_total = 0
+
+ with torch.no_grad():
+ for inputs, targets in val_loader:
+ outputs = model(inputs)
+ targets = targets.unsqueeze(1)
+ val_loss += criterion(outputs, targets).item()
+
+ predicted = (outputs >= 0.5).float()
+ val_total += targets.size(0)
+ val_correct += (predicted == targets).sum().item()
+
+ # Calculate metrics
+ train_loss /= len(train_loader)
+ val_loss /= len(val_loader)
+ train_acc = 100 * correct / total
+ val_acc = 100 * val_correct / val_total
+
+ print(f'Epoch [{epoch+1}/{num_epochs}] - Train Loss: {train_loss:.4f}, '
+ f'Train Acc: {train_acc:.2f}%, Val Loss: {val_loss:.4f}, '
+ f'Val Acc: {val_acc:.2f}%')
+
+ print('Training finished.')
+
+ # Save model
+ torch.onnx.export(model, torch.randn(1, train_features.shape[1]),
+ "/mnt/model/model.onnx", verbose=True)
+ print('Model saved as ONNX.')
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scenarios/credit-risk/.gitignore b/scenarios/credit-risk/.gitignore
new file mode 100644
index 0000000..db1ef7f
--- /dev/null
+++ b/scenarios/credit-risk/.gitignore
@@ -0,0 +1,19 @@
+**/preprocessed/*
+*.csv
+*.parquet
+*.bin
+*.img
+*.pth
+*.pt
+*.onnx
+*.npy
+*.gz
+
+# Ignore modeller output folder (relative to repo root)
+modeller/output/
+
+data/
+
+**/kaggle.json
+
+**/__pycache__/
\ No newline at end of file
diff --git a/scenarios/credit-risk/README.md b/scenarios/credit-risk/README.md
new file mode 100644
index 0000000..a9d9772
--- /dev/null
+++ b/scenarios/credit-risk/README.md
@@ -0,0 +1,380 @@
+# Home Credit Default Risk Prediction
+
+## Scenario Type
+
+| Scenario name | Scenario type | Task type | Privacy | No. of TDPs* | Data type (format) | Model type (format) | Join type (No. of datasets) |
+|--------------|---------------|-----------------|--------------|-----------|------------|------------|------------|
+| Credit Risk | Training - Classical ML | Binary Classification | Differentially Private | 4 | PII tabular data (Parquet) | XGBoost (JSON) | Horizontal (4)|
+
+---
+
+## Scenario Description
+
+This scenario involves training an XGBoost model on the [Home Credit Default Risk](https://www.kaggle.com/c/home-credit-default-risk) datasets [[1, 2]](README.md#references). We frame this scenario as involving four Training Data Providers (TDPs) - Bank A providing data for clients' credit applications, previous applications and payment installments, Bank B providing data on credit card balance, the Credit Bureau providing data on previous loans, and a Fintech providing data on point of sale (POS) cash balance. Here, Bank A is also the Training Data Consumer (TDC) who wishes to train the model on the joined datasets, in order to build a default risk prediction model.
+
+The end-to-end training pipeline consists of the following phases:
+
+1. Data pre-processing
+2. Data packaging, encryption and upload
+3. Encryption key import with key release policies
+4. Deployment and execution of CCR
+5. Trained model decryption
+
+## Build container images
+
+Build container images required for this sample as follows.
+
+```bash
+export SCENARIO=credit-risk
+export REPO_ROOT="$(git rev-parse --show-toplevel)"
+cd $REPO_ROOT/scenarios/$SCENARIO
+./ci/build.sh
+```
+
+This script builds the following container images.
+
+- ```preprocess-bank-a```: Container for pre-processing Bank A's dataset.
+- ```preprocess-bank-b```: Container for pre-processing Bank B's dataset.
+- ```preprocess-bureau```: Container for pre-processing Bureau's dataset.
+- ```preprocess-fintech```: Container for pre-processing Fintech's dataset.
+
+Alternatively, you can pull and use pre-built container images from the ispirt container registry by setting the following environment variable. Docker hub has started throttling which may effect the upload/download time, especially when images are bigger size. So, It is advisable to use other container registries. We are using Azure container registry (ACR) as shown below:
+
+```bash
+export CONTAINER_REGISTRY=ispirt.azurecr.io
+cd $REPO_ROOT/scenarios/$SCENARIO
+./ci/pull-containers.sh
+```
+
+## Data pre-processing
+
+The folder ```scenarios/credit-risk/src``` contains scripts for downloading and pre-processing the datasets. Acting as a Training Data Provider (TDP), prepare your datasets.
+
+Since these datasets are downloaded from [Kaggle.com](https://kaggle.com), set your Kaggle credentials as environment variables before running the preprocess script. The Kaggle credentials can be obtained from your Kaggle account settings > API > Create new token.
+
+```bash
+export KAGGLE_USERNAME=
+export KAGGLE_KEY=
+```
+
+Then, run the preprocess script.
+
+```bash
+cd $REPO_ROOT/scenarios/$SCENARIO/deployment/local
+./preprocess.sh
+```
+
+The datasets are saved to the [data](./data/) directory.
+
+## Deploy locally
+
+Assuming you have cleartext access to all the datasets, you can train the model _locally_ as follows:
+
+```bash
+./train.sh
+```
+
+The script joins the datasets and trains the model using a pipeline configuration. To modify the various components of the training pipeline, you can edit the training config files in the [config](./config/) directory. The training config files are used to create the pipeline configuration ([pipeline_config.json](./config/pipeline_config.json)) created by consolidating all the TDC's training config files, namely the [model config](./config/model_config.json), [dataset config](./config/dataset_config.json), [loss function config](./config/loss_config.json), [training config](./config/train_config_template.json), [evaluation config](./config/eval_config.json), and if multiple datasets are used, the [data join config](./config/join_config.json). These enable the TDC to design highly customized training pipelines without requiring review and approval of new custom code for each use case—reducing risks from potentially malicious or non-compliant code. The consolidated pipeline configuration is then attested against the signed contract using the TDP’s policy-as-code. If approved, it is executed in the CCR to train the model, which we will deploy in the next section.
+
+```mermaid
+flowchart TD
+
+ subgraph Config Files
+ C1[model_config.json]
+ C2[dataset_config.json]
+ C3[loss_config.json]
+ C4[train_config_template.json]
+ C5[eval_config.json]
+ C6[join_config.json]
+ end
+
+ B[Consolidated into
pipeline_config.json]
+
+ C1 --> B
+ C2 --> B
+ C3 --> B
+ C4 --> B
+ C5 --> B
+ C6 --> B
+
+ B --> D[Attested against contract
using policy-as-code]
+ D --> E{Approved?}
+ E -- Yes --> F[CCR training begins]
+ E -- No --> H[Rejected: fix config]
+```
+
+If all goes well, you should see output similar to the following output, and the trained model and evaluation metrics will be saved under the folder [output](./modeller/output).
+
+```
+train-1 | Training samples: 43636
+train-1 | Validation samples: 10909
+train-1 | Test samples: 5455
+train-1 | Dataset constructed from config
+train-1 | Model loaded from ONNX file
+train-1 | Optimizer Adam loaded from config
+train-1 | Scheduler CyclicLR loaded from config
+train-1 | Custom loss function loaded from config
+train-1 | Epoch 1/1 completed | Training Loss: 0.1586
+train-1 | Epoch 1/1 completed | Validation Loss: 0.0860
+train-1 | Saving trained model to /mnt/remote/output/trained_model.onnx
+train-1 | Evaluation Metrics: {'test_loss': 0.08991911436687393, 'accuracy': 0.9523373052245646, 'f1_score': 0.9522986646537908}
+train-1 | CCR Training complete!
+train-1 |
+train-1 exited with code 0
+```
+
+## Deploy on CCR
+
+In a more realistic scenario, this datasets will not be available in the clear to the TDC, and the TDC will be required to use a CCR for training. The following steps describe the process of sharing an encrypted dataset with TDCs and setting up a CCR in Azure for training. Please stay tuned for CCR on other cloud platforms.
+
+To deploy in Azure, you will need the following.
+
+- Docker Hub account to store container images. Alternatively, you can use pre-built images from the ```ispirt``` container registry.
+- [Azure Key Vault](https://azure.microsoft.com/en-us/products/key-vault/) to store encryption keys and implement secure key release to CCR. You can either you Azure Key Vault Premium (lower cost), or [Azure Key Vault managed HSM](https://learn.microsoft.com/en-us/azure/key-vault/managed-hsm/overview) for enhanced security. Please see instructions below on how to create and setup your AKV instance.
+- Valid Azure subscription with sufficient access to create key vault, storage accounts, storage containers, and Azure Container Instances (ACI).
+
+If you are using your own development environment instead of a dev container or codespaces, you will to install the following dependencies.
+
+- [Azure CLI](https://learn.microsoft.com/en-us/cli/azure/install-azure-cli-linux).
+- [Azure CLI Confidential containers extension](https://learn.microsoft.com/en-us/cli/azure/confcom?view=azure-cli-latest). After installing Azure CLI, you can install this extension using ```az extension add --name confcom -y```
+- [Go](https://go.dev/doc/install). Follow the instructions to install Go. After installing, ensure that the PATH environment variable is set to include ```go``` runtime.
+- ```jq```. You can install jq using ```sudo apt-get install -y jq```
+
+We will be creating the following resources as part of the deployment.
+
+- Azure Key Vault
+- Azure Storage account
+- Storage containers to host encrypted datasets
+- Azure Container Instances (ACI) to deploy the CCR and train the model
+
+### 1. Push Container Images
+
+Pre-built container images are available in iSPIRT's container registry, which can be pulled by setting the following environment variable.
+
+```bash
+export CONTAINER_REGISTRY=ispirt.azurecr.io
+```
+
+If you wish to use your own container images, login to docker hub (or your container registry of choice) and then build and push the container images to it, so that they can be pulled by the CCR. This is a one-time operation, and you can skip this step if you have already pushed the images to your container registry.
+
+```bash
+export CONTAINER_REGISTRY=
+docker login -u -p ${CONTAINER_REGISTRY}
+cd $REPO_ROOT
+./ci/push-containers.sh
+cd $REPO_ROOT/scenarios/$SCENARIO
+./ci/push-containers.sh
+```
+
+> **Note:** Replace ``, `` and `` with your container registry name, docker hub username and password respectively. Preferably use registry services other than Docker Hub as throttling restrictions will cause delays (or) image push/pull failures.
+
+### 2. Create Resources
+
+First, set up the necessary environment variables for your deployment.
+
+```bash
+az login
+
+export SCENARIO=credit-risk
+export CONTAINER_REGISTRY=ispirt.azurecr.io
+export AZURE_LOCATION=northeurope
+export AZURE_SUBSCRIPTION_ID=
+export AZURE_RESOURCE_GROUP=
+export AZURE_KEYVAULT_ENDPOINT=.vault.azure.net
+export AZURE_STORAGE_ACCOUNT_NAME=
+
+export AZURE_BANK_A_CONTAINER_NAME=bankacontainer
+export AZURE_BANK_B_CONTAINER_NAME=bankbcontainer
+export AZURE_BUREAU_CONTAINER_NAME=bureaucontainer
+export AZURE_FINTECH_CONTAINER_NAME=fintechcontainer
+export AZURE_OUTPUT_CONTAINER_NAME=outputcontainer
+```
+
+Alternatively, you can edit the values in the [export-variables.sh](./export-variables.sh) script and run it to set the environment variables.
+
+```bash
+./export-variables.sh
+source export-variables.sh
+```
+
+Azure Naming Rules:
+- Resource Group:
+ - 1–90 characters
+ - Letters, numbers, underscores, parentheses, hyphens, periods allowed
+ - Cannot end with a period (.)
+ - Case-insensitive, unique within subscription\
+- Key Vault:
+ - 3-24 characters
+ - Globally unique name
+ - Lowercase letters, numbers, hyphens only
+ - Must start and end with letter or number
+- Storage Account:
+ - 3-24 characters
+ - Globally unique name
+ - Lowercase letters and numbers only
+- Storage Container:
+ - 3-63 characters
+ - Lowercase letters, numbers, hyphens only
+ - Must start and end with a letter or number
+ - No consecutive hyphens
+ - Unique within storage account
+
+---
+
+**Important:**
+
+The values for the environment variables listed below must precisely match the namesake environment variables used during contract signing (next step). Any mismatch will lead to execution failure.
+
+- `SCENARIO`
+- `AZURE_KEYVAULT_ENDPOINT`
+- `CONTRACT_SERVICE_URL`
+- `AZURE_STORAGE_ACCOUNT_NAME`
+- `AZURE_BANK_A_CONTAINER_NAME`
+- `AZURE_BANK_B_CONTAINER_NAME`
+- `AZURE_BUREAU_CONTAINER_NAME`
+- `AZURE_FINTECH_CONTAINER_NAME`
+
+---
+With the environment variables set, we are ready to create the resources -- Azure Key Vault and Azure Storage containers.
+
+```bash
+cd $REPO_ROOT/scenarios/$SCENARIO/deployment/azure
+./1-create-storage-containers.sh
+./2-create-akv.sh
+```
+---
+
+### 3\. Contract Signing
+
+Navigate to the [contract-ledger](https://github.com/kapilvgit/contract-ledger/blob/main/README.md) repository and follow the instructions for contract signing.
+
+Once the contract is signed, export the contract sequence number as an environment variable in the same terminal where you set the environment variables for the deployment.
+
+```bash
+export CONTRACT_SEQ_NO=
+```
+
+---
+
+### 4\. Data Encryption and Upload
+
+Using their respective keys, the TDPs and TDC encrypt their datasets and model (respectively) and upload them to the Storage containers created in the previous step.
+
+Navigate to the [Azure deployment](./deployment/azure/) directory and execute the scripts for key import, data encryption and upload to Azure Blob Storage, in preparation of the CCR deployment.
+
+The import-keys script generates and imports encryption keys into Azure Key Vault with a policy based on [policy-in-template.json](./policy/policy-in-template.json). The policy requires that the CCRs run specific containers with a specific configuration which includes the public identity of the contract service. Only CCRs that satisfy this policy will be granted access to the encryption keys. The generated keys are available as files with the extension `.bin`.
+
+```bash
+export CONTRACT_SERVICE_URL=https://depa-training-contract-service.centralindia.cloudapp.azure.com:8000
+export TOOLS_HOME=$REPO_ROOT/external/confidential-sidecar-containers/tools
+
+./3-import-keys.sh
+```
+
+The data and model are then packaged as encrypted filesystems by the TDPs and TDC using their respective keys, which are saved as `.img` files.
+
+```bash
+./4-encrypt-data.sh
+```
+
+The encrypted data and model are then uploaded to the Storage containers created in the previous step. The `.img` files are uploaded to the Storage containers as blobs.
+
+```bash
+./5-upload-encrypted-data.sh
+```
+
+---
+
+### 5\. CCR Deployment
+
+With the resources ready, we are ready to deploy the Confidential Clean Room (CCR) for executing the privacy-preserving model training.
+
+```bash
+export CONTRACT_SEQ_NO=
+./deploy.sh -c $CONTRACT_SEQ_NO -p ../../config/pipeline_config.json
+```
+
+Set the `$CONTRACT_SEQ_NO` variable to the exact value of the contract sequence number (of format 2.XX). For example, if the number was 2.15, export as:
+
+```bash
+export CONTRACT_SEQ_NO=15
+```
+
+This script will deploy the container images from your container registry, including the encrypted filesystem sidecar. The sidecar will generate an SEV-SNP attestation report, generate an attestation token using the Microsoft Azure Attestation (MAA) service, retrieve dataset, model and output encryption keys from the TDP and TDC's Azure Key Vault, train the model, and save the resulting model into TDC's output filesystem image, which the TDC can later decrypt.
+
+
+
+**Note:** The completion of this script's execution simply creates a CCR instance, and doesn't indicate whether training has completed or not. The training process might still be ongoing. Poll the container logs (see below) to track progress until training is complete.
+
+### 6\. Monitor Container Logs
+
+Use the following commands to monitor the logs of the deployed containers. You might have to repeatedly poll this command to monitor the training progress:
+
+```bash
+az container logs \
+ --name "depa-training-$SCENARIO" \
+ --resource-group "$AZURE_RESOURCE_GROUP" \
+ --container-name depa-training
+```
+
+You will know training has completed when the logs print "CCR Training complete!".
+
+#### Troubleshooting
+
+In case training fails, you might want to monitor the logs of the encrypted storage sidecar container to see if the encryption process completed successfully:
+
+```bash
+az container logs --name depa-training-$SCENARIO --resource-group $AZURE_RESOURCE_GROUP --container-name encrypted-storage-sidecar
+```
+
+And to further debug, inspect the logs of the encrypted filesystem sidecar container:
+
+```bash
+az container exec \
+ --resource-group $AZURE_RESOURCE_GROUP \
+ --name depa-training-$SCENARIO \
+ --container-name encrypted-storage-sidecar \
+ --exec-command "/bin/sh"
+```
+
+Once inside the sidecar container shell, view the logs:
+
+```bash
+cat log.txt
+```
+Or inspect the individual mounted directories in `mnt/remote/`:
+
+```bash
+cd mnt/remote && ls
+```
+
+### 6\. Download and Decrypt Model
+
+Once training has completed succesfully (The training container logs will mention it explicitly), download and decrypt the trained model and other training outputs.
+
+```bash
+./6-download-decrypt-model.sh
+```
+
+The outputs will be saved to the [output](./modeller/output/) directory.
+
+To check if the trained model is fresh, you can run the following command:
+
+```bash
+stat $REPO_ROOT/scenarios/$SCENARIO/modeller/output/trained_model.onnx
+```
+
+---
+### Clean-up
+
+You can use the following command to delete the resource group and clean-up all resources used in the demo. Alternatively, you can navigate to the Azure portal and delete the resource group created for this demo.
+
+```bash
+az group delete --yes --name $AZURE_RESOURCE_GROUP
+```
+
+## References
+
+[1] [Anna Montoya, inversion, KirillOdintsov, and Martin Kotek. Home Credit Default Risk. https://kaggle.com/competitions/home-credit-default-risk, 2018. Kaggle.](https://www.kaggle.com/c/home-credit-default-risk)
+
+[2] [XGBoost](https://xgboost.readthedocs.io/en/stable/)
diff --git a/scenarios/credit-risk/ci/Dockerfile.bankA b/scenarios/credit-risk/ci/Dockerfile.bankA
new file mode 100644
index 0000000..6f0896c
--- /dev/null
+++ b/scenarios/credit-risk/ci/Dockerfile.bankA
@@ -0,0 +1,16 @@
+FROM ubuntu:22.04
+
+ENV DEBIAN_FRONTEND="noninteractive"
+
+RUN apt-get update && apt-get -y upgrade \
+ && apt-get install -y curl \
+ && apt-get install -y python3 python3-dev python3-distutils
+
+## Install pip
+RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
+RUN python3 get-pip.py
+
+## Install dependencies
+RUN pip3 install kaggle pandas numpy pyarrow
+
+COPY preprocess_bank_a.py preprocess_bank_a.py
\ No newline at end of file
diff --git a/scenarios/credit-risk/ci/Dockerfile.bankB b/scenarios/credit-risk/ci/Dockerfile.bankB
new file mode 100644
index 0000000..eebb613
--- /dev/null
+++ b/scenarios/credit-risk/ci/Dockerfile.bankB
@@ -0,0 +1,16 @@
+FROM ubuntu:22.04
+
+ENV DEBIAN_FRONTEND="noninteractive"
+
+RUN apt-get update && apt-get -y upgrade \
+ && apt-get install -y curl \
+ && apt-get install -y python3 python3-dev python3-distutils
+
+## Install pip
+RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
+RUN python3 get-pip.py
+
+## Install dependencies
+RUN pip3 install kaggle pandas numpy pyarrow
+
+COPY preprocess_bank_b.py preprocess_bank_b.py
\ No newline at end of file
diff --git a/scenarios/credit-risk/ci/Dockerfile.bureau b/scenarios/credit-risk/ci/Dockerfile.bureau
new file mode 100644
index 0000000..f8513d4
--- /dev/null
+++ b/scenarios/credit-risk/ci/Dockerfile.bureau
@@ -0,0 +1,16 @@
+FROM ubuntu:22.04
+
+ENV DEBIAN_FRONTEND="noninteractive"
+
+RUN apt-get update && apt-get -y upgrade \
+ && apt-get install -y curl \
+ && apt-get install -y python3 python3-dev python3-distutils
+
+## Install pip
+RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
+RUN python3 get-pip.py
+
+## Install dependencies
+RUN pip3 install kaggle pandas numpy pyarrow
+
+COPY preprocess_bureau.py preprocess_bureau.py
\ No newline at end of file
diff --git a/scenarios/credit-risk/ci/Dockerfile.fintech b/scenarios/credit-risk/ci/Dockerfile.fintech
new file mode 100644
index 0000000..e93f1a0
--- /dev/null
+++ b/scenarios/credit-risk/ci/Dockerfile.fintech
@@ -0,0 +1,16 @@
+FROM ubuntu:22.04
+
+ENV DEBIAN_FRONTEND="noninteractive"
+
+RUN apt-get update && apt-get -y upgrade \
+ && apt-get install -y curl \
+ && apt-get install -y python3 python3-dev python3-distutils
+
+## Install pip
+RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
+RUN python3 get-pip.py
+
+## Install dependencies
+RUN pip3 install kaggle pandas numpy pyarrow
+
+COPY preprocess_fintech.py preprocess_fintech.py
\ No newline at end of file
diff --git a/scenarios/credit-risk/ci/build.sh b/scenarios/credit-risk/ci/build.sh
new file mode 100755
index 0000000..d93b1d6
--- /dev/null
+++ b/scenarios/credit-risk/ci/build.sh
@@ -0,0 +1,6 @@
+#!/bin/bash
+
+docker build -f ci/Dockerfile.bankA src -t preprocess-bank-a:latest
+docker build -f ci/Dockerfile.bankB src -t preprocess-bank-b:latest
+docker build -f ci/Dockerfile.bureau src -t preprocess-bureau:latest
+docker build -f ci/Dockerfile.fintech src -t preprocess-fintech:latest
\ No newline at end of file
diff --git a/scenarios/credit-risk/ci/pull-containers.sh b/scenarios/credit-risk/ci/pull-containers.sh
new file mode 100755
index 0000000..fedf456
--- /dev/null
+++ b/scenarios/credit-risk/ci/pull-containers.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+
+containers=("preprocess-bank-a:latest" "preprocess-bank-b:latest" "preprocess-bureau:latest" "preprocess-fintech:latest")
+for container in "${containers[@]}"
+do
+ docker pull $CONTAINER_REGISTRY"/"$container
+done
\ No newline at end of file
diff --git a/scenarios/credit-risk/ci/push-containers.sh b/scenarios/credit-risk/ci/push-containers.sh
new file mode 100755
index 0000000..1332d99
--- /dev/null
+++ b/scenarios/credit-risk/ci/push-containers.sh
@@ -0,0 +1,6 @@
+containers=("preprocess-bank-a:latest" "preprocess-bank-b:latest" "preprocess-bureau:latest" "preprocess-fintech:latest")
+for container in "${containers[@]}"
+do
+ docker tag $container $CONTAINER_REGISTRY"/"$container
+ docker push $CONTAINER_REGISTRY"/"$container
+done
diff --git a/scenarios/credit-risk/config/consolidate_pipeline.sh b/scenarios/credit-risk/config/consolidate_pipeline.sh
new file mode 100755
index 0000000..00311e2
--- /dev/null
+++ b/scenarios/credit-risk/config/consolidate_pipeline.sh
@@ -0,0 +1,58 @@
+#! /bin/bash
+
+REPO_ROOT="$(git rev-parse --show-toplevel)"
+SCENARIO=credit-risk
+
+template_path="$REPO_ROOT/scenarios/$SCENARIO/config/templates"
+model_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/model_config.json"
+data_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/dataset_config.json"
+loss_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/loss_config.json"
+train_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/train_config.json"
+eval_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/eval_config.json"
+join_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/join_config.json"
+pipeline_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/pipeline_config.json"
+
+# populate "model_config", "data_config", and "loss_config" keys in train config
+train_config=$(cat $template_path/train_config_template.json)
+
+# Only merge if the file exists
+if [[ -f "$model_config_path" ]]; then
+ model_config=$(cat $model_config_path)
+ train_config=$(echo "$train_config" | jq --argjson model "$model_config" '.config.model_config = $model')
+fi
+
+if [[ -f "$data_config_path" ]]; then
+ data_config=$(cat $data_config_path)
+ train_config=$(echo "$train_config" | jq --argjson data "$data_config" '.config.dataset_config = $data')
+fi
+
+if [[ -f "$loss_config_path" ]]; then
+ loss_config=$(cat $loss_config_path)
+ train_config=$(echo "$train_config" | jq --argjson loss "$loss_config" '.config.loss_config = $loss')
+fi
+
+if [[ -f "$eval_config_path" ]]; then
+ eval_config=$(cat $eval_config_path)
+ # Get all keys from eval_config and copy them to train_config
+ for key in $(echo "$eval_config" | jq -r 'keys[]'); do
+ train_config=$(echo "$train_config" | jq --argjson eval "$eval_config" --arg key "$key" '.config[$key] = $eval[$key]')
+ done
+fi
+
+# save train_config
+echo "$train_config" > $train_config_path
+
+# prepare pipeline config from join_config.json (first dict "config") and train_config.json (second dict "config")
+pipeline_config=$(cat $template_path/pipeline_config_template.json)
+
+# Only merge join_config if the file exists
+if [[ -f "$join_config_path" ]]; then
+ join_config=$(cat $join_config_path)
+ pipeline_config=$(echo "$pipeline_config" | jq --argjson join "$join_config" '.pipeline += [$join]')
+fi
+
+# Always merge train_config as it's required
+pipeline_config=$(echo "$pipeline_config" | jq --argjson train "$train_config" '.pipeline += [$train]')
+
+# save pipeline_config to pipeline_config.json
+echo "$pipeline_config" > $pipeline_config_path
\ No newline at end of file
diff --git a/scenarios/credit-risk/config/dataset_config.json b/scenarios/credit-risk/config/dataset_config.json
new file mode 100644
index 0000000..3532347
--- /dev/null
+++ b/scenarios/credit-risk/config/dataset_config.json
@@ -0,0 +1,13 @@
+{
+ "type": "tabular",
+ "target_variable": "TARGET",
+ "missing_strategy": "drop",
+ "splits": {
+ "train": 0.8,
+ "val": 0.1,
+ "test": 0.1,
+ "random_state": 42,
+ "stratify": true
+ },
+ "data_type": "numpy"
+}
\ No newline at end of file
diff --git a/scenarios/credit-risk/config/eval_config.json b/scenarios/credit-risk/config/eval_config.json
new file mode 100644
index 0000000..3f9b22b
--- /dev/null
+++ b/scenarios/credit-risk/config/eval_config.json
@@ -0,0 +1,9 @@
+{
+ "task_type": "classification",
+ "metrics": [
+ "accuracy",
+ "roc_auc",
+ "confusion_matrix",
+ "classification_report"
+ ]
+}
\ No newline at end of file
diff --git a/scenarios/credit-risk/config/join_config.json b/scenarios/credit-risk/config/join_config.json
new file mode 100644
index 0000000..06e3f73
--- /dev/null
+++ b/scenarios/credit-risk/config/join_config.json
@@ -0,0 +1,108 @@
+{
+ "name": "SparkJoin",
+ "config": {
+ "datasets": [
+ {
+ "id": "19517ba8-bab8-11ed-afa1-0242ac120002",
+ "provider": "bank_a",
+ "name": "credit_applications",
+ "file": "bank_a_app_learn.parquet",
+ "select_variables": [
+ "SK_ID_CURR",
+ "TARGET",
+ "INCOME_PER_PERSON",
+ "CREDIT_TO_INCOME",
+ "NAME_CONTRACT_TYPE",
+ "CODE_GENDER",
+ "FLAG_OWN_CAR",
+ "FLAG_OWN_REALTY",
+ "CNT_CHILDREN"
+ ],
+ "mount_path": "/mnt/remote/bank_a/"
+ },
+ {
+ "id": "216d5cc6-bab8-11ed-afa1-0242ac120002",
+ "provider": "bank_a",
+ "name": "previous_applications",
+ "file": "bank_a_prev_agg.parquet",
+ "select_variables": [
+ "SK_ID_CURR",
+ "AMT_APPLICATION_count",
+ "AMT_APPLICATION_mean",
+ "AMT_CREDIT_mean",
+ "AMT_DOWN_PAYMENT_mean"
+ ],
+ "mount_path": "/mnt/remote/bank_a/"
+ },
+ {
+ "id": "2830a144-bab8-11ed-afa1-0242ac120002",
+ "provider": "bank_a",
+ "name": "payment_installments",
+ "file": "bank_a_inst_agg.parquet",
+ "select_variables": [
+ "SK_ID_CURR",
+ "AMT_PAYMENT_sum",
+ "AMT_INSTALMENT_sum",
+ "PAYMENT_DIFF_mean",
+ "LATE_PAYMENT_mean"
+ ],
+ "mount_path": "/mnt/remote/bank_a/"
+ },
+ {
+ "id": "3187712c-bab8-11ed-afa1-0242ac120002",
+ "provider": "bureau",
+ "name": "bureau_records",
+ "file": "bureau_agg.parquet",
+ "select_variables": [
+ "SK_ID_CURR",
+ "BUREAU_COUNT_sum",
+ "AMT_CREDIT_SUM_mean",
+ "AMT_CREDIT_SUM_DEBT_mean",
+ "STATUS_PAST_DUE_RATE"
+ ],
+ "mount_path": "/mnt/remote/bureau/"
+ },
+ {
+ "id": "40e82c74-bab8-11ed-afa1-0242ac120002",
+ "provider": "fintech",
+ "name": "pos_cash_balance",
+ "file": "pos_fintech_agg.parquet",
+ "select_variables": [
+ "SK_ID_CURR",
+ "SK_DPD_max",
+ "SK_DPD_DEF_mean",
+ "MONTHS_BALANCE_count"
+ ],
+ "mount_path": "/mnt/remote/fintech/"
+ },
+ {
+ "id": "45647e76-bab8-11ed-afa1-0242ac120002",
+ "provider": "bank_b",
+ "name": "credit_card_balance",
+ "file": "bank_b_cc_agg.parquet",
+ "select_variables": [
+ "SK_ID_CURR",
+ "AMT_BALANCE_mean",
+ "AMT_CREDIT_LIMIT_ACTUAL_max",
+ "SK_DPD_max",
+ "SK_DPD_DEF_mean"
+ ],
+ "mount_path": "/mnt/remote/bank_b/"
+ }
+ ],
+ "joined_dataset": {
+ "joined_dataset": "/tmp/credit_risk_joined.parquet",
+ "joining_query": "SELECT * FROM credit_applications a LEFT JOIN previous_applications p ON a.SK_ID_CURR = p.SK_ID_CURR LEFT JOIN payment_installments i ON a.SK_ID_CURR = i.SK_ID_CURR LEFT JOIN bureau_records b ON a.SK_ID_CURR = b.SK_ID_CURR LEFT JOIN pos_cash_balance pos ON a.SK_ID_CURR = pos.SK_ID_CURR LEFT JOIN credit_card_balance cc ON a.SK_ID_CURR = cc.SK_ID_CURR",
+ "joining_key": "SK_ID_CURR",
+ "drop_columns": [
+ "NAME_CONTRACT_TYPE",
+ "CODE_GENDER",
+ "FLAG_OWN_CAR",
+ "FLAG_OWN_REALTY"
+ ],
+ "identifiers": [
+ "SK_ID_CURR"
+ ]
+ }
+ }
+}
\ No newline at end of file
diff --git a/scenarios/credit-risk/config/model_config.json b/scenarios/credit-risk/config/model_config.json
new file mode 100644
index 0000000..7a96174
--- /dev/null
+++ b/scenarios/credit-risk/config/model_config.json
@@ -0,0 +1,9 @@
+{
+ "num_boost_round": 250,
+ "booster_params": {
+ "max_depth": 6,
+ "learning_rate": 0.05,
+ "objective": "binary:logistic",
+ "eval_metric": "auc"
+ }
+}
\ No newline at end of file
diff --git a/scenarios/credit-risk/config/pipeline_config.json b/scenarios/credit-risk/config/pipeline_config.json
new file mode 100644
index 0000000..49fff3a
--- /dev/null
+++ b/scenarios/credit-risk/config/pipeline_config.json
@@ -0,0 +1,157 @@
+{
+ "pipeline": [
+ {
+ "name": "SparkJoin",
+ "config": {
+ "datasets": [
+ {
+ "id": "19517ba8-bab8-11ed-afa1-0242ac120002",
+ "provider": "bank_a",
+ "name": "credit_applications",
+ "file": "bank_a_app_learn.parquet",
+ "select_variables": [
+ "SK_ID_CURR",
+ "TARGET",
+ "INCOME_PER_PERSON",
+ "CREDIT_TO_INCOME",
+ "NAME_CONTRACT_TYPE",
+ "CODE_GENDER",
+ "FLAG_OWN_CAR",
+ "FLAG_OWN_REALTY",
+ "CNT_CHILDREN"
+ ],
+ "mount_path": "/mnt/remote/bank_a/"
+ },
+ {
+ "id": "216d5cc6-bab8-11ed-afa1-0242ac120002",
+ "provider": "bank_a",
+ "name": "previous_applications",
+ "file": "bank_a_prev_agg.parquet",
+ "select_variables": [
+ "SK_ID_CURR",
+ "AMT_APPLICATION_count",
+ "AMT_APPLICATION_mean",
+ "AMT_CREDIT_mean",
+ "AMT_DOWN_PAYMENT_mean"
+ ],
+ "mount_path": "/mnt/remote/bank_a/"
+ },
+ {
+ "id": "2830a144-bab8-11ed-afa1-0242ac120002",
+ "provider": "bank_a",
+ "name": "payment_installments",
+ "file": "bank_a_inst_agg.parquet",
+ "select_variables": [
+ "SK_ID_CURR",
+ "AMT_PAYMENT_sum",
+ "AMT_INSTALMENT_sum",
+ "PAYMENT_DIFF_mean",
+ "LATE_PAYMENT_mean"
+ ],
+ "mount_path": "/mnt/remote/bank_a/"
+ },
+ {
+ "id": "3187712c-bab8-11ed-afa1-0242ac120002",
+ "provider": "bureau",
+ "name": "bureau_records",
+ "file": "bureau_agg.parquet",
+ "select_variables": [
+ "SK_ID_CURR",
+ "BUREAU_COUNT_sum",
+ "AMT_CREDIT_SUM_mean",
+ "AMT_CREDIT_SUM_DEBT_mean",
+ "STATUS_PAST_DUE_RATE"
+ ],
+ "mount_path": "/mnt/remote/bureau/"
+ },
+ {
+ "id": "40e82c74-bab8-11ed-afa1-0242ac120002",
+ "provider": "fintech",
+ "name": "pos_cash_balance",
+ "file": "pos_fintech_agg.parquet",
+ "select_variables": [
+ "SK_ID_CURR",
+ "SK_DPD_max",
+ "SK_DPD_DEF_mean",
+ "MONTHS_BALANCE_count"
+ ],
+ "mount_path": "/mnt/remote/fintech/"
+ },
+ {
+ "id": "45647e76-bab8-11ed-afa1-0242ac120002",
+ "provider": "bank_b",
+ "name": "credit_card_balance",
+ "file": "bank_b_cc_agg.parquet",
+ "select_variables": [
+ "SK_ID_CURR",
+ "AMT_BALANCE_mean",
+ "AMT_CREDIT_LIMIT_ACTUAL_max",
+ "SK_DPD_max",
+ "SK_DPD_DEF_mean"
+ ],
+ "mount_path": "/mnt/remote/bank_b/"
+ }
+ ],
+ "joined_dataset": {
+ "joined_dataset": "/tmp/credit_risk_joined.parquet",
+ "joining_query": "SELECT * FROM credit_applications a LEFT JOIN previous_applications p ON a.SK_ID_CURR = p.SK_ID_CURR LEFT JOIN payment_installments i ON a.SK_ID_CURR = i.SK_ID_CURR LEFT JOIN bureau_records b ON a.SK_ID_CURR = b.SK_ID_CURR LEFT JOIN pos_cash_balance pos ON a.SK_ID_CURR = pos.SK_ID_CURR LEFT JOIN credit_card_balance cc ON a.SK_ID_CURR = cc.SK_ID_CURR",
+ "joining_key": "SK_ID_CURR",
+ "drop_columns": [
+ "NAME_CONTRACT_TYPE",
+ "CODE_GENDER",
+ "FLAG_OWN_CAR",
+ "FLAG_OWN_REALTY"
+ ],
+ "identifiers": [
+ "SK_ID_CURR"
+ ]
+ }
+ }
+ },
+ {
+ "name": "Train_XGB",
+ "config": {
+ "paths": {
+ "input_dataset_path": "/tmp/credit_risk_joined.parquet",
+ "trained_model_output_path": "/mnt/remote/output"
+ },
+ "is_private": true,
+ "privacy_params": {
+ "epsilon": 4.0,
+ "mechanism": "gaussian",
+ "delta": 0.00001,
+ "clip_value": 1.0
+ },
+ "model_config": {
+ "num_boost_round": 250,
+ "booster_params": {
+ "max_depth": 6,
+ "learning_rate": 0.05,
+ "objective": "binary:logistic",
+ "eval_metric": "auc"
+ }
+ },
+ "dataset_config": {
+ "type": "tabular",
+ "target_variable": "TARGET",
+ "missing_strategy": "drop",
+ "splits": {
+ "train": 0.8,
+ "val": 0.1,
+ "test": 0.1,
+ "random_state": 42,
+ "stratify": true
+ },
+ "data_type": "numpy"
+ },
+ "metrics": [
+ "accuracy",
+ "roc_auc",
+ "confusion_matrix",
+ "classification_report"
+ ],
+ "task_type": "classification"
+ }
+ }
+ ]
+}
diff --git a/scenarios/credit-risk/config/templates/pipeline_config_template.json b/scenarios/credit-risk/config/templates/pipeline_config_template.json
new file mode 100644
index 0000000..43e9e84
--- /dev/null
+++ b/scenarios/credit-risk/config/templates/pipeline_config_template.json
@@ -0,0 +1,3 @@
+{
+ "pipeline": []
+}
\ No newline at end of file
diff --git a/scenarios/credit-risk/config/templates/train_config_template.json b/scenarios/credit-risk/config/templates/train_config_template.json
new file mode 100644
index 0000000..0f01ff1
--- /dev/null
+++ b/scenarios/credit-risk/config/templates/train_config_template.json
@@ -0,0 +1,16 @@
+{
+ "name": "Train_XGB",
+ "config": {
+ "paths": {
+ "input_dataset_path": "/tmp/credit_risk_joined.parquet",
+ "trained_model_output_path": "/mnt/remote/output"
+ },
+ "is_private": true,
+ "privacy_params": {
+ "epsilon": 4.0,
+ "mechanism": "gaussian",
+ "delta": 1e-5,
+ "clip_value": 1.0
+ }
+ }
+}
\ No newline at end of file
diff --git a/scenarios/credit-risk/config/train_config.json b/scenarios/credit-risk/config/train_config.json
new file mode 100644
index 0000000..b54c77c
--- /dev/null
+++ b/scenarios/credit-risk/config/train_config.json
@@ -0,0 +1,45 @@
+{
+ "name": "Train_XGB",
+ "config": {
+ "paths": {
+ "input_dataset_path": "/tmp/credit_risk_joined.parquet",
+ "trained_model_output_path": "/mnt/remote/output"
+ },
+ "is_private": true,
+ "privacy_params": {
+ "epsilon": 4.0,
+ "mechanism": "gaussian",
+ "delta": 0.00001,
+ "clip_value": 1.0
+ },
+ "model_config": {
+ "num_boost_round": 250,
+ "booster_params": {
+ "max_depth": 6,
+ "learning_rate": 0.05,
+ "objective": "binary:logistic",
+ "eval_metric": "auc"
+ }
+ },
+ "dataset_config": {
+ "type": "tabular",
+ "target_variable": "TARGET",
+ "missing_strategy": "drop",
+ "splits": {
+ "train": 0.8,
+ "val": 0.1,
+ "test": 0.1,
+ "random_state": 42,
+ "stratify": true
+ },
+ "data_type": "numpy"
+ },
+ "metrics": [
+ "accuracy",
+ "roc_auc",
+ "confusion_matrix",
+ "classification_report"
+ ],
+ "task_type": "classification"
+ }
+}
diff --git a/scenarios/credit-risk/contract/contract.json b/scenarios/credit-risk/contract/contract.json
new file mode 100644
index 0000000..2cbc573
--- /dev/null
+++ b/scenarios/credit-risk/contract/contract.json
@@ -0,0 +1,148 @@
+{
+ "id": "f4f72a88-bab1-11ed-afa1-0242ac120002",
+ "schemaVersion": "0.1",
+ "startTime": "2023-03-14T00:00:00.000Z",
+ "expiryTime": "2024-03-14T00:00:00.000Z",
+ "tdc": "",
+ "tdps": [],
+ "ccrp": "did:web:$CCRP_USERNAME.github.io",
+ "datasets": [
+ {
+ "id": "19517ba8-bab8-11ed-afa1-0242ac120002",
+ "name": "bank_a",
+ "url": "https://$AZURE_STORAGE_ACCOUNT_NAME.blob.core.windows.net/$AZURE_BANK_A_CONTAINER_NAME/data.img",
+ "provider": "",
+ "key": {
+ "type": "azure",
+ "properties": {
+ "kid": "BankAFilesystemEncryptionKey",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "endpoint": ""
+ }
+ }
+ },
+ {
+ "id": "216d5cc6-bab8-11ed-afa1-0242ac120002",
+ "name": "bank_a",
+ "url": "https://$AZURE_STORAGE_ACCOUNT_NAME.blob.core.windows.net/$AZURE_BANK_A_CONTAINER_NAME/data.img",
+ "provider": "",
+ "key": {
+ "type": "azure",
+ "properties": {
+ "kid": "BankAFilesystemEncryptionKey",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "endpoint": ""
+ }
+ }
+ },
+ {
+ "id": "2830a144-bab8-11ed-afa1-0242ac120002",
+ "name": "bank_a",
+ "url": "https://$AZURE_STORAGE_ACCOUNT_NAME.blob.core.windows.net/$AZURE_BANK_A_CONTAINER_NAME/data.img",
+ "provider": "",
+ "key": {
+ "type": "azure",
+ "properties": {
+ "kid": "BankAFilesystemEncryptionKey",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "endpoint": ""
+ }
+ }
+ },
+ {
+ "id": "45647e76-bab8-11ed-afa1-0242ac120002",
+ "name": "bank_b",
+ "url": "https://$AZURE_STORAGE_ACCOUNT_NAME.blob.core.windows.net/$AZURE_BANK_B_CONTAINER_NAME/data.img",
+ "provider": "",
+ "key": {
+ "type": "azure",
+ "properties": {
+ "kid": "BankBFilesystemEncryptionKey",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "endpoint": ""
+ }
+ }
+ },
+ {
+ "id": "3187712c-bab8-11ed-afa1-0242ac120002",
+ "name": "bureau",
+ "url": "https://$AZURE_STORAGE_ACCOUNT_NAME.blob.core.windows.net/$AZURE_BUREAU_CONTAINER_NAME/data.img",
+ "provider": "",
+ "key": {
+ "type": "azure",
+ "properties": {
+ "kid": "BureauFilesystemEncryptionKey",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "endpoint": ""
+ }
+ }
+ },
+ {
+ "id": "40e82c74-bab8-11ed-afa1-0242ac120002",
+ "name": "fintech",
+ "url": "https://$AZURE_STORAGE_ACCOUNT_NAME.blob.core.windows.net/$AZURE_FINTECH_CONTAINER_NAME/data.img",
+ "provider": "",
+ "key": {
+ "type": "azure",
+ "properties": {
+ "kid": "FintechFilesystemEncryptionKey",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "endpoint": ""
+ }
+ }
+ }
+ ],
+ "purpose": "TRAINING",
+ "constraints": [
+ {
+ "privacy": [
+ {
+ "dataset": "19517ba8-bab8-11ed-afa1-0242ac120002",
+ "epsilon_threshold": "7.5",
+ "delta": "0.01"
+ },
+ {
+ "dataset": "216d5cc6-bab8-11ed-afa1-0242ac120002",
+ "epsilon_threshold": "7.5",
+ "delta": "0.01"
+ },
+ {
+ "dataset": "2830a144-bab8-11ed-afa1-0242ac120002",
+ "epsilon_threshold": "9",
+ "delta": "0.01"
+ },
+ {
+ "dataset": "45647e76-bab8-11ed-afa1-0242ac120002",
+ "epsilon_threshold": "5",
+ "delta": "0.01"
+ },
+ {
+ "dataset": "3187712c-bab8-11ed-afa1-0242ac120002",
+ "epsilon_threshold": "4.5",
+ "delta": "0.01"
+ },
+ {
+ "dataset": "40e82c74-bab8-11ed-afa1-0242ac120002",
+ "epsilon_threshold": "4",
+ "delta": "0.01"
+ }
+ ]
+ }
+ ],
+ "terms": {
+ "payment": {},
+ "revocation": {}
+ }
+}
\ No newline at end of file
diff --git a/scenarios/credit-risk/deployment/azure/0-create-acr.sh b/scenarios/credit-risk/deployment/azure/0-create-acr.sh
new file mode 100755
index 0000000..4719bad
--- /dev/null
+++ b/scenarios/credit-risk/deployment/azure/0-create-acr.sh
@@ -0,0 +1,42 @@
+#!/bin/bash
+
+# Only to be run when creating a new ACR
+
+# Ensure required env vars are set
+if [[ -z "$CONTAINER_REGISTRY" || -z "$AZURE_RESOURCE_GROUP" || -z "$AZURE_LOCATION" ]]; then
+ echo "ERROR: CONTAINER_REGISTRY, AZURE_RESOURCE_GROUP, and AZURE_LOCATION environment variables must be set."
+ exit 1
+fi
+
+echo "Checking if ACR '$CONTAINER_REGISTRY' exists in resource group '$AZURE_RESOURCE_GROUP'..."
+
+# Check if ACR exists
+ACR_EXISTS=$(az acr show --name "$CONTAINER_REGISTRY" --resource-group "$AZURE_RESOURCE_GROUP" --query "name" -o tsv 2>/dev/null)
+
+if [[ -n "$ACR_EXISTS" ]]; then
+ echo "✅ ACR '$CONTAINER_REGISTRY' already exists."
+else
+ echo "⏳ ACR '$CONTAINER_REGISTRY' does not exist. Creating..."
+
+ # Create ACR with premium SKU and admin enabled
+ az acr create \
+ --name "$CONTAINER_REGISTRY" \
+ --resource-group "$AZURE_RESOURCE_GROUP" \
+ --location "$AZURE_LOCATION" \
+ --sku Premium \
+ --admin-enabled true \
+ --output table
+
+ # Enable anonymous pull
+ az acr update --name "$CONTAINER_REGISTRY" --anonymous-pull-enabled true
+
+ if [[ $? -eq 0 ]]; then
+ echo "✅ ACR '$CONTAINER_REGISTRY' created successfully."
+ else
+ echo "❌ Failed to create ACR."
+ exit 1
+ fi
+fi
+
+# Login to the ACR
+az acr login --name "$CONTAINER_REGISTRY"
\ No newline at end of file
diff --git a/scenarios/credit-risk/deployment/azure/1-create-storage-containers.sh b/scenarios/credit-risk/deployment/azure/1-create-storage-containers.sh
new file mode 100755
index 0000000..6a103d2
--- /dev/null
+++ b/scenarios/credit-risk/deployment/azure/1-create-storage-containers.sh
@@ -0,0 +1,66 @@
+#!/bin/bash
+#
+echo "Checking if resource group $AZURE_RESOURCE_GROUP exists..."
+RG_EXISTS=$(az group exists --name $AZURE_RESOURCE_GROUP)
+
+if [ "$RG_EXISTS" == "false" ]; then
+ echo "Resource group $AZURE_RESOURCE_GROUP does not exist. Creating it now..."
+ # Create the resource group
+ az group create --name $AZURE_RESOURCE_GROUP --location $AZURE_LOCATION
+else
+ echo "Resource group $AZURE_RESOURCE_GROUP already exists. Skipping creation."
+fi
+
+echo "Check if storage account $STORAGE_ACCOUNT_NAME exists..."
+STORAGE_ACCOUNT_EXISTS=$(az storage account check-name --name $AZURE_STORAGE_ACCOUNT_NAME --query "nameAvailable" --output tsv)
+
+if [ "$STORAGE_ACCOUNT_EXISTS" == "true" ]; then
+ echo "Storage account $AZURE_STORAGE_ACCOUNT_NAME does not exist. Creating it now..."
+ az storage account create --resource-group $AZURE_RESOURCE_GROUP --name $AZURE_STORAGE_ACCOUNT_NAME
+else
+ echo "Storage account $AZURE_STORAGE_ACCOUNT_NAME already exists. Skipping creation."
+fi
+
+# Get the storage account key
+ACCOUNT_KEY=$(az storage account keys list --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --query "[0].value" --output tsv)
+
+
+# Check if the Bank A container exists
+CONTAINER_EXISTS=$(az storage container exists --name $AZURE_BANK_A_CONTAINER_NAME --account-name $AZURE_STORAGE_ACCOUNT_NAME --account-key $ACCOUNT_KEY --query "exists" --output tsv)
+
+if [ "$CONTAINER_EXISTS" == "false" ]; then
+ echo "Container $AZURE_BANK_A_CONTAINER_NAME does not exist. Creating it now..."
+ az storage container create --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --name $AZURE_BANK_A_CONTAINER_NAME --account-key $ACCOUNT_KEY
+fi
+
+# Check if the Bank B container exists
+CONTAINER_EXISTS=$(az storage container exists --name $AZURE_BANK_B_CONTAINER_NAME --account-name $AZURE_STORAGE_ACCOUNT_NAME --account-key $ACCOUNT_KEY --query "exists" --output tsv)
+
+if [ "$CONTAINER_EXISTS" == "false" ]; then
+ echo "Container $AZURE_BANK_B_CONTAINER_NAME does not exist. Creating it now..."
+ az storage container create --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --name $AZURE_BANK_B_CONTAINER_NAME --account-key $ACCOUNT_KEY
+fi
+
+# Check if the Bureau container exists
+CONTAINER_EXISTS=$(az storage container exists --name $AZURE_BUREAU_CONTAINER_NAME --account-name $AZURE_STORAGE_ACCOUNT_NAME --account-key $ACCOUNT_KEY --query "exists" --output tsv)
+
+if [ "$CONTAINER_EXISTS" == "false" ]; then
+ echo "Container $AZURE_BUREAU_CONTAINER_NAME does not exist. Creating it now..."
+ az storage container create --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --name $AZURE_BUREAU_CONTAINER_NAME --account-key $ACCOUNT_KEY
+fi
+
+# Check if the Fintech container exists
+CONTAINER_EXISTS=$(az storage container exists --name $AZURE_FINTECH_CONTAINER_NAME --account-name $AZURE_STORAGE_ACCOUNT_NAME --account-key $ACCOUNT_KEY --query "exists" --output tsv)
+
+if [ "$CONTAINER_EXISTS" == "false" ]; then
+ echo "Container $AZURE_FINTECH_CONTAINER_NAME does not exist. Creating it now..."
+ az storage container create --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --name $AZURE_FINTECH_CONTAINER_NAME --account-key $ACCOUNT_KEY
+fi
+
+# Check if the OUTPUT container exists
+CONTAINER_EXISTS=$(az storage container exists --name $AZURE_OUTPUT_CONTAINER_NAME --account-name $AZURE_STORAGE_ACCOUNT_NAME --account-key $ACCOUNT_KEY --query "exists" --output tsv)
+
+if [ "$CONTAINER_EXISTS" == "false" ]; then
+ echo "Container $AZURE_OUTPUT_CONTAINER_NAME does not exist. Creating it now..."
+ az storage container create --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --name $AZURE_OUTPUT_CONTAINER_NAME --account-key $ACCOUNT_KEY
+fi
\ No newline at end of file
diff --git a/scenarios/credit-risk/deployment/azure/2-create-akv.sh b/scenarios/credit-risk/deployment/azure/2-create-akv.sh
new file mode 100755
index 0000000..c20a75e
--- /dev/null
+++ b/scenarios/credit-risk/deployment/azure/2-create-akv.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+set -e
+
+if [[ "$AZURE_KEYVAULT_ENDPOINT" == *".vault.azure.net" ]]; then
+ AZURE_AKV_RESOURCE_NAME=`echo $AZURE_KEYVAULT_ENDPOINT | awk '{split($0,a,"."); print a[1]}'`
+ # Check if the Key Vault already exists
+ echo "Checking if Key Vault $AZURE_AKV_RESOURCE_NAME exists..."
+ NAME_AVAILABLE=$(az rest --method post \
+ --uri "https://management.azure.com/subscriptions/$AZURE_SUBSCRIPTION_ID/providers/Microsoft.KeyVault/checkNameAvailability?api-version=2019-09-01" \
+ --headers "Content-Type=application/json" \
+ --body "{\"name\": \"$AZURE_AKV_RESOURCE_NAME\", \"type\": \"Microsoft.KeyVault/vaults\"}" | jq -r '.nameAvailable')
+ if [ "$NAME_AVAILABLE" == true ]; then
+ echo "Key Vault $AZURE_AKV_RESOURCE_NAME does not exist. Creating it now..."
+ echo CREATING $AZURE_KEYVAULT_ENDPOINT in resouce group $AZURE_RESOURCE_GROUP
+ # Create Azure key vault with RBAC authorization
+ az keyvault create --name $AZURE_AKV_RESOURCE_NAME --resource-group $AZURE_RESOURCE_GROUP --sku "Premium" --enable-rbac-authorization
+ # Assign RBAC roles to the resource owner so they can import keys
+ AKV_SCOPE=`az keyvault show --name $AZURE_AKV_RESOURCE_NAME --query id --output tsv`
+ az role assignment create --role "Key Vault Crypto Officer" --assignee `az account show --query user.name --output tsv` --scope $AKV_SCOPE
+ az role assignment create --role "Key Vault Crypto User" --assignee `az account show --query user.name --output tsv` --scope $AKV_SCOPE
+ else
+ echo "Key Vault $AZURE_AKV_RESOURCE_NAME already exists. Skipping creation."
+ fi
+else
+ echo "Automated creation of key vaults is supported only for vaults"
+fi
diff --git a/scenarios/credit-risk/deployment/azure/3-import-keys.sh b/scenarios/credit-risk/deployment/azure/3-import-keys.sh
new file mode 100755
index 0000000..71abccf
--- /dev/null
+++ b/scenarios/credit-risk/deployment/azure/3-import-keys.sh
@@ -0,0 +1,58 @@
+#!/bin/bash
+
+# Function to import a key with a given key ID and key material into AKV
+# The key is bound to a key release policy with host data defined in the environment variable CCE_POLICY_HASH
+function import_key() {
+ export KEYID=$1
+ export KEYFILE=$2
+
+ # For RSA-HSM keys, we need to set a salt and label which will be used in the symmetric key derivation
+ if [ "$AZURE_AKV_KEY_TYPE" = "RSA-HSM" ]; then
+ export AZURE_AKV_KEY_DERIVATION_LABEL=$KEYID
+ fi
+
+ CONFIG=$(jq '.claims[0][0].equals = env.CCE_POLICY_HASH' importkey-config-template.json)
+ CONFIG=$(echo $CONFIG | jq '.key.kid = env.KEYID')
+ CONFIG=$(echo $CONFIG | jq '.key.kty = env.AZURE_AKV_KEY_TYPE')
+ CONFIG=$(echo $CONFIG | jq '.key_derivation.salt = "9b53cddbe5b78a0b912a8f05f341bcd4dd839ea85d26a08efaef13e696d999f4"')
+ CONFIG=$(echo $CONFIG | jq '.key_derivation.label = env.AZURE_AKV_KEY_DERIVATION_LABEL')
+ CONFIG=$(echo $CONFIG | jq '.key.akv.endpoint = env.AZURE_KEYVAULT_ENDPOINT')
+ CONFIG=$(echo $CONFIG | jq '.key.akv.bearer_token = env.BEARER_TOKEN')
+ echo $CONFIG > /tmp/importkey-config.json
+ echo "Importing $KEYID key with key release policy"
+ jq '.key.akv.bearer_token = "REDACTED"' /tmp/importkey-config.json
+ pushd . && cd $TOOLS_HOME/importkey && go run main.go -c /tmp/importkey-config.json -out && popd
+ mv $TOOLS_HOME/importkey/keyfile.bin $KEYFILE
+}
+
+echo Obtaining contract service parameters...
+CONTRACT_SERVICE_URL=${CONTRACT_SERVICE_URL:-"http://localhost:8000"}
+export CONTRACT_SERVICE_PARAMETERS=$(curl -k -f $CONTRACT_SERVICE_URL/parameters | base64 --wrap=0)
+
+envsubst < ../../policy/policy-in-template.json > /tmp/policy-in.json
+export CCE_POLICY=$(az confcom acipolicygen -i /tmp/policy-in.json --debug-mode)
+export CCE_POLICY_HASH=$(go run $TOOLS_HOME/securitypolicydigest/main.go -p $CCE_POLICY)
+echo "Training container policy hash $CCE_POLICY_HASH"
+
+# Obtain the token based on the AKV resource endpoint subdomain
+if [[ "$AZURE_KEYVAULT_ENDPOINT" == *".vault.azure.net" ]]; then
+ export BEARER_TOKEN=$(az account get-access-token --resource https://vault.azure.net | jq -r .accessToken)
+ echo "Importing keys to AKV key vaults can be only of type RSA-HSM"
+ export AZURE_AKV_KEY_TYPE="RSA-HSM"
+elif [[ "$AZURE_KEYVAULT_ENDPOINT" == *".managedhsm.azure.net" ]]; then
+ export BEARER_TOKEN=$(az account get-access-token --resource https://managedhsm.azure.net | jq -r .accessToken)
+ export AZURE_AKV_KEY_TYPE="oct-HSM"
+fi
+
+DATADIR=$REPO_ROOT/scenarios/$SCENARIO/data
+MODELDIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+import_key "BankAFilesystemEncryptionKey" $DATADIR/bank_a_key.bin
+import_key "BankBFilesystemEncryptionKey" $DATADIR/bank_b_key.bin
+import_key "BureauFilesystemEncryptionKey" $DATADIR/bureau_key.bin
+import_key "FintechFilesystemEncryptionKey" $DATADIR/fintech_key.bin
+import_key "OutputFilesystemEncryptionKey" $MODELDIR/output_key.bin
+
+## Cleanup
+rm /tmp/importkey-config.json
+rm /tmp/policy-in.json
diff --git a/scenarios/credit-risk/deployment/azure/4-encrypt-data.sh b/scenarios/credit-risk/deployment/azure/4-encrypt-data.sh
new file mode 100755
index 0000000..82e55f0
--- /dev/null
+++ b/scenarios/credit-risk/deployment/azure/4-encrypt-data.sh
@@ -0,0 +1,13 @@
+#!/bin/bash
+
+DATADIR=$REPO_ROOT/scenarios/$SCENARIO/data
+MODELDIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+./generatefs.sh -d $DATADIR/bank_a/preprocessed -k $DATADIR/bank_a_key.bin -i $DATADIR/bank_a.img
+./generatefs.sh -d $DATADIR/bank_b/preprocessed -k $DATADIR/bank_b_key.bin -i $DATADIR/bank_b.img
+./generatefs.sh -d $DATADIR/bureau/preprocessed -k $DATADIR/bureau_key.bin -i $DATADIR/bureau.img
+./generatefs.sh -d $DATADIR/fintech/preprocessed -k $DATADIR/fintech_key.bin -i $DATADIR/fintech.img
+
+sudo rm -rf $MODELDIR/output
+mkdir -p $MODELDIR/output
+./generatefs.sh -d $MODELDIR/output -k $MODELDIR/output_key.bin -i $MODELDIR/output.img
\ No newline at end of file
diff --git a/scenarios/credit-risk/deployment/azure/5-upload-encrypted-data.sh b/scenarios/credit-risk/deployment/azure/5-upload-encrypted-data.sh
new file mode 100755
index 0000000..b12e1d2
--- /dev/null
+++ b/scenarios/credit-risk/deployment/azure/5-upload-encrypted-data.sh
@@ -0,0 +1,51 @@
+#!/bin/bash
+
+export DATA_DIR=$REPO_ROOT/scenarios/$SCENARIO/data
+export MODEL_DIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+ACCOUNT_KEY=$(az storage account keys list --account-name $AZURE_STORAGE_ACCOUNT_NAME --only-show-errors | jq -r .[0].value)
+
+az storage blob upload \
+ --account-name $AZURE_STORAGE_ACCOUNT_NAME \
+ --container $AZURE_BANK_A_CONTAINER_NAME \
+ --file $DATA_DIR/bank_a.img \
+ --name data.img \
+ --type page \
+ --overwrite \
+ --account-key $ACCOUNT_KEY
+
+az storage blob upload \
+ --account-name $AZURE_STORAGE_ACCOUNT_NAME \
+ --container $AZURE_BANK_B_CONTAINER_NAME \
+ --file $DATA_DIR/bank_b.img \
+ --name data.img \
+ --type page \
+ --overwrite \
+ --account-key $ACCOUNT_KEY
+
+az storage blob upload \
+ --account-name $AZURE_STORAGE_ACCOUNT_NAME \
+ --container $AZURE_BUREAU_CONTAINER_NAME \
+ --file $DATA_DIR/bureau.img \
+ --name data.img \
+ --type page \
+ --overwrite \
+ --account-key $ACCOUNT_KEY
+
+az storage blob upload \
+ --account-name $AZURE_STORAGE_ACCOUNT_NAME \
+ --container $AZURE_FINTECH_CONTAINER_NAME \
+ --file $DATA_DIR/fintech.img \
+ --name data.img \
+ --type page \
+ --overwrite \
+ --account-key $ACCOUNT_KEY
+
+az storage blob upload \
+ --account-name $AZURE_STORAGE_ACCOUNT_NAME \
+ --container $AZURE_OUTPUT_CONTAINER_NAME \
+ --file $MODEL_DIR/output.img \
+ --name data.img \
+ --type page \
+ --overwrite \
+ --account-key $ACCOUNT_KEY
diff --git a/scenarios/credit-risk/deployment/azure/6-download-decrypt-model.sh b/scenarios/credit-risk/deployment/azure/6-download-decrypt-model.sh
new file mode 100755
index 0000000..b6d043a
--- /dev/null
+++ b/scenarios/credit-risk/deployment/azure/6-download-decrypt-model.sh
@@ -0,0 +1,37 @@
+#!/bin/bash
+
+MODELDIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+sudo rm -rf $MODELDIR/output
+mkdir -p $MODELDIR/output
+
+ACCOUNT_KEY=$(az storage account keys list --account-name $AZURE_STORAGE_ACCOUNT_NAME --only-show-errors | jq -r .[0].value)
+
+az storage blob download \
+ --account-name $AZURE_STORAGE_ACCOUNT_NAME \
+ --container $AZURE_OUTPUT_CONTAINER_NAME \
+ --file $MODELDIR/output.img \
+ --name data.img \
+ --account-key $ACCOUNT_KEY
+
+encryptedImage=$MODELDIR/output.img
+keyFilePath=$MODELDIR/output_key.bin
+
+echo Decrypting $encryptedImage with key $keyFilePath
+deviceName=cryptdevice1
+deviceNamePath="/dev/mapper/$deviceName"
+
+sudo cryptsetup luksOpen "$encryptedImage" "$deviceName" \
+ --key-file "$keyFilePath" \
+ --integrity-no-journal --persistent
+
+mountPoint=`mktemp -d`
+sudo mount -t ext4 "$deviceNamePath" "$mountPoint" -o loop
+
+cp -r $mountPoint/* $MODELDIR/output/
+
+echo "[!] Closing device..."
+
+sudo umount "$mountPoint"
+sleep 2
+sudo cryptsetup luksClose "$deviceName"
\ No newline at end of file
diff --git a/scenarios/credit-risk/deployment/azure/aci-parameters-template.json b/scenarios/credit-risk/deployment/azure/aci-parameters-template.json
new file mode 100644
index 0000000..8eb11fc
--- /dev/null
+++ b/scenarios/credit-risk/deployment/azure/aci-parameters-template.json
@@ -0,0 +1,23 @@
+{
+ "containerRegistry": {
+ "value": ""
+ },
+ "ccePolicy": {
+ "value": ""
+ },
+ "EncfsSideCarArgs": {
+ "value": ""
+ },
+ "ContractService": {
+ "value": ""
+ },
+ "ContractServiceParameters": {
+ "value": ""
+ },
+ "Contracts": {
+ "value": ""
+ },
+ "PipelineConfiguration": {
+ "value": ""
+ }
+}
\ No newline at end of file
diff --git a/scenarios/credit-risk/deployment/azure/arm-template.json b/scenarios/credit-risk/deployment/azure/arm-template.json
new file mode 100644
index 0000000..1e76214
--- /dev/null
+++ b/scenarios/credit-risk/deployment/azure/arm-template.json
@@ -0,0 +1,181 @@
+{
+ "$schema": "https://schema.management.azure.com/schemas/2019-04-01/deploymentTemplate.json#",
+ "contentVersion": "1.0.0.0",
+ "parameters": {
+ "name": {
+ "defaultValue": "depa-training-credit-risk",
+ "type": "string",
+ "metadata": {
+ "description": "Name for the container group"
+ }
+ },
+ "location": {
+ "defaultValue": "northeurope",
+ "type": "string",
+ "metadata": {
+ "description": "Location for all resources."
+ }
+ },
+ "port": {
+ "defaultValue": 8080,
+ "type": "int",
+ "metadata": {
+ "description": "Port to open on the container and the public IP address."
+ }
+ },
+ "containerRegistry": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "The container registry login server."
+ }
+ },
+ "restartPolicy": {
+ "defaultValue": "Never",
+ "allowedValues": [
+ "Always",
+ "Never",
+ "OnFailure"
+ ],
+ "type": "string",
+ "metadata": {
+ "description": "The behavior of Azure runtime if container has stopped."
+ }
+ },
+ "ccePolicy": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "cce policy"
+ }
+ },
+ "EncfsSideCarArgs": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "Remote file system information for storage sidecar."
+ }
+ },
+ "ContractService": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "URL of contract service"
+ }
+ },
+ "Contracts": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "List of contracts"
+ }
+ },
+ "ContractServiceParameters": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "Contract service parameters"
+ }
+ },
+ "PipelineConfiguration": {
+ "defaultValue": "secureString",
+ "type": "string",
+ "metadata": {
+ "description": "Pipeline configuration"
+ }
+ }
+ },
+ "resources": [
+ {
+ "type": "Microsoft.ContainerInstance/containerGroups",
+ "apiVersion": "2023-05-01",
+ "name": "[parameters('name')]",
+ "location": "[parameters('location')]",
+ "properties": {
+ "confidentialComputeProperties": {
+ "ccePolicy": "[parameters('ccePolicy')]"
+ },
+ "containers": [
+ {
+ "name": "depa-training",
+ "properties": {
+ "image": "[concat(parameters('containerRegistry'), '/depa-training:latest')]",
+ "command": [
+ "/bin/bash",
+ "run.sh"
+ ],
+ "environmentVariables": [],
+ "volumeMounts": [
+ {
+ "name": "remotemounts",
+ "mountPath": "/mnt/remote"
+ }
+ ],
+ "resources": {
+ "requests": {
+ "cpu": 3,
+ "memoryInGB": 12
+ }
+ }
+ }
+ },
+ {
+ "name": "encrypted-storage-sidecar",
+ "properties": {
+ "image": "[concat(parameters('containerRegistry'), '/depa-training-encfs:latest')]",
+ "command": [
+ "/encfs.sh"
+ ],
+ "environmentVariables": [
+ {
+ "name": "EncfsSideCarArgs",
+ "value": "[parameters('EncfsSideCarArgs')]"
+ },
+ {
+ "name": "ContractService",
+ "value": "[parameters('ContractService')]"
+ },
+ {
+ "name": "Contracts",
+ "value": "[parameters('Contracts')]"
+ },
+ {
+ "name": "ContractServiceParameters",
+ "value": "[parameters('ContractServiceParameters')]"
+ },
+ {
+ "name": "PipelineConfiguration",
+ "value": "[parameters('PipelineConfiguration')]"
+ }
+ ],
+ "volumeMounts": [
+ {
+ "name": "remotemounts",
+ "mountPath": "/mnt/remote"
+ }
+ ],
+ "securityContext": {
+ "privileged": "true"
+ },
+ "resources": {
+ "requests": {
+ "cpu": 0.5,
+ "memoryInGB": 2
+ }
+ }
+ }
+ }
+ ],
+ "sku": "Confidential",
+ "osType": "Linux",
+ "restartPolicy": "[parameters('restartPolicy')]",
+ "volumes": [
+ {
+ "name": "remotemounts",
+ "emptydir": {}
+ }
+ ]
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/scenarios/credit-risk/deployment/azure/deploy.sh b/scenarios/credit-risk/deployment/azure/deploy.sh
new file mode 100755
index 0000000..e69f399
--- /dev/null
+++ b/scenarios/credit-risk/deployment/azure/deploy.sh
@@ -0,0 +1,159 @@
+#!/bin/bash
+
+set -e
+
+while getopts ":c:p:" options; do
+ case $options in
+ c)contract=$OPTARG;;
+ p)pipelineConfiguration=$OPTARG;;
+ esac
+done
+
+if [[ -z "${contract}" ]]; then
+ echo "No contract specified"
+ exit 1
+fi
+
+if [[ -z "${pipelineConfiguration}" ]]; then
+ echo "No pipeline configuration specified"
+ exit 1
+fi
+
+if [[ -z "${AZURE_KEYVAULT_ENDPOINT}" ]]; then
+ echo "Environment variable AZURE_KEYVAULT_ENDPOINT not defined"
+fi
+
+echo Obtaining contract service parameters...
+
+CONTRACT_SERVICE_URL=${CONTRACT_SERVICE_URL:-"https://localhost:8000"}
+export CONTRACT_SERVICE_PARAMETERS=$(curl -k -f $CONTRACT_SERVICE_URL/parameters | base64 --wrap=0)
+
+echo Computing CCE policy...
+envsubst < ../../policy/policy-in-template.json > /tmp/policy-in.json
+export CCE_POLICY=$(az confcom acipolicygen -i /tmp/policy-in.json --debug-mode)
+export CCE_POLICY_HASH=$(go run $TOOLS_HOME/securitypolicydigest/main.go -p $CCE_POLICY)
+echo "Training container policy hash $CCE_POLICY_HASH"
+
+export CONTRACTS=$contract
+export PIPELINE_CONFIGURATION=`cat $pipelineConfiguration | base64 --wrap=0`
+
+function generate_encrypted_filesystem_information() {
+ end=`date -u -d "60 minutes" '+%Y-%m-%dT%H:%MZ'`
+ BANK_A_SAS_TOKEN=$(az storage blob generate-sas --account-name $AZURE_STORAGE_ACCOUNT_NAME --container-name $AZURE_BANK_A_CONTAINER_NAME --permissions r --name data.img --expiry $end --only-show-errors)
+ export BANK_A_SAS_TOKEN="$(echo -n $BANK_A_SAS_TOKEN | tr -d \")"
+ export BANK_A_SAS_TOKEN="?$BANK_A_SAS_TOKEN"
+
+ BANK_B_SAS_TOKEN=$(az storage blob generate-sas --account-name $AZURE_STORAGE_ACCOUNT_NAME --container-name $AZURE_BANK_B_CONTAINER_NAME --permissions r --name data.img --expiry $end --only-show-errors)
+ export BANK_B_SAS_TOKEN=$(echo $BANK_B_SAS_TOKEN | tr -d \")
+ export BANK_B_SAS_TOKEN="?$BANK_B_SAS_TOKEN"
+
+ BUREAU_SAS_TOKEN=$(az storage blob generate-sas --account-name $AZURE_STORAGE_ACCOUNT_NAME --container-name $AZURE_BUREAU_CONTAINER_NAME --permissions r --name data.img --expiry $end --only-show-errors)
+ export BUREAU_SAS_TOKEN=$(echo $BUREAU_SAS_TOKEN | tr -d \")
+ export BUREAU_SAS_TOKEN="?$BUREAU_SAS_TOKEN"
+
+ FINTECH_SAS_TOKEN=$(az storage blob generate-sas --account-name $AZURE_STORAGE_ACCOUNT_NAME --container-name $AZURE_FINTECH_CONTAINER_NAME --permissions r --name data.img --expiry $end --only-show-errors)
+ export FINTECH_SAS_TOKEN=$(echo $FINTECH_SAS_TOKEN | tr -d \")
+ export FINTECH_SAS_TOKEN="?$FINTECH_SAS_TOKEN"
+
+ OUTPUT_SAS_TOKEN=$(az storage blob generate-sas --account-name $AZURE_STORAGE_ACCOUNT_NAME --container-name $AZURE_OUTPUT_CONTAINER_NAME --permissions rw --name data.img --expiry $end --only-show-errors)
+ export OUTPUT_SAS_TOKEN=$(echo $OUTPUT_SAS_TOKEN | tr -d \")
+ export OUTPUT_SAS_TOKEN="?$OUTPUT_SAS_TOKEN"
+
+ # Obtain the token based on the AKV resource endpoint subdomain
+ if [[ "$AZURE_KEYVAULT_ENDPOINT" == *".vault.azure.net" ]]; then
+ export BEARER_TOKEN=$(az account get-access-token --resource https://vault.azure.net | jq -r .accessToken)
+ echo "Importing keys to AKV key vaults can be only of type RSA-HSM"
+ export AZURE_AKV_KEY_TYPE="RSA-HSM"
+ elif [[ "$AZURE_KEYVAULT_ENDPOINT" == *".managedhsm.azure.net" ]]; then
+ export BEARER_TOKEN=$(az account get-access-token --resource https://managedhsm.azure.net | jq -r .accessToken)
+ export AZURE_AKV_KEY_TYPE="oct-HSM"
+ fi
+
+ TMP=$(jq . encrypted-filesystem-config-template.json)
+ TMP=`echo $TMP | \
+ jq '.azure_filesystems[0].azure_url = "https://" + env.AZURE_STORAGE_ACCOUNT_NAME + ".blob.core.windows.net/" + env.AZURE_BANK_A_CONTAINER_NAME + "/data.img" + env.BANK_A_SAS_TOKEN' | \
+ jq '.azure_filesystems[0].mount_point = "/mnt/remote/bank_a"' | \
+ jq '.azure_filesystems[0].key.kid = "BankAFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[0].key.kty = env.AZURE_AKV_KEY_TYPE' | \
+ jq '.azure_filesystems[0].key.akv.endpoint = env.AZURE_KEYVAULT_ENDPOINT' | \
+ jq '.azure_filesystems[0].key.akv.bearer_token = env.BEARER_TOKEN' | \
+ jq '.azure_filesystems[0].key_derivation.label = "BankAFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[0].key_derivation.salt = "9b53cddbe5b78a0b912a8f05f341bcd4dd839ea85d26a08efaef13e696d999f4"'`
+
+ TMP=`echo $TMP | \
+ jq '.azure_filesystems[1].azure_url = "https://" + env.AZURE_STORAGE_ACCOUNT_NAME + ".blob.core.windows.net/" + env.AZURE_BANK_B_CONTAINER_NAME + "/data.img" + env.BANK_B_SAS_TOKEN' | \
+ jq '.azure_filesystems[1].mount_point = "/mnt/remote/bank_b"' | \
+ jq '.azure_filesystems[1].key.kid = "BankBFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[1].key.kty = env.AZURE_AKV_KEY_TYPE' | \
+ jq '.azure_filesystems[1].key.akv.endpoint = env.AZURE_KEYVAULT_ENDPOINT' | \
+ jq '.azure_filesystems[1].key.akv.bearer_token = env.BEARER_TOKEN' | \
+ jq '.azure_filesystems[1].key_derivation.label = "BankBFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[1].key_derivation.salt = "9b53cddbe5b78a0b912a8f05f341bcd4dd839ea85d26a08efaef13e696d999f4"'`
+
+ TMP=`echo $TMP | \
+ jq '.azure_filesystems[2].azure_url = "https://" + env.AZURE_STORAGE_ACCOUNT_NAME + ".blob.core.windows.net/" + env.AZURE_BUREAU_CONTAINER_NAME + "/data.img" + env.BUREAU_SAS_TOKEN' | \
+ jq '.azure_filesystems[2].mount_point = "/mnt/remote/bureau"' | \
+ jq '.azure_filesystems[2].key.kid = "BureauFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[2].key.kty = env.AZURE_AKV_KEY_TYPE' | \
+ jq '.azure_filesystems[2].key.akv.endpoint = env.AZURE_KEYVAULT_ENDPOINT' | \
+ jq '.azure_filesystems[2].key.akv.bearer_token = env.BEARER_TOKEN' | \
+ jq '.azure_filesystems[2].key_derivation.label = "BureauFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[2].key_derivation.salt = "9b53cddbe5b78a0b912a8f05f341bcd4dd839ea85d26a08efaef13e696d999f4"'`
+
+ TMP=`echo $TMP | \
+ jq '.azure_filesystems[3].azure_url = "https://" + env.AZURE_STORAGE_ACCOUNT_NAME + ".blob.core.windows.net/" + env.AZURE_FINTECH_CONTAINER_NAME + "/data.img" + env.FINTECH_SAS_TOKEN' | \
+ jq '.azure_filesystems[3].mount_point = "/mnt/remote/fintech"' | \
+ jq '.azure_filesystems[3].key.kid = "FintechFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[3].key.kty = env.AZURE_AKV_KEY_TYPE' | \
+ jq '.azure_filesystems[3].key.akv.endpoint = env.AZURE_KEYVAULT_ENDPOINT' | \
+ jq '.azure_filesystems[3].key.akv.bearer_token = env.BEARER_TOKEN' | \
+ jq '.azure_filesystems[3].key_derivation.label = "FintechFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[3].key_derivation.salt = "9b53cddbe5b78a0b912a8f05f341bcd4dd839ea85d26a08efaef13e696d999f4"'`
+
+ TMP=`echo $TMP | \
+ jq '.azure_filesystems[4].azure_url = "https://" + env.AZURE_STORAGE_ACCOUNT_NAME + ".blob.core.windows.net/" + env.AZURE_OUTPUT_CONTAINER_NAME + "/data.img" + env.OUTPUT_SAS_TOKEN' | \
+ jq '.azure_filesystems[4].mount_point = "/mnt/remote/output"' | \
+ jq '.azure_filesystems[4].key.kid = "OutputFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[4].key.kty = env.AZURE_AKV_KEY_TYPE' | \
+ jq '.azure_filesystems[4].key.akv.endpoint = env.AZURE_KEYVAULT_ENDPOINT' | \
+ jq '.azure_filesystems[4].key.akv.bearer_token = env.BEARER_TOKEN' | \
+ jq '.azure_filesystems[4].key_derivation.label = "OutputFilesystemEncryptionKey"' | \
+ jq '.azure_filesystems[4].key_derivation.salt = "9b53cddbe5b78a0b912a8f05f341bcd4dd839ea85d26a08efaef13e696d999f4"'`
+
+ ENCRYPTED_FILESYSTEM_INFORMATION=`echo $TMP | base64 --wrap=0`
+}
+
+echo Generating encrypted file system information...
+generate_encrypted_filesystem_information
+echo $ENCRYPTED_FILESYSTEM_INFORMATION > /tmp/encrypted-filesystem-config.json
+export ENCRYPTED_FILESYSTEM_INFORMATION
+
+echo Generating parameters for ACI deployment...
+TMP=$(jq '.containerRegistry.value = env.CONTAINER_REGISTRY' aci-parameters-template.json)
+TMP=`echo $TMP | jq '.ccePolicy.value = env.CCE_POLICY'`
+TMP=`echo $TMP | jq '.EncfsSideCarArgs.value = env.ENCRYPTED_FILESYSTEM_INFORMATION'`
+TMP=`echo $TMP | jq '.ContractService.value = env.CONTRACT_SERVICE_URL'`
+TMP=`echo $TMP | jq '.ContractServiceParameters.value = env.CONTRACT_SERVICE_PARAMETERS'`
+TMP=`echo $TMP | jq '.Contracts.value = env.CONTRACTS'`
+TMP=`echo $TMP | jq '.PipelineConfiguration.value = env.PIPELINE_CONFIGURATION'`
+echo $TMP > /tmp/aci-parameters.json
+
+echo Deploying training clean room...
+
+echo "Checking if resource group $AZURE_RESOURCE_GROUP exists..."
+RG_EXISTS=$(az group exists --name $AZURE_RESOURCE_GROUP)
+
+if [ "$RG_EXISTS" == "false" ]; then
+ echo "Resource group $AZURE_RESOURCE_GROUP does not exist. Creating it now..."
+ # Create the resource group
+ az group create --name $AZURE_RESOURCE_GROUP --location $AZURE_LOCATION
+else
+ echo "Resource group $AZURE_RESOURCE_GROUP already exists. Skipping creation."
+fi
+
+az deployment group create \
+ --resource-group $AZURE_RESOURCE_GROUP \
+ --template-file arm-template.json \
+ --parameters @/tmp/aci-parameters.json
+
+echo Deployment complete.
\ No newline at end of file
diff --git a/scenarios/credit-risk/deployment/azure/encrypted-filesystem-config-template.json b/scenarios/credit-risk/deployment/azure/encrypted-filesystem-config-template.json
new file mode 100644
index 0000000..ba58980
--- /dev/null
+++ b/scenarios/credit-risk/deployment/azure/encrypted-filesystem-config-template.json
@@ -0,0 +1,114 @@
+{
+ "azure_filesystems": [
+ {
+ "azure_url": "",
+ "azure_url_private": false,
+ "read_write": false,
+ "mount_point": "",
+ "key": {
+ "kid": "",
+ "kty": "",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "akv": {
+ "endpoint": "",
+ "api_version": "api-version=7.3-preview",
+ "bearer_token": ""
+ }
+ },
+ "key_derivation": {
+ "salt": "",
+ "label": ""
+ }
+ },
+ {
+ "azure_url": "",
+ "azure_url_private": false,
+ "read_write": false,
+ "mount_point": "",
+ "key": {
+ "kid": "",
+ "kty": "",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "akv": {
+ "endpoint": "",
+ "api_version": "api-version=7.3-preview",
+ "bearer_token": ""
+ }
+ },
+ "key_derivation": {
+ "salt": "",
+ "label": ""
+ }
+ },
+ {
+ "azure_url": "",
+ "azure_url_private": false,
+ "read_write": false,
+ "mount_point": "",
+ "key": {
+ "kid": "",
+ "kty": "",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "akv": {
+ "endpoint": "",
+ "api_version": "api-version=7.3-preview",
+ "bearer_token": ""
+ }
+ },
+ "key_derivation": {
+ "salt": "",
+ "label": ""
+ }
+ },
+ {
+ "azure_url": "",
+ "azure_url_private": false,
+ "read_write": false,
+ "mount_point": "",
+ "key": {
+ "kid": "",
+ "kty": "",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "akv": {
+ "endpoint": "",
+ "api_version": "api-version=7.3-preview",
+ "bearer_token": ""
+ }
+ },
+ "key_derivation": {
+ "salt": "",
+ "label": ""
+ }
+ },
+ {
+ "azure_url": "",
+ "azure_url_private": false,
+ "read_write": true,
+ "mount_point": "",
+ "key": {
+ "kid": "",
+ "kty": "",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "akv": {
+ "endpoint": "",
+ "api_version": "api-version=7.3-preview",
+ "bearer_token": ""
+ }
+ },
+ "key_derivation": {
+ "salt": "",
+ "label": ""
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/scenarios/credit-risk/deployment/azure/generatefs.sh b/scenarios/credit-risk/deployment/azure/generatefs.sh
new file mode 100755
index 0000000..df8833e
--- /dev/null
+++ b/scenarios/credit-risk/deployment/azure/generatefs.sh
@@ -0,0 +1,75 @@
+#!/bin/bash
+
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+while getopts ":d:k:i:" options; do
+ case $options in
+ d)dataPath=$OPTARG;;
+ k)keyFilePath=$OPTARG;;
+ i)encryptedImage=$OPTARG;;
+ esac
+done
+
+echo Encrypting $dataPath with key $keyFilePath and generating $encryptedImage
+deviceName=cryptdevice1
+deviceNamePath="/dev/mapper/$deviceName"
+
+if [ -f "$keyFilePath" ]; then
+ echo "[!] Encrypting dataset using $keyFilePath"
+else
+ echo "[!] Generating keyfile..."
+ dd if=/dev/random of="$keyFilePath" count=1 bs=32
+ truncate -s 32 "$keyFilePath"
+fi
+
+echo "[!] Creating encrypted image..."
+
+response=`du -s $dataPath`
+read -ra arr <<< "$response"
+size=`echo "x=l($arr)/l(2); scale=0; 2^((x+0.5)/1)*2" | bc -l;`
+
+# cryptsetup requires 16M or more
+
+if (($((size)) < 65536)); then
+ size="65536"
+fi
+size=$size"K"
+
+echo "Data size: $size"
+
+rm -f "$encryptedImage"
+touch "$encryptedImage"
+truncate --size $size "$encryptedImage"
+
+sudo cryptsetup luksFormat --type luks2 "$encryptedImage" \
+ --key-file "$keyFilePath" -v --batch-mode --sector-size 4096 \
+ --cipher aes-xts-plain64 \
+ --pbkdf pbkdf2 --pbkdf-force-iterations 1000
+
+sudo cryptsetup luksOpen "$encryptedImage" "$deviceName" \
+ --key-file "$keyFilePath" \
+ --integrity-no-journal --persistent
+
+echo "[!] Formatting as ext4..."
+
+sudo mkfs.ext4 "$deviceNamePath"
+
+echo "[!] Mounting..."
+
+mountPoint=`mktemp -d`
+echo "Mounting to $mountPoint"
+sudo mount -t ext4 "$deviceNamePath" "$mountPoint" -o loop
+
+echo "[!] Copying contents to encrypted device..."
+
+# The /* is needed to copy folder contents instead of the folder + contents
+sudo cp -r $dataPath/* "$mountPoint"
+sudo rm -rf "$mountPoint/lost+found"
+ls "$mountPoint"
+
+echo "[!] Closing device..."
+
+sudo umount "$mountPoint"
+sleep 2
+sudo cryptsetup luksClose "$deviceName"
diff --git a/scenarios/credit-risk/deployment/azure/importkey-config-template.json b/scenarios/credit-risk/deployment/azure/importkey-config-template.json
new file mode 100644
index 0000000..42ed7ee
--- /dev/null
+++ b/scenarios/credit-risk/deployment/azure/importkey-config-template.json
@@ -0,0 +1,29 @@
+{
+ "key":{
+ "kid": "",
+ "kty": "oct-HSM",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "akv": {
+ "endpoint": "",
+ "api_version": "api-version=7.3-preview",
+ "bearer_token": ""
+ }
+ },
+ "key_derivation":
+ {
+ "salt": "",
+ "label": ""
+ },
+ "claims": [
+ [{
+ "claim": "x-ms-sevsnpvm-hostdata",
+ "equals": ""
+ },
+ {
+ "claim": "x-ms-compliance-status",
+ "equals": "azure-compliant-uvm"
+ }]
+ ]
+}
\ No newline at end of file
diff --git a/scenarios/credit-risk/deployment/local/docker-compose-preprocess.yml b/scenarios/credit-risk/deployment/local/docker-compose-preprocess.yml
new file mode 100644
index 0000000..523fedf
--- /dev/null
+++ b/scenarios/credit-risk/deployment/local/docker-compose-preprocess.yml
@@ -0,0 +1,37 @@
+services:
+ bank_a:
+ image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}preprocess-bank-a:latest
+ volumes:
+ - $BANK_A_INPUT_PATH:/mnt/input/data
+ - $BANK_A_OUTPUT_PATH:/mnt/output/preprocessed
+ environment:
+ - KAGGLE_USERNAME=${KAGGLE_USERNAME}
+ - KAGGLE_KEY=${KAGGLE_KEY}
+ command: ["python3", "preprocess_bank_a.py"]
+ bank_b:
+ image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}preprocess-bank-b:latest
+ volumes:
+ - $BANK_B_INPUT_PATH:/mnt/input/data
+ - $BANK_B_OUTPUT_PATH:/mnt/output/preprocessed
+ environment:
+ - KAGGLE_USERNAME=${KAGGLE_USERNAME}
+ - KAGGLE_KEY=${KAGGLE_KEY}
+ command: ["python3", "preprocess_bank_b.py"]
+ bureau:
+ image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}preprocess-bureau:latest
+ volumes:
+ - $BUREAU_INPUT_PATH:/mnt/input/data
+ - $BUREAU_OUTPUT_PATH:/mnt/output/preprocessed
+ environment:
+ - KAGGLE_USERNAME=${KAGGLE_USERNAME}
+ - KAGGLE_KEY=${KAGGLE_KEY}
+ command: ["python3", "preprocess_bureau.py"]
+ fintech:
+ image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}preprocess-fintech:latest
+ volumes:
+ - $FINTECH_INPUT_PATH:/mnt/input/data
+ - $FINTECH_OUTPUT_PATH:/mnt/output/preprocessed
+ environment:
+ - KAGGLE_USERNAME=${KAGGLE_USERNAME}
+ - KAGGLE_KEY=${KAGGLE_KEY}
+ command: ["python3", "preprocess_fintech.py"]
\ No newline at end of file
diff --git a/scenarios/credit-risk/deployment/local/docker-compose-train.yml b/scenarios/credit-risk/deployment/local/docker-compose-train.yml
new file mode 100644
index 0000000..da07f25
--- /dev/null
+++ b/scenarios/credit-risk/deployment/local/docker-compose-train.yml
@@ -0,0 +1,12 @@
+services:
+ train:
+ image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}depa-training:latest
+ volumes:
+ - $BANK_A_INPUT_PATH:/mnt/remote/bank_a
+ - $BANK_B_INPUT_PATH:/mnt/remote/bank_b
+ - $BUREAU_INPUT_PATH:/mnt/remote/bureau
+ - $FINTECH_INPUT_PATH:/mnt/remote/fintech
+ - $MODEL_OUTPUT_PATH:/mnt/remote/output
+ - $CONFIGURATION_PATH:/mnt/remote/config
+ command: ["/bin/bash", "run.sh"]
+
\ No newline at end of file
diff --git a/scenarios/credit-risk/deployment/local/preprocess.sh b/scenarios/credit-risk/deployment/local/preprocess.sh
new file mode 100755
index 0000000..f1888cc
--- /dev/null
+++ b/scenarios/credit-risk/deployment/local/preprocess.sh
@@ -0,0 +1,18 @@
+#!/bin/bash
+
+export REPO_ROOT="$(git rev-parse --show-toplevel)"
+export SCENARIO="credit-risk"
+export DATA_DIR=$REPO_ROOT/scenarios/$SCENARIO/data
+export BANK_A_INPUT_PATH=$DATA_DIR/bank_a
+export BANK_B_INPUT_PATH=$DATA_DIR/bank_b
+export BUREAU_INPUT_PATH=$DATA_DIR/bureau
+export FINTECH_INPUT_PATH=$DATA_DIR/fintech
+export BANK_A_OUTPUT_PATH=$DATA_DIR/bank_a/preprocessed
+export BANK_B_OUTPUT_PATH=$DATA_DIR/bank_b/preprocessed
+export BUREAU_OUTPUT_PATH=$DATA_DIR/bureau/preprocessed
+export FINTECH_OUTPUT_PATH=$DATA_DIR/fintech/preprocessed
+mkdir -p $BANK_A_OUTPUT_PATH
+mkdir -p $BANK_B_OUTPUT_PATH
+mkdir -p $BUREAU_OUTPUT_PATH
+mkdir -p $FINTECH_OUTPUT_PATH
+docker compose -f docker-compose-preprocess.yml up --remove-orphans
diff --git a/scenarios/credit-risk/deployment/local/train.sh b/scenarios/credit-risk/deployment/local/train.sh
new file mode 100755
index 0000000..c5a5f91
--- /dev/null
+++ b/scenarios/credit-risk/deployment/local/train.sh
@@ -0,0 +1,23 @@
+#!/bin/bash
+
+export REPO_ROOT="$(git rev-parse --show-toplevel)"
+export SCENARIO="credit-risk"
+
+export DATA_DIR=$REPO_ROOT/scenarios/$SCENARIO/data
+export MODEL_DIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+export BANK_A_INPUT_PATH=$DATA_DIR/bank_a/preprocessed
+export BANK_B_INPUT_PATH=$DATA_DIR/bank_b/preprocessed
+export BUREAU_INPUT_PATH=$DATA_DIR/bureau/preprocessed
+export FINTECH_INPUT_PATH=$DATA_DIR/fintech/preprocessed
+
+export MODEL_OUTPUT_PATH=$MODEL_DIR/output
+sudo rm -rf $MODEL_OUTPUT_PATH
+mkdir -p $MODEL_OUTPUT_PATH
+
+export CONFIGURATION_PATH=$REPO_ROOT/scenarios/$SCENARIO/config
+
+# Run consolidate_pipeline.sh to create pipeline_config.json
+$REPO_ROOT/scenarios/$SCENARIO/config/consolidate_pipeline.sh
+
+docker compose -f docker-compose-train.yml up --remove-orphans
diff --git a/scenarios/credit-risk/export-variables.sh b/scenarios/credit-risk/export-variables.sh
new file mode 100755
index 0000000..d79c6a0
--- /dev/null
+++ b/scenarios/credit-risk/export-variables.sh
@@ -0,0 +1,64 @@
+#!/bin/bash
+
+# Azure Naming Rules:
+#
+# Resource Group:
+# - 1–90 characters
+# - Letters, numbers, underscores, parentheses, hyphens, periods allowed
+# - Cannot end with a period (.)
+# - Case-insensitive, unique within subscription
+#
+# Key Vault:
+# - 3-24 characters
+# - Globally unique name
+# - Lowercase letters, numbers, hyphens only
+# - Must start and end with letter or number
+#
+# Storage Account:
+# - 3-24 characters
+# - Globally unique name
+# - Lowercase letters and numbers only
+#
+# Storage Container:
+# - 3-63 characters
+# - Lowercase letters, numbers, hyphens only
+# - Must start and end with a letter or number
+# - No consecutive hyphens
+# - Unique within storage account
+
+# For cloud resource creation:
+declare -x SCENARIO=credit-risk
+declare -x REPO_ROOT="$(git rev-parse --show-toplevel)"
+declare -x CONTAINER_REGISTRY=ispirt.azurecr.io
+declare -x AZURE_LOCATION=centralindia
+declare -x AZURE_SUBSCRIPTION_ID=
+declare -x AZURE_RESOURCE_GROUP=
+declare -x AZURE_KEYVAULT_ENDPOINT=
+declare -x AZURE_STORAGE_ACCOUNT_NAME=
+
+declare -x AZURE_BANK_A_CONTAINER_NAME=bankacontainer
+declare -x AZURE_BANK_B_CONTAINER_NAME=bankbcontainer
+declare -x AZURE_BUREAU_CONTAINER_NAME=bureaucontainer
+declare -x AZURE_FINTECH_CONTAINER_NAME=fintechcontainer
+declare -x AZURE_OUTPUT_CONTAINER_NAME=outputcontainer
+
+# For key import:
+declare -x CONTRACT_SERVICE_URL=https://depa-training-contract-service.centralindia.cloudapp.azure.com:8000
+declare -x TOOLS_HOME=$REPO_ROOT/external/confidential-sidecar-containers/tools
+
+# Export all variables to make them available to other scripts
+export SCENARIO
+export REPO_ROOT
+export CONTAINER_REGISTRY
+export AZURE_LOCATION
+export AZURE_SUBSCRIPTION_ID
+export AZURE_RESOURCE_GROUP
+export AZURE_KEYVAULT_ENDPOINT
+export AZURE_STORAGE_ACCOUNT_NAME
+export AZURE_BANK_A_CONTAINER_NAME
+export AZURE_BANK_B_CONTAINER_NAME
+export AZURE_BUREAU_CONTAINER_NAME
+export AZURE_FINTECH_CONTAINER_NAME
+export AZURE_OUTPUT_CONTAINER_NAME
+export CONTRACT_SERVICE_URL
+export TOOLS_HOME
\ No newline at end of file
diff --git a/scenarios/credit-risk/policy/policy-in-template.json b/scenarios/credit-risk/policy/policy-in-template.json
new file mode 100644
index 0000000..c093bdd
--- /dev/null
+++ b/scenarios/credit-risk/policy/policy-in-template.json
@@ -0,0 +1,63 @@
+{
+ "version": "1.0",
+ "containers": [
+ {
+ "containerImage": "$CONTAINER_REGISTRY/depa-training:latest",
+ "command": [
+ "/bin/bash",
+ "run.sh"
+ ],
+ "environmentVariables": [],
+ "mounts": [
+ {
+ "mountType": "emptyDir",
+ "mountPath": "/mnt/remote",
+ "readonly": false
+ }
+ ]
+ },
+ {
+ "containerImage": "$CONTAINER_REGISTRY/depa-training-encfs:latest",
+ "environmentVariables": [
+ {
+ "name" : "EncfsSideCarArgs",
+ "value" : ".+",
+ "strategy" : "re2"
+ },
+ {
+ "name": "ContractService",
+ "value": ".+",
+ "strategy": "re2"
+ },
+ {
+ "name": "ContractServiceParameters",
+ "value": "$CONTRACT_SERVICE_PARAMETERS",
+ "strategy": "string"
+ },
+ {
+ "name": "Contracts",
+ "value": ".+",
+ "strategy": "re2"
+ },
+ {
+ "name": "PipelineConfiguration",
+ "value": ".+",
+ "strategy": "re2"
+ }
+ ],
+ "command": [
+ "/encfs.sh"
+ ],
+ "securityContext": {
+ "privileged": "true"
+ },
+ "mounts": [
+ {
+ "mountType": "emptyDir",
+ "mountPath": "/mnt/remote",
+ "readonly": false
+ }
+ ]
+ }
+ ]
+}
diff --git a/scenarios/credit-risk/src/preprocess_bank_a.py b/scenarios/credit-risk/src/preprocess_bank_a.py
new file mode 100644
index 0000000..477a889
--- /dev/null
+++ b/scenarios/credit-risk/src/preprocess_bank_a.py
@@ -0,0 +1,125 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
+import os
+import zipfile
+from kaggle import KaggleApi
+import pandas as pd
+import numpy as np
+
+# # Get the KAGGLE_USERNAME and KAGGLE_KEY from your kaggle.json file downloaded from Kaggle.com > Account > Settings > API > Create new token
+# os.environ["KAGGLE_USERNAME"] = "your_username"
+# os.environ["KAGGLE_KEY"] = "your_key"
+
+DATA_DIR = '/mnt/input/data'
+OUT_DIR = '/mnt/output/preprocessed'
+
+api = KaggleApi()
+api.authenticate()
+
+FILES_TO_GET = [
+'application_train.csv',
+'application_test.csv',
+'previous_application.csv',
+'installments_payments.csv'
+]
+
+COMP = 'home-credit-default-risk'
+
+# helper to download (will overwrite if exists)
+def download_file(fname):
+ target = os.path.join(DATA_DIR, fname)
+ if os.path.exists(target):
+ print(f"{fname} already exists, skipping download")
+ return target
+ api.competition_download_file(COMP, fname, path=DATA_DIR, force=True)
+ # add .zip extension
+ os.rename(os.path.join(DATA_DIR, fname), os.path.join(DATA_DIR, fname + '.zip'))
+ print(f"Downloading {fname} to {DATA_DIR} ...")
+ # API sometimes zips single files; if zipped create extraction logic
+ zipped = target + '.zip'
+ if os.path.exists(zipped):
+ with zipfile.ZipFile(zipped, 'r') as z:
+ z.extractall(DATA_DIR)
+ os.remove(zipped)
+ return target
+
+# download
+for f in FILES_TO_GET:
+ download_file(f)
+
+###########################
+
+
+app_cols = [
+'SK_ID_CURR','TARGET','NAME_CONTRACT_TYPE','CODE_GENDER','FLAG_OWN_CAR','FLAG_OWN_REALTY',
+'CNT_CHILDREN','AMT_INCOME_TOTAL','AMT_CREDIT','AMT_ANNUITY','AMT_GOODS_PRICE',
+'NAME_INCOME_TYPE','NAME_EDUCATION_TYPE','NAME_FAMILY_STATUS','NAME_HOUSING_TYPE',
+'DAYS_BIRTH','DAYS_EMPLOYED','FLAG_MOBIL','REGION_POPULATION_RELATIVE','OBS_30_CNT_SOCIAL_CIRCLE',
+'OBS_60_CNT_SOCIAL_CIRCLE','DAYS_ID_PUBLISH'
+]
+# Get header first
+available_cols = pd.read_csv(os.path.join(DATA_DIR, "application_train.csv"), nrows=0).columns
+
+# Load with column filtering
+app_train = pd.read_csv(
+ os.path.join(DATA_DIR, "application_train.csv"),
+ usecols=[c for c in app_cols if c in available_cols]
+)
+
+app_train = app_train.dropna()
+
+# Feature examples
+app_train['INCOME_PER_PERSON'] = app_train['AMT_INCOME_TOTAL'] / (app_train['CNT_CHILDREN'].replace({0:1}) + 1)
+app_train['CREDIT_TO_INCOME'] = app_train['AMT_CREDIT'] / (app_train['AMT_INCOME_TOTAL'].replace({0:1}))
+
+cat_cols = app_train.select_dtypes(include=['object']).columns.tolist()
+app_train[['SK_ID_CURR','TARGET','INCOME_PER_PERSON','CREDIT_TO_INCOME'] + [c for c in cat_cols]].to_parquet(os.path.join(OUT_DIR, 'bank_a_app_learn.parquet'), index=False)
+
+# previous_application aggregation (per SK_ID_CURR)
+prev = pd.read_csv(os.path.join(DATA_DIR, 'previous_application.csv'))
+# create a couple of useful aggs
+prev_agg = prev.groupby('SK_ID_CURR').agg({
+'AMT_APPLICATION': ['count','mean','sum'],
+'AMT_CREDIT': ['mean','max'],
+'AMT_DOWN_PAYMENT': ['mean'],
+'CNT_PAYMENT': ['mean']
+})
+# flatten
+prev_agg.columns = ['_'.join(col).strip() for col in prev_agg.columns.values]
+prev_agg.reset_index(inplace=True)
+prev_agg.to_parquet(os.path.join(OUT_DIR, 'bank_a_prev_agg.parquet'), index=False)
+
+# installments_payments aggregation
+inst = pd.read_csv(os.path.join(DATA_DIR, 'installments_payments.csv'))
+# create basic payment discipline features
+inst['PAYMENT_DIFF'] = inst['AMT_PAYMENT'] - inst['AMT_INSTALMENT']
+inst['LATE_PAYMENT'] = (inst['DAYS_ENTRY_PAYMENT'] > 0).astype(int)
+inst_agg = inst.groupby('SK_ID_CURR').agg({
+'AMT_PAYMENT': ['sum','mean'],
+'AMT_INSTALMENT': ['sum','mean'],
+'PAYMENT_DIFF': ['mean'],
+'LATE_PAYMENT': ['sum','mean']
+})
+inst_agg.columns = ['_'.join(col).strip() for col in inst_agg.columns.values]
+inst_agg.reset_index(inplace=True)
+inst_agg.to_parquet(os.path.join(OUT_DIR, 'bank_a_inst_agg.parquet'), index=False)
+
+print('bank_a preprocessing finished -> processed/*.parquet')
+
+##########################
+
diff --git a/scenarios/credit-risk/src/preprocess_bank_b.py b/scenarios/credit-risk/src/preprocess_bank_b.py
new file mode 100644
index 0000000..28d8534
--- /dev/null
+++ b/scenarios/credit-risk/src/preprocess_bank_b.py
@@ -0,0 +1,70 @@
+import os
+import zipfile
+from kaggle import KaggleApi
+import pandas as pd
+import numpy as np
+
+# # Get the KAGGLE_USERNAME and KAGGLE_KEY from your kaggle.json file downloaded from Kaggle.com > Account > Settings > API > Create new token
+# os.environ["KAGGLE_USERNAME"] = "your_username"
+# os.environ["KAGGLE_KEY"] = "your_key"
+
+DATA_DIR = '/mnt/input/data'
+OUT_DIR = '/mnt/output/preprocessed'
+
+api = KaggleApi()
+api.authenticate()
+
+FILES_TO_GET = ['credit_card_balance.csv']
+
+COMP = 'home-credit-default-risk'
+
+# helper to download (will overwrite if exists)
+def download_file(fname):
+ target = os.path.join(DATA_DIR, fname)
+ if os.path.exists(target):
+ print(f"{fname} already exists, skipping download")
+ return target
+ api.competition_download_file(COMP, fname, path=DATA_DIR, force=True)
+ # add .zip extension
+ os.rename(os.path.join(DATA_DIR, fname), os.path.join(DATA_DIR, fname + '.zip'))
+ print(f"Downloading {fname} to {DATA_DIR} ...")
+ # API sometimes zips single files; if zipped create extraction logic
+ zipped = target + '.zip'
+ if os.path.exists(zipped):
+ with zipfile.ZipFile(zipped, 'r') as z:
+ z.extractall(DATA_DIR)
+ os.remove(zipped)
+ return target
+
+# download
+for f in FILES_TO_GET:
+ download_file(f)
+
+##########################
+
+cc = pd.read_csv(os.path.join(DATA_DIR,'credit_card_balance.csv'))
+
+# important derived features: utilization ratio, mean balance, max days past due
+cols_for_agg = {}
+if 'AMT_BALANCE' in cc.columns:
+ cols_for_agg['AMT_BALANCE'] = ['mean','max']
+if 'AMT_CREDIT_LIMIT_ACTUAL' in cc.columns:
+ cols_for_agg['AMT_CREDIT_LIMIT_ACTUAL'] = ['mean','max']
+if 'AMT_DRAWINGS_CURRENT' in cc.columns:
+ cols_for_agg['AMT_DRAWINGS_CURRENT'] = ['mean','max']
+if 'SK_DPD' in cc.columns:
+ cols_for_agg['SK_DPD'] = ['max','mean']
+if 'SK_DPD_DEF' in cc.columns:
+ cols_for_agg['SK_DPD_DEF'] = ['max','mean']
+
+if not cols_for_agg:
+ numcols = cc.select_dtypes(include=[np.number]).columns.tolist()
+ for c in numcols[:5]:
+ cols_for_agg[c] = ['mean','max']
+
+cc_agg = cc.groupby('SK_ID_CURR').agg(cols_for_agg)
+cc_agg.columns = ['_'.join(col).strip() for col in cc_agg.columns.values]
+cc_agg.reset_index(inplace=True)
+cc_agg.to_parquet(os.path.join(OUT_DIR,'bank_b_cc_agg.parquet'), index=False)
+
+print('bank_b preprocessing finished -> processed/bank_b_cc_agg.parquet')
\ No newline at end of file
diff --git a/scenarios/credit-risk/src/preprocess_bureau.py b/scenarios/credit-risk/src/preprocess_bureau.py
new file mode 100644
index 0000000..d5b2dec
--- /dev/null
+++ b/scenarios/credit-risk/src/preprocess_bureau.py
@@ -0,0 +1,80 @@
+import os
+import zipfile
+from kaggle import KaggleApi
+import pandas as pd
+import numpy as np
+
+# # Get the KAGGLE_USERNAME and KAGGLE_KEY from your kaggle.json file downloaded from Kaggle.com > Account > Settings > API > Create new token
+# os.environ["KAGGLE_USERNAME"] = "your_username"
+# os.environ["KAGGLE_KEY"] = "your_key"
+
+DATA_DIR = '/mnt/input/data'
+OUT_DIR = '/mnt/output/preprocessed'
+
+api = KaggleApi()
+api.authenticate()
+
+FILES_TO_GET = ['bureau.csv','bureau_balance.csv']
+
+COMP = 'home-credit-default-risk'
+
+# helper to download (will overwrite if exists)
+def download_file(fname):
+ target = os.path.join(DATA_DIR, fname)
+ if os.path.exists(target):
+ print(f"{fname} already exists, skipping download")
+ return target
+ api.competition_download_file(COMP, fname, path=DATA_DIR, force=True)
+ # add .zip extension
+ os.rename(os.path.join(DATA_DIR, fname), os.path.join(DATA_DIR, fname + '.zip'))
+ print(f"Downloading {fname} to {DATA_DIR} ...")
+ # API sometimes zips single files; if zipped create extraction logic
+ zipped = target + '.zip'
+ if os.path.exists(zipped):
+ with zipfile.ZipFile(zipped, 'r') as z:
+ z.extractall(DATA_DIR)
+ os.remove(zipped)
+ return target
+
+# download
+for f in FILES_TO_GET:
+ download_file(f)
+
+##########################
+
+bureau = pd.read_csv(os.path.join(DATA_DIR,'bureau.csv'))
+balance = pd.read_csv(os.path.join(DATA_DIR,'bureau_balance.csv'))
+
+bureau = bureau.dropna()
+balance = balance.dropna()
+
+# aggregate bureau_balance by SK_ID_BUREAU first: e.g., proportion of months past-due
+if 'STATUS' in balance.columns:
+ bal_agg = balance.groupby('SK_ID_BUREAU').agg({'MONTHS_BALANCE': ['count'], 'STATUS': lambda x: (x.astype(str).str.contains('2|3|4|5')).mean()})
+ bal_agg.columns = ['MONTHS_COUNT','STATUS_PAST_DUE_RATE']
+ bal_agg.reset_index(inplace=True)
+ bureau = bureau.merge(bal_agg, how='left', left_on='SK_ID_BUREAU', right_on='SK_ID_BUREAU')
+
+# now aggregate bureau by SK_ID_CURR
+agg_map = {}
+# defensive column existence checks
+for col in ['AMT_CREDIT_SUM','AMT_CREDIT_SUM_DEBT','AMT_CREDIT_SUM_OVERDUE','DAYS_CREDIT','DAYS_ENDDATE_FACT']:
+ if col in bureau.columns:
+ agg_map[col] = ['mean','max']
+
+# also count number of bureau records
+if 'SK_ID_BUREAU' in bureau.columns:
+ bureau['BUREAU_COUNT'] = 1
+ agg_map['BUREAU_COUNT'] = ['sum']
+
+if not agg_map:
+ raise RuntimeError('No expected columns found in bureau.csv; inspect file')
+
+bureau_agg = bureau.groupby('SK_ID_CURR').agg(agg_map)
+# flatten
+bureau_agg.columns = ['_'.join(col).strip() for col in bureau_agg.columns.values]
+bureau_agg.reset_index(inplace=True)
+bureau_agg.to_parquet(os.path.join(OUT_DIR, 'bureau_agg.parquet'), index=False)
+
+print('bureau preprocessing finished -> processed/bureau_agg.parquet')
+
diff --git a/scenarios/credit-risk/src/preprocess_fintech.py b/scenarios/credit-risk/src/preprocess_fintech.py
new file mode 100644
index 0000000..cf81b4c
--- /dev/null
+++ b/scenarios/credit-risk/src/preprocess_fintech.py
@@ -0,0 +1,66 @@
+import os
+import zipfile
+from kaggle import KaggleApi
+import pandas as pd
+import numpy as np
+
+# # Get the KAGGLE_USERNAME and KAGGLE_KEY from your kaggle.json file downloaded from Kaggle.com > Account > Settings > API > Create new token
+# os.environ["KAGGLE_USERNAME"] = "your_username"
+# os.environ["KAGGLE_KEY"] = "your_key"
+
+DATA_DIR = '/mnt/input/data'
+OUT_DIR = '/mnt/output/preprocessed'
+
+api = KaggleApi()
+api.authenticate()
+
+FILES_TO_GET = ['POS_CASH_balance.csv']
+
+COMP = 'home-credit-default-risk'
+
+# helper to download (will overwrite if exists)
+def download_file(fname):
+ target = os.path.join(DATA_DIR, fname)
+ if os.path.exists(target):
+ print(f"{fname} already exists, skipping download")
+ return target
+ api.competition_download_file(COMP, fname, path=DATA_DIR, force=True)
+ # add .zip extension
+ os.rename(os.path.join(DATA_DIR, fname), os.path.join(DATA_DIR, fname + '.zip'))
+ print(f"Downloading {fname} to {DATA_DIR} ...")
+ # API sometimes zips single files; if zipped create extraction logic
+ zipped = target + '.zip'
+ if os.path.exists(zipped):
+ with zipfile.ZipFile(zipped, 'r') as z:
+ z.extractall(DATA_DIR)
+ os.remove(zipped)
+ return target
+
+# download
+for f in FILES_TO_GET:
+ download_file(f)
+
+##########################
+
+pos = pd.read_csv(os.path.join(DATA_DIR,'POS_CASH_balance.csv'))
+
+agg_map = {}
+if 'SK_DPD' in pos.columns:
+ agg_map['SK_DPD'] = ['max','mean']
+if 'SK_DPD_DEF' in pos.columns:
+ agg_map['SK_DPD_DEF'] = ['max','mean']
+if 'MONTHS_BALANCE' in pos.columns:
+ agg_map['MONTHS_BALANCE'] = ['count']
+
+if not agg_map:
+ # fallback generic numeric aggs
+ numcols = pos.select_dtypes(include=[np.number]).columns.tolist()
+ for c in numcols[:5]:
+ agg_map[c] = ['mean','max']
+
+pos_agg = pos.groupby('SK_ID_CURR').agg(agg_map)
+pos_agg.columns = ['_'.join(col).strip() for col in pos_agg.columns.values]
+pos_agg.reset_index(inplace=True)
+pos_agg.to_parquet(os.path.join(OUT_DIR,'pos_fintech_agg.parquet'), index=False)
+
+print('pos_fintech preprocessing finished -> processed/pos_fintech_agg.parquet')
\ No newline at end of file
diff --git a/scenarios/mnist/.gitignore b/scenarios/mnist/.gitignore
index c6ca4c5..0f7a6d2 100644
--- a/scenarios/mnist/.gitignore
+++ b/scenarios/mnist/.gitignore
@@ -1,4 +1,15 @@
-data/cifar*
-*.pth
+**/preprocessed/*
*.bin
-*.img
\ No newline at end of file
+*.img
+*.pth
+*.pt
+*.onnx
+*.npy
+*.gz
+
+# Ignore modeller output folder (relative to repo root)
+modeller/output/
+
+data/
+
+**/__pycache__/
\ No newline at end of file
diff --git a/scenarios/mnist/README.md b/scenarios/mnist/README.md
index 9b3bb60..b604360 100644
--- a/scenarios/mnist/README.md
+++ b/scenarios/mnist/README.md
@@ -1,86 +1,127 @@
-# Convolution Neural Network using MNIST
+# MNIST Handwritten Digits Classification
-This scenario involves training a CNN using the MNIST dataset. It involves one training data provider (TDP), and a TDC who wishes the train a model.
+## Scenario Type
-The end-to-end training pipeline consists of the following phases.
+| Scenario name | Scenario type | Task type | Privacy | No. of TDPs* | Data type (format) | Model type (format) | Join type (No. of datasets) |
+|--------------|---------------|-----------------|--------------|-----------|------------|------------|------------|
+| MNIST | Training - Deep Learning | Multi-class Image Classification | NA | 1 | Non-PII image data (HDF5) | CNN (ONNX) | NA (1)|
-1. Data pre-processing and de-identification
+---
+
+## Scenario Description
+
+This scenario involves training a CNN for image classification using the MNIST handwritten digits dataset. It involves one Training Data Provider (TDP), and a Training Data Consumer (TDC) who wishes to train a model on the dataset. The MNIST dataset is a collection of 60,000 28x28 grayscale images of handwritten digits, with 10 classes (0-9), with 6,000 images per class.
+
+The end-to-end training pipeline consists of the following phases:
+
+1. Data pre-processing
2. Data packaging, encryption and upload
-3. Model packaging, encryption and upload
+3. Model packaging, encryption and upload
4. Encryption key import with key release policies
5. Deployment and execution of CCR
-6. Model decryption
+6. Trained model decryption
## Build container images
-Build container images required for this sample as follows.
+Build container images required for this sample as follows:
```bash
-cd scenarios/mnist
+export SCENARIO=mnist
+export REPO_ROOT="$(git rev-parse --show-toplevel)"
+cd $REPO_ROOT/scenarios/$SCENARIO
./ci/build.sh
-./ci/push-containers.sh
```
-These scripts build the following containers and push them to the container registry set in $CONTAINER_REGISTRY.
+This script builds the following container images:
-- ```depa-mnist-preprocess```: Container for pre-processing MNIST dataset.
-- ```depa-mnist-save-model```: Container that saves the model to be trained in ONNX format.
+- ```preprocess-mnist```: Container for pre-processing MNIST dataset.
+- ```mnist-model-save```: Container that saves the model to be trained in ONNX format.
-## Data pre-processing and de-identification
+Alternatively, you can pull and use pre-built container images from the ispirt container registry by setting the following environment variable. Docker hub has started throttling which may effect the upload/download time, especially when images are bigger size. So, It is advisable to use other container registries. We are using Azure container registry (ACR) as shown below:
-The folders ```scenarios/mnist/data``` contains scripts for downloading and pre-processing the MNIST dataset. Acting as a TDP for this dataset, run the following script.
+```bash
+export CONTAINER_REGISTRY=ispirt.azurecr.io
+cd $REPO_ROOT/scenarios/$SCENARIO
+./ci/pull-containers.sh
+```
+
+## Data pre-processing
+
+The folder ```scenarios/mnist/src``` contains scripts for downloading and pre-processing the MNIST dataset. Acting as a Training Data Provider (TDP), prepare your datasets:
```bash
-cd scenarios/mnist/deployment/docker
+cd $REPO_ROOT/scenarios/$SCENARIO/deployment/local
./preprocess.sh
```
+The datasets are saved to the [data](./data/) directory.
+
## Prepare model for training
-Next, acting as a TDC, save a sample model using the following script.
+Next, acting as a Training Data Consumer (TDC), define and save your base model for training using the following script. This calls the [save_base_model.py](./src/save_base_model.py) script, which is a custom script that saves the model to the [models](./modeller/models) directory, as an ONNX file:
```bash
./save-model.sh
```
-This script will save the model as ```scenarios/mnist/data/model/model.onnx.```
-
## Deploy locally
-Assuming you have cleartext access to the pre-processed dataset, you can train a CNN as follows.
+Assuming you have cleartext access to all the datasets, you can train the model _locally_ as follows:
```bash
./train.sh
```
-The script trains a model using a pipeline configuration defined in [pipeline_config.json](./config/pipeline_config.json). If all goes well, you should see output similar to the following output, and the trained model will be saved under the folder `/tmp/output`.
-
-```
-docker-train-1 | /usr/local/lib/python3.9/dist-packages/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension: 'libc10_cuda.so: cannot open shared object file: No such file or directory'If you don't plan on using image functionality from `torchvision.io`, you can ignore this warning. Otherwise, there might be something wrong with your environment. Did you have `libjpeg` or `libpng` installed before building `torchvision` from source?
-docker-train-1 | warn(
-docker-train-1 | /usr/local/lib/python3.9/dist-packages/onnx2pytorch/convert/layer.py:30: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at ../torch/csrc/utils/tensor_numpy.cpp:206.)
-docker-train-1 | layer.weight.data = torch.from_numpy(numpy_helper.to_array(weight))
-docker-train-1 | /usr/local/lib/python3.9/dist-packages/onnx2pytorch/convert/model.py:147: UserWarning: Using experimental implementation that allows 'batch_size > 1'.Batchnorm layers could potentially produce false outputs.
-docker-train-1 | warnings.warn(
-docker-train-1 | [1, 2000] loss: 2.242
-docker-train-1 | [1, 4000] loss: 1.972
-docker-train-1 | [1, 6000] loss: 1.799
-docker-train-1 | [1, 8000] loss: 1.695
-docker-train-1 | [1, 10000] loss: 1.642
-docker-train-1 | [1, 12000] loss: 1.581
-docker-train-1 | [1, 14000] loss: 1.545
-docker-train-1 | [1, 16000] loss: 1.502
-docker-train-1 | [1, 18000] loss: 1.520
-docker-train-1 | [1, 20000] loss: 1.471
-docker-train-1 | [1, 22000] loss: 1.438
-docker-train-1 | [1, 24000] loss: 1.435
-docker-train-1 | [2, 2000] loss: 1.402
-docker-train-1 | [2, 4000] loss: 1.358
-docker-train-1 | [2, 6000] loss: 1.379
-docker-train-1 | [2, 8000] loss: 1.355
-...
-```
-
-## Deploy to Azure
+
+The script joins the datasets and trains the model using a pipeline configuration. To modify the various components of the training pipeline, you can edit the training config files in the [config](./config/) directory. The training config files are used to create the pipeline configuration ([pipeline_config.json](./config/pipeline_config.json)) created by consolidating all the TDC's training config files, namely the [model config](./config/model_config.json), [dataset config](./config/dataset_config.json), [loss function config](./config/loss_config.json), [training config](./config/train_config_template.json), [evaluation config](./config/eval_config.json), and if multiple datasets are used, the [data join config](./config/join_config.json). These enable the TDC to design highly customized training pipelines without requiring review and approval of new custom code for each use case—reducing risks from potentially malicious or non-compliant code. The consolidated pipeline configuration is then attested against the signed contract using the TDP’s policy-as-code. If approved, it is executed in the CCR to train the model, which we will deploy in the next section.
+
+```mermaid
+flowchart TD
+
+ subgraph Config Files
+ C1[model_config.json]
+ C2[dataset_config.json]
+ C3[loss_config.json]
+ C4[train_config_template.json]
+ C5[eval_config.json]
+ C6[join_config.json]
+ end
+
+ B[Consolidated into
pipeline_config.json]
+
+ C1 --> B
+ C2 --> B
+ C3 --> B
+ C4 --> B
+ C5 --> B
+ C6 --> B
+
+ B --> D[Attested against contract
using policy-as-code]
+ D --> E{Approved?}
+ E -- Yes --> F[CCR training begins]
+ E -- No --> H[Rejected: fix config]
+```
+
+If all goes well, you should see output similar to the following output, and the trained model and evaluation metrics will be saved under the folder [output](./modeller/output).
+
+```
+train-1 | Training samples: 43636
+train-1 | Validation samples: 10909
+train-1 | Test samples: 5455
+train-1 | Dataset constructed from config
+train-1 | Model loaded from ONNX file
+train-1 | Optimizer Adam loaded from config
+train-1 | Scheduler CyclicLR loaded from config
+train-1 | Custom loss function loaded from config
+train-1 | Epoch 1/1 completed | Training Loss: 0.1586
+train-1 | Epoch 1/1 completed | Validation Loss: 0.0860
+train-1 | Saving trained model to /mnt/remote/output/trained_model.onnx
+train-1 | Evaluation Metrics: {'test_loss': 0.08991911436687393, 'accuracy': 0.9523373052245646, 'f1_score': 0.9522986646537908}
+train-1 | CCR Training complete!
+train-1 |
+train-1 exited with code 0
+```
+
+## Deploy on CCR
In a more realistic scenario, this datasets will not be available in the clear to the TDC, and the TDC will be required to use a CCR for training. The following steps describe the process of sharing an encrypted dataset with TDCs and setting up a CCR in Azure for training. Please stay tuned for CCR on other cloud platforms.
@@ -88,7 +129,7 @@ To deploy in Azure, you will need the following.
- Docker Hub account to store container images. Alternatively, you can use pre-built images from the ```ispirt``` container registry.
- [Azure Key Vault](https://azure.microsoft.com/en-us/products/key-vault/) to store encryption keys and implement secure key release to CCR. You can either you Azure Key Vault Premium (lower cost), or [Azure Key Vault managed HSM](https://learn.microsoft.com/en-us/azure/key-vault/managed-hsm/overview) for enhanced security. Please see instructions below on how to create and setup your AKV instance.
-- Valid Azure subscription with sufficient access to create key vault, storage accounts, storage containers, and Azure Container Instances.
+- Valid Azure subscription with sufficient access to create key vault, storage accounts, storage containers, and Azure Container Instances (ACI).
If you are using your own development environment instead of a dev container or codespaces, you will to install the following dependencies.
@@ -102,120 +143,225 @@ We will be creating the following resources as part of the deployment.
- Azure Key Vault
- Azure Storage account
- Storage containers to host encrypted datasets
-- Azure Container Instances to deploy the CCR and train the model
+- Azure Container Instances (ACI) to deploy the CCR and train the model
-### Push Container Images
+### 1. Push Container Images
-If you wish to use your own container images, login to docker hub and push containers to your container registry.
+Pre-built container images are available in iSPIRT's container registry, which can be pulled by setting the following environment variable.
-> **Note:** Replace `` the name of your docker hub registry name.
+```bash
+export CONTAINER_REGISTRY=ispirt.azurecr.io
+```
+
+If you wish to use your own container images, login to docker hub (or your container registry of choice) and then build and push the container images to it, so that they can be pulled by the CCR. This is a one-time operation, and you can skip this step if you have already pushed the images to your container registry.
```bash
-export CONTAINER_REGISTRY=
-docker login
+export CONTAINER_REGISTRY=
+docker login -u -p ${CONTAINER_REGISTRY}
+cd $REPO_ROOT
./ci/push-containers.sh
-cd scenarios/mnist
+cd $REPO_ROOT/scenarios/$SCENARIO
./ci/push-containers.sh
```
-### Create Resources
+> **Note:** Replace ``, `` and `` with your container registry name, docker hub username and password respectively. Preferably use registry services other than Docker Hub as throttling restrictions will cause delays (or) image push/pull failures.
-Acting as the TDP, we will create a resource group, a key vault instance and storage containers to host the encrypted MNIST training dataset and encryption keys. In a real deployments, TDPs and TDCs will use their own key vault instance. However, for this sample, we will use one key vault instance to store keys for all datasets and models.
+### 2. Create Resources
-> **Note:** At this point, automated creation of AKV managed HSMs is not supported.
-
-> **Note:** Replace `` and `` with names of your choice. Storage account names must not container any special characters. Key vault endpoints are of the form `.vault.azure.net` (for Azure Key Vault Premium) and `.managedhsm.azure.net` for AKV managed HSM, **with no leading https**. This endpoint must be the same endpoint you used while creating the contract.
+First, set up the necessary environment variables for your deployment.
```bash
az login
+export SCENARIO=mnist
+export CONTAINER_REGISTRY=ispirt.azurecr.io
+export AZURE_LOCATION=northeurope
+export AZURE_SUBSCRIPTION_ID=
export AZURE_RESOURCE_GROUP=
-export AZURE_KEYVAULT_ENDPOINT=
-export AZURE_STORAGE_ACCOUNT_NAME=
-export AZURE_MNIST_CONTAINER_NAME=mnistdatacontainer
-export AZURE_MODEL_CONTAINER_NAME=mnistmodelcontainer
-export AZURE_OUTPUT_CONTAINER_NAME=mnistoutputcontainer
-
-cd scenarios/mnist/data
-./1-create-storage-containers.sh
-./2-create-akv.sh
+export AZURE_KEYVAULT_ENDPOINT=.vault.azure.net
+export AZURE_STORAGE_ACCOUNT_NAME=
+
+export AZURE_MNIST_CONTAINER_NAME=mnistcontainer
+export AZURE_MODEL_CONTAINER_NAME=modelcontainer
+export AZURE_OUTPUT_CONTAINER_NAME=outputcontainer
```
-### Sign and Register Contract
+Alternatively, you can edit the values in the [export-variables.sh](./export-variables.sh) script and run it to set the environment variables.
-Next, follow instructions [here](./../../external/contract-ledger/README.md) to sign and register a contract with the contract service. You can either deploy your own contract service or use a test contract service hosted at ```https://contract-service.westeurope.cloudapp.azure.com:8000```. The registered contract must contain references to the datasets with matching names, keyIDs and Azure Key Vault endpoints used in this sample. A sample contract template for this scenario is provided [here](./contract/contract.json). After updating, signing and registering the contract, retain the contract service URL and sequence number of the contract for the rest of this sample.
+```bash
+./export-variables.sh
+source export-variables.sh
+```
-### Import encryption keys
+Azure Naming Rules:
+- Resource Group:
+ - 1–90 characters
+ - Letters, numbers, underscores, parentheses, hyphens, periods allowed
+ - Cannot end with a period (.)
+ - Case-insensitive, unique within subscription\
+- Key Vault:
+ - 3-24 characters
+ - Globally unique name
+ - Lowercase letters, numbers, hyphens only
+ - Must start and end with letter or number
+- Storage Account:
+ - 3-24 characters
+ - Globally unique name
+ - Lowercase letters and numbers only
+- Storage Container:
+ - 3-63 characters
+ - Lowercase letters, numbers, hyphens only
+ - Must start and end with a letter or number
+ - No consecutive hyphens
+ - Unique within storage account
+
+---
+
+**Important:**
+
+The values for the environment variables listed below must precisely match the namesake environment variables used during contract signing (next step). Any mismatch will lead to execution failure.
+
+- `SCENARIO`
+- `AZURE_KEYVAULT_ENDPOINT`
+- `CONTRACT_SERVICE_URL`
+- `AZURE_STORAGE_ACCOUNT_NAME`
+- `AZURE_MNIST_CONTAINER_NAME`
+
+---
+With the environment variables set, we are ready to create the resources -- Azure Key Vault and Azure Storage containers.
-Next, use the following script to generate and import encryption keys into Azure Key Vault with a policy based on [policy-in-template.json](./policy/policy-in-template.json). The policy requires that the CCRs run specific containers with a specific configuration which includes the public identity of the contract service. Only CCRs that satisfy this policy will be granted access to the encryption keys.
+```bash
+cd $REPO_ROOT/scenarios/$SCENARIO/deployment/azure
+./1-create-storage-containers.sh
+./2-create-akv.sh
+```
+---
+
+### 3\. Contract Signing
+
+Navigate to the [contract-ledger](https://github.com/kapilvgit/contract-ledger/blob/main/README.md) repository and follow the instructions for contract signing.
-> **Note:** Replace `` with the path to and including the `depa-training` folder where the repository was cloned.
+Once the contract is signed, export the contract sequence number as an environment variable in the same terminal where you set the environment variables for the deployment.
```bash
-export CONTRACT_SERVICE_URL=
-export TOOLS_HOME=/external/confidential-sidecar-containers/tools
-./3-import-keys.sh
+export CONTRACT_SEQ_NO=
```
-The generated keys are available as files with the extension `.bin`.
+---
+
+### 4\. Data Encryption and Upload
-### Encrypt Dataset and Model
+Using their respective keys, the TDPs and TDC encrypt their datasets and model (respectively) and upload them to the Storage containers created in the previous step.
-Next, encrypt the dataset and models using keys generated in the previous step.
+Navigate to the [Azure deployment](./deployment/azure/) directory and execute the scripts for key import, data encryption and upload to Azure Blob Storage, in preparation of the CCR deployment.
+
+The import-keys script generates and imports encryption keys into Azure Key Vault with a policy based on [policy-in-template.json](./policy/policy-in-template.json). The policy requires that the CCRs run specific containers with a specific configuration which includes the public identity of the contract service. Only CCRs that satisfy this policy will be granted access to the encryption keys. The generated keys are available as files with the extension `.bin`.
```bash
-cd scenarios/mnist/data
-./4-encrypt-data.sh
+export CONTRACT_SERVICE_URL=https://depa-training-contract-service.centralindia.cloudapp.azure.com:8000
+export TOOLS_HOME=$REPO_ROOT/external/confidential-sidecar-containers/tools
+
+./3-import-keys.sh
```
-This step will generate three encrypted file system images (with extension `.img`), one for the dataset, one encrypted file system image containing the model, and one image where the trained model will be stored.
+The data and model are then packaged as encrypted filesystems by the TDPs and TDC using their respective keys, which are saved as `.img` files.
-### Upload Datasets
+```bash
+./4-encrypt-data.sh
+```
-Now upload encrypted datasets to Azure storage containers.
+The encrypted data and model are then uploaded to the Storage containers created in the previous step. The `.img` files are uploaded to the Storage containers as blobs.
```bash
./5-upload-encrypted-data.sh
```
-### Deploy CCR in Azure
+---
-Acting as a TDC, use the following script to deploy the CCR using Confidential Containers on Azure Container Instances.
+### 5\. CCR Deployment
+
+With the resources ready, we are ready to deploy the Confidential Clean Room (CCR) for executing the privacy-preserving model training.
+
+```bash
+export CONTRACT_SEQ_NO=
+./deploy.sh -c $CONTRACT_SEQ_NO -p ../../config/pipeline_config.json
+```
-> **Note:** Replace `` with the sequence number of the contract registered with the contract service.
+Set the `$CONTRACT_SEQ_NO` variable to the exact value of the contract sequence number (of format 2.XX). For example, if the number was 2.15, export as:
```bash
-cd scenarios/mnist/deployment/aci
-./deploy.sh -c -m ../../config/model_config.json -q ../../config/query_config.json
+export CONTRACT_SEQ_NO=15
```
This script will deploy the container images from your container registry, including the encrypted filesystem sidecar. The sidecar will generate an SEV-SNP attestation report, generate an attestation token using the Microsoft Azure Attestation (MAA) service, retrieve dataset, model and output encryption keys from the TDP and TDC's Azure Key Vault, train the model, and save the resulting model into TDC's output filesystem image, which the TDC can later decrypt.
-Once the deployment is complete, you can obtain logs from the CCR using the following commands. Note there may be some delay in getting the logs are deployment is complete.
+
+
+**Note:** The completion of this script's execution simply creates a CCR instance, and doesn't indicate whether training has completed or not. The training process might still be ongoing. Poll the container logs (see below) to track progress until training is complete.
+
+### 6\. Monitor Container Logs
+
+Use the following commands to monitor the logs of the deployed containers. You might have to repeatedly poll this command to monitor the training progress:
```bash
-# Obtain logs from the training container
-az container logs --name depa-training-mnist --resource-group $AZURE_RESOURCE_GROUP --container-name depa-training
+az container logs \
+ --name "depa-training-$SCENARIO" \
+ --resource-group "$AZURE_RESOURCE_GROUP" \
+ --container-name depa-training
+```
-# Obtain logs from the encrypted filesystem sidecar
-az container logs --name depa-training-mnist --resource-group $AZURE_RESOURCE_GROUP --container-name encrypted-storage-sidecar
+You will know training has completed when the logs print "CCR Training complete!".
+
+#### Troubleshooting
+
+In case training fails, you might want to monitor the logs of the encrypted storage sidecar container to see if the encryption process completed successfully:
+
+```bash
+az container logs --name depa-training-$SCENARIO --resource-group $AZURE_RESOURCE_GROUP --container-name encrypted-storage-sidecar
```
-### Download and decrypt trained model
+And to further debug, inspect the logs of the encrypted filesystem sidecar container:
-You can download and decrypt the trained model using the following script.
+```bash
+az container exec \
+ --resource-group $AZURE_RESOURCE_GROUP \
+ --name depa-training-$SCENARIO \
+ --container-name encrypted-storage-sidecar \
+ --exec-command "/bin/sh"
+```
+
+Once inside the sidecar container shell, view the logs:
+
+```bash
+cat log.txt
+```
+Or inspect the individual mounted directories in `mnt/remote/`:
+
+```bash
+cd mnt/remote && ls
+```
+
+### 6\. Download and Decrypt Model
+
+Once training has completed succesfully (The training container logs will mention it explicitly), download and decrypt the trained model and other training outputs.
```bash
-cd scenarios/mnist/data
./6-download-decrypt-model.sh
```
-The trained model is available in `output` folder.
+The outputs will be saved to the [output](./modeller/output/) directory.
+
+To check if the trained model is fresh, you can run the following command:
+```bash
+stat $REPO_ROOT/scenarios/$SCENARIO/modeller/output/trained_model.onnx
+```
+
+---
### Clean-up
-You can use the following command to delete the resource group and clean-up all resources used in the demo.
+You can use the following command to delete the resource group and clean-up all resources used in the demo. Alternatively, you can navigate to the Azure portal and delete the resource group created for this demo.
```bash
az group delete --yes --name $AZURE_RESOURCE_GROUP
-```
+```
\ No newline at end of file
diff --git a/scenarios/mnist/ci/Dockerfile.mnist b/scenarios/mnist/ci/Dockerfile.mnist
new file mode 100644
index 0000000..c85c798
--- /dev/null
+++ b/scenarios/mnist/ci/Dockerfile.mnist
@@ -0,0 +1,19 @@
+FROM ubuntu:22.04
+
+ENV DEBIAN_FRONTEND="noninteractive"
+
+RUN apt-get update && apt-get -y upgrade \
+ && apt-get install -y curl \
+ && apt-get install -y python3 python3-dev python3-distutils
+
+## Install pip
+RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
+RUN python3 get-pip.py
+
+## Install dependencies
+RUN pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
+RUN pip3 install h5py
+
+RUN apt-get install -y jq
+
+COPY preprocess_mnist.py preprocess_mnist.py
\ No newline at end of file
diff --git a/scenarios/mnist/ci/Dockerfile.savemodel b/scenarios/mnist/ci/Dockerfile.modelsave
similarity index 59%
rename from scenarios/mnist/ci/Dockerfile.savemodel
rename to scenarios/mnist/ci/Dockerfile.modelsave
index 1b08717..d58eb4a 100644
--- a/scenarios/mnist/ci/Dockerfile.savemodel
+++ b/scenarios/mnist/ci/Dockerfile.modelsave
@@ -1,17 +1,17 @@
-FROM ubuntu:20.04
+FROM ubuntu:22.04
ENV DEBIAN_FRONTEND="noninteractive"
RUN apt-get update && apt-get -y upgrade \
&& apt-get install -y gcc g++ curl \
- && apt-get install -y python3.9 python3.9-dev python3.9-distutils
+ && apt-get install -y python3 python3-dev python3-distutils
## Install pip
RUN curl https://bootstrap.pypa.io/get-pip.py -o get-pip.py
-RUN python3.9 get-pip.py
+RUN python3 get-pip.py
## Install dependencies
RUN pip3 install torch --index-url https://download.pytorch.org/whl/cpu
-RUN pip3 --default-timeout=1000 install onnx onnx2pytorch
+RUN pip3 --default-timeout=1000 install onnx
-COPY save_model.py save_model.py
+COPY save_base_model.py save_base_model.py
diff --git a/scenarios/mnist/ci/build.sh b/scenarios/mnist/ci/build.sh
index a4d6d02..2c9a8ef 100755
--- a/scenarios/mnist/ci/build.sh
+++ b/scenarios/mnist/ci/build.sh
@@ -1,4 +1,4 @@
#!/bin/bash
-docker build -f ci/Dockerfile.preprocess src -t depa-mnist-preprocess:latest
-docker build -f ci/Dockerfile.savemodel src -t depa-mnist-save-model:latest
+docker build -f ci/Dockerfile.mnist src -t preprocess-mnist:latest
+docker build -f ci/Dockerfile.modelsave src -t mnist-model-save:latest
\ No newline at end of file
diff --git a/scenarios/mnist/ci/pull-containers.sh b/scenarios/mnist/ci/pull-containers.sh
new file mode 100755
index 0000000..1c10596
--- /dev/null
+++ b/scenarios/mnist/ci/pull-containers.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+
+containers=("preprocess-mnist:latest" "mnist-model-save:latest")
+for container in "${containers[@]}"
+do
+ docker pull $CONTAINER_REGISTRY"/"$container
+done
\ No newline at end of file
diff --git a/scenarios/mnist/ci/push-containers.sh b/scenarios/mnist/ci/push-containers.sh
index 9542648..1a7e038 100755
--- a/scenarios/mnist/ci/push-containers.sh
+++ b/scenarios/mnist/ci/push-containers.sh
@@ -1,4 +1,4 @@
-containers=("depa-mnist-save-model:latest" "depa-mnist-preprocess:latest")
+containers=("mnist-model-save:latest" "preprocess-mnist:latest")
for container in "${containers[@]}"
do
docker tag $container $CONTAINER_REGISTRY"/"$container
diff --git a/scenarios/mnist/config/consolidate_pipeline.sh b/scenarios/mnist/config/consolidate_pipeline.sh
new file mode 100755
index 0000000..38af49c
--- /dev/null
+++ b/scenarios/mnist/config/consolidate_pipeline.sh
@@ -0,0 +1,58 @@
+#! /bin/bash
+
+REPO_ROOT="$(git rev-parse --show-toplevel)"
+SCENARIO=mnist
+
+template_path="$REPO_ROOT/scenarios/$SCENARIO/config/templates"
+model_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/model_config.json"
+data_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/dataset_config.json"
+loss_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/loss_config.json"
+train_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/train_config.json"
+eval_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/eval_config.json"
+join_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/join_config.json"
+pipeline_config_path="$REPO_ROOT/scenarios/$SCENARIO/config/pipeline_config.json"
+
+# populate "model_config", "data_config", and "loss_config" keys in train config
+train_config=$(cat $template_path/train_config_template.json)
+
+# Only merge if the file exists
+if [[ -f "$model_config_path" ]]; then
+ model_config=$(cat $model_config_path)
+ train_config=$(echo "$train_config" | jq --argjson model "$model_config" '.config.model_config = $model')
+fi
+
+if [[ -f "$data_config_path" ]]; then
+ data_config=$(cat $data_config_path)
+ train_config=$(echo "$train_config" | jq --argjson data "$data_config" '.config.dataset_config = $data')
+fi
+
+if [[ -f "$loss_config_path" ]]; then
+ loss_config=$(cat $loss_config_path)
+ train_config=$(echo "$train_config" | jq --argjson loss "$loss_config" '.config.loss_config = $loss')
+fi
+
+if [[ -f "$eval_config_path" ]]; then
+ eval_config=$(cat $eval_config_path)
+ # Get all keys from eval_config and copy them to train_config
+ for key in $(echo "$eval_config" | jq -r 'keys[]'); do
+ train_config=$(echo "$train_config" | jq --argjson eval "$eval_config" --arg key "$key" '.config[$key] = $eval[$key]')
+ done
+fi
+
+# save train_config
+echo "$train_config" > $train_config_path
+
+# prepare pipeline config from join_config.json (first dict "config") and train_config.json (second dict "config")
+pipeline_config=$(cat $template_path/pipeline_config_template.json)
+
+# Only merge join_config if the file exists
+if [[ -f "$join_config_path" ]]; then
+ join_config=$(cat $join_config_path)
+ pipeline_config=$(echo "$pipeline_config" | jq --argjson join "$join_config" '.pipeline += [$join]')
+fi
+
+# Always merge train_config as it's required
+pipeline_config=$(echo "$pipeline_config" | jq --argjson train "$train_config" '.pipeline += [$train]')
+
+# save pipeline_config to pipeline_config.json
+echo "$pipeline_config" > $pipeline_config_path
\ No newline at end of file
diff --git a/scenarios/mnist/config/dataset_config.json b/scenarios/mnist/config/dataset_config.json
new file mode 100644
index 0000000..31f8c69
--- /dev/null
+++ b/scenarios/mnist/config/dataset_config.json
@@ -0,0 +1,17 @@
+{
+ "type": "serialized",
+ "format": "hdf5",
+ "structure": "list_of_tuples",
+ "features_key": "features",
+ "targets_key": "targets",
+ "transforms": {
+ "normalize": true,
+ "augment": false
+ },
+ "splits": {
+ "train": 0.7,
+ "val": 0.2,
+ "test": 0.1,
+ "random_state": 42
+ }
+}
\ No newline at end of file
diff --git a/scenarios/mnist/config/eval_config.json b/scenarios/mnist/config/eval_config.json
new file mode 100644
index 0000000..2b4b375
--- /dev/null
+++ b/scenarios/mnist/config/eval_config.json
@@ -0,0 +1,20 @@
+{
+ "task_type": "classification",
+ "metrics": [
+ "accuracy",
+ {
+ "name": "confusion_matrix",
+ "params": {}
+ },
+ {
+ "name": "f1_score",
+ "params": {
+ "average": "macro"
+ }
+ },
+ {
+ "name": "classification_report",
+ "params": {}
+ }
+ ]
+}
\ No newline at end of file
diff --git a/scenarios/mnist/config/loss_config.json b/scenarios/mnist/config/loss_config.json
new file mode 100644
index 0000000..68b5c0f
--- /dev/null
+++ b/scenarios/mnist/config/loss_config.json
@@ -0,0 +1,22 @@
+{
+ "expression": "alpha * (1 - pt)**gamma * ce_loss",
+ "components": {
+ "ce_loss": {
+ "class": "nn.CrossEntropyLoss",
+ "params": {
+ "reduction": "none"
+ }
+ },
+ "pt": {
+ "class": "torch.exp",
+ "params": {
+ "input": "-$ce_loss"
+ }
+ }
+ },
+ "variables": {
+ "alpha": 1,
+ "gamma": 2
+ },
+ "reduction": "mean"
+}
\ No newline at end of file
diff --git a/scenarios/mnist/config/pipeline_config.json b/scenarios/mnist/config/pipeline_config.json
index 6aaea77..1be9c91 100644
--- a/scenarios/mnist/config/pipeline_config.json
+++ b/scenarios/mnist/config/pipeline_config.json
@@ -1,24 +1,89 @@
{
"pipeline": [
{
- "name": "Train",
+ "name": "Train_DL",
"config": {
- "input_dataset_path": "/mnt/remote/mnist/cifar10-dataset.pth",
- "saved_model_path": "/mnt/remote/model/model.onnx",
- "saved_model_optimizer": "/mnt/remote/model/dpsgd_model_opimizer.pth",
- "trained_model_output_path":"/mnt/remote/output/model.onnx",
- "saved_weights_path": "",
- "batch_size": 4,
- "total_epochs": 2,
- "max_grad_norm": 0.1,
- "sample_size": 60000,
- "target_variable": "icmr_a_icmr_test_result",
- "test_train_split": 0.2,
+ "paths": {
+ "input_dataset_path": "/mnt/remote/mnist/mnist-dataset.h5",
+ "base_model_path": "/mnt/remote/model/model.onnx",
+ "trained_model_output_path": "/mnt/remote/output"
+ },
+ "model_type": "onnx",
+ "is_private": false,
+ "device": "cpu",
+ "batch_size": 16,
+ "optimizer": {
+ "name": "Adam",
+ "params": {
+ "lr": 0.0001
+ }
+ },
+ "scheduler": {
+ "name": "CyclicLR",
+ "params": {
+ "base_lr": 0.0001,
+ "max_lr": 0.01,
+ "cycle_momentum": false
+ }
+ },
+ "total_epochs": 1,
+ "dataset_config": {
+ "type": "serialized",
+ "format": "hdf5",
+ "structure": "list_of_tuples",
+ "features_key": "features",
+ "targets_key": "targets",
+ "transforms": {
+ "normalize": true,
+ "augment": false
+ },
+ "splits": {
+ "train": 0.7,
+ "val": 0.2,
+ "test": 0.1,
+ "random_state": 42
+ }
+ },
+ "loss_config": {
+ "expression": "alpha * (1 - pt)**gamma * ce_loss",
+ "components": {
+ "ce_loss": {
+ "class": "nn.CrossEntropyLoss",
+ "params": {
+ "reduction": "none"
+ }
+ },
+ "pt": {
+ "class": "torch.exp",
+ "params": {
+ "input": "-$ce_loss"
+ }
+ }
+ },
+ "variables": {
+ "alpha": 1,
+ "gamma": 2
+ },
+ "reduction": "mean"
+ },
"metrics": [
"accuracy",
- "precision",
- "recall"
- ]
+ {
+ "name": "confusion_matrix",
+ "params": {}
+ },
+ {
+ "name": "f1_score",
+ "params": {
+ "average": "macro"
+ }
+ },
+ {
+ "name": "classification_report",
+ "params": {}
+ }
+ ],
+ "task_type": "classification"
}
}
]
diff --git a/scenarios/mnist/config/templates/pipeline_config_template.json b/scenarios/mnist/config/templates/pipeline_config_template.json
new file mode 100644
index 0000000..43e9e84
--- /dev/null
+++ b/scenarios/mnist/config/templates/pipeline_config_template.json
@@ -0,0 +1,3 @@
+{
+ "pipeline": []
+}
\ No newline at end of file
diff --git a/scenarios/mnist/config/templates/train_config_template.json b/scenarios/mnist/config/templates/train_config_template.json
new file mode 100644
index 0000000..d2c68c7
--- /dev/null
+++ b/scenarios/mnist/config/templates/train_config_template.json
@@ -0,0 +1,29 @@
+{
+ "name": "Train_DL",
+ "config": {
+ "paths": {
+ "input_dataset_path": "/mnt/remote/mnist/mnist-dataset.h5",
+ "base_model_path": "/mnt/remote/model/model.onnx",
+ "trained_model_output_path": "/mnt/remote/output"
+ },
+ "model_type": "onnx",
+ "is_private": false,
+ "device": "cpu",
+ "batch_size": 16,
+ "optimizer": {
+ "name": "Adam",
+ "params": {
+ "lr": 1e-4
+ }
+ },
+ "scheduler": {
+ "name": "CyclicLR",
+ "params": {
+ "base_lr": 1e-4,
+ "max_lr": 1e-2,
+ "cycle_momentum": false
+ }
+ },
+ "total_epochs": 1
+ }
+}
\ No newline at end of file
diff --git a/scenarios/mnist/config/train_config.json b/scenarios/mnist/config/train_config.json
new file mode 100644
index 0000000..8f3e0cb
--- /dev/null
+++ b/scenarios/mnist/config/train_config.json
@@ -0,0 +1,86 @@
+{
+ "name": "Train_DL",
+ "config": {
+ "paths": {
+ "input_dataset_path": "/mnt/remote/mnist/mnist-dataset.h5",
+ "base_model_path": "/mnt/remote/model/model.onnx",
+ "trained_model_output_path": "/mnt/remote/output"
+ },
+ "model_type": "onnx",
+ "is_private": false,
+ "device": "cpu",
+ "batch_size": 16,
+ "optimizer": {
+ "name": "Adam",
+ "params": {
+ "lr": 0.0001
+ }
+ },
+ "scheduler": {
+ "name": "CyclicLR",
+ "params": {
+ "base_lr": 0.0001,
+ "max_lr": 0.01,
+ "cycle_momentum": false
+ }
+ },
+ "total_epochs": 1,
+ "dataset_config": {
+ "type": "serialized",
+ "format": "hdf5",
+ "structure": "list_of_tuples",
+ "features_key": "features",
+ "targets_key": "targets",
+ "transforms": {
+ "normalize": true,
+ "augment": false
+ },
+ "splits": {
+ "train": 0.7,
+ "val": 0.2,
+ "test": 0.1,
+ "random_state": 42
+ }
+ },
+ "loss_config": {
+ "expression": "alpha * (1 - pt)**gamma * ce_loss",
+ "components": {
+ "ce_loss": {
+ "class": "nn.CrossEntropyLoss",
+ "params": {
+ "reduction": "none"
+ }
+ },
+ "pt": {
+ "class": "torch.exp",
+ "params": {
+ "input": "-$ce_loss"
+ }
+ }
+ },
+ "variables": {
+ "alpha": 1,
+ "gamma": 2
+ },
+ "reduction": "mean"
+ },
+ "metrics": [
+ "accuracy",
+ {
+ "name": "confusion_matrix",
+ "params": {}
+ },
+ {
+ "name": "f1_score",
+ "params": {
+ "average": "macro"
+ }
+ },
+ {
+ "name": "classification_report",
+ "params": {}
+ }
+ ],
+ "task_type": "classification"
+ }
+}
diff --git a/scenarios/mnist/contract/contract.json b/scenarios/mnist/contract/contract.json
index b8362cf..91b3f32 100644
--- a/scenarios/mnist/contract/contract.json
+++ b/scenarios/mnist/contract/contract.json
@@ -3,21 +3,21 @@
"schemaVersion": "0.1",
"startTime": "2023-03-14T00:00:00.000Z",
"expiryTime": "2024-03-14T00:00:00.000Z",
- "tdc" : "",
- "tdps" : [],
- "ccrp": "did:web:ccrprovider.github.io",
+ "tdc": "",
+ "tdps": [],
+ "ccrp": "did:web:$CCRP_USERNAME.github.io",
"datasets": [
{
- "id" : "19517ba8-bab8-11ed-afa1-0242ac120002",
+ "id": "19517ba8-bab8-11ed-afa1-0242ac120002",
"name": "mnist",
- "url" : "https://ccrcontainer.blob.core.windows.net/mnist/data.img",
+ "url": "https://$AZURE_STORAGE_ACCOUNT_NAME.blob.core.windows.net/$AZURE_MNIST_CONTAINER_NAME/data.img",
"provider": "",
- "key" : {
+ "key": {
"type": "azure",
"properties": {
"kid": "MNISTFilesystemEncryptionKey",
"authority": {
- "endpoint": "sharedneu.neu.attest.azure.net"
+ "endpoint": "sharedneu.neu.attest.azure.net"
},
"endpoint": ""
}
@@ -26,9 +26,7 @@
],
"purpose": "TRAINING",
"terms": {
- "payment" : {
- },
- "revocation": {
- }
+ "payment": {},
+ "revocation": {}
}
-}
+}
\ No newline at end of file
diff --git a/scenarios/mnist/data/1-create-storage-containers.sh b/scenarios/mnist/data/1-create-storage-containers.sh
deleted file mode 100755
index 95d2299..0000000
--- a/scenarios/mnist/data/1-create-storage-containers.sh
+++ /dev/null
@@ -1,24 +0,0 @@
-#!/bin/bash
-
-az group create \
- --location westeurope \
- --name $AZURE_RESOURCE_GROUP
-
-az storage account create \
- --resource-group $AZURE_RESOURCE_GROUP \
- --name $AZURE_STORAGE_ACCOUNT_NAME
-
-az storage container create \
- --resource-group $AZURE_RESOURCE_GROUP \
- --account-name $AZURE_STORAGE_ACCOUNT_NAME \
- --name $AZURE_MNIST_CONTAINER_NAME
-
-az storage container create \
- --resource-group $AZURE_RESOURCE_GROUP \
- --account-name $AZURE_STORAGE_ACCOUNT_NAME \
- --name $AZURE_MODEL_CONTAINER_NAME
-
-az storage container create \
- --resource-group $AZURE_RESOURCE_GROUP \
- --account-name $AZURE_STORAGE_ACCOUNT_NAME \
- --name $AZURE_OUTPUT_CONTAINER_NAME
diff --git a/scenarios/mnist/data/2-create-akv.sh b/scenarios/mnist/data/2-create-akv.sh
deleted file mode 100755
index 6cbcb46..0000000
--- a/scenarios/mnist/data/2-create-akv.sh
+++ /dev/null
@@ -1,18 +0,0 @@
-#!/bin/bash
-
-set -e
-
-echo CREATING $AZURE_KEYVAULT_ENDPOINT in resouce group $AZURE_RESOURCE_GROUP
-echo $AZURE_RESOURCE_GROUP
-
-if [[ "$AZURE_KEYVAULT_ENDPOINT" == *".vault.azure.net" ]]; then
- # Create Azure key vault with RBAC authorization
- AZURE_AKV_RESOURCE_NAME=`echo $AZURE_KEYVAULT_ENDPOINT | awk '{split($0,a,"."); print a[1]}'`
- az keyvault create --name $AZURE_AKV_RESOURCE_NAME --resource-group $AZURE_RESOURCE_GROUP --sku "Premium" --enable-rbac-authorization
- # Assign RBAC roles to the resource owner so they can import keys
- AKV_SCOPE=`az keyvault show --name $AZURE_AKV_RESOURCE_NAME --query id --output tsv`
- az role assignment create --role "Key Vault Crypto Officer" --assignee `az account show --query user.name --output tsv` --scope $AKV_SCOPE
- az role assignment create --role "Key Vault Crypto User" --assignee `az account show --query user.name --output tsv` --scope $AKV_SCOPE
-else
- echo "Automated creation of key vaults is supported only for vaults"
-fi
\ No newline at end of file
diff --git a/scenarios/mnist/data/4-encrypt-data.sh b/scenarios/mnist/data/4-encrypt-data.sh
deleted file mode 100755
index 9c5332e..0000000
--- a/scenarios/mnist/data/4-encrypt-data.sh
+++ /dev/null
@@ -1,6 +0,0 @@
-#!/bin/bash
-
-./generatefs.sh -d preprocessed -k mnistkey.bin -i mnist.img
-./generatefs.sh -d model -k modelkey.bin -i model.img
-mkdir -p output
-./generatefs.sh -d output -k outputkey.bin -i output.img
\ No newline at end of file
diff --git a/scenarios/mnist/deployment/azure/0-create-acr.sh b/scenarios/mnist/deployment/azure/0-create-acr.sh
new file mode 100755
index 0000000..4719bad
--- /dev/null
+++ b/scenarios/mnist/deployment/azure/0-create-acr.sh
@@ -0,0 +1,42 @@
+#!/bin/bash
+
+# Only to be run when creating a new ACR
+
+# Ensure required env vars are set
+if [[ -z "$CONTAINER_REGISTRY" || -z "$AZURE_RESOURCE_GROUP" || -z "$AZURE_LOCATION" ]]; then
+ echo "ERROR: CONTAINER_REGISTRY, AZURE_RESOURCE_GROUP, and AZURE_LOCATION environment variables must be set."
+ exit 1
+fi
+
+echo "Checking if ACR '$CONTAINER_REGISTRY' exists in resource group '$AZURE_RESOURCE_GROUP'..."
+
+# Check if ACR exists
+ACR_EXISTS=$(az acr show --name "$CONTAINER_REGISTRY" --resource-group "$AZURE_RESOURCE_GROUP" --query "name" -o tsv 2>/dev/null)
+
+if [[ -n "$ACR_EXISTS" ]]; then
+ echo "✅ ACR '$CONTAINER_REGISTRY' already exists."
+else
+ echo "⏳ ACR '$CONTAINER_REGISTRY' does not exist. Creating..."
+
+ # Create ACR with premium SKU and admin enabled
+ az acr create \
+ --name "$CONTAINER_REGISTRY" \
+ --resource-group "$AZURE_RESOURCE_GROUP" \
+ --location "$AZURE_LOCATION" \
+ --sku Premium \
+ --admin-enabled true \
+ --output table
+
+ # Enable anonymous pull
+ az acr update --name "$CONTAINER_REGISTRY" --anonymous-pull-enabled true
+
+ if [[ $? -eq 0 ]]; then
+ echo "✅ ACR '$CONTAINER_REGISTRY' created successfully."
+ else
+ echo "❌ Failed to create ACR."
+ exit 1
+ fi
+fi
+
+# Login to the ACR
+az acr login --name "$CONTAINER_REGISTRY"
\ No newline at end of file
diff --git a/scenarios/mnist/deployment/azure/1-create-storage-containers.sh b/scenarios/mnist/deployment/azure/1-create-storage-containers.sh
new file mode 100755
index 0000000..49c6ba6
--- /dev/null
+++ b/scenarios/mnist/deployment/azure/1-create-storage-containers.sh
@@ -0,0 +1,50 @@
+#!/bin/bash
+#
+echo "Checking if resource group $AZURE_RESOURCE_GROUP exists..."
+RG_EXISTS=$(az group exists --name $AZURE_RESOURCE_GROUP)
+
+if [ "$RG_EXISTS" == "false" ]; then
+ echo "Resource group $AZURE_RESOURCE_GROUP does not exist. Creating it now..."
+ # Create the resource group
+ az group create --name $AZURE_RESOURCE_GROUP --location $AZURE_LOCATION
+else
+ echo "Resource group $AZURE_RESOURCE_GROUP already exists. Skipping creation."
+fi
+
+echo "Check if storage account $STORAGE_ACCOUNT_NAME exists..."
+STORAGE_ACCOUNT_EXISTS=$(az storage account check-name --name $AZURE_STORAGE_ACCOUNT_NAME --query "nameAvailable" --output tsv)
+
+if [ "$STORAGE_ACCOUNT_EXISTS" == "true" ]; then
+ echo "Storage account $AZURE_STORAGE_ACCOUNT_NAME does not exist. Creating it now..."
+ az storage account create --resource-group $AZURE_RESOURCE_GROUP --name $AZURE_STORAGE_ACCOUNT_NAME
+else
+ echo "Storage account $AZURE_STORAGE_ACCOUNT_NAME already exists. Skipping creation."
+fi
+
+# Get the storage account key
+ACCOUNT_KEY=$(az storage account keys list --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --query "[0].value" --output tsv)
+
+
+# Check if the MNIST container exists
+CONTAINER_EXISTS=$(az storage container exists --name $AZURE_MNIST_CONTAINER_NAME --account-name $AZURE_STORAGE_ACCOUNT_NAME --account-key $ACCOUNT_KEY --query "exists" --output tsv)
+
+if [ "$CONTAINER_EXISTS" == "false" ]; then
+ echo "Container $AZURE_MNIST_CONTAINER_NAME does not exist. Creating it now..."
+ az storage container create --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --name $AZURE_MNIST_CONTAINER_NAME --account-key $ACCOUNT_KEY
+fi
+
+# Check if the MODEL container exists
+CONTAINER_EXISTS=$(az storage container exists --name $AZURE_MODEL_CONTAINER_NAME --account-name $AZURE_STORAGE_ACCOUNT_NAME --account-key $ACCOUNT_KEY --query "exists" --output tsv)
+
+if [ "$CONTAINER_EXISTS" == "false" ]; then
+ echo "Container $AZURE_MODEL_CONTAINER_NAME does not exist. Creating it now..."
+ az storage container create --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --name $AZURE_MODEL_CONTAINER_NAME --account-key $ACCOUNT_KEY
+fi
+
+# Check if the OUTPUT container exists
+CONTAINER_EXISTS=$(az storage container exists --name $AZURE_OUTPUT_CONTAINER_NAME --account-name $AZURE_STORAGE_ACCOUNT_NAME --account-key $ACCOUNT_KEY --query "exists" --output tsv)
+
+if [ "$CONTAINER_EXISTS" == "false" ]; then
+ echo "Container $AZURE_OUTPUT_CONTAINER_NAME does not exist. Creating it now..."
+ az storage container create --resource-group $AZURE_RESOURCE_GROUP --account-name $AZURE_STORAGE_ACCOUNT_NAME --name $AZURE_OUTPUT_CONTAINER_NAME --account-key $ACCOUNT_KEY
+fi
\ No newline at end of file
diff --git a/scenarios/mnist/deployment/azure/2-create-akv.sh b/scenarios/mnist/deployment/azure/2-create-akv.sh
new file mode 100755
index 0000000..c20a75e
--- /dev/null
+++ b/scenarios/mnist/deployment/azure/2-create-akv.sh
@@ -0,0 +1,27 @@
+#!/bin/bash
+
+set -e
+
+if [[ "$AZURE_KEYVAULT_ENDPOINT" == *".vault.azure.net" ]]; then
+ AZURE_AKV_RESOURCE_NAME=`echo $AZURE_KEYVAULT_ENDPOINT | awk '{split($0,a,"."); print a[1]}'`
+ # Check if the Key Vault already exists
+ echo "Checking if Key Vault $AZURE_AKV_RESOURCE_NAME exists..."
+ NAME_AVAILABLE=$(az rest --method post \
+ --uri "https://management.azure.com/subscriptions/$AZURE_SUBSCRIPTION_ID/providers/Microsoft.KeyVault/checkNameAvailability?api-version=2019-09-01" \
+ --headers "Content-Type=application/json" \
+ --body "{\"name\": \"$AZURE_AKV_RESOURCE_NAME\", \"type\": \"Microsoft.KeyVault/vaults\"}" | jq -r '.nameAvailable')
+ if [ "$NAME_AVAILABLE" == true ]; then
+ echo "Key Vault $AZURE_AKV_RESOURCE_NAME does not exist. Creating it now..."
+ echo CREATING $AZURE_KEYVAULT_ENDPOINT in resouce group $AZURE_RESOURCE_GROUP
+ # Create Azure key vault with RBAC authorization
+ az keyvault create --name $AZURE_AKV_RESOURCE_NAME --resource-group $AZURE_RESOURCE_GROUP --sku "Premium" --enable-rbac-authorization
+ # Assign RBAC roles to the resource owner so they can import keys
+ AKV_SCOPE=`az keyvault show --name $AZURE_AKV_RESOURCE_NAME --query id --output tsv`
+ az role assignment create --role "Key Vault Crypto Officer" --assignee `az account show --query user.name --output tsv` --scope $AKV_SCOPE
+ az role assignment create --role "Key Vault Crypto User" --assignee `az account show --query user.name --output tsv` --scope $AKV_SCOPE
+ else
+ echo "Key Vault $AZURE_AKV_RESOURCE_NAME already exists. Skipping creation."
+ fi
+else
+ echo "Automated creation of key vaults is supported only for vaults"
+fi
diff --git a/scenarios/mnist/data/3-import-keys.sh b/scenarios/mnist/deployment/azure/3-import-keys.sh
similarity index 87%
rename from scenarios/mnist/data/3-import-keys.sh
rename to scenarios/mnist/deployment/azure/3-import-keys.sh
index 268f4df..7e6abed 100755
--- a/scenarios/mnist/data/3-import-keys.sh
+++ b/scenarios/mnist/deployment/azure/3-import-keys.sh
@@ -29,7 +29,7 @@ echo Obtaining contract service parameters...
CONTRACT_SERVICE_URL=${CONTRACT_SERVICE_URL:-"http://localhost:8000"}
export CONTRACT_SERVICE_PARAMETERS=$(curl -k -f $CONTRACT_SERVICE_URL/parameters | base64 --wrap=0)
-envsubst < ../policy/policy-in-template.json > /tmp/policy-in.json
+envsubst < ../../policy/policy-in-template.json > /tmp/policy-in.json
export CCE_POLICY=$(az confcom acipolicygen -i /tmp/policy-in.json --debug-mode)
export CCE_POLICY_HASH=$(go run $TOOLS_HOME/securitypolicydigest/main.go -p $CCE_POLICY)
echo "Training container policy hash $CCE_POLICY_HASH"
@@ -44,10 +44,12 @@ elif [[ "$AZURE_KEYVAULT_ENDPOINT" == *".managedhsm.azure.net" ]]; then
export AZURE_AKV_KEY_TYPE="oct-HSM"
fi
-DATADIR=`pwd`
-import_key "MNISTFilesystemEncryptionKey" $DATADIR/mnistkey.bin
-import_key "ModelFilesystemEncryptionKey" $DATADIR/modelkey.bin
-import_key "OutputFilesystemEncryptionKey" $DATADIR/outputkey.bin
+DATADIR=$REPO_ROOT/scenarios/$SCENARIO/data
+MODELDIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+import_key "MNISTFilesystemEncryptionKey" $DATADIR/mnist_key.bin
+import_key "ModelFilesystemEncryptionKey" $MODELDIR/model_key.bin
+import_key "OutputFilesystemEncryptionKey" $MODELDIR/output_key.bin
## Cleanup
rm /tmp/importkey-config.json
diff --git a/scenarios/mnist/deployment/azure/4-encrypt-data.sh b/scenarios/mnist/deployment/azure/4-encrypt-data.sh
new file mode 100755
index 0000000..6705565
--- /dev/null
+++ b/scenarios/mnist/deployment/azure/4-encrypt-data.sh
@@ -0,0 +1,11 @@
+#!/bin/bash
+
+DATADIR=$REPO_ROOT/scenarios/$SCENARIO/data
+MODELDIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+./generatefs.sh -d $DATADIR/preprocessed -k $DATADIR/mnist_key.bin -i $DATADIR/mnist.img
+./generatefs.sh -d $MODELDIR/models -k $MODELDIR/model_key.bin -i $MODELDIR/model.img
+
+sudo rm -rf $MODELDIR/output
+mkdir -p $MODELDIR/output
+./generatefs.sh -d $MODELDIR/output -k $MODELDIR/output_key.bin -i $MODELDIR/output.img
\ No newline at end of file
diff --git a/scenarios/mnist/data/5-upload-encrypted-data.sh b/scenarios/mnist/deployment/azure/5-upload-encrypted-data.sh
similarity index 78%
rename from scenarios/mnist/data/5-upload-encrypted-data.sh
rename to scenarios/mnist/deployment/azure/5-upload-encrypted-data.sh
index 4629190..b7f1ad0 100755
--- a/scenarios/mnist/data/5-upload-encrypted-data.sh
+++ b/scenarios/mnist/deployment/azure/5-upload-encrypted-data.sh
@@ -1,11 +1,14 @@
#!/bin/bash
+export DATA_DIR=$REPO_ROOT/scenarios/$SCENARIO/data
+export MODEL_DIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
ACCOUNT_KEY=$(az storage account keys list --account-name $AZURE_STORAGE_ACCOUNT_NAME --only-show-errors | jq -r .[0].value)
az storage blob upload \
--account-name $AZURE_STORAGE_ACCOUNT_NAME \
--container $AZURE_MNIST_CONTAINER_NAME \
- --file mnist.img \
+ --file $DATA_DIR/mnist.img \
--name data.img \
--type page \
--overwrite \
@@ -14,7 +17,7 @@ az storage blob upload \
az storage blob upload \
--account-name $AZURE_STORAGE_ACCOUNT_NAME \
--container $AZURE_MODEL_CONTAINER_NAME \
- --file model.img \
+ --file $MODEL_DIR/model.img \
--name data.img \
--type page \
--overwrite \
@@ -23,7 +26,7 @@ az storage blob upload \
az storage blob upload \
--account-name $AZURE_STORAGE_ACCOUNT_NAME \
--container $AZURE_OUTPUT_CONTAINER_NAME \
- --file output.img \
+ --file $MODEL_DIR/output.img \
--name data.img \
--type page \
--overwrite \
diff --git a/scenarios/mnist/deployment/azure/6-download-decrypt-model.sh b/scenarios/mnist/deployment/azure/6-download-decrypt-model.sh
new file mode 100755
index 0000000..b6d043a
--- /dev/null
+++ b/scenarios/mnist/deployment/azure/6-download-decrypt-model.sh
@@ -0,0 +1,37 @@
+#!/bin/bash
+
+MODELDIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+sudo rm -rf $MODELDIR/output
+mkdir -p $MODELDIR/output
+
+ACCOUNT_KEY=$(az storage account keys list --account-name $AZURE_STORAGE_ACCOUNT_NAME --only-show-errors | jq -r .[0].value)
+
+az storage blob download \
+ --account-name $AZURE_STORAGE_ACCOUNT_NAME \
+ --container $AZURE_OUTPUT_CONTAINER_NAME \
+ --file $MODELDIR/output.img \
+ --name data.img \
+ --account-key $ACCOUNT_KEY
+
+encryptedImage=$MODELDIR/output.img
+keyFilePath=$MODELDIR/output_key.bin
+
+echo Decrypting $encryptedImage with key $keyFilePath
+deviceName=cryptdevice1
+deviceNamePath="/dev/mapper/$deviceName"
+
+sudo cryptsetup luksOpen "$encryptedImage" "$deviceName" \
+ --key-file "$keyFilePath" \
+ --integrity-no-journal --persistent
+
+mountPoint=`mktemp -d`
+sudo mount -t ext4 "$deviceNamePath" "$mountPoint" -o loop
+
+cp -r $mountPoint/* $MODELDIR/output/
+
+echo "[!] Closing device..."
+
+sudo umount "$mountPoint"
+sleep 2
+sudo cryptsetup luksClose "$deviceName"
\ No newline at end of file
diff --git a/scenarios/mnist/deployment/azure/aci-parameters-template.json b/scenarios/mnist/deployment/azure/aci-parameters-template.json
new file mode 100644
index 0000000..8eb11fc
--- /dev/null
+++ b/scenarios/mnist/deployment/azure/aci-parameters-template.json
@@ -0,0 +1,23 @@
+{
+ "containerRegistry": {
+ "value": ""
+ },
+ "ccePolicy": {
+ "value": ""
+ },
+ "EncfsSideCarArgs": {
+ "value": ""
+ },
+ "ContractService": {
+ "value": ""
+ },
+ "ContractServiceParameters": {
+ "value": ""
+ },
+ "Contracts": {
+ "value": ""
+ },
+ "PipelineConfiguration": {
+ "value": ""
+ }
+}
\ No newline at end of file
diff --git a/scenarios/mnist/deployment/aci/arm-template.json b/scenarios/mnist/deployment/azure/arm-template.json
similarity index 95%
rename from scenarios/mnist/deployment/aci/arm-template.json
rename to scenarios/mnist/deployment/azure/arm-template.json
index 9661dab..93f1acc 100644
--- a/scenarios/mnist/deployment/aci/arm-template.json
+++ b/scenarios/mnist/deployment/azure/arm-template.json
@@ -10,7 +10,7 @@
}
},
"location": {
- "defaultValue": "[resourceGroup().location]",
+ "defaultValue": "northeurope",
"type": "string",
"metadata": {
"description": "Location for all resources."
@@ -81,7 +81,7 @@
"defaultValue": "secureString",
"type": "string",
"metadata": {
- "description": "Configuration representing the pipeline to be trained"
+ "description": "Pipeline configuration"
}
}
},
@@ -154,9 +154,9 @@
"mountPath": "/mnt/remote"
}
],
- "securityContext": {
- "privileged": "true"
- },
+ "securityContext": {
+ "privileged": "true"
+ },
"resources": {
"requests": {
"cpu": 0.5,
@@ -178,4 +178,4 @@
}
}
]
-}
+}
\ No newline at end of file
diff --git a/scenarios/mnist/deployment/aci/deploy.sh b/scenarios/mnist/deployment/azure/deploy.sh
similarity index 93%
rename from scenarios/mnist/deployment/aci/deploy.sh
rename to scenarios/mnist/deployment/azure/deploy.sh
index 4cf21b7..60eab5c 100755
--- a/scenarios/mnist/deployment/aci/deploy.sh
+++ b/scenarios/mnist/deployment/azure/deploy.sh
@@ -112,9 +112,16 @@ echo $TMP > /tmp/aci-parameters.json
echo Deploying training clean room...
-az group create \
- --location westeurope \
- --name $AZURE_RESOURCE_GROUP
+echo "Checking if resource group $AZURE_RESOURCE_GROUP exists..."
+RG_EXISTS=$(az group exists --name $AZURE_RESOURCE_GROUP)
+
+if [ "$RG_EXISTS" == "false" ]; then
+ echo "Resource group $AZURE_RESOURCE_GROUP does not exist. Creating it now..."
+ # Create the resource group
+ az group create --name $AZURE_RESOURCE_GROUP --location $AZURE_LOCATION
+else
+ echo "Resource group $AZURE_RESOURCE_GROUP already exists. Skipping creation."
+fi
az deployment group create \
--resource-group $AZURE_RESOURCE_GROUP \
diff --git a/scenarios/mnist/deployment/azure/encrypted-filesystem-config-template.json b/scenarios/mnist/deployment/azure/encrypted-filesystem-config-template.json
new file mode 100644
index 0000000..2af9e95
--- /dev/null
+++ b/scenarios/mnist/deployment/azure/encrypted-filesystem-config-template.json
@@ -0,0 +1,70 @@
+{
+ "azure_filesystems": [
+ {
+ "azure_url": "",
+ "azure_url_private": false,
+ "read_write": false,
+ "mount_point": "",
+ "key": {
+ "kid": "",
+ "kty": "",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "akv": {
+ "endpoint": "",
+ "api_version": "api-version=7.3-preview",
+ "bearer_token": ""
+ }
+ },
+ "key_derivation": {
+ "salt": "",
+ "label": ""
+ }
+ },
+ {
+ "azure_url": "",
+ "azure_url_private": false,
+ "read_write": false,
+ "mount_point": "",
+ "key": {
+ "kid": "",
+ "kty": "",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "akv": {
+ "endpoint": "",
+ "api_version": "api-version=7.3-preview",
+ "bearer_token": ""
+ }
+ },
+ "key_derivation": {
+ "salt": "",
+ "label": ""
+ }
+ },
+ {
+ "azure_url": "",
+ "azure_url_private": false,
+ "read_write": true,
+ "mount_point": "",
+ "key": {
+ "kid": "",
+ "kty": "",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "akv": {
+ "endpoint": "",
+ "api_version": "api-version=7.3-preview",
+ "bearer_token": ""
+ }
+ },
+ "key_derivation": {
+ "salt": "",
+ "label": ""
+ }
+ }
+ ]
+}
\ No newline at end of file
diff --git a/scenarios/mnist/deployment/azure/generatefs.sh b/scenarios/mnist/deployment/azure/generatefs.sh
new file mode 100755
index 0000000..df8833e
--- /dev/null
+++ b/scenarios/mnist/deployment/azure/generatefs.sh
@@ -0,0 +1,75 @@
+#!/bin/bash
+
+# Copyright (c) Microsoft Corporation.
+# Licensed under the MIT License.
+
+while getopts ":d:k:i:" options; do
+ case $options in
+ d)dataPath=$OPTARG;;
+ k)keyFilePath=$OPTARG;;
+ i)encryptedImage=$OPTARG;;
+ esac
+done
+
+echo Encrypting $dataPath with key $keyFilePath and generating $encryptedImage
+deviceName=cryptdevice1
+deviceNamePath="/dev/mapper/$deviceName"
+
+if [ -f "$keyFilePath" ]; then
+ echo "[!] Encrypting dataset using $keyFilePath"
+else
+ echo "[!] Generating keyfile..."
+ dd if=/dev/random of="$keyFilePath" count=1 bs=32
+ truncate -s 32 "$keyFilePath"
+fi
+
+echo "[!] Creating encrypted image..."
+
+response=`du -s $dataPath`
+read -ra arr <<< "$response"
+size=`echo "x=l($arr)/l(2); scale=0; 2^((x+0.5)/1)*2" | bc -l;`
+
+# cryptsetup requires 16M or more
+
+if (($((size)) < 65536)); then
+ size="65536"
+fi
+size=$size"K"
+
+echo "Data size: $size"
+
+rm -f "$encryptedImage"
+touch "$encryptedImage"
+truncate --size $size "$encryptedImage"
+
+sudo cryptsetup luksFormat --type luks2 "$encryptedImage" \
+ --key-file "$keyFilePath" -v --batch-mode --sector-size 4096 \
+ --cipher aes-xts-plain64 \
+ --pbkdf pbkdf2 --pbkdf-force-iterations 1000
+
+sudo cryptsetup luksOpen "$encryptedImage" "$deviceName" \
+ --key-file "$keyFilePath" \
+ --integrity-no-journal --persistent
+
+echo "[!] Formatting as ext4..."
+
+sudo mkfs.ext4 "$deviceNamePath"
+
+echo "[!] Mounting..."
+
+mountPoint=`mktemp -d`
+echo "Mounting to $mountPoint"
+sudo mount -t ext4 "$deviceNamePath" "$mountPoint" -o loop
+
+echo "[!] Copying contents to encrypted device..."
+
+# The /* is needed to copy folder contents instead of the folder + contents
+sudo cp -r $dataPath/* "$mountPoint"
+sudo rm -rf "$mountPoint/lost+found"
+ls "$mountPoint"
+
+echo "[!] Closing device..."
+
+sudo umount "$mountPoint"
+sleep 2
+sudo cryptsetup luksClose "$deviceName"
diff --git a/scenarios/mnist/deployment/azure/importkey-config-template.json b/scenarios/mnist/deployment/azure/importkey-config-template.json
new file mode 100644
index 0000000..42ed7ee
--- /dev/null
+++ b/scenarios/mnist/deployment/azure/importkey-config-template.json
@@ -0,0 +1,29 @@
+{
+ "key":{
+ "kid": "",
+ "kty": "oct-HSM",
+ "authority": {
+ "endpoint": "sharedneu.neu.attest.azure.net"
+ },
+ "akv": {
+ "endpoint": "",
+ "api_version": "api-version=7.3-preview",
+ "bearer_token": ""
+ }
+ },
+ "key_derivation":
+ {
+ "salt": "",
+ "label": ""
+ },
+ "claims": [
+ [{
+ "claim": "x-ms-sevsnpvm-hostdata",
+ "equals": ""
+ },
+ {
+ "claim": "x-ms-compliance-status",
+ "equals": "azure-compliant-uvm"
+ }]
+ ]
+}
\ No newline at end of file
diff --git a/scenarios/mnist/deployment/docker/docker-compose-preprocess.yml b/scenarios/mnist/deployment/docker/docker-compose-preprocess.yml
deleted file mode 100644
index 519a904..0000000
--- a/scenarios/mnist/deployment/docker/docker-compose-preprocess.yml
+++ /dev/null
@@ -1,7 +0,0 @@
-services:
- mnist:
- image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}depa-mnist-preprocess:latest
- volumes:
- - $MNIST_INPUT_PATH:/mnt/input/mnist
- - $MNIST_OUTPUT_PATH:/mnt/output/mnist
- command: ["python3.9", "preprocess.py"]
diff --git a/scenarios/mnist/deployment/docker/docker-compose-save-model.yml b/scenarios/mnist/deployment/docker/docker-compose-save-model.yml
deleted file mode 100644
index 65e7c9d..0000000
--- a/scenarios/mnist/deployment/docker/docker-compose-save-model.yml
+++ /dev/null
@@ -1,6 +0,0 @@
-services:
- model_save:
- image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}depa-mnist-save-model:latest
- volumes:
- - $MODEL_OUTPUT_PATH:/mnt/model
- command: ["python3.9", "save_model.py"]
diff --git a/scenarios/mnist/deployment/docker/save-model.sh b/scenarios/mnist/deployment/docker/save-model.sh
deleted file mode 100755
index 472ae96..0000000
--- a/scenarios/mnist/deployment/docker/save-model.sh
+++ /dev/null
@@ -1,3 +0,0 @@
-export MODEL_OUTPUT_PATH=$PWD/../../data/model
-mkdir -p $MODEL_OUTPUT_PATH
-docker compose -f docker-compose-save-model.yml up --remove-orphans
diff --git a/scenarios/mnist/deployment/docker/train.sh b/scenarios/mnist/deployment/docker/train.sh
deleted file mode 100755
index 742f7f7..0000000
--- a/scenarios/mnist/deployment/docker/train.sh
+++ /dev/null
@@ -1,8 +0,0 @@
-export DATA_DIR=$PWD/../../data
-export MNIST_INPUT_PATH=$DATA_DIR/preprocessed
-export MODEL_INPUT_PATH=$DATA_DIR/model
-export MODEL_OUTPUT_PATH=/tmp/output
-mkdir -p $MODEL_OUTPUT_PATH
-export CONFIGURATION_PATH=/tmp
-cp $PWD/../../config/pipeline_config.json /tmp/pipeline_config.json
-docker compose -f docker-compose-train.yml up --remove-orphans
diff --git a/scenarios/mnist/deployment/local/docker-compose-modelsave.yml b/scenarios/mnist/deployment/local/docker-compose-modelsave.yml
new file mode 100644
index 0000000..98757fb
--- /dev/null
+++ b/scenarios/mnist/deployment/local/docker-compose-modelsave.yml
@@ -0,0 +1,6 @@
+services:
+ model_save:
+ image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}mnist-model-save:latest
+ volumes:
+ - $MODEL_OUTPUT_PATH:/mnt/model
+ command: ["python3", "save_base_model.py"]
diff --git a/scenarios/mnist/deployment/local/docker-compose-preprocess.yml b/scenarios/mnist/deployment/local/docker-compose-preprocess.yml
new file mode 100644
index 0000000..913be85
--- /dev/null
+++ b/scenarios/mnist/deployment/local/docker-compose-preprocess.yml
@@ -0,0 +1,7 @@
+services:
+ mnist:
+ image: ${CONTAINER_REGISTRY:+$CONTAINER_REGISTRY/}preprocess-mnist:latest
+ volumes:
+ - $MNIST_INPUT_PATH:/mnt/input/data
+ - $MNIST_OUTPUT_PATH:/mnt/output/preprocessed
+ command: ["python3", "preprocess_mnist.py"]
diff --git a/scenarios/mnist/deployment/docker/docker-compose-train.yml b/scenarios/mnist/deployment/local/docker-compose-train.yml
similarity index 100%
rename from scenarios/mnist/deployment/docker/docker-compose-train.yml
rename to scenarios/mnist/deployment/local/docker-compose-train.yml
diff --git a/scenarios/mnist/deployment/docker/preprocess.sh b/scenarios/mnist/deployment/local/preprocess.sh
similarity index 55%
rename from scenarios/mnist/deployment/docker/preprocess.sh
rename to scenarios/mnist/deployment/local/preprocess.sh
index aee3967..aab2d6c 100755
--- a/scenarios/mnist/deployment/docker/preprocess.sh
+++ b/scenarios/mnist/deployment/local/preprocess.sh
@@ -1,4 +1,8 @@
-export DATA_DIR=$PWD/../../data
+#!/bin/bash
+
+export REPO_ROOT="$(git rev-parse --show-toplevel)"
+export SCENARIO="mnist"
+export DATA_DIR=$REPO_ROOT/scenarios/$SCENARIO/data
export MNIST_INPUT_PATH=$DATA_DIR
export MNIST_OUTPUT_PATH=$DATA_DIR/preprocessed
mkdir -p $MNIST_OUTPUT_PATH
diff --git a/scenarios/mnist/deployment/local/save-model.sh b/scenarios/mnist/deployment/local/save-model.sh
new file mode 100755
index 0000000..0eacc39
--- /dev/null
+++ b/scenarios/mnist/deployment/local/save-model.sh
@@ -0,0 +1,7 @@
+#!/bin/bash
+
+export REPO_ROOT="$(git rev-parse --show-toplevel)"
+export SCENARIO="mnist"
+export MODEL_OUTPUT_PATH=$REPO_ROOT/scenarios/$SCENARIO/modeller/models
+mkdir -p $MODEL_OUTPUT_PATH
+docker compose -f docker-compose-modelsave.yml up --remove-orphans
\ No newline at end of file
diff --git a/scenarios/mnist/deployment/local/train.sh b/scenarios/mnist/deployment/local/train.sh
new file mode 100755
index 0000000..cd2da50
--- /dev/null
+++ b/scenarios/mnist/deployment/local/train.sh
@@ -0,0 +1,25 @@
+#!/bin/bash
+
+export REPO_ROOT="$(git rev-parse --show-toplevel)"
+export SCENARIO="mnist"
+
+export DATA_DIR=$REPO_ROOT/scenarios/$SCENARIO/data
+export MODEL_DIR=$REPO_ROOT/scenarios/$SCENARIO/modeller
+
+export MNIST_INPUT_PATH=$DATA_DIR/preprocessed
+
+export MODEL_INPUT_PATH=$MODEL_DIR/models
+
+# export MODEL_OUTPUT_PATH=/tmp/output
+export MODEL_OUTPUT_PATH=$MODEL_DIR/output
+sudo rm -rf $MODEL_OUTPUT_PATH
+mkdir -p $MODEL_OUTPUT_PATH
+
+export CONFIGURATION_PATH=$REPO_ROOT/scenarios/$SCENARIO/config
+# export CONFIGURATION_PATH=/tmp
+# cp $PWD/../../config/pipeline_config.json /tmp/pipeline_config.json
+
+# Run consolidate_pipeline.sh to create pipeline_config.json
+$REPO_ROOT/scenarios/$SCENARIO/config/consolidate_pipeline.sh
+
+docker compose -f docker-compose-train.yml up --remove-orphans
diff --git a/scenarios/mnist/export-variables.sh b/scenarios/mnist/export-variables.sh
new file mode 100755
index 0000000..6f9ae8c
--- /dev/null
+++ b/scenarios/mnist/export-variables.sh
@@ -0,0 +1,60 @@
+#!/bin/bash
+
+# Azure Naming Rules:
+#
+# Resource Group:
+# - 1–90 characters
+# - Letters, numbers, underscores, parentheses, hyphens, periods allowed
+# - Cannot end with a period (.)
+# - Case-insensitive, unique within subscription
+#
+# Key Vault:
+# - 3-24 characters
+# - Globally unique name
+# - Lowercase letters, numbers, hyphens only
+# - Must start and end with letter or number
+#
+# Storage Account:
+# - 3-24 characters
+# - Globally unique name
+# - Lowercase letters and numbers only
+#
+# Storage Container:
+# - 3-63 characters
+# - Lowercase letters, numbers, hyphens only
+# - Must start and end with a letter or number
+# - No consecutive hyphens
+# - Unique within storage account
+
+# For cloud resource creation:
+declare -x SCENARIO=mnist
+declare -x REPO_ROOT="$(git rev-parse --show-toplevel)"
+declare -x CONTAINER_REGISTRY=ispirt.azurecr.io
+declare -x AZURE_LOCATION=centralindia
+declare -x AZURE_SUBSCRIPTION_ID=
+declare -x AZURE_RESOURCE_GROUP=
+declare -x AZURE_KEYVAULT_ENDPOINT=
+declare -x AZURE_STORAGE_ACCOUNT_NAME=
+
+declare -x AZURE_MNIST_CONTAINER_NAME=mnistcontainer
+declare -x AZURE_MODEL_CONTAINER_NAME=modelcontainer
+declare -x AZURE_OUTPUT_CONTAINER_NAME=outputcontainer
+
+# For key import:
+declare -x CONTRACT_SERVICE_URL=https://depa-training-contract-service.centralindia.cloudapp.azure.com:8000
+declare -x TOOLS_HOME=$REPO_ROOT/external/confidential-sidecar-containers/tools
+
+# Export all variables to make them available to other scripts
+export SCENARIO
+export REPO_ROOT
+export CONTAINER_REGISTRY
+export AZURE_LOCATION
+export AZURE_SUBSCRIPTION_ID
+export AZURE_RESOURCE_GROUP
+export AZURE_KEYVAULT_ENDPOINT
+export AZURE_STORAGE_ACCOUNT_NAME
+export AZURE_MNIST_CONTAINER_NAME
+export AZURE_MODEL_CONTAINER_NAME
+export AZURE_OUTPUT_CONTAINER_NAME
+export CONTRACT_SERVICE_URL
+export TOOLS_HOME
\ No newline at end of file
diff --git a/scenarios/mnist/src/preprocess.py b/scenarios/mnist/src/preprocess.py
deleted file mode 100644
index d742162..0000000
--- a/scenarios/mnist/src/preprocess.py
+++ /dev/null
@@ -1,18 +0,0 @@
-import torch
-import torchvision
-import torchvision.transforms as transforms
-
-mnist_input_folder='/mnt/input/mnist/'
-
-# Location of preprocessed MNIST dataset
-mnist_output_folder='/mnt/output/mnist/'
-
-transform = transforms.Compose(
- [transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
-
-trainset = torchvision.datasets.CIFAR10(root=mnist_input_folder, train=True,
- download=True, transform=transform)
-
-# Save the CIFAR10 dataset
-torch.save(trainset, mnist_output_folder + 'cifar10-dataset.pth')
diff --git a/scenarios/mnist/src/preprocess_mnist.py b/scenarios/mnist/src/preprocess_mnist.py
new file mode 100644
index 0000000..54598f5
--- /dev/null
+++ b/scenarios/mnist/src/preprocess_mnist.py
@@ -0,0 +1,56 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
+import os
+import torch
+import torchvision
+import torchvision.transforms as transforms
+import h5py
+
+mnist_input_folder='/mnt/input/data/'
+
+# Location of preprocessed MNIST dataset
+mnist_output_folder='/mnt/output/preprocessed/'
+
+transform = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.1307,), (0.3081,)) # MNIST mean and std
+])
+
+trainset = torchvision.datasets.MNIST(root=mnist_input_folder, train=True, download=True, transform=transform)
+
+# Build tensors (N, C, H, W) and labels (N,)
+features = []
+targets = []
+for img, label in trainset:
+ # img is a tensor (1, 28, 28)
+ features.append(img)
+ targets.append(label)
+
+features = torch.stack(features).to(torch.float32)
+targets = torch.tensor(targets, dtype=torch.int64)
+
+# Ensure output directory exists
+os.makedirs(mnist_output_folder, exist_ok=True)
+
+# Save as HDF5 with keys 'features' and 'targets'
+out_path = os.path.join(mnist_output_folder, 'mnist-dataset.h5')
+with h5py.File(out_path, 'w') as f:
+ f.create_dataset('features', data=features.numpy())
+ f.create_dataset('targets', data=targets.numpy())
+
+print(f"Saved MNIST dataset to {out_path} as HDF5 with keys 'features' and 'targets'.")
diff --git a/scenarios/mnist/src/save_base_model.py b/scenarios/mnist/src/save_base_model.py
new file mode 100644
index 0000000..ea0c725
--- /dev/null
+++ b/scenarios/mnist/src/save_base_model.py
@@ -0,0 +1,48 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+model_path="/mnt/model/"
+
+class Net(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv1 = nn.Conv2d(1, 32, 3)
+ self.pool = nn.MaxPool2d(2, 2)
+ self.conv2 = nn.Conv2d(32, 64, 3)
+ self.fc1 = nn.Linear(64 * 5 * 5, 128)
+ self.fc2 = nn.Linear(128, 10) # 10 classes for MNIST digits
+
+ def forward(self, x):
+ x = self.pool(F.relu(self.conv1(x)))
+ x = self.pool(F.relu(self.conv2(x)))
+ x = torch.flatten(x, 1)
+ x = F.relu(self.fc1(x))
+ x = self.fc2(x)
+ return x
+
+
+net = Net()
+
+# Define the input size for MNIST (batch_size, channels, height, width)
+dummy_input = torch.randn(1, 1, 28, 28)
+
+# Export the model
+torch.onnx.export(net, dummy_input, model_path + "model.onnx")
\ No newline at end of file
diff --git a/scenarios/mnist/src/save_model.py b/scenarios/mnist/src/save_model.py
deleted file mode 100644
index e4dd3af..0000000
--- a/scenarios/mnist/src/save_model.py
+++ /dev/null
@@ -1,33 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-model_path="/mnt/model/"
-
-class Net(nn.Module):
- def __init__(self):
- super().__init__()
- self.conv1 = nn.Conv2d(3, 6, 5)
- self.pool = nn.MaxPool2d(2, 2)
- self.conv2 = nn.Conv2d(6, 16, 5)
- self.fc1 = nn.Linear(16 * 5 * 5, 120)
- self.fc2 = nn.Linear(120, 84)
- self.fc3 = nn.Linear(84, 10)
-
- def forward(self, x):
- x = self.pool(F.relu(self.conv1(x)))
- x = self.pool(F.relu(self.conv2(x)))
- x = torch.flatten(x, 1) # flatten all dimensions except batch
- x = F.relu(self.fc1(x))
- x = F.relu(self.fc2(x))
- x = self.fc3(x)
- return x
-
-
-net = Net()
-
-# Define the input size for the model
-dummy_input = torch.randn(4, 3, 32, 32)
-
-# Export the model
-torch.onnx.export(net, dummy_input, model_path + "model.onnx")
\ No newline at end of file
diff --git a/src/encfs/encfs.sh b/src/encfs/encfs.sh
old mode 100644
new mode 100755
diff --git a/src/policy/policy.rego b/src/policy/policy.rego
index 0bdde75..64e8592 100644
--- a/src/policy/policy.rego
+++ b/src/policy/policy.rego
@@ -5,7 +5,7 @@ import future.keywords.if
allowed if {
all_datasets_in_contract_included
- output_filesystem_mounted
+ modeller_filesystem_mounted
valid_pipeline
}
@@ -27,9 +27,11 @@ all_datasets_in_contract_included if {
}
}
-output_filesystem_mounted if {
- # expect two additional filesystems
- count(input.azure_filesystems) == count(data.datasets) + 2
+modeller_filesystem_mounted if {
+ # expect two additional filesystems (or one if model is instantiated from config in CCR)
+ providers = {p | p = data.datasets[_].name}
+ count(input.azure_filesystems) <= count(providers) + 2
+ count(input.azure_filesystems) > count(providers)
}
valid_pipeline if {
@@ -48,15 +50,15 @@ data_has_privacy_constraints if {
}
last_stage_is_private_training if {
- input.pipeline[count(input.pipeline) - 1].name == "PrivateTrain"
+ input.pipeline[count(input.pipeline) - 1].config.is_private == true
}
last_stage_is_training if {
- input.pipeline[count(input.pipeline) - 1].name == "Train"
+ input.pipeline[count(input.pipeline) - 1].config.is_private == false
}
min_privacy_budget_allocated if {
threshold = min({t | t = to_number(data.constraints[_].privacy[_].epsilon_threshold)})
last := count(input.pipeline)
- input.pipeline[last - 1].config.epsilon_threshold <= threshold
+ input.pipeline[last - 1].config.privacy_params.epsilon <= threshold
}
diff --git a/src/train/.gitignore b/src/train/.gitignore
index 13b2d12..fa702ce 100644
--- a/src/train/.gitignore
+++ b/src/train/.gitignore
@@ -11,4 +11,4 @@ workspace/
**/*.pid
.mypy_cache/
**/*.cose
-**/*.cbor
+**/*.cbor
\ No newline at end of file
diff --git a/src/train/README.md b/src/train/README.md
index da426d5..27b6560 100644
--- a/src/train/README.md
+++ b/src/train/README.md
@@ -1,3 +1,3 @@
# Training runtime
-Python based runtime for training models on private data
\ No newline at end of file
+Configurable Python based runtime for training models on private data
\ No newline at end of file
diff --git a/src/train/pytrain/__init__.py b/src/train/pytrain/__init__.py
index e69de29..ed4f153 100644
--- a/src/train/pytrain/__init__.py
+++ b/src/train/pytrain/__init__.py
@@ -0,0 +1,16 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
\ No newline at end of file
diff --git a/src/train/pytrain/dl_train.py b/src/train/pytrain/dl_train.py
new file mode 100644
index 0000000..477e338
--- /dev/null
+++ b/src/train/pytrain/dl_train.py
@@ -0,0 +1,363 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
+import os
+
+# Torch for datasets and training tools
+import torch
+import torch.optim as optim
+import torch.optim.lr_scheduler as lr_scheduler
+from torch.utils.data import DataLoader
+
+# Opacus for differential privacy
+import opacus
+from opacus import PrivacyEngine # For differential privacy
+from opacus.utils.batch_memory_manager import BatchMemoryManager # For large batch sizes
+
+# Onnx for loading and saving trained models
+import onnx
+from onnx2pytorch import ConvertModel
+from safetensors.torch import load_file as st_load, save_file as st_save
+
+from .task_base import TaskBase
+
+from .utilities.model_constructor import *
+from .utilities.dataset_constructor import *
+from .utilities.loss_constructor import *
+from .utilities.eval_tools import *
+
+
+class Train_DL(TaskBase):
+ """
+ Args:
+ config: training configuration
+
+ Methods:
+ init: initializes the model, dataloader, optimizer, scheduler and privacy engine
+ load_data: loads data from csv as data loaders
+ load_model: loads model object from model config
+ load_optimizer: loads model optimizer from model config
+ load_loss_fn: loads loss function from config
+ make_dprivate: makes model,dataloader and optimizer private
+ train: trains the model
+ inference: inference on the validation set
+ execute: main function which includes all the above functions
+
+ Attributes:
+ config: training configuration
+ device: device to train on
+ is_private: whether to use differential privacy
+ privacy_config: privacy configuration
+ model: model object
+ train_loader: training data loader
+ val_loader: validation data loader
+ test_loader: test data loader
+ custom_loss_fn: custom loss function
+ optimizer: optimizer object
+ scheduler: learning rate scheduler object
+ privacy_engine: privacy engine object
+ model_non_dp: non-private model object
+ """
+
+ def init(self, config):
+ self.config = config
+ self.device = torch.device(config.get("device"))
+ self.is_private = config.get("is_private", False)
+ self.privacy_config = config.get("privacy_params")
+ self.paths = config.get("paths", {})
+ self.model = None
+ self.train_loader = None
+ self.val_loader = None
+ self.test_loader = None
+ self.custom_loss_fn = None
+ self.optimizer = None
+ self.scheduler = None
+ self.privacy_engine = None
+ self.model_non_dp = None
+
+ def load_data(self):
+ dataset_config = self.config.get("dataset_config")
+ input_path = self.paths.get("input_dataset_path")
+ all_splits = create_dataset(dataset_config, input_path)
+
+ if "train" not in all_splits:
+ raise ValueError("Dataset must provide at least a 'train' split")
+
+ train_dataset = all_splits["train"]
+ self.train_loader = DataLoader(train_dataset, batch_size=self.config.get("batch_size"), shuffle=True, num_workers=0)
+ if "val" in all_splits:
+ val_dataset = all_splits["val"]
+ self.val_loader = DataLoader(val_dataset, batch_size=self.config.get("batch_size"), shuffle=True, num_workers=0)
+ if "test" in all_splits:
+ test_dataset = all_splits["test"]
+ self.test_loader = DataLoader(test_dataset, batch_size=self.config.get("batch_size"), shuffle=True, num_workers=0)
+
+ print(f"Loaded dataset splits | train: {self.train_loader.dataset.__len__()} | val: {None if self.val_loader is None else self.val_loader.dataset.__len__()} | test: {None if self.test_loader is None else self.test_loader.dataset.__len__()}")
+
+
+ def load_model(self):
+ model_type = self.config.get("model_type")
+ if model_type == "onnx":
+ onnx_model = onnx.load(self.config.get("paths", {}).get("base_model_path"))
+ model = ConvertModel(onnx_model, experimental=True)
+
+ print("Model loaded from ONNX file")
+ self.model = model.to(self.device)
+
+ if self.is_private:
+ model_non_dp = onnx.load(self.config.get("paths", {}).get("base_model_path"))
+ model_non_dp = ConvertModel(model_non_dp, experimental=True)
+ self.model_non_dp = model_non_dp.to(self.device)
+ print("Created non-private baseline model for comparison")
+
+ elif model_type == "safetensors":
+ self.model = ModelFactory.load_from_dict(self.config.get("model_config"))
+ print("Custom model loaded from PyTorch config")
+ if self.config.get("paths", {}).get("saved_weights_path") is not None:
+ self.model.load_state_dict(st_load(self.config.get("paths", {}).get("saved_weights_path")))
+ print("Loaded weights from " + self.config.get("paths", {}).get("saved_weights_path"))
+ self.model = self.model.to(self.device)
+
+ if self.is_private:
+ self.model_non_dp = ModelFactory.load_from_dict(self.config.get("model_config"))
+ if self.config.get("paths", {}).get("saved_weights_path") is not None:
+ self.model_non_dp.load_state_dict(st_load(self.config.get("paths", {}).get("saved_weights_path")))
+ self.model_non_dp = self.model_non_dp.to(self.device)
+ print("Created non-private baseline model for comparison")
+
+
+ def load_optimizer(self):
+ optimizer_name = self.config.get("optimizer", {}).get("name", "adam")
+ optimizer_params = self.config.get("optimizer", {}).get("params", {})
+ optimizer_class = getattr(optim, optimizer_name)
+ self.optimizer = optimizer_class(self.model.parameters(), **optimizer_params)
+ if self.is_private:
+ self.optimizer_non_dp = optimizer_class(self.model_non_dp.parameters(), **optimizer_params)
+
+ print(f"Optimizer {optimizer_name} loaded from config")
+
+ if self.config.get("scheduler") is not None:
+ scheduler_name = self.config.get("scheduler", {}).get("name", "cyclic")
+ scheduler_params = self.config.get("scheduler", {}).get("params", {})
+ scheduler_class = getattr(lr_scheduler, scheduler_name)
+ self.scheduler = scheduler_class(self.optimizer, **scheduler_params)
+ print(f"Scheduler {scheduler_name} loaded from config")
+
+
+ def load_loss_fn(self):
+ if self.config.get("loss_config") is not None:
+ self.custom_loss_fn = LossComposer.load_from_dict(self.config.get("loss_config"))
+ print("Custom loss function loaded from config")
+ else:
+ # Raise an error if no loss function configuration is found
+ raise ValueError("No loss function configuration found. Please provide a loss function configuration.")
+
+
+ def make_dprivate(self):
+ # Ensure delta is not too large to avoid privacy breach
+ max_delta = 1/len(self.train_loader.dataset)
+ if self.privacy_config.get("delta") > max_delta:
+ self.privacy_config["delta"] = max_delta
+ print(f"Delta set to {max_delta} (1/train_samples) to avoid privacy breach")
+
+ self.privacy_engine = PrivacyEngine() # secure_mode=True requires torchcsprng to be installed
+
+ self.model, self.optimizer, self.train_loader = self.privacy_engine.make_private_with_epsilon(
+ module=self.model,
+ optimizer=self.optimizer,
+ data_loader=self.train_loader,
+ epochs=self.config.get("total_epochs"),
+ target_delta=self.privacy_config.get("delta"), # Privacy budget
+ target_epsilon=self.privacy_config.get("epsilon"), # Probability of privacy breach
+ max_grad_norm=self.privacy_config.get("max_grad_norm"), # threshold for clipping the norm of per-sample gradients
+ batch_first=True
+ )
+
+
+ def train(self):
+ run_val = True if self.val_loader is not None else False
+
+ for epoch in range(self.config.get("total_epochs")):
+ # set model to train mode
+ self.model.train()
+
+ train_loss = 0
+ for [inputs, labels] in self.train_loader:
+ # move inputs and labels to device
+ inputs, labels = inputs.to(self.device), labels.to(self.device)
+
+ # zero the gradients
+ self.optimizer.zero_grad()
+
+ # forward pass
+ pred = self.model(inputs)
+
+ # compute loss
+ loss = self.custom_loss_fn.calculate_loss(pred, labels)
+ train_loss += loss.item()
+
+ # backward pass
+ loss.backward()
+
+ # update weights
+ self.optimizer.step()
+
+ # update learning rate per batch - optional
+ if self.scheduler is not None:
+ self.scheduler.step()
+
+ epsilon = None
+ if self.privacy_engine is not None:
+ epsilon = self.privacy_engine.get_epsilon(self.config.get("privacy_params", {}).get("delta"))
+
+ # update learning rate per epoch - optional
+ # self.scheduler.step()
+
+ eps_str = f"| Epsilon: {epsilon:.4f}" if epsilon is not None else ""
+ print(f"Epoch {epoch+1}/{self.config.get('total_epochs')} completed | Training Loss: {train_loss/len(self.train_loader):.4f} {eps_str}")
+
+ if run_val:
+ val_loss = 0
+ with torch.no_grad():
+ for [inputs, labels] in self.val_loader:
+ inputs, labels = inputs.to(self.device), labels.to(self.device)
+ pred = self.model(inputs)
+ loss = self.custom_loss_fn.calculate_loss(pred, labels)
+ val_loss += loss.item()
+ print(f"Epoch {epoch+1}/{self.config.get('total_epochs')} completed | Validation Loss: {val_loss/len(self.val_loader):.4f}")
+
+ # --- END OF TRAINING ---
+
+ # If privacy is enabled, train a non-private replica model for comparison
+ if self.is_private:
+ print("\nTraining non-private replica model for comparison...")
+
+ # Train non-private model
+ self.model_non_dp.train()
+ for epoch in range(self.config.get("total_epochs")):
+ train_loss = 0
+ for [inputs, labels] in self.train_loader:
+ inputs, labels = inputs.to(self.device), labels.to(self.device)
+ self.optimizer_non_dp.zero_grad()
+ pred = self.model_non_dp(inputs)
+ loss = self.custom_loss_fn.calculate_loss(pred, labels)
+ train_loss += loss.item()
+ loss.backward()
+ self.optimizer_non_dp.step()
+
+ print(f"Non-private baseline model - Epoch {epoch+1}/{self.config.get('total_epochs')} completed | Training Loss: {train_loss/len(self.train_loader):.4f}")
+
+ if run_val:
+ val_loss = 0
+ with torch.no_grad():
+ for [inputs, labels] in self.val_loader:
+ inputs, labels = inputs.to(self.device), labels.to(self.device)
+ pred = self.model_non_dp(inputs)
+ loss = self.custom_loss_fn.calculate_loss(pred, labels)
+ val_loss += loss.item()
+ print(f"Non-private baseline model - Epoch {epoch+1}/{self.config.get('total_epochs')} completed | Validation Loss: {val_loss/len(self.val_loader):.4f}")
+
+
+ def save_model(self):
+
+ # If using differential privacy, extract the underlying model from GradSampleModule
+ if isinstance(self.model, opacus.grad_sample.GradSampleModule):
+ self.model = self.model._module
+
+ # set model to eval mode
+ self.model.eval()
+
+ # save the model
+ if self.config.get("model_type") == "safetensors":
+ output_path = os.path.join(self.config.get("paths", {}).get("trained_model_output_path"), "trained_model.safetensors")
+ print("Saving trained model to " + output_path)
+ st_save(self.model.state_dict(), output_path)
+
+ elif self.config.get("model_type") == "onnx":
+ output_path = os.path.join(self.config.get("paths", {}).get("trained_model_output_path"), "trained_model.onnx")
+ print("Saving trained model to " + output_path)
+ in_shape = (1,) + tuple(self.train_loader.dataset[0][0].shape)
+ torch.onnx.export(
+ self.model,
+ torch.randn(in_shape),
+ output_path,
+ verbose=False,
+ )
+
+
+ def inference_eval(self):
+ if self.test_loader is None:
+ print("Test loader is not defined. Skipping inference.")
+ return
+
+ self.model.eval()
+ preds_list = []
+ targets_list = []
+
+ test_loss = 0
+
+ with torch.no_grad():
+ for batch in self.test_loader:
+ if isinstance(batch, (list, tuple)) and len(batch) == 2:
+ x, y = batch
+ else:
+ x, y = batch["input"], batch.get("target", None)
+
+ x = x.to(self.device)
+ y = y.to(self.device) if y is not None else None
+ pred = self.model(x)
+
+ loss = self.custom_loss_fn.calculate_loss(pred, y)
+ test_loss += loss.item()
+
+ preds_list.extend([p.detach().squeeze().cpu().numpy() for p in pred])
+ if y is not None:
+ targets_list.extend([t.detach().squeeze().cpu().numpy() for t in y])
+
+ # compute test loss
+ test_loss = test_loss / len(self.test_loader)
+
+ # compute metrics
+ numeric_metrics = compute_metrics(preds_list, targets_list, test_loss, self.config)
+
+ print(f"Evaluation Metrics: {numeric_metrics}")
+
+
+ def execute(self, config):
+ try:
+ self.init(config)
+ self.load_data()
+ self.load_model()
+ self.load_optimizer()
+ self.load_loss_fn()
+ if self.is_private:
+ self.make_dprivate()
+
+ # --- START OF TRAINING ---
+ self.train()
+ # --- END OF TRAINING ---
+
+ self.save_model()
+
+ # run evaluation on test set
+ self.inference_eval()
+
+ print("CCR Training complete!\n")
+
+ except Exception as e:
+ print(f"An error occurred: {e}")
+ raise e
\ No newline at end of file
diff --git a/src/train/pytrain/join.py b/src/train/pytrain/join.py
index 06695b9..9d1f4fc 100644
--- a/src/train/pytrain/join.py
+++ b/src/train/pytrain/join.py
@@ -1,260 +1,471 @@
-# 2023, The DEPA CCR DP Training Reference Implementation
-# authors shyam@ispirt.in, sridhar.avs@ispirt.in
+# 2025 DEPA Foundation
#
-# Licensed TBD
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Key references / Attributions: https://depa.world/training/contracts
-# Key frameworks used : pyspark
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
import os
import json
import argparse
+import shutil
from pathlib import Path
from pyspark.sql import SparkSession
-from pyspark.sql.functions import *
-from pyspark.sql.types import *
+from pyspark.sql.functions import col
+from pyspark.sql.types import StructType
from pyspark.sql.functions import col, column
from .task_base import TaskBase
-# Debug Enabled
-debug_poc = True
-class Join(TaskBase):
+# class SparkJoin(TaskBase):
- def load_tdp_list(self, config, debug=True):
- """
- Extract List of TDP configurations for data joining process from query config
- """
- return config["datasets"]
-
- def load_joined_dataset_config(cself, config, debug=True):
- """
- Extract List of query join configurations for data joining process from query config.
- """
- return config["joined_dataset"]
-
- def get_name(self, tdp_config_list, debug=True):
- """
- Extract list of names of all TDP's from query config
- """
- name_list = []
- for c in tdp_config_list:
- name_list.append(c["name"])
- if debug:
- print("Debug |get_name")
- return name_list
-
- def create_spark_context(self, tdp_config_list, debug=True):
- """
- Create a spark session with app context auto generated from TDP names
- """
- name_list = self.get_name(tdp_config_list)
- context = ""
- for c in name_list:
- context = context + c + "_"
- if debug:
- print("Debug |create_spark_context")
- return SparkSession.builder.appName(context).getOrCreate()
-
- def generate_query(self, tdp_config_list, joined_dataset_config, debug=True):
- """
- Extract the query logic from the query config. Current implementation extracts query from query config.
- """
- if debug:
- print("Debug |generate_query")
- return joined_dataset_config["joining_query"]
-
- def dropDupeDfCols(self, df, debug=True):
+# def load_tdp_list(self, config, debug=True):
+# """
+# Extract List of TDP configurations for data joining process from query config
+# """
+# return config["datasets"]
+
+# def load_joined_dataset_config(cself, config, debug=True):
+# """
+# Extract List of query join configurations for data joining process from query config.
+# """
+# return config["joined_dataset"]
+
+# def get_name(self, tdp_config_list, debug=True):
+# """
+# Extract list of names of all TDP's from query config
+# """
+# name_list = []
+# for c in tdp_config_list:
+# name_list.append(c["name"])
+
+# return name_list
+
+# def create_spark_context(self, tdp_config_list, debug=True):
+# """
+# Create a spark session with app context auto generated from TDP names
+# """
+# name_list = self.get_name(tdp_config_list)
+# context = ""
+# for c in name_list:
+# context = context + c + "_"
+
+# return SparkSession.builder.appName(context).getOrCreate()
+
+# def generate_query(self, tdp_config_list, joined_dataset_config, debug=True):
+# """
+# Extract the query logic from the query config. Current implementation extracts query from query config.
+# """
+# # Very light validation to disallow multiple statements / dangerous tokens
+# q = joined_dataset_config["joining_query"]
+# if not isinstance(q, str):
+# raise ValueError("joining_query must be a string")
+# # Disallow semicolons and common DDL/DML keywords
+# banned = [";", "DROP ", "TRUNCATE ", "ALTER ", "CREATE ", "INSERT ", "UPDATE ", "DELETE "]
+# uq = q.upper()
+# if any(b in uq for b in banned):
+# raise ValueError("Disallowed SQL construct in joining_query")
+# return q
+
+# def dropDupeDfCols(self, df, debug=True):
+# """
+# Drops Duplicate Columns from the dataframe passed
+# """
+
+# newcols = []
+# dupcols = []
+
+# for i in range(len(df.columns)):
+# if df.columns[i] not in newcols:
+# newcols.append(df.columns[i])
+# else:
+# dupcols.append(i)
+
+# df = df.toDF(*[str(i) for i in range(len(df.columns))])
+# for dupcol in dupcols:
+# df = df.drop(str(dupcol))
+
+# return df.toDF(*newcols)
+
+# def dp_load_data(self, spark, input_folder, data_file, load=True, debug=True):
+# """
+# Generic Data Loading Function at Data Provider
+# """
+
+# if load:
+# input_file = input_folder + data_file
+
+# data_loaded = spark.read.csv(
+# input_file, header=True, inferSchema=True, mode="DROPMALFORMED"
+# )
+
+# return data_loaded
+
+# def create_view(self, dataset_name, view_name):
+# """
+# Create the temp query-able views for spark query processing
+# """
+# return dataset_name.createOrReplaceTempView(view_name)
+
+# def ccr_prepare_joined_dataset(self, dataset_info, query, save_path, debug=True):
+# """
+# In CCR/Sandbox we are creating the joined dataset anon from the three TDP and joined data configurations . This function abstracts this and returns the dataet.
+# """
+# # Create the temp query-able views
+# for c in dataset_info:
+# self.create_view(c[0], c[1])
+# # Query Execution
+# retun_dataset = spark.sql(query).cache()
+# retun_dataset = retun_dataset.drop("_c0")
+# # Save the dataset
+# retun_dataset.toPandas().to_csv(save_path)
+
+# return retun_dataset
+
+# def ccr_prepare_joined_dataset_full(
+# self, spark, dataset_info, joined_dataset_config, query, save_path, debug=True
+# ):
+# """
+# In CCR/Sandbox we are creating the joined dataset anon from the three TDPs. This function abstracts this and returns the dataet.
+# """
+# # Create the temp query-able views
+# for c in dataset_info:
+# self.create_view(c[0], c[1])
+
+# # Query Execution
+# return_dataset = spark.sql(query).cache()
+# return_dataset = return_dataset.drop("_c0")
+# return_dataset_step1 = return_dataset.dropDuplicates()
+# return_dataset_step2 = self.dropDupeDfCols(return_dataset_step1)
+# drop_columns = joined_dataset_config["drop_columns"]
+# return_dataset_step3 = return_dataset_step2.drop(*drop_columns)
+# return_dataset_final = return_dataset_step3
+
+# return_dataset_final.toPandas().to_csv(save_path)
+
+# return return_dataset_final
+
+# def ccr_create_joined_dataset_wo_identifiers(
+# self, joined_dataset, joined_dataset_config, modelfile, debug=True
+# ):
+# identifiers = joined_dataset_config["identifiers"]
+# return_dataset = joined_dataset.drop(*identifiers)
+# # Modeling dataset
+# return_dataset.toPandas().to_csv(modelfile)
+
+# return return_dataset
+
+# def generate_data_info(self, spark, tdp_config_list):
+# """
+# Extracts the list of loaded data set along with its alias from TDP config.
+# Required for Spark view creation-create_view
+# """
+# lis = []
+# for c in tdp_config_list:
+# l = []
+# l.append(self.dp_load_data(spark, c["mount_path"], c["file"]))
+# l.append(c["name"])
+# lis.append(l)
+# return lis
+
+# def generate_base_query_dataset(self, dataset_info):
+# """
+# This will generate the base joined dataset for ccr_prepare_joined_dataset_full function.
+# """
+# # sandbox_icmr_cowin_index_linked_anon
+# file_str = "sandbox_"
+# for c in dataset_info:
+# file_str = file_str + c[1] + "_"
+# file_str = file_str + "linked_anon.csv"
+
+# return file_str
+
+# def execute(self, config):
+# """
+# Final Execution Function
+# """
+# tdp_config_list = self.load_tdp_list(config)
+# joined_dataset_config = self.load_joined_dataset_config(config)
+# spark = self.create_spark_context(
+# tdp_config_list
+# ) # currently treated as a global instance but can be converted into a specific instance for multiple pipelines
+# spark.sparkContext.setLogLevel("ERROR")
+# query = self.generate_query(tdp_config_list, joined_dataset_config)
+# dataset_info = self.generate_data_info(spark, tdp_config_list)
+# # sandbox_joined_anon_simplified=ccr_prepare_joined_dataset_full(dataset_info,query,joined_dataset_config["save_path"],debug=True)
+# save_path = joined_dataset_config["joined_dataset"]
+# sandbox_joined_anon_simplified = self.ccr_prepare_joined_dataset_full(
+# spark, dataset_info, joined_dataset_config, query, save_path, debug=False
+# )
+# dataset_names = [dataset["name"] for dataset in config["datasets"]]
+# print(f"Joined datasets {dataset_names} in {save_path}")
+# sandbox_joined_without_key_identifiers = (
+# self.ccr_create_joined_dataset_wo_identifiers(
+# sandbox_joined_anon_simplified,
+# joined_dataset_config,
+# save_path,
+# True,
+# )
+# )
+
+
+class SparkJoin(TaskBase):
+ """Utility class for loading, joining, and exporting Spark datasets based on config."""
+
+ # -------------------------------------------------------------------------
+ # Config Loaders
+ # -------------------------------------------------------------------------
+
+ def load_tdp_list(self, config):
+ """Extract list of dataset (TDP) configurations."""
+ return config.get("datasets", [])
+
+ def load_joined_dataset_config(self, config):
+ """Extract joined dataset configuration."""
+ return config.get("joined_dataset", {})
+
+ # -------------------------------------------------------------------------
+ # Spark Context
+ # -------------------------------------------------------------------------
+
+ def get_dataset_names(self, tdp_config_list):
+ """Return dataset names from config."""
+ return [c["name"] for c in tdp_config_list]
+
+ def create_spark_session(self, tdp_config_list):
+ """Create Spark session with app name based on dataset names."""
+ app_name = "_".join(self.get_dataset_names(tdp_config_list))
+ return SparkSession.builder.appName(app_name).getOrCreate()
+
+ # -------------------------------------------------------------------------
+ # Query Handling
+ # -------------------------------------------------------------------------
+
+ def validate_and_get_query(self, joined_dataset_config):
+ """Validate and return query string from config."""
+ query = joined_dataset_config.get("joining_query")
+ if not isinstance(query, str):
+ raise ValueError("joining_query must be a string")
+
+ banned_tokens = [";", "DROP ", "TRUNCATE ", "ALTER ", "CREATE ",
+ "INSERT ", "UPDATE ", "DELETE "]
+ if any(token in query.upper() for token in banned_tokens):
+ raise ValueError("Disallowed SQL construct in joining_query")
+
+ return query
+
+ # -------------------------------------------------------------------------
+ # Data I/O
+ # -------------------------------------------------------------------------
+
+ def read_data(
+ self,
+ spark,
+ path,
+ fmt
+ ):
+ """Read dataset from given path with specified format."""
+ fmt = fmt.lower()
+ if fmt == "csv":
+ return spark.read.csv(path, header=True, inferSchema=True, mode="DROPMALFORMED")
+ elif fmt == "parquet":
+ return spark.read.parquet(path)
+ elif fmt == "hdf5":
+ # Spark does not natively support HDF5 — fallback to pandas then parallelize
+ import pandas as pd
+ pdf = pd.read_hdf(path)
+ return spark.createDataFrame(pdf)
+ else:
+ raise ValueError(f"Unsupported format: {fmt}")
+
+ def write_data(
+ self,
+ df,
+ path,
+ fmt
+ ) -> None:
+ """Write dataset to given path with specified format."""
+ fmt = fmt.lower()
+ if fmt == "csv":
+ df.toPandas().to_csv(path, index=False)
+ elif fmt == "parquet":
+ df.write.mode("overwrite").parquet(path)
+ elif fmt == "hdf5":
+ import pandas as pd
+ df.toPandas().to_hdf(path, key="data", mode="w")
+ else:
+ raise ValueError(f"Unsupported format: {fmt}")
+
+ # -------------------------------------------------------------------------
+ # Data Processing
+ # -------------------------------------------------------------------------
+
+ def drop_duplicate_columns(self, df):
+ """Drop duplicate columns from DataFrame."""
+ seen, duplicates = [], []
+ for idx, col_name in enumerate(df.columns):
+ if col_name in seen:
+ duplicates.append(idx)
+ else:
+ seen.append(col_name)
+
+ df = df.toDF(*map(str, range(len(df.columns))))
+ for dup in duplicates:
+ df = df.drop(str(dup))
+ return df.toDF(*seen)
+
+ def create_temp_views(self, dataset_info):
+ """Register each dataset as a temporary SQL view."""
+ for df, alias in dataset_info:
+ df.createOrReplaceTempView(alias)
+
+ def prepare_joined_dataset(
+ self,
+ spark,
+ dataset_info,
+ joined_dataset_config,
+ query,
+ save_path,
+ fmt
+ ):
"""
- Drops Duplicate Columns from the dataframe passed
+ Create joined dataset by running SQL query and applying cleaning rules.
+ - Drops `_c0` (index col from CSVs).
+ - Drops duplicates.
+ - Removes explicitly configured columns.
+ - Saves to disk in requested format.
"""
+ self.create_temp_views(dataset_info)
- newcols = []
- dupcols = []
+ df = spark.sql(query).cache()
+ if "_c0" in df.columns:
+ df = df.drop("_c0")
- for i in range(len(df.columns)):
- if df.columns[i] not in newcols:
- newcols.append(df.columns[i])
- else:
- dupcols.append(i)
+ df = df.dropDuplicates()
+ df = self.drop_duplicate_columns(df)
- df = df.toDF(*[str(i) for i in range(len(df.columns))])
- for dupcol in dupcols:
- df = df.drop(str(dupcol))
- if debug:
- print("Debug |dropDupeDfCols")
- return df.toDF(*newcols)
+ drop_cols = joined_dataset_config.get("drop_columns", [])
+ if drop_cols:
+ df = df.drop(*drop_cols)
- def dp_load_data(self, spark, input_folder, data_file, load=True, debug=True):
- """
- Generic Data Loading Function at Data Provider
- """
-
- if load:
- input_file = input_folder + data_file
- if debug:
- print("Debug | input_file", input_file)
- data_loaded = spark.read.csv(
- input_file, header=True, inferSchema=True, mode="DROPMALFORMED"
- )
- if debug:
- print("Debug |dp_load_data | input_file", data_loaded.count())
- data_loaded.show(2)
- return data_loaded
-
- def create_view(self, dataset_name, view_name):
- """
- Create the temp query-able views for spark query processing
- """
- return dataset_name.createOrReplaceTempView(view_name)
+ self.write_data(df, save_path, fmt)
+ return df
- def ccr_prepare_joined_dataset(self, dataset_info, query, model_file, debug=True):
- """
- In CCR/Sandbox we are creating the joined dataset anon from the three TDP and joined data configurations . This function abstracts this and returns the dataet.
- """
- # Create the temp query-able views
- for c in dataset_info:
- self.create_view(c[0], c[1])
- # Query Execution
- retun_dataset = spark.sql(query).cache()
- retun_dataset = retun_dataset.drop("_c0")
- # Save the dataset
- retun_dataset.toPandas().to_csv(model_file)
- if debug:
- print("Debug | ccr_prepare_joined_dataset | Dataset Created ", model_file)
- retun_dataset.show(2)
- return retun_dataset
-
- def ccr_prepare_joined_dataset_full(
- self, spark, dataset_info, joined_dataset_config, query, model_file, debug=True
+ def remove_identifiers(
+ self,
+ df,
+ joined_dataset_config,
+ save_path,
+ fmt
):
- """
- In CCR/Sandbox we are creating the joined dataset anon from the three TDPs. This function abstracts this and returns the dataet.
- """
- # Create the temp query-able views
- for c in dataset_info:
- self.create_view(c[0], c[1])
-
- # Query Execution
- return_dataset = spark.sql(query).cache()
- return_dataset = return_dataset.drop("_c0")
- return_dataset_step1 = return_dataset.dropDuplicates()
- return_dataset_step2 = self.dropDupeDfCols(return_dataset_step1)
- drop_columns = joined_dataset_config["drop_columns"]
- return_dataset_step3 = return_dataset_step2.drop(*drop_columns)
- return_dataset_final = return_dataset_step3
-
- return_dataset_final.toPandas().to_csv(model_file)
-
- if debug:
- print(
- "Debug | ccr_prepare_joined_dataset_full|joint_dataset| count =",
- return_dataset.count(),
- )
- print(
- "Debug | ccr_prepare_joined_dataset_full|joint_dataset|step1 count =",
- return_dataset_step1.count(),
- )
- print(
- "Debug | ccr_prepare_joined_dataset_full|joint_dataset|step2 count =",
- return_dataset_step2.count(),
- )
- print(
- "Debug | ccr_prepare_joined_dataset_full|joint_dataset|step3 count =",
- return_dataset_step3.count(),
- )
- print(
- "Debug | ccr_prepare_joined_dataset_full|joint_dataset|final count =",
- return_dataset_final.count(),
- )
- return_dataset.show(2)
- print(
- "Debug | ccr_prepare_joined_dataset_full |Dataset Created ", model_file
- )
-
- return return_dataset_final
-
- def ccr_create_joined_dataset_wo_identifiers(
- self, joined_dataset, joined_dataset_config, modelfile, debug=True
+ """Remove identifier columns before modeling dataset creation."""
+ identifiers = joined_dataset_config.get("identifiers", [])
+ df = df.drop(*identifiers)
+ self.write_data(df, save_path, fmt)
+ return df
+
+ # -------------------------------------------------------------------------
+ # Dataset Info Builders
+ # -------------------------------------------------------------------------
+
+ def build_dataset_info(
+ self,
+ spark,
+ fmt,
+ tdp_config_list
):
- identifiers = joined_dataset_config["identifiers"]
- return_dataset = joined_dataset.drop(*identifiers)
- # Modeling dataset
- return_dataset.toPandas().to_csv(modelfile)
-
- if debug_poc:
- print(
- "Debug | ccr_create_joined_dataset_wo_identifiers|joint_dataset| count =",
- joined_dataset.count(),
- )
- print(
- "Debug | ccr_create_joined_dataset_wo_identifiers|return_dataset| count =",
- return_dataset.count(),
- )
- return_dataset.show(2)
-
- return return_dataset
-
- def generate_data_info(self, spark, tdp_config_list):
- """
- Extracts the list of loaded data set along with its alias from TDP config.
- Required for Spark view creation-create_view
- """
- lis = []
- for c in tdp_config_list:
- l = []
- l.append(self.dp_load_data(spark, c["mount_path"], c["file"]))
- l.append(c["name"])
- lis.append(l)
- return lis
-
- def generate_base_query_dataset(self, dataset_info):
- """
- This will generate the base joined dataset for ccr_prepare_joined_dataset_full function.
- """
- # sandbox_icmr_cowin_index_linked_anon
- file_str = "sandbox_"
- for c in dataset_info:
- file_str = file_str + c[1] + "_"
- file_str = file_str + "linked_anon.csv"
-
- if debug_poc:
- print("Debug | generate_base_query_dataset | file generated ", file_str)
-
- return file_str
+ """Load datasets and return list of (DataFrame, alias)."""
+ dataset_info = []
+ for conf in tdp_config_list:
+ path = os.path.join(conf["mount_path"], conf["file"])
+ df = self.read_data(spark, path, fmt)
+ dataset_info.append((df, conf["name"]))
+ return dataset_info
+
+ def generate_output_filename(self, dataset_info):
+ """Generate joined dataset filename."""
+ return "sandbox_" + "_".join([name for _, name in dataset_info]) + "_linked_anon.csv"
+
+ # -------------------------------------------------------------------------
+ # Orchestration
+ # -------------------------------------------------------------------------
def execute(self, config):
- """
- Final Execution Function
- """
+ """Run full Spark join pipeline based on config."""
tdp_config_list = self.load_tdp_list(config)
joined_dataset_config = self.load_joined_dataset_config(config)
- spark = self.create_spark_context(
- tdp_config_list
- ) # currently treated as a global instance but can be converted into a specific instance for multiple pipelines
- query = self.generate_query(tdp_config_list, joined_dataset_config)
- dataset_info = self.generate_data_info(spark, tdp_config_list)
- model_output_folder = joined_dataset_config["model_output_folder"]
- # sandbox_joined_anon_simplified=ccr_prepare_joined_dataset_full(dataset_info,query,joined_dataset_config["model_file"],debug=True)
- model_file = joined_dataset_config["joined_dataset"]
- sandbox_joined_anon_simplified = self.ccr_prepare_joined_dataset_full(
- spark, dataset_info, joined_dataset_config, query, model_file, debug=True
+
+ spark = self.create_spark_session(tdp_config_list)
+ spark.sparkContext.setLogLevel("ERROR")
+
+ query = self.validate_and_get_query(joined_dataset_config)
+
+ save_path = joined_dataset_config.get("joined_dataset")
+ fmt = save_path.split(".")[-1]
+
+ dataset_info = self.build_dataset_info(spark, fmt, tdp_config_list)
+
+ joined_dataset = self.prepare_joined_dataset(
+ spark, dataset_info, joined_dataset_config, query, save_path, fmt
)
- print("Generating aggregated data in " + model_output_folder + model_file)
- sandbox_joined_without_key_identifiers = (
- self.ccr_create_joined_dataset_wo_identifiers(
- sandbox_joined_anon_simplified,
- joined_dataset_config,
- model_output_folder + model_file,
- True,
- )
+
+ dataset_names = [d["name"] for d in tdp_config_list]
+ print(f"Joined datasets: {dataset_names}")
+
+ self.remove_identifiers(
+ joined_dataset,
+ joined_dataset_config,
+ save_path,
+ fmt
)
+
+
+class DirectoryJoin(TaskBase):
+
+ def join_datasets(self, config):
+ output_path = config["joined_dataset"]
+ Path(output_path).mkdir(parents=True, exist_ok=True)
+
+ for dataset in config["datasets"]:
+ dataset_path = dataset["mount_path"]
+ dataset_name = dataset["name"]
+
+ if os.path.isdir(dataset_path):
+ for root, dirs, files in os.walk(dataset_path):
+ rel_path = os.path.relpath(root, dataset_path)
+ target_root = os.path.join(output_path, rel_path)
+ os.makedirs(target_root, exist_ok=True)
+
+ for file in files:
+ src_file = os.path.join(root, file)
+ dst_file = os.path.join(target_root, file)
+
+ if not os.path.exists(dst_file):
+ # Avoid following symlinks
+ if os.path.islink(src_file):
+ continue
+ shutil.copy2(src_file, dst_file, follow_symlinks=False)
+ print(f"Merged dataset '{dataset_name}' into '{output_path}'")
+ else:
+ print(f"Dataset '{dataset_name}' is not a valid directory.")
+
+ print(f"\nAll datasets joined in: {output_path}")
+
+
+ def execute(self, config):
+ # Join the datasets
+ self.join_datasets(config)
+
diff --git a/src/train/pytrain/pipeline_executor.py b/src/train/pytrain/pipeline_executor.py
index 384f7e6..2cbf4a9 100644
--- a/src/train/pytrain/pipeline_executor.py
+++ b/src/train/pytrain/pipeline_executor.py
@@ -1,9 +1,26 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
import json
import argparse
from .task_base import TaskBase
-from .join import Join
-from .private_train import PrivateTrain
-from .train import Train
+from .join import *
+from .dl_train import Train_DL
+from .xgb_train import Train_XGB
class PipelineExecutor:
def __init__(self):
@@ -28,7 +45,7 @@ def execute_pipeline(self):
step_instance = step_class()
step_instance.execute(step_config)
else:
- print("Error: Class {step_name} not found or does not inherit from TaskBase.")
+ print(f"Error: Class {step_name} not found or does not inherit from TaskBase.")
def main():
diff --git a/src/train/pytrain/private_train.py b/src/train/pytrain/private_train.py
deleted file mode 100644
index fa6abd6..0000000
--- a/src/train/pytrain/private_train.py
+++ /dev/null
@@ -1,193 +0,0 @@
-# 2023, The DEPA CCR DP Training Reference Implementation
-# authors shyam@ispirt.in, sridhar.avs@ispirt.in
-#
-# Licensed TBD
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
-# Key references / Attributions: https://depa.world/training/reference-implementation
-# Key frameworks used : DEPA CCR,Opacus, PyTorch,ONNX, onnx2pytorch
-
-# torch related imports
-from typing import Optional
-import torch
-from torchvision import datasets, transforms
-
-# from tqdm import tqdm
-import torch.utils.data as data
-from torch.utils.data import DataLoader
-import torch.nn as nn
-import torch.optim as optim
-import torch.nn.functional as F
-from torch.utils.data import Dataset, DataLoader
-
-# sklearn,pandas,numpy related imports
-from sklearn.model_selection import train_test_split
-from sklearn.preprocessing import StandardScaler
-from sklearn.metrics import accuracy_score
-import numpy as np
-import pandas as pd
-
-# opacus related imports
-from opacus.accountants import create_accountant
-from opacus import PrivacyEngine
-
-# onnx related imports
-import onnx
-from onnx2pytorch import ConvertModel
-
-# other imports
-import os
-import json
-import argparse
-from pathlib import Path
-
-from .task_base import TaskBase
-
-logger = {
- "epochs_per_report": 1,
- "metrics": [
- "tdp_config",
- "tdc_config",
- "model_architecture",
- "model_hyperparameters",
- "model_config",
- "accuracy",
- "precision",
- "recall",
- ],
- "ccr_pbt_logger_file": "/mnt/remote/output/ccr_depa_trg_model_logger.json",
-}
-
-def compute_delta(ccr_context):
- return 1 / ccr_context["sample_size"]
-
-
-class CustomDataset(Dataset):
- """
- Class to convert dataset columns to tensors
- """
-
- def __init__(self, features, target):
- self.features = torch.tensor(features, dtype=torch.float32)
- self.target = torch.tensor(target.values, dtype=torch.float32)
-
- def __len__(self):
- return len(self.features)
-
- def __getitem__(self, idx):
- return self.features[idx], self.target[idx]
-
-
-class PrivateTrain(TaskBase):
- """
- Args:
- cofig:training configuration
-
- Methods:
- load_data:loads data from csv as data loaders
- load_model:loads model object from model config
- load_optimizer:loads model optimizer from model config
- make_dprivate:make model,dataloader and optimizer private
- execute_model:mega function which includes all the above functions
-
- """
-
- def ccr_logger_function(ccr_tracking_object, ccr_model):
- """
- Function to implement logging for audit/model cert
- """
- file_path = ccr_tracking_object["ccr_pbt_logger_file"]
- with open(file_path, "w") as file:
- file.write("Model Architecture\n")
- string = str(ccr_model.model)
- file.write(string)
- for c in ccr_model.logger_list:
- file.write(c)
-
- def load_data(self, config):
- # path from config
- data = pd.read_csv(config["input_dataset_path"])
- features = data.drop(columns=[config["target_variable"]])
- target = data[config["target_variable"]]
- train_features, val_features, train_target, val_target = train_test_split(
- features,
- target,
- test_size=config["test_train_split"],
- random_state=42,
- )
- scaler = StandardScaler()
- self.train_features = scaler.fit_transform(train_features)
- self.val_features = scaler.transform(val_features)
-
- train_dataset = CustomDataset(self.train_features, train_target)
- val_dataset = CustomDataset(self.val_features, val_target)
-
- batch_size = config["batch_size"]
- self.train_loader = DataLoader(
- train_dataset, batch_size=batch_size, shuffle=True
- )
- self.val_loader = DataLoader(val_dataset, batch_size=batch_size)
-
- def load_model(self, config):
- onnx_model = onnx.load(config["saved_model_path"])
- model = ConvertModel(onnx_model, experimental=True)
- self.model = model
-
- def load_optimizer(self, config):
- # self.optimizer = optim.Adam(self.model.parameters(), lr=0.001)
- self.optimizer = optim.Adam(self.model.parameters())
- # optimizer=torch.load(self.model_config["saved_model_optimizer"])
- # self.optimizer=optimizer
-
- def make_dprivate(self, config):
- privacy_engine = PrivacyEngine()
- modules = privacy_engine.make_private_with_epsilon(
- module=self.model,
- optimizer=self.optimizer,
- data_loader=self.train_loader,
- target_epsilon=config["epsilon_threshold"],
- target_delta=config["delta"],
- epochs=config["total_epochs"],
- max_grad_norm=config["max_grad_norm"],
- batch_first="True",
- )
- self.model = modules[0]
- self.optimizer = modules[1]
- self.train_loader = modules[2]
-
- def train(self, config):
- self.logger_list = []
- criterion = nn.MSELoss()
- for epoch in range(config["total_epochs"]):
- for inputs, labels in self.train_loader:
- self.optimizer.zero_grad()
- outputs = self.model(inputs)
- loss = criterion(outputs, labels.unsqueeze(1))
- loss.backward()
- self.optimizer.step()
- self.logger_list.append(
- 'Epoch [{epoch+1}/{config["total_epochs"]}], Loss: {loss.item():.4f}'
- )
- print(
- f'Epoch [{epoch+1}/{config["total_epochs"]}], Loss: {loss.item():.4f}'
- )
- output_path = config["trained_model_output_path"]
- print("Writing training model to " + output_path)
- torch.onnx.export(
- self.model,
- torch.randn(1, self.train_features.shape[1]),
- output_path,
- verbose=True,
- )
-
- def execute(self, config):
- self.load_data(config)
- self.load_model(config)
- self.load_optimizer(config)
- self.make_dprivate(config)
- self.train(config)
diff --git a/src/train/pytrain/task_base.py b/src/train/pytrain/task_base.py
index a9ab838..0d57046 100644
--- a/src/train/pytrain/task_base.py
+++ b/src/train/pytrain/task_base.py
@@ -1,3 +1,19 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
class TaskBase:
def execute(self, config):
diff --git a/src/train/pytrain/train.py b/src/train/pytrain/train.py
deleted file mode 100644
index 52068ae..0000000
--- a/src/train/pytrain/train.py
+++ /dev/null
@@ -1,59 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.optim as optim
-
-# onnx related imports
-import onnx
-from onnx2pytorch import ConvertModel
-
-from .task_base import TaskBase
-
-class Train(TaskBase):
- def load_data(self, config):
- batch_size = config["batch_size"]
-
- # Load the dataset from a .pth file
- trainset = torch.load(config["input_dataset_path"])
-
- self.trainloader = torch.utils.data.DataLoader(
- trainset, batch_size=batch_size, shuffle=True, num_workers=2
- )
-
- def load_model(self, config):
- onnx_model = onnx.load(config["saved_model_path"])
- model = ConvertModel(onnx_model, experimental=True)
- self.model = model
-
- def load_optimizer(self, config):
- self.optimizer = optim.SGD(self.model.parameters(), lr=0.001, momentum=0.9)
-
- def train(self, config):
- criterion = nn.CrossEntropyLoss()
- for epoch in range(
- config["total_epochs"]
- ): # loop over the dataset multiple times
- running_loss = 0.0
- for i, data in enumerate(self.trainloader, 0):
- # get the inputs; data is a list of [inputs, labels]
- inputs, labels = data
-
- # zero the parameter gradients
- self.optimizer.zero_grad()
-
- # forward + backward + optimize
- outputs = self.model(inputs)
- loss = criterion(outputs, labels)
- loss.backward()
- self.optimizer.step()
-
- # print statistics
- running_loss += loss.item()
- if i % 2000 == 1999: # print every 2000 mini-batches
- print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
- running_loss = 0.0
-
- def execute(self, config):
- self.load_data(config)
- self.load_model(config)
- self.load_optimizer(config)
- self.train(config)
diff --git a/src/train/pytrain/utilities/__init__.py b/src/train/pytrain/utilities/__init__.py
new file mode 100644
index 0000000..ed4f153
--- /dev/null
+++ b/src/train/pytrain/utilities/__init__.py
@@ -0,0 +1,16 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
\ No newline at end of file
diff --git a/src/train/pytrain/utilities/dataset_constructor.py b/src/train/pytrain/utilities/dataset_constructor.py
new file mode 100644
index 0000000..2d430c9
--- /dev/null
+++ b/src/train/pytrain/utilities/dataset_constructor.py
@@ -0,0 +1,759 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
+import os
+import json
+import pandas as pd
+import numpy as np
+import torch
+from torch.utils.data import Dataset, random_split
+from sklearn.preprocessing import StandardScaler, MinMaxScaler, RobustScaler, Normalizer, OrdinalEncoder
+from sklearn.model_selection import train_test_split
+import cv2
+from PIL import Image
+from torchvision.transforms import ToTensor
+from glob import glob
+from pathlib import Path
+from typing import Dict, Any, Tuple, Optional, List, Union
+from safetensors.torch import load_file as st_load
+import h5py
+
+def encode_categoricals(df: pd.DataFrame, cat_cols: list):
+ """
+ Encode categorical columns:
+ - If cardinality <= low_card_threshold -> one-hot (get_dummies)
+ - Else -> OrdinalEncoder with unknown_value = -1
+ Returns transformed dataframe and list of new feature column names.
+ """
+ low_card_threshold = max(2, int(0.01 * len(df)))
+ df = df.copy()
+ encoded_parts = []
+ kept_cols = []
+
+ high_card_cols = []
+ for c in cat_cols:
+ nunique = df[c].nunique(dropna=False)
+ if nunique <= low_card_threshold:
+ # one-hot
+ d = pd.get_dummies(df[c].astype(str), prefix=c, dummy_na=True)
+ encoded_parts.append(d)
+ kept_cols += d.columns.tolist()
+ else:
+ high_card_cols.append(c)
+
+ if high_card_cols:
+ oe = OrdinalEncoder(handle_unknown="use_encoded_value", unknown_value=-1)
+ # sklearn expects 2D array
+ arr = oe.fit_transform(df[high_card_cols].astype(str))
+ df_enc = pd.DataFrame(arr, columns=[f"{c}_ord" for c in high_card_cols], index=df.index)
+ encoded_parts.append(df_enc)
+ kept_cols += df_enc.columns.tolist()
+
+ if encoded_parts:
+ encoded_df = pd.concat(encoded_parts, axis=1)
+ else:
+ encoded_df = pd.DataFrame(index=df.index)
+
+ return encoded_df, kept_cols
+
+
+def build_feature_matrix(df: pd.DataFrame, num_cols: list, cat_encoded_df: pd.DataFrame):
+ """Stack numeric cols + encoded categorical df into final X matrix"""
+ parts = []
+ if num_cols:
+ parts.append(df[num_cols].reset_index(drop=True))
+ if not cat_encoded_df.empty:
+ parts.append(cat_encoded_df.reset_index(drop=True))
+ if parts:
+ X = pd.concat(parts, axis=1)
+ else:
+ raise ValueError("No features available after preprocessing.")
+ return X
+
+def create_dataset(config: str, data_path: str) -> Dict[str, Dataset]:
+ """
+ Create all dataset splits based on configuration file
+
+ Args:
+ config: JSON configuration dictionary
+
+ Returns:
+ Dict containing all available splits {'train': dataset, 'val': dataset, 'test': dataset}
+ """
+
+ dataset_type = config.get('type', 'tabular')
+
+ if dataset_type == 'tabular':
+ dataset = TabularDataset(config, data_path)
+ elif dataset_type == 'directory':
+ dataset = DirectoryDataset(config, data_path)
+ elif dataset_type == 'serialized':
+ dataset = SerializedDataset(config, data_path)
+ else:
+ raise ValueError(f"Unknown dataset type: {dataset_type}")
+
+ return dataset.get_all_splits()
+
+class BaseDataset:
+ """Base class with common functionality for all dataset types"""
+
+ def __init__(self, config: Dict[str, Any]):
+ self.config = config
+ self.transform = None
+ self.scaler = None
+ self.splits_data = {}
+
+ # Initialize transforms if specified
+ self._setup_transforms()
+ self._setup_preprocessing()
+
+ def _setup_transforms(self):
+ """Setup data transforms based on config"""
+ transform_config = self.config.get('transforms', {})
+ # This is a placeholder - you can extend with actual transform implementations
+ pass
+
+ def _setup_preprocessing(self):
+ """Setup preprocessing/normalization based on config"""
+ preprocessing_config = self.config.get('preprocessing', {})
+ scaler_type = preprocessing_config.get('scaler', None)
+
+ if scaler_type == 'standard':
+ self.scaler = StandardScaler()
+ elif scaler_type == 'minmax':
+ self.scaler = MinMaxScaler()
+ elif scaler_type == 'robust':
+ self.scaler = RobustScaler()
+ elif scaler_type == 'normalizer':
+ self.scaler = Normalizer()
+
+ def _create_splits(self, data, targets=None):
+ """Create train/val/test splits based on config"""
+ split_config = self.config.get('splits', {})
+
+ if not split_config:
+ # If no splits specified, return all data as train
+ return {'train': (data, targets) if targets is not None else data}
+
+ train_ratio = split_config.get('train', 1.0)
+ val_ratio = split_config.get('val', 0.0)
+ test_ratio = split_config.get('test', 0.0)
+
+ # Ensure ratios sum to 1
+ total = train_ratio + val_ratio + test_ratio
+ train_ratio /= total
+ val_ratio /= total
+ test_ratio /= total
+
+ if targets is not None:
+ # For supervised learning
+ splits = {}
+ remaining_data = data
+ remaining_targets = targets
+
+ if test_ratio > 0:
+ # Create test split first if specified
+ remaining_data, X_test, remaining_targets, y_test = train_test_split(
+ data, targets, test_size=test_ratio,
+ random_state=split_config.get('random_state', 42),
+ stratify=targets if split_config.get('stratify', False) else None
+ )
+ splits['test'] = (X_test, y_test)
+
+ if val_ratio > 0:
+ # Create validation split if specified
+ X_train, X_val, y_train, y_val = train_test_split(
+ remaining_data, remaining_targets,
+ test_size=val_ratio/(val_ratio + train_ratio),
+ random_state=split_config.get('random_state', 42),
+ stratify=remaining_targets if split_config.get('stratify', False) else None
+ )
+ splits['train'] = (X_train, y_train)
+ splits['val'] = (X_val, y_val)
+ else:
+ # Just use remaining data as train
+ splits['train'] = (remaining_data, remaining_targets)
+
+ final_splits = {'train': splits['train']}
+ if 'val' in splits:
+ final_splits['val'] = splits['val']
+ if 'test' in splits:
+ final_splits['test'] = splits['test']
+
+ return final_splits
+
+ else:
+ # For unsupervised learning or when targets are not separate
+ splits = {}
+ indices = list(range(len(data)))
+ remaining_indices = indices
+
+ if test_ratio > 0:
+ # Create test split first if specified
+ remaining_indices, test_indices = train_test_split(
+ indices, test_size=test_ratio,
+ random_state=split_config.get('random_state', 42)
+ )
+ splits['test'] = [data[i] for i in test_indices]
+
+ if val_ratio > 0:
+ # Create validation split if specified
+ train_indices, val_indices = train_test_split(
+ remaining_indices,
+ test_size=val_ratio/(val_ratio + train_ratio),
+ random_state=split_config.get('random_state', 42)
+ )
+ splits['train'] = [data[i] for i in train_indices]
+ splits['val'] = [data[i] for i in val_indices]
+ else:
+ # Just use remaining indices as train
+ splits['train'] = [data[i] for i in remaining_indices]
+
+ return splits
+
+ def get_all_splits(self):
+ """Return all available dataset splits"""
+ return self.splits_data
+
+class SplitDataset(Dataset):
+ """Individual dataset for a specific split"""
+
+ def __init__(self, features, targets=None, transform=None, scaler=None, fit_scaler=False, data_type='numpy', encoding_info: Dict = None):
+ self.transform = transform
+ self.scaler = scaler
+ self.data_type = data_type
+ self.encoding_info = encoding_info or {}
+
+ # Apply preprocessing to features
+ if self.scaler:
+ if fit_scaler:
+ self.features = self.scaler.fit_transform(features)
+ else:
+ self.features = self.scaler.transform(features)
+ else:
+ self.features = features
+
+ if self.data_type == 'tensor':
+ # Convert to tensors if numpy arrays
+ if isinstance(self.features, np.ndarray):
+ self.features = torch.tensor(self.features, dtype=torch.float32)
+ else:
+ self.features = self.features
+
+ if targets is not None:
+ if isinstance(targets, np.ndarray):
+ self.targets = torch.tensor(targets, dtype=torch.float32).reshape(-1, 1) # Reshape to match model output
+ else:
+ self.targets = targets
+ else:
+ self.targets = None
+
+ if self.data_type == 'numpy':
+ if not isinstance(self.features, np.ndarray):
+ self.features = np.array(self.features)
+ else:
+ self.features = self.features
+
+ if targets is not None:
+ if not isinstance(targets, np.ndarray):
+ self.targets = np.array(targets)
+ else:
+ self.targets = targets
+ else:
+ self.targets = None
+
+ def get_encoding_info(self) -> Dict:
+ """Return information about how categorical features were encoded"""
+ return self.encoding_info.copy()
+
+ def get_feature_names(self) -> List[str]:
+ """Return list of feature names after encoding"""
+ if self.encoding_info:
+ return (self.encoding_info.get('numerical_columns', []) +
+ self.encoding_info.get('encoded_columns', []))
+ return []
+
+ def get_categorical_columns(self) -> List[str]:
+ """Return list of original categorical column names"""
+ return self.encoding_info.get('categorical_columns', [])
+
+ def get_numerical_columns(self) -> List[str]:
+ """Return list of numerical column names"""
+ return self.encoding_info.get('numerical_columns', [])
+
+ def get_encoded_columns(self) -> List[str]:
+ """Return list of encoded categorical column names"""
+ return self.encoding_info.get('encoded_columns', [])
+
+ def get_feature_dimensions(self) -> Dict[str, int]:
+ """Return dictionary with feature dimensions for different types"""
+ return {
+ 'total_features': len(self.get_feature_names()),
+ 'numerical_features': len(self.get_numerical_columns()),
+ 'categorical_features': len(self.get_categorical_columns()),
+ 'encoded_features': len(self.get_encoded_columns())
+ }
+
+ def __len__(self):
+ if isinstance(self.features, (list, tuple)):
+ return len(self.features)
+ elif hasattr(self.features, '__len__'):
+ return len(self.features)
+ else:
+ return self.features.size(0)
+
+ def __getitem__(self, idx):
+ if isinstance(self.features, (list, tuple)):
+ features = self.features[idx]
+ else:
+ features = self.features[idx]
+
+ if self.transform:
+ features = self.transform(features)
+
+ if self.targets is not None:
+ if isinstance(self.targets, (list, tuple)):
+ targets = self.targets[idx]
+ else:
+ targets = self.targets[idx]
+ return features, targets
+ else:
+ return features
+
+class TabularDataset(BaseDataset):
+ """Dataset for tabular data (CSV, Excel, etc.)"""
+
+ def __init__(self, config: Dict[str, Any], data_path: str):
+ super().__init__(config)
+ self.data_path = data_path
+ self.target_variable = config.get('target_variable')
+ self.feature_columns = config.get('feature_columns', None)
+ self.data_type = config.get('data_type', 'numpy')
+ self.categorical_columns = []
+ self.numerical_columns = []
+ self.encoding_info = {}
+ self._load_and_split_data()
+
+ def _load_and_split_data(self):
+ """Load and preprocess tabular data, then create splits"""
+ file_ext = Path(self.data_path).suffix.lower()
+
+ # Load data based on file extension
+ if file_ext == '.csv':
+ data = pd.read_csv(self.data_path)
+ elif file_ext in ['.xlsx', '.xls']:
+ data = pd.read_excel(self.data_path)
+ elif file_ext == '.parquet':
+ data = pd.read_parquet(self.data_path)
+ else:
+ raise ValueError(f"Unsupported file format: {file_ext}")
+
+ # Handle missing values
+ missing_strategy = self.config.get('missing_strategy', 'drop')
+ if missing_strategy == 'drop':
+ data = data.dropna()
+ elif missing_strategy == 'fill':
+ fill_value = self.config.get('fill_value', 0)
+ data = data.fillna(fill_value)
+ elif missing_strategy == 'forward_fill':
+ data = data.fillna(method='ffill')
+ elif missing_strategy == 'backward_fill':
+ data = data.fillna(method='bfill')
+
+ # Select features
+ if self.feature_columns:
+ features = data[self.feature_columns]
+ elif self.target_variable:
+ features = data.drop(columns=[self.target_variable])
+ else:
+ features = data
+
+ # Extract target
+ targets = data[self.target_variable].values if self.target_variable else None
+
+ # Identify categorical and numerical columns
+ self.categorical_columns = features.select_dtypes(include=['object', 'category']).columns.tolist()
+ self.numerical_columns = features.select_dtypes(include=['int64', 'float64']).columns.tolist()
+
+ # Encode categorical features
+ if self.categorical_columns:
+ encoded_df, kept_cols = encode_categoricals(features, self.categorical_columns)
+ self.encoding_info['encoded_columns'] = kept_cols
+ self.encoding_info['categorical_columns'] = self.categorical_columns
+ self.encoding_info['numerical_columns'] = self.numerical_columns
+ else:
+ encoded_df = pd.DataFrame(index=features.index)
+ self.encoding_info['encoded_columns'] = []
+ self.encoding_info['categorical_columns'] = []
+ self.encoding_info['numerical_columns'] = self.numerical_columns
+
+ # Build the final feature matrix
+ X = build_feature_matrix(features, self.numerical_columns, encoded_df)
+
+ # Create splits
+ splits = self._create_splits(X, targets)
+
+ # Create dataset objects for each split
+ for split_name, (split_features, split_targets) in splits.items():
+ fit_scaler = (split_name == 'train') # Only fit scaler on training data
+ self.splits_data[split_name] = SplitDataset(
+ features=split_features,
+ targets=split_targets,
+ transform=self.transform,
+ scaler=self.scaler,
+ fit_scaler=fit_scaler,
+ data_type=self.data_type,
+ encoding_info=self.encoding_info
+ )
+
+class DirectoryDataset(BaseDataset):
+ """Dataset for directory-based data (images, documents, etc.)"""
+
+ def __init__(self, config: Dict[str, Any], data_path: str):
+ super().__init__(config)
+ self.data_dir = data_path
+ self.file_pattern = config.get('file_pattern', '*')
+ self.data_type = config.get('data_type', 'image')
+
+ self._load_and_split_data()
+
+ def _load_and_split_data(self):
+ """Load directory-based data and create splits"""
+ structure_type = self.config.get('structure_type', 'flat')
+
+ if structure_type == 'flat':
+ data = self._load_flat_structure()
+ elif structure_type == 'nested':
+ data = self._load_nested_structure()
+ elif structure_type == 'paired':
+ data = self._load_paired_structure()
+ else:
+ raise ValueError(f"Unknown structure type: {structure_type}")
+
+ # Create splits
+ splits = self._create_splits(data)
+
+ # Create dataset objects for each split
+ for split_name, split_data in splits.items():
+ self.splits_data[split_name] = DirectorySplitDataset(
+ split_data,
+ data_type=self.data_type,
+ config=self.config,
+ transform=self.transform
+ )
+
+ def _load_flat_structure(self):
+ """Load files from flat directory structure"""
+ return glob(os.path.join(self.data_dir, self.file_pattern))
+
+ def _load_nested_structure(self):
+ """Load files from nested directory structure (e.g., class folders)"""
+ class_folders = [d for d in os.listdir(self.data_dir)
+ if os.path.isdir(os.path.join(self.data_dir, d))]
+
+ samples = []
+ self.class_to_idx = {cls: idx for idx, cls in enumerate(sorted(class_folders))}
+
+ for class_name in class_folders:
+ class_dir = os.path.join(self.data_dir, class_name)
+ file_paths = glob(os.path.join(class_dir, self.file_pattern))
+
+ for file_path in file_paths:
+ samples.append((file_path, self.class_to_idx[class_name]))
+
+ return samples
+
+ def _load_paired_structure(self):
+ """Load paired files (e.g., image-mask pairs like BRATS)"""
+ pairing_config = self.config.get('pairing', {})
+ input_pattern = pairing_config.get('input_pattern', '*')
+ target_pattern = pairing_config.get('target_pattern', '*')
+
+ samples = []
+
+ # Handle BRATS-like structure
+ if 'folder_pattern' in pairing_config:
+ folder_pattern = pairing_config['folder_pattern']
+ patient_folders = glob(os.path.join(self.data_dir, folder_pattern))
+
+ for patient_folder in patient_folders:
+ patient_id = os.path.basename(patient_folder)
+ input_files = sorted(glob(os.path.join(patient_folder,
+ input_pattern.replace('*', patient_id))))
+
+ for input_file in input_files:
+ # Extract identifier for pairing
+ base_name = os.path.basename(input_file)
+ identifier = self._extract_identifier(base_name, patient_id, pairing_config)
+
+ # Find corresponding target file
+ target_file = os.path.join(patient_folder,
+ target_pattern.replace('*', patient_id).replace('{id}', identifier))
+
+ if os.path.exists(target_file):
+ # Optional filtering (e.g., non-empty masks)
+ if self._should_include_sample(input_file, target_file):
+ samples.append((input_file, target_file))
+
+ return samples
+
+ def _extract_identifier(self, filename: str, patient_id: str, pairing_config: Dict) -> str:
+ """Extract identifier for file pairing"""
+ # Remove patient ID and extract slice/identifier info
+ identifier_pattern = pairing_config.get('identifier_extraction', {})
+
+ if 'remove_prefix' in identifier_pattern:
+ filename = filename.replace(identifier_pattern['remove_prefix'], '')
+ if 'remove_suffix' in identifier_pattern:
+ filename = filename.replace(identifier_pattern['remove_suffix'], '')
+
+ return filename
+
+ def _should_include_sample(self, input_file: str, target_file: str) -> bool:
+ """Determine if sample should be included (e.g., non-empty masks)"""
+ filter_config = self.config.get('filtering', {})
+
+ if filter_config.get('filter_empty_targets', False):
+ if self.data_type == 'image':
+ target_img = cv2.imread(target_file)
+ return not np.all(target_img == 0)
+
+ return True
+
+class DirectorySplitDataset(Dataset):
+ """Individual dataset for directory-based data splits"""
+
+ def __init__(self, data, data_type='image', config=None, transform=None):
+ self.data = data
+ self.data_type = data_type
+ self.config = config or {}
+ self.transform = transform
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ sample = self.data[idx]
+
+ if isinstance(sample, tuple) and len(sample) == 2:
+ # Check if it's (file_path, label) or (input_path, target_path)
+ first_item, second_item = sample
+ if isinstance(second_item, (int, float)):
+ # It's (file_path, label)
+ data = self._load_file(first_item)
+ return data, second_item
+ else:
+ # It's (input_path, target_path)
+ input_data = self._load_file(first_item)
+ target_data = self._load_file(second_item)
+ return input_data, target_data
+ else:
+ # Single file path
+ return self._load_file(sample)
+
+ def _load_file(self, file_path: str):
+ """Load individual file based on data type"""
+ if self.data_type == 'image':
+ return self._load_image(file_path)
+ elif self.data_type == 'text':
+ return self._load_text(file_path)
+ elif self.data_type == 'audio':
+ return self._load_audio(file_path)
+ else:
+ # Generic file loading
+ with open(file_path, 'rb') as f:
+ return f.read()
+
+ def _load_image(self, image_path: str):
+ """Load and preprocess image"""
+ image_config = self.config.get('image_config', {})
+
+ if image_config.get('use_cv2', False):
+ img = cv2.imread(image_path)
+ if image_config.get('convert_to_pil', False):
+ img = Image.fromarray(img)
+ else:
+ img = Image.open(image_path)
+
+ # Convert to grayscale if specified
+ if image_config.get('grayscale', False):
+ img = img.convert('L') if hasattr(img, 'convert') else cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
+
+ # Convert to tensor
+ if image_config.get('to_tensor', True):
+ img = ToTensor()(img)
+
+ # Binarize if specified
+ if image_config.get('binarize', False):
+ img = (img > image_config.get('binarize_threshold', 0)).float()
+
+ if self.transform:
+ img = self.transform(img)
+
+ return img
+
+ def _load_text(self, text_path: str):
+ """Load text file"""
+ encoding = self.config.get('text_config', {}).get('encoding', 'utf-8')
+ with open(text_path, 'r', encoding=encoding) as f:
+ return f.read()
+
+ def _load_audio(self, audio_path: str):
+ """Load audio file (placeholder - requires audio library)"""
+ # This would require librosa or similar
+ raise NotImplementedError("Audio loading requires additional dependencies")
+
+class SerializedDataset(BaseDataset):
+ """Dataset for serialized data (.pth, .pkl, etc.)"""
+
+ def __init__(self, config: Dict[str, Any], data_path: str):
+ super().__init__(config)
+ self.data_path = data_path
+ self.serialization_format = config.get('format', 'torch')
+
+ self._load_and_split_data()
+
+ def _load_and_split_data(self):
+ """Load serialized data and create splits using safe formats only.
+
+ Supported formats:
+ - safetensors: expects tensor keys (default 'features', 'targets')
+ - hdf5: expects dataset keys provided via config ('features_key', 'targets_key')
+ - parquet: expects columns provided via config ('features_key', 'targets_key') or defaults
+
+ Deprecated/disabled formats for security: pickle, raw torch .pt/.pth (use safetensors instead).
+ """
+ fmt = self.serialization_format.lower()
+ structure = self.config.get('structure', 'list_of_tuples')
+ features_key = self.config.get('features_key', 'features')
+ targets_key = self.config.get('targets_key', 'targets')
+
+ if fmt in ('pt', 'pth', 'torch'):
+ raise RuntimeError("Loading raw PyTorch checkpoint files is disabled. Please export as safetensors, hdf5, or parquet.")
+ if fmt in ('pkl', 'pickle'):
+ raise RuntimeError("Loading datasets via pickle is disabled for security. Use safetensors, hdf5, or parquet.")
+
+ data = None
+
+ if fmt == 'safetensors':
+
+ tensors = st_load(self.data_path, device='cpu')
+ if features_key not in tensors or targets_key not in tensors:
+ raise KeyError(f"safetensors file must contain '{features_key}' and '{targets_key}' keys")
+ feats = tensors[features_key]
+ targs = tensors[targets_key]
+ if structure == 'list_of_tuples':
+ n = feats.shape[0]
+ if targs.shape[0] != n:
+ raise ValueError("features and targets must have the same first dimension")
+ data = [(feats[i], targs[i]) for i in range(n)]
+ elif structure in ('dict', 'separate_tensors'):
+ self.dataset = {features_key: feats, targets_key: targs}
+ else:
+ raise ValueError(f"Unsupported structure '{structure}' for safetensors")
+
+ elif fmt == 'hdf5':
+ with h5py.File(self.data_path, 'r') as f:
+ if features_key not in f or targets_key not in f:
+ raise KeyError(f"HDF5 file must contain '{features_key}' and '{targets_key}' datasets")
+ feats = f[features_key][...]
+ targs = f[targets_key][...]
+ if structure == 'list_of_tuples':
+ if len(feats) != len(targs):
+ raise ValueError("features and targets must have the same length")
+ data = list(zip(feats, targs))
+ elif structure in ('dict', 'separate_tensors'):
+ self.dataset = {features_key: feats, targets_key: targs}
+ else:
+ raise ValueError(f"Unsupported structure '{structure}' for hdf5")
+
+ elif fmt == 'parquet':
+ # Read with pandas; require feature/target column names
+ df = pd.read_parquet(self.data_path)
+ if features_key not in df.columns or targets_key not in df.columns:
+ raise KeyError(f"Parquet file must contain columns '{features_key}' and '{targets_key}'")
+ if structure == 'list_of_tuples':
+ data = list(zip(df[features_key].to_list(), df[targets_key].to_list()))
+ elif structure in ('dict', 'separate_tensors'):
+ self.dataset = {features_key: df[features_key].to_numpy(), targets_key: df[targets_key].to_numpy()}
+ else:
+ raise ValueError(f"Unsupported structure '{structure}' for parquet")
+
+ else:
+ raise ValueError(f"Unsupported serialization format: {self.serialization_format}")
+
+ # Handle different dataset structures
+ structure = self.config.get('structure', 'list_of_tuples')
+
+ if structure == 'list_of_tuples':
+ # Dataset is a list of (input, target) tuples; if not yet built, convert now
+ if data is None:
+ # Convert dict of arrays/tensors into list of tuples
+ if isinstance(self.dataset, dict):
+ feats = self.dataset.get(features_key)
+ targs = self.dataset.get(targets_key)
+ if feats is None or targs is None:
+ raise KeyError("Missing features/targets in dataset for list_of_tuples structure")
+ n = len(feats)
+ if len(targs) != n:
+ raise ValueError("features and targets must have the same length")
+ data = [(feats[i], targs[i]) for i in range(n)]
+ else:
+ data = self.dataset
+ # else: data already prepared above
+ elif structure == 'dict':
+ # Dataset is a dictionary with 'data' and 'targets' keys
+ data = list(zip(self.dataset['data'], self.dataset['targets']))
+ elif structure == 'separate_tensors':
+ # Dataset has separate feature and target tensors
+ features = self.dataset[features_key]
+ targets = self.dataset[targets_key]
+ data = list(zip(features, targets))
+
+ # Create splits
+ splits = self._create_splits(data)
+
+ # Create dataset objects for each split
+ for split_name, split_data in splits.items():
+ self.splits_data[split_name] = SerializedSplitDataset(
+ split_data,
+ transform=self.transform
+ )
+
+class SerializedSplitDataset(Dataset):
+ """Individual dataset for serialized data splits"""
+
+ def __init__(self, data, transform=None):
+ self.data = data
+ self.transform = transform
+
+ def __len__(self):
+ return len(self.data)
+
+ def __getitem__(self, idx):
+ sample = self.data[idx]
+
+ if isinstance(sample, tuple) and len(sample) == 2:
+ input_data, target = sample
+
+ # Apply transforms if specified
+ if self.transform:
+ input_data = self.transform(input_data)
+
+ return input_data, target
+ else:
+ return sample
\ No newline at end of file
diff --git a/src/train/pytrain/utilities/dp_xgboost.py b/src/train/pytrain/utilities/dp_xgboost.py
new file mode 100644
index 0000000..23b7f13
--- /dev/null
+++ b/src/train/pytrain/utilities/dp_xgboost.py
@@ -0,0 +1,281 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
+from typing import Any, Dict, Optional, Sequence, Tuple
+import json
+import math
+import logging
+
+import numpy as np
+import xgboost as xgb
+
+logger = logging.getLogger(__name__)
+logging.basicConfig(level=logging.INFO)
+
+
+class DPXGBoost:
+ def __init__(self, xgb_params, privacy_params = None):
+ """
+ xgb_params: config passed to xgb.train
+ privacy_params: optional config.
+ mechanism: 'laplace' or 'gaussian'
+ epsilon: float (required)
+ delta: float (required when gaussian)
+ clip_value: float (required)
+ """
+ self.xgb_params = xgb_params
+ self.privacy_params = privacy_params
+ self.privacy_enabled = True if privacy_params is not None else False
+
+ # Validate privacy params only if enabled
+ if self.privacy_enabled:
+ mech = self.privacy_params.get('mechanism', 'gaussian')
+ if mech not in ('laplace', 'gaussian'):
+ raise ValueError("mechanism must be 'laplace' or 'gaussian'")
+ if 'epsilon' not in self.privacy_params:
+ raise ValueError('epsilon is required in privacy_params when privacy enabled')
+ if mech == 'gaussian' and 'delta' not in self.privacy_params:
+ raise ValueError('delta is required for gaussian mechanism')
+ if 'clip_value' not in self.privacy_params:
+ raise ValueError('clip_value is required in privacy_params')
+
+ # bookkeeping
+ self.bst: Optional[xgb.Booster] = None
+ self._noisy_leaf_values: Optional[Dict[Tuple[int, int], float]] = None
+ self._orig_leaf_values: Optional[Dict[Tuple[int, int], float]] = None
+ self._trained_num_trees: Optional[int] = None
+
+ # Delegation: let wrapper expose Booster attributes once a booster exists
+ def __getattr__(self, name: str):
+ if name.startswith("_"):
+ raise AttributeError(name)
+ if self.bst is not None and hasattr(self.bst, name):
+ return getattr(self.bst, name)
+ raise AttributeError(f"{self.__class__.__name__} has no attribute {name}")
+
+
+ def fit(self,
+ X: Any = None,
+ y: Any = None,
+ num_boost_round: int = 100,
+ dtrain: Optional[xgb.DMatrix] = None,
+ evals: Optional[Sequence[Tuple[xgb.DMatrix, str]]] = None,
+ **train_kwargs) -> xgb.Booster:
+ """
+ Train via xgb.train. If privacy enabled, compute noisy leaf mapping.
+ """
+ if dtrain is None:
+ if X is None or y is None:
+ raise ValueError("Provide either dtrain or X and y")
+ dtrain = xgb.DMatrix(X, label=y)
+
+ if self.privacy_enabled:
+ logger.info("Privacy enabled: DP noise will be computed after training.")
+ else:
+ logger.info("Privacy disabled: running standard non-DP training.")
+
+ # Train using xgboost's train (delegates heavy training to xgboost)
+ self.bst = xgb.train(self.xgb_params, dtrain, num_boost_round=num_boost_round, evals=evals, **train_kwargs)
+
+ # cache number of trees (dump length)
+ dump = self.bst.get_dump(dump_format='json')
+ self._trained_num_trees = len(dump)
+ logger.info("Trained %d trees", self._trained_num_trees)
+
+ # original leaf values
+ self._orig_leaf_values = self._extract_leaf_values(self.bst)
+
+ # compute noisy mapping only if enabled, otherwise copy original map
+ if self.privacy_enabled:
+ self._noisy_leaf_values = self._compute_dp_leaf_noise(self.bst, dtrain)
+ else:
+ self._noisy_leaf_values = dict(self._orig_leaf_values)
+
+ return self.bst
+
+ def predict(self, X: Any, dp: bool = False, **predict_kwargs) -> np.ndarray:
+ """
+ If dp=False: delegate to underlying Booster.predict.
+ If dp=True:
+ - If privacy disabled -> returns same as non-DP predictions (we use noisy_map==orig_map).
+ - If privacy enabled -> compute predictions by summing noisy leaf values.
+ NOTE: For binary classification with objective 'binary:logistic', booster.predict(dp=False)
+ returns probabilities, whereas predict(dp=True) returns the raw margin from base_score +
+ sum(eta * noisy_leaf). Convert with sigmoid if you need probability-like output.
+ """
+ if self.bst is None:
+ raise ValueError("Model not trained yet. Call fit(...) first.")
+
+ dmat = xgb.DMatrix(X)
+ if not dp:
+ return self.bst.predict(dmat, **predict_kwargs)
+
+ # dp prediction (noisy leaf aggregation)
+ if self._noisy_leaf_values is None or self._orig_leaf_values is None:
+ raise ValueError("No leaf mappings available. Call fit(...) first.")
+
+ leaves = self.bst.predict(dmat, pred_leaf=True)
+ n_samples, n_trees = leaves.shape
+ if self._trained_num_trees is not None and n_trees != self._trained_num_trees:
+ logger.warning("pred_leaf reported %d trees but trained had %d", n_trees, self._trained_num_trees)
+
+ lr = float(self.xgb_params.get("eta", self.xgb_params.get("learning_rate", 0.3)))
+ base = float(self.xgb_params.get("base_score", 0.5))
+
+ preds = np.full(n_samples, base, dtype=float)
+
+ for t in range(n_trees):
+ # build vector of leaf values for tree t
+ tree_vals = np.zeros(n_samples, dtype=float)
+ for i in range(n_samples):
+ leaf_id = int(leaves[i, t])
+ key = (t, leaf_id)
+ val = self._noisy_leaf_values.get(key)
+ if val is None:
+ val = self._orig_leaf_values.get(key, 0.0)
+ tree_vals[i] = val
+ preds += lr * tree_vals
+
+ return preds
+
+ def save_model(self, path: str) -> None:
+ if self.bst is None:
+ raise ValueError("No trained model to save")
+ return self.bst.save_model(path)
+
+ def load_model(self, path: str) -> xgb.Booster:
+ self.bst = xgb.Booster()
+ self.bst.load_model(path)
+ dump = self.bst.get_dump(dump_format='json')
+ self._trained_num_trees = len(dump)
+ self._orig_leaf_values = self._extract_leaf_values(self.bst)
+ self._noisy_leaf_values = None
+ return self.bst
+
+ def get_booster(self) -> Optional[xgb.Booster]:
+ return self.bst
+
+ # ----------------
+ # Internal helpers
+ # ----------------
+ def _extract_leaf_values(self, booster: xgb.Booster) -> Dict[Tuple[int, int], float]:
+ """
+ Return mapping (tree_idx, nodeid) -> leaf_value.
+ Prefer trees_to_dataframe, fallback to JSON dump parsing.
+ """
+ mapping: Dict[Tuple[int, int], float] = {}
+ try:
+ df = booster.trees_to_dataframe()
+ if "Leaf" in df.columns:
+ leaf_rows = df.loc[df["Leaf"].notnull(), ["Tree", "Node", "Leaf"]]
+ for _, row in leaf_rows.iterrows():
+ t = int(row["Tree"])
+ nodeid = int(row["Node"])
+ val = float(row["Leaf"])
+ mapping[(t, nodeid)] = val
+ if mapping:
+ return mapping
+ except Exception:
+ logger.debug("trees_to_dataframe failed; falling back to JSON parse", exc_info=True)
+
+ dump = booster.get_dump(dump_format="json")
+ for t, tree_json in enumerate(dump):
+ try:
+ tree = json.loads(tree_json)
+ except Exception:
+ continue
+
+ def _walk(node):
+ if "leaf" in node:
+ nodeid = int(node.get("nodeid", -1))
+ mapping[(t, nodeid)] = float(node["leaf"])
+ else:
+ if "children" in node and isinstance(node["children"], list):
+ for c in node["children"]:
+ _walk(c)
+ else:
+ for k in ("yes", "no", "missing"):
+ child = node.get(k)
+ if isinstance(child, dict):
+ _walk(child)
+
+ if isinstance(tree, dict):
+ _walk(tree)
+ elif isinstance(tree, list):
+ for node in tree:
+ _walk(node)
+
+ return mapping
+
+ def _compute_dp_leaf_noise(self, booster: xgb.Booster, dtrain: xgb.DMatrix) -> Dict[Tuple[int, int], float]:
+ """
+ Compute noisy leaf mapping using a simple per-tree uniform budget split.
+ Sensitivity approximated as: clip_value / (leaf_count * min_hessian + l2_reg)
+ """
+ if self._orig_leaf_values is None:
+ self._orig_leaf_values = self._extract_leaf_values(booster)
+
+ leaves = booster.predict(dtrain, pred_leaf=True)
+ n_samples, n_trees = leaves.shape
+ num_trees = n_trees
+
+ # counts per (tree, leaf)
+ leaf_counts: Dict[Tuple[int, int], int] = {}
+ for i in range(n_samples):
+ for t in range(num_trees):
+ leafid = int(leaves[i, t])
+ key = (t, leafid)
+ leaf_counts[key] = leaf_counts.get(key, 0) + 1
+
+ mech = self.privacy_params.get("mechanism", "gaussian")
+ epsilon = float(self.privacy_params["epsilon"])
+ delta = float(self.privacy_params.get("delta", 0.0))
+ clip = float(self.privacy_params["clip_value"])
+ min_hessian = float(self.privacy_params.get("min_hessian", 1.0))
+ budget_alloc = self.privacy_params.get("budget_allocation", "uniform")
+ privacy_seed = self.privacy_params.get("privacy_seed", None)
+
+ l2_reg = float(self.xgb_params.get("lambda", self.xgb_params.get("reg_lambda", 1.0)))
+
+ if budget_alloc != "uniform":
+ logger.warning("Only uniform budget_allocation supported; falling back to uniform")
+ eps_per_tree = epsilon / max(1, num_trees)
+ delta_per_tree = (delta / max(1, num_trees)) if mech == "gaussian" and delta > 0 else None
+
+ rng = np.random.default_rng(privacy_seed)
+
+ noisy_map: Dict[Tuple[int, int], float] = {}
+ for (t, leafid), orig_val in self._orig_leaf_values.items():
+ count = leaf_counts.get((t, leafid), 0)
+ denom = count * min_hessian + l2_reg
+ if denom <= 0:
+ denom = 1e-6
+ sensitivity = clip / denom
+
+ if mech == "laplace":
+ scale = sensitivity / max(1e-12, eps_per_tree)
+ noise = rng.laplace(0.0, scale)
+ else:
+ if delta_per_tree is None or delta_per_tree <= 0:
+ raise ValueError("delta must be positive for gaussian mechanism")
+ sigma = math.sqrt(2.0 * math.log(1.25 / delta_per_tree)) * sensitivity / max(1e-12, eps_per_tree)
+ noise = rng.normal(0.0, sigma)
+
+ noisy_map[(t, leafid)] = float(orig_val + noise)
+
+ logger.info("Computed noisy leaf mapping for %d trees", num_trees)
+ return noisy_map
\ No newline at end of file
diff --git a/src/train/pytrain/utilities/eval_tools.py b/src/train/pytrain/utilities/eval_tools.py
new file mode 100644
index 0000000..504896a
--- /dev/null
+++ b/src/train/pytrain/utilities/eval_tools.py
@@ -0,0 +1,384 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
+import os
+import json
+from typing import Any, Dict, List, Union
+from collections import defaultdict
+
+import numpy as np
+import torch
+import matplotlib.pyplot as plt
+from sklearn.metrics import (
+ accuracy_score, f1_score, precision_score, recall_score,
+ roc_auc_score, confusion_matrix, classification_report,
+ precision_recall_curve, roc_curve,
+ mean_squared_error, mean_absolute_error, r2_score
+)
+
+# --- JSON serialization helper ---
+def _to_json_safe(obj: Any):
+ """Recursively convert NumPy/Torch scalars/arrays to Python-native types for JSON serialization."""
+ if isinstance(obj, (np.floating, np.integer)):
+ return obj.item()
+ if isinstance(obj, np.ndarray):
+ return obj.tolist()
+ try:
+ if isinstance(obj, torch.Tensor):
+ return obj.detach().cpu().tolist()
+ except Exception:
+ pass
+ if isinstance(obj, dict):
+ return {k: _to_json_safe(v) for k, v in obj.items()}
+ if isinstance(obj, (list, tuple)):
+ return [_to_json_safe(v) for v in obj]
+ if isinstance(obj, (float, int, str, bool)) or obj is None:
+ return obj
+ # Fallback: try float conversion, else string
+ try:
+ return float(obj)
+ except Exception:
+ return str(obj)
+
+# --- Helpers for segmentation/regression etc. ---
+def dice_score_np(y_pred: np.ndarray, y_true: np.ndarray, eps: float = 1e-6, threshold: float = 0.5) -> float:
+ p = (y_pred > threshold).astype(np.uint8) * 255
+ t = (y_true > threshold).astype(np.uint8) * 255
+ inter = (p * t).sum()
+ return float((2.0 * inter) / (p.sum() + t.sum() + eps))
+
+def jaccard_index_np(y_pred: np.ndarray, y_true: np.ndarray, eps: float = 1e-6, threshold: float = 0.5) -> float:
+ p = (y_pred > threshold).astype(np.uint8) * 255
+ t = (y_true > threshold).astype(np.uint8) * 255
+ inter = (p * t).sum()
+ union = p.sum() + t.sum() - inter
+ return float(inter / (union + eps))
+
+def hausdorff_distance_np(y_pred: np.ndarray, y_true: np.ndarray, threshold: float = 0.5) -> float:
+ try:
+ from scipy.spatial.distance import directed_hausdorff
+ except Exception:
+ return float("nan") # scipy not available
+ pred_pts = np.argwhere(y_pred > threshold)
+ true_pts = np.argwhere(y_true > threshold)
+ if len(pred_pts) == 0 or len(true_pts) == 0:
+ return float("inf")
+ return float(max(directed_hausdorff(pred_pts, true_pts)[0],
+ directed_hausdorff(true_pts, pred_pts)[0]))
+
+# --- Metric registry format:
+# "metric_name": {
+# "fn": callable(y_pred, y_true, params, meta) -> scalar|array|tuple|string,
+# "output": "scalar"|"plot"|"text",
+# "requires_proba": bool (if True, uses probability/score column instead of argmax)
+# }
+# meta passed to fn includes {"is_binary": bool, "n_classes": int, "task": str}
+# ---
+METRIC_REGISTRY: Dict[str, Dict[str, Any]] = {
+ # Classification - scalar
+ "accuracy": {
+ "fn": lambda yp, yt, p, m: accuracy_score(yt, np.argmax(yp, axis=1) if yp.ndim > 1 else (yp > 0.5).astype(int)),
+ "output": "scalar",
+ "requires_proba": False
+ },
+ "f1_score": {
+ "fn": lambda yp, yt, p, m: f1_score(
+ yt,
+ np.argmax(yp, axis=1) if yp.ndim > 1 else (yp > 0.5).astype(int),
+ **(p or {"average": ("binary" if m["is_binary"] else "macro")})
+ ),
+ "output": "scalar",
+ "requires_proba": False
+ },
+ "precision": {
+ "fn": lambda yp, yt, p, m: precision_score(
+ yt,
+ np.argmax(yp, axis=1) if yp.ndim > 1 else (yp > 0.5).astype(int),
+ **(p or {"average": ("binary" if m["is_binary"] else "macro")})
+ ),
+ "output": "scalar",
+ "requires_proba": False
+ },
+ "recall": {
+ "fn": lambda yp, yt, p, m: recall_score(
+ yt,
+ np.argmax(yp, axis=1) if yp.ndim > 1 else (yp > 0.5).astype(int),
+ **(p or {"average": ("binary" if m["is_binary"] else "macro")})
+ ),
+ "output": "scalar",
+ "requires_proba": False
+ },
+ "roc_auc": {
+ # for binary: expects score/prob for positive class; for multiclass sklearn supports multi_class param (ovr/ovo) if provided via params
+ "fn": lambda yp, yt, p, m: (
+ float(roc_auc_score(yt, yp[:, 1], **(p or {}))) if (yp.ndim > 1 and yp.shape[1] > 1 and m["is_binary"])
+ else float(roc_auc_score(yt, yp[:, 1], **(p or {}))) if (yp.ndim > 1 and yp.shape[1] == 2)
+ else float(roc_auc_score(yt, yp, **(p or {}))) if yp.ndim == 1
+ else (float(roc_auc_score(yt, yp, **(p or {}))) if not m["is_binary"] else float("nan"))
+ ),
+ "output": "scalar",
+ "requires_proba": True
+ },
+
+ # Classification - plot/text
+ "confusion_matrix": {
+ "fn": lambda yp, yt, p, m: confusion_matrix(yt, np.argmax(yp, axis=1) if yp.ndim>1 else (yp>0.5).astype(int)),
+ "output": "plot",
+ "requires_proba": False,
+ "filename": "confusion_matrix.png"
+ },
+ "classification_report": {
+ "fn": lambda yp, yt, p, m: classification_report(yt, np.argmax(yp, axis=1) if yp.ndim>1 else (yp>0.5).astype(int)),
+ "output": "text",
+ "requires_proba": False,
+ "filename": "classification_report.txt"
+ },
+ "precision_recall_curve": {
+ "fn": lambda yp, yt, p, m: precision_recall_curve(yt, (yp[:, 1] if (yp.ndim>1 and yp.shape[1]>1) else yp).ravel()),
+ "output": "plot",
+ "requires_proba": True,
+ "filename": "precision_recall_curve.png"
+ },
+ "roc_curve": {
+ "fn": lambda yp, yt, p, m: roc_curve(yt, (yp[:, 1] if (yp.ndim>1 and yp.shape[1]>1) else yp).ravel()),
+ "output": "plot",
+ "requires_proba": True,
+ "filename": "roc_curve.png"
+ },
+
+ # Segmentation
+ "dice_score": {
+ "fn": lambda yp, yt, p, m: dice_score_np(yp, yt, threshold=p.get("threshold", 0.5)),
+ "output": "scalar",
+ "requires_proba": False
+ },
+ "jaccard_index": {
+ "fn": lambda yp, yt, p, m: jaccard_index_np(yp, yt, threshold=p.get("threshold", 0.5)),
+ "output": "scalar",
+ "requires_proba": False
+ },
+ "hausdorff_distance": {
+ "fn": lambda yp, yt, p, m: hausdorff_distance_np(yp, yt, threshold=p.get("threshold", 0.5)),
+ "output": "scalar",
+ "requires_proba": False
+ },
+
+ # Regression
+ "mse": {
+ "fn": lambda yp, yt, p, m: float(mean_squared_error(yt, yp)),
+ "output": "scalar",
+ "requires_proba": False
+ },
+ "mae": {
+ "fn": lambda yp, yt, p, m: float(mean_absolute_error(yt, yp)),
+ "output": "scalar",
+ "requires_proba": False
+ },
+ "rmse": {
+ "fn": lambda yp, yt, p, m: float(mean_squared_error(yt, yp, squared=False)),
+ "output": "scalar",
+ "requires_proba": False
+ },
+ "r2_score": {
+ "fn": lambda yp, yt, p, m: float(r2_score(yt, yp)),
+ "output": "scalar",
+ "requires_proba": False
+ }
+}
+
+# --- Utility: parse metrics config: allow string or dict {"name":.., "params":{...}} ---
+def parse_metrics_config(metrics_config: Union[List[Any], None]) -> List[Dict[str, Any]]:
+ if not metrics_config:
+ return []
+ parsed = []
+ for entry in metrics_config:
+ if isinstance(entry, str):
+ parsed.append({"name": entry, "params": None})
+ elif isinstance(entry, dict):
+ parsed.append({"name": entry.get("name"), "params": entry.get("params", None)})
+ else:
+ raise ValueError("Metric config entries must be either str or dict")
+ return parsed
+
+
+def compute_metrics(preds_list, targets_list, test_loss=None, config=None):
+ metrics = parse_metrics_config(config.get("metrics", []))
+ task_type = config.get("task_type", "")
+ save_path = config.get("paths", {}).get("trained_model_output_path", "")
+ # n_pred_samples = config.get("n_pred_samples", 0)
+ threshold = config.get("threshold", 0.5) if config else 0.5
+
+ if len(preds_list) == 0:
+ raise ValueError("Predictions on test set are empty. Please check the test loader.")
+
+ # y_pred_all = torch.cat(preds_list, dim=0).numpy()
+ # y_true_all = torch.cat(targets_list, dim=0).numpy() if len(targets_list) else None
+ y_pred_all = np.array(preds_list)
+ y_true_all = np.array(targets_list) if len(targets_list) else None
+
+ # ------------------------
+ # Auto-choose default metrics if none provided
+ # ------------------------
+ if not metrics:
+ if task_type == "classification" and y_true_all is not None:
+ n_classes = len(np.unique(y_true_all))
+ if n_classes == 2:
+ metrics = [{"name": "classification_report"}, {"name": "roc_auc"}]
+ else:
+ metrics = [{"name": "classification_report"}]
+ elif task_type == "segmentation":
+ metrics = [{"name": "dice_score"}, {"name": "jaccard_index"}]
+ elif task_type == "regression":
+ metrics = [{"name": "mse"}, {"name": "mae"}, {"name": "r2_score"}]
+
+ # meta info passed to metric fns
+ meta = {}
+ if y_true_all is not None and task_type == "classification":
+ unique = np.unique(y_true_all)
+ meta["n_classes"] = int(unique.size)
+ meta["is_binary"] = (unique.size == 2)
+ else:
+ meta["n_classes"] = None
+ meta["is_binary"] = False
+
+ # ------------------------
+ # Metric computation
+ # ------------------------
+ if test_loss is not None:
+ numeric_metrics = {"test_loss": test_loss}
+ else:
+ numeric_metrics = {}
+
+ if y_true_all is not None:
+ n_classes = len(np.unique(y_true_all)) if task_type == "classification" else None
+ is_binary = (n_classes == 2)
+
+ for m in metrics:
+ name = m["name"]
+ params = m.get("params", None)
+ entry = METRIC_REGISTRY.get(name)
+ if entry is None:
+ numeric_metrics[name] = f"Metric {name} not implemented. Please raise an issue on GitHub."
+ continue
+ try:
+ result = entry["fn"](y_pred_all, y_true_all, params, {"is_binary": meta["is_binary"], "n_classes": meta["n_classes"], "task": task_type})
+
+ # Route outputs by declared type
+ if entry["output"] == "scalar":
+ # ensure JSON serializable float
+ numeric_metrics[name] = float(result) if (isinstance(result, (int, float, np.floating, np.integer))) else result
+
+ elif entry["output"] == "text":
+ txt = str(result)
+ fname = entry.get("filename", f"{name}.txt")
+ if save_path:
+ with open(os.path.join(save_path, fname), "w") as f:
+ f.write(txt)
+ # textual outputs not included in numeric JSON
+
+ elif entry["output"] == "plot":
+ # Expect result to be either array-like (matrix) or tuple for curve (x,y) or (precision,recall,_)
+ fname = entry.get("filename", f"{name}.png")
+ if save_path:
+ plt.figure()
+ # confusion matrix -> 2D array
+ if isinstance(result, (list, np.ndarray)) and np.asarray(result).ndim == 2:
+ arr = np.asarray(result)
+ plt.imshow(arr, interpolation='nearest', cmap='Blues')
+ plt.colorbar()
+ plt.title(name.replace("_", " ").title())
+ plt.xlabel("Predicted")
+ plt.ylabel("True")
+ # curve -> tuple (x,y,maybe thresholds)
+ elif isinstance(result, (tuple, list)) and len(result) >= 2:
+ x, y = result[0], result[1]
+ plt.plot(x, y)
+ plt.title(name.replace("_", " ").title())
+ plt.xlabel("x")
+ plt.ylabel("y")
+ else:
+ # fallback: try plotting 1D
+ arr = np.asarray(result)
+ if arr.ndim == 1:
+ plt.plot(arr)
+ plt.title(name.replace("_", " ").title())
+ else:
+ plt.text(0.1, 0.5, "Cannot plot result", fontsize=12)
+ plt.tight_layout()
+ plt.savefig(os.path.join(save_path, fname))
+ plt.close()
+
+ except Exception as e:
+ numeric_metrics[name] = f"Failed: {e}"
+
+ # Save metrics
+ if save_path and len(numeric_metrics) > 0:
+ with open(os.path.join(save_path, "evaluation_metrics.json"), "w") as f:
+ json.dump(numeric_metrics, f, indent=4, default=_to_json_safe)
+
+
+ # TO EXPLORE: Can sample predictions be saved without breaking privacy constraints?
+
+ """
+ # Save sample predictions
+ if save_path and n_pred_samples > 0 and y_true_all is not None:
+ nsave = min(n_pred_samples, len(y_pred_all))
+
+ # Pre-compute common paths and conversions
+ if task_type == "segmentation":
+ for i in range(nsave):
+ pred_img = (y_pred_all[i] > threshold).astype(np.uint8) * 255
+
+ # plt.imsave(os.path.join(save_path, f"pred_{i+1}_binarized.png"), pred_img, cmap="gray", vmin=0, vmax=255)
+ plt.imsave(os.path.join(save_path, f"pred_{i+1}.png"), y_pred_all[i], cmap="gray", vmin=0, vmax=255)
+ # plt.imsave(os.path.join(save_path, f"mask_{i+1}.png"), true_img, cmap="gray")
+
+ elif task_type == "classification":
+ is_multiclass = y_pred_all.ndim > 1 and y_pred_all.shape[1] > 1
+ preds = y_pred_all[:nsave]
+ trues = y_true_all[:nsave]
+
+ for i in range(nsave):
+ if is_multiclass:
+ pred_data = {
+ "pred_label": int(np.argmax(preds[i])),
+ "scores": preds[i].tolist(),
+ "true": trues[i].tolist() if hasattr(trues[i], "tolist") else int(trues[i])
+ }
+ else:
+ pred_data = {
+ "pred_label": int((preds[i] > 0.5).astype(int)),
+ "scores": float(preds[i].ravel()[0]),
+ "true": int(trues[i]) if hasattr(trues[i], "__iter__") and not isinstance(trues[i], str) else trues[i]
+ }
+ with open(os.path.join(save_path, f"pred_{i+1}.txt"), "w") as f:
+ json.dump(pred_data, f, default=_to_json_safe)
+
+ elif task_type == "regression":
+ preds = y_pred_all[:nsave]
+ trues = y_true_all[:nsave]
+
+ for i in range(nsave):
+ pred_data = {
+ "pred": float(preds[i].ravel()[0]) if np.asarray(preds[i]).size==1 else preds[i].tolist(),
+ "true": float(trues[i].ravel()[0]) if np.asarray(trues[i]).size==1 else trues[i].tolist()
+ }
+ with open(os.path.join(save_path, f"pred_{i+1}.txt"), "w") as f:
+ json.dump(pred_data, f, default=_to_json_safe)
+ """
+
+ return numeric_metrics
\ No newline at end of file
diff --git a/src/train/pytrain/utilities/loss_constructor.py b/src/train/pytrain/utilities/loss_constructor.py
new file mode 100644
index 0000000..2722add
--- /dev/null
+++ b/src/train/pytrain/utilities/loss_constructor.py
@@ -0,0 +1,294 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import monai.losses
+# import kornia.losses
+# import torchgan.losses
+# import piq
+
+import importlib
+import ast
+
+# Unified approved namespace map.
+# Keys may be:
+# - root shorthands mapped to live objects: "torch", "nn", "F"
+# - fully-qualified approved prefixes mapped to module import strings
+APPROVED_NAMESPACE_MAP = {
+ "torch": torch,
+ "nn": nn,
+ "F": F,
+ "torch.nn": "torch.nn",
+ "torch.nn.functional": "torch.nn.functional",
+ "monai.losses": "monai.losses",
+ "kornia.losses": "kornia.losses",
+ "piq": "piq",
+}
+
+class LossComposer:
+ def __init__(self, config):
+ self.config = config
+ self.loss_fn = self._parse_config(config)
+
+ @classmethod
+ def load_from_dict(cls, config):
+ return cls(config)
+
+ def calculate_loss(self, outputs, targets):
+ return self.loss_fn(outputs, targets)
+
+ # ----------------------------------------------------------------------
+ # Internal methods
+ # ----------------------------------------------------------------------
+ def _parse_config(self, config):
+ if "class" in config:
+ # Atomic loss
+ return self._build_atomic_loss(config)
+ elif "expression" in config:
+ # Composed loss
+ return self._build_composed_loss(config)
+ else:
+ raise ValueError("Invalid loss config: must contain 'class' or 'expression'")
+
+ def _build_atomic_loss(self, config):
+ cls_path = config["class"]
+ # Resolve object from approved namespaces (supports deep paths like torch.nn.functional.binary_cross_entropy_with_logits)
+ obj = _resolve_obj_from_approved_path(cls_path)
+
+ if callable(obj) and not isinstance(obj, type):
+ # It's a function (like torch.exp)
+ def fn(outputs, targets, cache=None):
+ params = config.get("params", {})
+ resolved = self._resolve_params(params, outputs, targets, cache)
+ return obj(**resolved)
+ return fn
+ else:
+ # It's a class, instantiate
+ instance = obj(**config.get("params", {}))
+ def fn(outputs, targets, cache=None):
+ return instance(outputs, targets)
+ return fn
+
+ def _build_composed_loss(self, config):
+ expression = config["expression"]
+ components_cfg = config.get("components", {})
+ variables = config.get("variables", {})
+ reduction = config.get("reduction", "mean")
+
+ # Recursively build sub-components
+ components = {
+ name: self._parse_config(sub_cfg)
+ for name, sub_cfg in components_cfg.items()
+ }
+
+ def fn(outputs, targets, cache=None):
+ if cache is None:
+ cache = {}
+
+ local_ctx = {}
+
+ # Evaluate components with caching
+ for name, comp_fn in components.items():
+ if name not in cache:
+ cache[name] = comp_fn(outputs, targets, cache)
+ local_ctx[name] = cache[name]
+
+ # Add variables/constants
+ local_ctx.update(variables)
+
+ # Evaluate expression with a safe arithmetic parser (no calls / attributes)
+ try:
+ loss_val = _safe_eval_expression(expression, local_ctx)
+ except Exception as e:
+ raise RuntimeError(f"Failed to evaluate expression '{expression}': {e}")
+
+ # Apply reduction if needed
+ if isinstance(loss_val, torch.Tensor) and loss_val.ndim > 0:
+ if reduction == "mean":
+ loss_val = loss_val.mean()
+ elif reduction == "sum":
+ loss_val = loss_val.sum()
+
+ return loss_val
+
+ return fn
+
+ def _resolve_params(self, params, outputs, targets, cache=None):
+ resolved = {}
+ for k, v in params.items():
+ # Recursive resolution for nested structures
+ if isinstance(v, dict):
+ resolved[k] = self._resolve_params(v, outputs, targets, cache)
+ continue
+ if isinstance(v, (list, tuple)):
+ seq_type = list if isinstance(v, list) else tuple
+ resolved[k] = seq_type(self._resolve_params({"_": x}, outputs, targets, cache)["_"] for x in v)
+ continue
+
+ # Direct placeholders
+ if v in ("output", "outputs"):
+ resolved[k] = outputs
+ continue
+ if v in ("target", "targets"):
+ resolved[k] = targets
+ continue
+
+ # References to cached component values
+ if isinstance(v, str) and cache is not None:
+ if v.startswith("-$"):
+ ref = v[2:]
+ if ref in cache:
+ resolved[k] = -cache[ref]
+ continue
+ raise KeyError(f"Referenced component '{ref}' not found in cache")
+ if v.startswith("$"):
+ ref = v[1:]
+ if ref in cache:
+ resolved[k] = cache[ref]
+ continue
+ raise KeyError(f"Referenced component '{ref}' not found in cache")
+ if v.startswith("-"):
+ # Backwards-compatible: "-name" refers to negative of cached 'name'
+ ref = v[1:]
+ if ref in cache:
+ resolved[k] = -cache[ref]
+ continue
+ # If not found, fall through to literal string
+
+ # Literal value
+ resolved[k] = v
+ return resolved
+
+
+# ----------------------------------------------------------------------
+# Security helpers
+# ----------------------------------------------------------------------
+def _resolve_obj_from_approved_path(path: str):
+ """Resolve an attribute object from an approved module path.
+ Chooses the longest approved namespace prefix and traverses attributes.
+ """
+ if not isinstance(path, str):
+ raise TypeError("Expected string path")
+
+ # Resolve via longest approved prefix from unified map
+ approved_sorted = sorted(APPROVED_NAMESPACE_MAP.keys(), key=len, reverse=True)
+ base = None
+ for ns in approved_sorted:
+ if path == ns or path.startswith(ns + "."):
+ base = ns
+ break
+ if base is None:
+ raise ValueError(f"Path '{path}' is not under approved namespaces: {list(APPROVED_NAMESPACE_MAP.keys())}")
+
+ provider = APPROVED_NAMESPACE_MAP[base]
+ if isinstance(provider, str):
+ # Import the module for string providers
+ module = importlib.import_module(provider)
+ if path == base:
+ raise ValueError(f"Path '{path}' refers to a module, expected a class or function under it")
+ remainder = path[len(base) + 1:]
+ obj = module
+ else:
+ # provider is a live module/object (torch, nn, F)
+ if path == base:
+ raise ValueError(f"Path '{path}' refers to a namespace root, expected a class or function under it")
+ remainder = path[len(base) + 1:]
+ obj = provider
+
+ for part in remainder.split('.'):
+ if part == "":
+ raise ValueError(f"Invalid path '{path}'")
+ if not hasattr(obj, part):
+ raise AttributeError(f"'{obj}' has no attribute '{part}' while resolving '{path}'")
+ obj = getattr(obj, part)
+ return obj
+
+
+def _safe_eval_expression(expression: str, names: dict):
+ """
+ Safely evaluate an arithmetic expression using AST.
+ - Supports PEDMAS: +, -, *, /, ** and parentheses
+ - Supports unary + and -
+ - Disallows function calls, attribute access, subscripting, comprehensions, etc.
+ - Names must exist in the provided names dict
+ """
+
+ def eval_node(node):
+ if isinstance(node, ast.Expression):
+ return eval_node(node.body)
+
+ # Constants / numbers
+ if isinstance(node, ast.Constant):
+ if isinstance(node.value, (int, float)):
+ return node.value
+ raise ValueError("Only numeric constants are allowed in expressions")
+
+ # Variables (component outputs or variables)
+ if isinstance(node, ast.Name):
+ if node.id in names:
+ return names[node.id]
+ raise NameError(f"Unknown name in expression: {node.id}")
+
+ # Parentheses are represented implicitly via AST structure
+
+ # Unary operations
+ if isinstance(node, ast.UnaryOp):
+ operand = eval_node(node.operand)
+ if isinstance(node.op, ast.UAdd):
+ return +operand
+ if isinstance(node.op, ast.USub):
+ return -operand
+ raise ValueError("Unsupported unary operator in expression")
+
+ # Binary operations
+ if isinstance(node, ast.BinOp):
+ left = eval_node(node.left)
+ right = eval_node(node.right)
+ if isinstance(node.op, ast.Add):
+ return left + right
+ if isinstance(node.op, ast.Sub):
+ return left - right
+ if isinstance(node.op, ast.Mult):
+ return left * right
+ if isinstance(node.op, ast.Div):
+ return left / right
+ if isinstance(node.op, ast.Pow):
+ return left ** right
+ # Optional: uncomment if you want to support floor-div or mod
+ # if isinstance(node.op, ast.FloorDiv):
+ # return left // right
+ # if isinstance(node.op, ast.Mod):
+ # return left % right
+ raise ValueError("Unsupported binary operator in expression")
+
+ # Anything else is forbidden
+ forbidden = (
+ ast.Call, ast.Attribute, ast.Subscript, ast.Dict, ast.List, ast.Tuple,
+ ast.BoolOp, ast.Compare, ast.IfExp, ast.Lambda, ast.ListComp, ast.DictComp,
+ ast.GeneratorExp, ast.SetComp, ast.Await, ast.Yield, ast.YieldFrom,
+ ast.FormattedValue, ast.JoinedStr
+ )
+ if isinstance(node, forbidden):
+ raise ValueError("Disallowed construct in expression")
+
+ raise ValueError(f"Unsupported expression element: {type(node).__name__}")
+
+ tree = ast.parse(expression, mode='eval')
+ return eval_node(tree)
diff --git a/src/train/pytrain/utilities/model_constructor.py b/src/train/pytrain/utilities/model_constructor.py
new file mode 100644
index 0000000..1d01585
--- /dev/null
+++ b/src/train/pytrain/utilities/model_constructor.py
@@ -0,0 +1,362 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
+
+from typing import Any, Dict, List, Tuple, Callable
+import types
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+APPROVED_NAMESPACES = {
+ "torch": torch,
+ "nn": nn,
+ "F": F,
+}
+
+# Security controls
+ALLOWED_OP_PREFIXES = {"F.", "torch.nn.functional."} # Allow only torch.nn.functional.* by default
+ALLOWED_OPS = {
+ "torch.cat", "torch.stack", "torch.concat", "torch.flatten", "torch.reshape", "torch.permute", "torch.transpose",
+ "torch.unsqueeze", "torch.squeeze", "torch.chunk", "torch.split", "torch.gather", "torch.index_select", "torch.narrow",
+ "torch.sum", "torch.mean", "torch.std", "torch.var", "torch.max", "torch.min", "torch.argmax", "torch.argmin", "torch.norm",
+ "torch.exp", "torch.log", "torch.log1p", "torch.sigmoid", "torch.tanh", "torch.softmax", "torch.log_softmax", "torch.relu", "torch.gelu",
+ "torch.matmul", "torch.mm", "torch.bmm", "torch.addmm", "torch.einsum",
+ "torch.roll", "torch.flip", "torch.rot90", "torch.rot180", "torch.rot270", "torch.rot360",
+}
+
+# Denylist of potentially dangerous kwarg names (case-insensitive)
+DENYLIST_ARG_NAMES = {
+ "out", # in-place writes to user-provided buffers
+ "file", "filename", "path", "dir", "directory", # filesystem
+ "map_location", # avoid device remap surprises
+}
+
+# DoS safeguards
+MAX_FORWARD_STEPS = 200
+MAX_OPS_PER_STEP = 10
+
+
+def _resolve_submodule(path: str) -> Any:
+ """Resolve dotted path like 'nn.Conv2d' or 'torch.sigmoid' to an object.
+ Raises AttributeError if resolution fails.
+ """
+ try:
+ if not isinstance(path, str):
+ raise TypeError("path must be a string")
+ parts = path.split(".")
+ if parts[0] in APPROVED_NAMESPACES:
+ obj = APPROVED_NAMESPACES[parts[0]]
+ else:
+ # allow direct module names like 'math' if needed
+ raise AttributeError(f"Unknown root namespace '{parts[0]}' in path '{path}'")
+ for p in parts[1:]:
+ try:
+ obj = getattr(obj, p)
+ except AttributeError:
+ raise AttributeError(f"Could not resolve attribute '{p}' in path '{path}'")
+ return obj
+ except Exception as e:
+ raise RuntimeError(f"Error resolving dotted path '{path}': {str(e)}") from e
+
+
+def _replace_placeholders(obj: Any, params: Dict[str, Any]) -> Any:
+ """Recursively replace strings of the form '$name' using params mapping."""
+ try:
+ if isinstance(obj, str) and obj.startswith("$"):
+ key = obj[1:]
+ if key not in params:
+ raise KeyError(f"Placeholder '{obj}' not found in params {params}")
+ return params[key]
+ elif isinstance(obj, dict):
+ try:
+ return {k: _replace_placeholders(v, params) for k, v in obj.items()}
+ except Exception as e:
+ raise RuntimeError(f"Error replacing placeholders in dict: {str(e)}") from e
+ elif isinstance(obj, (list, tuple)):
+ try:
+ seq_type = list if isinstance(obj, list) else tuple
+ return seq_type(_replace_placeholders(x, params) for x in obj)
+ except Exception as e:
+ raise RuntimeError(f"Error replacing placeholders in sequence: {str(e)}") from e
+ else:
+ return obj
+ except Exception as e:
+ raise RuntimeError(f"Error in placeholder replacement: {str(e)}") from e
+
+
+class ModelFactory:
+ """Factory for building PyTorch nn.Module instances from config dicts.
+
+ Public API:
+ ModelFactory.load_from_dict(config: dict) -> nn.Module
+ """
+
+ @classmethod
+ def load_from_dict(cls, config: Dict[str, Any]) -> nn.Module:
+ """Create an nn.Module instance from a top-level config.
+
+ The config may define 'submodules' (a dict of reusable component templates) and
+ a top-level 'layers' and 'forward' graph. Submodules are used by layers that have
+ a 'submodule' key and are instantiated with their provided params.
+ """
+ try:
+ if not isinstance(config, dict):
+ raise TypeError("Config must be a dictionary")
+
+ submodules_defs = config.get("submodules", {})
+
+ def create_instance_from_def(def_cfg: Dict[str, Any], provided_params: Dict[str, Any]):
+ try:
+ # Replace placeholders in the def_cfg copy
+ # Deep copy not strictly necessary since we replace on the fly
+ replaced_cfg = {
+ k: (_replace_placeholders(v, provided_params) if k in ("layers",) or isinstance(v, dict) else v)
+ for k, v in def_cfg.items()
+ }
+ # Build module from replaced config (submodule templates should not themselves contain further 'submodules')
+ return cls._build_module_from_config(replaced_cfg, submodules_defs)
+ except Exception as e:
+ raise RuntimeError(f"Error creating instance from definition: {str(e)}") from e
+
+ # When a layer entry references a 'submodule', we instantiate it using template from submodules_defs
+ return cls._build_module_from_config(config, submodules_defs)
+ except Exception as e:
+ raise RuntimeError(f"Error loading model from config: {str(e)}") from e
+
+ @classmethod
+ def _build_module_from_config(cls, config: Dict[str, Any], submodules_defs: Dict[str, Any]) -> nn.Module:
+ try:
+ layers_cfg = config.get("layers", {})
+ forward_cfg = config.get("forward", [])
+ input_names = config.get("input", [])
+ output_names = config.get("output", [])
+
+ # Create dynamic module class
+ class DynamicModule(nn.Module):
+ def __init__(self):
+ try:
+ super().__init__()
+ # ModuleDict to register submodules / layers
+ self._layers = nn.ModuleDict()
+ # Save forward graph and io names
+ self._forward_cfg = forward_cfg
+ self._input_names = input_names
+ self._output_names = output_names
+
+ # Build each layer / submodule
+ for name, entry in layers_cfg.items():
+ try:
+ if "class" in entry:
+ cls_obj = _resolve_submodule(entry["class"]) # e.g. nn.Conv2d
+ if not (isinstance(cls_obj, type) and issubclass(cls_obj, nn.Module)):
+ raise TypeError(f"Layer '{name}' class must be an nn.Module subclass, got {cls_obj}")
+ params = entry.get("params", {})
+ inst_params = _replace_placeholders(params, {}) # top-level layers likely have no placeholders
+ module = cls_obj(**inst_params)
+ self._layers[name] = module
+ elif "submodule" in entry:
+ sub_name = entry["submodule"]
+ if sub_name not in submodules_defs:
+ raise KeyError(f"Submodule '{sub_name}' not found in submodules definitions")
+ sub_def = submodules_defs[sub_name]
+ provided_params = entry.get("params", {})
+ # Replace placeholders inside sub_def using provided_params
+ # We create a fresh instance of submodule by calling helper
+ sub_inst = cls._instantiate_submodule(sub_def, provided_params, submodules_defs)
+ self._layers[name] = sub_inst
+ else:
+ raise KeyError(f"Layer '{name}' must contain either 'class' or 'submodule' key")
+ except Exception as e:
+ raise RuntimeError(f"Error building layer '{name}': {str(e)}") from e
+ except Exception as e:
+ raise RuntimeError(f"Error initializing DynamicModule: {str(e)}") from e
+
+ def forward(self, *args, **kwargs):
+ try:
+ # Map inputs
+ env: Dict[str, Any] = {}
+ # assign by position
+ for i, in_name in enumerate(self._input_names):
+ if i < len(args):
+ env[in_name] = args[i]
+ elif in_name in kwargs:
+ env[in_name] = kwargs[in_name]
+ else:
+ raise ValueError(f"Missing input '{in_name}' for forward; provided args={len(args)}, kwargs keys={list(kwargs.keys())}")
+
+ # Execute forward graph
+ if len(self._forward_cfg) > MAX_FORWARD_STEPS:
+ raise RuntimeError(f"Too many forward steps: {len(self._forward_cfg)} > {MAX_FORWARD_STEPS}. This is a security feature to prevent infinite loops.")
+
+ for idx, step in enumerate(self._forward_cfg):
+ try:
+ ops = step.get("ops", [])
+ if isinstance(ops, (list, tuple)) and len(ops) > MAX_OPS_PER_STEP:
+ raise RuntimeError(f"Too many ops in step {idx}: {len(ops)} > {MAX_OPS_PER_STEP}")
+ inputs_spec = step.get("input", [])
+ out_name = step.get("output", None)
+
+ # Resolve input tensors for this step
+ # inputs_spec might be: ['x'] or ['x1','x2'] or [['x3','encoded_feature']]
+ if len(inputs_spec) == 1 and isinstance(inputs_spec[0], (list, tuple)):
+ args_list = [env[n] for n in inputs_spec[0]]
+ else:
+ args_list = [env[n] for n in inputs_spec]
+
+ # Apply ops sequentially
+ current = args_list
+ for op in ops:
+ try:
+ # op can be string like 'conv1' or dotted 'F.relu'
+ # or can be a list like ['torch.flatten', {'start_dim':1}]
+ op_callable, op_kwargs = self._resolve_op(op)
+ # Validate kwargs denylist
+ for k in op_kwargs.keys():
+ if isinstance(k, str) and k.lower() in DENYLIST_ARG_NAMES:
+ raise PermissionError(f"Denied kwarg '{k}' for op '{op}'")
+
+ # If op_callable is a module in self._layers, call with module semantics
+ if isinstance(op_callable, str) and op_callable in self._layers:
+ module = self._layers[op_callable]
+ # if current is list of multiple args, pass them all
+ if isinstance(current, (list, tuple)) and len(current) > 1:
+ result = module(*current)
+ else:
+ result = module(current[0])
+ else:
+ # op_callable is a real callable object
+
+ if op_callable in {torch.cat, torch.stack}: # Ops that require a sequence input (instead of varargs)
+ # Wrap current into a list
+ result = op_callable(list(current), **op_kwargs)
+ elif isinstance(current, (list, tuple)):
+ result = op_callable(*current, **op_kwargs)
+ else:
+ result = op_callable(current, **op_kwargs)
+
+ # prepare current for next op
+ current = [result]
+ except Exception as e:
+ raise RuntimeError(f"Error applying operation '{op}': {str(e)}") from e
+
+ # write outputs back into env
+ if out_name is None:
+ continue
+ if isinstance(out_name, (list, tuple)):
+ # if step produces multiple outputs (rare), try unpacking
+ if len(out_name) == 1:
+ env[out_name[0]] = current[0]
+ else:
+ # try to unpack
+ try:
+ for k, v in zip(out_name, current[0]):
+ env[k] = v
+ except Exception as e:
+ raise RuntimeError(f"Could not assign multiple outputs for step {step}: {e}")
+ else:
+ env[out_name] = current[0]
+ except Exception as e:
+ raise RuntimeError(f"Error executing forward step: {str(e)}") from e
+
+ # Build function return
+ if len(self._output_names) == 0:
+ return None
+ if len(self._output_names) == 1:
+ return env[self._output_names[0]]
+ return tuple(env[n] for n in self._output_names)
+ except Exception as e:
+ raise RuntimeError(f"Error in forward pass: {str(e)}") from e
+
+ def _resolve_op(self, op_spec):
+ """Return (callable_or_module_name, kwargs)
+
+ If op_spec is a string and matches a layer name -> returns (layer_name_str, {}).
+ If op_spec is a string dotted path -> resolve dotted and return (callable, {}).
+ If op_spec is a list like ["torch.flatten", {"start_dim":1}] -> resolve and return (callable, kwargs)
+ """
+ try:
+ # module reference by name
+ if isinstance(op_spec, str):
+ if op_spec in self._layers:
+ return (op_spec, {})
+ # dotted function (F.relu, torch.sigmoid)
+ if not _is_allowed_op_path(op_spec):
+ raise PermissionError(f"Operation '{op_spec}' is not allowed")
+ callable_obj = _resolve_submodule(op_spec)
+ if not callable(callable_obj):
+ raise TypeError(f"Resolved object for '{op_spec}' is not callable")
+ return (callable_obj, {})
+ elif isinstance(op_spec, (list, tuple)):
+ if len(op_spec) == 0:
+ raise ValueError("Empty op_spec list")
+ path = op_spec[0]
+ kwargs = op_spec[1] if len(op_spec) > 1 else {}
+ if not _is_allowed_op_path(path):
+ raise PermissionError(f"Operation '{path}' is not allowed")
+ callable_obj = _resolve_submodule(path)
+ if not callable(callable_obj):
+ raise TypeError(f"Resolved object for '{path}' is not callable")
+ return (callable_obj, kwargs)
+ else:
+ raise TypeError(f"Unsupported op spec type: {type(op_spec)}")
+ except Exception as e:
+ raise RuntimeError(f"Error resolving operation '{op_spec}': {str(e)}") from e
+
+ # Instantiate dynamic module and return
+ dyn = DynamicModule()
+ return dyn
+ except Exception as e:
+ raise RuntimeError(f"Error building module from config: {str(e)}") from e
+
+ @classmethod
+ def _instantiate_submodule(cls, sub_def: Dict[str, Any], provided_params: Dict[str, Any], submodules_defs: Dict[str, Any]) -> nn.Module:
+ """Instantiate a submodule defined in 'submodules' using provided_params to replace placeholders.
+
+ provided_params are used to replace occurrences of strings like '$in_ch' inside the sub_def's 'layers' params.
+ """
+ try:
+ # Deep replace placeholders within sub_def copy
+ # We'll construct a new config where the "layers"->"params" are substituted
+ replaced = {}
+ for k, v in sub_def.items():
+ try:
+ if k == "layers":
+ new_layers = {}
+ for lname, lentry in v.items():
+ new_entry = dict(lentry)
+ if "params" in lentry:
+ new_entry["params"] = _replace_placeholders(lentry["params"], provided_params)
+ new_layers[lname] = new_entry
+ replaced[k] = new_layers
+ else:
+ # copy other keys directly (input/forward/output)
+ replaced[k] = v
+ except Exception as e:
+ raise RuntimeError(f"Error processing key '{k}': {str(e)}") from e
+
+ # Now build a module from this replaced config. This call may in turn instantiate nested submodules.
+ return cls._build_module_from_config(replaced, submodules_defs)
+ except Exception as e:
+ raise RuntimeError(f"Error instantiating submodule: {str(e)}") from e
+
+
+def _is_allowed_op_path(path: str) -> bool:
+ if any(path.startswith(p) for p in ALLOWED_OP_PREFIXES):
+ return True
+ return path in ALLOWED_OPS
\ No newline at end of file
diff --git a/src/train/pytrain/xgb_train.py b/src/train/pytrain/xgb_train.py
new file mode 100644
index 0000000..c3ab57d
--- /dev/null
+++ b/src/train/pytrain/xgb_train.py
@@ -0,0 +1,190 @@
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
+
+import os
+import json
+from typing import Any, Dict, Optional, Tuple
+
+import numpy as np
+import torch
+
+from .task_base import TaskBase
+from .utilities.dataset_constructor import create_dataset
+from .utilities.eval_tools import compute_metrics
+
+import xgboost as xgb
+from .utilities.dp_xgboost import DPXGBoost
+from sklearn.metrics import mean_squared_error
+
+def _to_numpy(array_like: Any) -> np.ndarray:
+ try:
+ if isinstance(array_like, torch.Tensor):
+ return array_like.detach().cpu().numpy()
+ except Exception:
+ pass
+ if hasattr(array_like, "to_numpy"):
+ return array_like.to_numpy()
+ return np.asarray(array_like)
+
+
+def _extract_xy_from_split(split_ds: Any) -> Tuple[np.ndarray, Optional[np.ndarray]]:
+ if hasattr(split_ds, "features"):
+ X = _to_numpy(getattr(split_ds, "features"))
+ y = getattr(split_ds, "targets")
+ y_np = _to_numpy(y) if y is not None else None
+ if y_np is not None and y_np.ndim > 1 and y_np.shape[1] == 1:
+ y_np = y_np.ravel()
+ return X, y_np
+
+ features_list = []
+ targets_list = []
+ has_target = None
+ for i in range(len(split_ds)):
+ item = split_ds[i]
+ if isinstance(item, (tuple, list)) and len(item) == 2:
+ x, y = item
+ features_list.append(_to_numpy(x))
+ targets_list.append(_to_numpy(y))
+ has_target = True
+ else:
+ features_list.append(_to_numpy(item))
+ if has_target is None:
+ has_target = False
+
+ X = np.vstack([x.reshape(1, -1) if x.ndim == 1 else x for x in features_list]) if len(features_list) else np.empty((0,))
+ y = None
+ if has_target:
+ y_arr = np.asarray(targets_list)
+ if y_arr.ndim > 1 and y_arr.shape[1] == 1:
+ y_arr = y_arr.ravel()
+ y = y_arr
+ return X, y
+
+
+class Train_XGB(TaskBase):
+ """Train (DP-)XGBoost models.
+
+ Expected config snippet:
+ {
+ "task_type": "classification" | "regression",
+ "dataset_config": {...},
+ "paths": {"input_dataset_path": "/path", "trained_model_output_path": "/out"},
+ "model_config": {
+ "n_estimators": 300,
+ "max_depth": 6,
+ "learning_rate": 0.1,
+ "objective": "binary:logistic",
+ "num_class": 2,
+ "booster_params": {"tree_method": "hist", ...} # optional, passed to underlying booster
+ },
+ "is_private": true,
+ "privacy_params": {"mechanism": "gaussian", "epsilon": 2.0, "delta": 1e-5, "clip_value": 1.0} # used by some dp-xgboost forks
+ }
+ """
+
+ def init(self, config: Dict[str, Any]):
+ self.config = config
+ self.task_type = config.get("task_type")
+ self.paths = config.get("paths", {})
+
+ self.is_private = bool(config["is_private"])
+ self.privacy_config = config["privacy_params"] if self.is_private else None
+
+ self.X_train = None
+ self.y_train = None
+ self.X_val = None
+ self.y_val = None
+ self.X_test = None
+ self.y_test = None
+
+ self.model = None
+
+ def load_data(self):
+ dataset_cfg = self.config.get("dataset_config", {})
+ input_path = self.paths.get("input_dataset_path")
+ all_splits = create_dataset(dataset_cfg, input_path)
+
+ if "train" not in all_splits:
+ raise ValueError("Dataset must provide at least a 'train' split")
+
+ self.X_train, self.y_train = _extract_xy_from_split(all_splits["train"])
+ if "val" in all_splits:
+ self.X_val, self.y_val = _extract_xy_from_split(all_splits["val"])
+ if "test" in all_splits:
+ self.X_test, self.y_test = _extract_xy_from_split(all_splits["test"])
+
+ print(f"Loaded dataset splits | train: {self.X_train.shape} | val: {None if self.X_val is None else self.X_val.shape} | test: {None if self.X_test is None else self.X_test.shape}")
+
+ def load_model(self):
+ xgb_params = self.config["model_config"]["booster_params"]
+ self.model = DPXGBoost(xgb_params, privacy_params=self.privacy_config)
+
+ def train(self):
+ num_boost_round = self.config["model_config"]["num_boost_round"]
+ self.model.fit(X=self.X_train, y=self.y_train, num_boost_round=num_boost_round)
+ if self.is_private:
+ eps_string = f"Epsilon: {self.privacy_config['epsilon']}"
+ else:
+ eps_string = "Non DP"
+ print(f"Trained Gradient Boosting model with {num_boost_round} boosting rounds | {eps_string}")
+
+ def save_model(self):
+ save_path = os.path.join(self.paths["trained_model_output_path"], "trained_model.json")
+ self.model.save_model(save_path)
+ print(f"Saved model to {self.paths['trained_model_output_path']}")
+
+ def inference_and_eval(self):
+ if self.X_test is None:
+ print("No test split provided; skipping evaluation.")
+ return
+
+ preds_list = []
+ targets_list = []
+ non_dp_preds_list = []
+
+ if self.is_private:
+ preds_list = self.model.predict(X=self.X_test, dp=True)
+ non_dp_preds_list = self.model.predict(X=self.X_test, dp=False)
+ else:
+ preds_list = self.model.predict(X=self.X_test, dp=False)
+
+ targets_list = self.y_test
+
+ numeric_metrics = compute_metrics(preds_list, targets_list, None, self.config)
+ print(f"Evaluation Metrics: {numeric_metrics}")
+
+ if self.is_private:
+ os.makedirs(os.path.join(self.paths["trained_model_output_path"], "non_dp"), exist_ok=True)
+ self.config["paths"]["trained_model_output_path"] = os.path.join(self.paths["trained_model_output_path"], "non_dp")
+ non_dp_numeric_metrics = compute_metrics(non_dp_preds_list, targets_list, None, self.config)
+ print(f"Non-DP Evaluation Metrics: {non_dp_numeric_metrics}")
+
+ def execute(self, config: Dict[str, Any]):
+ try:
+ self.init(config)
+ self.load_data()
+ self.load_model()
+ self.train()
+ self.save_model()
+ self.inference_and_eval()
+ print("CCR Training complete!\n")
+
+ except Exception as e:
+ print(f"An error occurred: {e}")
+ raise e
+
+
diff --git a/src/train/run.sh b/src/train/run.sh
old mode 100644
new mode 100755
index 07f5f05..6272439
--- a/src/train/run.sh
+++ b/src/train/run.sh
@@ -9,4 +9,5 @@ echo "pipeline configuration is available"
echo "Running pipeline with configuration:"
cat /mnt/remote/config/pipeline_config.json
+echo ""
pytrain /mnt/remote/config/pipeline_config.json
diff --git a/src/train/setup.py b/src/train/setup.py
index 92459d8..2b2d7bd 100644
--- a/src/train/setup.py
+++ b/src/train/setup.py
@@ -1,5 +1,19 @@
-# Copyright (c) Microsoft Corporation.
-# Licensed under the MIT License.
+# 2025 DEPA Foundation
+#
+# This work is dedicated to the public domain under the CC0 1.0 Universal license.
+# To the extent possible under law, DEPA Foundation has waived all copyright and
+# related or neighboring rights to this work.
+# CC0 1.0 Universal (https://creativecommons.org/publicdomain/zero/1.0/)
+#
+# This software is provided "as is", without warranty of any kind, express or implied,
+# including but not limited to the warranties of merchantability, fitness for a
+# particular purpose and noninfringement. In no event shall the authors or copyright
+# holders be liable for any claim, damages or other liability, whether in an action
+# of contract, tort or otherwise, arising from, out of or in connection with the
+# software or the use or other dealings in the software.
+#
+# For more information about this framework, please visit:
+# https://depa.world/training/depa_training_framework/
from setuptools import find_packages, setup