|
#include <assert.h> |
|
#include <pthread.h> |
|
#include <stdatomic.h> |
|
#include <stdint.h> |
|
#include <string.h> |
|
|
|
#include "connect4.h" |
|
|
|
void node_push(Node *node, |
|
int32_t parent_i, |
|
int32_t parent_action_i, |
|
const Connect4State *state, |
|
int32_t depth) { |
|
assert(node->in_use == false); |
|
node->in_use = true; |
|
|
|
node->parent_i = parent_i; |
|
node->parent_action_i = parent_action_i; |
|
node->depth = depth; |
|
node->state = *state; |
|
|
|
// Actions are not computed on push. |
|
// They are computed by node_init_actions. |
|
node->actions_count = 0; |
|
atomic_store(&node->finished_actions_count, 0); |
|
|
|
for (int col = 0; col < COLS; col++) { |
|
node->actions[col] = 0; |
|
node->action_scores[col] = 0; |
|
node->action_pushed[col] = false; |
|
} |
|
} |
|
|
|
void node_pop(Node *node) { |
|
assert(node->in_use == true); |
|
node->in_use = false; |
|
} |
|
|
|
void node_init_actions(Node *node) { |
|
assert(node->in_use == true); |
|
assert(node->actions_count == 0); |
|
assert(atomic_load(&node->finished_actions_count) == 0); |
|
// Find valid actions. |
|
for (uint8_t col = 0; col < COLS; col++) { |
|
uint8_t current_player = node->state.current_player; |
|
if (connect4_check_action(&node->state, current_player, col)) { |
|
int32_t action_i = node->actions_count++; |
|
node->actions[action_i] = col; |
|
|
|
// For terminal states compute the score right away. |
|
if (node->depth == LAYER_COUNT - 1) { |
|
node->action_scores[action_i] = 0; |
|
node->action_pushed[action_i] = true; |
|
atomic_fetch_add(&node->finished_actions_count, 1); |
|
} else { |
|
connect4_apply_action(&node->state, current_player, col); |
|
if (node->state.status.kind == CONNECT4_OVER) { |
|
float score; |
|
if (node->state.status.winner == current_player) { |
|
score = 1; |
|
} else { |
|
score = -1; |
|
} |
|
node->action_scores[action_i] = score; |
|
node->action_pushed[action_i] = true; |
|
atomic_fetch_add(&node->finished_actions_count, 1); |
|
} |
|
connect4_undo_action(&node->state, current_player, col); |
|
} |
|
} |
|
} |
|
assert(node->actions_count > 0); |
|
} |
|
|
|
int32_t node_count_children_to_push(Node *node) { |
|
assert(node->depth < LAYER_COUNT - 1); |
|
int32_t children_to_push = 0; |
|
for (int action_i = 0; action_i < node->actions_count; action_i++) { |
|
if (!node->action_pushed[action_i]) { |
|
children_to_push++; |
|
} |
|
} |
|
return children_to_push; |
|
} |
|
|
|
int32_t node_best_action(Node *node) { |
|
assert(node->parent_i == -1); |
|
assert(node->parent_action_i == -1); |
|
assert(node->depth == 0); |
|
assert(atomic_load(&node->finished_actions_count) == node->actions_count); |
|
|
|
int best_action = node->actions[0]; |
|
float best_score = node->action_scores[0]; |
|
for (int action_i = 1; action_i < node->actions_count; action_i++) { |
|
if (node->action_scores[action_i] >= best_score) { |
|
best_action = node->actions[action_i]; |
|
best_score = node->action_scores[action_i]; |
|
} |
|
} |
|
|
|
return best_action; |
|
} |
|
|
|
float node_score(Node *node) { |
|
assert(node->parent_i != -1); |
|
assert(node->parent_action_i != -1); |
|
assert(node->depth > 0); |
|
assert(atomic_load(&node->finished_actions_count) == node->actions_count); |
|
|
|
float total_score = 0; |
|
for (int action_i = 0; action_i < node->actions_count; action_i++) { |
|
total_score += node->action_scores[action_i]; |
|
} |
|
float score = total_score / node->actions_count; |
|
return score; |
|
} |
|
|
|
void connect4_ai_action(int32_t thread_i, |
|
_Atomic(int32_t) *current_layer, |
|
const Connect4State *state, |
|
Layers *layers, |
|
NewNodes *new_nodes, |
|
pthread_barrier_t *thread_barrier, |
|
_Atomic(int32_t) *result) { |
|
assert(state->status.kind != CONNECT4_OVER); |
|
|
|
uint8_t current_player = state->current_player; |
|
|
|
int rc; |
|
rc = pthread_barrier_wait(thread_barrier); |
|
if (rc == PTHREAD_BARRIER_SERIAL_THREAD) { |
|
memset(layers, 0, sizeof(*layers)); |
|
memset(new_nodes, 0, sizeof(*new_nodes)); |
|
|
|
layers->layer[0].count = 1; |
|
Node *root = &layers->layer[0].node[0]; |
|
node_push(root, -1, -1, state, 0); |
|
|
|
atomic_store(result, -1); |
|
|
|
atomic_store(current_layer, 0); |
|
} else { |
|
assert(rc == 0); |
|
} |
|
|
|
// @TODO: Should be able to prune on wins and losses. |
|
// I can also def prune on wins and losses at least, Since those are |
|
// precomputed, no reason to check other actions if one move is a win. |
|
|
|
while (true) { |
|
rc = pthread_barrier_wait(thread_barrier); |
|
assert(rc == 0 || rc == PTHREAD_BARRIER_SERIAL_THREAD); |
|
|
|
if (atomic_load(result) != -1) { |
|
break; |
|
} |
|
|
|
int32_t layer_i = atomic_load(current_layer); |
|
assert(layer_i >= 0); |
|
assert(layer_i < LAYER_COUNT); |
|
Layer *layer = &layers->layer[layer_i]; |
|
|
|
// Each thread gets a distinct view of the current layer. |
|
int32_t values_per_thread = layer->count / THREAD_COUNT; |
|
int32_t leftover_count = layer->count % THREAD_COUNT; |
|
bool thread_has_leftover = thread_i < leftover_count; |
|
int32_t leftovers_before_me = |
|
(thread_has_leftover ? thread_i : leftover_count); |
|
int32_t thread_start = values_per_thread * thread_i + leftovers_before_me; |
|
int32_t thread_after = |
|
thread_start + values_per_thread + !!thread_has_leftover; |
|
|
|
int32_t unpushed_child_count = 0; |
|
|
|
rc = pthread_barrier_wait(thread_barrier); |
|
assert(rc == 0 || rc == PTHREAD_BARRIER_SERIAL_THREAD); |
|
|
|
for (int node_i = thread_start; node_i < thread_after; node_i++) { |
|
Node *node = &layer->node[node_i]; |
|
if (!node->in_use) { |
|
continue; |
|
} |
|
|
|
assert(node->state.status.kind != CONNECT4_OVER); |
|
|
|
if (node->actions_count == 0) { |
|
node_init_actions(node); |
|
} |
|
assert(node->actions_count > 0); |
|
|
|
if (atomic_load(&node->finished_actions_count) == node->actions_count) { |
|
if (node->depth == 0) { |
|
int32_t best_action = node_best_action(node); |
|
atomic_store(result, best_action); |
|
} else { |
|
float score = node_score(node); |
|
|
|
Node *parent = &layers->layer[layer_i - 1].node[node->parent_i]; |
|
|
|
assert(node->parent_i != -1); |
|
assert(parent->actions_count >= node->parent_action_i); |
|
assert(atomic_load(&parent->finished_actions_count) < |
|
parent->actions_count); |
|
|
|
parent->action_scores[node->parent_action_i] = -score; |
|
atomic_fetch_add(&parent->finished_actions_count, 1); |
|
} |
|
node_pop(node); |
|
|
|
} else { |
|
unpushed_child_count += node_count_children_to_push(node); |
|
} |
|
} |
|
|
|
new_nodes->unpushed_child_count[thread_i] = unpushed_child_count; |
|
|
|
// On one thread, allocate each thread slots for it's child nodes. |
|
rc = pthread_barrier_wait(thread_barrier); |
|
if (rc == PTHREAD_BARRIER_SERIAL_THREAD) { |
|
uint32_t total_unpushed_child_count = 0; |
|
|
|
for (int alloc_thread_i = 0; alloc_thread_i < THREAD_COUNT; |
|
alloc_thread_i++) { |
|
total_unpushed_child_count += |
|
new_nodes->unpushed_child_count[alloc_thread_i]; |
|
new_nodes->new_node_index[alloc_thread_i].count = 0; |
|
} |
|
|
|
if (layer_i == LAYER_COUNT - 1) { |
|
assert(total_unpushed_child_count == 0); |
|
} else { |
|
Layer *next_layer = &layers->layer[layer_i + 1]; |
|
|
|
int32_t new_node_i = 0; |
|
|
|
for (int alloc_thread_i = 0; alloc_thread_i < THREAD_COUNT; |
|
alloc_thread_i++) { |
|
while (new_nodes->new_node_index[alloc_thread_i].count < |
|
new_nodes->unpushed_child_count[alloc_thread_i] && |
|
new_node_i < MAX_NODES_PER_LAYER) { |
|
if (!next_layer->node[new_node_i].in_use) { |
|
int32_t index_i = |
|
new_nodes->new_node_index[alloc_thread_i].count++; |
|
new_nodes->new_node_index[alloc_thread_i].index[index_i] = |
|
new_node_i; |
|
} |
|
new_node_i++; |
|
} |
|
} |
|
// Update count if we allocated nodes past the current count. |
|
if (new_node_i > next_layer->count) { |
|
next_layer->count = new_node_i; |
|
} |
|
} |
|
} else { |
|
assert(rc == 0); |
|
} |
|
|
|
// Wait for child node slots to be allocated, then push as many children |
|
// as were allocated. |
|
rc = pthread_barrier_wait(thread_barrier); |
|
assert(rc == 0 || rc == PTHREAD_BARRIER_SERIAL_THREAD); |
|
|
|
int32_t children_to_push = new_nodes->new_node_index[thread_i].count; |
|
int32_t children_pushed = 0; |
|
|
|
assert(children_to_push <= unpushed_child_count); |
|
|
|
// Loop through nodes second time to push new nodes to the next layer. |
|
for (int node_i = thread_start; node_i < thread_after; node_i++) { |
|
if (children_pushed == children_to_push) { |
|
break; |
|
} |
|
|
|
Node *node = &layer->node[node_i]; |
|
if (!node->in_use) { |
|
continue; |
|
} |
|
assert(node->actions_count != 0); |
|
|
|
for (int action_i = 0; action_i < node->actions_count; action_i++) { |
|
if (!node->action_pushed[action_i]) { |
|
if (children_pushed == children_to_push) { |
|
break; |
|
} |
|
|
|
int32_t child_i = |
|
new_nodes->new_node_index[thread_i].index[children_pushed++]; |
|
Node *child = &layers->layer[layer_i + 1].node[child_i]; |
|
|
|
Connect4State child_state = node->state; |
|
connect4_apply_action(&child_state, |
|
child_state.current_player, |
|
node->actions[action_i]); |
|
|
|
node_push(child, node_i, action_i, &child_state, node->depth + 1); |
|
node->action_pushed[action_i] = true; |
|
} |
|
} |
|
} |
|
|
|
rc = pthread_barrier_wait(thread_barrier); |
|
if (rc == PTHREAD_BARRIER_SERIAL_THREAD) { |
|
bool any_children_pushed = false; |
|
for (int i = 0; i < THREAD_COUNT; i++) { |
|
if (new_nodes->new_node_index[i].count > 0) { |
|
any_children_pushed = true; |
|
break; |
|
} |
|
} |
|
if (any_children_pushed) { |
|
atomic_store(current_layer, layer_i + 1); |
|
} else { |
|
atomic_store(current_layer, layer_i - 1); |
|
} |
|
} else { |
|
assert(rc == 0); |
|
} |
|
} |
|
} |
|
|
|
// @NOTE: |
|
// Assumes that barriers are already initialized and the same reference |
|
// accross threads. Assumes all atomic references are to the same atomics |
|
// across threads. These threads run in the background and wait for a new |
|
// input to process. |
|
// To trigger a run. |
|
// - Set ai_turn_started and ai_turn_input. |
|
// - wait on the turn barrier to release the threads to begin processing. |
|
// - then do whatever you want |
|
// To check if the run is done look at ai_turn_completed, when it's the same |
|
// value as the ai_turn_started you set, the run is done. |
|
// To get the result wait on the turn barrier again to be sure the result was |
|
// set and then it'll be availabe in ai_turn_result. |
|
// To shutdown the threads (when they aren't working but waiting for a new |
|
// input.) |
|
// - set shutdown to 1 |
|
// - wait on the barrier once more to start them back up, they'll see the |
|
// shutdown flag and return. then nobody will be waiting on the barrier so you |
|
// can clean it and the threads up. |
|
void *connect4_ai_worker_thread_main(void *arg) { |
|
Connect4ThreadState *thread_state = (Connect4ThreadState *)arg; |
|
|
|
// printf("thread %i: init\n", thread_state->thread_i); |
|
|
|
while (true) { |
|
// Wait for a new AI turn. |
|
int rc; |
|
rc = pthread_barrier_wait(thread_state->turn_barrier); |
|
assert(rc == 0 || rc == PTHREAD_BARRIER_SERIAL_THREAD); |
|
|
|
if (atomic_load(thread_state->shutdown)) { |
|
break; |
|
} |
|
|
|
rc = pthread_barrier_wait(thread_state->thread_barrier); |
|
assert(rc == 0 || rc == PTHREAD_BARRIER_SERIAL_THREAD); |
|
|
|
// call the function to calculate the answer. |
|
connect4_ai_action(thread_state->thread_i, |
|
thread_state->current_layer, |
|
thread_state->ai_turn_input, |
|
thread_state->layers, |
|
thread_state->new_nodes, |
|
thread_state->thread_barrier, |
|
thread_state->ai_turn_result); |
|
// printf("thread %i: workin\n", thread_state->thread_i); |
|
|
|
rc = pthread_barrier_wait(thread_state->thread_barrier); |
|
if (rc == PTHREAD_BARRIER_SERIAL_THREAD) { |
|
// Signal that we're done. |
|
atomic_store(thread_state->ai_turn_completed, |
|
atomic_load(thread_state->ai_turn_started)); |
|
} else { |
|
assert(rc == 0); |
|
} |
|
|
|
// Sync with main thread to "return". |
|
rc = pthread_barrier_wait(thread_state->turn_barrier); |
|
assert(rc == 0 || rc == PTHREAD_BARRIER_SERIAL_THREAD); |
|
} |
|
|
|
// printf("thread %i: shutdown\n", thread_state->thread_i); |
|
|
|
return NULL; |
|
} |