AK-dream

【题目描述】
给定一个长度为 $n$ 的字符串 $S$ ,令 $T_i$ 表示它从第 $i$ 个字符开始的后缀,求:

$\sum\limits_{1\le i<j\le n}\operatorname{len}(T_i)+\operatorname{len}(T_j)-2*\operatorname{lcp}(T_i,T_j)$

其中,$\operatorname{len}(a)$表示字符串$a$的长度,$\operatorname{lcp}(a,b)$表示字符串 $a$ 和字符串 $b$ 的最长公共前缀。

【输入格式】
一行,一个字符串 $S$。

【输出格式】
一行,一个整数,表示所求值。

题解

解法一:后缀数组


$\sum\limits_{1\le i<j\le n}\operatorname{len}(T_i)+\operatorname{len}(T_j)$这个东西就等于$\frac{n(n-1)(n+1)}{2}$,丢一边就行了

求出$height$数组 然后相当于询问$[2,n]$内所有区间的 区间内$height$最小值 之和

用单调栈+DP来求解

设$f(l,r)=\min_{i=l}^{r}height[i]$
设$dp[i]$表示$\sum\limits_{j=1}^{i-1}f(j,i)$

维护这个单调栈来每次找到 第一个$height[p]<height[i]$的$p$的位置

那么转移方程是$dp[i]=dp[p]+(i-p)*height[i]$

因为对于$j\in [1,p]$ 显然$f(j,p)=f(j,i)$; 而对于$j\in [p+1,i-1]$,有$f(j,i)=height[i]$

最后把上面算的那个$\frac{n(n-1)(n+1)}{2}$减去$2*\sum\limits_{i=1}^{n}f[i]$就是答案

时间复杂度$O(n\log n)$

代码


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

ll n;
char s[500005];
ll sa[500005], rnk[500005], sum[500005], sa2[500005], key[500005], height[500005];
ll ans, f[500005];
ll stk[500005], top;

inline bool check(ll *num, ll a, ll b, ll l) {
return num[a] == num[b] && num[a+l] == num[b+l];
}

inline void DA(ll m) {
ll i, j, p;
for (i = 1; i <= m; i++) sum[i] = 0;
for (i = 1; i <= n; i++) sum[rnk[i]=s[i]]++;
for (i = 2; i <= m; i++) sum[i] += sum[i-1];
for (i = n; i >= 1; i--) sa[sum[rnk[i]]--] = i;
for (j = 1, p = 0; j <= n; j <<= 1, m = p) {
p = 0; for (i = n - j + 1; i <= n; i++) sa2[++p] = i;
for (i = 1; i <= n; i++) if (sa[i] > j) sa2[++p] = sa[i] - j;
for (i = 1; i <= n; i++) key[i] = rnk[sa2[i]];
for (i = 1; i <= m; i++) sum[i] = 0;
for (i = 1; i <= n; i++) sum[key[i]]++;
for (i = 2; i <= m; i++) sum[i] += sum[i-1];
for (i = n; i >= 1; i--) sa[sum[key[i]]--] = sa2[i];
for (swap(rnk, sa2), p = 2, rnk[sa[1]] = 1, i = 2; i <= n; i++) {
rnk[sa[i]] = check(sa2, sa[i-1], sa[i], j) ? p - 1 : p++;
}
}
}

inline void geth() {
ll p = 0;
for (ll i = 1; i <= n; i++) rnk[sa[i]] = i;
for (ll i = 1; i <= n; i++) {
if (p) p--;
ll j = sa[rnk[i]-1];
while (s[i + p] == s[j + p]) p++;
height[rnk[i]] = p;
}
}

int main() {
scanf("%s", s+1);
n = strlen(s+1);
DA(128);
geth();
stk[top=1] = 1;
ans = 1ll * n * (n + 1) * (n - 1) / 2;
for (ll i = 2; i <= n; i++) {
while (top && height[stk[top]] > height[i]) top--;
f[i] = f[stk[top]] + 1ll * (i - stk[top]) * height[i];
ans -= f[i] * 2;
stk[++top] = i;
}
printf("%lld\n", ans);
return 0;
}

解法二:后缀自动机


要求的是每个后缀两两之间的最长公共前缀 那么我们把原串翻转一下就变成求每个前缀两两之间的最长公共后缀 这个显然可以用SAM解决

两个前缀的最长公共后缀就是那两个前缀代表的节点在parent树上的LCA节点的len

注意到这个式子$\operatorname{len}(T_i)+\operatorname{len}(T_j)-2*\operatorname{lcp}(T_i,T_j)$ 是不是像求树上最短路径的式子?

实际上 我们把parent树上每条从$x$到$fa[x]$的边长度设为$len[x]-len[fa[x]]$

那么这个式子就是两个节点树上的距离

我们把$n$个代表着前缀的节点称为“关键节点” 那么答案就是关键节点之间两两距离的总和

考虑每条边被经过了多少次 设$cnt[x]$代表$x$为根的子树中有多少个关键节点 因为关键节点一共$n$个 所以$x->fa[x]$这条边被经过$cnt[x]*(n-cnt[x])$次

答案就是$\sum\limits_{x\in V} cnt[x]*(n-cnt[x])*(len[x]-len[fa[x]])$

吐槽一下 压根就不用翻转原串也能AC 谁来证明一下每个后缀两两之间的最长公共前缀等于每个前缀两两之间的最长公共后缀。。。

时间复杂度$O(n*26)$

代码


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#include <bits/stdc++.h>
using namespace std;

char s[1000005];
int n;
long long ans;

struct SAM{
struct node{
int len, link, ch[26], cnt;
} tr[1000010];
int tot, lst;

inline void extend(int c) {
int ind = ++tot;
tr[ind].len = tr[lst].len + 1;
tr[ind].cnt = 1;
int p = lst;
while (p && !tr[p].ch[c]) {
tr[p].ch[c] = ind;
p = tr[p].link;
}
if (!p) {
tr[ind].link = 1;
} else {
int q = tr[p].ch[c];
if (tr[q].len == tr[p].len + 1) {
tr[ind].link = q;
} else {
int clone = ++tot;
tr[clone].link = tr[q].link;
for (int j = 0; j < 26; j++) tr[clone].ch[j] = tr[q].ch[j];
tr[clone].len = tr[p].len + 1;
while (p && tr[p].ch[c] == q) {
tr[p].ch[c] = clone;
p = tr[p].link;
}
tr[ind].link = tr[q].link = clone;
}
}
lst = ind;
}

int b[1000010], cc[1000010];

inline void calc() {
for (int i = 1; i <= tot; i++) cc[tr[i].len]++;
for (int i = 1; i <= tot; i++) cc[i] += cc[i-1];
for (int i = 1; i <= tot; i++) b[cc[tr[i].len]--] = i;
for (int i = tot; i >= 1; i--) {
int x = b[i];
tr[tr[x].link].cnt += tr[x].cnt;
ans += 1ll * (tr[x].len - tr[tr[x].link].len) * tr[x].cnt * (n - tr[x].cnt);
}
}
}T;

int main() {
scanf("%s", s + 1);
n = strlen(s + 1);
reverse(s + 1, s + n + 1);
T.tot = T.lst = 1;
for (int i = 1; i <= n; i++) T.extend(s[i] - 'a');
T.calc();
printf("%lld\n", ans);
return 0;
}


 评论