Digit Dynamic Programming

8 minute read

Published:

Introduction

Counting problems are my absolute nemesis. It usually involves some combination of fancy counting logic and/or dynamic programming. This post aims to tackle questions of the format:

Count the number of positive integers between L and R that satisfy condition

Some examples of this include:

The condition usually isn’t very complicated and can be checked by simply iterating over the string in O(logn) time. In this brute force approach, we would perform this action for every number up to R leading to a runtime of O(Rlog(length(R))). The challenge arises when looking at the constraints for R, which typically ranges from 10^9 - 10^15. Obviously, anything slower than O(logR) is not going to suffice.

Premise

So instead of actually iterating over the actual numbers, why don’t we actually build the number digit by digit? We know that the number has to be of some length N and gradually build it up to that value N using dynamic programming.

As we add each digit, we’ll see how it affects the subproblem. We’ll maintain these changes in some kind of state and maintain both which digit we are on and the current number’s ‘state’ in our dynamic programming cache or state. More formally:

dp[(N, S)] = dp[(0(N-1), S')] + dp[(1(N-1), S')] + ... + dp[(9(N-1), S')]

In the above statement, I’m using the notation X(N-1) to indicate a number starting with the digit X and is of length N-1 and S to indicate the state. A key thing to note that we can add the number 0 as a prefix. This is equivalent to adding an blank space in the prefix of the number. Otherwise, how are we supposed to get numbers of length less than N?

This is best illustrated with an example and throughout this post I’ll be using two classic problems as examples: Number of Digit One and Count Special Integers.

For the former, the state would be the number of 1’s accumulated so far and for the latter the state would maintain digits we have used so far so we don’t repeat them.

Basic Scenario - No Defined Upper Bound

Let’s assume that, for now, we don’t have a limit R, but instead want to obtain all the numbers satisfying condition below with N digits. Using the assumption, we can write some skeleton code:

N = 3
def go(i, state, path):
    # Base case, we have built a valid number
    if i == N:
        print(path)
        return 1

    res = 0

    # Try adding each individual digit
    limit = 9
    for d in range(limit + 1):
        # Check condition
        if (<condition>)
            res += go(i + 1, newState, path + str(d))
    return res

The path variable above is a debugging trick I use often to see the ‘path’ my dynamic programming took. In the above case (with no condition), I would print out an output like so:

000
001
002
...
100
101
...
901
902
...
999

Obtaining all numbers from 0 - 999.

This basic template is pretty universal. We simply need to tack on the condition. For the problem counting the number of ones, we simply have the following code inside the for loop:

for d in range(limit + 1):
    res += go(i + 1, count + 1 if d == 1 else count, path  + str(d))

For the repeated digit problem, we would have the following code:

for d in range(limit + 1):
    # Create a mask and check if old mask & new mask indicate no overlap
    m = d << 1
    if mask & m == 0:
        res += go(i + 1, mask & m, path  + str(d))

Now, we can count the number of values of length N greater than or equal to zero satisfying any condition.

Adding the bound constraints

Ok, but how can se ensure that our result does not go over the limit R? To achieve this, we can use an additional state variable called tight to check if we are under a ‘tight constraint’. I like to think of it as ‘are we hitting our head against the ceiling?’.

Let’s say R is 68543 and we are about to add the first digit to our number. There are 3 cases to consider:

1) Adding any number greater than 6. Adding any number above 6 initially would create a number of the form 7_ _ _ _ , which is for sure greater than R. So we ensure enforce a strict limit of 6.

2) Adding any number less than 6. Obviously, the resulting number would be less than R, so we can do this. Additionally, we are under no tight restraint after adding this number. In other words, after adding a number less than 6, we can add any number in the range 0-9 since the number is already guaranteed to be less than R

3) Adding 6 itself. We can do this, but we must ensure that the successive adds will not lead to a number greater than R. For example, after adding 6, we cannot add 9 like case 2. We must ensure that the number is less than or equal to 8. Moreover, after adding 8, we need to make sure that the next add is less than or equal to 5. And so on and so forth. You get the idea.

Basically, tight will maintain whether we are at the bleeding edge of the upper bound. If we are currently tight, then we can only add whatever is at values less than or equal to whatever is at that index of R. Otherwise, we can add any digit.

Additionally, if we are currently tight and we add the digit at that current index of R, successive adds must be tight as well.

Tight Tight Tight

So now, we just have to add this state variable and check whether we are tight and adjust accordingly:

digits = '68543'
def go(i, state, tight,):
    # Base case, we have built a valid number
    if i == N:
        return 1

    res = 0

    # Adjust limit according to tight
    limit = int(s[i]) if tight else 9
    for d in range(limit + 1):
        # Check condition
        if (<condition>)
            # Check if tight AND current digit is at the limit
            res += go(i + 1, newState, tight and d == int(s[i]))
    return res

Now, we have our digit DP function! Note that we will initialize the variable tight to True since we are technically under a tight constraint when we have added no digits to our solution.

One final tidbit to add is how to enforce a lower bound. We can do this simply using the above dp function again, except subtracting all values less than our lower bound. So we will make 2 calls to the dp function - one to obtain all the numbers up to upper bound and another to obtain all the numbers up to lower bound. Subtracting the two will give us the answer:

dp([L, R]) = dp(R) - dp(L - 1)

To cap it off, I want to give the two problems state above as examples:

Count Special Integers

class Solution:
    def countSpecialNumbers(self, n: int) -> int:
        digits = [int(i) for i in str(n)]
        @lru_cache(None)
        def go(i, mask, is_limit):
            if i == len(digits):
                return 1
            res = 0
            limit = digits[i] if is_limit else 9
            for d in range(limit + 1):
                m = 1 << d
                if mask & m == 0:
                    new_mask = mask if mask == 0 and d == 0 else mask | (1 << d)
                    res += go(i + 1, new_mask, is_limit and d == digits[i])
            return res

        
        return go(0, 0, True) - 1

Number of Digit One

class Solution:
    def countDigitOne(self, n: int) -> int:
        digits = [int(c) for c in str(n)]
        @lru_cache(None)
        def go(i, count, tight, ):
            if i == len(digits):
                return count

            res = 0
            limit = digits[i] if tight else 9
            for d in range(limit + 1):
                if d == 1:
                    res += go(i + 1, count + 1, tight and d == digits[i])
                else:
                    res += go(i + 1, count, tight and d == digits[i])
            return res

        return go(0, 0, True)