Skip to content

Instantly share code, notes, and snippets.

@kimuchi1203
Created November 13, 2015 16:12
Show Gist options
  • Select an option

  • Save kimuchi1203/c49f14cb218754e96466 to your computer and use it in GitHub Desktop.

Select an option

Save kimuchi1203/c49f14cb218754e96466 to your computer and use it in GitHub Desktop.
Answer the maximum among the three 0-9
import numpy as np
from chainer import optimizers, FunctionSet, Variable
import chainer.functions as F
import random
import copy
model = FunctionSet(
l1 = F.Linear(30, 10)
)
optimizer = optimizers.Adam()
optimizer.setup(model)
def forward(x):
vx = Variable(x)
q = model.l1(vx)
return q
def update(q, a, r):
target = copy.deepcopy(q.data)
target[0][a] = r
loss = F.mean_squared_error(q, Variable(target))
optimizer.zero_grads()
loss.backward()
optimizer.update()
time = 0
q = None
point = 0
state = [[], [], []]
while time < 1000:
state[0] = [0 for x in range(10)]
state[1] = [0 for x in range(10)]
state[2] = [0 for x in range(10)]
i = random.randint(0, 9)
state[0][i] = 1
j = random.randint(0, 9)
state[1][j] = 1
k = random.randint(0, 9)
state[2][k] = 1
answer = np.max([i, j, k])
q = forward(np.asarray(state, dtype=np.float32).reshape(1, 30))
action = np.argmax(q.data)
reward = 0
if answer == action:
reward = 1
point += 1
print('({0}, {1}, {2} -> {3} {4})'.format(i, j, k, action, q.data.tolist()[0]))
update(q, action, reward)
time += 1
print(point)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment