Skip to content
This repository was archived by the owner on Mar 2, 2022. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ jobs:
wget -O ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh
bash ~/miniconda.sh -b -p /opt/conda
rm ~/miniconda.sh
git clone --recursive https://github.com/pytorch/pytorch --branch=v1.1.0 ~/pytorch
git clone --recursive https://github.com/pytorch/pytorch --branch=v1.2.0 ~/pytorch
fi

- save_cache:
Expand Down Expand Up @@ -161,7 +161,7 @@ jobs:
- run:
name: Test polygames-tests (unit tests)
command: |
./tests/build/polygames-tests
./tests/build/polygames-tests ludii/Ludii.jar

- run:
name: Test Mcts
Expand Down
18 changes: 8 additions & 10 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,17 @@ set(CMAKE_CXX_FLAGS

set(CMAKE_POSITION_INDEPENDENT_CODE ON)

OPTION(PYTORCH12 "Is PyTorch >= 1.2" OFF)
OPTION(PYTORCH15 "Is PyTorch >= 1.5" OFF)
IF(PYTORCH15)
ADD_DEFINITIONS(-DPYTORCH15 -DPYTORCH12)
ELSEIF(PYTORCH12)
ADD_DEFINITIONS(-DPYTORCH12)
ENDIF()

execute_process(
COMMAND python -c "import torch; import os; print(os.path.dirname(torch.__file__), end='')"
OUTPUT_VARIABLE TorchPath
)
set(CMAKE_PREFIX_PATH ${TorchPath})
find_package(Torch REQUIRED)
find_package(Torch 1.2 REQUIRED)
message(STATUS "Torch VERSION: ${Torch_VERSION}")
add_definitions(-DTORCH_VERSION_MAJOR=${Torch_VERSION_MAJOR})
add_definitions(-DTORCH_VERSION_MINOR=${Torch_VERSION_MINOR})

find_package( PythonLibs 3.7 REQUIRED )

find_package(Boost COMPONENTS system)
if( Boost_FOUND )
Expand Down Expand Up @@ -55,7 +52,8 @@ add_subdirectory(src/games/minesweeper_csp_vkms)

# tests
add_executable(test_state src/core/test_state.cc src/core/state.cc)
target_link_libraries(test_state PUBLIC _tube _mcts _games ${JNI_LIBRARIES})
target_link_libraries(test_state PUBLIC
_tube _mcts _games ${JNI_LIBRARIES} ${PYTHON_LIBRARIES})

enable_testing()

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ cd polygames
mkdir build
cd build

cmake .. -DCMAKE_BUILD_TYPE=relwithdebinfo -DPYTORCH15=ON
cmake .. -DCMAKE_BUILD_TYPE=relwithdebinfo
make -j

```
Expand Down
2 changes: 1 addition & 1 deletion nix/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ the [Dockerfile](./Dockerfile).
```
mkdir build
cd build
cmake -DPYTORCH12=ON ..
cmake ..
make -j4
cd ..
```
Expand Down
6 changes: 6 additions & 0 deletions src/core/actor.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,13 @@ class Actor {
torch::Tensor policy;
if (useValue_ && resultsAreValid) {
if (logitValue_) {
#if TORCH_VERSION_MINOR >= 5
float* begin = value_->data.data_ptr<float>();
#elif TORCH_VERSION_MINOR >= 2
float* begin = value_->data.data<float>();
#else
#error UNSUPPORTED PYTORCH VERSION
#endif
float* end = begin + 3;
softmax_(begin, end);
}
Expand Down
24 changes: 15 additions & 9 deletions src/core/model_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,17 @@ std::unordered_map<std::string, at::Tensor> convertIValueToMap(
std::unordered_map<std::string, torch::Tensor> map;
auto dict = value.toGenericDict();

#ifdef PYTORCH12
for (auto& name2tensor : dict) {
auto name = name2tensor.key().toString();
torch::Tensor tensor = name2tensor.value().toTensor();
#else
#if TORCH_VERSION_MINOR >= 5
auto ivalMap = dict->elements();
for (auto& name2tensor : ivalMap) {
auto name = name2tensor.first.toString();
torch::Tensor tensor = name2tensor.second.toTensor();
#elif TORCH_VERSION_MINOR >= 2
for (auto& name2tensor : dict) {
auto name = name2tensor.key().toString();
torch::Tensor tensor = name2tensor.value().toTensor();
#else
#error UNSUPPORTED PYTORCH VERSION
#endif

tensor = tensor.detach();
Expand Down Expand Up @@ -189,11 +191,13 @@ class ModelManagerImpl {
"train", trainChannelNumSlots, trainChannelTimeoutMs);
actChannel_ = std::make_shared<tube::DataChannel>("act", actBatchsize, -1);

#ifdef PYTORCH12
#if TORCH_VERSION_MINOR >= 5
model_ = torch::jit::load(jitModel_, device);
#elif TORCH_VERSION_MINOR >= 2
model_ =
std::make_shared<TorchJitModel>(torch::jit::load(jitModel_, device));
#else
model_ = torch::jit::load(jitModel_, device);
#error UNSUPPORTED PYTORCH VERSION
#endif
model_->eval();

Expand Down Expand Up @@ -333,7 +337,7 @@ class ModelManagerImpl {
}
}

#ifdef PYTORCH15
#if TORCH_VERSION_MINOR >= 5
void loadModelStateDict(
TorchJitModel& model,
const std::unordered_map<std::string, torch::Tensor>& stateDict) {
Expand Down Expand Up @@ -378,7 +382,7 @@ class ModelManagerImpl {
}
model.eval();
}
#else
#elif TORCH_VERSION_MINOR >= 2
void loadModelStateDict(
TorchJitModel& model,
const std::unordered_map<std::string, torch::Tensor>& stateDict) {
Expand Down Expand Up @@ -422,6 +426,8 @@ class ModelManagerImpl {
}
model.eval();
}
#else
#error UNSUPPORTED PYTORCH VERSION
#endif

void updateModel(
Expand Down
1 change: 1 addition & 0 deletions src/distributed/distributed.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

#include <cstring>
#include <functional>
#include <list>
#include <optional>
#include <random>
#include <string>
Expand Down
6 changes: 6 additions & 0 deletions src/games/havannah_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ template <int SIZE, bool PIE, bool EXTENDED> class State : public core::State {
std::string actionsDescription() const override;
int parseAction(const std::string& str) const override;
virtual int getCurrentPlayerColor() const override;
virtual int getNumPlayerColors() const override;
};
} // namespace Havannah

Expand Down Expand Up @@ -303,3 +304,8 @@ template <int SIZE, bool PIE, bool EXTENDED>
int Havannah ::State<SIZE, PIE, EXTENDED>::getCurrentPlayerColor() const {
return _board.getCurrentColor();
}

template <int SIZE, bool PIE, bool EXTENDED>
int Havannah::State<SIZE, PIE, EXTENDED>::getNumPlayerColors() const {
return 2;
}
26 changes: 7 additions & 19 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@ execute_process(
OUTPUT_VARIABLE TorchPath
)
set(CMAKE_PREFIX_PATH ${TorchPath})
find_package(Torch REQUIRED)
find_package(Torch 1.2 REQUIRED)
message(STATUS "Torch VERSION: ${Torch_VERSION}")
add_definitions(-DTORCH_VERSION_MAJOR=${Torch_VERSION_MAJOR})
add_definitions(-DTORCH_VERSION_MINOR=${Torch_VERSION_MINOR})
include_directories(${TORCH_INCLUDE_DIRS})

find_package( PythonInterp 3.7 REQUIRED )
find_package( PythonLibs 3.7 REQUIRED )
include_directories( ${PYTHON_INCLUDE_DIRS} )

find_package (Threads)

Expand All @@ -22,20 +23,9 @@ include_directories( ${GTEST_INCLUDE_DIRS} )
find_package(JNI REQUIRED)
include_directories( ${JNI_INCLUDE_DIRS})

include_directories(
../games
../torchRL
../torchRL/third_party/fmt/include
../torchRL/tube/src_cpp
)
add_subdirectory(../src build-libpolygames)

add_executable( polygames-tests
../core/game.cc
../core/state.cc
../torchRL/mcts/mcts.cc
../torchRL/mcts/node.cc
../torchRL/tube/src_cpp/data_channel.cc
../torchRL/tube/src_cpp/replay_buffer.cc
tests.cc

# Include your tests here.
Expand All @@ -44,17 +34,15 @@ add_executable( polygames-tests
havannah-tests.cc
hex-state-tests.cc
hex-tests.cc

ludii-game-tests.cc
../games/ludii/jni_utils.cc
../games/ludii/ludii_game_wrapper.cc
../games/ludii/ludii_state_wrapper.cc
)

target_link_libraries( polygames-tests
libpolygames
${CMAKE_THREAD_LIBS_INIT}
${GTEST_LIBRARIES}
${JNI_LIBRARIES}
${PYTHON_LIBRARIES}
${TORCH_LIBRARIES}
)

Expand Down
30 changes: 14 additions & 16 deletions tests/connectfour-tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

#include <gtest/gtest.h>
#include "utils.h"
#include <connectfour.h>
#include <games/connectfour.h>

TEST(Connectfour, init_1) {

Expand All @@ -23,12 +23,12 @@ TEST(Connectfour, init_1) {
ASSERT_EQ(GameStatus::player0Turn, GameStatus(state.getCurrentPlayer()));

for (int i=0; i<7; ++i) {
auto a_i = std::dynamic_pointer_cast<ActionForConnectFour>(state.GetLegalActions()[i]);
ASSERT_EQ(i, a_i->GetX());
ASSERT_EQ(0, a_i->GetY());
ASSERT_EQ(0, a_i->GetZ());
ASSERT_EQ(i, a_i->GetHash());
ASSERT_EQ(i, a_i->GetIndex());
auto a_i = state.GetLegalActions()[i];
ASSERT_EQ(i, a_i.GetX());
ASSERT_EQ(0, a_i.GetY());
ASSERT_EQ(0, a_i.GetZ());
ASSERT_EQ(0, a_i.GetHash());
ASSERT_EQ(i, a_i.GetIndex());
}

std::vector<float> expectedFeatures {
Expand Down Expand Up @@ -71,26 +71,25 @@ TEST(Connectfour, init_1) {

}


TEST(Connectfour, play_1) {

StateForConnectFour state(0);
state.Initialize();

ActionForConnectFour action(1, 7);
_Action action(0, 1, 0, 0);
state.ApplyAction(action);

ASSERT_EQ((std::vector<int64_t>{3, 6, 7}), state.GetFeatureSize());
ASSERT_EQ((std::vector<int64_t>{7, 1, 1}), state.GetActionSize());
ASSERT_EQ(GameStatus::player1Turn, GameStatus(state.getCurrentPlayer()));

for (int i=0; i<7; ++i) {
auto a_i = std::dynamic_pointer_cast<ActionForConnectFour>(state.GetLegalActions()[i]);
ASSERT_EQ(i, a_i->GetX());
ASSERT_EQ(0, a_i->GetY());
ASSERT_EQ(0, a_i->GetZ());
ASSERT_EQ(i, a_i->GetHash());
ASSERT_EQ(i, a_i->GetIndex());
auto a_i = state.GetLegalActions()[i];
ASSERT_EQ(i, a_i.GetX());
ASSERT_EQ(0, a_i.GetY());
ASSERT_EQ(0, a_i.GetZ());
ASSERT_EQ(0, a_i.GetHash());
ASSERT_EQ(i, a_i.GetIndex());
}

std::vector<float> expectedFeatures {
Expand Down Expand Up @@ -131,4 +130,3 @@ TEST(Connectfour, play_1) {

}


Loading