From ada77446c42ffc0e325b8dc67007de83ca4651b9 Mon Sep 17 00:00:00 2001 From: Julien Dehos Date: Wed, 20 Jan 2021 18:45:57 +0100 Subject: [PATCH 01/11] fix compilation for pytorch 1.2 --- CMakeLists.txt | 2 +- src/core/actor.h | 4 ++++ src/distributed/distributed.cc | 1 + 3 files changed, 6 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 614f945f..43fb84be 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,7 +55,7 @@ add_subdirectory(src/games/minesweeper_csp_vkms) # tests add_executable(test_state src/core/test_state.cc src/core/state.cc) -target_link_libraries(test_state PUBLIC _tube _mcts _games ${JNI_LIBRARIES}) +target_link_libraries(test_state PUBLIC _tube _mcts _games ${JNI_LIBRARIES} -lpython3) enable_testing() diff --git a/src/core/actor.h b/src/core/actor.h index fd560025..6906bfe8 100644 --- a/src/core/actor.h +++ b/src/core/actor.h @@ -116,7 +116,11 @@ class Actor { torch::Tensor policy; if (useValue_ && resultsAreValid) { if (logitValue_) { +#ifdef PYTORCH12 + float* begin = value_->data.data(); +#else float* begin = value_->data.data_ptr(); +#endif float* end = begin + 3; softmax_(begin, end); } diff --git a/src/distributed/distributed.cc b/src/distributed/distributed.cc index a3a21124..37bc046f 100644 --- a/src/distributed/distributed.cc +++ b/src/distributed/distributed.cc @@ -13,6 +13,7 @@ #include #include +#include #include #include #include From d7d5450691527bd9867a619b10f03a0a8858c904 Mon Sep 17 00:00:00 2001 From: Julien Dehos Date: Thu, 21 Jan 2021 13:59:08 +0100 Subject: [PATCH 02/11] CI + CMake + tests (wip) --- .circleci/config.yml | 6 +- CMakeLists.txt | 2 +- tests/CMakeLists.txt | 27 +++---- tests/connectfour-tests.cc | 31 ++++---- tests/havannah-state-tests.cc | 25 ++++--- tests/havannah-tests.cc | 2 +- tests/hex-state-tests.cc | 57 +++++++++------ tests/hex-tests.cc | 2 +- tests/ludii-game-tests.cc | 131 ++-------------------------------- tests/tests.cc | 3 + 10 files changed, 94 insertions(+), 192 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 327d9063..40927ad6 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -138,7 +138,7 @@ jobs: conda activate pypg mkdir build cd build - cmake .. + cmake -DPYTORCH12=ON .. make -j 2 - run: @@ -150,7 +150,7 @@ jobs: wget -P ludii https://ludii.games/downloads/Ludii.jar mkdir tests/build cd tests/build - cmake .. + cmake -DPYTORCH12=ON .. make -j 2 - run: @@ -161,7 +161,7 @@ jobs: - run: name: Test polygames-tests (unit tests) command: | - ./tests/build/polygames-tests + ./tests/build/polygames-tests ludii/Ludii.jar - run: name: Test Mcts diff --git a/CMakeLists.txt b/CMakeLists.txt index 43fb84be..40818077 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,7 +55,7 @@ add_subdirectory(src/games/minesweeper_csp_vkms) # tests add_executable(test_state src/core/test_state.cc src/core/state.cc) -target_link_libraries(test_state PUBLIC _tube _mcts _games ${JNI_LIBRARIES} -lpython3) +target_link_libraries(test_state PUBLIC _tube _mcts _games ${JNI_LIBRARIES} python3) enable_testing() diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index a2dd2ca5..fd1847f1 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -2,6 +2,14 @@ cmake_minimum_required( VERSION 3.3 ) project( polygames-tests ) set(CMAKE_CXX_STANDARD 17) +OPTION(PYTORCH12 "Is PyTorch >= 1.2" OFF) +OPTION(PYTORCH15 "Is PyTorch >= 1.5" OFF) +IF(PYTORCH15) + ADD_DEFINITIONS(-DPYTORCH15 -DPYTORCH12) +ELSEIF(PYTORCH12) + ADD_DEFINITIONS(-DPYTORCH12) +ENDIF() + execute_process( COMMAND python -c "import torch; import os; print(os.path.dirname(torch.__file__), end='')" OUTPUT_VARIABLE TorchPath @@ -22,20 +30,9 @@ include_directories( ${GTEST_INCLUDE_DIRS} ) find_package(JNI REQUIRED) include_directories( ${JNI_INCLUDE_DIRS}) -include_directories( - ../games - ../torchRL - ../torchRL/third_party/fmt/include - ../torchRL/tube/src_cpp - ) +add_subdirectory(../src build-libpolygames) add_executable( polygames-tests - ../core/game.cc - ../core/state.cc - ../torchRL/mcts/mcts.cc - ../torchRL/mcts/node.cc - ../torchRL/tube/src_cpp/data_channel.cc - ../torchRL/tube/src_cpp/replay_buffer.cc tests.cc # Include your tests here. @@ -44,14 +41,12 @@ add_executable( polygames-tests havannah-tests.cc hex-state-tests.cc hex-tests.cc - ludii-game-tests.cc - ../games/ludii/jni_utils.cc - ../games/ludii/ludii_game_wrapper.cc - ../games/ludii/ludii_state_wrapper.cc ) target_link_libraries( polygames-tests + libpolygames + python3 ${CMAKE_THREAD_LIBS_INIT} ${GTEST_LIBRARIES} ${JNI_LIBRARIES} diff --git a/tests/connectfour-tests.cc b/tests/connectfour-tests.cc index a6bbb306..1908dacc 100644 --- a/tests/connectfour-tests.cc +++ b/tests/connectfour-tests.cc @@ -11,8 +11,10 @@ #include #include "utils.h" -#include +#include +/* + TODO TEST(Connectfour, init_1) { StateForConnectFour state(0); @@ -23,12 +25,12 @@ TEST(Connectfour, init_1) { ASSERT_EQ(GameStatus::player0Turn, GameStatus(state.getCurrentPlayer())); for (int i=0; i<7; ++i) { - auto a_i = std::dynamic_pointer_cast(state.GetLegalActions()[i]); - ASSERT_EQ(i, a_i->GetX()); - ASSERT_EQ(0, a_i->GetY()); - ASSERT_EQ(0, a_i->GetZ()); - ASSERT_EQ(i, a_i->GetHash()); - ASSERT_EQ(i, a_i->GetIndex()); + auto a_i = state.GetLegalActions()[i]; + ASSERT_EQ(i, a_i.GetX()); + ASSERT_EQ(0, a_i.GetY()); + ASSERT_EQ(0, a_i.GetZ()); + ASSERT_EQ(i, a_i.GetHash()); + ASSERT_EQ(i, a_i.GetIndex()); } std::vector expectedFeatures { @@ -77,7 +79,7 @@ TEST(Connectfour, play_1) { StateForConnectFour state(0); state.Initialize(); - ActionForConnectFour action(1, 7); + _Action action(1, 7, 0, 0); state.ApplyAction(action); ASSERT_EQ((std::vector{3, 6, 7}), state.GetFeatureSize()); @@ -85,12 +87,12 @@ TEST(Connectfour, play_1) { ASSERT_EQ(GameStatus::player1Turn, GameStatus(state.getCurrentPlayer())); for (int i=0; i<7; ++i) { - auto a_i = std::dynamic_pointer_cast(state.GetLegalActions()[i]); - ASSERT_EQ(i, a_i->GetX()); - ASSERT_EQ(0, a_i->GetY()); - ASSERT_EQ(0, a_i->GetZ()); - ASSERT_EQ(i, a_i->GetHash()); - ASSERT_EQ(i, a_i->GetIndex()); + auto a_i = state.GetLegalActions()[i]; + ASSERT_EQ(i, a_i.GetX()); + ASSERT_EQ(0, a_i.GetY()); + ASSERT_EQ(0, a_i.GetZ()); + ASSERT_EQ(i, a_i.GetHash()); + ASSERT_EQ(i, a_i.GetIndex()); } std::vector expectedFeatures { @@ -131,4 +133,5 @@ TEST(Connectfour, play_1) { } +*/ diff --git a/tests/havannah-state-tests.cc b/tests/havannah-state-tests.cc index 78676c63..cfcd8c44 100644 --- a/tests/havannah-state-tests.cc +++ b/tests/havannah-state-tests.cc @@ -7,7 +7,7 @@ // Unit tests for Havannah Action/State. -#include +#include #include #include "utils.h" @@ -20,9 +20,14 @@ namespace Havannah { template class StateTest : public Havannah::State { public: + core::FeatureOptions _opts; StateTest(int seed, int history, bool turnFeatures) : - Havannah::State(seed, history, turnFeatures) {} - GameStatus GetStatus() { return ::State::_status; }; + Havannah::State(seed) { + _opts.history = history; + _opts.turnFeaturesMultiChannel = turnFeatures; + core::State::setFeatures(&_opts); + } + GameStatus GetStatus() { return core::State::_status; }; }; }; @@ -32,6 +37,9 @@ namespace Havannah { // unit tests /////////////////////////////////////////////////////////////////////////////// +/* +TODO + TEST(HavannahStateGroup, init_0) { const int size = 5; @@ -186,11 +194,11 @@ TEST(HavannahStateGroup, init_2) { int i = expectedAction.first; int j = expectedAction.second; int h = i*fullsize + j; - ASSERT_EQ(0, action->GetX()); - ASSERT_EQ(i, action->GetY()); - ASSERT_EQ(j, action->GetZ()); - ASSERT_EQ(h, action->GetHash()); - ASSERT_EQ(k, action->GetIndex()); + ASSERT_EQ(0, action.GetX()); + ASSERT_EQ(i, action.GetY()); + ASSERT_EQ(j, action.GetZ()); + ASSERT_EQ(h, action.GetHash()); + ASSERT_EQ(k, action.GetIndex()); } } @@ -812,4 +820,5 @@ TEST(HavannahStateGroup, features_3_nopie) { } +*/ diff --git a/tests/havannah-tests.cc b/tests/havannah-tests.cc index 4628ecb6..996ad20d 100644 --- a/tests/havannah-tests.cc +++ b/tests/havannah-tests.cc @@ -7,7 +7,7 @@ // Unit tests for the Havannah game. -#include +#include #include #include "utils.h" diff --git a/tests/hex-state-tests.cc b/tests/hex-state-tests.cc index 79e704ca..6f7b9273 100644 --- a/tests/hex-state-tests.cc +++ b/tests/hex-state-tests.cc @@ -7,7 +7,7 @@ // Unit tests for Hex Action/State. -#include +#include #include #include "utils.h" @@ -19,9 +19,15 @@ namespace Hex { template class StateTest : public Hex::State { public: + core::FeatureOptions _opts; StateTest(int seed, int history, bool turnFeatures) : - Hex::State(seed, history, turnFeatures) {} - GameStatus GetStatus() { return ::State::_status; }; + Hex::State(seed) { + _opts.history = history; + _opts.turnFeaturesMultiChannel = turnFeatures; + core::State::setFeatures(&_opts); + } + GameStatus GetStatus() { return core::State::_status; } + void addAction(int x, int y, int z) { core::State::addAction(x, y, z); } }; }; @@ -31,6 +37,9 @@ namespace Hex { // unit tests /////////////////////////////////////////////////////////////////////////////// +/* + TODO + TEST(HexStateGroup, init_1) { Hex::StateTest<7,true> state(0, 0, false); @@ -50,11 +59,11 @@ TEST(HexStateGroup, init_1) { int i = k / 7; int j = k % 7; auto a = state.GetLegalActions()[k]; - ASSERT_EQ(0, a->GetX()); - ASSERT_EQ(i, a->GetY()); - ASSERT_EQ(j, a->GetZ()); - ASSERT_EQ(k, a->GetHash()); - ASSERT_EQ(k, a->GetIndex()); + ASSERT_EQ(0, a.GetX()); + ASSERT_EQ(i, a.GetY()); + ASSERT_EQ(j, a.GetZ()); + ASSERT_EQ(k, a.GetHash()); + ASSERT_EQ(k, a.GetIndex()); } } @@ -63,8 +72,9 @@ TEST(HexStateGroup, play_1) { Hex::StateTest<7,true> state(0, 0, false); - Hex::Action<7> a(2, 3, 2*7+3); - state.ApplyAction(a); + //TODO Hex::Action<7> a(2, 3, 2*7+3); + state.addAction(0, 2, 3); + state.ApplyAction(state.GetLegalActions()[0]); ASSERT_EQ(GameStatus::player1Turn, state.GetStatus()); @@ -91,27 +101,26 @@ TEST(HexStateGroup, play_1) { int k = i*7+j; if (k<2*7+3) { auto a = state.GetLegalActions()[k]; - ASSERT_EQ(0, a->GetX()); - ASSERT_EQ(i, a->GetY()); - ASSERT_EQ(j, a->GetZ()); - ASSERT_EQ(k, a->GetHash()); - ASSERT_EQ(k, a->GetIndex()); + ASSERT_EQ(0, a.GetX()); + ASSERT_EQ(i, a.GetY()); + ASSERT_EQ(j, a.GetZ()); + ASSERT_EQ(k, a.GetHash()); + ASSERT_EQ(k, a.GetIndex()); } else if (k>2*7+3) { int k2 = k-1; auto a = state.GetLegalActions()[k2]; - ASSERT_EQ(0, a->GetX()); - ASSERT_EQ(i, a->GetY()); - ASSERT_EQ(j, a->GetZ()); - ASSERT_EQ(k2, a->GetHash()); - ASSERT_EQ(k2, a->GetIndex()); + ASSERT_EQ(0, a.GetX()); + ASSERT_EQ(i, a.GetY()); + ASSERT_EQ(j, a.GetZ()); + ASSERT_EQ(k2, a.GetHash()); + ASSERT_EQ(k2, a.GetIndex()); } } } } - TEST(HexStateGroup, clone_1) { try { @@ -123,8 +132,9 @@ TEST(HexStateGroup, clone_1) { ASSERT_EQ(49, state.GetLegalActions().size()); ASSERT_EQ(49, ptrClone->GetLegalActions().size()); - Hex::Action<7> a(2, 3, -1); - state.ApplyAction(a); + // TODO Hex::Action<7> a(2, 3, -1); + state.addAction(0, 2, 3); + state.ApplyAction(state.GetLegalActions()[0]); ASSERT_EQ(49, state.GetLegalActions().size()); ASSERT_EQ(49, ptrClone->GetLegalActions().size()); @@ -448,4 +458,5 @@ TEST(HexStateGroup, features_3) { // } } +*/ diff --git a/tests/hex-tests.cc b/tests/hex-tests.cc index a7348907..4ead2a73 100644 --- a/tests/hex-tests.cc +++ b/tests/hex-tests.cc @@ -7,7 +7,7 @@ // Unit tests for the Hex game. -#include +#include #include #include "utils.h" diff --git a/tests/ludii-game-tests.cc b/tests/ludii-game-tests.cc index 6e645f27..9fa26079 100644 --- a/tests/ludii-game-tests.cc +++ b/tests/ludii-game-tests.cc @@ -5,18 +5,20 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include +#include +#include #include "utils.h" #include +extern std::string LUDII_PATH; + /////////////////////////////////////////////////////////////////////////////// // unit tests /////////////////////////////////////////////////////////////////////////////// TEST(LudiiGameGroup, ludii_yavalath_0) { - Ludii::JNIUtils::InitJVM(""); // Use default /ludii/Ludii.jar path + Ludii::JNIUtils::InitJVM(LUDII_PATH); JNIEnv* jni_env = Ludii::JNIUtils::GetEnv(); EXPECT_TRUE(jni_env); @@ -121,127 +123,6 @@ TEST(LudiiGameGroup, ludii_yavalath_0) { ASSERT_EQ(0, features[i]); ++i; } -} /** - * Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. - * - * This source code is licensed under the MIT license found in the - * LICENSE file in the root directory of this source tree. - */ - -#include -#include - -#include "utils.h" -#include +} -/////////////////////////////////////////////////////////////////////////////// -// unit tests -/////////////////////////////////////////////////////////////////////////////// -TEST(LudiiGameGroup, ludii_yavalath_0) { - Ludii::JNIUtils::InitJVM(""); // Use default /ludii/Ludii.jar path - JNIEnv* jni_env = Ludii::JNIUtils::GetEnv(); - EXPECT_TRUE(jni_env); - - Ludii::LudiiGameWrapper game_wrapper("Yavalath.lud"); - Ludii::LudiiStateWrapper state = - Ludii::LudiiStateWrapper(0, std::move(game_wrapper)); - state.Initialize(); - - ASSERT_EQ((std::vector{10, 9, 17}), state.GetFeatureSize()); - ASSERT_EQ((std::vector{3, 9, 17}), state.GetActionSize()); - ASSERT_EQ(GameStatus::player0Turn, GameStatus(state.getCurrentPlayer())); - - // We expect the following meanings for Yavalath state tensor channels: - // 0: Piece Type 1 (Ball1) - // 1: Piece Type 2 (Ball2) - // 2: Is Player 1 the current mover? - // 3: Is Player 2 the current mover? - // 4: Did Swap Occur? - // 5: Does position exist in container 0 (Board)? - // 6: Last move's from-position - // 7: Last move's to-position - // 8: Second-to-last move's from-position - // 9: Second-to-last move's to-position - - // TODO guess we really need a channel to indicate that swap happened - const std::vector features = state.GetFeatures(); - - // We expect empty board initial state, so first two channels - // should be all-zero - size_t i = 0; - while (i < 2 * 9 * 17) { - ASSERT_EQ(0, features[i]); - ++i; - } - - // Player 1 should be mover, so expect channel filled with 1s next - while (i < 3 * 9 * 17) { - ASSERT_EQ(1, features[i]); - ++i; - } - - // Player 2 not current mover, so full channel of 0s - while (i < 4 * 9 * 17) { - ASSERT_EQ(0, features[i]); - ++i; - } - - // No swap occured yet, so expect full channel of 0s - while (i < 5 * 9 * 17) { - ASSERT_EQ(0, features[i]); - ++i; - } - - // Channel: Does position exist in container 0 (Board)? - // First and last column have 5 cells each, - // expected pattern: 0,0,0,0,1,0,1,0,1,0,1,0,1,0,0,0,0 - const float _5_cells_pattern[17] = { - 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0}; - for (size_t j = 0; j < 17; ++j) { - ASSERT_EQ(features[i + 0 * 17 + j], _5_cells_pattern[j]); - ASSERT_EQ(features[i + 8 * 17 + j], _5_cells_pattern[j]); - } - - // Second and second-to-last column have 6 cells each, - // expected pattern: 0,0,0,1,0,1,0,1,0,1,0,1,0,1,0,0,0 - const float _6_cells_pattern[17] = { - 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0}; - for (size_t j = 0; j < 17; ++j) { - ASSERT_EQ(features[i + 1 * 17 + j], _6_cells_pattern[j]); - ASSERT_EQ(features[i + 7 * 17 + j], _6_cells_pattern[j]); - } - - // Third and third-to-last column have 7 cells each, - // expected pattern: 0,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,0 - const float _7_cells_pattern[17] = { - 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0}; - for (size_t j = 0; j < 17; ++j) { - ASSERT_EQ(features[i + 2 * 17 + j], _7_cells_pattern[j]); - ASSERT_EQ(features[i + 6 * 17 + j], _7_cells_pattern[j]); - } - - // Fourth and fourth-to-last column have 8 cells each, - // expected pattern: 0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0 - const float _8_cells_pattern[17] = { - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0}; - for (size_t j = 0; j < 17; ++j) { - ASSERT_EQ(features[i + 3 * 17 + j], _8_cells_pattern[j]); - ASSERT_EQ(features[i + 5 * 17 + j], _8_cells_pattern[j]); - } - - // Middle column has 9 cells, - // expected pattern: 1,0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1 - const float _9_cells_pattern[17] = { - 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1}; - for (size_t j = 0; j < 17; ++j) { - ASSERT_EQ(features[i + 4 * 17 + j], _9_cells_pattern[j]); - } - i += 9 * 17; - - // All remaining channels should be all-zero; no moves played - while (i < 10 * 9 * 17) { - ASSERT_EQ(0, features[i]); - ++i; - } -} \ No newline at end of file diff --git a/tests/tests.cc b/tests/tests.cc index f5f9f61a..a0712494 100644 --- a/tests/tests.cc +++ b/tests/tests.cc @@ -9,7 +9,10 @@ #include +std::string LUDII_PATH = ""; + int main(int argc, char** argv) { + if (argc > 1) LUDII_PATH = argv[1]; ::testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } From 5a99b26acf00a89bc9c8f86675497f3b5c2e56f3 Mon Sep 17 00:00:00 2001 From: Julien Dehos Date: Thu, 21 Jan 2021 14:22:48 +0100 Subject: [PATCH 03/11] fix tests for connectfour --- tests/connectfour-tests.cc | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/connectfour-tests.cc b/tests/connectfour-tests.cc index 1908dacc..a28e5325 100644 --- a/tests/connectfour-tests.cc +++ b/tests/connectfour-tests.cc @@ -13,8 +13,6 @@ #include "utils.h" #include -/* - TODO TEST(Connectfour, init_1) { StateForConnectFour state(0); @@ -29,7 +27,7 @@ TEST(Connectfour, init_1) { ASSERT_EQ(i, a_i.GetX()); ASSERT_EQ(0, a_i.GetY()); ASSERT_EQ(0, a_i.GetZ()); - ASSERT_EQ(i, a_i.GetHash()); + ASSERT_EQ(0, a_i.GetHash()); ASSERT_EQ(i, a_i.GetIndex()); } @@ -73,13 +71,12 @@ TEST(Connectfour, init_1) { } - TEST(Connectfour, play_1) { StateForConnectFour state(0); state.Initialize(); - _Action action(1, 7, 0, 0); + _Action action(0, 1, 0, 0); state.ApplyAction(action); ASSERT_EQ((std::vector{3, 6, 7}), state.GetFeatureSize()); @@ -91,7 +88,7 @@ TEST(Connectfour, play_1) { ASSERT_EQ(i, a_i.GetX()); ASSERT_EQ(0, a_i.GetY()); ASSERT_EQ(0, a_i.GetZ()); - ASSERT_EQ(i, a_i.GetHash()); + ASSERT_EQ(0, a_i.GetHash()); ASSERT_EQ(i, a_i.GetIndex()); } @@ -133,5 +130,3 @@ TEST(Connectfour, play_1) { } -*/ - From 27ca2db7cf43eaf65f7a7c5c6c220656fc078aa0 Mon Sep 17 00:00:00 2001 From: Julien Dehos Date: Thu, 21 Jan 2021 15:56:13 +0100 Subject: [PATCH 04/11] fix tests for hex --- tests/hex-state-tests.cc | 157 +++++++++++++++------------------------ tests/utils.h | 11 +++ 2 files changed, 72 insertions(+), 96 deletions(-) diff --git a/tests/hex-state-tests.cc b/tests/hex-state-tests.cc index 6f7b9273..e75744b7 100644 --- a/tests/hex-state-tests.cc +++ b/tests/hex-state-tests.cc @@ -20,12 +20,18 @@ namespace Hex { template class StateTest : public Hex::State { public: core::FeatureOptions _opts; + StateTest(int seed, int history, bool turnFeatures) : Hex::State(seed) { _opts.history = history; _opts.turnFeaturesMultiChannel = turnFeatures; core::State::setFeatures(&_opts); } + + StateTest(int seed) : + Hex::State(seed) { + } + GameStatus GetStatus() { return core::State::_status; } void addAction(int x, int y, int z) { core::State::addAction(x, y, z); } }; @@ -37,12 +43,10 @@ namespace Hex { // unit tests /////////////////////////////////////////////////////////////////////////////// -/* - TODO - TEST(HexStateGroup, init_1) { Hex::StateTest<7,true> state(0, 0, false); + state.Initialize(); ASSERT_EQ(GameStatus::player0Turn, state.GetStatus()); @@ -62,82 +66,24 @@ TEST(HexStateGroup, init_1) { ASSERT_EQ(0, a.GetX()); ASSERT_EQ(i, a.GetY()); ASSERT_EQ(j, a.GetZ()); - ASSERT_EQ(k, a.GetHash()); + ASSERT_EQ(0, a.GetHash()); ASSERT_EQ(k, a.GetIndex()); } } -TEST(HexStateGroup, play_1) { - - Hex::StateTest<7,true> state(0, 0, false); - - //TODO Hex::Action<7> a(2, 3, 2*7+3); - state.addAction(0, 2, 3); - state.ApplyAction(state.GetLegalActions()[0]); - - ASSERT_EQ(GameStatus::player1Turn, state.GetStatus()); - - // features - ASSERT_EQ((std::vector{2, 7, 7}), state.GetFeatureSize()); - for (int p=0; p<2; ++p) { - for (int i=0; i<2; ++i) { - for (int j=0; j<2; ++j) { - int k = (p*2+i)*7+j; - auto f_k = state.GetFeatures()[k]; - if (p==0 and i==2 and j==3) - ASSERT_EQ(1, f_k); - else - ASSERT_EQ(0, f_k); - } - } - } - - // actions - ASSERT_EQ((std::vector{1, 7, 7}), state.GetActionSize()); - ASSERT_EQ(7*7, state.GetLegalActions().size()); - for (int i=0; i<2; ++i) { - for (int j=0; j<2; ++j) { - int k = i*7+j; - if (k<2*7+3) { - auto a = state.GetLegalActions()[k]; - ASSERT_EQ(0, a.GetX()); - ASSERT_EQ(i, a.GetY()); - ASSERT_EQ(j, a.GetZ()); - ASSERT_EQ(k, a.GetHash()); - ASSERT_EQ(k, a.GetIndex()); - } - else if (k>2*7+3) { - int k2 = k-1; - auto a = state.GetLegalActions()[k2]; - ASSERT_EQ(0, a.GetX()); - ASSERT_EQ(i, a.GetY()); - ASSERT_EQ(j, a.GetZ()); - ASSERT_EQ(k2, a.GetHash()); - ASSERT_EQ(k2, a.GetIndex()); - } - } - } - -} - TEST(HexStateGroup, clone_1) { try { Hex::State<7,true> state(0); + state.Initialize(); auto clone = state.clone(); auto ptrClone = dynamic_cast *>(clone.get()); + ASSERT_NE(nullptr, ptrClone); ASSERT_NE(&state, ptrClone); ASSERT_EQ(49, state.GetLegalActions().size()); ASSERT_EQ(49, ptrClone->GetLegalActions().size()); - - // TODO Hex::Action<7> a(2, 3, -1); - state.addAction(0, 2, 3); - state.ApplyAction(state.GetLegalActions()[0]); - - ASSERT_EQ(49, state.GetLegalActions().size()); - ASSERT_EQ(49, ptrClone->GetLegalActions().size()); } catch (std::bad_cast) { FAIL() << "not a Hex::State<7,true>"; @@ -149,33 +95,31 @@ TEST(HexStateGroup, clone_1) { TEST(HexStateGroup, features_1) { Hex::StateTest<3,true> state(0, 2, true); + state.Initialize(); // apply actions ASSERT_EQ((std::vector{1, 3, 3}), state.GetActionSize()); - std::vector> actions {{ - {1,0,-1}, - {0,0,-1} - }}; + // DEBUG printActions(state.GetLegalActions()); auto currentPlayer = GameStatus::player0Turn; auto nextPlayer = GameStatus::player1Turn; ASSERT_EQ(GameStatus::player0Turn, state.GetStatus()); ASSERT_EQ(9, state.GetLegalActions().size()); - state.ApplyAction(actions[0]); + state.ApplyAction( _Action(3, 0, 1, 0) ); ASSERT_EQ(GameStatus::player1Turn, state.GetStatus()); ASSERT_EQ(9, state.GetLegalActions().size()); - state.ApplyAction(actions[1]); + state.ApplyAction( _Action(0, 0, 0, 0) ); ASSERT_EQ(7, state.GetLegalActions().size()); ASSERT_EQ(GameStatus::player0Turn, state.GetStatus()); // check features - ASSERT_EQ((std::vector{7, 3, 3}), state.GetFeatureSize()); + ASSERT_EQ((std::vector{8, 3, 3}), state.GetFeatureSize()); std::vector expectedFeatures { @@ -203,6 +147,10 @@ TEST(HexStateGroup, features_1) { 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f @@ -211,9 +159,9 @@ TEST(HexStateGroup, features_1) { // DEBUG // std::cout << "*** expected ***" << std::endl; - // printPlanes>(expectedFeatures, 7, 3, 3); + // printPlanes>(expectedFeatures, 8, 3, 3); // std::cout << "*** actual ***" << std::endl; - // printPlanes>(state.GetFeatures(), 7, 3, 3); + // printPlanes>(state.GetFeatures(), 8, 3, 3); ASSERT_EQ(expectedFeatures, state.GetFeatures()); @@ -223,16 +171,17 @@ TEST(HexStateGroup, features_1) { TEST(HexStateGroup, features_2) { Hex::StateTest<3,false> state(0, 2, true); + state.Initialize(); // apply actions ASSERT_EQ((std::vector{1, 3, 3}), state.GetActionSize()); - std::vector> actions {{ - {1,1,-1}, {0,0,-1}, - {2,2,-1}, {2,0,-1}, - {1,0,-1} - }}; + std::vector<_Action> actions { + _Action(0, 0, 1, 1), _Action(0, 0, 0, 0), + _Action(0, 0, 2, 2), _Action(0, 0, 2, 0), + _Action(0, 0, 1, 0) + }; auto currentPlayer = GameStatus::player0Turn; auto nextPlayer = GameStatus::player1Turn; @@ -249,7 +198,7 @@ TEST(HexStateGroup, features_2) { // check features - ASSERT_EQ((std::vector{7, 3, 3}), state.GetFeatureSize()); + ASSERT_EQ((std::vector{8, 3, 3}), state.GetFeatureSize()); std::vector expectedFeatures { @@ -283,6 +232,11 @@ TEST(HexStateGroup, features_2) { 0.f, 0.f, 0.f, 1.f, 0.f, 0.f, + // + 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, + // turn 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, @@ -292,9 +246,9 @@ TEST(HexStateGroup, features_2) { // DEBUG // std::cout << "*** expected ***" << std::endl; - // printPlanes>(expectedFeatures, 7, 3, 3); + // printPlanes>(expectedFeatures, 8, 3, 3); // std::cout << "*** actual ***" << std::endl; - // printPlanes>(state.GetFeatures(), 7, 3, 3); + // printPlanes>(state.GetFeatures(), 8, 3, 3); ASSERT_EQ(expectedFeatures, state.GetFeatures()); @@ -305,27 +259,28 @@ TEST(HexStateGroup, features_3) { const int history = 2; const int size = 9; const bool turnFeatures = true; - const int nbChannels = 2*(1 + history) + (turnFeatures ? 1 : 0); + const int nbChannels = 2*(1 + history) + (turnFeatures ? 1 : 0) + 1; Hex::StateTest state(0, history, turnFeatures); + state.Initialize(); // apply actions ASSERT_EQ((std::vector{1, size, size}), state.GetActionSize()); - std::vector> actions {{ - {0,0,-1}, {4,1,-1}, - {2,3,-1}, {5,2,-1}, - {2,5,-1}, {4,4,-1}, - {2,6,-1}, {5,5,-1}, - {7,4,-1}, {4,7,-1}, - {7,6,-1}, {3,8,-1}, - {5,6,-1}, {4,6,-1}, - {4,5,-1}, {5,4,-1}, - {5,3,-1}, {4,3,-1}, - {4,2,-1}, {5,1,-1}, - {5,0,-1}, {4,0,-1} - }}; + std::vector<_Action> actions { + _Action(0,0,0,0), _Action(0,0,4,1), + _Action(0,0,2,3), _Action(0,0,5,2), + _Action(0,0,2,5), _Action(0,0,4,4), + _Action(0,0,2,6), _Action(0,0,5,5), + _Action(0,0,7,4), _Action(0,0,4,7), + _Action(0,0,7,6), _Action(0,0,3,8), + _Action(0,0,5,6), _Action(0,0,4,6), + _Action(0,0,4,5), _Action(0,0,5,4), + _Action(0,0,5,3), _Action(0,0,4,3), + _Action(0,0,4,2), _Action(0,0,5,1), + _Action(0,0,5,0), _Action(0,0,4,0) + }; auto currentPlayer = GameStatus::player0Turn; auto nextPlayer = GameStatus::player1Turn; @@ -413,6 +368,17 @@ TEST(HexStateGroup, features_3) { 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + // + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, + // turn 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, @@ -458,5 +424,4 @@ TEST(HexStateGroup, features_3) { // } } -*/ diff --git a/tests/utils.h b/tests/utils.h index 4e5beae9..e94b7137 100644 --- a/tests/utils.h +++ b/tests/utils.h @@ -11,6 +11,7 @@ #include #include +#include // Print a feature plane: // printPlanes&>(state.GetFeatures(), indexChannels, nbRows, nbCols); @@ -43,3 +44,13 @@ void printData(T data) { std::cout << std::endl; } +template +void printActions(std::vector actions) { + for (const auto & a : actions) + std::cout << a.GetIndex() << " " + << a.GetX() << " " + << a.GetY() << " " + << a.GetZ() << std::endl; + std::cout << std::endl; +} + From e21e8ddc161830e1c40b07cfa43da6e5bb3046fe Mon Sep 17 00:00:00 2001 From: Julien Dehos Date: Thu, 21 Jan 2021 16:15:37 +0100 Subject: [PATCH 05/11] fix CI (maybe) --- CMakeLists.txt | 2 +- nix/shell-cpu.nix | 2 +- nix/shell-cuda.nix | 2 +- tests/CMakeLists.txt | 1 - 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 40818077..614f945f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,7 +55,7 @@ add_subdirectory(src/games/minesweeper_csp_vkms) # tests add_executable(test_state src/core/test_state.cc src/core/state.cc) -target_link_libraries(test_state PUBLIC _tube _mcts _games ${JNI_LIBRARIES} python3) +target_link_libraries(test_state PUBLIC _tube _mcts _games ${JNI_LIBRARIES}) enable_testing() diff --git a/nix/shell-cpu.nix b/nix/shell-cpu.nix index c1cc3a5d..d761a3e9 100644 --- a/nix/shell-cpu.nix +++ b/nix/shell-cpu.nix @@ -37,7 +37,7 @@ in pkgs.mkShell { shellHook = '' export CFLAGS="-I${pybind11}/include -I${pytorch}/${python.sitePackages}/torch/include -I${pytorch}/${python.sitePackages}/torch/include/torch/csrc/api/include" export CXXFLAGS=$CFLAGS - export LDFLAGS="-L${pytorch}/${python.sitePackages}/torch/lib -L$out/${python.sitePackages}" + export LDFLAGS="-lpython3 -L${pytorch}/${python.sitePackages}/torch/lib -L$out/${python.sitePackages}" export PYTHONPATH="$PYTHONPATH:build:build/torchRL/mcts:build/torchRL/tube" export OMP_NUM_THREADS=1 ''; diff --git a/nix/shell-cuda.nix b/nix/shell-cuda.nix index 95af1df5..a27f90d6 100644 --- a/nix/shell-cuda.nix +++ b/nix/shell-cuda.nix @@ -73,7 +73,7 @@ in pkgs.mkShell { shellHook = '' export CFLAGS="-I${pybind11}/include -I${pytorch}/${python.sitePackages}/torch/include -I${pytorch}/${python.sitePackages}/torch/include/torch/csrc/api/include" export CXXFLAGS=$CFLAGS - export LDFLAGS="-L${pytorch}/${python.sitePackages}/torch/lib -L$out/${python.sitePackages} -L${pkgs.cudatoolkit}/lib" + export LDFLAGS="-lpython3 -L${pytorch}/${python.sitePackages}/torch/lib -L$out/${python.sitePackages} -L${pkgs.cudatoolkit}/lib" export LD_LIBRARY_PATH="${pkgs.linuxPackages.nvidia_x11}/lib" export PYTHONPATH="$PYTHONPATH:build:build/torchRL/mcts:build/torchRL/tube" export OMP_NUM_THREADS=1 diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index fd1847f1..b8b0e779 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -46,7 +46,6 @@ add_executable( polygames-tests target_link_libraries( polygames-tests libpolygames - python3 ${CMAKE_THREAD_LIBS_INIT} ${GTEST_LIBRARIES} ${JNI_LIBRARIES} From d9c0bb6322f5d60105e1c5c2d54c60cd3e57a155 Mon Sep 17 00:00:00 2001 From: Julien Dehos Date: Thu, 21 Jan 2021 16:27:46 +0100 Subject: [PATCH 06/11] CI clone pytorch 1.2.0 --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 40927ad6..9ef33664 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -72,7 +72,7 @@ jobs: wget -O ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh bash ~/miniconda.sh -b -p /opt/conda rm ~/miniconda.sh - git clone --recursive https://github.com/pytorch/pytorch --branch=v1.1.0 ~/pytorch + git clone --recursive https://github.com/pytorch/pytorch --branch=v1.2.0 ~/pytorch fi - save_cache: From 752b96c2a044215f9d653bdba4b345dc89fe77e3 Mon Sep 17 00:00:00 2001 From: Julien Dehos Date: Thu, 21 Jan 2021 16:51:45 +0100 Subject: [PATCH 07/11] disable CI cache --- .circleci/config.yml | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 9ef33664..e60bc2c7 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -32,11 +32,11 @@ jobs: #- sudo apt-get update #- sudo apt-get install -y singularity-container - - restore_cache: - keys: - - tag3-conda-{{ checksum "singularity/environment.yml" }}-pytorch - - tag3-conda-{{ checksum "singularity/environment.yml" }} - - tag1-conda + #- restore_cache: + # keys: + # - tag3-conda-{{ checksum "singularity/environment.yml" }}-pytorch + # - tag3-conda-{{ checksum "singularity/environment.yml" }} + # - tag1-conda - run: name: Install miniconda and clone pytorch From 1f2a75635ca5d0912f8d198e402d85aa8026bd1a Mon Sep 17 00:00:00 2001 From: Julien Dehos Date: Thu, 21 Jan 2021 17:52:17 +0100 Subject: [PATCH 08/11] fix tests --- .circleci/config.yml | 10 +- src/games/havannah_state.h | 6 + tests/havannah-state-tests.cc | 251 +++++++++++++++++++++++++++++++--- 3 files changed, 241 insertions(+), 26 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index e60bc2c7..9ef33664 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -32,11 +32,11 @@ jobs: #- sudo apt-get update #- sudo apt-get install -y singularity-container - #- restore_cache: - # keys: - # - tag3-conda-{{ checksum "singularity/environment.yml" }}-pytorch - # - tag3-conda-{{ checksum "singularity/environment.yml" }} - # - tag1-conda + - restore_cache: + keys: + - tag3-conda-{{ checksum "singularity/environment.yml" }}-pytorch + - tag3-conda-{{ checksum "singularity/environment.yml" }} + - tag1-conda - run: name: Install miniconda and clone pytorch diff --git a/src/games/havannah_state.h b/src/games/havannah_state.h index 5a2407c5..a666aed9 100644 --- a/src/games/havannah_state.h +++ b/src/games/havannah_state.h @@ -37,6 +37,7 @@ template class State : public core::State { std::string actionsDescription() const override; int parseAction(const std::string& str) const override; virtual int getCurrentPlayerColor() const override; + virtual int getNumPlayerColors() const override; }; } // namespace Havannah @@ -303,3 +304,8 @@ template int Havannah ::State::getCurrentPlayerColor() const { return _board.getCurrentColor(); } + +template +int Havannah::State::getNumPlayerColors() const { + return 2; +} diff --git a/tests/havannah-state-tests.cc b/tests/havannah-state-tests.cc index cfcd8c44..37502769 100644 --- a/tests/havannah-state-tests.cc +++ b/tests/havannah-state-tests.cc @@ -37,25 +37,24 @@ namespace Havannah { // unit tests /////////////////////////////////////////////////////////////////////////////// -/* -TODO - TEST(HavannahStateGroup, init_0) { const int size = 5; const int history = 0; const bool turnFeatures = true; const int fullsize = 2*size - 1; - const int nbChannels = 3*(1+history) + (turnFeatures ? 1 : 0); + const int nbChannels = 3*(1+history) + (turnFeatures ? 1 : 0) + 1; const int nbActions = fullsize*fullsize - size*(size-1); Havannah::StateTest state(0, history, turnFeatures); + state.Initialize(); ASSERT_EQ(GameStatus::player0Turn, state.GetStatus()); // features std::vector expectedFeatures(nbChannels*fullsize*fullsize, 0.f); const std::vector boardFeatures = { + 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, @@ -64,7 +63,18 @@ TEST(HavannahStateGroup, init_0) { 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, - 1, 1, 1, 1, 1, 0, 0, 0, 0 + 1, 1, 1, 1, 1, 0, 0, 0, 0, + + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1 + }; const int f2 = fullsize*fullsize; std::copy(boardFeatures.begin(), boardFeatures.end(), expectedFeatures.begin() + 2*f2); @@ -91,16 +101,99 @@ TEST(HavannahStateGroup, init_1) { const int history = 2; const bool turnFeatures = true; const int fullsize = 2*size - 1; - const int nbChannels = 3*(1+history) + (turnFeatures ? 1 : 0); + const int nbChannels = 3*(1+history) + (turnFeatures ? 1 : 0) + 1; const int nbActions = fullsize*fullsize - size*(size-1); Havannah::StateTest state(0, history, turnFeatures); + state.Initialize(); ASSERT_EQ(GameStatus::player0Turn, state.GetStatus()); // features - std::vector expectedFeatures(nbChannels*fullsize*fullsize, 0.f); - const std::vector boardFeatures = { + const std::vector expectedFeatures = { + + // 1 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + + 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, + + // 2 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, @@ -115,12 +208,92 @@ TEST(HavannahStateGroup, init_1) { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, - 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0 + 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, + + // 3 + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + + 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, + 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, + + // + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, + + // + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 + }; - const int f2 = fullsize*fullsize; - std::copy(boardFeatures.begin(), boardFeatures.end(), expectedFeatures.begin() + 2*f2); - std::copy(boardFeatures.begin(), boardFeatures.end(), expectedFeatures.begin() + 5*f2); - std::copy(boardFeatures.begin(), boardFeatures.end(), expectedFeatures.begin() + 8*f2); // DEBUG // std::cout << "*** expected ***" << std::endl; @@ -148,6 +321,7 @@ TEST(HavannahStateGroup, init_2) { const int nbActions = fullsize*fullsize - size*(size-1); Havannah::StateTest state(0, history, turnFeatures); + state.Initialize(); ASSERT_EQ(GameStatus::player0Turn, state.GetStatus()); @@ -193,11 +367,11 @@ TEST(HavannahStateGroup, init_2) { auto action = state.GetLegalActions()[k]; int i = expectedAction.first; int j = expectedAction.second; - int h = i*fullsize + j; + // int h = i*fullsize + j; ASSERT_EQ(0, action.GetX()); ASSERT_EQ(i, action.GetY()); ASSERT_EQ(j, action.GetZ()); - ASSERT_EQ(h, action.GetHash()); + ASSERT_EQ(0, action.GetHash()); ASSERT_EQ(k, action.GetIndex()); } @@ -208,9 +382,11 @@ TEST(HavannahStateGroup, clone_1) { try { Havannah::State<4, true, false> state(0); + state.Initialize(); auto clone = state.clone(); auto ptrClone = dynamic_cast *>(clone.get()); + ASSERT_NE(nullptr, ptrClone); ASSERT_NE(&state, ptrClone); ASSERT_EQ(37, state.GetLegalActions().size()); ASSERT_EQ(37, ptrClone->GetLegalActions().size()); @@ -249,6 +425,7 @@ TEST(HavannahStateGroup, clone_2) { try { Havannah::State<4, true, false> state(0); + state.Initialize(); auto clone = state.clone(); auto ptrClone = dynamic_cast *>(clone.get()); @@ -275,10 +452,11 @@ TEST(HavannahStateGroup, features_1_pie) { const int history = 2; const bool turnFeatures = true; const int fullsize = 2*size - 1; - const int nbChannels = 3*(1+history) + (turnFeatures ? 1 : 0); + const int nbChannels = 3*(1+history) + (turnFeatures ? 1 : 0) + 1; const int nbActions = fullsize*fullsize - size*(size-1); Havannah::StateTest state(0, history, turnFeatures); + state.Initialize(); // apply actions @@ -386,6 +564,13 @@ TEST(HavannahStateGroup, features_1_pie) { 1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, + // + 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, + // turn 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, @@ -412,10 +597,11 @@ TEST(HavannahStateGroup, features_1_nopie) { const int history = 2; const bool turnFeatures = true; const int fullsize = 2*size - 1; - const int nbChannels = 3*(1+history) + (turnFeatures ? 1 : 0); + const int nbChannels = 3*(1+history) + (turnFeatures ? 1 : 0) + 1; const int nbActions = fullsize*fullsize - size*(size-1); Havannah::StateTest state(0, history, turnFeatures); + state.Initialize(); // apply actions @@ -507,6 +693,13 @@ TEST(HavannahStateGroup, features_1_nopie) { 1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, + // + 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, + // turn 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, @@ -533,10 +726,11 @@ TEST(HavannahStateGroup, features_2_nopie) { const int history = 2; const bool turnFeatures = true; const int fullsize = 2*size - 1; - const int nbChannels = 3*(1+history) + (turnFeatures ? 1 : 0); + const int nbChannels = 3*(1+history) + (turnFeatures ? 1 : 0) + 1; const int nbActions = fullsize*fullsize - size*(size-1); Havannah::StateTest state(0, history, turnFeatures); + state.Initialize(); // apply actions @@ -627,6 +821,13 @@ TEST(HavannahStateGroup, features_2_nopie) { 1.f, 1.f, 1.f, 1.f, 0.f, 1.f, 1.f, 1.f, 0.f, 0.f, + // + 1.f, 1.f, 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, + 1.f, 1.f, 1.f, 1.f, 1.f, + // turn 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, @@ -658,10 +859,11 @@ TEST(HavannahStateGroup, features_3_nopie) { const int history = 2; const bool turnFeatures = true; const int fullsize = 2*size - 1; - const int nbChannels = 3*(1+history) + (turnFeatures ? 1 : 0); + const int nbChannels = 3*(1+history) + (turnFeatures ? 1 : 0) + 1; const int nbActions = fullsize*fullsize - size*(size-1); Havannah::StateTest state(0, history, turnFeatures); + state.Initialize(); // apply actions @@ -776,6 +978,15 @@ TEST(HavannahStateGroup, features_3_nopie) { 1.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 1.f, 1.f, 1.f, 1.f, 0.f, 0.f, 0.f, + // + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, 0.f, + // turn 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, 1.f, @@ -820,5 +1031,3 @@ TEST(HavannahStateGroup, features_3_nopie) { } -*/ - From 607a6d90d20f80f8a8aaa8d32026462e95ecf4a4 Mon Sep 17 00:00:00 2001 From: Julien Dehos Date: Thu, 21 Jan 2021 19:23:27 +0100 Subject: [PATCH 09/11] find pytorch version automatically --- CMakeLists.txt | 11 +++-------- README.md | 2 +- nix/README.md | 2 +- src/core/actor.h | 2 +- src/core/model_manager.cc | 6 +++--- tests/CMakeLists.txt | 11 +++-------- 6 files changed, 12 insertions(+), 22 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 614f945f..88c05c77 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,20 +11,15 @@ set(CMAKE_CXX_FLAGS set(CMAKE_POSITION_INDEPENDENT_CODE ON) -OPTION(PYTORCH12 "Is PyTorch >= 1.2" OFF) -OPTION(PYTORCH15 "Is PyTorch >= 1.5" OFF) -IF(PYTORCH15) - ADD_DEFINITIONS(-DPYTORCH15 -DPYTORCH12) -ELSEIF(PYTORCH12) - ADD_DEFINITIONS(-DPYTORCH12) -ENDIF() - execute_process( COMMAND python -c "import torch; import os; print(os.path.dirname(torch.__file__), end='')" OUTPUT_VARIABLE TorchPath ) set(CMAKE_PREFIX_PATH ${TorchPath}) find_package(Torch REQUIRED) +message(STATUS "Torch VERSION: ${Torch_VERSION}") +add_definitions(-DTORCH_VERSION_MAJOR=${Torch_VERSION_MAJOR}) +add_definitions(-DTORCH_VERSION_MINOR=${Torch_VERSION_MINOR}) find_package(Boost COMPONENTS system) if( Boost_FOUND ) diff --git a/README.md b/README.md index 347daaba..326d017d 100644 --- a/README.md +++ b/README.md @@ -47,7 +47,7 @@ cd polygames mkdir build cd build -cmake .. -DCMAKE_BUILD_TYPE=relwithdebinfo -DPYTORCH15=ON +cmake .. -DCMAKE_BUILD_TYPE=relwithdebinfo make -j ``` diff --git a/nix/README.md b/nix/README.md index 6298a6ca..ff52adda 100644 --- a/nix/README.md +++ b/nix/README.md @@ -88,7 +88,7 @@ the [Dockerfile](./Dockerfile). ``` mkdir build cd build - cmake -DPYTORCH12=ON .. + cmake .. make -j4 cd .. ``` diff --git a/src/core/actor.h b/src/core/actor.h index 6906bfe8..ed05a16b 100644 --- a/src/core/actor.h +++ b/src/core/actor.h @@ -116,7 +116,7 @@ class Actor { torch::Tensor policy; if (useValue_ && resultsAreValid) { if (logitValue_) { -#ifdef PYTORCH12 +#if TORCH_VERSION_MINOR >= 2 float* begin = value_->data.data(); #else float* begin = value_->data.data_ptr(); diff --git a/src/core/model_manager.cc b/src/core/model_manager.cc index 7d8af5df..b1121020 100644 --- a/src/core/model_manager.cc +++ b/src/core/model_manager.cc @@ -21,7 +21,7 @@ std::unordered_map convertIValueToMap( std::unordered_map map; auto dict = value.toGenericDict(); -#ifdef PYTORCH12 +#if TORCH_VERSION_MINOR >= 2 for (auto& name2tensor : dict) { auto name = name2tensor.key().toString(); torch::Tensor tensor = name2tensor.value().toTensor(); @@ -189,7 +189,7 @@ class ModelManagerImpl { "train", trainChannelNumSlots, trainChannelTimeoutMs); actChannel_ = std::make_shared("act", actBatchsize, -1); -#ifdef PYTORCH12 +#if TORCH_VERSION_MINOR >= 2 model_ = std::make_shared(torch::jit::load(jitModel_, device)); #else @@ -333,7 +333,7 @@ class ModelManagerImpl { } } -#ifdef PYTORCH15 +#if TORCH_VERSION_MINOR >= 5 void loadModelStateDict( TorchJitModel& model, const std::unordered_map& stateDict) { diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b8b0e779..f39c17c8 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -2,14 +2,6 @@ cmake_minimum_required( VERSION 3.3 ) project( polygames-tests ) set(CMAKE_CXX_STANDARD 17) -OPTION(PYTORCH12 "Is PyTorch >= 1.2" OFF) -OPTION(PYTORCH15 "Is PyTorch >= 1.5" OFF) -IF(PYTORCH15) - ADD_DEFINITIONS(-DPYTORCH15 -DPYTORCH12) -ELSEIF(PYTORCH12) - ADD_DEFINITIONS(-DPYTORCH12) -ENDIF() - execute_process( COMMAND python -c "import torch; import os; print(os.path.dirname(torch.__file__), end='')" OUTPUT_VARIABLE TorchPath @@ -17,6 +9,9 @@ execute_process( set(CMAKE_PREFIX_PATH ${TorchPath}) find_package(Torch REQUIRED) include_directories(${TORCH_INCLUDE_DIRS}) +message(STATUS "Torch VERSION: ${Torch_VERSION}") +add_definitions(-DTORCH_VERSION_MAJOR=${Torch_VERSION_MAJOR}) +add_definitions(-DTORCH_VERSION_MINOR=${Torch_VERSION_MINOR}) find_package( PythonInterp 3.7 REQUIRED ) find_package( PythonLibs 3.7 REQUIRED ) From f075bd55b1ad9440a4f3e9198cc03061370d0a71 Mon Sep 17 00:00:00 2001 From: Julien Dehos Date: Thu, 21 Jan 2021 19:59:14 +0100 Subject: [PATCH 10/11] check pytorch version --- .circleci/config.yml | 4 ++-- CMakeLists.txt | 2 +- src/core/actor.h | 6 ++++-- src/core/model_manager.cc | 22 ++++++++++++++-------- tests/CMakeLists.txt | 4 ++-- 5 files changed, 23 insertions(+), 15 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 9ef33664..cbd69b48 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -138,7 +138,7 @@ jobs: conda activate pypg mkdir build cd build - cmake -DPYTORCH12=ON .. + cmake .. make -j 2 - run: @@ -150,7 +150,7 @@ jobs: wget -P ludii https://ludii.games/downloads/Ludii.jar mkdir tests/build cd tests/build - cmake -DPYTORCH12=ON .. + cmake .. make -j 2 - run: diff --git a/CMakeLists.txt b/CMakeLists.txt index 88c05c77..5cdd3ef9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -16,7 +16,7 @@ execute_process( OUTPUT_VARIABLE TorchPath ) set(CMAKE_PREFIX_PATH ${TorchPath}) -find_package(Torch REQUIRED) +find_package(Torch 1.2 REQUIRED) message(STATUS "Torch VERSION: ${Torch_VERSION}") add_definitions(-DTORCH_VERSION_MAJOR=${Torch_VERSION_MAJOR}) add_definitions(-DTORCH_VERSION_MINOR=${Torch_VERSION_MINOR}) diff --git a/src/core/actor.h b/src/core/actor.h index ed05a16b..1f76f8bd 100644 --- a/src/core/actor.h +++ b/src/core/actor.h @@ -116,10 +116,12 @@ class Actor { torch::Tensor policy; if (useValue_ && resultsAreValid) { if (logitValue_) { -#if TORCH_VERSION_MINOR >= 2 +#if TORCH_VERSION_MINOR >= 5 + float* begin = value_->data.data_ptr(); +#elif TORCH_VERSION_MINOR >= 2 float* begin = value_->data.data(); #else - float* begin = value_->data.data_ptr(); +#error UNSUPPORTED PYTORCH VERSION #endif float* end = begin + 3; softmax_(begin, end); diff --git a/src/core/model_manager.cc b/src/core/model_manager.cc index b1121020..cf0bf4ca 100644 --- a/src/core/model_manager.cc +++ b/src/core/model_manager.cc @@ -21,15 +21,17 @@ std::unordered_map convertIValueToMap( std::unordered_map map; auto dict = value.toGenericDict(); -#if TORCH_VERSION_MINOR >= 2 - for (auto& name2tensor : dict) { - auto name = name2tensor.key().toString(); - torch::Tensor tensor = name2tensor.value().toTensor(); -#else +#if TORCH_VERSION_MINOR >= 5 auto ivalMap = dict->elements(); for (auto& name2tensor : ivalMap) { auto name = name2tensor.first.toString(); torch::Tensor tensor = name2tensor.second.toTensor(); +#elif TORCH_VERSION_MINOR >= 2 + for (auto& name2tensor : dict) { + auto name = name2tensor.key().toString(); + torch::Tensor tensor = name2tensor.value().toTensor(); +#else +#error UNSUPPORTED PYTORCH VERSION #endif tensor = tensor.detach(); @@ -189,11 +191,13 @@ class ModelManagerImpl { "train", trainChannelNumSlots, trainChannelTimeoutMs); actChannel_ = std::make_shared("act", actBatchsize, -1); -#if TORCH_VERSION_MINOR >= 2 +#if TORCH_VERSION_MINOR >= 5 + model_ = torch::jit::load(jitModel_, device); +#elif TORCH_VERSION_MINOR >= 2 model_ = std::make_shared(torch::jit::load(jitModel_, device)); #else - model_ = torch::jit::load(jitModel_, device); +#error UNSUPPORTED PYTORCH VERSION #endif model_->eval(); @@ -378,7 +382,7 @@ class ModelManagerImpl { } model.eval(); } -#else +#elif TORCH_VERSION_MINOR >= 2 void loadModelStateDict( TorchJitModel& model, const std::unordered_map& stateDict) { @@ -422,6 +426,8 @@ class ModelManagerImpl { } model.eval(); } +#else +#error UNSUPPORTED PYTORCH VERSION #endif void updateModel( diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index f39c17c8..3a083306 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -7,11 +7,11 @@ execute_process( OUTPUT_VARIABLE TorchPath ) set(CMAKE_PREFIX_PATH ${TorchPath}) -find_package(Torch REQUIRED) -include_directories(${TORCH_INCLUDE_DIRS}) +find_package(Torch 1.2 REQUIRED) message(STATUS "Torch VERSION: ${Torch_VERSION}") add_definitions(-DTORCH_VERSION_MAJOR=${Torch_VERSION_MAJOR}) add_definitions(-DTORCH_VERSION_MINOR=${Torch_VERSION_MINOR}) +include_directories(${TORCH_INCLUDE_DIRS}) find_package( PythonInterp 3.7 REQUIRED ) find_package( PythonLibs 3.7 REQUIRED ) From 8e5768c800c0698e53cc08327969ff004eabac86 Mon Sep 17 00:00:00 2001 From: Julien Dehos Date: Fri, 22 Jan 2021 13:08:03 +0100 Subject: [PATCH 11/11] cmake configs: link to python libs --- CMakeLists.txt | 5 ++++- nix/shell-cpu.nix | 2 +- nix/shell-cuda.nix | 2 +- tests/CMakeLists.txt | 3 +-- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 5cdd3ef9..0dd13481 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,6 +21,8 @@ message(STATUS "Torch VERSION: ${Torch_VERSION}") add_definitions(-DTORCH_VERSION_MAJOR=${Torch_VERSION_MAJOR}) add_definitions(-DTORCH_VERSION_MINOR=${Torch_VERSION_MINOR}) +find_package( PythonLibs 3.7 REQUIRED ) + find_package(Boost COMPONENTS system) if( Boost_FOUND ) include_directories( ${Boost_INCLUDE_DIRS}) @@ -50,7 +52,8 @@ add_subdirectory(src/games/minesweeper_csp_vkms) # tests add_executable(test_state src/core/test_state.cc src/core/state.cc) -target_link_libraries(test_state PUBLIC _tube _mcts _games ${JNI_LIBRARIES}) +target_link_libraries(test_state PUBLIC + _tube _mcts _games ${JNI_LIBRARIES} ${PYTHON_LIBRARIES}) enable_testing() diff --git a/nix/shell-cpu.nix b/nix/shell-cpu.nix index d761a3e9..c1cc3a5d 100644 --- a/nix/shell-cpu.nix +++ b/nix/shell-cpu.nix @@ -37,7 +37,7 @@ in pkgs.mkShell { shellHook = '' export CFLAGS="-I${pybind11}/include -I${pytorch}/${python.sitePackages}/torch/include -I${pytorch}/${python.sitePackages}/torch/include/torch/csrc/api/include" export CXXFLAGS=$CFLAGS - export LDFLAGS="-lpython3 -L${pytorch}/${python.sitePackages}/torch/lib -L$out/${python.sitePackages}" + export LDFLAGS="-L${pytorch}/${python.sitePackages}/torch/lib -L$out/${python.sitePackages}" export PYTHONPATH="$PYTHONPATH:build:build/torchRL/mcts:build/torchRL/tube" export OMP_NUM_THREADS=1 ''; diff --git a/nix/shell-cuda.nix b/nix/shell-cuda.nix index a27f90d6..95af1df5 100644 --- a/nix/shell-cuda.nix +++ b/nix/shell-cuda.nix @@ -73,7 +73,7 @@ in pkgs.mkShell { shellHook = '' export CFLAGS="-I${pybind11}/include -I${pytorch}/${python.sitePackages}/torch/include -I${pytorch}/${python.sitePackages}/torch/include/torch/csrc/api/include" export CXXFLAGS=$CFLAGS - export LDFLAGS="-lpython3 -L${pytorch}/${python.sitePackages}/torch/lib -L$out/${python.sitePackages} -L${pkgs.cudatoolkit}/lib" + export LDFLAGS="-L${pytorch}/${python.sitePackages}/torch/lib -L$out/${python.sitePackages} -L${pkgs.cudatoolkit}/lib" export LD_LIBRARY_PATH="${pkgs.linuxPackages.nvidia_x11}/lib" export PYTHONPATH="$PYTHONPATH:build:build/torchRL/mcts:build/torchRL/tube" export OMP_NUM_THREADS=1 diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 3a083306..46eaf4ac 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -13,9 +13,7 @@ add_definitions(-DTORCH_VERSION_MAJOR=${Torch_VERSION_MAJOR}) add_definitions(-DTORCH_VERSION_MINOR=${Torch_VERSION_MINOR}) include_directories(${TORCH_INCLUDE_DIRS}) -find_package( PythonInterp 3.7 REQUIRED ) find_package( PythonLibs 3.7 REQUIRED ) -include_directories( ${PYTHON_INCLUDE_DIRS} ) find_package (Threads) @@ -44,6 +42,7 @@ target_link_libraries( polygames-tests ${CMAKE_THREAD_LIBS_INIT} ${GTEST_LIBRARIES} ${JNI_LIBRARIES} + ${PYTHON_LIBRARIES} ${TORCH_LIBRARIES} )