有人可以解释一下这个示例函数的用途吗?

数据挖掘 Python 数据清理 可能性 麻木的 文本生成
2021-09-16 19:25:34

所以在Deeplearning.ai 笔记本的 Dino_Name_Generator 中有一个函数

def sample(parameters, char_to_ix, seed):  
    # Retrieve parameters and relevant shapes from "parameters" dictionary
    Waa, Wax, Wya, by, b = parameters['Waa'], parameters['Wax'],parameters['Wya'], parameters['by'], parameters['b']
    vocab_size = by.shape[0]
    n_a = Waa.shape[1]

    ### START CODE HERE ###
    # Step 1: Create the one-hot vector x for the first character (initializing the sequence generation). (≈1 line)
    x = np.zeros((vocab_size, 1))

    # Step 1': Initialize a_prev as zeros (≈1 line)
    a_prev = np.zeros((n_a, 1))

    # Create an empty list of indices, this is the list which will contain the list of indices of the characters to generate (≈1 line)
    indices = []

    # Idx is a flag to detect a newline character, we initialize it to -1
    idx = -1 

    # Loop over time-steps t. At each time-step, sample a character from a probability distribution and append 
    # its index to "indices". We'll stop if we reach 50 characters (which should be very unlikely with a well 
    # trained model), which helps debugging and prevents entering an infinite loop. 
    counter = 0
    newline_character = char_to_ix['\n']

    while (idx != newline_character and counter != 50):

        # Step 2: Forward propagate x using the equations (1), (2) and (3)
        a = np.tanh(np.dot(Wax, x) + np.dot(Waa, a_prev) + b)
        z = np.dot(Wya, a) + by
        y = softmax(z)

        # for grading purposes
        np.random.seed(counter+seed) 

        # Step 3: Sample the index of a character within the vocabulary from the probability distribution y
        idx = np.random.choice(vocab_size, size=None, p = y.ravel())

        # Append the index to "indices"
        indices.append(idx)

        # Step 4: Overwrite the input character as the one corresponding to the sampled index.
        x = np.zeros((vocab_size, 1))
        x[[idx]] = 1

        # Update "a_prev" to be "a"
        a_prev = a

        # for grading purposes
        seed += 1
        counter +=1


    ### END CODE HERE ###

    if (counter == 50):
        indices.append(char_to_ix['\n'])

    return indices

有人可以帮忙解释一下返回的索引比普通的 char_to_integer 索引有什么好处吗?

我想了解在输入网络之前执行的链接中的文本处理。

1个回答

从您提供的链接:

根据RNN输出的概率分布序列对字符序列进行采样

参数:参数——包含参数 Waa、Wax、Wya、by 和 b 的 Python 字典。char_to_ix -- 将每个字符映射到索引的 python 字典。

返回: indices -- 一个长度为 n 的列表,其中包含采样字符的索引

您正在从作为参数提供的字典中返回索引。为什么要使用char_to_integer索引。