Project Euler Solution 51: Prime digit replacements

In Problem 51: Prime digit replacements we have to find families of prime numbers that relate to each other with common digits.

By replacing the 1st digit of the 2-digit number *3, it turns out that six of the nine possible values: 13, 23, 43, 53, 73, and 83, are all prime.

By replacing the 3rd and 4th digits of 56**3 with the same digit, this 5-digit number is the first example having seven primes among the ten generated numbers, yielding the family: 56003, 56113, 56333, 56443, 56663, 56773, and 56993. Consequently 56003, being the first member of this family, is the smallest prime with this property.

Find the smallest prime which, by replacing part of the number (not necessarily adjacent digits) with the same digit, is part of an eight prime value family.

First we encode the example in the problem statement with a test. We want to have a function get_prime_family which accepts a list of digits, a mask and a set of prime numbers. Then it will return a list of all the primes in that family.

def test_get_prime_family() -> None:
    primes = prime_sieve(100000)
    prime_set = set(primes)
    assert get_prime_family(
        list("56003"), (False, False, True, True, False), prime_set
    ) == [56003, 56113, 56333, 56443, 56663, 56773, 56993]

We can implement this function now. It just replaces all the digits in the mask with all values from 0 to 9. If it is the first digit, we don't use the 0.

def get_prime_family(digits: list[str], mask: tuple, prime_set: set[int]) -> list[int]:
    new_numbers = [
        int("".join(str(replacement) if m else digit for digit, m in zip(digits, mask)))
        for replacement in range(1 if mask[0] else 0, 10)
    ]
    return [number for number in new_numbers if number in prime_set]

We can then write a function around that which takes all possible masks and get all possible families.

def get_prime_families(prime: int, prime_set: set[int]) -> list[list[int]]:
    digits = list(str(prime))
    families = [
        get_prime_family(digits, mask + (False,), prime_set)
        for mask in itertools.product(*[(True, False) for i in range(len(digits) - 1)])
    ]
    result = [family for family in families if family]
    result.sort()
    return result

In the solution we just have to iterate through the primes and find the first family with eight elements.

def solution() -> int:
    primes = prime_sieve(1000000)
    prime_set = set(primes)
    for prime in primes:
        families = get_prime_families(prime, prime_set)
        for family in families:
            if len(family) == 8:
                return family[0]

This runs in 2.5 m until it finds the correct answer. That still feels acceptable.