Creating a Tree Search Library

I have recently simulated a few games using Python code, see the articles with the Game Simulation Toolbox tag. All the code is in one repository, but I haven't really written good code. It was mostly just experimenting, and I did not work as cleanly as I otherwise do. My latest addition, the backtracking with Railroad Ink ended up as completely hacky code: Control flow on the module level, no clear separation, huge functions, little use of classes. It is a total mess.

Eventually I had everything in place to write that article and generate the video. But it turned out that the backtracking algorithm just isn't the best choice for this problem. I then wanted to try a random walk to generate funny videos. And also I wanted to try Monte Carlo Tree Search (MCTS). In my sandbox I have already implemented the MCTS for Nidavellir and also wanted to use that for Tic-Tac-Toe at some point. The random walk is trivial to implement. I had implemented beam search for Scythe. But I could not just mix-and-match the algorithms with the different games.

Prototyping without having to spend time to work professionally can be fun. For a while. And now it is getting on my nerves, and I will refactor it. The basic idea is that tree search algorithms and the games are independent of each other. This notion should be reflected in the design of the software as well. Therefore I propose to refactor it like this:

So instead of having to implement all algorithms with all games, I just have to implement each game, and also have to implement each algorithm. From there on, I can just combine them like I want. But how do we get there? We will need to have an interface for the tree search and for the iteration which fulfills all the respective needs. Perhaps it is best to start with a few use cases to figure out what we really need.

Use cases

Taking the existing blog posts, we can extract a few use cases. The latest one is to use backtracking to explore the state space of the Railroad Ink game. It was graph exploration, but what exactly was I searching? For the video, I just wanted to go through all the combinations depth first, it was a simple tree traversal with the side effect that the iterator would just render an image whenever it was moved. From this we can gather a few requirements:

  • The iterator needs to present the next possible child leaves to the tree search algorithm.
  • Either the iterator allows to go back to the parent, or it allows to store iterators such that one can go back to any position in the tree. A backtracking algorithm would then keep a stack of the parents and eventually restore the state to a parent when it needs to back-track.
  • The iterator must be notified when it goes back to the parent such that it can create a picture for that too.
  • Somehow the iterator needs to signal that it is on a terminal node. This could also be done by returning an empty list of children.

Then for MCTS with Nidavellir I need to somehow specify a function which assigns “win” and “loss” to terminal nodes. This will then used to grade the nodes of the tree which have already been explored. One essential trait of MCTS is that it can restart from all possible nodes. The iterator needs to be able to be set to any state and continue from there. This gives us more requirements:

  • The iterator needs to support random access to the states of the system, not just going back to a parent.
  • We need a binary value function for terminal notes which we can use to compute values for all nodes.

For a beam search with Scythe I need to have some idea which beams to favor. This means that I need to have a number valued value function which assigns some number to each state. Having this function automatically learned would be better, but that is not easy for such a game.

  • General value functions must also be supported. It must not only rank the children of a node, but be able to rank all nodes on a given level of the tree. With MCTS it would also be good to have an absolute ranking over all levels.

The full search would be easy, either just set the beam width to infinity to have a breadth-first-search, or just use backtracking without an end, which is a depth-first search. Also backtracking would benefit from a value function.

More design elements

The value function seems to be independent of the traversal. I would like to use different ones. Therefore that needs to be done with an interface as well.

Also I need to have something where interesting states can be passed to. For the video use case I want to observe the traversal. This could be done as a side-effect of the iterator. But perhaps I also pass a call-back function, a traversal observer? The observer could then determine whether that state is worth of recording. Perhaps collecting good episodes and just observing the traversal are two different things. Maybe those should be separate elements.

This is how I think of the design at this point:

Implementing backtracking

I will just start to implement one of the algorithms and try to do that in a test-driven way. I will use the backtracking algorithm. So what do I need for that? We first need some test tree that we can go through. Let us take this example:

When we go through that with backtracking, it will go deeper until it is at a leaf. Then it goes back up, and tries the next child. Using a mock observer, we can then write the following test case:

