Skip to content

Commit 0dfc6e0

Browse files
bottlerfacebook-github-bot
authored andcommitted
CPU function for points2vols
Summary: Single C++ function for the core of points2vols, not used anywhere yet. Added ability to control align_corners and the weight of each point, which may be useful later. Reviewed By: nikhilaravi Differential Revision: D29548607 fbshipit-source-id: a5cda7ec2c14836624e7dfe744c4bbb3f3d3dfe2
1 parent c7c6dea commit 0dfc6e0

File tree

5 files changed

+767
-0
lines changed

5 files changed

+767
-0
lines changed

pytorch3d/csrc/ext.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include "mesh_normal_consistency/mesh_normal_consistency.h"
2626
#include "packed_to_padded_tensor/packed_to_padded_tensor.h"
2727
#include "point_mesh/point_mesh_cuda.h"
28+
#include "points_to_volumes/points_to_volumes.h"
2829
#include "rasterize_meshes/rasterize_meshes.h"
2930
#include "rasterize_points/rasterize_points.h"
3031
#include "sample_farthest_points/sample_farthest_points.h"
@@ -47,6 +48,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
4748
m.def(
4849
"mesh_normal_consistency_find_verts", &MeshNormalConsistencyFindVertices);
4950
m.def("gather_scatter", &GatherScatter);
51+
m.def("points_to_volumes_forward", PointsToVolumesForward);
52+
m.def("points_to_volumes_backward", PointsToVolumesBackward);
5053
m.def("rasterize_points", &RasterizePoints);
5154
m.def("rasterize_points_backward", &RasterizePointsBackward);
5255
m.def("rasterize_meshes_backward", &RasterizeMeshesBackward);
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
/*
2+
* Copyright (c) Facebook, Inc. and its affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
#include <torch/extension.h>
11+
#include <cstdio>
12+
#include <tuple>
13+
#include "utils/pytorch3d_cutils.h"
14+
15+
/*
16+
volume_features and volume_densities are modified in place.
17+
18+
Args:
19+
points_3d: Batch of 3D point cloud coordinates of shape
20+
`(minibatch, N, 3)` where N is the number of points
21+
in each point cloud. Coordinates have to be specified in the
22+
local volume coordinates (ranging in [-1, 1]).
23+
points_features: Features of shape `(minibatch, N, feature_dim)`
24+
corresponding to the points of the input point cloud `points_3d`.
25+
volume_features: Batch of input feature volumes
26+
of shape `(minibatch, feature_dim, D, H, W)`
27+
volume_densities: Batch of input feature volume densities
28+
of shape `(minibatch, 1, D, H, W)`. Each voxel should
29+
contain a non-negative number corresponding to its
30+
opaqueness (the higher, the less transparent).
31+
32+
grid_sizes: `LongTensor` of shape (minibatch, 3) representing the
33+
spatial resolutions of each of the the non-flattened `volumes`
34+
tensors. Note that the following has to hold:
35+
`torch.prod(grid_sizes, dim=1)==N_voxels`.
36+
37+
point_weight: A scalar controlling how much weight a single point has.
38+
39+
mask: A binary mask of shape `(minibatch, N)` determining
40+
which 3D points are going to be converted to the resulting
41+
volume. Set to `None` if all points are valid.
42+
43+
align_corners: as for grid_sample.
44+
45+
splat: if true, trilinear interpolation. If false all the weight goes in
46+
the nearest voxel.
47+
*/
48+
49+
void PointsToVolumesForwardCpu(
50+
const torch::Tensor& points_3d,
51+
const torch::Tensor& points_features,
52+
const torch::Tensor& volume_densities,
53+
const torch::Tensor& volume_features,
54+
const torch::Tensor& grid_sizes,
55+
const torch::Tensor& mask,
56+
float point_weight,
57+
bool align_corners,
58+
bool splat);
59+
60+
inline void PointsToVolumesForward(
61+
const torch::Tensor& points_3d,
62+
const torch::Tensor& points_features,
63+
const torch::Tensor& volume_densities,
64+
const torch::Tensor& volume_features,
65+
const torch::Tensor& grid_sizes,
66+
const torch::Tensor& mask,
67+
float point_weight,
68+
bool align_corners,
69+
bool splat) {
70+
if (points_3d.is_cuda()) {
71+
#ifdef WITH_CUDA
72+
AT_ERROR("CUDA not implemented yet");
73+
#else
74+
AT_ERROR("Not compiled with GPU support.");
75+
#endif
76+
}
77+
PointsToVolumesForwardCpu(
78+
points_3d,
79+
points_features,
80+
volume_densities,
81+
volume_features,
82+
grid_sizes,
83+
mask,
84+
point_weight,
85+
align_corners,
86+
splat);
87+
}
88+
89+
// grad_points_3d and grad_points_features are modified in place.
90+
91+
void PointsToVolumesBackwardCpu(
92+
const torch::Tensor& points_3d,
93+
const torch::Tensor& points_features,
94+
const torch::Tensor& grid_sizes,
95+
const torch::Tensor& mask,
96+
float point_weight,
97+
bool align_corners,
98+
bool splat,
99+
const torch::Tensor& grad_volume_densities,
100+
const torch::Tensor& grad_volume_features,
101+
const torch::Tensor& grad_points_3d,
102+
const torch::Tensor& grad_points_features);
103+
104+
inline void PointsToVolumesBackward(
105+
const torch::Tensor& points_3d,
106+
const torch::Tensor& points_features,
107+
const torch::Tensor& grid_sizes,
108+
const torch::Tensor& mask,
109+
float point_weight,
110+
bool align_corners,
111+
bool splat,
112+
const torch::Tensor& grad_volume_densities,
113+
const torch::Tensor& grad_volume_features,
114+
const torch::Tensor& grad_points_3d,
115+
const torch::Tensor& grad_points_features) {
116+
if (points_3d.is_cuda()) {
117+
#ifdef WITH_CUDA
118+
AT_ERROR("CUDA not implemented yet");
119+
#else
120+
AT_ERROR("Not compiled with GPU support.");
121+
#endif
122+
}
123+
PointsToVolumesBackwardCpu(
124+
points_3d,
125+
points_features,
126+
grid_sizes,
127+
mask,
128+
point_weight,
129+
align_corners,
130+
splat,
131+
grad_volume_densities,
132+
grad_volume_features,
133+
grad_points_3d,
134+
grad_points_features);
135+
}

0 commit comments

Comments
 (0)