题目描述

如果一个字符串可以被拆分为$\text{AABB}$的形式,其中$\text{A}$和$\text{B}$是任意非空字符串,则我们称该字符串的这种拆分是优秀的。
例如,对于字符串$\texttt{aabaabaa}$,如果令$\text{A}=\texttt{aab}$,$\text{B}=\texttt{a}$,我们就找到了这个字符串拆分成$\text{AABB}$的一种方式。

一个字符串可能没有优秀的拆分,也可能存在不止一种优秀的拆分。
比如我们令$\text{A}=\texttt{a}$,$\text{B}=\texttt{baa}$,也可以用$\text{AABB}$表示出上述字符串;但是,字符串$\texttt{abaabaa}$就没有优秀的拆分。

现在给出一个长度为$n$的字符串$S$,我们需要求出,在它所有子串的所有拆分方式中,优秀拆分的总个数。这里的子串是指字符串中连续的一段。

以下事项需要注意:

  1. 出现在不同位置的相同子串,我们认为是不同的子串,它们的优秀拆分均会被记入答案。
  2. 在一个拆分中,允许出现$\text{A}=\text{B}$。例如$\texttt{cccc}$存在拆分$\text{A}=\text{B}=\texttt{c}$。
  3. 字符串本身也是它的一个子串。

输入格式

每个输入文件包含多组数据。
输入文件的第一行只有一个整数$T$,表示数据的组数。
接下来$T$行,每行包含一个仅由英文小写字母构成的字符串$S$,意义如题所述。

输出格式

输出$T$行,每行包含一个整数,表示字符串$S$所有子串的所有拆分中,总共有多少个是优秀的拆分。

$n\le 30000$

题解

太良心了
$85%$的点$n\le 500$,直接$O(n^3)$暴力枚举区间+断点用哈希判断

然后只要稍微动动脑子:设$a[i]$表示以$i$结尾的$\text{AA}$串个数,$b[i]$表示以$i$开头的$\text{AA}$串个数,那么答案就是$\sum\limits_{i=1}^{n-1}
a[i]*b[i+1]$

$O(n^2)$95分到手

虽然正解和字符串哈希无关,但是我觉得还是应该给这道鉴题加一个字符串哈希的tag

最后五分如果想不出来不拿也感觉无所谓。。。最后五分确实不好想

所以开始说正解:

上面的95分解法问题就在于$a[N],\ b[N]$,我们需要$O(n^2)$的时间求出来,考虑怎么样求得更快

我们枚举一个$len$表示我们现在想找到那些长度为$2*len$的$\text{AA}$串

然后在原串上每隔$len$放一个断点

我们枚举相邻的两个断点$i,j$,现在我们想要知道 以$i$开头的后缀与以$j$开头的后缀的最长公共前缀(LCP) 和 以$i$结尾的前缀与以$j$结尾的前缀的最长公共后缀(LCS)

LCP可以用后缀数组求;LCS也可以,把原数组翻转之后就变成后缀的LCP了,所以这两个都是可以用ST表$O(1)$求出的

那么现在我们求出了这两个值

情况1

对于这种情况,即$LCP+LCS-1<len$,我们是找不出$\text{AA}$串的

情况2

用脚画图 不愧是我

$LCP+LCS-1<len$,这个时候就有很多的长为$2*len$的$\text{AA}$串了,图中画出的$\text{AA},\ \text{BB}$就是最靠左和最靠右的两个这样的串

实际上,我画了”OK”的那个橙色区间的每一个点都是一个长为$2*len$的$\text{AA}$串的开头

如何找哪一段是合法$\text{AA}$串的结尾也同理

所以实际上每次就是把$a[N]$和$b[N]$的某一段全部加一 用差分来维护一下就行了

最后来看一下时间复杂度

后缀数组+ST表是$O(n\log n)$

$\frac{n}{1}+\frac{n}{2}+\frac{n}{3}+\dots+\frac{n}{n}$我记得差不多就是$O(n \log n)$吧。。。可能要稍微大一点

总之$n\le 30000$的数据是完全没有压力的

