Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,6 @@ cfg_constructor/out
**/bin
**/.venv
**/logdat

!**/demo_file_id.txt
**/.ipynb_checkpoints
**/*.swp
!**/demo_file_id.txt
70 changes: 70 additions & 0 deletions .ipynb_checkpoints/README-checkpoint.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Malware-Analysis

## Environment Setup

Install poetry
`pip install poetry`

Install dependencies
`poetry install`

Enter Poetry Shell
`poetry shell`

Prior to running scripts in GraphIsomorphismNetwork/JAX-GIN branch, run `unset LD_LIBRARY_PATH` to ensure that the jax library can properly use cuda devices.

## Running the CFG Constructor Tool `(src/cfg_constructor)`

To generate Control Flow Graphs (CFGs) from binary files, use the `cfg_creator.py` script. This script analyzes binary files and creates CFGs, which can be visualized or saved in different formats.

### Usage

Run the script from the root directory using the following command:
```sh
python src/cfg_constructor/cfg_creator.py --data-dir <path_to_binary_files> --vis-mode <visualization_mode> --job-id <job_id> --id-list <fname.txt>
```

Args:
- `--data-dir`: Path to the directory containing the binary files in the parent directory. (str, default='data')
- `--vis-mode`: Visualization mode. 0 = visualize in window, 1 = save as HTML docs, 2 = save graphs w/o visualizing as edgelists and `csv` for node values. (int, default=2)
- `--job-id`: int for job id for use for logging + avoiding reprocessing already processed data based on job_id, vis_mode, and data dir (int, default=0)
- `--id-list`: [OPTIONAL] .txt file with a list of ids to constrain processing to (str, default='', no file input; if a file is put in, it must be a valid `.txt` file which is a list of ids, single id per line)
- Note: if using the `--id-list` option, reference the `src/cfg_constructor/demo_file_id.txt` for how your input file should be formatted. If not, the tool will automatically create graphs for all files in the specified data directory.

Output will be stored dependent on the vis_mode:
- vis_mode=0: no saved output, graphs will be displayed in GUI
- vis_mode=1: HTML files in `cfg_constructor/out/out_html`
- vis_mode=2: CSV files in `cfg_constructor/out/out_adjacency_matrices`

Example usage from my machine:
```
(malware-analysis-py3.12) me@mac Malware-Analysis %
python cfg_constructor/cfg_creator.py --data_dir data --vis_mode 2 --job_id 0
```

This goes to the root directory of the repository and runs the constructor from a `data` dir (also in the root directory), visualizes each with mode `2` (saving adjacency lists to specified dir above), and assigns job_id `0` for logging (i.e. program crashes, can easily resume)

If a logging file with the existing job already exists, the script will load that and silently skip any files marked as processed by that log file.

## Methodology
Using Static Analysis (deconstruction of binaries without execution) to extract Control Flow Graphs from a binary.

