k-anonymous sequence - poj 3709
#dp #cht #convex hull trick
問題
http://poj.org/problem?id=3709
O(N^3)
a[j + 1] ~ a[i] を a[j] と同じにする
#include<bits/stdc++.h> using namespace std; typedef long long ll; #define F first #define S second #define pii pair<int, int> #define eb emplace_back #define all(v) v.begin(), v.end() #define rep(i, n) for (int i = 0; i < (n); ++i) #define rep3(i, l, n) for (int i = l; i < (n); ++i) #define sz(v) (int)v.size() const int inf = 1e9 + 7; const ll INF = 1e18; const int mod = 1000000007; #define abs(x) (x >= 0 ? x : -(x)) #define lb(v, x) (int)(lower_bound(all(v), x) - v.begin()) #define ub(v, x) (int)(upper_bound(all(v), x) - v.begin()) template<typename T1, typename T2> inline bool chmin(T1 &a, T2 b) { if (a > b) { a = b; return 1; } return 0; } template<typename T1, typename T2> inline bool chmax(T1 &a, T2 b) { if (a < b) { a = b; return 1; } return 0; } template<typename T> T gcd(T a, T b) { if (b == 0) return a; return gcd(b, a % b); } template<typename T> T lcm(T a, T b) { return a / gcd(a, b) * b; } template<typename T> T pow(T a, int b) { return b ? pow(a * a, b / 2) * (b % 2 ? a : 1) : 1; } ll modpow(ll a, int b, int _mod) { return b ? modpow(a * a % _mod, b / 2, _mod) * (b % 2 ? a : 1) % _mod : 1; } template<class T> ostream& operator<<(ostream& os, const vector<T>& vec) { for (auto &vi: vec) os << vi << " "; return os; } template<class T, class U> ostream& operator<<(ostream& os, const pair<T, U>& p) { os << p.F << " " << p.S; return os; } template<class T> inline void add(T &a, int b) { a += b; if (a >= mod) a -= mod; } void solve(); int main() { ios::sync_with_stdio(false); cin.tie(0); cout.tie(0); cout << fixed << setprecision(10); int T; // cin >> T; T = 1; while (T--) { solve(); } } void solve() { int n, k; cin >> n >> k; vector<int> a(n); rep(i, n) cin >> a[i]; vector<ll> dp(n + 1, INF); // 最初の i 項だけ見て, ある値と同じ値が他に k-1 個は存在するような // 変更の最小回数 dp[0] = 0; // 空 rep(i, n) { // 最初 j 項が条件を満たす // 最初 i 項が条件を満たすには? rep(j, i - (k - 1) + 1) { ll tmp = dp[j]; rep3(l, j + 1, i + 1) { tmp += a[l] - a[j]; } chmin(dp[i + 1], tmp); } } cout << dp << endl; }
O(N^2)
累積和を使う
rep(j, i - (k - 1) + 1) { ll tmp = dp[j]; tmp -= a[j] * (i - (j + 1) + 1); tmp += acc[i + 1] - acc[j + 1]; chmin(dp[i + 1], tmp); }
ちょっと変形
rep(j, i - (k - 1) + 1) { ll tmp = - a[j] * i + dp[j] + a[j] * j - acc[j + 1]; // i に関する一次式 chmin(dp[i + 1], tmp); } dp[i + 1] += acc[i + 1];
ここで, dp[i], dp[i + 1] の計算の違いを考える
i の一次式で, i がずれる
また, rep(j, i - (k - 1) + 1) で, i が1 つ大きくなるので, 直線が 1 つ増える
i を走査すると, 直線が 1 つずつ増え, 最小値を見る x 座標が 1 ずつずれる
尺取りのように, 今後最小値をとる可能性がある直線集合を持つ
int n, k; vector<ll> a; vector<ll> dp; vector<ll> acc; ll f(int j, int x) { // 直線 f_j の x における値 return -a[j] * x + dp[j] - acc[j] + a[j] * j; } bool check(int j1, int j2, int j3) { // j = j2 の直線 f2 が最小値をとる可能性があるか ll a1 = -a[j1], b1 = dp[j1] - acc[j1] + a[j1] * j1; // f の式で, j = j1 を代入 ll a2 = -a[j2], b2 = dp[j2] - acc[j2] + a[j2] * j2; ll a3 = -a[j3], b3 = dp[j3] - acc[j3] + a[j3] * j3; return (a2 - a1) * (b3 - b2) >= (b2 - b1) * (a3 - a2); } void solve() { cin >> n >> k; a.resize(n); rep(i, n) cin >> a[i]; acc.resize(n + 1); rep(i, n) acc[i + 1] += acc[i] + a[i]; dp.resize(n + 1); dp[0] = 0; deque<int> deq; rep3(i, k, n + 1) { if (i - k >= k) { // 末尾から最小値をとる可能性がない直線を取り除く while (1) { if (sz(deq) < 2) break; int l1 = deq.back(); deq.pop_back(); int l2 = deq.back(); if (check(l2, l1, i - k) == 0) { deq.push_back(l1); break; } } deq.push_back(i - k); } while (1) { if (sz(deq) < 2) break; int s1 = deq.front(); deq.pop_front(); int s2 = deq.front(); if (f(s1, i) < f(s2, i)) { // 先頭が最小値 deq.push_front(s1); break; } } dp[i] = acc[i] + f(deq.front(), i); } cout << dp << endl; }
check 関数の説明をする
まず, f1, f2, f3 の順で傾きが負に大きい (-a[j1] のようにマイナスがつく)
f1, f3 の交点の x 座標で, f2 がその y 座標より負に大きければ f2 は最小値に今後なりえない
このとき,
(次の式変形で見やすく)
(添字の 1, 3 は入れ替えてもいい, y の計算で f1 に代入するか f3 に代入するか)
取り除く直線を末尾から見るのは, 末尾に新しい直線が入るから
参考
蟻本 p304