def test_backtracking() -> None:
    tree = {
        "Algorithms": {
            "Breadth-First": {},
            "Depth-First": {"Backtracking": {}, "Beam Search": {}},
            "Other": {"Monte Carlo Tree Search": {}},
        }
    }

    iterator = DictTreeIterator("Root", tree)
    observer = MockObserver()
    tree_search = Backtracking(observer)
    tree_search.run(iterator)

    assert [it.name for it in observer.states] == [
        "Root",
        "Algorithms",
        "Breadth-First",
        "Algorithms",
        "Depth-First",
        "Backtracking",
        "Depth-First",
        "Beam Search",
        "Depth-First",
        "Algorithms",
        "Other",
        "Monte Carlo Tree Search",
        "Other",
        "Algorithms",
        "Root",
    ]

The mock objects are not very fancy. The dictionary tree iterator looks like this:

class DictTreeIterator(TreeIterator):
    def __init__(self, name: str, children: Dict[str, Any]):
        self.name = name
        self.children = children

    def is_terminal(self) -> bool:
        return len(self.children) > 0

    def get_children(self) -> Generator["DictTreeIterator", None, None]:
        for key, value in sorted(self.children.items()):
            yield DictTreeIterator(key, value)

    def name(self) -> str:
        return self.name

And the mock observer is just super trivial, it simply records all the iterators that were every passed to it:

class MockObserver(Observer):
    def __init__(self):
        self.states = []

    def observe(self, state: TreeIterator) -> None:
        self.states.append(state)

The Backtracking class is just a stub at this moment:

class Backtracking(TreeSearch):
    def __init__(self, observer: Observer):
        self.observer = observer

    def run(self, start: TreeIterator) -> None:
        pass

And unsurpringly the test fails, because the observer never observed anything, so it is just an empty list for now. Pytest words this like so:

E       AssertionError: assert [] == ['Root', 'Alg...racking', ...]
E         Right contains 15 more items, first extra item: 'Root'
E         Use -v to get the full diff

Now we can start to write the algorithm. And it turns out that the algorithm is really simple when written in terms of these iterators and observers:

class StackEntry:
    def __init__(self, state: TreeIterator):
        self.state = state
        self.child_iterator = iter(state.get_children())


class Backtracking(TreeSearch):
    def __init__(self, observer: Observer):
        self.observer = observer

    def run(self, start: TreeIterator) -> None:
        stack = [StackEntry(start)]

        while len(stack) > 0:
            current = stack[-1]
            self.observer.observe(current.state)
            try:
                next_child = next(current.child_iterator)
                stack.append(StackEntry(next_child))
            except StopIteration:
                stack.pop()

We just need to keep a stack of states and an iterator on their children. When there are children left, we dive deeper into the tree. When there are no more children, we back out one level. We do that until we have exchausted the whole tree. Done.

It is one of these instances where I am pretty dazzled by the ease of programming when using a decent design and also when working either test-driven or at least with testability in mind.

Refactoring Railroad Ink code

Now that we have the backtracking algorithm in an isolated fashion, we can go ahead and refactor the existing code from the blog post about Railroad Ink. Let us have a look at the search code there:

def do_step(i, j):
    if i <= 0 or j <= 0 or i >= board_size - 1 or j >= board_size - 1:
        return
    if board[i][j] is not None:
        return

    for tile in available_tiles:
        fits = [
            tile.fits(board[i][j - 1], Direction.LEFT),
            tile.fits(board[i][j + 1], Direction.RIGHT),
            tile.fits(board[i - 1][j], Direction.UP),
            tile.fits(board[i + 1][j], Direction.DOWN),
        ]

        if all(fit != FitType.INCOMPATIBLE for fit in fits) and any(
            fit == FitType.MATCHES for fit in fits
        ):
            board[i][j] = tile
            img_board[i * 100 : (i + 1) * 100, j * 100 : (j + 1) * 100] = tile.image
            open_spots = get_open(board)
            print_board()
            for ii, jj in sorted(open_spots):
                do_step(ii, jj)
            board[i][j] = None
            img_board[i * 100 : (i + 1) * 100, j * 100 : (j + 1) * 100] = img_empty_tile
            get_open(board)
            print_board()

