Optimizing Large Knapsack Problems with CDQ Divide and Conquer
Applying CDQ divide and conquer can significantly enhance the execution speed of the solution. In practice, this approach can pass all test cases within 10ms for the given data scale.
For background on meet-in-the-middle search, refer to existing solutions on the topic.
Problem Analysis
This is a classic large knapsack problem, where the standard approach is meet-in-the-middle search. The process involves two separate searches, recording all possible states from each half, and then combining them to find the optimal solution. However, a common issue is the generation of many inefficient states—combinations that are heavy but offer little value—which are unnecessary for the final answer.
A viable optimization is to integrate a divide-and-conquer strategy on top of the basic meet-in-the-middle approach. By recursively splitting the interval and pruning invalid states at each merge step, we prevent these states from being used in subsequent combinations.
Benchmarking shows that for typical data, the divide-and-conquer version completes in milliseconds. For worst-case data (where no states are invalid), the performance may be slightly slower than a standard implementation.
Time Complexity
The asymptotic time complexity remains the same as that of meet-in-the-middle search, but with a reduced constant factor.
Algorithm Explanatoin
The core idea leverages CDQ divide and conquer to merge results from left and right segments at each step, using binary search to find the optimal combination efficiently. This is particularly effective for handling large-scale data.
CDQ Divide and Conquer Implementation
#include <iostream>
#include <algorithm>
#include <vector>
#define int long long
using namespace std;
int itemCount, capacity, selectedMask;
struct Item {
int weight;
int profit;
} items[55];
struct State {
int totalWeight;
int totalValue;
int mask;
};
vector<State> segmentStates[55], mergedList;
bool compareStates(State a, State b) {
if (a.totalWeight != b.totalWeight) return a.totalWeight < b.totalWeight;
return a.totalValue < b.totalValue;
}
int cdqDivide(int left, int right) {
if (left == right) {
segmentStates[left].clear();
segmentStates[left].push_back({0, 0, 0});
if (items[left].weight <= capacity)
segmentStates[left].push_back({items[left].weight, items[left].profit, 1LL << left});
if (left == 0 && right == itemCount - 1) {
selectedMask = 1;
return segmentStates[left].back().totalValue;
}
return 0;
}
int mid = (left + right) / 2;
cdqDivide(left, mid);
cdqDivide(mid + 1, right);
if (itemCount == right - left + 1) {
int bestValue = 0;
auto &rightSegment = segmentStates[mid + 1];
for (int i = 0; i < segmentStates[left].size(); i++) {
int low = 0, high = rightSegment.size() - 1;
while (low <= high) {
int middle = (low + high) / 2;
if (segmentStates[left][i].totalWeight + rightSegment[middle].totalWeight <= capacity) low = middle + 1;
else high = middle - 1;
}
if (segmentStates[left][i].totalValue + rightSegment[low - 1].totalValue > bestValue) {
bestValue = segmentStates[left][i].totalValue + rightSegment[low - 1].totalValue;
selectedMask = segmentStates[left][i].mask + rightSegment[low - 1].mask;
}
bestValue = max(bestValue, segmentStates[left][i].totalValue + rightSegment[low - 1].totalValue);
}
return bestValue;
}
for (int i = 0; i < segmentStates[left].size(); i++) {
for (int j = 0; j < segmentStates[mid + 1].size(); j++) {
if (segmentStates[left][i].totalWeight + segmentStates[mid + 1][j].totalWeight <= capacity) {
int combinedWeight = segmentStates[left][i].totalWeight + segmentStates[mid + 1][j].totalWeight;
int combinedValue = segmentStates[left][i].totalValue + segmentStates[mid + 1][j].totalValue;
int combinedMask = segmentStates[left][i].mask + segmentStates[mid + 1][j].mask;
mergedList.push_back({combinedWeight, combinedValue, combinedMask});
} else break;
}
}
sort(mergedList.begin(), mergedList.end(), compareStates);
segmentStates[left].clear();
int currentMaxValue = -1;
for (int i = 0; i < mergedList.size(); i++) {
if (mergedList[i].totalValue > currentMaxValue) {
currentMaxValue = mergedList[i].totalValue;
segmentStates[left].push_back(mergedList[i]);
}
}
mergedList.clear();
return 0;
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin >> itemCount >> capacity;
for (int i = 0; i < itemCount; i++)
cin >> items[i].weight >> items[i].profit;
cout << cdqDivide(0, itemCount - 1) << " ";
vector<int> chosenItems;
for (int i = 0; i < itemCount; i++) {
if (selectedMask >> i & 1) {
chosenItems.push_back(i + 1);
}
}
cout << chosenItems.size() << endl;
for (auto idx : chosenItems) cout << idx << " ";
return 0;
}
Meet-in-the-Middle Search Implementation
#include <iostream>
#include <algorithm>
#include <vector>
#define int long long
using namespace std;
int n, m, splitPoint;
int weight[55], value[55];
struct Node {
int w;
int val;
int mask;
};
vector<Node> firstHalf, secondHalf;
bool comp(Node a, Node b) {
if (a.w == b.w) return a.val < b.val;
return a.w < b.w;
}
void exploreFirst(int start, int end, int curWeight, int curValue, int curMask) {
if (curWeight > m) return;
if (start > end) {
firstHalf.push_back({curWeight, curValue, curMask});
return;
}
exploreFirst(start + 1, end, curWeight, curValue, curMask);
exploreFirst(start + 1, end, curWeight + weight[start], curValue + value[start], curMask + (1LL << (start - 1)));
}
void exploreSecond(int start, int end, int curWeight, int curValue, int curMask) {
if (curWeight > m) return;
if (start > end) {
secondHalf.push_back({curWeight, curValue, curMask});
return;
}
exploreSecond(start + 1, end, curWeight, curValue, curMask);
exploreSecond(start + 1, end, curWeight + weight[start], curValue + value[start], curMask + (1LL << (start - 1)));
}
signed main() {
ios::sync_with_stdio(0);
cin.tie(0); cout.tie(0);
cin >> n >> m;
for (int i = 1; i <= n; i++)
cin >> weight[i] >> value[i];
splitPoint = n / 2;
exploreFirst(1, splitPoint, 0, 0, 0);
exploreSecond(splitPoint + 1, n, 0, 0, 0);
sort(firstHalf.begin(), firstHalf.end(), comp);
vector<Node> filtered;
filtered.push_back({-1, -1, -1});
for (int i = 0; i < firstHalf.size(); i++) {
if (filtered.back().val < firstHalf[i].val) {
filtered.push_back(firstHalf[i]);
}
}
int maxProfit = 0, resultMask = 0;
for (int i = 0; i < secondHalf.size(); i++) {
int remaining = m - secondHalf[i].w;
int left = 1, right = filtered.size() - 1;
int bestIdx = 0;
while (left <= right) {
int mid = (left + right) / 2;
if (filtered[mid].w <= remaining) {
left = mid + 1;
bestIdx = mid;
} else right = mid - 1;
}
if (bestIdx != 0 && filtered[bestIdx].val + secondHalf[i].val > maxProfit) {
maxProfit = filtered[bestIdx].val + secondHalf[i].val;
resultMask = secondHalf[i].mask + filtered[bestIdx].mask;
}
}
vector<int> answer;
for (int i = 0; i < n; i++) {
if (resultMask >> i & 1) {
answer.push_back(i + 1);
}
}
cout << maxProfit << " " << answer.size() << endl;
for (auto elem : answer) cout << elem << " ";
return 0;
}