Skip to content
Merged

go #519

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions ocean/go/binding.c
Original file line number Diff line number Diff line change
@@ -1,25 +1,26 @@
#include "go.h"
#define OBS_SIZE 100
// 9x9 - obs 326, act 82
// 13x13 - obs 678, act 170
// 19x19 - obs 1446, act 362
#define OBS_SIZE 326
#define NUM_ATNS 1
#define ACT_SIZES {50}
#define OBS_TYPE FLOAT
#define ACT_TYPE DOUBLE
#define ACT_SIZES {82}
#define OBS_TENSOR_T FloatTensor

#define Env CGo
#include "vecenv.h"

void my_init(Env* env, Dict* kwargs) {
env->num_agents = 1;
env->side = (rand_r(&env->rng) % 2) + 1;
env->selfplay = dict_get(kwargs, "selfplay")->value;
env->width = dict_get(kwargs, "width")->value;
env->height = dict_get(kwargs, "height")->value;
env->grid_size = dict_get(kwargs, "grid_size")->value;
env->board_width = dict_get(kwargs, "board_width")->value;
env->board_height = dict_get(kwargs, "board_height")->value;
env->grid_square_size = dict_get(kwargs, "grid_square_size")->value;
env->moves_made = dict_get(kwargs, "moves_made")->value;
env->komi = dict_get(kwargs, "komi")->value;
env->score = dict_get(kwargs, "score")->value;
env->last_capture_position = dict_get(kwargs, "last_capture_position")->value;
env->reward_move_pass = dict_get(kwargs, "reward_move_pass")->value;
env->reward_move_invalid = dict_get(kwargs, "reward_move_invalid")->value;
env->reward_move_valid = dict_get(kwargs, "reward_move_valid")->value;
Expand Down
Loading
Loading