Leverage Graph Neural Networks trained on these CFGs to classify an arbitrary binary as malicious or benign.
We aim to primarily utilize a dataset of 200k+ Windows PE binaries [linked here](https://practicalsecurityanalytics.com/pe-malware-machine-learning-dataset/)

## Goal
Produce a pipeline capable of performing deconstruction + inference **very fast**.

Feature based models (i.e. XGBoost -> tree model, Yara Rules -> condition matching) can run in <1s and NLP tools (i.e. Kilogram paper -> n-gram analysis) can also run fairly fast.

Our hypothesis is that GNNs can capture more complex characteristics of malicious binaries via their CFGs and by training a large model and compressing it to a smaller downstream one, we can match the accuracy of feature based approaches with a fairly close inference time as well.

## Compression Techniques

TBD - do more research here. Added potential ones but requires more insight
- Distillation (teaching a smaller downstream model to learn the behavior of the large model; famous from DistilBERT)
- Quantization (possibly quantization aware training to facilitate this approach)
- Pruning (self explanatory; remove weights determined as irrelevant by some arbitrary technique)
- Theoretical Optimization based on BERT-of-Theseus paper
- Paper is centered around replacing large BERT modules with small modules while training to get small modules to mimic behavior of large ones in network
- Depends on GNN architecture but could this be applied here? Can we have an optimization method where modules with 50%, 75% less params than large-GNN modules are randomly inserted in some trained network and trained to mimic the role of the large modules (similar to distillation but incorporated in the network)?
28 changes: 28 additions & 0 deletions .ipynb_checkpoints/pyproject-checkpoint.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
[tool.poetry]
name = "malware-analysis"
version = "0.1.0"
description = "environment for static analysis of malware"
authors = ["Aarav Gupta <aaravgupta@gatech.edu>"]
readme = "README.md"

[tool.poetry.dependencies]
python = "^3.12"
networkx = "^3.3"
matplotlib = "^3.9.1.post1"
tqdm = "^4.66.5"
pyvis = "^0.3.2"
pandas = "^2.2.2"
click = "^8.1.7"
setuptools = "^72.1.0"
wheel = "^0.44.0"
setuptools-rust = "^1.10.1"
iced-x86 = "^1.21.0"
scipy = "^1.14.0"
jax = "^0.5.3"
flax = "^0.10.4"
torch = "^2.7.1"


[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"
Empty file added .runJaxGin.sh.swp
Empty file.
256 changes: 256 additions & 0 deletions GraphIsomorphismNetwork/.ipynb_checkpoints/jax_gin-checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state, checkpoints
from new_loader import get_paths, load_dataset_jax_new
import optax
from pathlib import Path
import pandas as pd
import numpy as np
import time
import os
import logging
import tqdm

log_file = Path(__file__).resolve().parent.parent / "logs" / "gin_training.log"
log_file.parent.mkdir(parents=True, exist_ok=True)

root_logger = logging.getLogger()
if root_logger.handlers:
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)

logging.basicConfig(
filename=log_file,
level=logging.INFO,
filemode="a",
format="%(asctime)s - %(levelname)s - %(message)s",
datefmt='%Y-%m-%d %H:%M:%S'
)

logging.info("Starting GIN training...")

# Dense -> BatchNorm -> ReLU -> Dense.
class MLP(nn.Module):
hidden_dim: int

@nn.compact
def __call__(self, x, training: bool):
x = nn.Dense(2 * self.hidden_dim)(x)
x = nn.BatchNorm(use_running_average=not training)(x)
x = nn.relu(x)
x = nn.Dense(self.hidden_dim)(x)
return x

# Define a GIN convolution layer: x_i' = MLP((1+eps)*x_i + sum_{j in N(i)} x_j)
class GINConv(nn.Module):
hidden_dim: int
train_eps: bool = True

@nn.compact
def __call__(self, x, senders, receivers, training: bool):
mlp = MLP(hidden_dim=self.hidden_dim)
# Learnable epsilon
if self.train_eps:
eps = self.param("eps", lambda rng: jnp.zeros(()))
else:
eps = 0.0
# Aggregate neighbor features using segment_sum.
aggregated = jax.ops.segment_sum(x[senders], receivers, num_segments=x.shape[0])
out = mlp((1 + eps) * x + aggregated, training=training)
return out

class GIN(nn.Module):
in_channels: int
hidden_channels: int
out_channels: int
num_layers: int
dropout_rate: float = 0.5
train_eps: bool = True

@nn.compact
def __call__(self, x, edge_index, batch, training: bool):
senders, receivers = edge_index # edge_index is a tuple: (senders, receivers)
for _ in range(self.num_layers):
x = GINConv(hidden_dim=self.hidden_channels, train_eps=self.train_eps)(
x, senders, receivers, training=training
)
x = nn.BatchNorm(use_running_average=not training)(x)
x = nn.relu(x)

# Global add pooling: sum node features per graph.
x = jax.ops.segment_sum(x, batch, num_segments=1)

# Two-layer MLP for graph-level output.
x = nn.Dense(self.hidden_channels)(x)
x = nn.LayerNorm()(x)
x = nn.relu(x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not training)
x = nn.Dense(self.out_channels)(x)
return jax.nn.log_softmax(x, axis=-1)

class TrainState(train_state.TrainState):
batch_stats: dict

def create_train_state(rng, model, learning_rate, sample_input):
variables = model.init(
rng,
sample_input["x"],
sample_input["edge_index"],
sample_input["batch"],
training=True,
)
tx = optax.adam(learning_rate)
return TrainState.create(
apply_fn=model.apply,
params=variables["params"],
tx=tx,
batch_stats=variables.get("batch_stats", {})
)

@jax.jit
def train_step(state, batch, dropout_rng):
def loss_fn(params):
variables = {"params": params, "batch_stats": state.batch_stats}
logits, new_model_state = state.apply_fn(
variables,
batch["x"],
batch["edge_index"],
batch["batch"],
training=True,
mutable=["batch_stats"],
rngs={"dropout": dropout_rng},
)
labels = batch["y"]
nll = -jnp.mean(jnp.take_along_axis(logits, labels[:, None], axis=-1).squeeze())
return nll, new_model_state

grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
(loss, new_model_state), grads = grad_fn(state.params)
state = state.apply_gradients(grads=grads, batch_stats=new_model_state["batch_stats"])
return state, loss

@jax.jit
def test_step(state, batch):
variables = {"params": state.params, "batch_stats": state.batch_stats}
logits = state.apply_fn(
variables,
batch["x"],
batch["edge_index"],
batch["batch"],
training=False,
mutable=False
)
pred = jnp.argmax(logits, axis=-1)
correct = jnp.sum(pred == batch["y"])
return correct

def save_model(state, checkpoint_dir, step):
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoints.save_checkpoint(
ckpt_dir=checkpoint_dir,
target=state,
step=step,
overwrite=True
)
msg = f"Model saved at step {step} to {checkpoint_dir}"
print(msg)
logging.info(msg)

def load_model_if_exists(state, checkpoint_dir):
latest_ckpt = checkpoints.latest_checkpoint(checkpoint_dir)
if latest_ckpt:
state = checkpoints.restore_checkpoint(
ckpt_dir=checkpoint_dir,
target=state
)
msg = f"Model restored from checkpoint: {latest_ckpt}"
print(msg)
logging.info(msg)
else:
msg = "No checkpoint found. Training from scratch."
print(msg)
logging.info(msg)
return state

def main():
# print(jax.devices())
dev = jax.devices()[0] if jax.devices() else None

paths = get_paths(samples_2000=True)
data_loader, num_features, num_classes = load_dataset_jax_new(paths, max_files=None)

split_ratio = 0.8
n_total = len(data_loader)
split_idx = int(n_total * split_ratio)
train_loader = data_loader[:split_idx]
test_loader = data_loader[split_idx:]
print(f"Total batches: {n_total}; Training batches: {len(train_loader)}; Test batches: {len(test_loader)}")

print(f"Number of Features: {num_features}, model hidden layer dim: {int(num_features * 1e-4)}")
model = GIN(
in_channels=num_features,
hidden_channels=int(num_features * 1e-4),
out_channels=num_classes,
num_layers=4,
dropout_rate=0.5
)

rng = jax.random.PRNGKey(0)
dropout_rng, init_rng = jax.random.split(rng)
#print(train_loader)
print(n_total)
sample_input = train_loader[0]
state = create_train_state(init_rng, model, learning_rate=0.01, sample_input=sample_input)

checkpoint_dir = Path(__file__).resolve().parent.parent / "weights"
state = load_model_if_exists(state, checkpoint_dir)

start_train = time.perf_counter()
num_epochs = 10

state = jax.device_put(state, device=dev)

for epoch in range(1, num_epochs + 1):
epoch_start = time.perf_counter()

# Train Loop
epoch_loss = 0.0
total_graphs = 0
for batch in tqdm.tqdm(train_loader, desc=f"training epoch {epoch}"):
dropout_rng, new_dropout_rng = jax.random.split(dropout_rng)
state, loss = train_step(state, batch, dropout_rng)
state = jax.device_put(state, device=dev)
epoch_loss += loss * batch["num_graphs"]
total_graphs += batch["num_graphs"]
avg_loss = epoch_loss / total_graphs

# Test Loop
total_correct = 0
total_test_graphs = 0
for batch in test_loader:
correct = test_step(state, batch)
total_correct += correct
total_test_graphs += batch["num_graphs"]
test_acc = total_correct / total_test_graphs

epoch_end = time.perf_counter()
msg1 = f"Time for epoch {epoch} was {(epoch_end - epoch_start):.6f} seconds"
msg2 = f"Epoch: {epoch:03d}, Loss: {avg_loss:.4f}, Test Acc: {test_acc:.4f}"
print(msg1)
print(msg2)
logging.info(msg1)
logging.info(msg2)

save_model(state, checkpoint_dir, epoch)

end_train = time.perf_counter()
msg3 = f"Total training time : {(end_train - start_train):.6f} seconds"
msg4 = f"Average time per epoch: {((end_train - start_train) / num_epochs):.6f} seconds"
print(msg3)
print(msg4)
logging.info(msg3)
logging.info(msg4)

if __name__ == "__main__":
main()
Loading