@@ -166,6 +166,73 @@ def mesh_structures_equal(mesh1, mesh2) -> bool:
166
166
return True
167
167
168
168
169
+ def to_sorted (mesh : Meshes ) -> "Meshes" :
170
+ """
171
+ Create a new Meshes object, where each sub-mesh's vertices are sorted
172
+ alphabetically.
173
+
174
+ Returns:
175
+ A Meshes object with the same topology as this mesh, with vertices sorted
176
+ alphabetically.
177
+
178
+ Example:
179
+
180
+ For a mesh with verts [[2.3, .2, .4], [.0, .1, .2], [.0, .0, .1]] and a single
181
+ face [[0, 1, 2]], to_sorted will create a new mesh with verts [[.0, .0, .1],
182
+ [.0, .1, .2], [2.3, .2, .4]] and a single face [[2, 1, 0]]. This is useful to
183
+ create a semi-canonical representation of the mesh that is invariant to vertex
184
+ permutations, but not invariant to coordinate frame changes.
185
+ """
186
+ if mesh .textures is not None :
187
+ raise NotImplementedError (
188
+ "to_sorted is not implemented for meshes with "
189
+ f"{ type (mesh .textures ).__name__ } textures."
190
+ )
191
+
192
+ verts_list = mesh .verts_list ()
193
+ faces_list = mesh .faces_list ()
194
+ verts_sorted_list = []
195
+ faces_sorted_list = []
196
+
197
+ for verts , faces in zip (verts_list , faces_list ):
198
+ # Argsort the vertices alphabetically: sort_ids[k] corresponds to the id of
199
+ # the vertex in the non-sorted mesh that should sit at index k in the sorted mesh.
200
+ sort_ids = torch .tensor (
201
+ [
202
+ idx_and_val [0 ]
203
+ for idx_and_val in sorted (
204
+ enumerate (verts .tolist ()),
205
+ key = lambda idx_and_val : idx_and_val [1 ],
206
+ )
207
+ ],
208
+ device = mesh .device ,
209
+ )
210
+
211
+ # Resort the vertices. index_select allocates new memory.
212
+ verts_sorted = verts [sort_ids ]
213
+ verts_sorted_list .append (verts_sorted )
214
+
215
+ # The `faces` tensor contains vertex ids. Substitute old vertex ids for the
216
+ # new ones. new_vertex_ids is the inverse of sort_ids: new_vertex_ids[k]
217
+ # corresponds to the id of the vertex in the sorted mesh that is the same as
218
+ # vertex k in the non-sorted mesh.
219
+ new_vertex_ids = torch .argsort (sort_ids )
220
+ faces_sorted = (
221
+ torch .gather (new_vertex_ids , 0 , faces .flatten ())
222
+ .reshape (faces .shape )
223
+ .clone ()
224
+ )
225
+ faces_sorted_list .append (faces_sorted )
226
+
227
+ other = mesh .__class__ (verts = verts_sorted_list , faces = faces_sorted_list )
228
+ for k in mesh ._INTERNAL_TENSORS :
229
+ v = getattr (mesh , k )
230
+ if torch .is_tensor (v ):
231
+ setattr (other , k , v .clone ())
232
+
233
+ return other
234
+
235
+
169
236
class TestMeshes (TestCaseMixin , unittest .TestCase ):
170
237
def setUp (self ) -> None :
171
238
np .random .seed (42 )
@@ -1223,6 +1290,57 @@ def test_equality(self):
1223
1290
self .assertFalse (mesh_structures_equal (meshes1 , meshes2 ))
1224
1291
self .assertFalse (mesh_structures_equal (meshes1 , meshes3 ))
1225
1292
1293
+ def test_to_sorted (self ):
1294
+ mesh = init_simple_mesh ()
1295
+ sorted_mesh = to_sorted (mesh )
1296
+
1297
+ expected_verts = [
1298
+ torch .tensor (
1299
+ [[0.1 , 0.3 , 0.5 ], [0.5 , 0.2 , 0.1 ], [0.6 , 0.8 , 0.7 ]],
1300
+ dtype = torch .float32 ,
1301
+ ),
1302
+ torch .tensor (
1303
+ # Vertex permutation: 0->0, 1->3, 2->2, 3->1
1304
+ [[0.1 , 0.3 , 0.3 ], [0.1 , 0.5 , 0.3 ], [0.2 , 0.3 , 0.4 ], [0.6 , 0.7 , 0.8 ]],
1305
+ dtype = torch .float32 ,
1306
+ ),
1307
+ torch .tensor (
1308
+ # Vertex permutation: 0->2, 1->1, 2->4, 3->0, 4->3
1309
+ [
1310
+ [0.2 , 0.3 , 0.4 ],
1311
+ [0.2 , 0.4 , 0.8 ],
1312
+ [0.7 , 0.3 , 0.6 ],
1313
+ [0.9 , 0.3 , 0.8 ],
1314
+ [0.9 , 0.5 , 0.2 ],
1315
+ ],
1316
+ dtype = torch .float32 ,
1317
+ ),
1318
+ ]
1319
+
1320
+ expected_faces = [
1321
+ torch .tensor ([[0 , 1 , 2 ]], dtype = torch .int64 ),
1322
+ torch .tensor ([[0 , 3 , 2 ], [3 , 2 , 1 ]], dtype = torch .int64 ),
1323
+ torch .tensor (
1324
+ [
1325
+ [1 , 4 , 2 ],
1326
+ [2 , 1 , 0 ],
1327
+ [4 , 0 , 1 ],
1328
+ [3 , 0 , 4 ],
1329
+ [3 , 2 , 1 ],
1330
+ [3 , 0 , 1 ],
1331
+ [3 , 4 , 1 ],
1332
+ ],
1333
+ dtype = torch .int64 ,
1334
+ ),
1335
+ ]
1336
+
1337
+ self .assertFalse (mesh_structures_equal (mesh , sorted_mesh ))
1338
+ self .assertTrue (
1339
+ mesh_structures_equal (
1340
+ Meshes (verts = expected_verts , faces = expected_faces ), sorted_mesh
1341
+ )
1342
+ )
1343
+
1226
1344
@staticmethod
1227
1345
def compute_packed_with_init (
1228
1346
num_meshes : int = 10 , max_v : int = 100 , max_f : int = 300 , device : str = "cpu"
0 commit comments