-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathbuild.scala
More file actions
126 lines (106 loc) · 3.77 KB
/
build.scala
File metadata and controls
126 lines (106 loc) · 3.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
//> using scala "3.3"
//> using dep "ch.unibas.cs.gravis::scalismo-ui:0.92.0"
import scalismo.geometry.{Point, EuclideanVector, _3D}
import scalismo.common._
import scalismo.common.interpolation._
import scalismo.statisticalmodel._
import scalismo.statisticalmodel.dataset.DataCollection
import scalismo.utils.Random
import scalismo.io.StatisticalModelIO
import java.io.File
import scala.io.Source
@main
def buildPDM(dataPath: String = ""): Unit = {
// Initialize Scalismo
implicit val rng: Random = Random(42)
scalismo.initialize()
// Step 1: Load point clouds - using command line argument for data directory
val dataDir = new File(dataPath)
// Verify the directory exists
if (!dataDir.exists() || !dataDir.isDirectory) {
println(s"Error: ${dataDir.getAbsolutePath} is not a valid directory.")
sys.exit(1)
}
println(s"Loading point clouds from ${dataDir.getAbsolutePath}")
val pointCloudFiles = dataDir.listFiles().filter(_.getName.endsWith(".pts"))
val allPointClouds = pointCloudFiles.map { file =>
loadPointclouds(file)
}.toIndexedSeq
// Step 2: Choose the first point cloud as reference
val referencePoints = computeMeanPointCloud(allPointClouds)
val referenceDomain = UnstructuredPointsDomain(referencePoints)
// Step 3: Convert point clouds to deformation fields
val deformationFields = allPointClouds.map { points =>
val deformations = points.zip(referencePoints).map { case (pt, refPt) =>
pt - refPt
}
DiscreteField(referenceDomain, deformations)
}
// Step 4: Build the PDM using PCA
val dataset = DataCollection(deformationFields.toIndexedSeq)
val pdm = PointDistributionModel.createUsingPCA(dataset)
// Print PDM information
println(s"PDM built successfully with ${pdm.rank} principal components")
// Step 5: Save the PDM to a file
val outputDir = new File("models")
if (!outputDir.exists()) {
outputDir.mkdirs()
}
// Use the directory name as the basis for the output filename
val dirName = dataDir.getName
val outputFile = new File(outputDir, s"${dirName}_pdm.h5.json")
// Use the specialized method for UnstructuredPointsDomain PDMs
val saveResult =
StatisticalModelIO.writeStatisticalPointModel3D(pdm, outputFile)
saveResult match {
case scala.util.Success(_) =>
println(s"PDM successfully saved to: ${outputFile.getAbsolutePath}")
case scala.util.Failure(ex) =>
println(s"Failed to save PDM: ${ex.getMessage}")
}
}
def loadPointclouds(file: File): IndexedSeq[Point[_3D]] = {
val source = Source.fromFile(file)
try {
val points = source
.getLines()
.flatMap { line =>
val trimmedLine = line.trim
if (trimmedLine.isEmpty) None
else {
val parts = trimmedLine.split("\\s+")
if (parts.length == 3) {
val x = parts(0).toDouble
val y = parts(1).toDouble
val z = parts(2).toDouble
Some(Point(x, y, z))
} else None
}
}
.toIndexedSeq
points
} finally { source.close() }
}
def computeMeanPointCloud(
pointClouds: Seq[IndexedSeq[Point[_3D]]]
): IndexedSeq[Point[_3D]] = {
// Check if input is not empty
if (pointClouds.isEmpty) {
throw new IllegalArgumentException(
"Cannot compute mean of empty point cloud collection"
)
}
// Initialize with zeros
val zeroPoints =
Array.fill(pointClouds.head.length)(EuclideanVector(0, 0, 0))
// Sum all points
val summedVectors = pointClouds.foldLeft(zeroPoints) { (acc, pointCloud) =>
acc.zip(pointCloud).map { case (accVec, pt) =>
accVec + EuclideanVector(pt.x, pt.y, pt.z)
}
}
// Divide by count to get mean
val meanVectors = summedVectors.map(_ * (1.0 / pointClouds.length))
// Convert back to points
meanVectors.map(v => Point(v.x, v.y, v.z)).toIndexedSeq
}