@@ -1067,18 +1067,18 @@ def load_ply(
1067
1067
return verts , faces
1068
1068
1069
1069
1070
- def _save_ply (
1070
+ def _write_ply_header (
1071
1071
f ,
1072
1072
* ,
1073
1073
verts : torch .Tensor ,
1074
1074
faces : Optional [torch .LongTensor ],
1075
1075
verts_normals : Optional [torch .Tensor ],
1076
1076
verts_colors : Optional [torch .Tensor ],
1077
1077
ascii : bool ,
1078
- decimal_places : Optional [ int ] = None ,
1078
+ colors_as_uint8 : bool ,
1079
1079
) -> None :
1080
1080
"""
1081
- Internal implementation for saving 3D data to a .ply file.
1081
+ Internal implementation for writing header when saving to a .ply file.
1082
1082
1083
1083
Args:
1084
1084
f: File object to which the 3D data should be written.
@@ -1087,7 +1087,8 @@ def _save_ply(
1087
1087
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
1088
1088
verts_colors: FloatTensor of shape (V, 3) giving vertex colors.
1089
1089
ascii: (bool) whether to use the ascii ply format.
1090
- decimal_places: Number of decimal places for saving if ascii=True.
1090
+ colors_as_uint8: Whether to save colors as numbers in the range
1091
+ [0, 255] instead of float32.
1091
1092
"""
1092
1093
assert not len (verts ) or (verts .dim () == 2 and verts .size (1 ) == 3 )
1093
1094
assert faces is None or not len (faces ) or (faces .dim () == 2 and faces .size (1 ) == 3 )
@@ -1113,33 +1114,88 @@ def _save_ply(
1113
1114
f .write (b"property float ny\n " )
1114
1115
f .write (b"property float nz\n " )
1115
1116
if verts_colors is not None :
1116
- f . write ( b"property float red \n " )
1117
- f . write (b"property float green\n " )
1118
- f .write (b"property float blue \n " )
1117
+ color_ply_type = b"uchar" if colors_as_uint8 else b"float"
1118
+ for color in (b"red" , b" green" , b"blue" ):
1119
+ f .write (b"property " + color_ply_type + b" " + color + b" \n " )
1119
1120
if len (verts ) and faces is not None :
1120
1121
f .write (f"element face { faces .shape [0 ]} \n " .encode ("ascii" ))
1121
1122
f .write (b"property list uchar int vertex_index\n " )
1122
1123
f .write (b"end_header\n " )
1123
1124
1125
+
1126
+ def _save_ply (
1127
+ f ,
1128
+ * ,
1129
+ verts : torch .Tensor ,
1130
+ faces : Optional [torch .LongTensor ],
1131
+ verts_normals : Optional [torch .Tensor ],
1132
+ verts_colors : Optional [torch .Tensor ],
1133
+ ascii : bool ,
1134
+ decimal_places : Optional [int ] = None ,
1135
+ colors_as_uint8 : bool ,
1136
+ ) -> None :
1137
+ """
1138
+ Internal implementation for saving 3D data to a .ply file.
1139
+
1140
+ Args:
1141
+ f: File object to which the 3D data should be written.
1142
+ verts: FloatTensor of shape (V, 3) giving vertex coordinates.
1143
+ faces: LongTensor of shape (F, 3) giving faces.
1144
+ verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
1145
+ verts_colors: FloatTensor of shape (V, 3) giving vertex colors.
1146
+ ascii: (bool) whether to use the ascii ply format.
1147
+ decimal_places: Number of decimal places for saving if ascii=True.
1148
+ colors_as_uint8: Whether to save colors as numbers in the range
1149
+ [0, 255] instead of float32.
1150
+ """
1151
+ _write_ply_header (
1152
+ f ,
1153
+ verts = verts ,
1154
+ faces = faces ,
1155
+ verts_normals = verts_normals ,
1156
+ verts_colors = verts_colors ,
1157
+ ascii = ascii ,
1158
+ colors_as_uint8 = colors_as_uint8 ,
1159
+ )
1160
+
1124
1161
if not (len (verts )):
1125
1162
warnings .warn ("Empty 'verts' provided" )
1126
1163
return
1127
1164
1128
- verts_tensors = [verts ]
1165
+ color_np_type = np .ubyte if colors_as_uint8 else np .float32
1166
+ verts_dtype = [("verts" , np .float32 , 3 )]
1129
1167
if verts_normals is not None :
1130
- verts_tensors .append (verts_normals )
1168
+ verts_dtype .append (( "normals" , np . float32 , 3 ) )
1131
1169
if verts_colors is not None :
1132
- verts_tensors .append (verts_colors )
1170
+ verts_dtype .append (("colors" , color_np_type , 3 ))
1171
+
1172
+ vert_data = np .zeros (verts .shape [0 ], dtype = verts_dtype )
1173
+ vert_data ["verts" ] = verts .detach ().cpu ().numpy ()
1174
+ if verts_normals is not None :
1175
+ vert_data ["normals" ] = verts_normals .detach ().cpu ().numpy ()
1176
+ if verts_colors is not None :
1177
+ color_data = verts_colors .detach ().cpu ().numpy ()
1178
+ if colors_as_uint8 :
1179
+ vert_data ["colors" ] = np .rint (color_data * 255 )
1180
+ else :
1181
+ vert_data ["colors" ] = color_data
1133
1182
1134
- vert_data = torch .cat (verts_tensors , dim = 1 ).detach ().cpu ().numpy ()
1135
1183
if ascii :
1136
1184
if decimal_places is None :
1137
- float_str = "%f"
1185
+ float_str = b "%f"
1138
1186
else :
1139
- float_str = "%" + ".%df" % decimal_places
1140
- np .savetxt (f , vert_data , float_str )
1187
+ float_str = b"%" + b".%df" % decimal_places
1188
+ float_group_str = (float_str + b" " ) * 3
1189
+ formats = [float_group_str ]
1190
+ if verts_normals is not None :
1191
+ formats .append (float_group_str )
1192
+ if verts_colors is not None :
1193
+ formats .append (b"%d %d %d " if colors_as_uint8 else float_group_str )
1194
+ formats [- 1 ] = formats [- 1 ][:- 1 ] + b"\n "
1195
+ for line_data in vert_data :
1196
+ for data , format in zip (line_data , formats ):
1197
+ f .write (format % tuple (data ))
1141
1198
else :
1142
- assert vert_data .dtype == np .float32
1143
1199
if isinstance (f , BytesIO ):
1144
1200
# tofile only works with real files, but is faster than this.
1145
1201
f .write (vert_data .tobytes ())
@@ -1189,7 +1245,6 @@ def save_ply(
1189
1245
ascii: (bool) whether to use the ascii ply format.
1190
1246
decimal_places: Number of decimal places for saving if ascii=True.
1191
1247
path_manager: PathManager for interpreting f if it is a str.
1192
-
1193
1248
"""
1194
1249
1195
1250
if len (verts ) and not (verts .dim () == 2 and verts .size (1 ) == 3 ):
@@ -1227,6 +1282,7 @@ def save_ply(
1227
1282
verts_colors = None ,
1228
1283
ascii = ascii ,
1229
1284
decimal_places = decimal_places ,
1285
+ colors_as_uint8 = False ,
1230
1286
)
1231
1287
1232
1288
@@ -1272,8 +1328,14 @@ def save(
1272
1328
path_manager : PathManager ,
1273
1329
binary : Optional [bool ],
1274
1330
decimal_places : Optional [int ] = None ,
1331
+ colors_as_uint8 : bool = False ,
1275
1332
** kwargs ,
1276
1333
) -> bool :
1334
+ """
1335
+ Extra optional args:
1336
+ colors_as_uint8: (bool) Whether to save colors as numbers in the
1337
+ range [0, 255] instead of float32.
1338
+ """
1277
1339
if not endswith (path , self .known_suffixes ):
1278
1340
return False
1279
1341
@@ -1307,6 +1369,7 @@ def save(
1307
1369
verts_normals = verts_normals ,
1308
1370
ascii = binary is False ,
1309
1371
decimal_places = decimal_places ,
1372
+ colors_as_uint8 = colors_as_uint8 ,
1310
1373
)
1311
1374
return True
1312
1375
@@ -1342,8 +1405,14 @@ def save(
1342
1405
path_manager : PathManager ,
1343
1406
binary : Optional [bool ],
1344
1407
decimal_places : Optional [int ] = None ,
1408
+ colors_as_uint8 : bool = False ,
1345
1409
** kwargs ,
1346
1410
) -> bool :
1411
+ """
1412
+ Extra optional args:
1413
+ colors_as_uint8: (bool) Whether to save colors as numbers in the
1414
+ range [0, 255] instead of float32.
1415
+ """
1347
1416
if not endswith (path , self .known_suffixes ):
1348
1417
return False
1349
1418
@@ -1360,5 +1429,6 @@ def save(
1360
1429
faces = None ,
1361
1430
ascii = binary is False ,
1362
1431
decimal_places = decimal_places ,
1432
+ colors_as_uint8 = colors_as_uint8 ,
1363
1433
)
1364
1434
return True
0 commit comments