Skip to content

Is it able to convert point cloud model to tensorflow Probabilistic Deep Learning model ? #1

@kuangzy2011

Description

@kuangzy2011

Hi Frightera:
I am doing some test with pointnet to classify point cloud object. And there might be some other unknown object are not listed in training set.
So I tried to convert pointnet model(https://keras.io/examples/vision/pointnet/) to tensorflow probability model by refering your Simple Fully Probabilistic Bayesian CNN, and hope the model can say "I don't know" if target object is not in training model, but the results of the training were less than ideal.
Is it possible to convert point cloud model to tensorflow Probabilistic Deep Learning model ?

Here is my testing model:
NUM_EXAMPLES = 83820
NUM_CLASSES = 10
BATCH_SIZE = 64
NUM_POINTS = 231

# Revise reinterpreted_batch_ndims arg. Needed for custom prior & posterior.
shape = (3, )
dtype = tf.float64
distribution = tfd.Normal(loc = tf.zeros(shape, dtype), scale = tf.ones(shape, dtype))
# batch_ndims = tf.size(distribution.batch_shape_tensor())

for i in range(len(shape) + 1):
    print('reinterpreted_batch_ndims: %d:' %(i))
    independent_dist = tfd.Independent(distribution, reinterpreted_batch_ndims = i)
    samples = independent_dist.sample()
    print('batch_shape: {}' 
          ' event_shape: {}' 
          ' Sample shape: {}'.format(independent_dist._batch_shape(),
                                independent_dist._event_shape(),
                                samples.shape))
    print('Samples:', samples.numpy(), '\n')
    

# The default posterior is Normal and if we use a laplace prior, we need to
# approximate the KL. If we try to compute KL with that 2 distributions, we will get:
# Error:
# No KL(distribution_a || distribution_b) registered for distribution_a type Normal and distribution_b type Laplace
# Call arguments received:
#   • inputs=tf.Tensor(shape=(None, 28, 28, 1), dtype=float32)
def approximate_kl(q, p, q_tensor):
    return tf.reduce_mean(q.log_prob(q_tensor) - p.log_prob(q_tensor))

divergence_fn = lambda q, p, q_tensor : approximate_kl(q, p, q_tensor) / NUM_EXAMPLES

def conv_reparameterization_layer(filters, kernel_size, padding):
    # For simplicity, we use default prior and posterior.
    # In the next parts, we will use custom mixture prior and posteriors.
    return tfpl.Convolution1DReparameterization(
        filters = filters,
        kernel_size = kernel_size,
        activation = None, 
        padding = padding,
        kernel_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
        kernel_prior_fn = tfpl.default_multivariate_normal_fn,
        
        bias_prior_fn = tfpl.default_multivariate_normal_fn,
        bias_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
        
        kernel_divergence_fn = divergence_fn,
        bias_divergence_fn = divergence_fn)

def dense_reparameterization_layer(filters):
    return tfpl.DenseReparameterization(
        units = filters, activation = None,
        kernel_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
        kernel_prior_fn = tfpl.default_multivariate_normal_fn,
    
        bias_prior_fn = tfpl.default_multivariate_normal_fn,
        bias_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
    
        kernel_divergence_fn = divergence_fn,
        bias_divergence_fn = divergence_fn)
    
def nll(y_true, y_pred):
    return -y_pred.log_prob(y_true)
        
        
def conv_bn(x, filters):
    x = conv_reparameterization_layer(filters, kernel_size=1, padding='valid')(x)
    x = layers.BatchNormalization(momentum=0.0)(x)
    return layers.Activation("relu")(x)

def dense_bn(x, filters):
    x = dense_reparameterization_layer(filters)(x)
    x = layers.BatchNormalization(momentum=0.0)(x)
    return layers.Activation("relu")(x)
    
    
class OrthogonalRegularizer(keras.regularizers.Regularizer):
    def __init__(self, num_features, l2reg=0.001):
        self.num_features = num_features
        self.l2reg = l2reg
        self.eye = tf.eye(num_features)

    def __call__(self, x):
        x = tf.reshape(x, (-1, self.num_features, self.num_features))
        xxt = tf.tensordot(x, x, axes=(2, 2))
        xxt = tf.reshape(xxt, (-1, self.num_features, self.num_features))
        return tf.reduce_sum(self.l2reg * tf.square(xxt - self.eye))



def tnet(inputs, num_features):

    # Initalise bias as the indentity matrix
    bias = keras.initializers.Constant(np.eye(num_features).flatten())
    reg = OrthogonalRegularizer(num_features)

    x = conv_bn(inputs, 32)
    x = conv_bn(x, 64)
    x = conv_bn(x, 512)
    x = layers.GlobalMaxPooling1D()(x)
    x = dense_bn(x, 256)
    x = dense_bn(x, 128)
    x = layers.Dense(
            num_features * num_features,
            kernel_initializer="zeros",
            bias_initializer=bias,
            activity_regularizer=reg,
        )(x)
    feat_T = layers.Reshape((num_features, num_features))(x)
    # Apply affine transformation to input features
    return layers.Dot(axes=(2, 1))([inputs, feat_T])

def create_model_normal(num_points):
    inputs = keras.Input(shape=(num_points, 3))

    x = tnet(inputs, 3)
    x = conv_bn(x, 32)
    x = conv_bn(x, 32)
    x = tnet(x, 32)
    x = conv_bn(x, 32)
    x = conv_bn(x, 64)
    x = conv_bn(x, 512)
    x = layers.GlobalMaxPooling1D()(x)
    x = dense_bn(x, 256)
    x = layers.Dropout(0.25)(x)
    x = dense_bn(x, 128)
    x = layers.Dropout(0.25)(x)
    
    dense = tfpl.DenseReparameterization(
            units = tfpl.OneHotCategorical.params_size(NUM_CLASSES), activation = None,
            kernel_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
            kernel_prior_fn = tfpl.default_multivariate_normal_fn,
        
            bias_prior_fn = tfpl.default_multivariate_normal_fn,
            bias_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
        
            kernel_divergence_fn = divergence_fn,
            bias_divergence_fn = divergence_fn)(x)
    outputs = tfpl.OneHotCategorical(NUM_CLASSES)(dense)
    model = keras.Model(inputs=inputs, outputs=outputs, name="pointnet")
    return model

######################################################################################

def create_model(num_points):
    return create_model_normal(num_points)

model = create_model(NUM_POINTS)
model.summary()

>>Model: "pointnet"
>>__________________________________________________________________________________________________
>>Layer (type)                    Output Shape         Param #     Connected to                     
>>==================================================================================================
>>input_3 (InputLayer)            [(None, 231, 3)]     0                                            
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_22 (C (None, 231, 32)      256         input_3[0][0]                    
>>__________________________________________________________________________________________________
>>batch_normalization_34 (BatchNo (None, 231, 32)      128         conv1d_reparameterization_22[0][0
>>__________________________________________________________________________________________________
>>activation_34 (Activation)      (None, 231, 32)      0           batch_normalization_34[0][0]     
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_23 (C (None, 231, 64)      4224        activation_34[0][0]              
>>__________________________________________________________________________________________________
>>batch_normalization_35 (BatchNo (None, 231, 64)      256         conv1d_reparameterization_23[0][0
>>__________________________________________________________________________________________________
>>activation_35 (Activation)      (None, 231, 64)      0           batch_normalization_35[0][0]     
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_24 (C (None, 231, 512)     66560       activation_35[0][0]              
>>__________________________________________________________________________________________________
>>batch_normalization_36 (BatchNo (None, 231, 512)     2048        conv1d_reparameterization_24[0][0
>>__________________________________________________________________________________________________
>>activation_36 (Activation)      (None, 231, 512)     0           batch_normalization_36[0][0]     
>>__________________________________________________________________________________________________
>>global_max_pooling1d_6 (GlobalM (None, 512)          0           activation_36[0][0]              
>>__________________________________________________________________________________________________
>>dense_reparameterization_18 (De (None, 256)          262656      global_max_pooling1d_6[0][0]     
>>__________________________________________________________________________________________________
>>batch_normalization_37 (BatchNo (None, 256)          1024        dense_reparameterization_18[0][0]
>>__________________________________________________________________________________________________
>>activation_37 (Activation)      (None, 256)          0           batch_normalization_37[0][0]     
>>__________________________________________________________________________________________________
>>dense_reparameterization_19 (De (None, 128)          65792       activation_37[0][0]              
>>__________________________________________________________________________________________________
>>batch_normalization_38 (BatchNo (None, 128)          512         dense_reparameterization_19[0][0]
>>__________________________________________________________________________________________________
>>activation_38 (Activation)      (None, 128)          0           batch_normalization_38[0][0]     
>>__________________________________________________________________________________________________
>>dense (Dense)                   (None, 9)            1161        activation_38[0][0]              
>>__________________________________________________________________________________________________
>>reshape_4 (Reshape)             (None, 3, 3)         0           dense[0][0]                      
>>__________________________________________________________________________________________________
>>dot_4 (Dot)                     (None, 231, 3)       0           input_3[0][0]                    
>>                                                                 reshape_4[0][0]                  
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_25 (C (None, 231, 32)      256         dot_4[0][0]                      
>>__________________________________________________________________________________________________
>>batch_normalization_39 (BatchNo (None, 231, 32)      128         conv1d_reparameterization_25[0][0
>>__________________________________________________________________________________________________
>>activation_39 (Activation)      (None, 231, 32)      0           batch_normalization_39[0][0]     
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_26 (C (None, 231, 32)      2112        activation_39[0][0]              
>>__________________________________________________________________________________________________
>>batch_normalization_40 (BatchNo (None, 231, 32)      128         conv1d_reparameterization_26[0][0
>>__________________________________________________________________________________________________
>>activation_40 (Activation)      (None, 231, 32)      0           batch_normalization_40[0][0]     
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_27 (C (None, 231, 32)      2112        activation_40[0][0]              
>>__________________________________________________________________________________________________
>>batch_normalization_41 (BatchNo (None, 231, 32)      128         conv1d_reparameterization_27[0][0
>>__________________________________________________________________________________________________
>>activation_41 (Activation)      (None, 231, 32)      0           batch_normalization_41[0][0]     
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_28 (C (None, 231, 64)      4224        activation_41[0][0]              
>>__________________________________________________________________________________________________
>>batch_normalization_42 (BatchNo (None, 231, 64)      256         conv1d_reparameterization_28[0][0
>>__________________________________________________________________________________________________
>>activation_42 (Activation)      (None, 231, 64)      0           batch_normalization_42[0][0]     
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_29 (C (None, 231, 512)     66560       activation_42[0][0]              
>>__________________________________________________________________________________________________
>>batch_normalization_43 (BatchNo (None, 231, 512)     2048        conv1d_reparameterization_29[0][0
>>__________________________________________________________________________________________________
>>activation_43 (Activation)      (None, 231, 512)     0           batch_normalization_43[0][0]     
>>__________________________________________________________________________________________________
>>global_max_pooling1d_7 (GlobalM (None, 512)          0           activation_43[0][0]              
>>__________________________________________________________________________________________________
>>dense_reparameterization_20 (De (None, 256)          262656      global_max_pooling1d_7[0][0]     
>>__________________________________________________________________________________________________
>>batch_normalization_44 (BatchNo (None, 256)          1024        dense_reparameterization_20[0][0]
>>__________________________________________________________________________________________________
>>activation_44 (Activation)      (None, 256)          0           batch_normalization_44[0][0]     
>>__________________________________________________________________________________________________
>>dense_reparameterization_21 (De (None, 128)          65792       activation_44[0][0]              
>>__________________________________________________________________________________________________
>>batch_normalization_45 (BatchNo (None, 128)          512         dense_reparameterization_21[0][0]
>>__________________________________________________________________________________________________
>>activation_45 (Activation)      (None, 128)          0           batch_normalization_45[0][0]     
>>__________________________________________________________________________________________________
>>dense_1 (Dense)                 (None, 1024)         132096      activation_45[0][0]              
>>__________________________________________________________________________________________________
>>reshape_5 (Reshape)             (None, 32, 32)       0           dense_1[0][0]                    
>>__________________________________________________________________________________________________
>>dot_5 (Dot)                     (None, 231, 32)      0           activation_40[0][0]              
>>                                                                 reshape_5[0][0]                  
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_30 (C (None, 231, 32)      2112        dot_5[0][0]                      
>>__________________________________________________________________________________________________
>>batch_normalization_46 (BatchNo (None, 231, 32)      128         conv1d_reparameterization_30[0][0
>>__________________________________________________________________________________________________
>>activation_46 (Activation)      (None, 231, 32)      0           batch_normalization_46[0][0]     
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_31 (C (None, 231, 64)      4224        activation_46[0][0]              
>>__________________________________________________________________________________________________
>>batch_normalization_47 (BatchNo (None, 231, 64)      256         conv1d_reparameterization_31[0][0
>>__________________________________________________________________________________________________
>>activation_47 (Activation)      (None, 231, 64)      0           batch_normalization_47[0][0]     
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_32 (C (None, 231, 512)     66560       activation_47[0][0]              
>>__________________________________________________________________________________________________
>>batch_normalization_48 (BatchNo (None, 231, 512)     2048        conv1d_reparameterization_32[0][0
>>__________________________________________________________________________________________________
>>activation_48 (Activation)      (None, 231, 512)     0           batch_normalization_48[0][0]     
>>__________________________________________________________________________________________________
>>global_max_pooling1d_8 (GlobalM (None, 512)          0           activation_48[0][0]              
>>__________________________________________________________________________________________________
>>dense_reparameterization_22 (De (None, 256)          262656      global_max_pooling1d_8[0][0]     
>>__________________________________________________________________________________________________
>>batch_normalization_49 (BatchNo (None, 256)          1024        dense_reparameterization_22[0][0]
>>__________________________________________________________________________________________________
>>activation_49 (Activation)      (None, 256)          0           batch_normalization_49[0][0]     
>>__________________________________________________________________________________________________
>>dropout_4 (Dropout)             (None, 256)          0           activation_49[0][0]              
>>__________________________________________________________________________________________________
>>dense_reparameterization_23 (De (None, 128)          65792       dropout_4[0][0]                  
>>__________________________________________________________________________________________________
>>batch_normalization_50 (BatchNo (None, 128)          512         dense_reparameterization_23[0][0]
>>__________________________________________________________________________________________________
>>activation_50 (Activation)      (None, 128)          0           batch_normalization_50[0][0]     
>>__________________________________________________________________________________________________
>>dropout_5 (Dropout)             (None, 128)          0           activation_50[0][0]              
>>__________________________________________________________________________________________________
>>dense_reparameterization_24 (De (None, 10)           2580        dropout_5[0][0]                  
>>__________________________________________________________________________________________________
>>one_hot_categorical_2 (OneHotCa multiple             0           dense_reparameterization_24[0][0]
>>==================================================================================================
>>Total params: 1,352,541
>>Trainable params: 1,346,461
>>Non-trainable params: 6,080


>>Train model
EPOCHS = 30

# Audjusting learning rate
learning_rate_reduction = ReduceLROnPlateau(monitor='val_acc', 
                                            patience=3, 
                                            verbose=1, 
                                            factor=0.5, 
                                            min_lr=0.00001)


model.compile(
        loss=nll,
        optimizer=keras.optimizers.Adam(learning_rate=0.0001),
        metrics=["accuracy"]
    )

history = model.fit(train_dataset, epochs=EPOCHS, validation_data=test_dataset, verbose=2, callbacks=[learning_rate_reduction])

>>train_dataset <BatchDataset shapes: ((None, 231, 3), (None, 10)), types: (tf.float64, tf.float64)>
>>test_dataset <BatchDataset shapes: ((None, 231, 3), (None, 10)), types: (tf.float64, tf.float64)>
>>Epoch 1/30
>>1310/1310 - 60s - loss: 21.9061 - accuracy: 0.3590 - val_loss: 182635040.0000 - val_accuracy: 0.0833
>>Epoch 2/30
>>1310/1310 - 44s - loss: 20.6475 - accuracy: 0.7531 - val_loss: 60451084.0000 - val_accuracy: 0.1429
>>Epoch 3/30
>>1310/1310 - 44s - loss: 20.2906 - accuracy: 0.8938 - val_loss: 14749431.0000 - val_accuracy: 0.1071
>>Epoch 4/30
>>1310/1310 - 44s - loss: 20.0421 - accuracy: 0.9462 - val_loss: 22719518.0000 - val_accuracy: 0.1024
>>Epoch 5/30
>>1310/1310 - 43s - loss: 19.8038 - accuracy: 0.9664 - val_loss: 26609368.0000 - val_accuracy: 0.1381
>>Epoch 6/30
>>1310/1310 - 44s - loss: 19.5439 - accuracy: 0.9753 - val_loss: 2488617.5000 - val_accuracy: 0.0905
>>Epoch 7/30
>>1310/1310 - 44s - loss: 19.2616 - accuracy: 0.9819 - val_loss: 5732172.0000 - val_accuracy: 0.1048
>>Epoch 8/30
>>1310/1310 - 43s - loss: 18.9594 - accuracy: 0.9854 - val_loss: 3599270.0000 - val_accuracy: 0.1095
>>Epoch 9/30
>>1310/1310 - 44s - loss: 18.6469 - accuracy: 0.9876 - val_loss: 1626159.1250 - val_accuracy: 0.0976
>>Epoch 10/30
>>1310/1310 - 44s - loss: 18.3281 - accuracy: 0.9896 - val_loss: 2010936.1250 - val_accuracy: 0.1048
>>Epoch 11/30
>>1310/1310 - 44s - loss: 18.0074 - accuracy: 0.9907 - val_loss: 2303319.7500 - val_accuracy: 0.1190
>>Epoch 12/30
>>1310/1310 - 45s - loss: 17.6917 - accuracy: 0.9915 - val_loss: 6043947.5000 - val_accuracy: 0.0905
>>Epoch 13/30
>>1310/1310 - 44s - loss: 17.3721 - accuracy: 0.9932 - val_loss: 7441842.0000 - val_accuracy: 0.1071
>>Epoch 14/30
>>1310/1310 - 44s - loss: 17.0469 - accuracy: 0.9932 - val_loss: 8680777.0000 - val_accuracy: 0.1071
>>Epoch 15/30
>>1310/1310 - 44s - loss: 16.7201 - accuracy: 0.9942 - val_loss: 5293146.0000 - val_accuracy: 0.0810
>>Epoch 16/30
>>1310/1310 - 44s - loss: 16.3949 - accuracy: 0.9944 - val_loss: 3354950.5000 - val_accuracy: 0.1119
>>Epoch 17/30
>>1310/1310 - 44s - loss: 16.0781 - accuracy: 0.9944 - val_loss: 89280024.0000 - val_accuracy: 0.1167
>>Epoch 18/30
>>1310/1310 - 44s - loss: 15.7646 - accuracy: 0.9953 - val_loss: 2574807.5000 - val_accuracy: 0.0905
>>Epoch 19/30
>>1310/1310 - 43s - loss: 15.4577 - accuracy: 0.9952 - val_loss: 1522962.6250 - val_accuracy: 0.1143
>>Epoch 20/30
>>1310/1310 - 45s - loss: 15.1654 - accuracy: 0.9950 - val_loss: 26839674.0000 - val_accuracy: 0.0714
>>Epoch 21/30
>>1310/1310 - 43s - loss: 14.8791 - accuracy: 0.9954 - val_loss: 68192664.0000 - val_accuracy: 0.0976
>>Epoch 22/30
>>1310/1310 - 43s - loss: 14.5917 - accuracy: 0.9952 - val_loss: 69933488.0000 - val_accuracy: 0.0976
>>Epoch 23/30
>>1310/1310 - 44s - loss: 14.3119 - accuracy: 0.9955 - val_loss: 916111680.0000 - val_accuracy: 0.1286
>>Epoch 24/30
>>1310/1310 - 44s - loss: 14.0397 - accuracy: 0.9953 - val_loss: 1315758208.0000 - val_accuracy: 0.0952
>>Epoch 25/30
>>1310/1310 - 44s - loss: 13.7770 - accuracy: 0.9954 - val_loss: 176787792.0000 - val_accuracy: 0.1095
>>Epoch 26/30
>>1310/1310 - 44s - loss: 13.5258 - accuracy: 0.9958 - val_loss: 209021744.0000 - val_accuracy: 0.1000
>>Epoch 27/30
>>1310/1310 - 44s - loss: 13.2780 - accuracy: 0.9955 - val_loss: 217535568.0000 - val_accuracy: 0.1238
>>Epoch 28/30
>>1310/1310 - 44s - loss: 13.0325 - accuracy: 0.9960 - val_loss: 62926640.0000 - val_accuracy: 0.1024
>>Epoch 29/30
>>1310/1310 - 44s - loss: 12.7926 - accuracy: 0.9960 - val_loss: 111481688.0000 - val_accuracy: 0.0905
>>Epoch 30/30
>>1310/1310 - 44s - loss: 12.5660 - accuracy: 0.9951 - val_loss: 61252780.0000 - val_accuracy: 0.0810

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions