Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ram+enlargement #23

Open
049940 opened this issue Dec 8, 2023 · 0 comments
Open

Ram+enlargement #23

049940 opened this issue Dec 8, 2023 · 0 comments

Comments

@049940
Copy link

049940 commented Dec 8, 2023

xrange = range

dataset = tf_mnist_loader.read_data_sets("mnist_data")
save_dir = "chckPts/"
save_prefix = "save"
summaryFolderName = "summary/"

if len(sys.argv) == 2:
simulationName = str(sys.argv[1])
print("Simulation name = " + simulationName)
summaryFolderName = summaryFolderName + simulationName + "/"
saveImgs = True
imgsFolderName = "imgs/" + simulationName + "/"
if os.path.isdir(summaryFolderName) == False:
os.mkdir(summaryFolderName)
# if os.path.isdir(imgsFolderName) == False:
# os.mkdir(imgsFolderName)
else:
saveImgs = False
print("Testing... image files will not be saved.")

start_step = 0
#load_path = None
load_path = save_dir + save_prefix + str(start_step) + ".ckpt"

to enable visualization, set draw to True

eval_only = False
draw = False
animate = False

conditions

translateMnist = 1
eyeCentered = 0

preTraining = 0
preTraining_epoch = 20000
drawReconsturction = 0

about translation

MNIST_SIZE = 28
translated_img_size = 60 # side length of the picture

fixed_learning_rate = 0.001

if translateMnist:
print("TRANSLATED MNIST")
img_size = translated_img_size
depth = 3 # number of zooms
sensorBandwidth = 12
minRadius = 8 # zooms -> minRadius * 2**<depth_level>

initLr = 1e-3
lr_min = 1e-4
lrDecayRate = .999
lrDecayFreq = 200
momentumValue = .9
batch_size = 64

else:
print("CENTERED MNIST")
img_size = MNIST_SIZE
depth = 1 # number of zooms
sensorBandwidth = 8
minRadius = 4 # zooms -> minRadius * 2**<depth_level>

initLr = 1e-3
lrDecayRate = .99
lrDecayFreq = 200
momentumValue = .9
batch_size = 20

model parameters

channels = 1 # mnist are grayscale images
totalSensorBandwidth = depth * channels * (sensorBandwidth **2)
nGlimpses = 6 # number of glimpses
loc_sd = 0.22 # std when setting the location

network units

hg_size = 128 #
hl_size = 128 #
g_size = 256 #
cell_size = 256 #
cell_out_size = cell_size #

paramters about the training examples

n_classes = 10 # card(Y)

training parameters

max_iters = 1000000
SMALL_NUM = 1e-10

resource prellocation

mean_locs = [] # expectation of locations
sampled_locs = [] # sampled locations ~N(mean_locs[.], loc_sd)
baselines = [] # baseline, the value prediction
glimpse_images = [] # to show in window

set the weights to be small random values, with truncated normal distribution

def weight_variable(shape, myname, train):
initial = tf.random_uniform(shape, minval=-0.1, maxval = 0.1)
return tf.Variable(initial, name=myname, trainable=train)

get local glimpses

def glimpseSensor(img, normLoc):
loc = tf.round(((normLoc + 1) / 2.0) * img_size) # normLoc coordinates are between -1 and 1
loc = tf.cast(loc, tf.int32)

img = tf.reshape(img, (batch_size, img_size, img_size, channels))

# process each image individually
zooms = []
for k in range(batch_size):
    imgZooms = []
    one_img = img[k,:,:,:]
    max_radius = minRadius * (2 ** (depth - 1))
    offset = 2 * max_radius

    # pad image with zeros
    one_img = tf.image.pad_to_bounding_box(one_img, offset, offset, \
                                           max_radius * 4 + img_size, max_radius * 4 + img_size)

    for i in range(depth):
        r = int(minRadius * (2 ** (i)))

        d_raw = 2 * r
        d = tf.constant(d_raw, shape=[1])
        d = tf.tile(d, [2])
        loc_k = loc[k,:]
        adjusted_loc = offset + loc_k - r
        one_img2 = tf.reshape(one_img, (one_img.get_shape()[0].value, one_img.get_shape()[1].value))

        # crop image to (d x d)
        zoom = tf.slice(one_img2, adjusted_loc, d)

        # resize cropped image to (sensorBandwidth x sensorBandwidth)
        zoom = tf.image.resize_bilinear(tf.reshape(zoom, (1, d_raw, d_raw, 1)), (sensorBandwidth, sensorBandwidth))
        zoom = tf.reshape(zoom, (sensorBandwidth, sensorBandwidth))
        imgZooms.append(zoom)

    zooms.append(tf.stack(imgZooms))

zooms = tf.stack(zooms)

glimpse_images.append(zooms)

return zooms

implements the input network

def get_glimpse(loc):
# get input using the previous location
glimpse_input = glimpseSensor(inputs_placeholder, loc)
glimpse_input = tf.reshape(glimpse_input, (batch_size, totalSensorBandwidth))

# the hidden units that process location & the input
act_glimpse_hidden = tf.nn.relu(tf.matmul(glimpse_input, Wg_g_h) + Bg_g_h)
act_loc_hidden = tf.nn.relu(tf.matmul(loc, Wg_l_h) + Bg_l_h)

# the hidden units that integrates the location & the glimpses
glimpseFeature1 = tf.nn.relu(tf.matmul(act_glimpse_hidden, Wg_hg_gf1) + tf.matmul(act_loc_hidden, Wg_hl_gf1) + Bg_hlhg_gf1)
# return g
# glimpseFeature2 = tf.matmul(glimpseFeature1, Wg_gf1_gf2) + Bg_gf1_gf2
return glimpseFeature1

def get_next_input(output):
# the next location is computed by the location network
core_net_out = tf.stop_gradient(output)

# baseline = tf.sigmoid(tf.matmul(core_net_out, Wb_h_b) + Bb_h_b)
baseline = tf.sigmoid(tf.matmul(core_net_out, Wb_h_b) + Bb_h_b)
baselines.append(baseline)

# compute the next location, then impose noise
if eyeCentered:
    # add the last sampled glimpse location
    # TODO max(-1, min(1, u + N(output, sigma) + prevLoc))
    mean_loc = tf.maximum(-1.0, tf.minimum(1.0, tf.matmul(core_net_out, Wl_h_l) + sampled_locs[-1] ))
else:
    # mean_loc = tf.clip_by_value(tf.matmul(core_net_out, Wl_h_l) + Bl_h_l, -1, 1)
    mean_loc = tf.matmul(core_net_out, Wl_h_l) + Bl_h_l
    mean_loc = tf.clip_by_value(mean_loc, -1, 1)
# mean_loc = tf.stop_gradient(mean_loc)
mean_locs.append(mean_loc)

# add noise
# sample_loc = tf.tanh(mean_loc + tf.random_normal(mean_loc.get_shape(), 0, loc_sd))
sample_loc = tf.maximum(-1.0, tf.minimum(1.0, mean_loc + tf.random_normal(mean_loc.get_shape(), 0, loc_sd)))

# don't propagate throught the locations
sample_loc = tf.stop_gradient(sample_loc)
sampled_locs.append(sample_loc)

return get_glimpse(sample_loc)

def affineTransform(x,output_dim):
"""
affine transformation Wx+b
assumes x.shape = (batch_size, num_features)
"""
w=tf.get_variable("w", [x.get_shape()[1], output_dim])
b=tf.get_variable("b", [output_dim], initializer=tf.constant_initializer(0.0))
return tf.matmul(x,w)+b

def model():

# initialize the location under unif[-1,1], for all example in the batch
initial_loc = tf.random_uniform((batch_size, 2), minval=-1, maxval=1)
mean_locs.append(initial_loc)

# initial_loc = tf.tanh(initial_loc + tf.random_normal(initial_loc.get_shape(), 0, loc_sd))
initial_loc = tf.clip_by_value(initial_loc + tf.random_normal(initial_loc.get_shape(), 0, loc_sd), -1, 1)

sampled_locs.append(initial_loc)

# get the input using the input network
initial_glimpse = get_glimpse(initial_loc)

# set up the recurrent structure
inputs = [0] * nGlimpses
outputs = [0] * nGlimpses
glimpse = initial_glimpse
REUSE = None
for t in range(nGlimpses):
    if t == 0:  # initialize the hidden state to be the zero vector
        hiddenState_prev = tf.zeros((batch_size, cell_size))
    else:
        hiddenState_prev = outputs[t-1]

    # forward prop
    with tf.variable_scope("coreNetwork", reuse=REUSE):
        # the next hidden state is a function of the previous hidden state and the current glimpse
        hiddenState = tf.nn.relu(affineTransform(hiddenState_prev, cell_size) + (tf.matmul(glimpse, Wc_g_h) + Bc_g_h))

    # save the current glimpse and the hidden state
    inputs[t] = glimpse
    outputs[t] = hiddenState
    # get the next input glimpse
    if t != nGlimpses -1:
        glimpse = get_next_input(hiddenState)
    else:
        first_hiddenState = tf.stop_gradient(hiddenState)
        # baseline = tf.sigmoid(tf.matmul(first_hiddenState, Wb_h_b) + Bb_h_b)
        baseline = tf.sigmoid(tf.matmul(first_hiddenState, Wb_h_b) + Bb_h_b)
        baselines.append(baseline)
    REUSE = True  # share variables for later recurrence

return outputs

def dense_to_one_hot(labels_dense, num_classes=10):
"""Convert class labels from scalars to one-hot vectors."""
# copied from TensorFlow tutorial
num_labels = labels_dense.shape[0]
index_offset = np.arange(num_labels) * num_classes
labels_one_hot = np.zeros((num_labels, num_classes))
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
return labels_one_hot

to use for maximum likelihood with input location

def gaussian_pdf(mean, sample):
Z = 1.0 / (loc_sd * tf.sqrt(2.0 * np.pi))
a = -tf.square(sample - mean) / (2.0 * tf.square(loc_sd))
return Z * tf.exp(a)

def calc_reward(outputs):

# consider the action at the last time step
outputs = outputs[-1] # look at ONLY THE END of the sequence
outputs = tf.reshape(outputs, (batch_size, cell_out_size))

# get the baseline
b = tf.stack(baselines)
b = tf.concat(axis=2, values=[b, b])
b = tf.reshape(b, (batch_size, (nGlimpses) * 2))
no_grad_b = tf.stop_gradient(b)

# get the action(classification)
p_y = tf.nn.softmax(tf.matmul(outputs, Wa_h_a) + Ba_h_a)
max_p_y = tf.arg_max(p_y, 1)
correct_y = tf.cast(labels_placeholder, tf.int64)

# reward for all examples in the batch
R = tf.cast(tf.equal(max_p_y, correct_y), tf.float32)
reward = tf.reduce_mean(R) # mean reward
R = tf.reshape(R, (batch_size, 1))
R = tf.tile(R, [1, (nGlimpses)*2])

# get the location

p_loc = gaussian_pdf(mean_locs, sampled_locs)
# p_loc = tf.tanh(p_loc)

p_loc_orig = p_loc
p_loc = tf.reshape(p_loc, (batch_size, (nGlimpses) * 2))

# define the cost function
J = tf.concat(axis=1, values=[tf.log(p_y + SMALL_NUM) * (onehot_labels_placeholder), tf.log(p_loc + SMALL_NUM) * (R - no_grad_b)])
J = tf.reduce_sum(J, 1)
J = J - tf.reduce_sum(tf.square(R - b), 1)
J = tf.reduce_mean(J, 0)
cost = -J
var_list = tf.trainable_variables()
grads = tf.gradients(cost, var_list)
grads, _ = tf.clip_by_global_norm(grads, 0.5)
# define the optimizer
# lr_max = tf.maximum(lr, lr_min)
optimizer = tf.train.AdamOptimizer(lr)
# optimizer = tf.train.MomentumOptimizer(lr, momentumValue)
# train_op = optimizer.minimize(cost, global_step)
train_op = optimizer.apply_gradients(zip(grads, var_list), global_step=global_step)

return cost, reward, max_p_y, correct_y, train_op, b, tf.reduce_mean(b), tf.reduce_mean(R - b), lr

def preTrain(outputs):
lr_r = 1e-3
# consider the action at the last time step
outputs = outputs[-1] # look at ONLY THE END of the sequence
outputs = tf.reshape(outputs, (batch_size, cell_out_size))
# if preTraining:
reconstruction = tf.sigmoid(tf.matmul(outputs, Wr_h_r) + Br_h_r)
reconstructionCost = tf.reduce_mean(tf.square(inputs_placeholder - reconstruction))

train_op_r = tf.train.RMSPropOptimizer(lr_r).minimize(reconstructionCost)
return reconstructionCost, reconstruction, train_op_r

def evaluate():
data = dataset.test
batches_in_epoch = len(data._images) // batch_size
accuracy = 0

for i in range(batches_in_epoch):
    nextX, nextY = dataset.test.next_batch(batch_size)
    if translateMnist:
        nextX, _ = convertTranslated(nextX, MNIST_SIZE, img_size)
    feed_dict = {inputs_placeholder: nextX, labels_placeholder: nextY,
                 onehot_labels_placeholder: dense_to_one_hot(nextY)}
    r = sess.run(reward, feed_dict=feed_dict)
    accuracy += r

accuracy /= batches_in_epoch
print(("ACCURACY: " + str(accuracy)))

def convertTranslated(images, initImgSize, finalImgSize):
size_diff = finalImgSize - initImgSize
newimages = np.zeros([batch_size, finalImgSizefinalImgSize])
imgCoord = np.zeros([batch_size,2])
for k in range(batch_size):
image = images[k, :]
image = np.reshape(image, (initImgSize, initImgSize))
# generate and save random coordinates
randX = random.randint(0, size_diff)
randY = random.randint(0, size_diff)
imgCoord[k,:] = np.array([randX, randY])
# padding
image = np.lib.pad(image, ((randX, size_diff - randX), (randY, size_diff - randY)), 'constant', constant_values = (0))
newimages[k, :] = np.reshape(image, (finalImgSize
finalImgSize))

return newimages, imgCoord

def toMnistCoordinates(coordinate_tanh):
'''
Transform coordinate in [-1,1] to mnist
:param coordinate_tanh: vector in [-1,1] x [-1,1]
:return: vector in the corresponding mnist coordinate
'''
return np.round(((coordinate_tanh + 1) / 2.0) * img_size)

def variable_summaries(var, name):
"""Attach a lot of summaries to a Tensor."""
with tf.name_scope('param_summaries'):
mean = tf.reduce_mean(var)
tf.summary.scalar('param_mean/' + name, mean)
with tf.name_scope('param_stddev'):
stddev = tf.sqrt(tf.reduce_sum(tf.square(var - mean)))
tf.summary.scalar('param_sttdev/' + name, stddev)
tf.summary.scalar('param_max/' + name, tf.reduce_max(var))
tf.summary.scalar('param_min/' + name, tf.reduce_min(var))
tf.summary.histogram(name, var)

def plotWholeImg(img, img_size, sampled_locs_fetched):
plt.imshow(np.reshape(img, [img_size, img_size]),
cmap=plt.get_cmap('gray'), interpolation="nearest")

plt.ylim((img_size - 1, 0))
plt.xlim((0, img_size - 1))

# transform the coordinate to mnist map
sampled_locs_mnist_fetched = toMnistCoordinates(sampled_locs_fetched)
# visualize the trace of successive nGlimpses (note that x and y coordinates are "flipped")
plt.plot(sampled_locs_mnist_fetched[0, :, 1], sampled_locs_mnist_fetched[0, :, 0], '-o',
         color='lawngreen')
plt.plot(sampled_locs_mnist_fetched[0, -1, 1], sampled_locs_mnist_fetched[0, -1, 0], 'o',
         color='red')

with tf.device('/gpu:1'):

with tf.Graph().as_default():

    # set the learning rate
    global_step = tf.Variable(0, trainable=False)
    lr = tf.train.exponential_decay(initLr, global_step, lrDecayFreq, lrDecayRate, staircase=True)

    # preallocate x, y, baseline
    labels = tf.placeholder("float32", shape=[batch_size, n_classes])
    labels_placeholder = tf.placeholder(tf.float32, shape=(batch_size), name="labels_raw")
    onehot_labels_placeholder = tf.placeholder(tf.float32, shape=(batch_size, 10), name="labels_onehot")
    inputs_placeholder = tf.placeholder(tf.float32, shape=(batch_size, img_size * img_size), name="images")

    # declare the model parameters, here're naming rule:
    # the 1st captical letter: weights or bias (W = weights, B = bias)
    # the 2nd lowercase letter: the network (e.g.: g = glimpse network)
    # the 3rd and 4th letter(s): input-output mapping, which is clearly written in the variable name argument

    Wg_l_h = weight_variable((2, hl_size), "glimpseNet_wts_location_hidden", True)
    Bg_l_h = weight_variable((1,hl_size), "glimpseNet_bias_location_hidden", True)

    Wg_g_h = weight_variable((totalSensorBandwidth, hg_size), "glimpseNet_wts_glimpse_hidden", True)
    Bg_g_h = weight_variable((1,hg_size), "glimpseNet_bias_glimpse_hidden", True)

    Wg_hg_gf1 = weight_variable((hg_size, g_size), "glimpseNet_wts_hiddenGlimpse_glimpseFeature1", True)
    Wg_hl_gf1 = weight_variable((hl_size, g_size), "glimpseNet_wts_hiddenLocation_glimpseFeature1", True)
    Bg_hlhg_gf1 = weight_variable((1,g_size), "glimpseNet_bias_hGlimpse_hLocs_glimpseFeature1", True)

    Wc_g_h = weight_variable((cell_size, g_size), "coreNet_wts_glimpse_hidden", True)
    Bc_g_h = weight_variable((1,g_size), "coreNet_bias_glimpse_hidden", True)

    Wr_h_r = weight_variable((cell_out_size, img_size**2), "reconstructionNet_wts_hidden_action", True)
    Br_h_r = weight_variable((1, img_size**2), "reconstructionNet_bias_hidden_action", True)

    Wb_h_b = weight_variable((g_size, 1), "baselineNet_wts_hiddenState_baseline", True)
    Bb_h_b = weight_variable((1,1), "baselineNet_bias_hiddenState_baseline", True)

    Wl_h_l = weight_variable((cell_out_size, 2), "locationNet_wts_hidden_location", True)
    Bl_h_l = weight_variable((1, 2), "locationNet_bias_hidden_location", True)

    Wa_h_a = weight_variable((cell_out_size, n_classes), "actionNet_wts_hidden_action", True)
    Ba_h_a = weight_variable((1,n_classes),  "actionNet_bias_hidden_action", True)

    # query the model ouput
    outputs = model()

    # convert list of tensors to one big tensor
    sampled_locs = tf.concat(axis=0, values=sampled_locs)
    sampled_locs = tf.reshape(sampled_locs, (nGlimpses, batch_size, 2))
    sampled_locs = tf.transpose(sampled_locs, [1, 0, 2])
    mean_locs = tf.concat(axis=0, values=mean_locs)
    mean_locs = tf.reshape(mean_locs, (nGlimpses, batch_size, 2))
    mean_locs = tf.transpose(mean_locs, [1, 0, 2])
    glimpse_images = tf.concat(axis=0, values=glimpse_images)

    # compute the reward
    reconstructionCost, reconstruction, train_op_r = preTrain(outputs)
    cost, reward, predicted_labels, correct_labels, train_op, b, avg_b, rminusb, lr = calc_reward(outputs)

    # tensorboard visualization for the parameters
    variable_summaries(Wg_l_h, "glimpseNet_wts_location_hidden")
    variable_summaries(Bg_l_h, "glimpseNet_bias_location_hidden")
    variable_summaries(Wg_g_h, "glimpseNet_wts_glimpse_hidden")
    variable_summaries(Bg_g_h, "glimpseNet_bias_glimpse_hidden")
    variable_summaries(Wg_hg_gf1, "glimpseNet_wts_hiddenGlimpse_glimpseFeature1")
    variable_summaries(Wg_hl_gf1, "glimpseNet_wts_hiddenLocation_glimpseFeature1")
    variable_summaries(Bg_hlhg_gf1, "glimpseNet_bias_hGlimpse_hLocs_glimpseFeature1")

    variable_summaries(Wc_g_h, "coreNet_wts_glimpse_hidden")
    variable_summaries(Bc_g_h, "coreNet_bias_glimpse_hidden")

    variable_summaries(Wb_h_b, "baselineNet_wts_hiddenState_baseline")
    variable_summaries(Bb_h_b, "baselineNet_bias_hiddenState_baseline")

    variable_summaries(Wl_h_l, "locationNet_wts_hidden_location")

    variable_summaries(Wa_h_a, 'actionNet_wts_hidden_action')
    variable_summaries(Ba_h_a, 'actionNet_bias_hidden_action')

    # tensorboard visualization for the performance metrics
    tf.summary.scalar("reconstructionCost", reconstructionCost)
    tf.summary.scalar("reward", reward)
    tf.summary.scalar("cost", cost)
    tf.summary.scalar("mean(b)", avg_b)
    tf.summary.scalar("mean(R - b)", rminusb)
    summary_op = tf.summary.merge_all()


    ####################################### START RUNNING THE MODEL #######################################

    sess_config = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
    sess_config.gpu_options.allow_growth = True
    sess = tf.Session(config=sess_config)

    saver = tf.train.Saver()
    b_fetched = np.zeros((batch_size, (nGlimpses)*2))

    init = tf.global_variables_initializer()
    sess.run(init)

    if eval_only:
        evaluate()
    else:
        summary_writer = tf.summary.FileWriter(summaryFolderName, graph=sess.graph)

        if draw:
            fig = plt.figure(1)
            txt = fig.suptitle("-", fontsize=36, fontweight='bold')
            plt.ion()
            plt.show()
            plt.subplots_adjust(top=0.7)
            plotImgs = []

        if drawReconsturction:
            fig = plt.figure(2)
            txt = fig.suptitle("-", fontsize=36, fontweight='bold')
            plt.ion()
            plt.show()

        if preTraining:
            for epoch_r in range(1,preTraining_epoch):
                nextX, _ = dataset.train.next_batch(batch_size)
                nextX_orig = nextX
                if translateMnist:
                    nextX, _ = convertTranslated(nextX, MNIST_SIZE, img_size)

                fetches_r = [reconstructionCost, reconstruction, train_op_r]

                reconstructionCost_fetched, reconstruction_fetched, train_op_r_fetched = sess.run(fetches_r, feed_dict={inputs_placeholder: nextX})

                if epoch_r % 20 == 0:
                    print(('Step %d: reconstructionCost = %.5f' % (epoch_r, reconstructionCost_fetched)))
                    if epoch_r % 100 == 0:
                        if drawReconsturction:
                            fig = plt.figure(2)

                            plt.subplot(1, 2, 1)
                            plt.imshow(np.reshape(nextX[0, :], [img_size, img_size]),
                                       cmap=plt.get_cmap('gray'), interpolation="nearest")
                            plt.ylim((img_size - 1, 0))
                            plt.xlim((0, img_size - 1))

                            plt.subplot(1, 2, 2)
                            plt.imshow(np.reshape(reconstruction_fetched[0, :], [img_size, img_size]),
                                       cmap=plt.get_cmap('gray'), interpolation="nearest")
                            plt.ylim((img_size - 1, 0))
                            plt.xlim((0, img_size - 1))
                            plt.draw()
                            plt.pause(0.0001)
                            # plt.show()


        # training
        for epoch in range(start_step + 1, max_iters):
            start_time = time.time()

            # get the next batch of examples
            nextX, nextY = dataset.train.next_batch(batch_size)
            nextX_orig = nextX
            if translateMnist:
                nextX, nextX_coord = convertTranslated(nextX, MNIST_SIZE, img_size)

            feed_dict = {inputs_placeholder: nextX, labels_placeholder: nextY, \
                         onehot_labels_placeholder: dense_to_one_hot(nextY)}

            fetches = [train_op, cost, reward, predicted_labels, correct_labels, glimpse_images, avg_b, rminusb, \
                       mean_locs, sampled_locs, lr]
            # feed them to the model
            results = sess.run(fetches, feed_dict=feed_dict)

            _, cost_fetched, reward_fetched, prediction_labels_fetched, correct_labels_fetched, glimpse_images_fetched, \
            avg_b_fetched, rminusb_fetched, mean_locs_fetched, sampled_locs_fetched, lr_fetched = results


            duration = time.time() - start_time

            if epoch % 20 == 0:
                print(('Step %d: cost = %.5f reward = %.5f (%.3f sec) b = %.5f R-b = %.5f, LR = %.5f'
                      % (epoch, cost_fetched, reward_fetched, duration, avg_b_fetched, rminusb_fetched, lr_fetched)))
                summary_str = sess.run(summary_op, feed_dict=feed_dict)
                summary_writer.add_summary(summary_str, epoch)
                # if saveImgs:
                #     plt.savefig(imgsFolderName + simulationName + '_ep%.6d.png' % (epoch))

                if epoch % 5000 == 0:
                    saver.save(sess, save_dir + save_prefix + str(epoch) + ".ckpt")
                    evaluate()

                ##### DRAW WINDOW ################
                f_glimpse_images = np.reshape(glimpse_images_fetched, \
                                              (nGlimpses, batch_size, depth, sensorBandwidth, sensorBandwidth))

                if draw:
                    if animate:
                        fillList = False
                        if len(plotImgs) == 0:
                            fillList = True

                        # display the first image in the in mini-batch
                        nCols = depth+1
                        plt.subplot2grid((depth, nCols), (0, 1), rowspan=depth, colspan=depth)
                        # display the entire image
                        plotWholeImg(nextX[0, :], img_size, sampled_locs_fetched)

                        # display the glimpses
                        for y in range(nGlimpses):
                            txt.set_text('Epoch: %.6d \nPrediction: %i -- Truth: %i\nStep: %i/%i'
                                         % (epoch, prediction_labels_fetched[0], correct_labels_fetched[0], (y + 1), nGlimpses))

                            for x in range(depth):
                                plt.subplot(depth, nCols, 1 + nCols * x)
                                if fillList:
                                    plotImg = plt.imshow(f_glimpse_images[y, 0, x], cmap=plt.get_cmap('gray'),
                                                         interpolation="nearest")
                                    plotImg.autoscale()
                                    plotImgs.append(plotImg)
                                else:
                                    plotImgs[x].set_data(f_glimpse_images[y, 0, x])
                                    plotImgs[x].autoscale()
                            fillList = False

                            # fig.canvas.draw()
                            time.sleep(0.1)
                            plt.pause(0.00005)

                    else:
                        txt.set_text('PREDICTION: %i\nTRUTH: %i' % (prediction_labels_fetched[0], correct_labels_fetched[0]))
                        for x in range(depth):
                            for y in range(nGlimpses):
                                plt.subplot(depth, nGlimpses, x * nGlimpses + y + 1)
                                plt.imshow(f_glimpse_images[y, 0, x], cmap=plt.get_cmap('gray'), interpolation="nearest")

Action="{runs}"
="on"

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant