The data files to be used in this challenge should be downloaded from here: https://drive.google.com/drive/folders/1swCsdUeYnMYLIEKYZ5ed1YIU0X1vyt9u?usp=sharing
Installation of required packages in requirements.txt pip install -r requirements.txt
The task is to design and implement a pipeline that generates graph-based embeddings for a biological knowledge graph and uses them to predict potential drug-disease associations. This task has been split into three subtasks:
- Constructing the knowledge graph. This is done in graph_builder.py. This component is not completed. See next steps and caveats below.
- Extracting embeddings from the knowledge graph. This is done in embedding_extractor.py
- Using the extracted embeddings to train and test a machine learning classifier. This is done in train_model.py
This code repository has another file build_test_graph.py that can be used to generate random graphs for testing purposed. Each part of the pipeline can be run separately:
#graph_builder not finished testing successfully. The large files look like they need batch processing and this has not been implemented.
Do not run: python graph_builder.py --node_file ./data/Nodes.csv --edge_file ./data/Edges.csv --output_file graph.pt
To run the ML classifier using the provided embeddings (Assuming data stored in folder data), trained model output to output_model.pkl:
python train_model.py --embedding_file ./data/Embeddings.csv --label_file ./data/Ground\ Truth.csv --model_file output_model.pkl
Random test graphs was built to use to check the embeddings script without needing the graph_builder functionality complete. Random test data can be generated by running:
python build_graph.py
This will output 4 test files Output_Graph: a Pytorch graph object edges.csv: file containing edges in graph Class_Labels.csv: class labels for training node_features.csv : file with node features node_ids.csv: node IDs corresponding to order in Output_Graph
The output from this can be used to test the extract embedding function, with the final option embedding_output giving the name of the csv file the embeddings should be output to.
python embedding_extractor.py --graph_object Output_Graph --node_ids node_ids.csv --embedding_output embeddings_output.csv
This can then be used with the model training:
python train_model.py --embedding_file embeddings_output.csv --label_file Class_Labels.csv --model_file test_model.pkl
Overview of methodology To extract embeddings from the knowledge graph and graph attention network (GAT) model has been used. This method allows the model to assign different levels of importance to the neighbours in the network. XGBoost is used for the classification model. Part of the data is held-out for testing later. Cross-validation is used during training with a randomised parameter search. Model evaluation is via accuracy metrics, reported statistics include precision, recall, F1 and the confusion matrix is output.
Next steps The construction of the knowledge graph features needs to be completed. This involves parsing the file to get relevant edge attributes from the column 'predicate'. The node features also need to parsed and formulated into features that can be used in machine learning models. This was set to use the equivalent_curies column. Testing was not completed on this script - it looks like the data needs to be batch processed.
Tuning of the hyperparameters for the GAT model. As this part of the pipeline is not currently directly linked to the classification prediction, alternative methods could be to use 'sense check' measures to assess the impact of hyperparameters on the embeddings, including checking if the cosine similarity between connected nodes is higher than unconnected nodes using the embeddgings. Consider combining the GAT and classifier into one model so that the learning is done across both tasks.
Further arguments can be passed to the scripts e.g. number of training runs for the XGBoost model. This is currently set to 5 (n_iter) for testing purposes. The main.py script needs to be updated to include the correct parameters for all the stages, once each module is complete.