Skip to content

Commit dd76b41

Browse files
bottlerfacebook-github-bot
authored andcommitted
save colors as uint8 in PLY
Summary: Allow saving colors as 8bit when writing .ply files. Reviewed By: patricklabatut, nikitos9000 Differential Revision: D30905312 fbshipit-source-id: 44500982c9ed6d6ee901e04f9623e22792a0e7f7
1 parent 1b1ba56 commit dd76b41

File tree

2 files changed

+131
-20
lines changed

2 files changed

+131
-20
lines changed

pytorch3d/io/ply_io.py

Lines changed: 86 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1067,18 +1067,18 @@ def load_ply(
10671067
return verts, faces
10681068

10691069

1070-
def _save_ply(
1070+
def _write_ply_header(
10711071
f,
10721072
*,
10731073
verts: torch.Tensor,
10741074
faces: Optional[torch.LongTensor],
10751075
verts_normals: Optional[torch.Tensor],
10761076
verts_colors: Optional[torch.Tensor],
10771077
ascii: bool,
1078-
decimal_places: Optional[int] = None,
1078+
colors_as_uint8: bool,
10791079
) -> None:
10801080
"""
1081-
Internal implementation for saving 3D data to a .ply file.
1081+
Internal implementation for writing header when saving to a .ply file.
10821082
10831083
Args:
10841084
f: File object to which the 3D data should be written.
@@ -1087,7 +1087,8 @@ def _save_ply(
10871087
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
10881088
verts_colors: FloatTensor of shape (V, 3) giving vertex colors.
10891089
ascii: (bool) whether to use the ascii ply format.
1090-
decimal_places: Number of decimal places for saving if ascii=True.
1090+
colors_as_uint8: Whether to save colors as numbers in the range
1091+
[0, 255] instead of float32.
10911092
"""
10921093
assert not len(verts) or (verts.dim() == 2 and verts.size(1) == 3)
10931094
assert faces is None or not len(faces) or (faces.dim() == 2 and faces.size(1) == 3)
@@ -1113,33 +1114,88 @@ def _save_ply(
11131114
f.write(b"property float ny\n")
11141115
f.write(b"property float nz\n")
11151116
if verts_colors is not None:
1116-
f.write(b"property float red\n")
1117-
f.write(b"property float green\n")
1118-
f.write(b"property float blue\n")
1117+
color_ply_type = b"uchar" if colors_as_uint8 else b"float"
1118+
for color in (b"red", b"green", b"blue"):
1119+
f.write(b"property " + color_ply_type + b" " + color + b"\n")
11191120
if len(verts) and faces is not None:
11201121
f.write(f"element face {faces.shape[0]}\n".encode("ascii"))
11211122
f.write(b"property list uchar int vertex_index\n")
11221123
f.write(b"end_header\n")
11231124

1125+
1126+
def _save_ply(
1127+
f,
1128+
*,
1129+
verts: torch.Tensor,
1130+
faces: Optional[torch.LongTensor],
1131+
verts_normals: Optional[torch.Tensor],
1132+
verts_colors: Optional[torch.Tensor],
1133+
ascii: bool,
1134+
decimal_places: Optional[int] = None,
1135+
colors_as_uint8: bool,
1136+
) -> None:
1137+
"""
1138+
Internal implementation for saving 3D data to a .ply file.
1139+
1140+
Args:
1141+
f: File object to which the 3D data should be written.
1142+
verts: FloatTensor of shape (V, 3) giving vertex coordinates.
1143+
faces: LongTensor of shape (F, 3) giving faces.
1144+
verts_normals: FloatTensor of shape (V, 3) giving vertex normals.
1145+
verts_colors: FloatTensor of shape (V, 3) giving vertex colors.
1146+
ascii: (bool) whether to use the ascii ply format.
1147+
decimal_places: Number of decimal places for saving if ascii=True.
1148+
colors_as_uint8: Whether to save colors as numbers in the range
1149+
[0, 255] instead of float32.
1150+
"""
1151+
_write_ply_header(
1152+
f,
1153+
verts=verts,
1154+
faces=faces,
1155+
verts_normals=verts_normals,
1156+
verts_colors=verts_colors,
1157+
ascii=ascii,
1158+
colors_as_uint8=colors_as_uint8,
1159+
)
1160+
11241161
if not (len(verts)):
11251162
warnings.warn("Empty 'verts' provided")
11261163
return
11271164

1128-
verts_tensors = [verts]
1165+
color_np_type = np.ubyte if colors_as_uint8 else np.float32
1166+
verts_dtype = [("verts", np.float32, 3)]
11291167
if verts_normals is not None:
1130-
verts_tensors.append(verts_normals)
1168+
verts_dtype.append(("normals", np.float32, 3))
11311169
if verts_colors is not None:
1132-
verts_tensors.append(verts_colors)
1170+
verts_dtype.append(("colors", color_np_type, 3))
1171+
1172+
vert_data = np.zeros(verts.shape[0], dtype=verts_dtype)
1173+
vert_data["verts"] = verts.detach().cpu().numpy()
1174+
if verts_normals is not None:
1175+
vert_data["normals"] = verts_normals.detach().cpu().numpy()
1176+
if verts_colors is not None:
1177+
color_data = verts_colors.detach().cpu().numpy()
1178+
if colors_as_uint8:
1179+
vert_data["colors"] = np.rint(color_data * 255)
1180+
else:
1181+
vert_data["colors"] = color_data
11331182

1134-
vert_data = torch.cat(verts_tensors, dim=1).detach().cpu().numpy()
11351183
if ascii:
11361184
if decimal_places is None:
1137-
float_str = "%f"
1185+
float_str = b"%f"
11381186
else:
1139-
float_str = "%" + ".%df" % decimal_places
1140-
np.savetxt(f, vert_data, float_str)
1187+
float_str = b"%" + b".%df" % decimal_places
1188+
float_group_str = (float_str + b" ") * 3
1189+
formats = [float_group_str]
1190+
if verts_normals is not None:
1191+
formats.append(float_group_str)
1192+
if verts_colors is not None:
1193+
formats.append(b"%d %d %d " if colors_as_uint8 else float_group_str)
1194+
formats[-1] = formats[-1][:-1] + b"\n"
1195+
for line_data in vert_data:
1196+
for data, format in zip(line_data, formats):
1197+
f.write(format % tuple(data))
11411198
else:
1142-
assert vert_data.dtype == np.float32
11431199
if isinstance(f, BytesIO):
11441200
# tofile only works with real files, but is faster than this.
11451201
f.write(vert_data.tobytes())
@@ -1189,7 +1245,6 @@ def save_ply(
11891245
ascii: (bool) whether to use the ascii ply format.
11901246
decimal_places: Number of decimal places for saving if ascii=True.
11911247
path_manager: PathManager for interpreting f if it is a str.
1192-
11931248
"""
11941249

11951250
if len(verts) and not (verts.dim() == 2 and verts.size(1) == 3):
@@ -1227,6 +1282,7 @@ def save_ply(
12271282
verts_colors=None,
12281283
ascii=ascii,
12291284
decimal_places=decimal_places,
1285+
colors_as_uint8=False,
12301286
)
12311287

12321288

@@ -1272,8 +1328,14 @@ def save(
12721328
path_manager: PathManager,
12731329
binary: Optional[bool],
12741330
decimal_places: Optional[int] = None,
1331+
colors_as_uint8: bool = False,
12751332
**kwargs,
12761333
) -> bool:
1334+
"""
1335+
Extra optional args:
1336+
colors_as_uint8: (bool) Whether to save colors as numbers in the
1337+
range [0, 255] instead of float32.
1338+
"""
12771339
if not endswith(path, self.known_suffixes):
12781340
return False
12791341

@@ -1307,6 +1369,7 @@ def save(
13071369
verts_normals=verts_normals,
13081370
ascii=binary is False,
13091371
decimal_places=decimal_places,
1372+
colors_as_uint8=colors_as_uint8,
13101373
)
13111374
return True
13121375

@@ -1342,8 +1405,14 @@ def save(
13421405
path_manager: PathManager,
13431406
binary: Optional[bool],
13441407
decimal_places: Optional[int] = None,
1408+
colors_as_uint8: bool = False,
13451409
**kwargs,
13461410
) -> bool:
1411+
"""
1412+
Extra optional args:
1413+
colors_as_uint8: (bool) Whether to save colors as numbers in the
1414+
range [0, 255] instead of float32.
1415+
"""
13471416
if not endswith(path, self.known_suffixes):
13481417
return False
13491418

@@ -1360,5 +1429,6 @@ def save(
13601429
faces=None,
13611430
ascii=binary is False,
13621431
decimal_places=decimal_places,
1432+
colors_as_uint8=colors_as_uint8,
13631433
)
13641434
return True

tests/test_io_ply.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -528,29 +528,70 @@ def test_save_pointcloud(self):
528528
).encode("ascii")
529529
data = struct.pack("<" + "f" * 48, *range(48))
530530
points = torch.FloatTensor([0, 1, 2]) + 6 * torch.arange(8)[:, None]
531-
features = torch.FloatTensor([3, 4, 5]) + 6 * torch.arange(8)[:, None]
531+
features_large = torch.FloatTensor([3, 4, 5]) + 6 * torch.arange(8)[:, None]
532+
features = features_large / 255.0
533+
pointcloud_largefeatures = Pointclouds(
534+
points=[points], features=[features_large]
535+
)
532536
pointcloud = Pointclouds(points=[points], features=[features])
533537

534538
io = IO()
535539
with NamedTemporaryFile(mode="rb", suffix=".ply") as f:
536-
io.save_pointcloud(data=pointcloud, path=f.name)
540+
io.save_pointcloud(data=pointcloud_largefeatures, path=f.name)
537541
f.flush()
538542
f.seek(0)
539543
actual_data = f.read()
540544
reloaded_pointcloud = io.load_pointcloud(f.name)
541545

542546
self.assertEqual(header + data, actual_data)
543547
self.assertClose(reloaded_pointcloud.points_list()[0], points)
544-
self.assertClose(reloaded_pointcloud.features_list()[0], features)
548+
self.assertClose(reloaded_pointcloud.features_list()[0], features_large)
549+
# Test the load-save cycle leaves file completely unchanged
550+
with NamedTemporaryFile(mode="rb", suffix=".ply") as f:
551+
io.save_pointcloud(
552+
data=reloaded_pointcloud,
553+
path=f.name,
554+
)
555+
f.flush()
556+
f.seek(0)
557+
data2 = f.read()
558+
self.assertEqual(data2, actual_data)
545559

546560
with NamedTemporaryFile(mode="r", suffix=".ply") as f:
547-
io.save_pointcloud(data=pointcloud, path=f.name, binary=False)
561+
io.save_pointcloud(
562+
data=pointcloud, path=f.name, binary=False, decimal_places=9
563+
)
548564
reloaded_pointcloud2 = io.load_pointcloud(f.name)
549565
self.assertEqual(f.readline(), "ply\n")
550566
self.assertEqual(f.readline(), "format ascii 1.0\n")
551567
self.assertClose(reloaded_pointcloud2.points_list()[0], points)
552568
self.assertClose(reloaded_pointcloud2.features_list()[0], features)
553569

570+
for binary in [True, False]:
571+
with NamedTemporaryFile(mode="rb", suffix=".ply") as f:
572+
io.save_pointcloud(
573+
data=pointcloud, path=f.name, colors_as_uint8=True, binary=binary
574+
)
575+
f.flush()
576+
f.seek(0)
577+
actual_data = f.read()
578+
reloaded_pointcloud3 = io.load_pointcloud(f.name)
579+
self.assertClose(reloaded_pointcloud3.features_list()[0], features)
580+
self.assertIn(b"property uchar green", actual_data)
581+
582+
# Test the load-save cycle leaves file completely unchanged
583+
with NamedTemporaryFile(mode="rb", suffix=".ply") as f:
584+
io.save_pointcloud(
585+
data=reloaded_pointcloud3,
586+
path=f.name,
587+
binary=binary,
588+
colors_as_uint8=True,
589+
)
590+
f.flush()
591+
f.seek(0)
592+
data2 = f.read()
593+
self.assertEqual(data2, actual_data)
594+
554595
def test_load_pointcloud_bad_order(self):
555596
"""
556597
Ply file with a strange property order

0 commit comments

Comments
 (0)