AK-dream

对于字符串$x$和$y$,假设字符串$y$在字符串$x$中若干的匹配位置,我们用$(l_i,r_i)$来表示。二元组代表$x$中$x_{l_i}\sim x_{r_i}$的字符串和$y_1\sim y_{len(y)}$ 完全一致。这些二元组根据第一关键字从小到大排序后形成一个序列,定义一个函数$F(x,y)$的值为该序列的非空连续序列的数量。以$F(babbabbababbab, babb) = 6$为例子,匹配的二元组序列为:
$(1, 4), (4, 7), (9, 12)$

非空连续序列为
$(1, 4)$
$(4, 7)$
$(9, 12)$
$(1, 4), (4, 7)$
$(4, 7), (9, 12)$
$(1, 4), (4, 7), (9, 12)$

现在给你一个字符串$s$,请求出$F(s,x)$的和,其中$x$为$s$的全部子串。

【输入格式】
输入一个字符串$s$。

【输出格式】
输出题目要求的$F$值。

题解

模板更模板

建好后缀自动机 我们把代表原串前缀的那几个节点称作关键节点(即建自动机的时候每次添加一个新字符得到的那个新节点)

那么对于一个子串$t$,我们把$t$在后缀自动机中所在的节点称作$x$,在parent树中,$x$为根的子树中的 关键节点的数量 就是$t$在原串中出现的次数 把它记为$cnt[x]$

这个从每个关键节点开始暴力向上跳father 沿路数量++ 就能统计出来

然后这个“非空连续序列数量”其实就是$\frac{a(a+1)}{2}$

然后自动机中每个节点$x$含有的原串的子串数量就是$len[x]-len[fa[x]]$

那每个节点的答案就是$cnt[x]*(cnt[x]+1)/2*(len[x]-len[fa[x]])$ 把所有节点的答案加起来就可以了

时间复杂度$O(n*26)$

代码


1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

ll n, key[100005];
char s[100005];
ll ans;

struct SAM{
struct node{
ll nxt[30], link, len, cnt;
} tr[2000010];
ll tot, lst;

inline void init() {
tot = lst = 0;
tr[0].link = -1; tr[0].len = 0;
}

inline void build(char *str, ll len) {
init();
for (ll i = 1; i <= len; i++) {
ll ind = ++tot, c = str[i] - 'a';
tr[ind].len = tr[lst].len + 1;
ll p = lst;
while (p != -1 && !tr[p].nxt[c]) {
tr[p].nxt[c] = ind;
p = tr[p].link;
}
if (p == -1) {
tr[ind].link = 0;
} else {
ll q = tr[p].nxt[c];
if (tr[p].len + 1 == tr[q].len) {
tr[ind].link = q;
} else {
ll clone = ++tot;
tr[clone].len = tr[p].len + 1;
tr[clone].link = tr[q].link;
for (ll j = 0; j < 26; j++) tr[clone].nxt[j] = tr[q].nxt[j];
while (p != -1 && tr[p].nxt[c] == q) {
tr[p].nxt[c] = clone;
p = tr[p].link;
}
tr[q].link = tr[ind].link = clone;
}
}
lst = ind;
key[i] = ind; //ind是一个关键节点
}
}

inline void getans() {
for (int i = 1; i <= n; i++) {
int p = key[i];
while (p != -1) {
tr[p].cnt++;
p = tr[p].link;
}
}
for (ll i = 1; i <= tot; i++) {
ll x = tr[i].len - tr[tr[i].link].len, y = tr[i].cnt;
ans += y * (y + 1) / 2 * x;
}
}
}T;


int main() {
scanf("%s", s + 1);
n = strlen(s + 1);
T.build(s, n);
T.getans();
printf("%lld\n", ans);
return 0;
}


 评论