您可以尝试使用循环模型来识别零件,以便它“逐字母”读取字符串。我会提取电路编号并手动解析,然后将文本部分输入网络。在 TensorFlow 中,一个基本示例可能如下所示:
import tensorflow as tf
# Words used by people
TAGS = [
'BT',
'SW',
'LBLB-F',
'LBLB',
# ...
]
# "Component id"
LABELS = [
0, # 0 => Battery
1, # 1 => Circuit switch
2, # 2 => Fluroescent light bulb
3, # 3 => Light bulb
# ...
]
# Number of different components
NUM_CLASSES = 4
NUM_LAYERS = 1
LAYER_SIZE = 64
BATCH_SIZE = 100
NUM_EPOCHS = 100
# Inputs must be strings of the same size padded with '\0'
input_tags = tf.placeholder(tf.string, [None], name='Input')
# Convert to ascii values
tag_ascii = tf.decode_raw(input_tags, tf.uint8)
# Get actual lengths
mask = ~tf.equal(tag_ascii, 0)
tag_length = tf.reduce_sum(tf.cast(mask, tf.int32), axis=1)
# Convert to one-hot encoding
tag_1h = tf.one_hot(tag_ascii, 256, dtype=tf.float32)
# RNN
cells = [tf.nn.rnn_cell.BasicLSTMCell(LAYER_SIZE) for _ in range(NUM_LAYERS)]
rnn = tf.nn.rnn_cell.MultiRNNCell(cells)
rnn_output, _ = tf.nn.dynamic_rnn(rnn, tag_1h, sequence_length=tag_length, dtype=tf.float32)
# Get last RNN output
last_rnn_indices = tf.stack([tf.range(tf.shape(rnn_output)[0]), tag_length - 1], axis=-1)
rnn_last_output = tf.gather_nd(rnn_output, last_rnn_indices)
# Output layer
output_weights = tf.get_variable('OutputWeights', (LAYER_SIZE, NUM_CLASSES))
output_logit = rnn_last_output @ output_weights
# Final output as distribution and highest-scoring class
output_dist = tf.nn.softmax(output_logit)
output_class = tf.argmax(output_logit, axis=-1)
# Loss and training
input_labels = tf.placeholder(tf.int32, [None], name='Class')
loss = tf.losses.sparse_softmax_cross_entropy(labels=input_labels, logits=output_logit)
# Choose optimizer and hyperparameters
train_op = tf.train.AdamOptimizer().minimize(loss)
# Variable initialization
init_op = tf.global_variables_initializer()
# Preprocess words so all have the same size
max_tag_len = max(len(tag) for tag in TAGS)
tags_padded = [tag + '\0' * (max_tag_len - len(tag)) for tag in TAGS]
num_examples = len(tags_padded)
with tf.Session() as session:
session.run(init_op)
# Train
for i_epoch in range(NUM_EPOCHS):
for idx_batch in range(0, num_examples, BATCH_SIZE):
tags_batch = tags_padded[idx_batch:idx_batch + BATCH_SIZE]
labels_batch = LABELS[idx_batch:idx_batch + BATCH_SIZE]
session.run(train_op, feed_dict={input_tags: tags_batch, input_labels: labels_batch})
# Check results
predictions, dist = session.run([output_class, rnn_output], feed_dict={input_tags: tags_padded})
for tag, label, prediction in zip(TAGS, LABELS, predictions):
print('Tag {} is class {} and was predicted to be class {}.'.format(tag, label, prediction))
# Test for an unknown tag: 22LBLB-T should be class 3
tag = '22LBLB-T'
prediction = session.run(output_class, feed_dict={input_tags: [tag]})[0]
print('Tag {} was predicted to be class {}.'.format(tag, prediction))
输出:
Tag BT is class 0 and was predicted to be class 0.
Tag SW is class 1 and was predicted to be class 1.
Tag LBLB-F is class 2 and was predicted to be class 2.
Tag LBLB is class 3 and was predicted to be class 3.
Tag 22LBLB-T was predicted to be class 2.
它基本上采用每个字符串,将其转换为数字向量,将它们转换为 one-hot 编码并将它们提供给循环网络(加上一个输出层)。在这种特殊情况下,它在末尾用空字符填充字符串,因此所有字符串都具有相同的长度。
22LBLB-T我在最后添加了一个测试,用于测试应该归类为灯泡的看不见的标签。在这种情况下,模型失败并说它是荧光灯泡,但公平地说,它没有太多线索来找出正确答案,鉴于数据(事实上,在这种情况下,该标签更类似于荧光灯灯泡,因为它有一个连字符 - 如果你认为它们会“混淆”模型,你可以考虑过滤过滤字符,如连字符和其他字符)。无论如何,模型的预测对于提供的数据“有意义”(它没有预测它是电池或电路开关,这没有任何意义)。