组合嵌入的方法

数据挖掘 机器学习 张量流 词嵌入
2022-01-23 20:17:49

我正在为在线商店构建推荐器,并且我有属于以下类别之一的分类输入:

  • 用户当前会话特征(例如current_product_brand, current_product_id
  • 用户以前的活动,即最近的 10 个会话(例如previous_product_id, previous_product_brand
  • 用户以前的活动 2D 特征,即(例如product_semantics_embeddingproduct_tags_embedding-假设每个产品有 20 个标签/语义)

模型摘要的指示部分:

_________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
current_product_brand (InputLaye   (None, 1)            0                                            
__________________________________________________________________________________________________
current_product_id (InputLayer)    (None, 1)            0                                            
__________________________________________________________________________________________________
previous_product_id (InputLayer)   (None, 10)           0                                            
__________________________________________________________________________________________________
previous_product_brand (InputLayer)(None, 10)           0                                            
__________________________________________________________________________________________________
previous_product_semantics (InputL (None, 200)          0                                            
__________________________________________________________________________________________________
previous_product_tags (InputLayer) (None, 200)          0                                            
__________________________________________________________________________________________________
product_brand_embedding (Embeddi   (None, 10, 60)       60       current_product_brand[0][0]       
                                                                 previous_product_brand[0][0]      
__________________________________________________________________________________________________
product_id_embedding (Embedding)   (None, 10, 60)       98820    current_product_id[0][0]            
                                                                 previous_product_id[0][0]
__________________________________________________________________________________________________
product_semantics_embedding (Embed (None, 200, 60)      104880   previous_product_semantics[0][0]    
__________________________________________________________________________________________________
product_tags_embedding (Embedding) (None, 200, 60)      2760     previous_product_tags[0][0]         
__________________________________________________________________________________________________      

在当前和以前的会话中出现的产品品牌等特征嵌入在同一个空间中。

请注意,所有嵌入的输出都是恒定的(在本例中为 60)。

现在,我想将所有嵌入组合成一个张量,以便将它们馈送到另一层,例如 Dense。我认为我的选择如下:

  • 连接所有嵌入:我不能使用轴 1,因为product_semanticsproduct_tags具有不同的形状。将它们连接在轴 2 上是否有意义?
  • 每组连接它们,即 concat product_brand_embeddingwithproduct_id_embeddingproduct_semantics_embeddingwith product_tags_embedding,对每个结果应用全局平均池化,然后连接全局平均池化节点的 2 个输出。

哪条路是正确的?还有其他选择吗?

1个回答

我已经使用 GlobalAveragePooling1D 解决了这个确切的问题,以使这些多元嵌入的输出变平,然后将它们与 1D 嵌入一起串联起来,这就是 Youtube 在自己的推荐引擎中处理这些事情的方式:

在此处输入图像描述

因此,在您的位置上,我会采用您的第二种方法。当我这样做时,模型结构看起来像这样:

Model: "model_8"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
history-input (InputLayer)      (None, 3063)         0                                            
__________________________________________________________________________________________________
manufacturer-history-input (Inp (None, 3063)         0                                            
__________________________________________________________________________________________________
history-embedding (Embedding)   (None, 3063, 32)     538144      history-input[0][0]              
__________________________________________________________________________________________________
manufacturer-history-embedding  (None, 3063, 32)     32224       manufacturer-history-input[0][0] 
__________________________________________________________________________________________________
average-history-embedding (Glob (None, 32)           0           history-embedding[0][0]          
__________________________________________________________________________________________________
manufacturer-average-history-em (None, 32)           0           manufacturer-history-embedding[0]
__________________________________________________________________________________________________
numeric-inputs (InputLayer)     (None, 49)           0                                            
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 113)          0           average-history-embedding[0][0]  
                                                                 manufacturer-average-history-embe
                                                                 numeric-inputs[0][0]             
__________________________________________________________________________________________________
dense-512 (Dense)               (None, 1024)         116736      concatenate[0][0]                
__________________________________________________________________________________________________
dropout_13 (Dropout)            (None, 1024)         0           dense-512[0][0]                  
__________________________________________________________________________________________________
dense-256 (Dense)               (None, 512)          524800      dropout_13[0][0]                 
__________________________________________________________________________________________________
dropout_15 (Dropout)            (None, 512)          0           dense-256[0][0]                  
__________________________________________________________________________________________________
target (Dense)                  (None, 16817)        8627121     dropout_15[0][0]                 
==================================================================================================
Total params: 9,839,025
Trainable params: 9,839,025
Non-trainable params: 0