Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions SEGMENTATION/SEGMENTATION.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,8 @@ def run(self, meshes_dir, mrk_dir, out_dir,
labV = self._project_labels(V, P, labels)
labV = self._smooth_labels(V, F, labV, smooth_iters)
if write_vtp:
self._write_vtp_with_labels_world(meshPath, out_dir, labV, labV); wrote += 1
#self._write_vtp_with_labels_world(meshPath, out_dir, labV, labV); wrote += 1 # TO DO
self._write_ply_segments(meshPath, out_dir, labV); wrote += 1
if i < previewN:
self._preview_in_scene(meshPath, labV)

Expand Down Expand Up @@ -219,7 +220,7 @@ def _load_model_node(self, path):
before = self._ids_of_class("vtkMRMLModelNode")
node = None
try:
node = slicer.util.loadModel(path, properties={"coordinateSystem":"LPS"})
node = slicer.util.loadModel(path, properties={"coordinateSystem":"RAS"})
except TypeError: # very old Slicer: no properties kw
ok = slicer.util.loadModel(path)
if not ok: raise RuntimeError(f"Failed to load model: {path}")
Expand All @@ -232,7 +233,7 @@ def _load_markups_node(self, path):
before = self._ids_of_class("vtkMRMLMarkupsNode")
node = None
try:
node = slicer.util.loadMarkups(path, properties={"coordinateSystem":"LPS"})
node = slicer.util.loadMarkups(path, properties={"coordinateSystem":"RAS"})
except TypeError:
ok = slicer.util.loadMarkups(path)
if not ok: raise RuntimeError(f"Failed to load markups: {path}")
Expand Down Expand Up @@ -422,13 +423,49 @@ def _smooth_labels(self,V,F,lab,iters):
maj=np.argmax(votes,1); lab=np.where(deg>0,maj,lab)
return lab

# TO DO: substitute _write_ply_segments() to save ply's of each segment instead of a colored vtp
def _write_vtp_with_labels_world(self,meshPath,out_dir,labV_raw,labV_smooth):
n=self._load_model_node(meshPath); pdW=self._polydata_world(n)
a1=vtk_np.numpy_to_vtk(labV_raw, deep=True, array_type=vtk.VTK_INT); a1.SetName("SegID"); pdW.GetPointData().AddArray(a1)
a2=vtk_np.numpy_to_vtk(labV_smooth, deep=True, array_type=vtk.VTK_INT); a2.SetName("SegID_smooth"); pdW.GetPointData().AddArray(a2); pdW.GetPointData().SetActiveScalars("SegID_smooth")
w=vtk.vtkXMLPolyDataWriter(); base=os.path.splitext(os.path.basename(meshPath))[0]; out=os.path.join(out_dir,base+"_seg.vtp"); w.SetFileName(out); w.SetInputData(pdW); w.Write()
slicer.mrmlScene.RemoveNode(n)

# TO DO: added func to output ply segments
def _write_ply_segments(self, meshPath, out_dir, labels):
n = self._load_model_node(meshPath)
pd = self._polydata_world(n)
# attach labels as active point scalars
arr = vtk_np.numpy_to_vtk(labels.astype(np.int32), deep=True, array_type=vtk.VTK_INT)
arr.SetName("SegID")
pd.GetPointData().SetScalars(arr)
base = os.path.splitext(os.path.basename(meshPath))[0]
for seg in np.unique(labels):
if seg < 0:
continue # optional: skip background
t = vtk.vtkThreshold()
t.SetInputData(pd)
# make sure we threshold on point scalars "SegID"
t.SetInputArrayToProcess(
0, 0, 0,
vtk.vtkDataObject.FIELD_ASSOCIATION_POINTS,
"SegID"
)
# tighter bounds are safer than exact equality in floating pipelines
t.SetLowerThreshold(float(seg) - 0.5)
t.SetUpperThreshold(float(seg) + 0.5)
t.Update()
g = vtk.vtkGeometryFilter()
g.SetInputConnection(t.GetOutputPort())
g.Update()
out_path = os.path.join(out_dir, f"{base}_seg_{int(seg):02d}.ply")
w = vtk.vtkPLYWriter()
w.SetFileName(out_path)
w.SetInputData(g.GetOutput())
w.SetFileTypeToBinary()
w.Write()
slicer.mrmlScene.RemoveNode(n)

def _preview_in_scene(self, meshPath, labV):
n = self._load_model_node(meshPath)
if n is None: return
Expand Down Expand Up @@ -837,9 +874,4 @@ def _assert_finite(self, what, arr, path):
raise RuntimeError(
f"Non-finite values in {what} of '{os.path.basename(path)}' "
f"(first bad indices: {bad})."
)





)