k-anonymous sequence - poj 3709

#dp #cht #convex hull trick

問題
http://poj.org/problem?id=3709



O(N^3)
a[j + 1] ~ a[i] を a[j] と同じにする

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define F first
#define S second
#define pii pair<int, int>
#define eb emplace_back
#define all(v) v.begin(), v.end()
#define rep(i, n) for (int i = 0; i < (n); ++i)
#define rep3(i, l, n) for (int i = l; i < (n); ++i)
#define sz(v) (int)v.size()
const int inf = 1e9 + 7;
const ll INF = 1e18;
const int mod = 1000000007;
#define abs(x) (x >= 0 ? x : -(x))
#define lb(v, x) (int)(lower_bound(all(v), x) - v.begin())
#define ub(v, x) (int)(upper_bound(all(v), x) - v.begin())
template<typename T1, typename T2> inline bool chmin(T1 &a, T2 b) { if (a > b) { a = b; return 1; } return 0; }
template<typename T1, typename T2> inline bool chmax(T1 &a, T2 b) { if (a < b) { a = b; return 1; } return 0; }
template<typename T> T gcd(T a, T b) { if (b == 0) return a; return gcd(b, a % b); }
template<typename T> T lcm(T a, T b) { return a / gcd(a, b) * b; }
template<typename T> T pow(T a, int b) { return b ? pow(a * a, b / 2) * (b % 2 ? a : 1) : 1; }
ll modpow(ll a, int b, int _mod) { return b ? modpow(a * a % _mod, b / 2, _mod) * (b % 2 ? a : 1) % _mod : 1; }
template<class T> ostream& operator<<(ostream& os, const vector<T>& vec) { for (auto &vi: vec) os << vi << " "; return os; }
template<class T, class U> ostream& operator<<(ostream& os, const pair<T, U>& p) { os << p.F << " " << p.S; return os; }
template<class T> inline void add(T &a, int b) { a += b; if (a >= mod) a -= mod; }



void solve();

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0);
    cout.tie(0);
    cout << fixed << setprecision(10);

    int T;
    // cin >> T;
    T = 1;

    while (T--) {
        solve();
    }
}

void solve() {
  int n, k;
  cin >> n >> k;
  vector<int> a(n);
  rep(i, n) cin >> a[i];
  vector<ll> dp(n + 1, INF);
  // 最初の i 項だけ見て, ある値と同じ値が他に k-1 個は存在するような
  // 変更の最小回数
  dp[0] = 0; // 空
  rep(i, n) {
    // 最初 j 項が条件を満たす
    // 最初 i 項が条件を満たすには?
    rep(j, i - (k - 1) + 1) {
      ll tmp = dp[j];
      rep3(l, j + 1, i + 1) {
        tmp += a[l] - a[j];
      }
      chmin(dp[i + 1], tmp);
    }
  }
  cout << dp << endl;
}



O(N^2)
累積和を使う

rep(j, i - (k - 1) + 1) {
  ll tmp = dp[j];
  tmp -= a[j] * (i - (j + 1) + 1);
  tmp += acc[i + 1] - acc[j + 1];
  chmin(dp[i + 1], tmp);
}

ちょっと変形

rep(j, i - (k - 1) + 1) {
  ll tmp = - a[j] * i + dp[j] + a[j] * j - acc[j + 1]; // i に関する一次式
  chmin(dp[i + 1], tmp);
}
dp[i + 1] += acc[i + 1];

ここで, dp[i], dp[i + 1] の計算の違いを考える
i の一次式で, i がずれる
また, rep(j, i - (k - 1) + 1) で, i が1 つ大きくなるので, 直線が 1 つ増える

i を走査すると, 直線が 1 つずつ増え, 最小値を見る x 座標が 1 ずつずれる
尺取りのように, 今後最小値をとる可能性がある直線集合を持つ



int n, k;
vector<ll> a;
vector<ll> dp;
vector<ll> acc;

ll f(int j, int x) { // 直線 f_j の x における値
  return -a[j] * x + dp[j] - acc[j] + a[j] * j;
}

bool check(int j1, int j2, int j3) { // j = j2 の直線 f2 が最小値をとる可能性があるか
  ll a1 = -a[j1], b1 = dp[j1] - acc[j1] + a[j1] * j1; // f の式で, j = j1 を代入
  ll a2 = -a[j2], b2 = dp[j2] - acc[j2] + a[j2] * j2;
  ll a3 = -a[j3], b3 = dp[j3] - acc[j3] + a[j3] * j3;
  return (a2 - a1) * (b3 - b2) >= (b2 - b1) * (a3 - a2);
}

void solve() {
  cin >> n >> k;
  a.resize(n);
  rep(i, n) cin >> a[i];
  acc.resize(n + 1);
  rep(i, n) acc[i + 1] += acc[i] + a[i];

  dp.resize(n + 1);
  dp[0] = 0;

  deque<int> deq;
  rep3(i, k, n + 1) {
    if (i - k >= k) { // 末尾から最小値をとる可能性がない直線を取り除く
      while (1) {
        if (sz(deq) < 2) break;
        int l1 = deq.back();
        deq.pop_back();
        int l2 = deq.back();
        if (check(l2, l1, i - k) == 0) {
          deq.push_back(l1);
          break;
        }
      }
      deq.push_back(i - k);
    }
    while (1) {
      if (sz(deq) < 2) break;
      int s1 = deq.front();
      deq.pop_front();
      int s2 = deq.front();
      if (f(s1, i) < f(s2, i)) { // 先頭が最小値
        deq.push_front(s1);
        break;
      }
    }
    dp[i] = acc[i] + f(deq.front(), i);
  }
  cout << dp << endl;
}

check 関数の説明をする
まず, f1, f2, f3 の順で傾きが負に大きい (-a[j1] のようにマイナスがつく)
f1, f3 の交点の x 座標で, f2 がその y 座標より負に大きければ f2 は最小値に今後なりえない

 a_1 x + b_1 = a_3 x + b_3 \\
x = \frac{b_3 - b_1}{a_3 - a_1}
このとき,
 y = a_1 \frac{b_3 - b_1}{a_3 - a_1} + b_1 \\ \\
f2(x) = a_2 \frac{b_3 - b_1}{a_3 - a_1} + b_2 \geq a_1 \frac{b_3 - b_1}{a_3 - a_1} + b_1 = y
(次の式変形で見やすく)
 (a_2 - a_1) \frac{b_3 - b_1}{a_3 - a_1} \geq (b_1 - b_2)
(添字の 1, 3 は入れ替えてもいい, y の計算で f1 に代入するか f3 に代入するか)

取り除く直線を末尾から見るのは, 末尾に新しい直線が入るから

参考
蟻本 p304