注意多组数据初始化数组!注意多组数据初始化数组!注意多组数据初始化数组!

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
#include <bits/stdc++.h>
#define N 60005
using namespace std;

int t, n, nn;
char s[N];
int a[N], b[N];
int sa[N], sa2[N], rnk[N], sum[N], key[N], height[N], ST[N][21];

inline bool check(int *num, int aa, int bb, int l) {
if (aa + l > n || bb + l > n) return false; //多组数据,一定要加!
return num[aa] == num[bb] && num[aa+l] == num[bb+l];
}

void DA() {
int i, j, p, m = 128;
for (i = 1; i <= m; i++) sum[i] = 0;
for (i = 1; i <= n; i++) sum[rnk[i]=s[i]]++;
for (i = 2; i <= m; i++) sum[i] += sum[i-1];
for (i = n; i; i--) sa[sum[rnk[i]]--] = i;
for (j = 1; j <= n; j <<= 1, m = p) {
for (p = 0, i = n - j + 1; i <= n; i++) sa2[++p] = i;
for (i = 1; i <= n; i++) if (sa[i] - j > 0) sa2[++p] = sa[i] - j;
for (i = 1; i <= n; i++) key[i] = rnk[sa2[i]];
for (i = 1; i <= m; i++) sum[i] = 0;
for (i = 1; i <= n; i++) sum[key[i]]++;
for (i = 2; i <= m; i++) sum[i] += sum[i-1];
for (i = n; i; i--) sa[sum[key[i]]--] = sa2[i];
for (swap(sa2, rnk), p = 2, rnk[sa[1]] = 1, i = 2; i <= n; i++) {
rnk[sa[i]] = check(sa2, sa[i-1], sa[i], j) ? p - 1 : p++;
}
}
}

void geth() {
int p = 0;
for (int i = 1; i <= n; i++) rnk[sa[i]] = i;
for (int i = 1; i <= n; i++) {
if (p) p--;
int j = sa[rnk[i]-1];
while (s[i+p] == s[j+p] && i + p <= n && j + p <= n) p++; //多组数据,一定要加!
height[rnk[i]] = p;
}
}

void preST() {
for (int i = 1; i <= n; i++) ST[i][0] = height[i];
for (int l = 1; (1 << l) <= n; l++) {
for (int i = 1; i + (1<<l) - 1 <= n; i++) {
ST[i][l] = min(ST[i][l-1], ST[i+(1<<(l-1))][l-1]);
}
}
}

inline int QST(int x, int y) {
if (x > y) swap(x, y); x++;
int l = log2(y - x + 1);
return min(ST[x][l], ST[y-(1<<l)+1][l]);
}

inline int LCP(int x, int y) { return QST(rnk[x], rnk[y]); }
inline int LCS(int x, int y) { return QST(rnk[n-x+1], rnk[n-y+1]); }

void Solve() {
for (int l = 1; l * 2 <= nn; l++) {
for (int i = 1, j = i + 1; j * l <= nn; i++, j++) {
int lcp = min(LCP(i*l, j*l), l), lcs = min(LCS(i*l, j*l), l);
if (lcp + lcs - 1 >= l) {
a[j*l+l-lcs]++; a[j*l+lcp]--;
b[i*l-lcs+1]++; b[i*l-l+lcp+1]--;
}
}
}
for (int i = 1; i <= nn; i++) {
a[i] += a[i-1];
b[i] += b[i-1];
}
long long ans = 0;
for (int i = 1; i < nn; i++) {
ans += 1ll * a[i] * b[i+1];
}
printf("%lld\n", ans);
}

int main() {
scanf("%d", &t);
while (t--) {
memset(a, 0, sizeof(a)); memset(b, 0, sizeof(b));
scanf("%s", s + 1);
n = strlen(s + 1);
s[n+1] = '$';
for (int i = n + 2; i <= 2 * n + 1; i++) {
s[i] = s[2 * n - i + 2];
}
nn = n;
n = n * 2 + 1;
DA(); geth(); preST();
Solve();
}
return 0;
}

评论