https://github.com/fdalvi/NeuroX/blob/f2314ccd964f744eb3209f1438fa53704a6665b2/neurox/interpretation/iou_probe.py#L37 assumes that y_train only has 0's and 1's. For multiclass, we should run IoU for each tag, and then combine into a global ranking.