Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Libs/Optimize/Domain/MeshDomain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,16 @@ void MeshDomain::SetMesh(std::shared_ptr<Surface> mesh, double geodesic_remesh_p
}
}

//-------------------------------------------------------------------
void MeshDomain::SetMesh(std::shared_ptr<Surface> surface, std::shared_ptr<Surface> geodesics_surface,
std::shared_ptr<Mesh> sw_mesh, double surface_area) {
m_FixedDomain = false;
surface_ = surface;
geodesics_mesh_ = geodesics_surface;
sw_mesh_ = sw_mesh;
surface_area_ = surface_area;
}

//-------------------------------------------------------------------
void MeshDomain::InvalidateParticlePosition(int idx) const { this->surface_->invalidate_particle(idx); }

Expand Down
2 changes: 2 additions & 0 deletions Libs/Optimize/Domain/MeshDomain.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class MeshDomain : public ParticleDomain {
}

void SetMesh(std::shared_ptr<Surface> mesh_, double geodesic_remesh_percent);
void SetMesh(std::shared_ptr<Surface> surface, std::shared_ptr<Surface> geodesics_surface,
std::shared_ptr<Mesh> sw_mesh, double surface_area);

std::shared_ptr<Mesh> GetSWMesh() const { return sw_mesh_; }

Expand Down
12 changes: 12 additions & 0 deletions Libs/Optimize/Optimize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1757,6 +1757,18 @@ void Optimize::AddMesh(vtkSmartPointer<vtkPolyData> poly_data) {
this->m_spacing = 0.5;
}

//---------------------------------------------------------------------------
void Optimize::AddMesh(std::shared_ptr<Surface> surface, std::shared_ptr<Surface> geodesics_surface,
std::shared_ptr<Mesh> sw_mesh, double surface_area) {
if (!surface) {
m_sampler->AddMesh(nullptr);
} else {
m_sampler->AddMesh(surface, geodesics_surface, sw_mesh, surface_area);
}
this->m_num_shapes++;
this->m_spacing = 0.5;
}