It is backtracking implemented with recusion, so there is no explicit stack. Then there are two external functions which are called. The get_open() gets the open spots, these are the children in a graph view. That function is here:

def get_open(board):
    result = set()
    for i in range(1, board_size - 1):
        for j in range(1, board_size - 1):
            if board[i][j] is not None:
                continue
            if (
                has_direction(board[i - 1][j], Direction.DOWN)
                or has_direction(board[i][j + 1], Direction.LEFT)
                or has_direction(board[i][j - 1], Direction.RIGHT)
                or has_direction(board[i + 1][j], Direction.UP)
            ):
                result.add((i, j))

            img_board[i * 100 : (i + 1) * 100, j * 100 : (j + 1) * 100] = (
                img_open if (i, j) in result else img_empty_tile
            )
    return result

And the has_direction() is just a little helper function, which we will keep. The print_board() from above will just store the image file. There is some other setup code, but that is not really that important.

There is one subtlety which I would like to point out: The mingling of the tree search with the image generation allows me to only switch out the changed parts. This way the image does not need to be build up from scratch every time that a new state is encountered. You can see that in the code that the calls to do_step() are interleaved with slice assignments to img_board and calls to print_board(). It is a nice performance optimization that I would like to keep.

Yet it is not clear how to keep it, once we move that part into an “observer”. It would only get the current state. It would need to recall the previous state in order to create a difference between previous and current state. For the backtracking we know that it only differs in one tile. But with MCTS it could potentially be completely different. And with a random walk it would need to start with a fresh image after a reset. Computing this difference may be similarly expensive to just copying the 100×100 pixel arrays. For the time being I have just cut it.

The code looks much better now, and the parts have more purpose now.

There is a Board class which holds the state. It can create copies of itself. That may be a little wasteful, but it will enable me to do random jumps without having to worry about carefully undoing steps like I did with the old code.

class Board:
    def __init__(self):
        self.available_tiles = available_tiles
        self.board_size = 9
        self.board: List[List[Optional[Tile]]] = []
        for i in range(self.board_size):
            self.board.append([None] * self.board_size)

        self.board[2][0] = exit_rail_right
        self.board[4][0] = exit_road_right
        self.board[6][0] = exit_rail_right

        self.board[0][2] = exit_road_up
        self.board[0][4] = exit_rail_up
        self.board[0][6] = exit_road_up

        self.board[2][-1] = exit_rail_left
        self.board[4][-1] = exit_road_left
        self.board[6][-1] = exit_rail_left

        self.board[-1][2] = exit_road_down
        self.board[-1][4] = exit_rail_down
        self.board[-1][6] = exit_road_down

    def get_open_positions(self) -> Generator[Tuple[int, int], None, None]:
        for i in range(1, self.board_size - 1):
            for j in range(1, self.board_size - 1):
                if self.board[i][j] is not None:
                    continue
                if (
                    has_direction(self.board[i - 1][j], Direction.DOWN)
                    or has_direction(self.board[i][j + 1], Direction.LEFT)
                    or has_direction(self.board[i][j - 1], Direction.RIGHT)
                    or has_direction(self.board[i + 1][j], Direction.UP)
                ):
                    yield i, j

    def replace(self, i: int, j: int, tile: Tile) -> "Board":
        result = copy.copy(self)
        result.board = [copy.copy(row) for row in self.board]
        result.board[i][j] = tile
        return result

The matching iterator is pretty simple as well. The children are constructed from all free spots and all available tiles which match.

