博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
递归神经网络 简单示例
阅读量:5877 次
发布时间:2019-06-19

本文共 3857 字,大约阅读时间需要 12 分钟。

找到一个递归神经网络的例子,没看懂。

先保存,慢慢看。

# Recurrent Neural Networksimport copy, numpy as npnp.random.seed(0)# compute sigmoid nonlinearitydef sigmoid(x):    output = 1/(1+np.exp(-x))    return output# convert output of sigmoid function to its derivativedef sigmoid_output_to_derivative(output):    return output*(1-output)# training dataset generationint2binary = {}binary_dim = 8largest_number = pow(2,binary_dim)binary = np.unpackbits(    np.array([range(largest_number)],dtype=np.uint8).T,axis=1)for i in range(largest_number):    int2binary[i] = binary[i]# input variablesalpha = 0.1input_dim = 2hidden_dim = 16output_dim = 1# initialize neural network weightssynapse_0 = 2*np.random.random((input_dim,hidden_dim)) - 1synapse_1 = 2*np.random.random((hidden_dim,output_dim)) - 1synapse_h = 2*np.random.random((hidden_dim,hidden_dim)) - 1synapse_0_update = np.zeros_like(synapse_0)synapse_1_update = np.zeros_like(synapse_1)synapse_h_update = np.zeros_like(synapse_h)# training logicfor j in range(10000):        # generate a simple addition problem (a + b = c)    a_int = np.random.randint(largest_number/2) # int version    a = int2binary[a_int] # binary encoding    b_int = np.random.randint(largest_number/2) # int version    b = int2binary[b_int] # binary encoding    # true answer    c_int = a_int + b_int    c = int2binary[c_int]        # where we'll store our best guess (binary encoded)    d = np.zeros_like(c)    overallError = 0        layer_2_deltas = list()    layer_1_values = list()    layer_1_values.append(np.zeros(hidden_dim))        # moving along the positions in the binary encoding    for position in range(binary_dim):                # generate input and output        X = np.array([[a[binary_dim - position - 1],b[binary_dim - position - 1]]])        y = np.array([[c[binary_dim - position - 1]]]).T        # hidden layer (input ~+ prev_hidden)        layer_1 = sigmoid(np.dot(X,synapse_0) + np.dot(layer_1_values[-1],synapse_h))        # output layer (new binary representation)        layer_2 = sigmoid(np.dot(layer_1,synapse_1))        # did we miss?... if so, by how much?        layer_2_error = y - layer_2        layer_2_deltas.append((layer_2_error)*sigmoid_output_to_derivative(layer_2))        overallError += np.abs(layer_2_error[0])            # decode estimate so we can print(it out)        d[binary_dim - position - 1] = np.round(layer_2[0][0])                # store hidden layer so we can use it in the next timestep        layer_1_values.append(copy.deepcopy(layer_1))        future_layer_1_delta = np.zeros(hidden_dim)        for position in range(binary_dim):                X = np.array([[a[position],b[position]]])        layer_1 = layer_1_values[-position-1]        prev_layer_1 = layer_1_values[-position-2]                # error at output layer        layer_2_delta = layer_2_deltas[-position-1]        # error at hidden layer        layer_1_delta = (future_layer_1_delta.dot(synapse_h.T) + layer_2_delta.dot(synapse_1.T)) * sigmoid_output_to_derivative(layer_1)        # let's update all our weights so we can try again        synapse_1_update += np.atleast_2d(layer_1).T.dot(layer_2_delta)        synapse_h_update += np.atleast_2d(prev_layer_1).T.dot(layer_1_delta)        synapse_0_update += X.T.dot(layer_1_delta)                future_layer_1_delta = layer_1_delta        synapse_0 += synapse_0_update * alpha    synapse_1 += synapse_1_update * alpha    synapse_h += synapse_h_update * alpha        synapse_0_update *= 0    synapse_1_update *= 0    synapse_h_update *= 0        # print(out progress)    if j % 1000 == 0:        print("Error:" + str(overallError))        print("Pred:" + str(d))        print("True:" + str(c))        out = 0        for index,x in enumerate(reversed(d)):            out += x*pow(2,index)        print(str(a_int) + " + " + str(b_int) + " = " + str(out))        print("------------")

转载地址:http://rkuix.baihongyu.com/

你可能感兴趣的文章
常用弹出对话框函数
查看>>
2018暑期生活指导第三周
查看>>
php中字符串函数
查看>>
mysql中影响myisam引擎写入性能的三项设置
查看>>
python 中初始化二维数组的方法
查看>>
关于父类引用指向子类对象
查看>>
Oracle连接出现TNS:no listener或者ORA-12514: TNS:listener does not currently know
查看>>
第三次作业中遇到的困难和解决方法
查看>>
透过HT for Web 3D看动画Easing函数本质
查看>>
WebGL实现HTML5的3D贪吃蛇游戏
查看>>
JS 二维数组排序
查看>>
软件包管理 之 file.src.rpm 使用方法的简单介绍
查看>>
Perl语言的多线程(二)
查看>>
职场必须要会的餐桌礼仪
查看>>
Google Chrome Resize Plugin
查看>>
最优化 KKT条件
查看>>
Seekbar扩大点击区域
查看>>
angular开发环境搭建及新建项目
查看>>
ps6-图层基础与操作技巧
查看>>
git分支管理
查看>>