稀疏 3D 图像自动编码器的损失函数

数据挖掘 损失函数 自动编码器 稀疏性
2022-02-24 19:55:08

我有分子的 3D 结构数据。我将原子表示为 100*100*100 网格中的点,并应用高斯模糊来抵消稀疏性。(几乎所有网格单元都包含零)我正在尝试构建一个自动编码器来获得一个有意义的“分子结构到矢量”编码器。

我目前的方法是使用卷积层和最大池化层,然后展平和一些密集层来获得向量表示。然后我再次重塑并增加维度,直到模型预测在具有 sigmoid 的网格像素中存在原子的概率(参见下面的代码)。

如果我使用二元交叉熵,我担心模型不会学习,因为数据太稀疏了。我想要一个损失函数来惩罚“甚至不接近”的原子预测,而不是仅仅偏离几个网格单元的预测。

latent_dim= 512
input_mol = Input(shape=(100, 100, 100, 8))  # 8 channels for the different atom types

x = DepthwiseConv3D(kernel_size=(9,9,9), depth_multiplier=1,groups=8, padding ="same", use_bias=False)(input_mol) #gaussian blur
x = Conv3D(64, (3, 3, 3), activation='relu')(x)
x = MaxPooling3D((5, 5, 5))(x)
x = Conv3D(32, (3, 3, 3), activation='relu')(x)
x = MaxPooling3D((2, 2, 2))(x)
x = Conv3D(16, (3, 3, 3), activation='relu')(x)
x = MaxPooling3D((2, 2, 2))(x)
x = Flatten()(x)
x = Dense(1000, activation = 'relu')(x)
x = Dropout(rate=0.4)(x)
encoded = Dense(latent_dim, activation = 'relu')(x)

# add noise (variational autoencoder)
z_mean = Dense(latent_dim)(encoded)
z_log_sigma = Dense(latent_dim)(encoded)
z = Lambda(sampling, output_shape=(512,))([z_mean, z_log_sigma])


x= Reshape((8, 8, 8, 1))(encoded) 

x = Conv3D(32, (3,3, 3), activation='relu', padding='same')(x)
x = UpSampling3D((2, 2,2))(x)
x = Conv3D(32, (3,3, 3), activation='relu', padding='valid')(x)
x = UpSampling3D((2, 2,2))(x)
x = Conv3D(32, (3, 3,3), activation='relu', padding='valid')(x)
x = UpSampling3D((2, 2, 2))(x)
x = Conv3D(8, (3, 3,3), activation='relu', padding='valid')(x)
x = UpSampling3D((2, 2, 2))(x)
decoded = Conv3D(8, (10, 10, 10), activation='sigmoid', padding='same')(x)

autoencoder = Model(input_mol, decoded)
0个回答
没有发现任何回复~