Counting Valid Binary Strings Under Reduction Rules
Given a binary string composed of '0', '1', and '?' characters, and a lookup table f that maps each 3-bit binary number (from 0 to 7) too either 0 or 1, determine the number of ways to replace all '?' characters with '0' or '1' such that the resulting string can be reduced to "1" using the following operasion:
- Choose an odd index
iwhere3 ≤ i ≤ len(S). - Split the string into prefix
A = S[0:i]and suffixB = S[i:]. - Repeatedyl replace the last three bits of
Awithf(value_of_last_three_bits)untilAbecomes a single bit. - Concatenate the resulting single-bit
AwithBto form a new string.
This process is repeated until the entire string reduces to "1".
The solution uses dynamic programming with state compression. Define an inner DP state inner[a][b] indicating whether it's possible to reduce the processed prefix ending in bit a to final result b. The outer DP tracks transitions over the input string while accounting for unknown ('?') positions.
Each state is encoded as a 5-bit integer:
- Bit 0: the last character of the current prefix (
s[i]). - Bits 1–4: the four values of
inner[0][0],inner[0][1],inner[1][0],inner[1][1].
For every position i (processed in steps of 2 starting from 3), and for each valid prior state, the algorithm enumerates possible values for s[i−1] and s[i] (respecting fixed characters if not '?'), computes the next inner DP state, and updates the outer DP accordingly.
At the end, sum over all terminal states where the full string reduces to "1", i.e., inner[last_char][1] == true.
#include <bits/stdc++.h>
using namespace std;
const int MOD = 998244353;
const int MAXN = 100005;
int T, n;
int trans[8];
int str[MAXN];
long long dp[MAXN][32];
int encodeState(int a00, int a01, int a10, int a11, int last) {
return (a00 << 4) | (a01 << 3) | (a10 << 2) | (a11 << 1) | last;
}
int getStateFromIndex(int x, int y, int z) {
return (x << 2) | (y << 1) | z;
}
int nextState(int prevState, int mid, int cur) {
int prevLast = prevState & 1;
bool inner[2][2] = {};
inner[0][0] = (prevState >> 4) & 1;
inner[0][1] = (prevState >> 3) & 1;
inner[1][0] = (prevState >> 2) & 1;
inner[1][1] = (prevState >> 1) & 1;
bool nextInner[2][2] = {};
for (int a = 0; a < 2; ++a) {
for (int b = 0; b < 2; ++b) {
int idx1 = getStateFromIndex(prevLast, mid, a);
nextInner[a][b] |= inner[trans[idx1]][b];
for (int c = 0; c < 2; ++c) {
if (inner[prevLast][c]) {
int idx2 = getStateFromIndex(c, mid, a);
if (trans[idx2] == b)
nextInner[a][b] = true;
}
}
}
}
return encodeState(nextInner[0][0], nextInner[0][1], nextInner[1][0], nextInner[1][1], cur);
}
int main() {
ios::sync_with_stdio(false);
cin.tie(nullptr);
cin >> T;
while (T--) {
string rule;
cin >> rule;
// Map f(000), f(100), ..., f(111) to trans[0..7]
trans[0] = rule[0] - '0'; // 000
trans[4] = rule[1] - '0'; // 100
trans[2] = rule[2] - '0'; // 010
trans[6] = rule[3] - '0'; // 110
trans[1] = rule[4] - '0'; // 001
trans[5] = rule[5] - '0'; // 101
trans[3] = rule[6] - '0'; // 011
trans[7] = rule[7] - '0'; // 111
string s;
cin >> s;
n = s.size();
for (int i = 0; i < n; ++i) {
if (s[i] == '?') str[i + 1] = -1;
else str[i + 1] = s[i] - '0';
}
memset(dp, 0, sizeof(dp));
if (str[1] != -1) {
int st = encodeState(0, 0, 0, 0, str[1]);
st |= (1LL << (4 - str[1] * 2 - str[1])); // set inner[str[1]][str[1]] = 1
// Actually simpler: directly build correct initial state
int initInner[2][2] = {};
initInner[str[1]][str[1]] = 1;
dp[1][encodeState(initInner[0][0], initInner[0][1], initInner[1][0], initInner[1][1], str[1])] = 1;
} else {
for (int v = 0; v < 2; ++v) {
int initInner[2][2] = {};
initInner[v][v] = 1;
dp[1][encodeState(initInner[0][0], initInner[0][1], initInner[1][0], initInner[1][1], v)] = 1;
}
}
for (int i = 3; i <= n; i += 2) {
for (int prevState = 0; prevState < 32; ++prevState) {
if (!dp[i - 2][prevState]) continue;
int prevLast = prevState & 1;
for (int midVal = 0; midVal < 2; ++midVal) {
if (str[i - 1] != -1 && str[i - 1] != midVal) continue;
for (int curVal = 0; curVal < 2; ++curVal) {
if (str[i] != -1 && str[i] != curVal) continue;
int newState = nextState(prevState, midVal, curVal);
dp[i][newState] = (dp[i][newState] + dp[i - 2][prevState]) % MOD;
}
}
}
}
long long ans = 0;
for (int state = 0; state < 32; ++state) {
int lastChar = state & 1;
if (str[n] != -1 && str[n] != lastChar) continue;
bool inner[2][2];
inner[0][0] = (state >> 4) & 1;
inner[0][1] = (state >> 3) & 1;
inner[1][0] = (state >> 2) & 1;
inner[1][1] = (state >> 1) & 1;
if (inner[lastChar][1]) {
ans = (ans + dp[n][state]) % MOD;
}
}
cout << ans << '\n';
}
return 0;
}