Project Euler Solution 92: Square Digit Chains

In Project Euler Problem 92 we're asked to look at some digit stuff again.

When we take a number like 44, we can decompose it into the digits 4 and 4, square each to 4² = 16 and add those up. We end with 32. We do it again and have 3² + 2² = 13. We can continue this process and supposedly end up with either 89 or 1 eventually. These numbers have the curious property that they are fix points in this expansion as 8² + 9² = 89 and 1¹ = 1.

The question now is how many numbers below ten million end up at 89.

My approach is brute force with caching. So we create a dictionary which holds the numbers that each number terminates. After we have reached either 1 or 89, we also put all the intermediate numbers into the set such that we can return earlier for the next number.

terminators = {}


def get_terminator(number: int) -> int:
    passed = []
    while True:
        if number == 1 or number == 89:
            break
        if number in terminators:
            number = terminators[number]
            break
        passed.append(number)
        number = sum(int(digit) ** 2 for digit in str(number))
    for p in passed:
        terminators[p] = number
    return number

Then we just need to go through all the numbers and count that.

def solution() -> int:
    terminates_in_89 = 0
    for number in range(1, 10_000_000):
        terminator = get_terminator(number)
        if terminator == 89:
            terminates_in_89 += 1
    return terminates_in_89

This solution takes 42 s to run, so that's not very good.

Incremental improvements

The number-to-string and string-to-number conversions might take too long. Let's try to keep it with integer arithmetic and replace the relevant part:

        digits = number
        number = 0
        while digits:
            number += (digits % 10) ** 2
            digits //= 10

This makes it a little faster, it now takes 36 s.

Next we can replace the dictionary with list of a defined size. Then we don't have to do the dictionary accesses but can go directly to the relevant index:

terminators = [None] * 10_000_000


def get_terminator(number: int) -> int:
    passed = []
    while True:
        if number == 1 or number == 89:
            break
        if terminators[number]:
            number = terminators[number]
            break
        passed.append(number)
        digits = number
        number = 0
        while digits:
            number += (digits % 10) ** 2
            digits //= 10
    for p in passed:
        terminators[p] = number
    return number

That brings it down to 32 s, which is just a little faster. And it shows how well the Python dictionary is optimized.

One can also use tail recursion and the cache from functools:

@functools.cache
def get_terminator(number: int) -> int:
    if number == 1 or number == 89:
        return number
    else:
        digits = number
        number = 0
        while digits:
            number += (digits % 10) ** 2
            digits //= 10
        return get_terminator(number)

This takes 31 s, so no noticeable difference to the manual implementation.

Removing permutations

When one has the correct result, one can take a look into the discussion thread and see what other ideas people came up with. And there is the insight that permutations of numbers don't matter as we are only interested in the digits.

Instead of running all the numbers, we can normalize them first by just alphabetically sorting their digits:

def normalize_number(number: int) -> int:
    return int("".join(sorted(str(number))))

This way we access more elements from the cache and therefore have a shorter run time of only 11 s. So this really makes a difference.

Enumerating all digit contents

One can go one step further. We are concerned with all numbers smaller than ten million, which means that we look at all numbers with 7 digits made up from the digits 0 to 9, though 0,000,000 is excluded.

Let's say that a given number $n$ is made up of $k_d$ occurrences of the digit $d$. As a constraint the number of digits needs to be 7, so we can write this constraint as $\sum_{d = 0}^9 k_d = 7$. This means that we cannot choose all $k_d$ independently, but that's okay.

The largest number would consists of seven nines. Then the sum of the squares of the digits would be $7 \cdot 9^2 = 567$. That is the largest case that we really need to consider after the first step and is a rather small table to build.

For each digit content, there are a bunch of permutations that yield different numbers but also permutations that yield the same number. We need to count all the permutations that yield different numbers only, otherwise we would overcount. As we're looking at seven digit numbers, there are $7!$ permutations. But if we have multiple occurrences of a digit, there are $k_d!$ permutations which don't change the number. Therefore the number of permutations $p$ for a given digit content is this: $$ p = \frac{7!}{\prod_{d=0}^9 d_k!} \,. $$

This means that we can drastically change the code. First we can split the terminator computation into two functions to make it a little easier to read. Then we need to make the terminator function work with an input of 0.

def get_digit_square_sum(number: int) -> int:
    digits = number
    number = 0
    while digits:
        number += (digits % 10) ** 2
        digits //= 10
    return number


@functools.cache
def get_terminator(number: int) -> int:
    if number == 0:
        return 0
    if number == 1 or number == 89:
        return number
    else:
        return get_terminator(get_digit_square_sum(number))

I use a recursive iterator to generate all the possible digit counts. It has a list of integers which serve as the $k_d$ and also the number of digits used so far. Then it generates all the possibilities using the remaining number of integers. The elements yielded by this generator are lists of 10 elements, denoting the multiplicities of the digits 0 to 9.

def iter_digits(counts: list[int], used: int) -> Iterator[list[int]]:
    if len(counts) == 10:
        yield counts
    else:
        lower = 0 if len(counts) < 9 else 7 - used
        upper = 8 - used
        for k in range(lower, upper):
            counts.append(k)
            yield from iter_digits(counts, used + k)
            counts.pop()

The actual solution is the just assembling these pieces. We iterate through all the digit contents that the generator gives us. Then we compute the multiplicity using the factorial formula written above. Finally we compute the digit sum once and then use our cached tail recursive function get_terminator work out the terminator.

def solution() -> int:
    terminates_in_89 = 0
    for ks in iter_digits([], 0):
        multiplicity = math.factorial(7)
        for kd in ks:
            multiplicity //= math.factorial(kd)
        terminator = get_terminator(sum(kd * d**2 for d, kd in enumerate(ks)))
        if terminator == 89:
            terminates_in_89 += multiplicity

    return terminates_in_89

This now runs in 112 ms, which is much faster than before and feels like the right solution to this problem because it is fast and has used some non-trivial insights.