Fading Coder

One Final Commit for the Last Sprint

Home > Tech > Content

Optimizing Large Knapsack Problems with CDQ Divide and Conquer

Tech 2

Applying CDQ divide and conquer can significantly enhance the execution speed of the solution. In practice, this approach can pass all test cases within 10ms for the given data scale.

For background on meet-in-the-middle search, refer to existing solutions on the topic.

Problem Analysis

This is a classic large knapsack problem, where the standard approach is meet-in-the-middle search. The process involves two separate searches, recording all possible states from each half, and then combining them to find the optimal solution. However, a common issue is the generation of many inefficient states—combinations that are heavy but offer little value—which are unnecessary for the final answer.

A viable optimization is to integrate a divide-and-conquer strategy on top of the basic meet-in-the-middle approach. By recursively splitting the interval and pruning invalid states at each merge step, we prevent these states from being used in subsequent combinations.

Benchmarking shows that for typical data, the divide-and-conquer version completes in milliseconds. For worst-case data (where no states are invalid), the performance may be slightly slower than a standard implementation.

Time Complexity

The asymptotic time complexity remains the same as that of meet-in-the-middle search, but with a reduced constant factor.

Algorithm Explanatoin

The core idea leverages CDQ divide and conquer to merge results from left and right segments at each step, using binary search to find the optimal combination efficiently. This is particularly effective for handling large-scale data.

CDQ Divide and Conquer Implementation

#include <iostream>
#include <algorithm>
#include <vector>
#define int long long
using namespace std;

int itemCount, capacity, selectedMask;
struct Item {
    int weight;
    int profit;
} items[55];

struct State {
    int totalWeight;
    int totalValue;
    int mask;
};

vector<State> segmentStates[55], mergedList;

bool compareStates(State a, State b) {
    if (a.totalWeight != b.totalWeight) return a.totalWeight < b.totalWeight;
    return a.totalValue < b.totalValue;
}

int cdqDivide(int left, int right) {
    if (left == right) {
        segmentStates[left].clear();
        segmentStates[left].push_back({0, 0, 0});
        if (items[left].weight <= capacity)
            segmentStates[left].push_back({items[left].weight, items[left].profit, 1LL << left});
        if (left == 0 && right == itemCount - 1) {
            selectedMask = 1;
            return segmentStates[left].back().totalValue;
        }
        return 0;
    }
    int mid = (left + right) / 2;
    cdqDivide(left, mid);
    cdqDivide(mid + 1, right);

    if (itemCount == right - left + 1) {
        int bestValue = 0;
        auto &rightSegment = segmentStates[mid + 1];
        for (int i = 0; i < segmentStates[left].size(); i++) {
            int low = 0, high = rightSegment.size() - 1;
            while (low <= high) {
                int middle = (low + high) / 2;
                if (segmentStates[left][i].totalWeight + rightSegment[middle].totalWeight <= capacity) low = middle + 1;
                else high = middle - 1;
            }
            if (segmentStates[left][i].totalValue + rightSegment[low - 1].totalValue > bestValue) {
                bestValue = segmentStates[left][i].totalValue + rightSegment[low - 1].totalValue;
                selectedMask = segmentStates[left][i].mask + rightSegment[low - 1].mask;
            }
            bestValue = max(bestValue, segmentStates[left][i].totalValue + rightSegment[low - 1].totalValue);
        }
        return bestValue;
    }

    for (int i = 0; i < segmentStates[left].size(); i++) {
        for (int j = 0; j < segmentStates[mid + 1].size(); j++) {
            if (segmentStates[left][i].totalWeight + segmentStates[mid + 1][j].totalWeight <= capacity) {
                int combinedWeight = segmentStates[left][i].totalWeight + segmentStates[mid + 1][j].totalWeight;
                int combinedValue = segmentStates[left][i].totalValue + segmentStates[mid + 1][j].totalValue;
                int combinedMask = segmentStates[left][i].mask + segmentStates[mid + 1][j].mask;
                mergedList.push_back({combinedWeight, combinedValue, combinedMask});
            } else break;
        }
    }
    sort(mergedList.begin(), mergedList.end(), compareStates);
    segmentStates[left].clear();
    int currentMaxValue = -1;
    for (int i = 0; i < mergedList.size(); i++) {
        if (mergedList[i].totalValue > currentMaxValue) {
            currentMaxValue = mergedList[i].totalValue;
            segmentStates[left].push_back(mergedList[i]);
        }
    }
    mergedList.clear();
    return 0;
}

signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> itemCount >> capacity;
    for (int i = 0; i < itemCount; i++)
        cin >> items[i].weight >> items[i].profit;

    cout << cdqDivide(0, itemCount - 1) << " ";
    vector<int> chosenItems;
    for (int i = 0; i < itemCount; i++) {
        if (selectedMask >> i & 1) {
            chosenItems.push_back(i + 1);
        }
    }
    cout << chosenItems.size() << endl;
    for (auto idx : chosenItems) cout << idx << " ";
    return 0;
}

Meet-in-the-Middle Search Implementation

#include <iostream>
#include <algorithm>
#include <vector>
#define int long long
using namespace std;

int n, m, splitPoint;
int weight[55], value[55];
struct Node {
    int w;
    int val;
    int mask;
};
vector<Node> firstHalf, secondHalf;

bool comp(Node a, Node b) {
    if (a.w == b.w) return a.val < b.val;
    return a.w < b.w;
}

void exploreFirst(int start, int end, int curWeight, int curValue, int curMask) {
    if (curWeight > m) return;
    if (start > end) {
        firstHalf.push_back({curWeight, curValue, curMask});
        return;
    }
    exploreFirst(start + 1, end, curWeight, curValue, curMask);
    exploreFirst(start + 1, end, curWeight + weight[start], curValue + value[start], curMask + (1LL << (start - 1)));
}

void exploreSecond(int start, int end, int curWeight, int curValue, int curMask) {
    if (curWeight > m) return;
    if (start > end) {
        secondHalf.push_back({curWeight, curValue, curMask});
        return;
    }
    exploreSecond(start + 1, end, curWeight, curValue, curMask);
    exploreSecond(start + 1, end, curWeight + weight[start], curValue + value[start], curMask + (1LL << (start - 1)));
}

signed main() {
    ios::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);
    cin >> n >> m;
    for (int i = 1; i <= n; i++)
        cin >> weight[i] >> value[i];

    splitPoint = n / 2;
    exploreFirst(1, splitPoint, 0, 0, 0);
    exploreSecond(splitPoint + 1, n, 0, 0, 0);
    sort(firstHalf.begin(), firstHalf.end(), comp);

    vector<Node> filtered;
    filtered.push_back({-1, -1, -1});
    for (int i = 0; i < firstHalf.size(); i++) {
        if (filtered.back().val < firstHalf[i].val) {
            filtered.push_back(firstHalf[i]);
        }
    }

    int maxProfit = 0, resultMask = 0;
    for (int i = 0; i < secondHalf.size(); i++) {
        int remaining = m - secondHalf[i].w;
        int left = 1, right = filtered.size() - 1;
        int bestIdx = 0;
        while (left <= right) {
            int mid = (left + right) / 2;
            if (filtered[mid].w <= remaining) {
                left = mid + 1;
                bestIdx = mid;
            } else right = mid - 1;
        }
        if (bestIdx != 0 && filtered[bestIdx].val + secondHalf[i].val > maxProfit) {
            maxProfit = filtered[bestIdx].val + secondHalf[i].val;
            resultMask = secondHalf[i].mask + filtered[bestIdx].mask;
        }
    }

    vector<int> answer;
    for (int i = 0; i < n; i++) {
        if (resultMask >> i & 1) {
            answer.push_back(i + 1);
        }
    }
    cout << maxProfit << " " << answer.size() << endl;
    for (auto elem : answer) cout << elem << " ";
    return 0;
}
Tags: algorithms

Related Articles

Understanding Strong and Weak References in Java

Strong References Strong reference are the most prevalent type of object referencing in Java. When an object has a strong reference pointing to it, the garbage collector will not reclaim its memory. F...

Comprehensive Guide to SSTI Explained with Payload Bypass Techniques

Introduction Server-Side Template Injection (SSTI) is a vulnerability in web applications where user input is improper handled within the template engine and executed on the server. This exploit can r...

Implement Image Upload Functionality for Django Integrated TinyMCE Editor

Django’s Admin panel is highly user-friendly, and pairing it with TinyMCE, an effective rich text editor, simplifies content management significantly. Combining the two is particular useful for bloggi...

Leave a Comment

Anonymous

◎Feel free to join the discussion and share your thoughts.