//---------------------------------------------------------------------------
void Optimize::AddContour(vtkSmartPointer<vtkPolyData> poly_data) {
m_sampler->AddContour(poly_data);
Expand Down
2 changes: 2 additions & 0 deletions Libs/Optimize/Optimize.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ class Optimize {
//! Set the shape input images
void AddImage(ImageType::Pointer image, std::string name = "");
void AddMesh(vtkSmartPointer<vtkPolyData> poly_data);
void AddMesh(std::shared_ptr<Surface> surface, std::shared_ptr<Surface> geodesics_surface,
std::shared_ptr<Mesh> sw_mesh, double surface_area);
void AddContour(vtkSmartPointer<vtkPolyData> poly_data);

//! Set the shape filenames (TODO: details)
Expand Down
225 changes: 184 additions & 41 deletions Libs/Optimize/OptimizeParameters.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
#include <Particles/ParticleFile.h>
#include <Utils/StringUtils.h>

#include <atomic>
#include <boost/algorithm/string.hpp>
#include <boost/filesystem.hpp>
#include <functional>
#include <optional>
#include <tbb/parallel_for.h>
#include <tbb/blocked_range.h>

#include "Libs/Optimize/Domain/Surface.h"
#include "Optimize.h"

using namespace shapeworks;
Expand Down Expand Up @@ -683,9 +688,47 @@ bool OptimizeParameters::set_up_optimize(Optimize* optimize) {
}
}

// === Phase A: Collect work items (sequential, fast I/O) ===
struct DomainWorkItem {
std::string filename;
DomainType domain_type;
std::shared_ptr<Subject> subject;
std::vector<std::vector<double>> transforms;
int domain_index_in_subject; // index of this domain within the subject's files
int global_domain_index; // sequential domain index across all subjects
int subject_index; // sequential subject index (non-excluded)
bool is_fixed;

// For mesh domains — prepared poly_data from Phase A
vtkSmartPointer<vtkPolyData> poly_data;
bool is_contour_hack = false; // true if mesh cells have 2 points (contour detection hack)

// For image domains
std::optional<Image> image;

// Results from Phase B (parallel pre-compute)
std::shared_ptr<shapeworks::Surface> surface;
std::shared_ptr<shapeworks::Surface> geodesics_surface;
std::shared_ptr<Mesh> sw_mesh;
double surface_area = 0.0;
};

std::vector<DomainWorkItem> work_items;
std::vector<std::string> filenames;
// Track which work items belong to each subject for callback and particle filename assignment
struct SubjectInfo {
std::shared_ptr<Subject> subject;
std::vector<size_t> work_item_indices; // indices into work_items
};
std::vector<SubjectInfo> subject_infos;

int count = 0;
domain_count = 0;
bool geodesics_enabled = get_use_geodesic_distance();
double geodesic_remesh_percent = get_geodesic_remesh_percent();
int geodesic_cache_multiplier = get_geodesic_cache_multiplier();
bool use_geodesics_to_landmarks = get_use_geodesics_to_landmarks();

for (auto s : subjects) {
if (abort_load_) {
return false;
Expand All @@ -700,8 +743,9 @@ bool OptimizeParameters::set_up_optimize(Optimize* optimize) {
throw std::invalid_argument("No groomed inputs for optimization");
}
auto transforms = s->get_groomed_transforms();
std::vector<std::string> local_particle_filenames;
std::vector<std::string> world_particle_filenames;

SubjectInfo si;
si.subject = s;

for (int i = 0; i < files.size(); i++) {
auto filename = files[i];
Expand All @@ -713,75 +757,174 @@ bool OptimizeParameters::set_up_optimize(Optimize* optimize) {
auto domain_type = project_->get_groomed_domain_types()[i];
filenames.push_back(filename);

DomainWorkItem item;
item.filename = filename;
item.domain_type = domain_type;
item.subject = s;
item.transforms = transforms;
item.domain_index_in_subject = i;
item.global_domain_index = domain_count;
item.subject_index = count;
item.is_fixed = s->is_fixed();

if (domain_type == DomainType::Mesh) {
Mesh mesh = MeshUtils::threadSafeReadMesh(filename.c_str());
if (domain_count < constraints.size()) {
Constraints constraint = constraints[domain_count];
constraint.clipMesh(mesh);
auto poly_data = mesh.getVTKMesh();
if (poly_data->GetNumberOfCells() == 0) {
auto pd = mesh.getVTKMesh();
if (pd->GetNumberOfCells() == 0) {
throw std::invalid_argument("Mesh has zero cells after constraint clipping: " + filename);
}
}

if (get_use_geodesics_to_landmarks()) {
auto filenames = s->get_landmarks_filenames();
if (filenames.empty()) {
throw std::invalid_argument("Geodesic distance from landmarks is enabled, but subject has no landmark files");
if (use_geodesics_to_landmarks) {
auto landmark_filenames = s->get_landmarks_filenames();
if (landmark_filenames.empty()) {
throw std::invalid_argument(
"Geodesic distance from landmarks is enabled, but subject has no landmark files");
}
Eigen::VectorXd points;
if (!ParticleSystemEvaluation::read_particle_file(filenames[0], points)) {
SW_ERROR("Unable to read landmark file: {}", filenames[0]);
if (!ParticleSystemEvaluation::read_particle_file(landmark_filenames[0], points)) {
SW_ERROR("Unable to read landmark file: {}", landmark_filenames[0]);
}

// convert points to landmarks
std::vector<Point3> landmarks;
for (int i = 0; i < points.size() / 3; ++i) {
for (int j = 0; j < points.size() / 3; ++j) {
Point3 p;
p[0] = points(3 * i);
p[1] = points(3 * i + 1);
p[2] = points(3 * i + 2);
p[0] = points(3 * j);
p[1] = points(3 * j + 1);
p[2] = points(3 * j + 2);
landmarks.push_back(p);
}
mesh.computeLandmarkGeodesics(landmarks);
}

auto poly_data = mesh.getVTKMesh();

if (poly_data) {
if (poly_data->GetNumberOfCells() == 0) {
throw std::invalid_argument("Error, mesh had zero cells: " + filename);
}
// TODO This is a HACK for detecting contours
if (poly_data->GetCell(0)->GetNumberOfPoints() == 2) {
optimize->AddContour(poly_data);
} else {
optimize->AddMesh(poly_data);
}
} else {
auto pd = mesh.getVTKMesh();
if (!pd) {
throw std::invalid_argument("Error loading mesh: " + filename);
}
if (pd->GetNumberOfCells() == 0) {
throw std::invalid_argument("Error, mesh had zero cells: " + filename);
}
// TODO This is a HACK for detecting contours
item.is_contour_hack = (pd->GetCell(0)->GetNumberOfPoints() == 2);
item.poly_data = pd;
} else if (domain_type == DomainType::Contour) {
Mesh mesh = MeshUtils::threadSafeReadMesh(filename.c_str());
auto poly_data = mesh.getVTKMesh();
if (poly_data) {
optimize->AddContour(poly_data);
} else {
auto pd = mesh.getVTKMesh();
if (!pd) {
throw std::invalid_argument("Error loading contour: " + filename);
}
item.poly_data = pd;
} else {
Image image(filename);
if (s->is_fixed()) {
optimize->AddImage(nullptr, filename);
// Image domain — read now
item.image.emplace(filename);
}

si.work_item_indices.push_back(work_items.size());
work_items.push_back(std::move(item));
domain_count++;
}

subject_infos.push_back(std::move(si));
count++;
}

// === Phase B: Parallel pre-compute Surface objects for mesh domains (expensive) ===
// Collect indices of mesh work items that need Surface construction
std::vector<size_t> mesh_work_indices;
for (size_t idx = 0; idx < work_items.size(); idx++) {
auto& item = work_items[idx];
if (item.domain_type == DomainType::Mesh && !item.is_contour_hack && item.poly_data) {
mesh_work_indices.push_back(idx);
}
}

if (!mesh_work_indices.empty()) {
std::atomic<int> completed{0};
int total = static_cast<int>(mesh_work_indices.size());

double effective_remesh_percent = geodesics_enabled ? geodesic_remesh_percent : 100.0;
bool remeshing = effective_remesh_percent < 100.0;
std::string loading_message = remeshing ? "Loading and remeshing meshes" : "Loading meshes";
SW_PROGRESS(0, "{}: 0 / {}", loading_message, total);

tbb::parallel_for(tbb::blocked_range<size_t>(0, mesh_work_indices.size()),
[&](const tbb::blocked_range<size_t>& r) {
for (size_t i = r.begin(); i < r.end(); ++i) {
auto& item = work_items[mesh_work_indices[i]];

// Construct the main Surface (triangulation, cleaning, normals, cell locator, geodesics)
auto surface = std::make_shared<shapeworks::Surface>(
item.poly_data, geodesics_enabled,
static_cast<size_t>(geodesic_cache_multiplier));

// Create the Mesh wrapper and compute surface area
auto sw_mesh = std::make_shared<Mesh>(surface->get_polydata());
double surface_area = sw_mesh->getSurfaceArea();

// Create remeshed Surface for geodesics if needed
std::shared_ptr<shapeworks::Surface> geodesics_surface;
if (effective_remesh_percent >= 100.0) {
geodesics_surface = surface;
} else {
auto poly_copy = surface->get_polydata();
Mesh mesh_copy(poly_copy);
mesh_copy.remeshPercent(effective_remesh_percent, 1.0);
geodesics_surface =
std::make_shared<shapeworks::Surface>(mesh_copy.getVTKMesh());
}

item.surface = surface;
item.geodesics_surface = geodesics_surface;
item.sw_mesh = sw_mesh;
item.surface_area = surface_area;

int done = ++completed;
double progress = static_cast<double>(done) / static_cast<double>(total) * 100.0;
SW_PROGRESS(progress, "{}: {} / {}", loading_message, done, total);
}
});
}

// === Phase C: Sequential registration (fast) ===
domain_count = 0;
int subject_count = 0;
for (auto& si : subject_infos) {
if (abort_load_) {
return false;
}

std::vector<std::string> local_particle_filenames;
std::vector<std::string> world_particle_filenames;

for (size_t wi_idx : si.work_item_indices) {
auto& item = work_items[wi_idx];

if (item.domain_type == DomainType::Mesh && !item.is_contour_hack && item.poly_data) {
// Use pre-built Surface data
optimize->AddMesh(item.surface, item.geodesics_surface, item.sw_mesh, item.surface_area);
} else if (item.domain_type == DomainType::Mesh && item.is_contour_hack) {
optimize->AddContour(item.poly_data);
} else if (item.domain_type == DomainType::Contour) {
optimize->AddContour(item.poly_data);
} else {
// Image domain
if (item.is_fixed) {
optimize->AddImage(nullptr, item.filename);
} else {
optimize->AddImage(image, filename);
optimize->AddImage(*item.image, item.filename);
}
}

using TransformType = vnl_matrix_fixed<double, 4, 4>;
TransformType prefix_transform;
prefix_transform.set_identity();

int i = item.domain_index_in_subject;
auto& transforms = item.transforms;

if (i < transforms.size() && transforms[i].size() >= 12) {
prefix_transform[0][3] = transforms[i][9];
prefix_transform[1][3] = transforms[i][10];
Expand All @@ -801,19 +944,19 @@ bool OptimizeParameters::set_up_optimize(Optimize* optimize) {

domain_count++;

auto name = StringUtils::getBaseFilenameWithoutExtension(filename);
auto name = StringUtils::getBaseFilenameWithoutExtension(item.filename);

auto extension = get_particle_format();
auto prefix = get_output_prefix();
local_particle_filenames.push_back(prefix + name + "_local." + extension);
world_particle_filenames.push_back(prefix + name + "_world." + extension);
}
s->set_local_particle_filenames(local_particle_filenames);
s->set_world_particle_filenames(world_particle_filenames);
si.subject->set_local_particle_filenames(local_particle_filenames);
si.subject->set_world_particle_filenames(world_particle_filenames);

count++;
subject_count++;
if (load_callback_) {
load_callback_(count);
load_callback_(subject_count);
}
}

Expand Down
12 changes: 12 additions & 0 deletions Libs/Optimize/Sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,18 @@ void Sampler::AddMesh(std::shared_ptr<shapeworks::Surface> mesh, double geodesic
m_DomainList.push_back(domain);
}

void Sampler::AddMesh(std::shared_ptr<Surface> surface, std::shared_ptr<Surface> geodesics_surface,
std::shared_ptr<Mesh> sw_mesh, double surface_area) {
auto domain = std::make_shared<MeshDomain>();

if (surface) {
this->m_Spacing = 1;
domain->SetMesh(surface, geodesics_surface, sw_mesh, surface_area);
this->m_meshes.push_back(surface->get_polydata());
}
m_DomainList.push_back(domain);
}

void Sampler::AddContour(vtkSmartPointer<vtkPolyData> poly_data) {
auto domain = std::make_shared<ContourDomain>();

Expand Down
2 changes: 2 additions & 0 deletions Libs/Optimize/Sampler.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ class Sampler {
}

void AddMesh(std::shared_ptr<shapeworks::Surface> mesh, double geodesic_remesh_percent = 100);
void AddMesh(std::shared_ptr<Surface> surface, std::shared_ptr<Surface> geodesics_surface,
std::shared_ptr<Mesh> sw_mesh, double surface_area);

void AddContour(vtkSmartPointer<vtkPolyData> poly_data);

Expand Down
Loading