菜单 学习猿地 - LMONKEY

VIP

开通学习猿地VIP

尊享10项VIP特权 持续新增

知识通关挑战

打卡带练!告别无效练习

接私单赚外块

VIP优先接,累计金额超百万

学习猿地私房课免费学

大厂实战课仅对VIP开放

你的一对一导师

每月可免费咨询大牛30次

领取更多软件工程师实用特权

入驻
277
0

测试1

原创
05/13 14:22
阅读数 44575

 

# encoding: UTF-8
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data as mnist_data
import tensorflow as tf
from tensorflow.python.platform import gfile
import os

print("Tensorflow version " + tf.__version__)
print(tf.__path__)

# tf.set_random_seed(0)

# # 输入mnist数据
# mnist = mnist_data.read_data_sets("data", one_hot=True)

# #输入数据
# x = tf.placeholder("float", [None, 784])
# y_ = tf.placeholder("float", [None,10])

# #权值输入
# W = tf.Variable(tf.zeros([784,10]))
# b = tf.Variable(tf.zeros([10]))
# #神经网络输出
# y = tf.nn.softmax(tf.matmul(x,W) + b)

# #设置交叉熵
# cross_entropy = -tf.reduce_sum(y_*tf.log(y))

# #设置训练模型
# learningRate = 0.005
# train_step = tf.train.GradientDescentOptimizer(learningRate).minimize(cross_entropy)

# init = tf.initialize_all_variables()
# sess = tf.Session()
# sess.run(init)

# itnum = 1000;
# batch_size = 100;
# for i in range(itnum):
#     if i % 100 == 0:
#         print("the index " + str(i + 1) + " train")
#     batch_xs, batch_ys = mnist.train.next_batch(batch_size)
#     sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

# correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
# accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
# print sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})


def train():
    height = 28
    width = 28
    inchannel = 1
    outchannel = 2

    #conv0 (64, 112, 112) kernel (3, 3) stride (1, 1) pad (1, 1)
    wkernel = 3
    stride = 1
    pad = 1
    dilate  = 1

    w = np.arange(wkernel * wkernel * inchannel * outchannel).reshape((outchannel,inchannel,wkernel,wkernel))
    b = np.array([0])
    data = np.arange(height * width * inchannel).reshape((1,inchannel,height,width))
    print('input:',data)
    print('weight:',w)

    data = data.transpose(0,3,2,1)
    w = w.transpose(3,2,1,0)
    # print('input:',data)
    # print('inputshape:',data.shape)
    # print('weight:',w)
    # print('weight:',w.shape)
    input = tf.Variable(data, dtype=np.float32, name="input")
    #input_reshape = tf.reshape(input, [1,inchannel,height,width])
    filter = tf.Variable(w, dtype=np.float32,name="weight")

    conv = tf.nn.conv2d(input, filter, strides=[1, stride, stride, 1], padding='SAME', name = "conv")
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        #print("input: \n", sess.run(input))
        #input_reshape = sess.run(input).transpose(0,3,2,1)
        #print("input_reshape: \n", input_reshape)
        #print("filter: \n", sess.run(filter))
        #filter_reshape = sess.run(filter).transpose(3,2,1,0)
        #print("filter_reshape: \n", filter_reshape)
        #print("conv ", sess.run(conv))
        conv_reshape = sess.run(conv).transpose(0,3,2,1)
        print("conv_reshape: \n", conv_reshape)

        # tf_prelu_reshape = sess.run(tf_prelu).transpose(0,3,2,1)
        # print("tf_prelu_reshape: \n", tf_prelu_reshape)

        # tf_bn_reshape = sess.run(tf_bn).transpose(0,3,2,1)
        # print("tf_bn_reshape: \n", tf_bn_reshape)

        export_dir = "log"
        saver = tf.train.Saver()
        step = 200
        import os
        checkpoint_file = os.path.join(export_dir, 'model.ckpt')
        saver.save(sess, checkpoint_file, global_step=step)

        graph = tf.get_default_graph()
        checkpoint_file = os.path.join(export_dir, 'model.ckpt-200.meta')
        _ = tf.train.import_meta_graph(checkpoint_file)
        summary_write = tf.summary.FileWriter(export_dir , graph)


if __name__ == '__main__':
    train()

 

发表评论

0/200
277 点赞
0 评论
收藏
为你推荐 换一批