diff --git a/include/tiny-cuda-nn/encodings/identity.h b/include/tiny-cuda-nn/encodings/identity.h index 7518775d..7a47c82a 100644 --- a/include/tiny-cuda-nn/encodings/identity.h +++ b/include/tiny-cuda-nn/encodings/identity.h @@ -60,7 +60,7 @@ __global__ void identity( const uint32_t j = encoded_index - i * fan_out; if (j >= num_to_encode) { - data_out(j, i) = 1; + data_out(j, i) = 0; // data_out(j, i) = 0; } else { data_out(j, i) = data_in(j, i) * scale + offset; } @@ -84,6 +84,25 @@ __global__ void identity_backward( dL_dx(j, i) = (T)((float)dL_dy(j, i) * scale); } +template +__global__ void identity_backward_backward( + const uint32_t num_outputs, + const uint32_t num_elements, + const uint32_t n_dims_to_encode, + const float scale, + MatrixView dL_ddLdy, + MatrixView dL_ddLdx) +{ + const uint32_t output_index = threadIdx.x + blockIdx.x * blockDim.x; + if (output_index >= num_outputs) return; + + const uint32_t i = output_index / n_dims_to_encode; + const uint32_t j = output_index - i * n_dims_to_encode; + + // The identity encoding can simply pass through the derivative. + dL_ddLdx(j, i) = (T)(dL_ddLdy(j, i) * scale); +} + template class IdentityEncoding : public Encoding { public: @@ -139,6 +158,33 @@ class IdentityEncoding : public Encoding { ); } + void backward_backward_input_impl( + cudaStream_t stream, + const Context& ctx, + const GPUMatrixDynamic& input, + const GPUMatrixDynamic& dL_ddLdinput, + const GPUMatrixDynamic& dL_doutput, + GPUMatrixDynamic* dL_ddLdoutput = nullptr, + GPUMatrixDynamic* dL_dinput = nullptr, + bool use_inference_params = false, + GradientMode param_gradients_mode = GradientMode::Overwrite + ) override { + if (!dL_dinput || !dL_ddLdoutput || padded_output_width() == 0) { + return; + } + + linear_kernel(identity_backward_backward, 0, stream, + input.n() * m_n_dims_to_encode, + input.n(), + m_n_dims_to_encode, + m_scale, + dL_ddLdinput.view(), + dL_ddLdoutput->view() + ); + + // dL_dinput: don't need to calculate this term, it's default set as 0.0 + } + uint32_t input_width() const override { return m_n_dims_to_encode; } diff --git a/include/tiny-cuda-nn/network_with_input_encoding.h b/include/tiny-cuda-nn/network_with_input_encoding.h index 525cc625..202c1c44 100644 --- a/include/tiny-cuda-nn/network_with_input_encoding.h +++ b/include/tiny-cuda-nn/network_with_input_encoding.h @@ -38,6 +38,33 @@ namespace tcnn { +// element-wise convert float* to T* +template +__global__ void element_wise_convert(uint32_t n_elements, float* in, T* out) { + uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= n_elements) return; + + out[i] = (T)in[i]; +} + +// element-wise convert T* to float* and then add back to *out +template +__global__ void element_wise_convert_float(uint32_t n_elements, T* in, float* out) { + uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= n_elements) return; + + out[i] += (float)in[i]; +} + +// element-wise add +template +__global__ void element_wise_add(uint32_t n_elements, T* in, T* out) { + uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= n_elements) return; + + out[i] += in[i]; +} + template class NetworkWithInputEncoding : public Network { public: @@ -90,8 +117,8 @@ class NetworkWithInputEncoding : public Network { bool use_inference_params = false, GradientMode param_gradients_mode = GradientMode::Overwrite ) override { - GPUMatrixDynamic dL_dnetwork_input; if (m_encoding->n_params() > 0 || dL_dinput) { + // dL_dnetwork_input becomes a member of the class instance dL_dnetwork_input = {m_encoding->padded_output_width(), input.n(), stream, m_encoding->preferred_output_layout()}; } @@ -112,6 +139,89 @@ class NetworkWithInputEncoding : public Network { } } + void backward_backward_input_impl( + cudaStream_t stream, + const Context& ctx, + const GPUMatrixDynamic& input, + const GPUMatrixDynamic& dL_ddLdinput, + const GPUMatrixDynamic& dL_doutput, + GPUMatrixDynamic* dL_ddLdoutput = nullptr, + GPUMatrixDynamic* dL_dinput = nullptr, + bool use_inference_params = false, + GradientMode param_gradients_mode = GradientMode::Overwrite + ) override { + const auto& forward = dynamic_cast(ctx); + + // dL_ddLdinput of m_network->backward_baward_input equals to dL_dLdencoding_output (different names) + GPUMatrixDynamic dL_dLdnetwork_input; + + if (m_encoding->n_params() > 0) { + dL_dLdnetwork_input = {m_encoding->padded_output_width(), input.n(), stream, dL_ddLdinput.layout()}; + // cudaMemsetAsync: set dL_dLdnetwork_input.data() with 0.0 to avoid NaN initialization + CUDA_CHECK_THROW(cudaMemsetAsync(dL_dLdnetwork_input.data(), 0, dL_dLdnetwork_input.n() * dL_dLdnetwork_input.m() * sizeof(T), stream)); + + // encoding backward backward + m_encoding->backward_backward_input( + stream, + *forward.encoding_ctx, + input, + dL_ddLdinput, + dL_dnetwork_input, // dL1_denc_output + &dL_dLdnetwork_input, // dL2_ddL1_denc_output + dL_dinput, + use_inference_params, + param_gradients_mode + ); + } else { // copy dL_ddLdinput (float) to dL_dLdnetwork_input (T) + dL_dLdnetwork_input = {m_encoding->padded_output_width(), input.n(), stream, dL_ddLdinput.layout()}; + linear_kernel(element_wise_convert, 0, stream, dL_dLdnetwork_input.n() * dL_dLdnetwork_input.m(), dL_ddLdinput.data(), dL_dLdnetwork_input.data()); + } + + // dL2_dinput of m_network->backward_backward_input + GPUMatrixDynamic dL2_dnetwork_input; + if (m_encoding->n_params() > 0 || dL_dinput) { + dL2_dnetwork_input = {m_encoding->padded_output_width(), input.n(), stream, m_encoding->preferred_output_layout()}; + } + + // network backward backward + m_network->backward_backward_input( + stream, + *forward.network_ctx, + forward.network_input, // enc_output i.e. network_input + dL_dLdnetwork_input, // dL2_dL1dnetwork_input + dL_doutput, + dL_ddLdoutput ? dL_ddLdoutput : nullptr, + dL2_dnetwork_input.data() ? &dL2_dnetwork_input : nullptr, // dL2_dinput of network + use_inference_params, + param_gradients_mode + ); + + // dL2dnetwork_input backward to dL2dinput, first order backward + GPUMatrixDynamic dL2_dinput; + if (m_encoding->n_params() > 0 || dL2_dnetwork_input.data()) { + dL2_dinput = {m_encoding->input_width(), input.n(), stream, input.layout()}; + } + + if (m_encoding->n_params() > 0) { + // backward dL2dnetwork_input to dL2dinput + m_encoding->backward( + stream, + *forward.encoding_ctx, + input, + forward.network_input, // enc_output + dL2_dnetwork_input, // dL2_dencoding_output + &dL2_dinput, + use_inference_params, + GradientMode::Accumulate // dL2denc_w : add up 1st order term + ); + + linear_kernel(element_wise_add, 0, stream, dL_dinput->n() * dL_dinput->m(), dL2_dinput.data(), dL_dinput->data()); + + } else if (dL2_dnetwork_input.data()) { + linear_kernel(element_wise_convert_float, 0, stream, dL_dinput->n() * dL_dinput->m(), dL2_dnetwork_input.data(), dL_dinput->data()); + } + } + void set_params_impl(T* params, T* inference_params, T* gradients) override { size_t offset = 0; m_network->set_params(params + offset, inference_params + offset, gradients + offset); @@ -181,6 +291,7 @@ class NetworkWithInputEncoding : public Network { private: std::shared_ptr> m_encoding; std::shared_ptr> m_network; + GPUMatrixDynamic dL_dnetwork_input; struct ForwardContext : public Context { GPUMatrixDynamic network_input; diff --git a/include/tiny-cuda-nn/networks/cutlass_mlp.h b/include/tiny-cuda-nn/networks/cutlass_mlp.h index 2cca2d3c..8dcd32c6 100644 --- a/include/tiny-cuda-nn/networks/cutlass_mlp.h +++ b/include/tiny-cuda-nn/networks/cutlass_mlp.h @@ -66,6 +66,28 @@ class CutlassMLP : public Network { GradientMode param_gradients_mode = GradientMode::Overwrite ) override; + void backward_backward_input_impl( + cudaStream_t stream, + const Context& ctx, + const GPUMatrixDynamic& input, + const GPUMatrixDynamic& dL_ddLdinput, + const GPUMatrixDynamic& dL_doutput, + GPUMatrixDynamic* dL_ddLdoutput = nullptr, + GPUMatrixDynamic* dL_dinput = nullptr, + bool use_inference_params = false, + GradientMode param_gradients_mode = GradientMode::Overwrite + ) override; + + bool prepare_backward_variables( + cudaStream_t stream, + const std::vector>& output, + const GPUMatrixDynamic& dL_doutput, + GPUMatrixDynamic& backward_output_tmp, + std::vector>& dL1dp, + std::vector>& dL1doutput, + bool use_inference_params + ); + void set_params_impl(T* params, T* inference_params, T* gradients) override; void initialize_params(pcg32& rnd, float* params_full_precision, float scale = 1) override; diff --git a/scripts/test_armadillo_bwdbwd.py b/scripts/test_armadillo_bwdbwd.py new file mode 100644 index 00000000..4d6a234a --- /dev/null +++ b/scripts/test_armadillo_bwdbwd.py @@ -0,0 +1,599 @@ +#!/usr/bin/env python3 + +import torch +import trimesh +import torch.nn as nn +from torch import autograd +from torch.optim import Adam, SGD +import torch.nn.functional as F +import skimage +import random +import apex + +import sys +import os +import numpy as np +import time +NoLog = False + +try: + import tinycudann as tcnn +except ImportError: + print("This script requires the tiny-cuda-nn extension for PyTorch.") + print("You can install it by running:") + print("============================================================") + print("tiny-cuda-nn$ cd bindings/torch") + print("tiny-cuda-nn/bindings/torch$ python setup.py install") + print("============================================================") + sys.exit() + +torch.set_printoptions(precision=10) + +def GenerateRasterPoints(midpoint, extents, resolution): + max_axis = max(extents) + voxel_size = max_axis / resolution +# 021 120 201 210 + points = np.meshgrid( + np.linspace(midpoint[0] - extents[0]*.501, midpoint[0] + extents[0]*.501, int(resolution * extents[0]/max_axis)), + np.linspace(midpoint[1] - extents[1]*.501, midpoint[1] + extents[1]*.501, int(resolution * extents[1]/max_axis)), + np.linspace(midpoint[2] - extents[2]*.501, midpoint[2] + extents[2]*.501, int(resolution * extents[2]/max_axis)) + ) + points = np.stack(points) + points = np.swapaxes(points, 1, 2) + points = points.reshape(3, -1).transpose().astype(np.float32) + res = np.array([int(resolution * extents[0]/max_axis),int(resolution * extents[1]/max_axis),int(resolution * extents[2]/max_axis)]) + return points, res#voxel_size, res + +class SDF(nn.Module): + def __init__(self, hash=True, n_levels=12, log2_hashmap_size=15, base_resolution=16, smoothstep=False) -> None: + super().__init__() + + self.encoder = tcnn.Encoding(3, { + "otype": "HashGrid" if hash else "DenseGrid", + "n_levels": 16, #n_levels, + "n_features_per_level": 2, #8, + "log2_hashmap_size": 19, #log2_hashmap_size, + "base_resolution": 16, #base_resolution, + "per_level_scale": np.exp2(np.log2(2048 * 1 * 1 / 16) / (16 - 1)), #1.5, + #"interpolation": "Smoothstep" if smoothstep else "Linear" + }) + b_flag = False + self.decoder = nn.Sequential( + nn.Linear(self.encoder.n_output_dims, 64, bias=b_flag), + nn.Softplus(beta = 10.0), + nn.Linear(64, 64, bias=b_flag), + nn.Softplus(beta = 10.0), + nn.Linear(64, 64, bias=b_flag), + nn.Softplus(beta = 10.0), + nn.Linear(64, 1, bias=b_flag) + ) + + # write into numpy + #params_enc = self.encoder.params.data.clone().cpu().numpy() + #np.save('numpy/params_sdf_enc.npy', params_enc) + + params_enc = np.load('./numpy/params_sdf_enc.npy') + self.encoder.params.data = torch.from_numpy(params_enc) + + idx = 0 + for m in self.decoder.modules(): + if isinstance(m, nn.Linear): + if idx == 0: + params_input = np.load("./numpy/params_input.npy") + m.weight.data = torch.from_numpy(params_input) + elif idx == 1: + params_hidden_1 = np.load("./numpy/params_hidden_1.npy") + m.weight.data = torch.from_numpy(params_hidden_1) + elif idx == 2: + params_hidden_2 = np.load("./numpy/params_hidden_2.npy") + m.weight.data = torch.from_numpy(params_hidden_2) + else: + params_output = np.load("./numpy/params_output.npy") + m.weight.data = torch.from_numpy(params_output) + idx += 1 + + def set_cuda_fun(func): + num = 0 # 初始化次数 + total_time = 0 + + def call_fun(*args, **kwargs): + nonlocal num + nonlocal total_time + if NoLog: + res = func(*args, **kwargs) + return res + + torch.cuda.synchronize() + start = time.perf_counter() # 代码执行开始时间 + num += 1 # 每次调用次数加1 + res = func(*args, **kwargs) + torch.cuda.synchronize() + end = time.perf_counter() # 代码执行结束时间 + # longtime = end - start + total_time += end - start + print("SDF in Pytorch 前向耗时: ", func.__name__, " 调用次数: ", num, " 累计时间:", total_time) + return res + return call_fun + + def loadMesh(self,mesh,batch,iteration): + sdf=[] + points=[] + self.mesh = trimesh.load_mesh(mesh) + self.name = mesh + scale_fac = np.max(self.mesh.extents) + self.mesh = trimesh.Trimesh.apply_translation(self.mesh,-1 * (self.mesh.bounds[0]+self.mesh.bounds[1])/2) + self.mesh = trimesh.Trimesh.apply_scale(self.mesh,1/scale_fac) + + + if not os.path.exists(mesh+"_sdf"+"_"+str(batch)+"_"+str(iteration)+".gz"): + print("Sample SDF form Model") + for i in range(iteration): + points_ = np.random.rand(batch,3) - 0.5 + query = trimesh.proximity.ProximityQuery(self.mesh) + sdf_ = query.signed_distance(points_) + sdf.append(sdf_) + points.append(points_) + print("Batch Size:",i) + self.sdf = np.array(sdf).reshape(-1,1) * -1 + self.points = np.array(points).reshape(-1,3) + np.savetxt(mesh+"_points"+"_"+str(batch)+"_"+str(iteration)+".gz",self.points) + np.savetxt(mesh+"_sdf"+"_"+str(batch)+"_"+str(iteration)+".gz",self.sdf) + else: + self.sdf = np.loadtxt(mesh+"_sdf"+"_"+str(batch)+"_"+str(iteration)+".gz") + self.points = np.loadtxt(mesh+"_points"+"_"+str(batch)+"_"+str(iteration)+".gz") + self.sdf = self.sdf.reshape(iteration,-1,1) + self.points = self.points.reshape(iteration,-1,3) + + def saveMesh(self, gridres, file, it, level=0): + + grid,res = GenerateRasterPoints(np.array([0,0,0]), np.array([1,1,1]), gridres) + print(grid.shape) + print(res) + grid_size = 64 + sdf = [] + spacing = 64*64*64 + for i in range(0,grid.shape[0],spacing): + grid_portion = grid[i:i+spacing] + sdf_portion = self(torch.tensor(grid_portion, dtype=torch.float32).cuda()).detach().cpu().numpy()[:,0] + sdf.append(sdf_portion) + sdf = np.concatenate(sdf) + try: + vertices, faces, normals, _ = skimage.measure.marching_cubes_lewiner(sdf.reshape(res[0],res[1],res[2]), allow_degenerate=True, level=0) + except Exception as e: + vertices = np.array(((0,0,0),(0.5,0.5,0.5),(0.2,0.3,0.8))) + faces = np.array([0,1,2]) + normals = np.array((0.6,0.5,0.6)) + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_normals=normals) + file_ = open(file+str(it) +'.ply','wb') + file_.write(trimesh.exchange.ply.export_ply(mesh)) + + try: + vertices, faces, normals, _ = skimage.measure.marching_cubes_lewiner(sdf.reshape(res[0],res[1],res[2]), allow_degenerate=True, level=5e-5) + except Exception as e: + vertices = np.array(((0,0,0),(0.5,0.5,0.5),(0.2,0.3,0.8))) + faces = np.array([0,1,2]) + normals = np.array((0.6,0.5,0.6)) + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_normals=normals) + file_ = open(file+str(it)+'_+5e-5' +'.ply','wb') + file_.write(trimesh.exchange.ply.export_ply(mesh)) + + try: + vertices, faces, normals, _ = skimage.measure.marching_cubes_lewiner(sdf.reshape(res[0],res[1],res[2]), allow_degenerate=True, level=5e-4) + except Exception as e: + vertices = np.array(((0,0,0),(0.5,0.5,0.5),(0.2,0.3,0.8))) + faces = np.array([0,1,2]) + normals = np.array((0.6,0.5,0.6)) + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_normals=normals) + file_ = open(file+str(it)+'_+5e-4' +'.ply','wb') + file_.write(trimesh.exchange.ply.export_ply(mesh)) + + def forward(self, x): + encoded = self.encoder(x).to(dtype=torch.float) + sdf = self.decoder(encoded) + return sdf + + #@set_cuda_fun + def forward_with_nablas(self, x): + with torch.enable_grad(): + x = x.requires_grad_(True) + sdf = self.forward(x) + nablas = autograd.grad( + sdf, + x, + torch.ones_like(sdf, device=x.device), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + return sdf, nablas + +class SDF_TCNN(nn.Module): + def __init__(self, hash=True, n_levels=12, log2_hashmap_size=15, base_resolution=16, smoothstep=False) -> None: + super().__init__() + + self.encoder_decoder = tcnn.NetworkWithInputEncoding( + encoding_config = { + "otype": "HashGrid" if hash else "DenseGrid", + "n_levels": 16, #n_levels, + "n_features_per_level": 2, #8, + "log2_hashmap_size": 19, #log2_hashmap_size, + "base_resolution": 16, #base_resolution, + "per_level_scale": np.exp2(np.log2(2048 * 1 * 1 / 16) / (16 - 1)), #1.5, + #"interpolation": "Smoothstep" if smoothstep else "Linear" + }, + n_input_dims=3, #self.encoder.n_output_dims, #3 + n_output_dims=1, #64, + network_config={ + "otype": "CutlassMLP", #"FullyFusedMLP", + "activation": "Softplus", + "output_activation": "None", #"Softplus", + "n_neurons": 64, + "n_hidden_layers": 3 #7 + }, + ) + + + # init encoder params from file + file_dir = os.getcwd() + + #enc_params = np.load(os.path.join(file_dir, 'numpy/params_sdf_enc.npy')) + #enc_p_torch = torch.from_numpy(enc_params) + + num_params = len(self.encoder_decoder.params) + num_enc_params = 12196240 + start_params = num_params - num_enc_params + # for i in range(num_enc_params): + # tmp = enc_p_torch[i].detach() + # self.encoder_decoder.params.data.index_fill_(0, torch.tensor(i+start_params, dtype=torch.int64).cuda(), torch.tensor(tmp).cuda()) + + # init decoder params from file + params_input = np.load(os.path.join(file_dir, './numpy/params_input.npy')) + params_input_tensor = torch.from_numpy(params_input) + + ## init input layer, notice NOT to fill unneccessary blanks (0.0) + idx_tcnn = 0 + for i in range(params_input_tensor.shape[0]): + for j in range(params_input_tensor.shape[1]): + self.encoder_decoder.params[j+idx_tcnn].data.copy_(params_input_tensor[i, j]) + idx_tcnn += 32 # column num in TCNN + + ## init hidden layers + params = np.load(os.path.join(file_dir, './numpy/params_hidden_1.npy')) + params_tensor = torch.from_numpy(params) + idx_tcnn = 2048 # skip the input layer params + for j in range(params_tensor.shape[0]): + for k in range(params_tensor.shape[1]): + self.encoder_decoder.params[k+idx_tcnn].data.copy_(params_tensor[j, k]) + idx_tcnn += params_tensor.shape[1] + + params = np.load(os.path.join(file_dir, './numpy/params_hidden_2.npy')) + params_tensor = torch.from_numpy(params) + idx_tcnn = 6144 # skip the input layer params + for j in range(params_tensor.shape[0]): + for k in range(params_tensor.shape[1]): + self.encoder_decoder.params[k+idx_tcnn].data.copy_(params_tensor[j, k]) + idx_tcnn += params_tensor.shape[1] + + params = np.load(os.path.join(file_dir, './numpy/params_output.npy')) + params_tensor = torch.from_numpy(params) + idx_tcnn = 10240 # skip the input layer params + for j in range(params_tensor.shape[0]): + for k in range(params_tensor.shape[1]): + self.encoder_decoder.params[k+idx_tcnn].data.copy_(params_tensor[j, k]) + idx_tcnn += params_tensor.shape[1] + + def set_cuda_fun(func): + num = 0 # 初始化次数 + total_time = 0 + + def call_fun(*args, **kwargs): + nonlocal num + nonlocal total_time + if NoLog: + res = func(*args, **kwargs) + return res + + torch.cuda.synchronize() + start = time.perf_counter() # 代码执行开始时间 + num += 1 # 每次调用次数加1 + res = func(*args, **kwargs) + torch.cuda.synchronize() + end = time.perf_counter() # 代码执行结束时间 + # longtime = end - start + total_time += end - start + print("SDF in TCNN 前向耗时: ", func.__name__, " 调用次数: ", num, " 累计时间:", total_time) + return res + return call_fun + + def loadMesh(self,mesh,batch,iteration): + sdf=[] + points=[] + self.mesh = trimesh.load_mesh(mesh) + self.name = mesh + scale_fac = np.max(self.mesh.extents) + self.mesh = trimesh.Trimesh.apply_translation(self.mesh,-1 * (self.mesh.bounds[0]+self.mesh.bounds[1])/2) + self.mesh = trimesh.Trimesh.apply_scale(self.mesh,1/scale_fac) + + if not os.path.exists(mesh+"_sdf"+"_"+str(batch)+"_"+str(iteration)+".gz"): + print("Sample SDF form Model") + for i in range(iteration): + points_ = np.random.rand(batch,3) - 0.5 + query = trimesh.proximity.ProximityQuery(self.mesh) + sdf_ = query.signed_distance(points_) + sdf.append(sdf_) + points.append(points_) + print("Batch Size:",i) + self.sdf = np.array(sdf).reshape(-1,1) * -1 + self.points = np.array(points).reshape(-1,3) + np.savetxt(mesh+"_points"+"_"+str(batch)+"_"+str(iteration)+".gz",self.points) + np.savetxt(mesh+"_sdf"+"_"+str(batch)+"_"+str(iteration)+".gz",self.sdf) + else: + self.sdf = np.loadtxt(mesh+"_sdf"+"_"+str(batch)+"_"+str(iteration)+".gz") + self.points = np.loadtxt(mesh+"_points"+"_"+str(batch)+"_"+str(iteration)+".gz") + self.sdf = self.sdf.reshape(iteration,-1,1) + self.points = self.points.reshape(iteration,-1,3) + + def saveMesh(self, gridres, file, it, level=0): + + grid,res = GenerateRasterPoints(np.array([0,0,0]), np.array([1,1,1]), gridres) + print(grid.shape) + print(res) + grid_size = 64 + sdf = [] + spacing = 64*64*64 + for i in range(0,grid.shape[0],spacing): + grid_portion = grid[i:i+spacing] + sdf_portion = self(torch.tensor(grid_portion, dtype=torch.float32).cuda()).detach().cpu().numpy()[:,0] + sdf.append(sdf_portion) + sdf = np.concatenate(sdf) + try: + vertices, faces, normals, _ = skimage.measure.marching_cubes_lewiner(sdf.reshape(res[0],res[1],res[2]), allow_degenerate=True, level=0) + except Exception as e: + vertices = np.array(((0,0,0),(0.5,0.5,0.5),(0.2,0.3,0.8))) + faces = np.array([0,1,2]) + normals = np.array((0.6,0.5,0.6)) + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_normals=normals) + file_ = open(file+str(it) +'.ply','wb') + file_.write(trimesh.exchange.ply.export_ply(mesh)) + + try: + vertices, faces, normals, _ = skimage.measure.marching_cubes_lewiner(sdf.reshape(res[0],res[1],res[2]), allow_degenerate=True, level=5e-5) + except Exception as e: + vertices = np.array(((0,0,0),(0.5,0.5,0.5),(0.2,0.3,0.8))) + faces = np.array([0,1,2]) + normals = np.array((0.6,0.5,0.6)) + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_normals=normals) + file_ = open(file+str(it)+'_+5e-5' +'.ply','wb') + file_.write(trimesh.exchange.ply.export_ply(mesh)) + + try: + vertices, faces, normals, _ = skimage.measure.marching_cubes_lewiner(sdf.reshape(res[0],res[1],res[2]), allow_degenerate=True, level=5e-4) + except Exception as e: + vertices = np.array(((0,0,0),(0.5,0.5,0.5),(0.2,0.3,0.8))) + faces = np.array([0,1,2]) + normals = np.array((0.6,0.5,0.6)) + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_normals=normals) + file_ = open(file+str(it)+'_+5e-4' +'.ply','wb') + file_.write(trimesh.exchange.ply.export_ply(mesh)) + + def forward(self, x): + sdf = self.encoder_decoder(x) + return sdf + + #@set_cuda_fun + def forward_with_nablas(self, x): + with torch.enable_grad(): + x = x.requires_grad_(True) + sdf = self.forward(x) + nablas = autograd.grad( + sdf, + x, + torch.ones_like(sdf, device=x.device), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + return sdf, nablas + + +if __name__ == '__main__': + """ + NOTE: Jianfei: I provide three testing tools for backward_backward functionality. + Play around as you want :) + 1. test_train(): train a toy SDF model with eikonal term. + 2. grad_check(): check backward_backward numerical correctness via torch.autograd.gradcheck. + 3. vis_graph(): visualize torch compute graph + """ + + def print_torch_SDF(model): + # for pytorch SDF + #print("encoder grad: ", model.decoder.params.grad) + print("========= ========= ========= ========= ========= ========") + print("decoder grad: ") + for i in range(0, len(model.decoder), 2): + if i % 2 == 0: + w = model.decoder[i].weight.grad + print("grad.shape: ", w.shape) + output, input = w.shape + if output == 1: + row = 1 + else: + row = 3 + for j in range(0, row): + print("grad[",j*input,", ",(j+1)*input,"]: ", w[j]) + print("========= ========= ========= ========= ========= ========") + + return + + def print_TCNN_layer_weight(prefix, weight, row, col): + # for printing weight of each linear layer in TCNN + print("decoder grad of layer - ", prefix, " - [",row,",", col,"]: ") + idx = 0 + for i in range(row): + print(prefix,"[", idx, ":", idx+col, "]: ", weight[idx:idx+col]) + idx = idx + col + + if idx > 128: + break + print("========= ========= ========= ========= ========= ========") + + return + + def print_TCNN_SDF(model): + # for SDF in TCNN + # print grad + print("encoder_decoder grad.shape: ", model.encoder_decoder.params.grad.shape) + enc_grad = model.encoder_decoder.params.grad[-512:] + print("encoder grad: ", enc_grad) + + dec_grad_0 = model.encoder_decoder.params.grad[0:2048] + dec_grad_1 = model.encoder_decoder.params.grad[2048:6144] # 1st hidden layer + dec_grad_2 = model.encoder_decoder.params.grad[6144:10240] # 2nd hidden layer + dec_grad_3 = model.encoder_decoder.params.grad[10240:11264] # 3rd hidden layer + + print_TCNN_layer_weight("dec_grad_0", dec_grad_0, 32, 64) + print_TCNN_layer_weight("dec_grad_1", dec_grad_1, 64, 64) + print_TCNN_layer_weight("dec_grad_2", dec_grad_2, 64, 64) + print_TCNN_layer_weight("dec_grad_3", dec_grad_3, 64, 1) + + return + + + def print_bwdbwd_time(scaler, loss, optimizer, total_time, num): + + torch.cuda.synchronize() + start = time.perf_counter() # 代码执行开始时间 + num += 1 # 每次调用次数加1 + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + torch.cuda.synchronize() + end = time.perf_counter() # 代码执行结束时间 + total_time += end - start + print("SDF in backward_backward耗时: ", " 调用次数: ", num, " 累计时间:", total_time) + + return total_time, num + + def compute_normal(x:torch.Tensor, y): #[N,3] + x.requires_grad_(True) + y.requires_grad_(True) + d_output = torch.ones_like(y, requires_grad=False, device=y.device) + gradients = torch.autograd.grad( + outputs=y, + inputs=x, + grad_outputs=d_output, + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + + return gradients + + def test_train(): + """ + train a toy SDF model with eikonal term. + """ + from tqdm import tqdm + device = torch.device("cuda") + + # leverage TCNN SDF + # NOTE: if leveraging TCNN, please leverage print_TCNN_SDF(model) to print temporary variables + model = SDF_TCNN(True, n_levels=1, log2_hashmap_size=15, base_resolution=4, smoothstep=False).to(device) + + # leverage Pytorch SDF + # NOTE: if leveraging Pytorch, please leverage print_torch_SDF(model) to print temporary variables + #model = SDF(True, n_levels=1, base_resolution=4).to(device) + + model.loadMesh("./data/Armadillo.ply", 1024, 3000) + + torch.cuda.nvtx.range_push('training_preparation') + torch.manual_seed(0) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + #optimizer = Adam(model.parameters(), 2.0e-4) + optimizer = apex.optimizers.FusedAdam(model.parameters(), lr = 2.0e-4) + + iter_i = 0 + fake_nablas = torch.ones([1], device='cuda') + torch.cuda.nvtx.range_pop() + + from torch.cuda.amp import GradScaler + from torch.cuda.amp import autocast + scaler = GradScaler() + + num = 0 + total_time = 0 + + with tqdm(range(10000)) as pbar: + for i in pbar: + sdf_list = [] + point_list = [] + for _ in range(1): + it = int(random.random() * 3000) + sdf_list.append(model.sdf[0]) + point_list.append(model.points[0]) + ref = torch.tensor(np.concatenate(sdf_list),device='cuda',dtype=torch.float32).squeeze() + x = torch.tensor(np.concatenate(point_list),device='cuda',dtype=torch.float32) + + torch.cuda.nvtx.range_push('python_forward_start') + sdf, nablas = model.forward_with_nablas(x) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push('python_nablas_norm') + nablas_norm: torch.Tensor = nablas.norm(dim=-1) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push('python_loss_compute') + optimizer.zero_grad() + + r = torch.rand([1024, 3], device='cuda', dtype=torch.float32) - 0.5 + sdf_, nablas = model.forward_with_nablas(r) + eikonal_loss = torch.sum(torch.abs(torch.norm(nablas, p=2, dim=1) - 1)) * 0.0001 + loss = F.mse_loss(sdf[..., 0], ref) + eikonal_loss + + loss.backward() + optimizer.step() + + # print time consumed + # NOTE: if use print_bwdbwd_time, one needs to comment loss.backward() and optimizer.step() + #total_time, num = print_bwdbwd_time(scaler, loss, optimizer, total_time, num) + + torch.cuda.nvtx.range_pop() + + pbar.set_postfix(loss=loss.item()) + + if (torch.isnan(loss)): + break + + if(iter_i%1000==0): + model.saveMesh(128, "result/test", iter_i) + iter_i = iter_i + 1 + + + return + + def save_model_param(model): + file_dir = os.getcwd() + params = [] + params_input = [] + flag_layer = 0 + for m in model.decoder.modules(): + if isinstance(m, nn.Linear): + if flag_layer == 0: + params_input.append(m.weight.clone().detach().cpu().numpy()) + print("params of input layer m.weight: ", m.weight.clone().detach().cpu().numpy()) + elif flag_layer > 0: + params.append(m.weight.clone().detach().cpu().numpy()) + print("params of m.weight: ", m.weight.clone().detach().cpu().numpy()) + # params.append(m.bias.detach().numpy()) + flag_layer = 1 + np.save("test_params_iter_1_input.npy", params_input) + np.save("test_params_iter_1.npy", params) + + return + +if __name__ == "__main__": + + # test tcnn + test_train() + + diff --git a/scripts/test_armadillo_numeric_align.py b/scripts/test_armadillo_numeric_align.py new file mode 100644 index 00000000..3e478c45 --- /dev/null +++ b/scripts/test_armadillo_numeric_align.py @@ -0,0 +1,616 @@ +#!/usr/bin/env python3 + +import torch +import trimesh +import torch.nn as nn +from torch import autograd +from torch.optim import Adam, SGD +import torch.nn.functional as F +import skimage +import random + +import sys +import os +import numpy as np +import time +NoLog = False + +try: + import tinycudann as tcnn +except ImportError: + print("This script requires the tiny-cuda-nn extension for PyTorch.") + print("You can install it by running:") + print("============================================================") + print("tiny-cuda-nn$ cd bindings/torch") + print("tiny-cuda-nn/bindings/torch$ python setup.py install") + print("============================================================") + sys.exit() + +def GenerateRasterPoints(midpoint, extents, resolution): + max_axis = max(extents) + voxel_size = max_axis / resolution +# 021 120 201 210 + points = np.meshgrid( + np.linspace(midpoint[0] - extents[0]*.501, midpoint[0] + extents[0]*.501, int(resolution * extents[0]/max_axis)), + np.linspace(midpoint[1] - extents[1]*.501, midpoint[1] + extents[1]*.501, int(resolution * extents[1]/max_axis)), + np.linspace(midpoint[2] - extents[2]*.501, midpoint[2] + extents[2]*.501, int(resolution * extents[2]/max_axis)) + ) + points = np.stack(points) + points = np.swapaxes(points, 1, 2) + points = points.reshape(3, -1).transpose().astype(np.float32) + res = np.array([int(resolution * extents[0]/max_axis),int(resolution * extents[1]/max_axis),int(resolution * extents[2]/max_axis)]) + return points, res#voxel_size, res + +class SDF(nn.Module): + def __init__(self, hash=True, n_levels=12, log2_hashmap_size=15, base_resolution=16, smoothstep=False) -> None: + super().__init__() + self.encoder = tcnn.Encoding(3, { + "otype": "HashGrid" if hash else "DenseGrid", + "n_levels": 16, #n_levels, + "n_features_per_level": 2, #8, + "log2_hashmap_size": 19, #log2_hashmap_size, + "base_resolution": 16, #base_resolution, + "per_level_scale": np.exp2(np.log2(2048 * 1 * 1 / 16) / (16 - 1)), #1.5, + #"interpolation": "Smoothstep" if smoothstep else "Linear" + }) + + b_flag = False + self.decoder = nn.Sequential( + nn.Linear(self.encoder.n_output_dims, 64, bias=b_flag), + nn.Softplus(beta = 10.0), + nn.Linear(64, 64, bias=b_flag), + nn.Softplus(beta = 10.0), + nn.Linear(64, 64, bias=b_flag), + nn.Softplus(beta = 10.0), + nn.Linear(64, 1, bias=b_flag), + ) + + # write into numpy + params_enc = self.encoder.params.data.clone().cpu().numpy() + print("params_enc.length: ", len(params_enc), params_enc[0:64]) + np.save('numpy/params_sdf_enc.npy', params_enc) + + # load param from file + #params_enc = np.load('numpy/params_sdf_enc.npy') + #self.encoder.params.data = torch.from_numpy(params_enc) + + idx = 0 + for m in self.decoder.modules(): + if isinstance(m, nn.Linear): + if idx == 0: + #params_input = m.weight.data.clone().cpu().numpy() + #np.save("numpy/params_input.npy", params_input) + + params_input = np.load("numpy/params_input.npy") + m.weight.data = torch.from_numpy(params_input) + elif idx == 1: + #params_hidden_1 = m.weight.data.clone().cpu().numpy() + #np.save("numpy/params_hidden_1.npy", params_hidden_1) + + params_hidden_1 = np.load("numpy/params_hidden_1.npy") + m.weight.data = torch.from_numpy(params_hidden_1) + elif idx == 2: + #params_hidden_2 = m.weight.data.clone().cpu().numpy() + #np.save("numpy/params_hidden_2.npy", params_hidden_2) + + params_hidden_2 = np.load("numpy/params_hidden_2.npy") + m.weight.data = torch.from_numpy(params_hidden_2) + else: + #params_output = m.weight.data.clone().cpu().numpy() + #np.save("numpy/params_output.npy", params_output) + + params_output = np.load("numpy/params_output.npy") + m.weight.data = torch.from_numpy(params_output) + idx += 1 + + def set_cuda_fun(func): + num = 0 # 初始化次数 + total_time = 0 + + def call_fun(*args, **kwargs): + nonlocal num + nonlocal total_time + if NoLog: + res = func(*args, **kwargs) + return res + + torch.cuda.synchronize() + start = time.perf_counter() # 代码执行开始时间 + num += 1 # 每次调用次数加1 + res = func(*args, **kwargs) + torch.cuda.synchronize() + end = time.perf_counter() # 代码执行结束时间 + # longtime = end - start + total_time += end - start + print("SDF in Pytorch 前向耗时: ", func.__name__, " 调用次数: ", num, " 累计时间:", total_time) + return res + return call_fun + + def loadMesh(self,mesh,batch,iteration): + sdf=[] + points=[] + self.mesh = trimesh.load_mesh(mesh) + self.name = mesh + scale_fac = np.max(self.mesh.extents) + self.mesh = trimesh.Trimesh.apply_translation(self.mesh,-1 * (self.mesh.bounds[0]+self.mesh.bounds[1])/2) + self.mesh = trimesh.Trimesh.apply_scale(self.mesh,1/scale_fac) + + + if not os.path.exists(mesh+"_sdf"+"_"+str(batch)+"_"+str(iteration)+".gz"): + print("Sample SDF form Model") + for i in range(iteration): + points_ = np.random.rand(batch,3) - 0.5 + query = trimesh.proximity.ProximityQuery(self.mesh) + sdf_ = query.signed_distance(points_) + sdf.append(sdf_) + points.append(points_) + print("Batch Size:",i) + self.sdf = np.array(sdf).reshape(-1,1) * -1 + self.points = np.array(points).reshape(-1,3) + np.savetxt(mesh+"_points"+"_"+str(batch)+"_"+str(iteration)+".gz",self.points) + np.savetxt(mesh+"_sdf"+"_"+str(batch)+"_"+str(iteration)+".gz",self.sdf) + else: + self.sdf = np.loadtxt(mesh+"_sdf"+"_"+str(batch)+"_"+str(iteration)+".gz") + self.points = np.loadtxt(mesh+"_points"+"_"+str(batch)+"_"+str(iteration)+".gz") + self.sdf = self.sdf.reshape(iteration,-1,1) + self.points = self.points.reshape(iteration,-1,3) + + def saveMesh(self, gridres, file, it, level=0): + + grid,res = GenerateRasterPoints(np.array([0,0,0]), np.array([1,1,1]), gridres) + print("saveMesh - grid.shape:", grid.shape) + print("saveMesh - res:", res) + grid_size = 64 + sdf = [] + spacing = 64*64*64 + for i in range(0,grid.shape[0],spacing): + grid_portion = grid[i:i+spacing] + sdf_portion = self(torch.tensor(grid_portion, dtype=torch.float32).cuda()).detach().cpu().numpy()[:,0] + sdf.append(sdf_portion) + sdf = np.concatenate(sdf) + try: + vertices, faces, normals, _ = skimage.measure.marching_cubes_lewiner(sdf.reshape(res[0],res[1],res[2]), allow_degenerate=True, level=0) + except Exception as e: + vertices = np.array(((0,0,0),(0.5,0.5,0.5),(0.2,0.3,0.8))) + faces = np.array([0,1,2]) + normals = np.array((0.6,0.5,0.6)) + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_normals=normals) + file_ = open(file+str(it) +'.ply','wb') + file_.write(trimesh.exchange.ply.export_ply(mesh)) + + try: + vertices, faces, normals, _ = skimage.measure.marching_cubes_lewiner(sdf.reshape(res[0],res[1],res[2]), allow_degenerate=True, level=5e-5) + except Exception as e: + vertices = np.array(((0,0,0),(0.5,0.5,0.5),(0.2,0.3,0.8))) + faces = np.array([0,1,2]) + normals = np.array((0.6,0.5,0.6)) + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_normals=normals) + file_ = open(file+str(it)+'_+5e-5' +'.ply','wb') + file_.write(trimesh.exchange.ply.export_ply(mesh)) + + try: + vertices, faces, normals, _ = skimage.measure.marching_cubes_lewiner(sdf.reshape(res[0],res[1],res[2]), allow_degenerate=True, level=5e-4) + except Exception as e: + vertices = np.array(((0,0,0),(0.5,0.5,0.5),(0.2,0.3,0.8))) + faces = np.array([0,1,2]) + normals = np.array((0.6,0.5,0.6)) + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_normals=normals) + file_ = open(file+str(it)+'_+5e-4' +'.ply','wb') + file_.write(trimesh.exchange.ply.export_ply(mesh)) + + def forward(self, x): + encoded = self.encoder(x).to(dtype=torch.float) + sdf = self.decoder(encoded) + return sdf + + #@set_cuda_fun + def forward_with_nablas(self, x): + with torch.enable_grad(): + x = x.requires_grad_(True) + print("within forward_with_nablas: x.shape: ", x.shape) + sdf = self.forward(x) + nablas = autograd.grad( + sdf, + x, + torch.ones_like(sdf, device=x.device), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + return sdf, nablas + +class SDF_TCNN(nn.Module): + def __init__(self, hash=True, n_levels=12, log2_hashmap_size=15, base_resolution=16, smoothstep=False) -> None: + super().__init__() + + self.encoder_decoder = tcnn.NetworkWithInputEncoding( + encoding_config = { + "otype": "HashGrid" if hash else "DenseGrid", + "n_levels": 16, #n_levels, + "n_features_per_level": 2, #8, + "log2_hashmap_size": 19, #log2_hashmap_size, + "base_resolution": 16, #base_resolution, + "per_level_scale": np.exp2(np.log2(2048 * 1 * 1 / 16) / (16 - 1)), #1.5, + }, + n_input_dims=3, #self.encoder.n_output_dims, #3 + n_output_dims=1, #64, + network_config={ + "otype": "CutlassMLP", #"FullyFusedMLP", + "activation": "Softplus", + "output_activation": "None", + "n_neurons": 64, + "n_hidden_layers": 3 #7 + }, + ) + + # init encoder params from file + file_dir = os.getcwd() + # load encoder params but initialization params are the same no need to assign + enc_params = np.load(os.path.join(file_dir, 'numpy/params_sdf_enc.npy')) + enc_p_torch = torch.from_numpy(enc_params) + + num_params = len(self.encoder_decoder.params) + num_enc_params = 12196240 # 12196240 + start_params = num_params - num_enc_params + #for i in range(num_enc_params): + # tmp = enc_p_torch[i].detach() + # self.encoder_decoder.params.data.index_fill_(0, torch.tensor(i+start_params, dtype=torch.int64).cuda(), torch.tensor(tmp).cuda()) + + # init decoder params from file + params_input = np.load(os.path.join(file_dir, 'numpy/params_input.npy')) + params_input_tensor = torch.from_numpy(params_input) + + ## init input layer, notice NOT to fill unneccessary blanks (0.0) + idx_tcnn = 0 + for i in range(params_input_tensor.shape[0]): + for j in range(params_input_tensor.shape[1]): + self.encoder_decoder.params[j+idx_tcnn].data.copy_(params_input_tensor[i, j]) + idx_tcnn += 32 # column num in TCNN + + ## init hidden layers + params = np.load(os.path.join(file_dir, 'numpy/params_hidden_1.npy')) + params_tensor = torch.from_numpy(params) + idx_tcnn = 2048 # skip the input layer params + for j in range(params_tensor.shape[0]): + for k in range(params_tensor.shape[1]): + self.encoder_decoder.params[k+idx_tcnn].data.copy_(params_tensor[j, k]) + idx_tcnn += params_tensor.shape[1] + + params = np.load(os.path.join(file_dir, 'numpy/params_hidden_2.npy')) + params_tensor = torch.from_numpy(params) + idx_tcnn = 6144 # skip the input layer params + for j in range(params_tensor.shape[0]): + for k in range(params_tensor.shape[1]): + self.encoder_decoder.params[k+idx_tcnn].data.copy_(params_tensor[j, k]) + idx_tcnn += params_tensor.shape[1] + + params = np.load(os.path.join(file_dir, 'numpy/params_output.npy')) + params_tensor = torch.from_numpy(params) + idx_tcnn = 10240 # skip the input layer params + for j in range(params_tensor.shape[0]): + for k in range(params_tensor.shape[1]): + self.encoder_decoder.params[k+idx_tcnn].data.copy_(params_tensor[j, k]) + idx_tcnn += params_tensor.shape[1] + + def set_cuda_fun(func): + num = 0 # 初始化次数 + total_time = 0 + + def call_fun(*args, **kwargs): + nonlocal num + nonlocal total_time + if NoLog: + res = func(*args, **kwargs) + return res + + torch.cuda.synchronize() + start = time.perf_counter() # 代码执行开始时间 + num += 1 # 每次调用次数加1 + res = func(*args, **kwargs) + torch.cuda.synchronize() + end = time.perf_counter() # 代码执行结束时间 + # longtime = end - start + total_time += end - start + print("SDF in TCNN 前向耗时: ", func.__name__, " 调用次数: ", num, " 累计时间:", total_time) + return res + return call_fun + + def loadMesh(self,mesh,batch,iteration): + sdf=[] + points=[] + self.mesh = trimesh.load_mesh(mesh) + self.name = mesh + scale_fac = np.max(self.mesh.extents) + self.mesh = trimesh.Trimesh.apply_translation(self.mesh,-1 * (self.mesh.bounds[0]+self.mesh.bounds[1])/2) + self.mesh = trimesh.Trimesh.apply_scale(self.mesh,1/scale_fac) + + if not os.path.exists(mesh+"_sdf"+"_"+str(batch)+"_"+str(iteration)+".gz"): + print("Sample SDF form Model") + for i in range(iteration): + points_ = np.random.rand(batch,3) - 0.5 + query = trimesh.proximity.ProximityQuery(self.mesh) + sdf_ = query.signed_distance(points_) + sdf.append(sdf_) + points.append(points_) + print("Batch Size:",i) + self.sdf = np.array(sdf).reshape(-1,1) * -1 + self.points = np.array(points).reshape(-1,3) + np.savetxt(mesh+"_points"+"_"+str(batch)+"_"+str(iteration)+".gz",self.points) + np.savetxt(mesh+"_sdf"+"_"+str(batch)+"_"+str(iteration)+".gz",self.sdf) + else: + self.sdf = np.loadtxt(mesh+"_sdf"+"_"+str(batch)+"_"+str(iteration)+".gz") + self.points = np.loadtxt(mesh+"_points"+"_"+str(batch)+"_"+str(iteration)+".gz") + self.sdf = self.sdf.reshape(iteration,-1,1) + self.points = self.points.reshape(iteration,-1,3) + + def saveMesh(self, gridres, file, it, level=0): + + grid, res = GenerateRasterPoints(np.array([0,0,0]), np.array([1,1,1]), gridres) + print("saveMesh - grid.shape:", grid.shape) + print("saveMesh - res:", res) + grid_size = 64 + sdf = [] + spacing = 64*64*64 + for i in range(0,grid.shape[0],spacing): + grid_portion = grid[i:i+spacing] + sdf_portion = self(torch.tensor(grid_portion, dtype=torch.float32).cuda()).detach().cpu().numpy()[:,0] + sdf.append(sdf_portion) + sdf = np.concatenate(sdf) + try: + vertices, faces, normals, _ = skimage.measure.marching_cubes_lewiner(sdf.reshape(res[0],res[1],res[2]), allow_degenerate=True, level=0) + except Exception as e: + vertices = np.array(((0,0,0),(0.5,0.5,0.5),(0.2,0.3,0.8))) + faces = np.array([0,1,2]) + normals = np.array((0.6,0.5,0.6)) + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_normals=normals) + file_ = open(file+str(it) +'.ply','wb') + file_.write(trimesh.exchange.ply.export_ply(mesh)) + + try: + vertices, faces, normals, _ = skimage.measure.marching_cubes_lewiner(sdf.reshape(res[0],res[1],res[2]), allow_degenerate=True, level=5e-5) + except Exception as e: + vertices = np.array(((0,0,0),(0.5,0.5,0.5),(0.2,0.3,0.8))) + faces = np.array([0,1,2]) + normals = np.array((0.6,0.5,0.6)) + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_normals=normals) + file_ = open(file+str(it)+'_+5e-5' +'.ply','wb') + file_.write(trimesh.exchange.ply.export_ply(mesh)) + + try: + vertices, faces, normals, _ = skimage.measure.marching_cubes_lewiner(sdf.reshape(res[0],res[1],res[2]), allow_degenerate=True, level=5e-4) + except Exception as e: + vertices = np.array(((0,0,0),(0.5,0.5,0.5),(0.2,0.3,0.8))) + faces = np.array([0,1,2]) + normals = np.array((0.6,0.5,0.6)) + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, vertex_normals=normals) + file_ = open(file+str(it)+'_+5e-4' +'.ply','wb') + file_.write(trimesh.exchange.ply.export_ply(mesh)) + + def forward(self, x): + sdf = self.encoder_decoder(x) + return sdf + + #@set_cuda_fun + def forward_with_nablas(self, x): + with torch.enable_grad(): + x = x.requires_grad_(True) + sdf = self.forward(x) + nablas = autograd.grad( + sdf, + x, + torch.ones_like(sdf, device=x.device), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + + return sdf, nablas + +if __name__ == '__main__': + """ + NOTE: Jianfei: I provide three testing tools for backward_backward functionality. + Play around as you want :) + 1. test_train(): train a toy SDF model with eikonal term. + 2. grad_check(): check backward_backward numerical correctness via torch.autograd.gradcheck. + 3. vis_graph(): visualize torch compute graph + """ + + def print_torch_SDF(model): + # for pytorch SDF + print("encoder grad: ", model.encoder.params.grad) + print("========= ========= ========= ========= ========= ========") + print("decoder grad: ") + for i in range(0, len(model.decoder), 2): + if i % 2 == 0: + w_grad = model.decoder[i].weight.grad + print("w_grad.shape: ", w_grad.shape) + r, c = w_grad.shape + if c >= 64 and r > 3: + r = 3 + #else: + # r = 8 + for i in range(0, r): + print("grad[",i*c,", ",(i+1)*c,"]: ", w_grad[i]) + print("========= ========= ========= ========= ========= ========") + + return + + def print_TCNN_layer_weight(prefix, weight, row, col): + # for printing weight of each linear layer in TCNN + print("decoder grad of layer - ", prefix, " - [",row,",", col,"]: ") + idx = 0 + for i in range(row): + print(prefix,"[", idx, ":", idx+col, "]: ", weight[idx:idx+col]) + idx = idx + col + + if idx > 128: + break + print("========= ========= ========= ========= ========= ========") + + return + + def print_TCNN_SDF(model): + # for SDF in TCNN + # print grad + print("encoder_decoder grad.shape: ", model.encoder_decoder.params.grad.shape) + enc_grad = model.encoder_decoder.params.grad[-512:] + print("encoder grad: ", enc_grad) + + dec_grad_0 = model.encoder_decoder.params.grad[0:2048] + dec_grad_1 = model.encoder_decoder.params.grad[2048:6144] # 1st hidden layer + dec_grad_2 = model.encoder_decoder.params.grad[6144:10240] # 2nd hidden layer + dec_grad_3 = model.encoder_decoder.params.grad[10240:10304] # 3rd hidden layer + + print_TCNN_layer_weight("dec_grad_0", dec_grad_0, 32, 64) + print_TCNN_layer_weight("dec_grad_1", dec_grad_1, 64, 64) + print_TCNN_layer_weight("dec_grad_2", dec_grad_2, 64, 64) + print_TCNN_layer_weight("dec_grad_3", dec_grad_3, 64, 1) + + return + + def print_TCNN_SDF_MLP(model): + # for SDF MLP in TCNN + print("decoder grad.shape: ", model.decoder.params.grad.shape) + + dec_grad_0 = model.decoder.params.grad[0:1024] + dec_grad_1 = model.decoder.params.grad[1024:5120] # 1st hidden layer + dec_grad_2 = model.decoder.params.grad[5120:9216] # 1st hidden layer + dec_grad_3 = model.decoder.params.grad[9216:13312] # 1st hidden layer + + print_TCNN_layer_weight("dec_grad_0", dec_grad_0, 16, 64) + print_TCNN_layer_weight("dec_grad_1", dec_grad_1, 64, 64) + print_TCNN_layer_weight("dec_grad_2", dec_grad_2, 64, 64) + print_TCNN_layer_weight("dec_grad_3", dec_grad_3, 64, 64) + + return + + def print_bwdbwd_time(scaler, loss, optimizer, total_time, num): + + torch.cuda.synchronize() + start = time.perf_counter() # 代码执行开始时间 + num += 1 # 每次调用次数加1 + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + torch.cuda.synchronize() + end = time.perf_counter() # 代码执行结束时间 + total_time += end - start + print("SDF in backward_backward耗时: ", " 调用次数: ", num, " 累计时间:", total_time) + + return total_time, num + + def compute_normal(x:torch.Tensor, y): #[N,3] + x.requires_grad_(True) + y.requires_grad_(True) + d_output = torch.ones_like(y, requires_grad=False, device=y.device) + gradients = torch.autograd.grad( + outputs=y, + inputs=x, + grad_outputs=d_output, + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + + return gradients + + def test_train(): + """ + train a toy SDF model with eikonal term. + """ + from tqdm import tqdm + device = torch.device("cuda") + # leverage TCNN SDF + # NOTE: if leveraging TCNN, please leverage print_TCNN_SDF(model) to print temporary variables + model = SDF_TCNN(True, n_levels=1, log2_hashmap_size=15, base_resolution=4, smoothstep=False).to(device) + + # leverage Pytorch SDF + # NOTE: if leveraging Pytorch, please leverage print_torch_SDF(model) to print temporary variables + #model = SDF(True, n_levels=1, base_resolution=4).to(device) + + model.loadMesh("data/Armadillo.ply", 1024, 3000) + + torch.cuda.nvtx.range_push('training_preparation') + torch.manual_seed(0) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + optimizer = Adam(model.parameters(), 2.0e-4) + + iter_i = 0 + fake_nablas = torch.ones([1], device='cuda') + torch.cuda.nvtx.range_pop() + + from torch.cuda.amp import GradScaler + from torch.cuda.amp import autocast + scaler = GradScaler() + + num = 0 + total_time = 0 + + with tqdm(range(10000)) as pbar: + for _ in pbar: + + # leverage 2 sets of data for numeric alignment + #r = torch.tensor([[-0.4805, 0.2734, 0.3652], [ 0.3096, 0.1665, -0.1348]], device='cuda', dtype=torch.float16) + r = torch.tensor([[-0.4805, 0.2734, 0.3652]], device='cuda', dtype=torch.float16) + + sdf_, nablas = model.forward_with_nablas(r) + eikonal_loss = torch.sum(torch.abs(torch.norm(nablas, p=2, dim=1) - 1)) + + loss = eikonal_loss + optimizer.zero_grad() + loss.backward() + optimizer.step() + + print("===== within the ", iter_i, "-th iteration: =====") + print("r: ", r) + print("sdf_: ", sdf_) + print("nablas: ", nablas) + print("loss: ", loss) + print("===== ========= ======== ========= ======== =====") + + torch.cuda.nvtx.range_pop() + + # print grad details when using pytorch SDF in line 528 + #print_torch_SDF(model) + + # print grad details when using TCNN SDF in line 524 + print_TCNN_SDF(model) + + pbar.set_postfix(loss=loss.item()) + + if (torch.isnan(loss)): + break + + if(iter_i % 1000 == 0): + model.saveMesh(128, "result/test", iter_i) + iter_i = iter_i + 1 + if iter_i > 0: + break + + return + + def save_model_param(model): + file_dir = os.getcwd() + params = [] + params_input = [] + flag_layer = 0 + for m in model.decoder.modules(): + if isinstance(m, nn.Linear): + if flag_layer == 0: + params_input.append(m.weight.clone().detach().cpu().numpy()) + print("params of input layer m.weight: ", m.weight.clone().detach().cpu().numpy()) + elif flag_layer > 0: + params.append(m.weight.clone().detach().cpu().numpy()) + print("params of m.weight: ", m.weight.clone().detach().cpu().numpy()) + flag_layer = 1 + np.save("test_params_iter_1_input.npy", params_input) + np.save("test_params_iter_1.npy", params) + + return + +if __name__ == "__main__": + + # test tcnn + test_train() + + diff --git a/scripts/test_sdf_derivation.py b/scripts/test_sdf_derivation.py new file mode 100755 index 00000000..e4d1f443 --- /dev/null +++ b/scripts/test_sdf_derivation.py @@ -0,0 +1,865 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +from torch import autograd +from torch.optim import Adam +import torch.nn.functional as F + +import sys +import os +import numpy as np + +try: + import tinycudann as tcnn +except ImportError: + print("This script requires the tiny-cuda-nn extension for PyTorch.") + print("You can install it by running:") + print("============================================================") + print("tiny-cuda-nn$ cd bindings/torch") + print("tiny-cuda-nn/bindings/torch$ python setup.py install") + print("============================================================") + sys.exit() + + +class SDF(nn.Module): + def __init__(self, hash=True, n_levels=12, log2_hashmap_size=15, base_resolution=16, smoothstep=False) -> None: + super().__init__() + + self.encoder = tcnn.Encoding(3, { + "otype": "HashGrid" if hash else "DenseGrid", + "n_levels": n_levels, + "n_features_per_level": 8, + "log2_hashmap_size": log2_hashmap_size, + "base_resolution": base_resolution, + "per_level_scale": 1.5, + "interpolation": "Smoothstep" if smoothstep else "Linear" + }) + + # 8 layers in total, 7 hidden layers + self.decoder = nn.Sequential( + nn.Linear(self.encoder.n_output_dims, 64), + nn.Softplus(beta = 10.0), + nn.Linear(64, 64), + nn.Softplus(beta = 10.0), + #nn.Linear(64, 64), + #nn.Softplus(beta = 10.0), + #nn.Linear(64, 64), + #nn.Softplus(beta = 10.0), + #nn.Linear(64, 64), + #nn.Softplus(beta = 10.0), + #nn.Linear(64, 64), + #nn.Softplus(beta = 10.0), + #nn.Linear(64, 64), + #nn.Softplus(beta = 10.0), + #nn.Linear(64, 1), + #nn.Softplus(beta = 10.0) + ) + + for i in range(0, len(self.decoder), 2): + if i % 2 == 0: + self.decoder[i].weight.data.fill_(0.005) + #self.decoder[i].bias.data.fill_(0.0) + + def forward(self, x): + enc_output = self.encoder(x).float() + mlp_input = enc_output + + p = [] + output = [] + for i in range(0, len(self.decoder), 2): + tmp_p = self.decoder[i](mlp_input) + tmp_output = self.decoder[i+1](tmp_p) + + p.append(tmp_p) + output.append(tmp_output) + mlp_input = tmp_output + + return enc_output, p, output + + def forward_with_nablas(self, x): + with torch.enable_grad(): + x = x.requires_grad_(True) + mlp_input, p, output = self.forward(x) + output_last = output[-1] + + device = torch.device('cuda:0') + L1 = 2.0 * torch.sum(output_last) #torch.sum(torch.pow(output - GT_64, 2.0)) + dL1doutput = (autograd.grad( + L1, + output_last, + retain_graph=True + )[0]).requires_grad_(True) + + nablas = autograd.grad( # dL1dinput aka dL1dmlp_input + output_last, + mlp_input, # x + dL1doutput, + create_graph=True, + retain_graph=True, + only_inputs=True + )[0].requires_grad_(True) + + dL1_dinput = autograd.grad( # dL1dinput aka dL1dmlp_input + output_last, + mlp_input, # x + dL1doutput, + create_graph=True, + retain_graph=True, + only_inputs=True + )[0].requires_grad_(True) + + dL1doutput_list = [] + for i in range(0, len(self.decoder), 2): + idx = int(i / 2) + if idx == int(len(self.decoder)/2 - 1): + input = mlp_input + else: + input = output[-(idx+2)] + output_current = output[-(idx+1)] + + dL1doutput_list.append(dL1doutput) + dL1dinput = (autograd.grad( + output_current, + input, + dL1doutput, + retain_graph=True, + create_graph=True, + )[0]).requires_grad_(True) + + dL1doutput = dL1dinput + + dL1_dinput = dL1doutput + + # second order + L2 = torch.sum(dL1_dinput) #torch.sum(dL1_dinput * dL1_dinput) + dL2_ddL1dinput = autograd.grad( + L2, + dL1_dinput, + create_graph=True, + retain_graph=True, + )[0].requires_grad_(True) + + dL2_ddL1doutput_list = [] + dL2_dinput_list = [] + for i in range(0, len(self.decoder), 2): + idx = int(i / 2) + dL1doutput = dL1doutput_list[-(idx+1)] + if idx == 0: + input = mlp_input + else: + input = output[idx-1] + + dL2_ddL1doutput = autograd.grad( + L2, + dL1doutput, + #create_graph=True, + retain_graph=True + )[0].requires_grad_(True) + + dL2_dinput = autograd.grad( + L2, + input, + #create_graph=True, + retain_graph=True + )[0].requires_grad_(True) + + dL2_ddL1doutput_list.append(dL2_ddL1doutput) + dL2_dinput_list.append(dL2_dinput) + + return mlp_input, p, output, nablas, dL1doutput_list, dL2_ddL1doutput_list, dL2_dinput_list + +class SDF_multi(nn.Module): # for derivation + def __init__(self, hash=True, n_levels=12, log2_hashmap_size=15, base_resolution=16, smoothstep=False) -> None: + super().__init__() + + self.encoder = tcnn.Encoding(3, { + "otype": "HashGrid" if hash else "DenseGrid", + "n_levels": 16, #n_levels, + "n_features_per_level": 2, #8, + "log2_hashmap_size": 19, #log2_hashmap_size, + "base_resolution": 16, #base_resolution, + "per_level_scale": np.exp2(np.log2(2048 * 1 * 1 / 16) / (16 - 1)), #1.5, + #"interpolation": "Smoothstep" if smoothstep else "Linear" + }) + + # 8 layers in total, 7 hidden layers + b_flag = False + self.decoder = nn.Sequential( + nn.Linear(self.encoder.n_output_dims, 64, bias=b_flag), + nn.Softplus(beta = 10.0), + nn.Linear(64, 64, bias=b_flag), + nn.Softplus(beta = 10.0), + nn.Linear(64, 64, bias=b_flag), + nn.Softplus(beta = 10.0), + nn.Linear(64, 1, bias=b_flag), + #nn.Softplus(beta = 10.0), + #nn.Linear(64, 64), + #nn.Softplus(beta = 10.0), + #nn.Linear(64, 64), + #nn.Softplus(beta = 10.0), + #nn.Linear(64, 64), + #nn.Softplus(beta = 10.0), + #nn.Linear(64, 1), + #nn.Softplus(beta = 10.0) + ) + ''' + # read params from numpy files + params_input = np.load('params_input.npy', allow_pickle=True) # init for input layers + params = np.load('params.npy', allow_pickle=True) # init for hidden + output layers + idx = 0 + for m in self.decoder.modules(): + if isinstance(m, nn.Linear): + if idx == 0: + m.weight.data = torch.from_numpy(params_input[idx]) + else: + m.weight.data = torch.from_numpy(params[idx-1]) + #m.bias.data = torch.from_numpy(params[idx + 1]) + idx += 1 + + #import numpy as np + file_dir = os.getcwd() + enc_params = np.load(os.path.join(file_dir, 'params_encoder.npy')) + self.encoder.params = torch.nn.Parameter(torch.from_numpy(enc_params)) + #print("encoder.params: ", self.encoder.params) + ''' + + params_enc = np.load('numpy/params_sdf_enc.npy') + self.encoder.params.data = torch.from_numpy(params_enc) + + idx = 0 + for m in self.decoder.modules(): + if isinstance(m, nn.Linear): + if idx == 0: + params_input = np.load("numpy/params_input.npy") + m.weight.data = torch.from_numpy(params_input) + elif idx == 1: + params_hidden_1 = np.load("numpy/params_hidden_1.npy") + m.weight.data = torch.from_numpy(params_hidden_1) + elif idx == 2: + params_hidden_2 = np.load("numpy/params_hidden_2.npy") + m.weight.data = torch.from_numpy(params_hidden_2) + + else: + params_output = np.load("numpy/params_output.npy") + m.weight.data = torch.from_numpy(params_output) + idx += 1 + + + def forward(self, x): + enc_output = self.encoder(x).float() + mlp_input = enc_output + + p = [] + output = [] + for i in range(0, len(self.decoder), 2): + tmp_p = self.decoder[i](mlp_input) + if (len(self.decoder) % 2 and i+1 == len(self.decoder)): + tmp_output = tmp_p + else: + tmp_output = self.decoder[i+1](tmp_p) + + p.append(tmp_p) + output.append(tmp_output) + mlp_input = tmp_output + + return enc_output, p, output + + def forward_with_nablas(self, x): + with torch.enable_grad(): + x = x.requires_grad_(True) + mlp_input, p, output = self.forward(x) + output_last = output[-1] + + device = torch.device('cuda:0') + #L1 = 2.0 * torch.sum(output_last) + L1 = torch.sum(output_last) + dL1doutput = (autograd.grad( + L1, + output_last, + retain_graph=True + )[0]).requires_grad_(True) + + nablas = autograd.grad( # dL1dinput aka dL1dmlp_input + output_last, + x, # mlp_input + dL1doutput, + create_graph=True, + retain_graph=True, + only_inputs=True + )[0].requires_grad_(True) + + dL1dinput = autograd.grad( # dL1dinput aka dL1dmlp_input + output_last, + x, # mlp_input + dL1doutput, + create_graph=True, + retain_graph=True, + only_inputs=True + )[0].requires_grad_(True) + # compute dL1denc_input + dL1denc_input = dL1dinput + + # compute dL1doutput in backward order + dL1doutput_list = [] + for i in range(0, len(self.decoder), 2): + idx = int(i / 2) + if idx == int((len(self.decoder)+1)/2 - 1): #if idx == int(len(self.decoder)/2 - 1): + input = mlp_input + else: + input = output[-(idx+2)] + output_current = output[-(idx+1)] + + dL1doutput_list.append(dL1doutput) + + dL1dinput = (autograd.grad( + output_current, + input, + dL1doutput, + retain_graph=True, + create_graph=True, + )[0]).requires_grad_(True) + + dL1doutput = dL1dinput + # append dL1doutput for encoder output + dL1denc_output = dL1doutput + + dL1denc_input = autograd.grad( # dL1dinput aka dL1dmlp_input + mlp_input, + x, #mlp_input, # x + dL1denc_output, + create_graph=True, + retain_graph=True, + only_inputs=True + )[0].requires_grad_(True) + + # second order + dL1_dinput = dL1denc_input + #L2 = torch.sum(nablas) + L2 = torch.sum(torch.abs(torch.norm(nablas, p=2, dim=1) - 1)) + + dL2_ddL1dinput = autograd.grad( + L2, + nablas, # x: the original input + create_graph=True, + retain_graph=True, + )[0].requires_grad_(True) + + # encoder second order + dL2_ddL1denc_input = dL2_ddL1dinput + + w = self.encoder.params + dL2_denc_w = autograd.grad( + L2, + w, + create_graph=True, # False + retain_graph=True + )[0] + + dL2_ddL1denc_output = autograd.grad( + dL1_dinput, + dL1denc_output, + dL2_ddL1dinput, + create_graph=True, + retain_graph=True + )[0].requires_grad_(True) + + dL2_denc_input = autograd.grad( + L2, + x, + create_graph=True, + retain_graph=True + )[0].requires_grad_(True) + + # mlp second order + dL1_dinput = dL1denc_output + dL2_ddL1dinput = dL2_ddL1denc_output + + dL2_ddL1doutput_list = [] + dL2_dinput_list = [] + dL2_dw_list = [] + for i in range(0, len(self.decoder), 2): + idx = int(i / 2) + if idx == 0: + input = mlp_input + dL1_dinput = dL1denc_output + else: + input = output[idx-1] + dL1dinput = dL1doutput + dL1doutput = dL1doutput_list[-(idx+1)] + + w = self.decoder[i].weight + + dL2_dw = autograd.grad( + L2, + w, + create_graph=False, + retain_graph=True + )[0] + + #print("dL1dinput.shape: ", dL1dinput.shape) + #print("dL1doutput.shape: ", dL1doutput.shape) + dL2_ddL1doutput = autograd.grad( + dL1dinput, + dL1doutput, + dL2_ddL1dinput, + create_graph=True, + retain_graph=True + )[0].requires_grad_(True) + dL2_ddL1dinput = dL2_ddL1doutput + + if (len(self.decoder) % 2 and idx == int((len(self.decoder)+1)/2 - 1)): + dL2_dinput = torch.zeros_like(input).to(device) + else: + dL2_dinput = autograd.grad( + L2, + input, + create_graph=True, + retain_graph=True + )[0].requires_grad_(True) + + dL2_ddL1doutput_list.append(dL2_ddL1doutput) + dL2_dinput_list.append(dL2_dinput) + dL2_dw_list.append(dL2_dw) + + # writing outputs to file + file_dir = os.getcwd() + file_path = os.path.join(file_dir, "python_output.txt") + f = open(file_path, "w") + print("&&&&&&& python output &&&&&&&", file=f) + print("x = ", x, file=f) + print("mlp_input = ", mlp_input, file=f) + print("p = ", p, file=f) + print("output = ", output, file=f) + print("nablas = ", nablas, file=f) + print("dL2_ddL1denc_input = ", dL2_ddL1denc_input, file=f) + print("dL2_ddL1denc_output = ", dL2_ddL1denc_output, file=f) + #print("dL2_denc_input = ", dL2_denc_input, file=f) + #print("dL2_denc_w = ", dL2_denc_w, file=f) + #print("dL1denc_output = ", dL1denc_output, file=f) + #print("dL1doutput = ", dL1doutput_list, file=f) + #print("dL1denc_input = ", dL1denc_input, file=f) + print("dL2_ddL1doutput_list: ", dL2_ddL1doutput_list, file=f) + print("dL2dw = ", dL2_dw_list[-1], file=f) + for i in range(dL2_dw_list[-1].shape[0]): + print("dL2dw[",i,"]:", dL2_dw_list[-1][i], file=f) + print("end of dL2dw", file=f) + print("dL2_dinput_list = ", dL2_dinput_list, file=f) + print("dL2_ddL1doutput = ", dL2_ddL1doutput, file=f) + # mlp_input, p, output + print("&&&&&&& &&&&&& &&&&&& &&&&&&&", file=f) + + return mlp_input, p, output, nablas, dL1doutput_list, dL2_ddL1doutput_list, dL2_dinput_list, dL2_dw_list, dL1denc_input, dL2_ddL1denc_output, dL2_denc_w + +if __name__ == '__main__': + """ + NOTE: Jianfei: I provide three testing tools for backward_backward functionality. + Play around as you want :) + 1. test_train(): train a toy SDF model with eikonal term. + 2. grad_check(): check backward_backward numerical correctness via torch.autograd.gradcheck. + 3. vis_graph(): visualize torch compute graph + """ + + def test_grad_grad_mlp_(): + ## ================ params declaration ================ + device = torch.device("cuda") + model = SDF_multi(True, n_levels=1, log2_hashmap_size=15, base_resolution=4, smoothstep=False).to(device) + + print("======= within python verification ======= ") + print("model.encoder.params.shape(): ", len(model.encoder.params)) + print("model.encoder.params: ", model.encoder.params[0:8]) + + #x = (torch.tensor([[0.0679, 0.1012, 0.1586]], dtype=torch.float, device=device)).requires_grad_(True) + #x = torch.tensor([[-0.4805, 0.2734, 0.3652], [ 0.3096, 0.1665, -0.1348]], device='cuda', dtype=torch.float16).requires_grad_(True) + x = torch.tensor([[-0.4805, 0.2734, 0.3652]], device='cuda', dtype=torch.float16).requires_grad_(True) + #x = (torch.rand((1, 3), dtype=torch.float, device=device)).requires_grad_(True) + + mlp_input, p_list, output_list, nablas, dL1doutput_n_torch_list, \ + dL2_ddL1doutput_list, dL2_dinput_list, dL2_dw_list, \ + dL1denc_input, dL2_ddL1denc_output, dL2_denc_w = model.forward_with_nablas(x) + + ## ================== file output ===================== + file_dir = os.getcwd() + # save model params + import numpy as np + enc_params = model.encoder.params.clone().detach().cpu().numpy() + + np.save(os.path.join(file_dir, 'params_encoder.npy'), np.float32(enc_params)) + enc_params = np.load(os.path.join(file_dir, 'params_encoder.npy')) + print("enc_params: ", enc_params) + + file_path = os.path.join(file_dir, "output.txt") + f = open(file_path, "w") + + + ## ================ first order backward: dL1dinput comparison ================ + output = output_list[-1] + + # L1 loss + L1 = torch.sum(output) # torch.sum(torch.pow(output - GT_64, 2.0)) # 2.0 * torch.sum(output) + dL1doutput = autograd.grad(L1, output, create_graph=True, retain_graph=True)[0].requires_grad_(True) + K_ACT = 10.0 + + dL1doutput_list = [] + dL1dp_list = [] + + for i in range(0, len(model.decoder), 2): + idx = int(i / 2) + if idx == int((len(model.decoder)+1)/2 - 1): + input = mlp_input + else: + input = output_list[-(idx+2)] + p = p_list[-(idx+1)] + output = output_list[-(idx+1)] + #w = model.decoder[-(i+2)].weight # last layer with activation + w = model.decoder[-(i+1)].weight # last layer without activation + dL1doutput_list.append(dL1doutput) + + #print("==================== input ====================", file=f) + #print("idx: ", idx, "input.shape: ", input.shape, "p.shape: ", p.shape, "output.shape: ", output.shape, file=f) + + # if it's the last layer, p equals to output and there is no activation + # therefore we don't have to compute dL1dp + if idx == 0: + dL1dp = dL1doutput + dL1dp_torch = (autograd.grad(L1, p, retain_graph=True)[0]).requires_grad_(True) + dL1dp_list.append(dL1dp) + else: + # compute doutputdp + p_m = p.shape[1] + doutputdp = torch.zeros([p_m, p_m], dtype=torch.float32, device=device) + for j in range(p_m): + # 1 – 1/(e^(p * K_ACT) + 1.0) + doutputdp_j = 1.0 - 1.0/(torch.exp(p[0, j] * K_ACT) + 1.0) + doutputdp[j, j] = doutputdp_j + + dL1dp_torch = (autograd.grad(L1, p, retain_graph=True)[0]).requires_grad_(True) + dL1dp = torch.matmul(dL1doutput, doutputdp) + # save in reverse order + dL1dp_list.append(dL1dp) + + #print("==================== first order derivative ====================") + #print("==================== first order derivative ====================", file=f) + er = torch.allclose(dL1dp, dL1dp_torch, rtol=1e-07, atol=1e-07) + #print("allclose of dL1dp and dL1dp_torch in atol=1e-07 rtol=1e-07: ", er) + #print("allclose of dL1dp and dL1dp_torch in atol=1e-07 rtol=1e-07: ", er, file=f) + + # compare the difference between dL1doutput and dL1doutput_torch + er = torch.allclose(dL1doutput, dL1doutput_n_torch_list[idx], rtol=1e-07, atol=1e-07) + #print("allclose of dL1doutput and dL1doutput_n_torch_list in atol=1e-07 rtol=1e-07: ", er) + #print("allclose of dL1doutput and dL1doutput_n_torch_list in atol=1e-07 rtol=1e-07: ", er, file=f) + + + # compute dpdinput = w + input_m = input.shape[1] + p_m = p.shape[1] + dpdinput = torch.ones([p_m, input_m], dtype=torch.float32, device=device) + for j in range(p_m): + dpdinput_j = autograd.grad(p[0, j], input, retain_graph=True)[0] + dpdinput[j] = dpdinput_j + + ## compute dL1dinput + dL1dinput = torch.matmul(dL1dp, dpdinput).requires_grad_(True) + # pytorch dL1dinput + dL1dinput_torch = autograd.grad(L1, input, retain_graph=True, create_graph=True)[0].requires_grad_(True) + + er = torch.allclose(dL1dinput, dL1dinput_torch, rtol=1e-07, atol=1e-07) + #print("allclose of dL1dinput and dL1dinput_torch in atol=1e-07 rtol=1e-07: ", er) + #print("allclose of dL1dinput and dL1dinput_torch in atol=1e-07 rtol=1e-07: ", er, file=f) + + ## compute dL1dw + dL1dw = torch.matmul(torch.transpose(dL1dp, 0, 1), input) + dL1dw_torch = autograd.grad(L1, w, retain_graph=True)[0] + er = torch.allclose(dL1dw, dL1dw_torch, rtol=1e-07, atol=1e-07) + #print("allclose of dL1dw and dL1dw_torch in atol=1e-07 rtol=1e-07: ", er) + #print("allclose of dL1dw and dL1dw_torch in atol=1e-07 rtol=1e-07: ", er, file=f) + #print("==================== ====================== ====================") + #print("==================== ====================== ====================", file=f) + + + dL1doutput = dL1dinput + + ## ================ second order backward: dL2d_dL1dp comparison ================ + # L2 loss + L2 = torch.sum(nablas) #torch.sum(nablas * nablas) #torch.sum(nablas) + print("==================== Loss 2 ====================", file=f) + print("L2: ", L2, file=f) + print("==================== ====== ====================", file=f) + + # MLP 2nd order derivative + dL2d_dL1dinput = autograd.grad( + L2, + nablas, + retain_graph=True, + create_graph=False + )[0] + dL2d_dL1dinput = dL2_ddL1denc_output + + dL2dinput_self_list = [] + dL2dw_self_list = [] + dL2dw_torch_list = [] + ''' + print("================= dL1dp_list ================", file=f) + print("dL1dp_list: ", dL1dp_list, file=f) + print("+++++++++++++++++++++++++++++++++++++++++++++", file=f) + ''' + for i in range(0, len(model.decoder), 2): + idx = int(i / 2) + if i == 0: + input = mlp_input + else: + input = output_list[idx-1] # after activation + p = p_list[idx] + output = output_list[idx] + w = model.decoder[i].weight + + dL1dp = dL1dp_list[-(idx+1)] + dL1doutput = dL1doutput_n_torch_list[-(idx+1)] + dL1doutput_torch = dL1doutput_n_torch_list[-(idx+1)] + input_m = input.shape[1] + p_m = p.shape[1] + + ## compute dL2dw + # the 1st term: dL1doutput x diag[e^(p*K_ACT) / (e^p*K_ACT + 1.0)^2] * K_ACT * p + d_dL1dinput_dw = torch.transpose(dL1dp, 0, 1) + dL2dw_1 = torch.matmul(d_dL1dinput_dw, dL2d_dL1dinput) + + # the 2nd term: dL1doutput x d_doutputdp_dw x dpdinput + # compute d_doutputdp_dw + + # =========== new modification ============ # + # dL2_ddL1dinput_2 x w_2 shape: [1, 64] + wT = w.transpose(0, 1) + dL2_ddL1dinput_x_w = torch.matmul(dL2d_dL1dinput, wT) + dL1doutput_x_dL2_ddL1dinput_x_w = torch.mul(dL1doutput, dL2_ddL1dinput_x_w) + + doutputdp_2 = torch.zeros([p_m, p_m], dtype=torch.float32, device=device) + for i in range(doutputdp_2.shape[0]): + tmp = torch.exp(p[0, i] * K_ACT)/torch.pow(torch.exp(p[0, i] * K_ACT)+1.0, 2.0) * K_ACT + doutputdp_2[i, i] = tmp + + ddoutputdp_dp_2 = torch.matmul(dL1doutput_x_dL2_ddL1dinput_x_w, doutputdp_2) + dL2dw_2 = torch.matmul(torch.transpose(ddoutputdp_dp_2, 0, 1), input) + + if idx == int(len(model.decoder)/2): + dL2dw_2 = torch.zeros_like(dL2dw_2) + + # dL2dw add up + dL2dw = dL2dw_1 + dL2dw_2 + dL2dw_self_list.append(dL2dw) + + dL2dw_torch = autograd.grad( + L2, + w, + retain_graph=True, + create_graph=False)[0] + dL2dw_torch_list.append(dL2dw_torch) + + er = torch.allclose(dL2dw, dL2dw_torch, rtol=1e-07, atol=1e-07) + + ''' + print("++++++++++ validation of dL2dw ++++++++++", file=f) + #print("p: ", p, file=f) + #print("d_dL1dinput_dw: ", d_dL1dinput_dw, file=f) + #print("dL1doutput: ", dL1doutput, file=f) + print("dL2dw_1: ", dL2dw_1, file=f) + #print("ddoutputdp_dp_2: ", ddoutputdp_dp_2, file=f) + #print("dL2_ddL1dinput_x_w: ", dL2_ddL1dinput_x_w, file=f) + #print("dL1doutput_x_dL2_ddL1dinput_x_w: ", dL1doutput_x_dL2_ddL1dinput_x_w, file=f) + #print("doutputdp_2: ", doutputdp_2, file=f) + + #print("input: ", input, file=f) + #print("output: ", output, file=f) + #print("ddoutputdp_dp_2: ", ddoutputdp_dp_2, file=f) + print("dL2dw_2: ", dL2dw_2, file=f) + print("dL2dw: ", file=f) + for ii in range(dL2dw.shape[0]): + print(dL2dw[ii], file = f) + + + print("dL2dw_torch: ", dL2dw_torch, file=f) + print("++++++++++ =================== ++++++++++", file=f) + ''' + + ## compute dL2dinput + dL2dinput_torch = dL2_dinput_list[idx] #autograd.grad(L2, input, retain_graph=True, create_graph=False)[0] + + d_doutputdp_dinput = torch.zeros([p_m, input_m], dtype=torch.float32, device=device) + #print("================== validation of dL2dinput_1 doutputdp ===================", file=f) + #print("p: ", p, file=f) + for j in range(p_m): + #print("within compute_dL2dinput - doutputdp_2 = ", torch.exp(p[0, j] * K_ACT)/torch.pow(torch.exp(p[0, j] * K_ACT)+1.0, 2.0) * K_ACT) + tmp = torch.exp(p[0, j] * K_ACT)/torch.pow(torch.exp(p[0, j] * K_ACT)+1.0, 2.0) * K_ACT * dL1doutput[0, j] + #print("doutputdp_2[",j,"] = ", torch.exp(p[0, j] * K_ACT)/torch.pow(torch.exp(p[0, j] * K_ACT)+1.0, 2.0) * K_ACT * dL1doutput[0, j], file=f) + for k in range(input_m): + d_doutputdp_dinput[j, k] = tmp * w[j, k] #torch.sum(w[i, :]) #x[0, j] + #print("===========================================================================", file=f) + dL2dinput_1 = torch.matmul(dL2d_dL1dinput, torch.matmul(torch.transpose(d_doutputdp_dinput, 0, 1), w)) + + if idx == int(len(model.decoder)/2): + dL2dinput_1 = torch.zeros_like(dL2dinput_1) + + + print("======== validation of dL2dinput_1: 2nd order term ========", file=f) + print("p: ", p, file=f) + print("dL1doutput: ", dL1doutput, file=f) + print("d_doutputdp_dinput: ", d_doutputdp_dinput, file=f) + print("dL2dinput_1: ", dL2dinput_1, file=f) + print("dL2dinput_torch: ", dL2dinput_torch, file=f) + #print("dL2d_dL1dinput: ", dL2d_dL1dinput, file=f) + #print("torch.matmul(torch.transpose(d_doutputdp_dinput, 0, 1), w): ", torch.matmul(torch.transpose(d_doutputdp_dinput, 0, 1), w), file=f) + print("======== ================================ ========", file=f) + + + # einsum for dL2_dinput + ddoutputdp_dp = torch.zeros([p_m, p_m, p_m], dtype=torch.float32, device=device) + for j in range(p_m): + ddoutputdp_dp[j, j, j] = torch.exp(p[0, j] * K_ACT)/torch.pow(torch.exp(p[0, j] * K_ACT)+1.0, 2.0) * K_ACT + # ddoutputdp_dinput = ddoutputdp_dp x w + ddoutputdp_dinput = torch.zeros([p_m, p_m, input_m], dtype=torch.float32, device=device) + ddoutputdp_dinput = torch.einsum("ijk, kh->ijh", ddoutputdp_dp, w) + ddL1doutput_dinput = torch.einsum("bi, ijh->jh", dL1doutput, ddoutputdp_dinput) + wT = w.transpose(0, 1) + ddL1dinput_dinput = torch.matmul(wT, ddL1doutput_dinput) + dL2dinput = torch.matmul(dL2d_dL1dinput, ddL1dinput_dinput) + + if idx == int(len(model.decoder)/2): # last layer without activation + dL2dinput = torch.zeros_like(dL2dinput) + + dL2dinput_self_list.append(dL2dinput) + + er = torch.allclose(dL2dinput_1, dL2dinput_torch, rtol=1e-07, atol=1e-07) + #print("allclose of dL2dinput and dL2dinput_torch in atol=1e-07 rtol=1e-07: ", er) + #print("allclose of dL2dinput and dL2dinput_torch in atol=1e-07 rtol=1e-07: ", er, file=f) + + print("======== validation of dL2dinput 2nd order term ========", file=f) + print("dL2dinput_1: ", dL2dinput_1, file=f) + print("dL2dinput_torch: ", dL2dinput_torch, file=f) + print("======== dL2dinput 2nd order term ends ========", file=f) + + + # compute doutputdp + p_m = p.shape[1] + doutputdp = torch.zeros([p_m, p_m], dtype=torch.float32, device=device) + for j in range(p_m): + # 1 – 1/(e^(p * K_ACT) + 1.0) + doutputdp_j = 1.0 - 1.0/(torch.exp(p[0, j] * K_ACT) + 1.0) + doutputdp[j, j] = doutputdp_j + + ## compute dL2d_dL1doutput + dL2d_dL1dp = torch.matmul(dL2d_dL1dinput, torch.transpose(w, 0, 1)) + dL2d_dL1doutput = torch.matmul(dL2d_dL1dp, doutputdp) + + if idx == int(len(model.decoder)/2): # last layer without activation + dL2d_dL1doutput = dL2d_dL1dp + + + # pytorch dL2d_dL1doutput + dL2d_dL1doutput_torch = dL2_ddL1doutput_list[idx] + er = torch.allclose(dL2d_dL1doutput, dL2d_dL1doutput_torch, rtol=1e-07, atol=1e-07) + #print("allclose of dL2d_dL1doutput and dL2d_dL1doutput_torch in atol=1e-07 rtol=1e-07: ", er) + #print("allclose of dL2d_dL1doutput and dL2d_dL1doutput_torch in atol=1e-07 rtol=1e-07: ", er, file=f) + + # for iteration + dL2d_dL1dinput = dL2d_dL1doutput_torch # dL2d_dL1doutput + + + print("======== validation of dL2d_dL1doutput ========", file=f) + #print("dL2d_dL1dp i.e. 1st step: ", dL2d_dL1dp, file=f) + #print("doutputdp: ", doutputdp, file=f) + print("dL2d_dL1doutput: ", dL2d_dL1doutput, file=f) + print("dL2d_dL1doutput_torch: ", dL2d_dL1doutput_torch, file=f) + print("======== ================================ ========", file=f) + + + dL2dinput_2_list = [] + dL2dinput_2_list.append(dL2dinput_self_list[-1]) + print("dL2_dinput_list: ", dL2_dinput_list, file=f) + #print("=============== add up L1 backward within L2 loss ===============", file=f) + for i in range(0, len(model.decoder)-2, 2): + idx = int(i / 2) + print("within 1st order in 2nd order derivative:") + print("idx: ", idx) + print("len(output_list): ", len(output_list)) + print("-(idx+2+1): ", -(idx+2+1)) + print("=================================================================") + print("i: ", i, "len(model.decoder) - 2 - 2: ", len(model.decoder) - 2 - 2) + if i == len(model.decoder) - 1 - 2: + input = mlp_input + else: + input = output_list[-(idx+2+1)] + + dL2doutput = dL2dinput_self_list[-(idx+1)] # the last dL2dinput + + p = p_list[-(idx+1+1)] + w = model.decoder[-(i+1+2)].weight + p_m = p.shape[1] + + doutputdp = torch.zeros(p_m, p_m) + doutputdp = torch.zeros([p_m, p_m], dtype=torch.float32, device=device) + for j in range(p_m): + # 1 – 1/(e^(p * K_ACT) + 1.0) + doutputdp_j = 1.0 - 1.0/(torch.exp(p[0, j] * K_ACT) + 1.0) + doutputdp[j, j] = doutputdp_j + + dL2dp = torch.matmul(dL2doutput, doutputdp) + dL2dinput_2 = torch.matmul(dL2doutput, torch.matmul(doutputdp, w)) + dL2dinput_2_list.append(dL2dinput_2) + dL2dinput_1 = dL2dinput_self_list[-(idx+2)] + + dL2dinput = dL2dinput_1 + dL2dinput_2 + dL2dinput_self_list[-(idx+2)] = dL2dinput + + dL2dinput_torch = dL2_dinput_list[-(idx+2)] + + er = torch.allclose(dL2dinput, dL2dinput_torch, rtol=1e-07, atol=1e-07) + #print("allclose of dL2dinput and dL2dinput_torch in atol=1e-07 rtol=1e-07: ", er) + #print("allclose of dL2dinput and dL2dinput_torch in atol=1e-07 rtol=1e-07: ", er, file=f) + + dL2dp = torch.matmul(dL2doutput, doutputdp) + dL2dw_2 = torch.matmul(dL2dp.transpose(0, 1), input) + dL2dw_1 = dL2dw_self_list[-(idx+2)] + + #print("dL2dw_1: ", dL2dw_1, file=f) + #print("dL2dw_2: ", dL2dw_2, file=f) + dL2dw = dL2dw_1 + dL2dw_2 + + dL2dw_torch = dL2_dw_list[-(idx+2)] #dL2dw_torch_list[-(idx+2)] + #print("dL2dw: ", dL2dw) + #print("dL2dw_torch: ", dL2dw_torch) + #print("dL2dw: ", dL2dw, file=f) + #print("dL2dw_torch: ", dL2dw_torch, file=f) + er = torch.allclose(dL2dw, dL2dw_torch, rtol=1e-07, atol=1e-07) + #print("allclose of dL2dw and dL2dw_torch in atol=1e-07 rtol=1e-07: ", er) + #print("allclose of dL2dw and dL2dw_torch in atol=1e-07 rtol=1e-07: ", er, file=f) + + + print("======== validation of dL2dinput_2 ========", file=f) + print("dL2dinput_1: ", dL2dinput_1, file=f) + #print("doutputdp: ", doutputdp, file=f) + print("dL2doutput: ", dL2doutput, file=f) + #print("dL2dp: ", dL2dp, file=f) + print("dL2dinput_2: ", dL2dinput_2, file=f) + print("dL2dinput: ", dL2dinput, file=f) + print("dL2dinput_torch: ", dL2dinput_torch, file=f) + print("======== ================================ ========", file=f) + + print("======== validation of dL2dw ========", file=f) + print("dL2dp: ", dL2dp, file=f) + print("input: ", input, file=f) + print("dL2dw_2: ", dL2dw_2, file=f) + print("dL2dw: ", dL2dw, file=f) + print("dL2dw_torch: ", file=f) + for ii in range(dL2dw_torch.shape[0]): + print(dL2dw_torch[ii], file=f) + print("======== ================================ ========", file=f) + + + output = output_list[-1] + + +if __name__ == "__main__": + # test encoding + multi-layer perceptron + test_grad_grad_mlp_() + + diff --git a/scripts/test_sphere_bwdbwd.py b/scripts/test_sphere_bwdbwd.py new file mode 100755 index 00000000..86bafa25 --- /dev/null +++ b/scripts/test_sphere_bwdbwd.py @@ -0,0 +1,663 @@ +#!/usr/bin/env python3 + +import torch +import torch.nn as nn +from torch import autograd +from torch.optim import Adam, SGD +import torch.nn.functional as F + +import sys +import os +import numpy as np +import time +NoLog = False + +try: + import tinycudann as tcnn +except ImportError: + print("This script requires the tiny-cuda-nn extension for PyTorch.") + print("You can install it by running:") + print("============================================================") + print("tiny-cuda-nn$ cd bindings/torch") + print("tiny-cuda-nn/bindings/torch$ python setup.py install") + print("============================================================") + sys.exit() + +class SDF(nn.Module): + def __init__(self, hash=True, n_levels=12, log2_hashmap_size=15, base_resolution=16, smoothstep=False) -> None: + super().__init__() + ''' + self.encoder = tcnn.Encoding(3, { + "otype": "HashGrid" if hash else "DenseGrid", + "n_levels": n_levels, + "n_features_per_level": 8, + "log2_hashmap_size": log2_hashmap_size, + "base_resolution": base_resolution, + "per_level_scale": 1.5, + "interpolation": "Smoothstep" if smoothstep else "Linear" + }) + ''' + self.encoder = tcnn.Encoding(3, { + "otype": "HashGrid" if hash else "DenseGrid", + "n_levels": 16, #n_levels, + "n_features_per_level": 2, #8, + "log2_hashmap_size": 19, #log2_hashmap_size, + "base_resolution": 16, #base_resolution, + "per_level_scale": np.exp2(np.log2(2048 * 1 * 1 / 16) / (16 - 1)), #1.5, + #"interpolation": "Smoothstep" if smoothstep else "Linear" + }) + b_flag = False + self.decoder = nn.Sequential( + nn.Linear(self.encoder.n_output_dims, 64, bias=b_flag), + nn.Softplus(beta = 10.0), + nn.Linear(64, 64, bias=b_flag), + nn.Softplus(beta = 10.0), + nn.Linear(64, 64, bias=b_flag), + nn.Softplus(beta = 10.0), + nn.Linear(64, 1, bias=b_flag), + #nn.Softplus(beta = 10.0), + ) + + # write into numpy + #params_enc = self.encoder.params.data.clone().cpu().numpy() + #np.save('numpy/params_sdf_enc.npy', params_enc) + #print("params_enc[0:512]: ", params_enc[0:512]) + + params_enc = np.load('numpy/params_sdf_enc.npy') + #print("after loading from numpy: params[0:512]: ", params_enc[0:512]) + #print("params_enc.shape: ", params_enc.shape) + self.encoder.params.data = torch.from_numpy(params_enc) + + idx = 0 + for m in self.decoder.modules(): + if isinstance(m, nn.Linear): + if idx == 0: + #params_input = m.weight.data.clone().cpu().numpy() + #np.save("numpy/params_input.npy", params_input) + #print("params_input.shape: ", params_input.shape) + + params_input = np.load("numpy/params_input.npy") + m.weight.data = torch.from_numpy(params_input) + elif idx == 1: + #params_hidden_1 = m.weight.data.clone().cpu().numpy() + #np.save("numpy/params_hidden_1.npy", params_hidden_1) + #print("params_hidden_1.shape: ", params_hidden_1.shape) + + params_hidden_1 = np.load("numpy/params_hidden_1.npy") + m.weight.data = torch.from_numpy(params_hidden_1) + elif idx == 2: + #params_hidden_2 = m.weight.data.clone().cpu().numpy() + #np.save("numpy/params_hidden_2.npy", params_hidden_2) + #print("params_hidden_2.shape: ", params_hidden_2.shape) + + params_hidden_2 = np.load("numpy/params_hidden_2.npy") + m.weight.data = torch.from_numpy(params_hidden_2) + + else: + #params_output = m.weight.data.clone().cpu().numpy() + #np.save("numpy/params_output.npy", params_output) + #print("params_output.shape: ", params_output.shape) + + params_output = np.load("numpy/params_output.npy") + m.weight.data = torch.from_numpy(params_output) + idx += 1 + + def set_cuda_fun(func): + num = 0 # 初始化次数 + total_time = 0 + + def call_fun(*args, **kwargs): + nonlocal num + nonlocal total_time + if NoLog: + res = func(*args, **kwargs) + return res + + torch.cuda.synchronize() + start = time.perf_counter() # 代码执行开始时间 + num += 1 # 每次调用次数加1 + res = func(*args, **kwargs) + torch.cuda.synchronize() + end = time.perf_counter() # 代码执行结束时间 + # longtime = end - start + total_time += end - start + print("SDF in Pytorch 前向耗时: ", func.__name__, " 调用次数: ", num, " 累计时间:", total_time) + return res + return call_fun + + def forward(self, x): + encoded = self.encoder(x).to(dtype=torch.float) + sdf = self.decoder(encoded) + return sdf + + #@set_cuda_fun + def forward_with_nablas(self, x): + with torch.enable_grad(): + x = x.requires_grad_(True) + sdf = self.forward(x) + nablas = autograd.grad( + sdf, + x, + torch.ones_like(sdf, device=x.device), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + return sdf, nablas + +class SDF_MLP(nn.Module): + def __init__(self, hash=True, n_levels=12, log2_hashmap_size=15, base_resolution=16, smoothstep=False) -> None: + super().__init__() + + b_flag = False + self.decoder = nn.Sequential( + nn.Linear(8, 64, bias=b_flag), + nn.Softplus(beta = 10.0), + nn.Linear(64, 64, bias=b_flag), + nn.Softplus(beta = 10.0), + nn.Linear(64, 64, bias=b_flag), + nn.Softplus(beta = 10.0), + nn.Linear(64, 64, bias=b_flag), + nn.Softplus(beta = 10.0), + ) + + # read params from numpy files + params_input = np.load('params_input.npy', allow_pickle=True) # init for input layers + params = np.load('params.npy', allow_pickle=True) # init for hidden + output layers + idx = 0 + for m in self.decoder.modules(): + if isinstance(m, nn.Linear): + if idx == 0: + m.weight.data = torch.from_numpy(params_input[idx]) + else: + m.weight.data = torch.from_numpy(params[idx-1]) + idx += 1 + #print("m.weight.data: ", m.weight.data) + + def forward(self, x): + #encoded = self.encoder(x).to(dtype=torch.float) + sdf = self.decoder(x) + return sdf + + def forward_with_nablas(self, x): + with torch.enable_grad(): + x = x.requires_grad_(True) + sdf = self.forward(x) + nablas = autograd.grad( + sdf, + x, + torch.ones_like(sdf, device=x.device), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + return sdf, nablas + +class SDF_TCNN(nn.Module): + def __init__(self, hash=True, n_levels=12, log2_hashmap_size=15, base_resolution=16, smoothstep=False) -> None: + super().__init__() + + self.encoder_decoder = tcnn.NetworkWithInputEncoding( + encoding_config = { + "otype": "HashGrid" if hash else "DenseGrid", + "n_levels": 16, #n_levels, + "n_features_per_level": 2, #8, + "log2_hashmap_size": 19, #log2_hashmap_size, + "base_resolution": 16, #base_resolution, + "per_level_scale": np.exp2(np.log2(2048 * 1 * 1 / 16) / (16 - 1)), #1.5, + #"interpolation": "Smoothstep" if smoothstep else "Linear" + }, + n_input_dims=3, #self.encoder.n_output_dims, #3 + n_output_dims=1, #64, + network_config={ + "otype": "CutlassMLP", #"FullyFusedMLP", + "activation": "Softplus", + "output_activation": "None", #"Softplus", + "n_neurons": 64, + "n_hidden_layers": 3 #7 + }, + #dtype=torch.float32, + ) + + + # init encoder params from file + file_dir = os.getcwd() + + enc_params = np.load(os.path.join(file_dir, 'numpy/params_sdf_enc.npy')) + #print("enc_params: ", enc_params) + enc_p_torch = torch.from_numpy(enc_params) + + num_params = len(self.encoder_decoder.params) + num_enc_params = 12196240 + start_params = num_params - num_enc_params + # for i in range(num_enc_params): + # tmp = enc_p_torch[i].detach() + # self.encoder_decoder.params.data.index_fill_(0, torch.tensor(i+start_params, dtype=torch.int64).cuda(), torch.tensor(tmp).cuda()) + #print("self.encoder_decoder.params.data of encoder: ", self.encoder_decoder.params[-512:]) + + + # init decoder params from file + params_input = np.load(os.path.join(file_dir, 'numpy/params_input.npy')) + params_input_tensor = torch.from_numpy(params_input) + print("params_input.shape: ", params_input.shape) + print("total params: ", self.encoder_decoder.params.shape) + + ## init input layer, notice NOT to fill unneccessary blanks (0.0) + idx_tcnn = 0 + for i in range(params_input_tensor.shape[0]): + for j in range(params_input_tensor.shape[1]): + self.encoder_decoder.params[j+idx_tcnn].data.copy_(params_input_tensor[i, j]) + #print("params_input[",j+idx_tcnn,"] = ", self.encoder_decoder.params[j+idx_tcnn]) + idx_tcnn += 32 # column num in TCNN + + ## init hidden layers + params = np.load(os.path.join(file_dir, 'numpy/params_hidden_1.npy')) + params_tensor = torch.from_numpy(params) + idx_tcnn = 2048 # skip the input layer params + print("params_hidden_1.shape: ", params_tensor.shape) + for j in range(params_tensor.shape[0]): + for k in range(params_tensor.shape[1]): + self.encoder_decoder.params[k+idx_tcnn].data.copy_(params_tensor[j, k]) + idx_tcnn += params_tensor.shape[1] + + params = np.load(os.path.join(file_dir, 'numpy/params_hidden_2.npy')) + params_tensor = torch.from_numpy(params) + idx_tcnn = 6144 # skip the input layer params + print("params_hidden_2.shape: ", params_tensor.shape) + for j in range(params_tensor.shape[0]): + for k in range(params_tensor.shape[1]): + self.encoder_decoder.params[k+idx_tcnn].data.copy_(params_tensor[j, k]) + idx_tcnn += params_tensor.shape[1] + + params = np.load(os.path.join(file_dir, 'numpy/params_output.npy')) + params_tensor = torch.from_numpy(params) + idx_tcnn = 10240 # skip the input layer params + print("params_output.shape: ", params_tensor.shape) + for j in range(params_tensor.shape[0]): + for k in range(params_tensor.shape[1]): + self.encoder_decoder.params[k+idx_tcnn].data.copy_(params_tensor[j, k]) + idx_tcnn += params_tensor.shape[1] + #print("params from numpy: ", params_tensor[-1, -1]) + #print("self.encoder_decoder.params.data of output layer: ", self.encoder_decoder.params[-576:-512]) + + def set_cuda_fun(func): + num = 0 # 初始化次数 + total_time = 0 + + def call_fun(*args, **kwargs): + nonlocal num + nonlocal total_time + if NoLog: + res = func(*args, **kwargs) + return res + + torch.cuda.synchronize() + start = time.perf_counter() # 代码执行开始时间 + num += 1 # 每次调用次数加1 + res = func(*args, **kwargs) + torch.cuda.synchronize() + end = time.perf_counter() # 代码执行结束时间 + # longtime = end - start + total_time += end - start + print("SDF in TCNN 前向耗时: ", func.__name__, " 调用次数: ", num, " 累计时间:", total_time) + return res + return call_fun + + + def forward(self, x): + sdf = self.encoder_decoder(x) + return sdf + + #@set_cuda_fun + def forward_with_nablas(self, x): + with torch.enable_grad(): + x = x.requires_grad_(True) + sdf = self.forward(x) + nablas = autograd.grad( + sdf, + x, + torch.ones_like(sdf, device=x.device), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + #print("sdf.shape: ", sdf.shape, "x.shape: ", x.shape, "nablas.shape: ", nablas.shape) + + return sdf, nablas + +class SDF_MLP_TCNN(nn.Module): + def __init__(self, hash=True, n_levels=12, log2_hashmap_size=15, base_resolution=16, smoothstep=False) -> None: + super().__init__() + + self.decoder = tcnn.Network( + n_input_dims=8, + n_output_dims=64, + network_config={ + "otype": "CutlassMLP", + "activation": "Softplus", + "output_activation": "Softplus", + "n_neurons": 64, + "n_hidden_layers": 3 + } + ) + + # init decoder params from file + file_dir = os.getcwd() + params_input = np.load(os.path.join(file_dir, 'params_input.npy')) + params_input_tensor = torch.from_numpy(params_input) + + ## init input layer, notice NOT to fill unneccessary blanks (0.0) + idx_tcnn = 0 + for i in range(params_input_tensor.shape[1]): + for j in range(params_input_tensor.shape[2]): + self.decoder.params[j+idx_tcnn].data.copy_(params_input_tensor[0, i, j]) + #print("params_input[",j+idx_tcnn,"] = ", self.encoder_decoder.params[j+idx_tcnn]) + idx_tcnn += 16 # column num in TCNN + + ## init hidden + output layers + params = np.load(os.path.join(file_dir, 'params.npy')) + params_tensor = torch.from_numpy(params) + idx_tcnn = 1024 # skip the input layer params + for i in range(params_tensor.shape[0]): + for j in range(params_tensor.shape[1]): + for k in range(params_tensor.shape[2]): + self.decoder.params[k+idx_tcnn].data.copy_(params_tensor[i, j, k]) + idx_tcnn += params_tensor.shape[2] + + #for i in range(1024, 2048, 64): + # print("params_input[",i,",",i+63,"]: ", self.decoder.params[i:i+64]) + + def forward(self, x): + #sdf = self.encoder_decoder(x) + sdf = self.decoder(x) + + return sdf + + def forward_with_nablas(self, x): + with torch.enable_grad(): + x = x.requires_grad_(True) + sdf = self.forward(x) + nablas = autograd.grad( + sdf, + x, + torch.ones_like(sdf, device=x.device), + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + #print("sdf.shape: ", sdf.shape, "x.shape: ", x.shape, "nablas.shape: ", nablas.shape) + + return sdf, nablas + + + +if __name__ == '__main__': + """ + NOTE: Jianfei: I provide three testing tools for backward_backward functionality. + Play around as you want :) + 1. test_train(): train a toy SDF model with eikonal term. + 2. grad_check(): check backward_backward numerical correctness via torch.autograd.gradcheck. + 3. vis_graph(): visualize torch compute graph + """ + + def print_torch_SDF(model): + # for pytorch SDF + print("encoder grad: ", model.encoder.params.grad) + print("========= ========= ========= ========= ========= ========") + print("decoder grad: ") + for i in range(0, len(model.decoder), 2): + if i % 2 == 0: + #print("decoder w grad ", i, "-th layer:", model.decoder[i].weight.grad) + #print("decoder b grad ", i, "-th layer:", model.decoder[i].bias.grad) + w_grad = model.decoder[i].weight.grad + print("w_grad.shape: ", w_grad.shape) + r, c = w_grad.shape + if c >= 64: + r = 3 + else: + r = 8 + for i in range(0, r): + print("grad[",i*c,", ",(i+1)*c,"]: ", w_grad[i]) + print("========= ========= ========= ========= ========= ========") + + return + + def print_torch_SDF_MLP(model): + # for pytorch SDF_MLP + print("decoder grad: ") + for i in range(0, len(model.decoder), 2): + if i % 2 == 0: + print("decoder w ", i, "-th layer:") + w = model.decoder[i].weight.data + for d in range(2): + print("w[",d,"]:", w[d]) + print("========= ========= ========= ========= ========= ========") + + ''' + #print("decoder w grad ", i, "-th layer:", model.decoder[i].weight.grad) + print("decoder w grad ", i, "-th layer:") + w_grad = model.decoder[i].weight.grad + for d in range(2): #range(w_grad.shape[0]): + print("w_grad[",d,"]:", w_grad[d]) + print("========= ========= ========= ========= ========= ========") + ''' + + return + + def print_TCNN_layer_weight(prefix, weight, row, col): + # for printing weight of each linear layer in TCNN + print("decoder grad of layer - ", prefix, " - [",row,",", col,"]: ") + idx = 0 + for i in range(row): + print(prefix,"[", idx, ":", idx+col, "]: ", weight[idx:idx+col]) + idx = idx + col + + if idx > 128: + break + print("========= ========= ========= ========= ========= ========") + + return + + def print_TCNN_SDF(model): + # for SDF in TCNN + # print grad + print("encoder_decoder grad.shape: ", model.encoder_decoder.params.grad.shape) + enc_grad = model.encoder_decoder.params.grad[-512:] + print("encoder grad: ", enc_grad) + + dec_grad_0 = model.encoder_decoder.params.grad[0:1024] + dec_grad_1 = model.encoder_decoder.params.grad[1024:5120] # 1st hidden layer + dec_grad_2 = model.encoder_decoder.params.grad[5120:9216] # 2nd hidden layer + dec_grad_3 = model.encoder_decoder.params.grad[9216:13312] # 3rd hidden layer + + print_TCNN_layer_weight("dec_grad_0", dec_grad_0, 16, 64) + print_TCNN_layer_weight("dec_grad_1", dec_grad_1, 64, 64) + print_TCNN_layer_weight("dec_grad_2", dec_grad_2, 64, 64) + print_TCNN_layer_weight("dec_grad_3", dec_grad_3, 64, 64) + + # print weight + weight_0 = model.encoder_decoder.params.data[0:1024] + weight_1 = model.encoder_decoder.params.data[1024:5120] + weight_2 = model.encoder_decoder.params.data[5120:9216] + weight_3 = model.encoder_decoder.params.data[9216:13312] + + #print_TCNN_layer_weight("weight_0", weight_0, 16, 64) + #print_TCNN_layer_weight("weight_1", weight_1, 64, 64) + #print_TCNN_layer_weight("weight_2", weight_2, 64, 64) + #print_TCNN_layer_weight("weight_3", weight_3, 64, 64) + + return + + def print_TCNN_SDF_MLP(model): + # for SDF MLP in TCNN + print("decoder grad.shape: ", model.decoder.params.grad.shape) + + dec_grad_0 = model.decoder.params.grad[0:1024] + dec_grad_1 = model.decoder.params.grad[1024:5120] # 1st hidden layer + dec_grad_2 = model.decoder.params.grad[5120:9216] # 1st hidden layer + dec_grad_3 = model.decoder.params.grad[9216:13312] # 1st hidden layer + + print_TCNN_layer_weight("dec_grad_0", dec_grad_0, 16, 64) + print_TCNN_layer_weight("dec_grad_1", dec_grad_1, 64, 64) + print_TCNN_layer_weight("dec_grad_2", dec_grad_2, 64, 64) + print_TCNN_layer_weight("dec_grad_3", dec_grad_3, 64, 64) + + return + + def print_bwdbwd_time(scaler, loss, optimizer, total_time, num): + + torch.cuda.synchronize() + start = time.perf_counter() # 代码执行开始时间 + num += 1 # 每次调用次数加1 + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + + torch.cuda.synchronize() + end = time.perf_counter() # 代码执行结束时间 + total_time += end - start + print("SDF in backward_backward耗时: ", " 调用次数: ", num, " 累计时间:", total_time) + + return total_time, num + + def compute_normal(x:torch.Tensor, y): #[N,3] + x.requires_grad_(True) + y.requires_grad_(True) + d_output = torch.ones_like(y, requires_grad=False, device=y.device) + gradients = torch.autograd.grad( + outputs=y, + inputs=x, + grad_outputs=d_output, + create_graph=True, + retain_graph=True, + only_inputs=True)[0] + + return gradients + + def test_train(): + """ + train a toy SDF model with eikonal term. + """ + from tqdm import tqdm + device = torch.device("cuda") + #model = SDF_TCNN(True, n_levels=1, log2_hashmap_size=15, base_resolution=4, smoothstep=False).to(device) + model = SDF(True, n_levels=1, base_resolution=4).to(device) + + torch.cuda.nvtx.range_push('training_preparation') + torch.manual_seed(0) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + optimizer = Adam(model.parameters(), 2.0e-4) + + iter_i = 0 + fake_nablas = torch.ones([1], device='cuda') + torch.cuda.nvtx.range_pop() + + from torch.cuda.amp import GradScaler + from torch.cuda.amp import autocast + scaler = GradScaler() + + num = 0 + total_time = 0 + #with tqdm(range(10000)) as pbar: + with tqdm(range(10000)) as pbar: + for _ in pbar: + torch.cuda.nvtx.range_push('var_declaration') + # pytorch input + #x = torch.rand([102400, 3], dtype=torch.float32, device=device) - 0.5 + # TCNN input + x = torch.rand([102400, 3], dtype=torch.float16, device=device) - 0.5 + + # for detailed number comparison + #x = (torch.tensor([[0.3, 0.4, 0.5]], dtype=torch.float, device=device))#.requires_grad_(True) + #x = (torch.tensor([[0.3, 0.4, 0.5], [0.3, 0.4, 0.5], [0.3, 0.4, 0.5], [0.3, 0.4, 0.5]], dtype=torch.float16, device=device)).requires_grad_(True) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push('python_forward_start') + sdf, nablas = model.forward_with_nablas(x) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push('python_nablas_norm') + nablas_norm: torch.Tensor = nablas.norm(dim=-1) + torch.cuda.nvtx.range_pop() + + torch.cuda.nvtx.range_push('python_loss_compute') + optimizer.zero_grad() + ref_value = torch.sqrt((x**2).sum(-1)) - 0.5 + #normal_surface = compute_normal(x, sdf) + + eikonal_loss = torch.sum(torch.abs(torch.norm(nablas, p=2, dim=1) - 1)) * 0.0001 + #loss = eikonal_loss + loss = eikonal_loss #F.mse_loss(sdf[..., 0], ref_value) + eikonal_loss + + #loss.backward() + #optimizer.step() + + # print("===== within the ", iter_i, "-th iteration: =====") + # print("x: ", x) + # print("sdf: ", sdf) + # print("nablas: ", nablas) + # print("normal_surface: ", normal_surface) + # print("eikonal_loss: ", eikonal_loss) + #print("loss: ", loss) + # print("===== ========= ======== ========= ======== =====") + + ''' + # autocast training + optimizer.zero_grad() + print("sdf.shape: ", sdf.shape, "ref_value.shape: ", ref_value.shape) + with autocast(): + #loss = F.mse_loss(nablas_norm, fake_nablas, reduction='mean') + loss = F.mse_loss(sdf[..., 0], ref_value) * 100 + + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + ''' + + # print time consumed + # NOTE: if use print_bwdbwd_time, one needs to comment loss.backward() and optimizer.step() + total_time, num = print_bwdbwd_time(scaler, loss, optimizer, total_time, num) + + torch.cuda.nvtx.range_pop() + + # print grad details for pytorch SDF + #print_torch_SDF(model) + + # print grad details for TCNN SDF + #print_TCNN_SDF(model) + + pbar.set_postfix(loss=loss.item()) + + if (torch.isnan(loss)): + break + + iter_i = iter_i + 1 + # if iter_i > 2: + # break + + return + + def save_model_param(model): + file_dir = os.getcwd() + params = [] + params_input = [] + flag_layer = 0 + for m in model.decoder.modules(): + if isinstance(m, nn.Linear): + if flag_layer == 0: + params_input.append(m.weight.clone().detach().cpu().numpy()) + print("params of input layer m.weight: ", m.weight.clone().detach().cpu().numpy()) + elif flag_layer > 0: + params.append(m.weight.clone().detach().cpu().numpy()) + print("params of m.weight: ", m.weight.clone().detach().cpu().numpy()) + # params.append(m.bias.detach().numpy()) + flag_layer = 1 + np.save("test_params_iter_1_input.npy", params_input) + np.save("test_params_iter_1.npy", params) + + return + +if __name__ == "__main__": + + # test tcnn + test_train() + + diff --git a/src/cutlass_mlp.cu b/src/cutlass_mlp.cu index e9aa81d4..bb7db427 100644 --- a/src/cutlass_mlp.cu +++ b/src/cutlass_mlp.cu @@ -115,6 +115,19 @@ bool compute_layer( return can_fuse_activation; } +template +bool compute_fc_layer( + cudaStream_t stream, + const GPUMatrix& weights, + const GPUMatrixDynamic& input, + GPUMatrixDynamic& p +) { + // compute for forward values before activation + fc_multiply(stream, weights, input, p); + + return true; +} + template bool compute_inference_layer( cudaStream_t stream, @@ -314,6 +327,606 @@ void CutlassMLP::backward_impl( } } +// ======================= backward_backward_input_impl ======================= +// compute 2nd order dact +template +__global__ void compute_activation_backward_backward(uint32_t n_elements, Activation activation, T* p, T* res) { + uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= n_elements) return; + + switch (activation) { + case Activation::Softplus: + float K_ACT = 10.0; + float tmp = (float)p[i] * K_ACT;// + if (tmp > 10.0) { + tmp = 10.0; + } else if (tmp < -15.0) { + tmp = -15.0; + } + + float exp_tmp = expf(tmp); + float pow_tmp = (exp_tmp + 1.0) * (exp_tmp + 1.0); + float ddoutputdp_dp = exp_tmp / pow_tmp * K_ACT; + res[i] = (T)ddoutputdp_dp; + return; + + case Activation::ReLU: + res[i] = 0.0; + return; + + default: + // ERROR: this activation currently is not supported + res[i] = 0.0; + return; + } + + return; +} + +// activation Softplus 2nd order derivative in 1D: doutputdp_2 +template +__global__ void compute_ddoutputdp_dp_dLdoutput(uint32_t n_elements, T* dL_doutput, T* ddoutputdp_dp, T* doutputdp_2) { + uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= n_elements) return; + + doutputdp_2[i] = ddoutputdp_dp[i] * dL_doutput[i]; + + return; +} + +// fuse the process of computing ddoutputdp_dp_2 +template +__global__ void fuse_ddoutputdp_dp(uint32_t n_elements, T* dL_doutput, T* w_x_dL2_ddL1dinput, T* ddoutputdp_dp, T* result) { + uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= n_elements) return; + + T dL1doutput_x_w_x_dL2_ddL1dinput = dL_doutput[i] * w_x_dL2_ddL1dinput[i]; + result[i] = ddoutputdp_dp[i] * dL1doutput_x_w_x_dL2_ddL1dinput; + + return; +} + +// element-wise add back to dL_dinput +template +__global__ void element_wise_add(uint32_t n_elements, T* tmp, T* res) { + uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= n_elements) return; + + res[i] += tmp[i]; +} + +// element-wise copy from an CM tmp to RM dL_dinput +// row: the tmp rows +// col: the tmp cols +template +__global__ void element_wise_copy_CM_RM(uint32_t n_elements, uint32_t row, uint32_t col, T* tmp, T* res) { + uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= n_elements) return; + + uint32_t idx_i = i % row; + uint32_t idx_j = i / row; + uint32_t i_RM = idx_i * col + idx_j; + res[i_RM] = tmp[i]; +} + +// compute d_doutputdp2_dinput +template +__global__ void multiply_w_RM(uint32_t n_elements, uint32_t w_row, uint32_t w_col, T* weights, T* doutputdp_2, T* result) { + uint32_t i = blockIdx.x * blockDim.x + threadIdx.x; + if (i >= n_elements) return; + + uint32_t idx_doutputdp = i % w_row; + T tmp = doutputdp_2[idx_doutputdp]; + uint32_t i_RM = (i % w_row) * w_col + i / w_row; + result[i] = weights[i_RM] * tmp; +} + +// compute dL2_ddL1doutput: activation Softplus/ReLU derivative, doutputdp and then multiply dL2_ddL1doutput_tmp +template +bool compute_dL2_ddL1doutput( + cudaStream_t stream, + bool is_inference, + Activation activation, + GPUMatrix& weights, + const GPUMatrix& p, + const GPUMatrixDynamic& dL_ddLdinput, + GPUMatrixDynamic& dL2_ddL1doutput, + GradientMode param_gradients_mode +) { + // dL2_ddL1doutput = weight x dL_ddLdinput where x is matrix multiply + if (weights.layout() == CM) { + fc_multiply(stream, weights.cm(), dL_ddLdinput, dL2_ddL1doutput); + } else { + fc_multiply(stream, weights.rm(), dL_ddLdinput, dL2_ddL1doutput); + } + + // if activation is None, dL2_ddL1doutput = weight x dL2_ddL1dinput + if (activation == Activation::None) { + return true; + } + + // dL2_ddL1doutput = dL2_ddL1doutput · doutputdp, where p is linear output, · is dot product + activation_backward_gpu(stream, activation, p, dL2_ddL1doutput); + + return true; +} + +template +bool compute_dL2dw( + cudaStream_t stream, + bool is_inference, + Activation activation, + GPUMatrix& weights, + const GPUMatrixDynamic& input, + const GPUMatrixDynamic& p, // const GPUMatrix& p + const GPUMatrixDynamic& dL_dp, + const GPUMatrixDynamic& dL_doutput, + const GPUMatrixDynamic& dL_ddLdinput, // dL2_ddL1dinput + const GPUMatrixDynamic& ddoutputdp_dp, // ddoutputdp_dp + GPUMatrixDynamic& dL2_ddL1doutput, // pointer better + GPUMatrix& weight_gradient, // pointer better + GradientMode param_gradients_mode +) { + // dL2dw_1 = ddL1dinput_dw x dL2_ddL1dinput + uint32_t batch_size = dL_ddLdinput.n(); + int split_k_factor = batch_size / std::min((uint32_t)(1 << 12), batch_size); + const float param_gradient_beta = param_gradients_mode == GradientMode::Accumulate ? 1.0f : 0.0f; + + //cudaStream_t stream_dL2dw_1; + if (param_gradients_mode != GradientMode::Ignore) { + fc_multiply_split_k(stream, dL_dp, dL_ddLdinput.transposed(), weight_gradient, split_k_factor, param_gradient_beta); + } + + // when activation is None, don't have to compute dL2dw_2 + if (activation == Activation::None) { + return true; + } + + //dL2dw_2 = torch.matmul(torch.transpose(ddoutputdp_dp_2, 0, 1), input) + GPUMatrixDynamic w_x_dL2_ddL1dinput = {weights.m(), dL_ddLdinput.n(), stream}; + + if (weights.layout() == CM) { + fc_multiply(stream, weights.cm(), dL_ddLdinput, w_x_dL2_ddL1dinput); + } else { + fc_multiply(stream, weights.rm(), dL_ddLdinput, w_x_dL2_ddL1dinput); + } + + // fuse kernels to compute ddoutputdp_dp_2 + GPUMatrixDynamic ddoutputdp_dp_2_fuse = {p.rows(), p.cols(), stream}; + linear_kernel(fuse_ddoutputdp_dp, 0, stream, ddoutputdp_dp_2_fuse.n_elements(), dL_doutput.data(), w_x_dL2_ddL1dinput.data(), ddoutputdp_dp.data(), ddoutputdp_dp_2_fuse.data()); + + if (param_gradients_mode != GradientMode::Ignore) { + fc_multiply_split_k(stream, ddoutputdp_dp_2_fuse, input.transposed(), weight_gradient, split_k_factor, 1.0); + } + + return true; +} + +template +bool compute_dL2dinput( + cudaStream_t stream, + bool is_inference, + Activation activation, + GPUMatrix& weights, + const GPUMatrixDynamic& p, + const GPUMatrixDynamic& dL_doutput, + const GPUMatrixDynamic& dL_ddLdinput, // dL2_ddL1dinput + const GPUMatrixDynamic& ddoutputdp_dp, // ddoutputdp_dp + GPUMatrixDynamic& dL_dinput, + GradientMode param_gradients_mode +) { + // no dL2dinput when activation is None + if (activation == Activation::None) { + return true; + } + + // compute weights x dL2_ddL1dinput in advance + GPUMatrixDynamic w_x_dL2ddL1dinput = {weights.rows(), dL_ddLdinput.cols(), stream}; + + if (weights.layout() == CM) { + fc_multiply(stream, weights.cm(), dL_ddLdinput, w_x_dL2ddL1dinput); + } else { + fc_multiply(stream, weights.rm(), dL_ddLdinput, w_x_dL2ddL1dinput); + } + + // doutputdp_2 = (p.rows(), batch_size) + GPUMatrixDynamic doutputdp_2 = {p.rows(), p.cols(), stream}; + + // compute doutputdp_2 in 1 dimension (1, 64) x batch_size, and multiply dL_doutput + linear_kernel(compute_ddoutputdp_dp_dLdoutput, 0, stream, doutputdp_2.n_elements(), dL_doutput.data(), ddoutputdp_dp.data(), doutputdp_2.data()); + + GPUMatrixDynamic ddoutputdp_dinput = {weights.rows(), weights.cols(), stream}; + // ddoutputdp_dinput[i, j] = tmp * w[i, j] where tmp = doutputdp_2_sum[i, 0] + linear_kernel(multiply_w_RM, 0, stream, weights.n_elements(), weights.rows(), weights.cols(), weights.data(), doutputdp_2.data(), ddoutputdp_dinput.data()); + + // dL_dinput = ddoutputdp_dinput_xw x dL2d_dL1dinput + if (ddoutputdp_dinput.transposed().layout() == CM) { + fc_multiply(stream, ddoutputdp_dinput.transposed().cm(), w_x_dL2ddL1dinput, dL_dinput); + } else { + fc_multiply(stream, ddoutputdp_dinput.transposed().rm(), w_x_dL2ddL1dinput, dL_dinput); + } + + return true; +} + +// prepare variables needed for backward temporary +template +bool CutlassMLP::prepare_backward_variables( + cudaStream_t stream, + const std::vector>& output, // const GPUMatrix& p + const GPUMatrixDynamic& dL_doutput, + GPUMatrixDynamic& backward_output_tmp, + std::vector>& dL1dp, + std::vector>& dL1doutput, + bool use_inference_params +) { + // compute dL1dp and dL1doutput + uint32_t batch_size = dL_doutput.n(); + uint32_t bwd_tmp_idx = m_n_hidden_matmuls + 1 - 1; + uint32_t bwd_dL1dp_idx = 0; + + // compute dL1dp and dL1dinput of output layer + const GPUMatrixDynamic& tmp_dL_doutput = m_output_activation == Activation::None ? dL_doutput : backward_output_tmp; + + // directly compute dL1dp_i-1 + fc_multiply(stream, output_weight_matrix(use_inference_params).transposed(), tmp_dL_doutput, output.at(bwd_tmp_idx), dL1dp.at(bwd_dL1dp_idx), m_activation, true); + // extra computing once to save dL1doutput of each layer + fc_multiply(stream, output_weight_matrix(use_inference_params).transposed(), tmp_dL_doutput, dL1doutput.at(bwd_dL1dp_idx)); + + bwd_tmp_idx -= m_can_fuse_activation ? 1 : 2; + ++bwd_dL1dp_idx; + + // layers + for (uint32_t i = 0; i < m_n_hidden_matmuls; ++i) { + uint32_t matrix_idx = m_n_hidden_matmuls - i - 1; + + fc_multiply(stream, weight_matrix_at(use_inference_params, matrix_idx).transposed(), dL1dp.at(bwd_dL1dp_idx-1), output.at(bwd_tmp_idx), dL1dp.at(bwd_dL1dp_idx), m_activation, true); + // extra computing once to save dL1doutput of each layer + fc_multiply(stream, weight_matrix_at(use_inference_params, matrix_idx).transposed(), dL1dp.at(bwd_dL1dp_idx-1), dL1doutput.at(bwd_dL1dp_idx)); + + bwd_tmp_idx -= m_can_fuse_activation ? 1 : 2; + ++bwd_dL1dp_idx; + } + + return true; +} + +template +void CutlassMLP::backward_backward_input_impl( + cudaStream_t stream, + const Context& ctx, + const GPUMatrixDynamic& input, + const GPUMatrixDynamic& dL_ddLdinput, + const GPUMatrixDynamic& dL_doutput, + GPUMatrixDynamic* dL_ddLdoutput, + GPUMatrixDynamic* dL_dinput, + bool use_inference_params, + GradientMode param_gradients_mode +) { + uint32_t batch_size = dL_doutput.n(); + + // dL2_ddL1dinput vector of each layer + std::vector> dL2_ddL1doutput(num_forward_activations()); + for (uint32_t i = 0; i < num_forward_activations(); ++i) { + dL2_ddL1doutput[i] = GPUMatrix{m_network_width, batch_size, stream}; + } + + // dL2dinput vector of each layer + std::vector> dL2dinput; + if (m_n_hidden_layers == 0) { + dL2dinput.resize(1); + dL2dinput[0] = GPUMatrixDynamic{m_input_width, batch_size, stream}; + } else { + dL2dinput.resize(m_n_hidden_matmuls + 2); + dL2dinput[0] = GPUMatrixDynamic{m_input_width, batch_size, stream}; + for (uint32_t i = 0; i < m_n_hidden_matmuls+1; ++i) { + dL2dinput[i+1] = GPUMatrixDynamic{m_network_width, batch_size, stream}; + } + } + // NOTE: this code is for removing NaN initialization of dL2dinput[last_layer], must keep this code + CUDA_CHECK_THROW(cudaMemsetAsync(dL2dinput[m_n_hidden_matmuls+1].data(), 0, dL2dinput[m_n_hidden_matmuls+1].n_elements() * sizeof(T), stream)); + + // 2nd order derivative of activation for each layer + std::vector> ddoutputdp_dp; + if (m_n_hidden_layers == 0) { + ddoutputdp_dp.resize(1); + ddoutputdp_dp[0] = GPUMatrixDynamic{m_padded_output_width, batch_size, stream}; + } else { + ddoutputdp_dp.resize(m_n_hidden_matmuls + 2); + for (uint32_t i = 0; i < m_n_hidden_matmuls+1; ++i) { + ddoutputdp_dp[i] = GPUMatrixDynamic{m_network_width, batch_size, stream}; + } + ddoutputdp_dp[m_n_hidden_matmuls+1] = GPUMatrixDynamic{m_padded_output_width, batch_size, stream}; + } + CUDA_CHECK_THROW(cudaMemsetAsync(ddoutputdp_dp[m_n_hidden_matmuls+1].data(), 0, ddoutputdp_dp[m_n_hidden_matmuls+1].n_elements() * sizeof(T), stream)); + + // declare variables for fc_output, aka p, the result right after linear layer + std::vector> fc_output(num_forward_activations()); + for (uint32_t i = 0; i < num_forward_activations(); ++i) { + fc_output[i] = GPUMatrix{m_network_width, batch_size, stream}; // GPUMatrix{m_network_width, batch_size, stream}; + } + GPUMatrix fc_last_output(m_padded_output_width, batch_size, stream); // p of output layer + + // declare variables for dL1dp and dL1doutput + GPUMatrixDynamic backward_output_tmp; // dL1dp of input layer + std::vector> dL1dp; // dL1dp: reverse order of all layers except for input layer + std::vector> dL1doutput; // dL1doutput of each layer: reverse order + + // initialization for dL1dp, dL1doutput + dL1dp.resize(num_forward_activations()); + dL1doutput.resize(num_forward_activations()); + for (uint32_t i = 0; i < num_forward_activations(); ++i) { + dL1dp[i] = GPUMatrix{m_network_width, batch_size, stream}; + dL1doutput[i] = GPUMatrix{m_network_width, batch_size, stream}; + } + + // prepare temporary variables for 2nd order derivative + const auto& forward = dynamic_cast(ctx); + + // multi-stream to compute + { + std::vector multi_streams_pre; + multi_streams_pre.emplace_back(stream, 2); + + // compute dL1dp and dL1doutput + bool bwd_prep = prepare_backward_variables ( + multi_streams_pre.back().get(1), + forward.hidden, + dL_doutput, + backward_output_tmp, + dL1dp, + dL1doutput, + use_inference_params + ); + + // compute fc_output(p) values + uint32_t forward_idx = 0; + multi_streams_pre.emplace_back(stream, 2); + bool is_returned = compute_fc_layer( + multi_streams_pre.back().get(1), //stream, + input_weight_matrix(use_inference_params), + input, + fc_output.at(forward_idx) // p: linear output + ); + + // compute ddoutputdp_dp of input layer + linear_kernel(compute_activation_backward_backward, 0, + multi_streams_pre.back().get(1), + fc_output.at(forward_idx).n_elements(), + m_activation, + fc_output.at(forward_idx).data(), + ddoutputdp_dp[forward_idx].data() + ); + forward_idx++; + + for (uint32_t i = 0; i < m_n_hidden_matmuls; i++) { + multi_streams_pre.emplace_back(stream, 2); + is_returned = compute_fc_layer( + multi_streams_pre.back().get(1), //stream, + weight_matrix_at(use_inference_params, i), + forward.hidden.at(i), // input + fc_output.at(forward_idx) + ); + linear_kernel(compute_activation_backward_backward, 0, + multi_streams_pre.back().get(1), + fc_output.at(forward_idx).n_elements(), + m_activation, + fc_output.at(forward_idx).data(), + ddoutputdp_dp[forward_idx].data() + ); + forward_idx++; + } + // output layer + multi_streams_pre.emplace_back(stream, 2); + is_returned = compute_fc_layer( + multi_streams_pre.back().get(1), //stream, + output_weight_matrix(use_inference_params), + forward.hidden.at(m_n_hidden_matmuls), + fc_last_output // p value + ); + if (m_output_activation != Activation::None) { + linear_kernel(compute_activation_backward_backward, 0, + multi_streams_pre.back().get(1), + fc_last_output.n_elements(), + m_output_activation, + fc_last_output.data(), + ddoutputdp_dp[forward_idx].data() + ); + } + forward_idx++; + } + + { // 2nd order derivative computation: local definition for multi-stream + // init for backward_backward computing + std::vector multi_streams; + uint32_t tmp_idx = 0, bwd_idx = 0, bwd_bwd_idx = 0; + + // input layer + // dL2dw for input layer + if (param_gradients_mode != GradientMode::Ignore) { + multi_streams.emplace_back(stream, 2); + bool return_tmp_dL2dw = compute_dL2dw( + multi_streams.back().get(1), + false, + m_activation, + input_weight_matrix(use_inference_params), + input, // input + fc_output.at(tmp_idx), // p + dL1dp.at(m_n_hidden_matmuls), // dL1dp + dL1doutput.at(m_n_hidden_matmuls), // dL1doutput + dL_ddLdinput, // dL2_ddL1dinput + ddoutputdp_dp.at(0), // ddoutputdp_dp + dL2_ddL1doutput.at(bwd_bwd_idx), // dL2_ddL1doutput + input_gradient_matrix(), // gradient matrix + param_gradients_mode + ); + } + + // 2nd order to dL2dinput of the 1st layer + if (dL_dinput) { + bool return_tmp = compute_dL2dinput( + stream, + false, + m_activation, + input_weight_matrix(use_inference_params), // weights + fc_output.at(0), // p + dL1doutput.at(m_n_hidden_matmuls), // dL_doutput + dL_ddLdinput, + ddoutputdp_dp.at(0), // ddoutputdp_dp + dL2dinput.at(0), // *dL_dinput, + param_gradients_mode + ); + } + + bool return_tmp_dL2_ddL1doutput = compute_dL2_ddL1doutput( + stream, + false, + m_activation, + input_weight_matrix(use_inference_params), + fc_output.at(tmp_idx), + dL_ddLdinput, + dL2_ddL1doutput.at(bwd_bwd_idx), + param_gradients_mode + ); + + // TODO: hidden_layer == 0 + tmp_idx ++; + bwd_bwd_idx++; + // 2nd order derivative to dL2dinput and dL2dw + for (uint32_t i = 0; i < m_n_hidden_matmuls; ++i) { + // 2nd order impact to dL2dw + multi_streams.emplace_back(stream, 2); + bool return_hidden_dL2dw = compute_dL2dw( + multi_streams.back().get(1), // stream + false, + m_activation, + weight_matrix_at(use_inference_params, i), + forward.hidden.at(tmp_idx-1), // input + fc_output.at(tmp_idx), // p + dL1dp.at(m_n_hidden_matmuls - bwd_idx - 1), // dL1dp + dL1doutput.at(m_n_hidden_matmuls - bwd_idx - 1), // dL1doutput + dL2_ddL1doutput.at(bwd_bwd_idx-1), // dL2d_dL1dinput + ddoutputdp_dp.at(i+1), // ddoutputdp_dp + dL2_ddL1doutput.at(bwd_bwd_idx), // dL2_ddL1doutput + gradient_matrix_at(i), // gradient matrix + param_gradients_mode + ); + + //multi_streams.emplace_back(stream, 2); + bool return_hidden_dL2dinput = compute_dL2dinput( + stream, //multi_streams.back().get(3), // stream, + false, + m_activation, + weight_matrix_at(use_inference_params, i), // weights + fc_output.at(i+1), // p + dL1doutput.at(m_n_hidden_matmuls - i - 1), // dL_doutput + dL2_ddL1doutput.at(bwd_bwd_idx-1), // dL2_ddL1dinput of current layer + ddoutputdp_dp.at(i+1), // ddoutputdp_dp + dL2dinput.at(i+1), // dL2dinput of current layer + param_gradients_mode + ); + + bool return_hidden_dL2_ddL1doutput = compute_dL2_ddL1doutput( + stream, //multi_streams.back().get(2), //stream, + false, // is_inference + m_activation, + weight_matrix_at(use_inference_params, i), + fc_output.at(tmp_idx), // p, + dL2_ddL1doutput.at(bwd_bwd_idx-1), // dL2_ddL1dinput + dL2_ddL1doutput.at(bwd_bwd_idx), // dL2_ddL1doutput + param_gradients_mode + ); + + tmp_idx++; // tmp_idx += can_fused ? 1 : 2; + bwd_bwd_idx++; + bwd_idx++; + } + + // Output layer weight + multi_streams.emplace_back(stream, 2); + bool return_output_dL2dw = compute_dL2dw( + multi_streams.back().get(1), //stream, + false, + m_output_activation, + output_weight_matrix(use_inference_params), + forward.hidden.at(tmp_idx-1), // input + fc_last_output, // p + m_output_activation == Activation::None ? dL_doutput : backward_output_tmp, // dL1dp + dL_doutput, // dL1doutput + dL2_ddL1doutput.at(bwd_bwd_idx-1), // dL2d_dL1dinput + ddoutputdp_dp.at(m_n_hidden_matmuls+1), // ddoutputdp_dp + *dL_ddLdoutput, // dL2_ddL1doutput + output_gradient_matrix(), + param_gradients_mode + ); + + // Output layer dL2dinput + bool return_output_dL2dinput = compute_dL2dinput( + stream, + false, // is_inference + m_output_activation, + output_weight_matrix(use_inference_params), // weights + fc_last_output, // p + m_output_activation == Activation::None ? dL_doutput : backward_output_tmp, // dL_doutput + dL2_ddL1doutput.at(bwd_bwd_idx-1), // dL2_ddL1dinput of current layer + ddoutputdp_dp.at(m_n_hidden_matmuls+1), // ddoutputdp_dp + dL2dinput.at(m_n_hidden_matmuls+1), // dL2dinput of current layer + param_gradients_mode + ); + + if (dL_ddLdoutput) { // if dL_ddLdoutput is not nullptr + bool return_output_dL2_ddL1doutput = compute_dL2_ddL1doutput( + stream, + false, // is_inference + m_output_activation, + output_weight_matrix(use_inference_params), + fc_last_output, // p + dL2_ddL1doutput.at(bwd_bwd_idx-1), // dL2d_dL1dinput + *dL_ddLdoutput, // dL2_ddL1doutput + param_gradients_mode + ); + } + } + + // 1st order backward of dL2dinput and dL2dw in reverse order + std::vector multi_streams_1st; + multi_streams_1st.emplace_back(stream, 2); + int split_k_factor = batch_size / std::min((uint32_t)(1 << 12), batch_size); + const float param_gradient_beta = 1.0f; // param_gradients_mode == GradientMode::Accumulate + + for (uint32_t i = m_n_hidden_layers; i > 0; i--) { + + GPUMatrixDynamic dL2dp = GPUMatrix{fc_output.at(i-1).rows(), fc_output.at(i-1).cols(), stream}; + activation_backward_gpu(stream, dL2dp.n_elements(), m_activation, fc_output.at(i-1).data(), dL2dinput.at(i).data(), dL2dp.data()); + + // compute 1st order dL2dw + if (param_gradients_mode != GradientMode::Ignore) { + if (i - 1) { + fc_multiply_split_k(multi_streams_1st.back().get(1), dL2dp, forward.hidden.at(i-2).transposed(), gradient_matrix_at(i-2), split_k_factor, param_gradient_beta); + } else { // i - 1 == 0: input layer + fc_multiply_split_k(multi_streams_1st.back().get(1), dL2dp, input.transposed(), input_gradient_matrix(), split_k_factor, param_gradient_beta); + } + } + + if (i - 1) { + GPUMatrixDynamic dL_dinput_1st_order = GPUMatrix{dL2dinput.at(i-1).rows(), dL2dinput.at(i-1).cols(), stream}; + // weight_matrix_at(use_inference_params, i-2).transposed().layout() == CM + fc_multiply(stream, weight_matrix_at(use_inference_params, i-2).transposed().cm(), dL2dp, dL_dinput_1st_order); + linear_kernel(element_wise_add, 0, stream, dL2dinput.at(i-1).n() * dL2dinput.at(i-1).m(), dL_dinput_1st_order.data(), dL2dinput.at(i-1).data()); + } else if (i - 1 == 0 && dL_dinput) { // input layer + GPUMatrixDynamic dL_dinput_1st_order = GPUMatrix{dL2dinput.at(i-1).rows(), dL2dinput.at(i-1).cols(), stream}; + // input_weight_matrix(use_inference_params).transposed().layout() == CM + fc_multiply(stream, input_weight_matrix(use_inference_params).transposed().cm(), dL2dp, dL_dinput_1st_order); + linear_kernel(element_wise_add, 0, stream, dL2dinput.at(i-1).n() * dL2dinput.at(i-1).m(), dL_dinput_1st_order.data(), dL2dinput.at(i-1).data()); + } + } + + if (dL_dinput) { + // sync up back to the gradient and *dL_dinput + linear_kernel(element_wise_copy_CM_RM, 0, stream, dL_dinput->rows() * dL_dinput->cols(), dL2dinput.at(0).rows(), dL2dinput.at(0).cols(), dL2dinput.at(0).data(), dL_dinput->data()); + } +} + template std::unique_ptr::ForwardContext> CutlassMLP::allocate_forward_buffers(cudaStream_t stream, uint32_t batch_size) { auto forward = std::make_unique();