-
Notifications
You must be signed in to change notification settings - Fork 7
Open
Description
Hi Frightera:
I am doing some test with pointnet to classify point cloud object. And there might be some other unknown object are not listed in training set.
So I tried to convert pointnet model(https://keras.io/examples/vision/pointnet/) to tensorflow probability model by refering your Simple Fully Probabilistic Bayesian CNN, and hope the model can say "I don't know" if target object is not in training model, but the results of the training were less than ideal.
Is it possible to convert point cloud model to tensorflow Probabilistic Deep Learning model ?
Here is my testing model:
NUM_EXAMPLES = 83820
NUM_CLASSES = 10
BATCH_SIZE = 64
NUM_POINTS = 231
# Revise reinterpreted_batch_ndims arg. Needed for custom prior & posterior.
shape = (3, )
dtype = tf.float64
distribution = tfd.Normal(loc = tf.zeros(shape, dtype), scale = tf.ones(shape, dtype))
# batch_ndims = tf.size(distribution.batch_shape_tensor())
for i in range(len(shape) + 1):
print('reinterpreted_batch_ndims: %d:' %(i))
independent_dist = tfd.Independent(distribution, reinterpreted_batch_ndims = i)
samples = independent_dist.sample()
print('batch_shape: {}'
' event_shape: {}'
' Sample shape: {}'.format(independent_dist._batch_shape(),
independent_dist._event_shape(),
samples.shape))
print('Samples:', samples.numpy(), '\n')
# The default posterior is Normal and if we use a laplace prior, we need to
# approximate the KL. If we try to compute KL with that 2 distributions, we will get:
# Error:
# No KL(distribution_a || distribution_b) registered for distribution_a type Normal and distribution_b type Laplace
# Call arguments received:
# • inputs=tf.Tensor(shape=(None, 28, 28, 1), dtype=float32)
def approximate_kl(q, p, q_tensor):
return tf.reduce_mean(q.log_prob(q_tensor) - p.log_prob(q_tensor))
divergence_fn = lambda q, p, q_tensor : approximate_kl(q, p, q_tensor) / NUM_EXAMPLES
def conv_reparameterization_layer(filters, kernel_size, padding):
# For simplicity, we use default prior and posterior.
# In the next parts, we will use custom mixture prior and posteriors.
return tfpl.Convolution1DReparameterization(
filters = filters,
kernel_size = kernel_size,
activation = None,
padding = padding,
kernel_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
kernel_prior_fn = tfpl.default_multivariate_normal_fn,
bias_prior_fn = tfpl.default_multivariate_normal_fn,
bias_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
kernel_divergence_fn = divergence_fn,
bias_divergence_fn = divergence_fn)
def dense_reparameterization_layer(filters):
return tfpl.DenseReparameterization(
units = filters, activation = None,
kernel_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
kernel_prior_fn = tfpl.default_multivariate_normal_fn,
bias_prior_fn = tfpl.default_multivariate_normal_fn,
bias_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
kernel_divergence_fn = divergence_fn,
bias_divergence_fn = divergence_fn)
def nll(y_true, y_pred):
return -y_pred.log_prob(y_true)
def conv_bn(x, filters):
x = conv_reparameterization_layer(filters, kernel_size=1, padding='valid')(x)
x = layers.BatchNormalization(momentum=0.0)(x)
return layers.Activation("relu")(x)
def dense_bn(x, filters):
x = dense_reparameterization_layer(filters)(x)
x = layers.BatchNormalization(momentum=0.0)(x)
return layers.Activation("relu")(x)
class OrthogonalRegularizer(keras.regularizers.Regularizer):
def __init__(self, num_features, l2reg=0.001):
self.num_features = num_features
self.l2reg = l2reg
self.eye = tf.eye(num_features)
def __call__(self, x):
x = tf.reshape(x, (-1, self.num_features, self.num_features))
xxt = tf.tensordot(x, x, axes=(2, 2))
xxt = tf.reshape(xxt, (-1, self.num_features, self.num_features))
return tf.reduce_sum(self.l2reg * tf.square(xxt - self.eye))
def tnet(inputs, num_features):
# Initalise bias as the indentity matrix
bias = keras.initializers.Constant(np.eye(num_features).flatten())
reg = OrthogonalRegularizer(num_features)
x = conv_bn(inputs, 32)
x = conv_bn(x, 64)
x = conv_bn(x, 512)
x = layers.GlobalMaxPooling1D()(x)
x = dense_bn(x, 256)
x = dense_bn(x, 128)
x = layers.Dense(
num_features * num_features,
kernel_initializer="zeros",
bias_initializer=bias,
activity_regularizer=reg,
)(x)
feat_T = layers.Reshape((num_features, num_features))(x)
# Apply affine transformation to input features
return layers.Dot(axes=(2, 1))([inputs, feat_T])
def create_model_normal(num_points):
inputs = keras.Input(shape=(num_points, 3))
x = tnet(inputs, 3)
x = conv_bn(x, 32)
x = conv_bn(x, 32)
x = tnet(x, 32)
x = conv_bn(x, 32)
x = conv_bn(x, 64)
x = conv_bn(x, 512)
x = layers.GlobalMaxPooling1D()(x)
x = dense_bn(x, 256)
x = layers.Dropout(0.25)(x)
x = dense_bn(x, 128)
x = layers.Dropout(0.25)(x)
dense = tfpl.DenseReparameterization(
units = tfpl.OneHotCategorical.params_size(NUM_CLASSES), activation = None,
kernel_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
kernel_prior_fn = tfpl.default_multivariate_normal_fn,
bias_prior_fn = tfpl.default_multivariate_normal_fn,
bias_posterior_fn = tfpl.default_mean_field_normal_fn(is_singular=False),
kernel_divergence_fn = divergence_fn,
bias_divergence_fn = divergence_fn)(x)
outputs = tfpl.OneHotCategorical(NUM_CLASSES)(dense)
model = keras.Model(inputs=inputs, outputs=outputs, name="pointnet")
return model
######################################################################################
def create_model(num_points):
return create_model_normal(num_points)
model = create_model(NUM_POINTS)
model.summary()
>>Model: "pointnet"
>>__________________________________________________________________________________________________
>>Layer (type) Output Shape Param # Connected to
>>==================================================================================================
>>input_3 (InputLayer) [(None, 231, 3)] 0
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_22 (C (None, 231, 32) 256 input_3[0][0]
>>__________________________________________________________________________________________________
>>batch_normalization_34 (BatchNo (None, 231, 32) 128 conv1d_reparameterization_22[0][0
>>__________________________________________________________________________________________________
>>activation_34 (Activation) (None, 231, 32) 0 batch_normalization_34[0][0]
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_23 (C (None, 231, 64) 4224 activation_34[0][0]
>>__________________________________________________________________________________________________
>>batch_normalization_35 (BatchNo (None, 231, 64) 256 conv1d_reparameterization_23[0][0
>>__________________________________________________________________________________________________
>>activation_35 (Activation) (None, 231, 64) 0 batch_normalization_35[0][0]
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_24 (C (None, 231, 512) 66560 activation_35[0][0]
>>__________________________________________________________________________________________________
>>batch_normalization_36 (BatchNo (None, 231, 512) 2048 conv1d_reparameterization_24[0][0
>>__________________________________________________________________________________________________
>>activation_36 (Activation) (None, 231, 512) 0 batch_normalization_36[0][0]
>>__________________________________________________________________________________________________
>>global_max_pooling1d_6 (GlobalM (None, 512) 0 activation_36[0][0]
>>__________________________________________________________________________________________________
>>dense_reparameterization_18 (De (None, 256) 262656 global_max_pooling1d_6[0][0]
>>__________________________________________________________________________________________________
>>batch_normalization_37 (BatchNo (None, 256) 1024 dense_reparameterization_18[0][0]
>>__________________________________________________________________________________________________
>>activation_37 (Activation) (None, 256) 0 batch_normalization_37[0][0]
>>__________________________________________________________________________________________________
>>dense_reparameterization_19 (De (None, 128) 65792 activation_37[0][0]
>>__________________________________________________________________________________________________
>>batch_normalization_38 (BatchNo (None, 128) 512 dense_reparameterization_19[0][0]
>>__________________________________________________________________________________________________
>>activation_38 (Activation) (None, 128) 0 batch_normalization_38[0][0]
>>__________________________________________________________________________________________________
>>dense (Dense) (None, 9) 1161 activation_38[0][0]
>>__________________________________________________________________________________________________
>>reshape_4 (Reshape) (None, 3, 3) 0 dense[0][0]
>>__________________________________________________________________________________________________
>>dot_4 (Dot) (None, 231, 3) 0 input_3[0][0]
>> reshape_4[0][0]
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_25 (C (None, 231, 32) 256 dot_4[0][0]
>>__________________________________________________________________________________________________
>>batch_normalization_39 (BatchNo (None, 231, 32) 128 conv1d_reparameterization_25[0][0
>>__________________________________________________________________________________________________
>>activation_39 (Activation) (None, 231, 32) 0 batch_normalization_39[0][0]
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_26 (C (None, 231, 32) 2112 activation_39[0][0]
>>__________________________________________________________________________________________________
>>batch_normalization_40 (BatchNo (None, 231, 32) 128 conv1d_reparameterization_26[0][0
>>__________________________________________________________________________________________________
>>activation_40 (Activation) (None, 231, 32) 0 batch_normalization_40[0][0]
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_27 (C (None, 231, 32) 2112 activation_40[0][0]
>>__________________________________________________________________________________________________
>>batch_normalization_41 (BatchNo (None, 231, 32) 128 conv1d_reparameterization_27[0][0
>>__________________________________________________________________________________________________
>>activation_41 (Activation) (None, 231, 32) 0 batch_normalization_41[0][0]
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_28 (C (None, 231, 64) 4224 activation_41[0][0]
>>__________________________________________________________________________________________________
>>batch_normalization_42 (BatchNo (None, 231, 64) 256 conv1d_reparameterization_28[0][0
>>__________________________________________________________________________________________________
>>activation_42 (Activation) (None, 231, 64) 0 batch_normalization_42[0][0]
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_29 (C (None, 231, 512) 66560 activation_42[0][0]
>>__________________________________________________________________________________________________
>>batch_normalization_43 (BatchNo (None, 231, 512) 2048 conv1d_reparameterization_29[0][0
>>__________________________________________________________________________________________________
>>activation_43 (Activation) (None, 231, 512) 0 batch_normalization_43[0][0]
>>__________________________________________________________________________________________________
>>global_max_pooling1d_7 (GlobalM (None, 512) 0 activation_43[0][0]
>>__________________________________________________________________________________________________
>>dense_reparameterization_20 (De (None, 256) 262656 global_max_pooling1d_7[0][0]
>>__________________________________________________________________________________________________
>>batch_normalization_44 (BatchNo (None, 256) 1024 dense_reparameterization_20[0][0]
>>__________________________________________________________________________________________________
>>activation_44 (Activation) (None, 256) 0 batch_normalization_44[0][0]
>>__________________________________________________________________________________________________
>>dense_reparameterization_21 (De (None, 128) 65792 activation_44[0][0]
>>__________________________________________________________________________________________________
>>batch_normalization_45 (BatchNo (None, 128) 512 dense_reparameterization_21[0][0]
>>__________________________________________________________________________________________________
>>activation_45 (Activation) (None, 128) 0 batch_normalization_45[0][0]
>>__________________________________________________________________________________________________
>>dense_1 (Dense) (None, 1024) 132096 activation_45[0][0]
>>__________________________________________________________________________________________________
>>reshape_5 (Reshape) (None, 32, 32) 0 dense_1[0][0]
>>__________________________________________________________________________________________________
>>dot_5 (Dot) (None, 231, 32) 0 activation_40[0][0]
>> reshape_5[0][0]
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_30 (C (None, 231, 32) 2112 dot_5[0][0]
>>__________________________________________________________________________________________________
>>batch_normalization_46 (BatchNo (None, 231, 32) 128 conv1d_reparameterization_30[0][0
>>__________________________________________________________________________________________________
>>activation_46 (Activation) (None, 231, 32) 0 batch_normalization_46[0][0]
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_31 (C (None, 231, 64) 4224 activation_46[0][0]
>>__________________________________________________________________________________________________
>>batch_normalization_47 (BatchNo (None, 231, 64) 256 conv1d_reparameterization_31[0][0
>>__________________________________________________________________________________________________
>>activation_47 (Activation) (None, 231, 64) 0 batch_normalization_47[0][0]
>>__________________________________________________________________________________________________
>>conv1d_reparameterization_32 (C (None, 231, 512) 66560 activation_47[0][0]
>>__________________________________________________________________________________________________
>>batch_normalization_48 (BatchNo (None, 231, 512) 2048 conv1d_reparameterization_32[0][0
>>__________________________________________________________________________________________________
>>activation_48 (Activation) (None, 231, 512) 0 batch_normalization_48[0][0]
>>__________________________________________________________________________________________________
>>global_max_pooling1d_8 (GlobalM (None, 512) 0 activation_48[0][0]
>>__________________________________________________________________________________________________
>>dense_reparameterization_22 (De (None, 256) 262656 global_max_pooling1d_8[0][0]
>>__________________________________________________________________________________________________
>>batch_normalization_49 (BatchNo (None, 256) 1024 dense_reparameterization_22[0][0]
>>__________________________________________________________________________________________________
>>activation_49 (Activation) (None, 256) 0 batch_normalization_49[0][0]
>>__________________________________________________________________________________________________
>>dropout_4 (Dropout) (None, 256) 0 activation_49[0][0]
>>__________________________________________________________________________________________________
>>dense_reparameterization_23 (De (None, 128) 65792 dropout_4[0][0]
>>__________________________________________________________________________________________________
>>batch_normalization_50 (BatchNo (None, 128) 512 dense_reparameterization_23[0][0]
>>__________________________________________________________________________________________________
>>activation_50 (Activation) (None, 128) 0 batch_normalization_50[0][0]
>>__________________________________________________________________________________________________
>>dropout_5 (Dropout) (None, 128) 0 activation_50[0][0]
>>__________________________________________________________________________________________________
>>dense_reparameterization_24 (De (None, 10) 2580 dropout_5[0][0]
>>__________________________________________________________________________________________________
>>one_hot_categorical_2 (OneHotCa multiple 0 dense_reparameterization_24[0][0]
>>==================================================================================================
>>Total params: 1,352,541
>>Trainable params: 1,346,461
>>Non-trainable params: 6,080
>>Train model
EPOCHS = 30
# Audjusting learning rate
learning_rate_reduction = ReduceLROnPlateau(monitor='val_acc',
patience=3,
verbose=1,
factor=0.5,
min_lr=0.00001)
model.compile(
loss=nll,
optimizer=keras.optimizers.Adam(learning_rate=0.0001),
metrics=["accuracy"]
)
history = model.fit(train_dataset, epochs=EPOCHS, validation_data=test_dataset, verbose=2, callbacks=[learning_rate_reduction])
>>train_dataset <BatchDataset shapes: ((None, 231, 3), (None, 10)), types: (tf.float64, tf.float64)>
>>test_dataset <BatchDataset shapes: ((None, 231, 3), (None, 10)), types: (tf.float64, tf.float64)>
>>Epoch 1/30
>>1310/1310 - 60s - loss: 21.9061 - accuracy: 0.3590 - val_loss: 182635040.0000 - val_accuracy: 0.0833
>>Epoch 2/30
>>1310/1310 - 44s - loss: 20.6475 - accuracy: 0.7531 - val_loss: 60451084.0000 - val_accuracy: 0.1429
>>Epoch 3/30
>>1310/1310 - 44s - loss: 20.2906 - accuracy: 0.8938 - val_loss: 14749431.0000 - val_accuracy: 0.1071
>>Epoch 4/30
>>1310/1310 - 44s - loss: 20.0421 - accuracy: 0.9462 - val_loss: 22719518.0000 - val_accuracy: 0.1024
>>Epoch 5/30
>>1310/1310 - 43s - loss: 19.8038 - accuracy: 0.9664 - val_loss: 26609368.0000 - val_accuracy: 0.1381
>>Epoch 6/30
>>1310/1310 - 44s - loss: 19.5439 - accuracy: 0.9753 - val_loss: 2488617.5000 - val_accuracy: 0.0905
>>Epoch 7/30
>>1310/1310 - 44s - loss: 19.2616 - accuracy: 0.9819 - val_loss: 5732172.0000 - val_accuracy: 0.1048
>>Epoch 8/30
>>1310/1310 - 43s - loss: 18.9594 - accuracy: 0.9854 - val_loss: 3599270.0000 - val_accuracy: 0.1095
>>Epoch 9/30
>>1310/1310 - 44s - loss: 18.6469 - accuracy: 0.9876 - val_loss: 1626159.1250 - val_accuracy: 0.0976
>>Epoch 10/30
>>1310/1310 - 44s - loss: 18.3281 - accuracy: 0.9896 - val_loss: 2010936.1250 - val_accuracy: 0.1048
>>Epoch 11/30
>>1310/1310 - 44s - loss: 18.0074 - accuracy: 0.9907 - val_loss: 2303319.7500 - val_accuracy: 0.1190
>>Epoch 12/30
>>1310/1310 - 45s - loss: 17.6917 - accuracy: 0.9915 - val_loss: 6043947.5000 - val_accuracy: 0.0905
>>Epoch 13/30
>>1310/1310 - 44s - loss: 17.3721 - accuracy: 0.9932 - val_loss: 7441842.0000 - val_accuracy: 0.1071
>>Epoch 14/30
>>1310/1310 - 44s - loss: 17.0469 - accuracy: 0.9932 - val_loss: 8680777.0000 - val_accuracy: 0.1071
>>Epoch 15/30
>>1310/1310 - 44s - loss: 16.7201 - accuracy: 0.9942 - val_loss: 5293146.0000 - val_accuracy: 0.0810
>>Epoch 16/30
>>1310/1310 - 44s - loss: 16.3949 - accuracy: 0.9944 - val_loss: 3354950.5000 - val_accuracy: 0.1119
>>Epoch 17/30
>>1310/1310 - 44s - loss: 16.0781 - accuracy: 0.9944 - val_loss: 89280024.0000 - val_accuracy: 0.1167
>>Epoch 18/30
>>1310/1310 - 44s - loss: 15.7646 - accuracy: 0.9953 - val_loss: 2574807.5000 - val_accuracy: 0.0905
>>Epoch 19/30
>>1310/1310 - 43s - loss: 15.4577 - accuracy: 0.9952 - val_loss: 1522962.6250 - val_accuracy: 0.1143
>>Epoch 20/30
>>1310/1310 - 45s - loss: 15.1654 - accuracy: 0.9950 - val_loss: 26839674.0000 - val_accuracy: 0.0714
>>Epoch 21/30
>>1310/1310 - 43s - loss: 14.8791 - accuracy: 0.9954 - val_loss: 68192664.0000 - val_accuracy: 0.0976
>>Epoch 22/30
>>1310/1310 - 43s - loss: 14.5917 - accuracy: 0.9952 - val_loss: 69933488.0000 - val_accuracy: 0.0976
>>Epoch 23/30
>>1310/1310 - 44s - loss: 14.3119 - accuracy: 0.9955 - val_loss: 916111680.0000 - val_accuracy: 0.1286
>>Epoch 24/30
>>1310/1310 - 44s - loss: 14.0397 - accuracy: 0.9953 - val_loss: 1315758208.0000 - val_accuracy: 0.0952
>>Epoch 25/30
>>1310/1310 - 44s - loss: 13.7770 - accuracy: 0.9954 - val_loss: 176787792.0000 - val_accuracy: 0.1095
>>Epoch 26/30
>>1310/1310 - 44s - loss: 13.5258 - accuracy: 0.9958 - val_loss: 209021744.0000 - val_accuracy: 0.1000
>>Epoch 27/30
>>1310/1310 - 44s - loss: 13.2780 - accuracy: 0.9955 - val_loss: 217535568.0000 - val_accuracy: 0.1238
>>Epoch 28/30
>>1310/1310 - 44s - loss: 13.0325 - accuracy: 0.9960 - val_loss: 62926640.0000 - val_accuracy: 0.1024
>>Epoch 29/30
>>1310/1310 - 44s - loss: 12.7926 - accuracy: 0.9960 - val_loss: 111481688.0000 - val_accuracy: 0.0905
>>Epoch 30/30
>>1310/1310 - 44s - loss: 12.5660 - accuracy: 0.9951 - val_loss: 61252780.0000 - val_accuracy: 0.0810
Metadata
Metadata
Assignees
Labels
No labels