class RailroadInkIterator(TreeIterator):
    def __init__(self, board: Board):
        self.board = board

    def get_children(self) -> Generator["RailroadInkIterator", None, None]:
        board = self.board.board
        for i, j in self.board.get_open_positions():
            for tile in self.board.available_tiles:
                fits = [
                    tile.fits(board[i][j - 1], Direction.LEFT),
                    tile.fits(board[i][j + 1], Direction.RIGHT),
                    tile.fits(board[i - 1][j], Direction.UP),
                    tile.fits(board[i + 1][j], Direction.DOWN),
                ]

                if all(fit != FitType.INCOMPATIBLE for fit in fits) and any(
                    fit == FitType.MATCHES for fit in fits
                ):
                    yield RailroadInkIterator(self.board.replace(i, j, tile))

The image generating part is put into the video observer. It will render the image from scratch, but it had allowed me to decouple this part.

class VideoObserver(Observer):
    def __init__(self):
        self.steps = 0

        self.img_empty = np.ones((100, 100), np.uint8) * 255
        self.img_empty_tile = np.array(
            PIL.Image.open(pathlib.Path(__file__).parent / "tiles" / "empty tile.png")
        )
        self.img_open = np.array(
            PIL.Image.open(pathlib.Path(__file__).parent / "tiles" / "open.png")
        )

    def observe(self, state: TreeIterator) -> None:
        assert isinstance(state, RailroadInkIterator)

        board_size = state.board.board_size
        open_positions = set(state.board.get_open_positions())
        img_board = np.zeros((board_size * 100, board_size * 100), np.uint8)
        for i in range(board_size):
            for j in range(board_size):
                if state.board.board[i][j] is None:
                    if i == 0 or j == 0 or i == board_size - 1 or j == board_size - 1:
                        img = self.img_empty
                    elif (i, j) in open_positions:
                        img = self.img_open
                    else:
                        img = self.img_empty_tile
                else:
                    img = state.board.board[i][j].image
                img_board[i * 100 : (i + 1) * 100, j * 100 : (j + 1) * 100] = img

        pil_image = PIL.Image.fromarray(img_board)
        pil_image.save(f"railroad/{self.steps:06d}.png", "PNG")

        self.steps += 1
        if self.steps >= 100:
            raise RuntimeError()

The main function is now really simple and a joy to read:

def main():
    pathlib.Path("railroad").mkdir(exist_ok=True)
    board = Board()
    iterator = RailroadInkIterator(board)
    observer = VideoObserver()
    tree_search = Backtracking(observer)
    tree_search.run(iterator)

That's it! It seems much more clean. And now I can implement more tree search algorithm and generate more videos.

Adding a random walk

In the beginning of this article I claimed that the random walk strategy is trivial. This meant the graph traversal, the graph itself may be hard. But that part is already done, so implementing the random walk actually is trivial.

All I need to do is to pick one of the children at random. That's it. And this is the code:

class RandomWalk(TreeSearch):
    def __init__(self, observer: Observer):
        self.observer = observer

    def run(self, start: TreeIterator) -> None:
        state = start
        while True:
            self.observer.observe(state)
            children = list(state.get_children())
            if len(children) == 0:
                return
            state = random.choice(children)

All I have to change in the main module is that I don't have an instance of Backtracking but RandomWalk. Done. I let that run many times until I have 1000 frames of video. This is the result:

One can see that the random walk just randomly selects pieces which fit. This usually leads to empty spots that cannot be filled with the available pieces, like this final configuration:

There is just no piece to fill the gap. Also the top corners are empty. There is a fairly large mesh that connects 7 out of 12 exits. That's okay, but it is not good. There are five exists which are not connected to another one, this would not give any points in the game.

Another example is this final configuration, which only connects three exits in the largest cluster:

So it is fun to watch, but it doesn't optimize anything.

Conclusion and outlook

I have now refactored the Railroad Ink code such that it uses my new tree search library. In that I have implemented two tree search algorithms, backtracking and random walk.

Going forward I can also refactor the code for the other games that I have implemented. And I can add more algorithms to the library. Maybe my next step in this project will involve something with Monte Carlo Tree Search, an algorithm I have wanted to try out for a longer time now.