算法学习专栏——线段树——线段树的多类操作(有问题!!!debug中.......)

【模板】线段树 2

原题

题目描述

如题,已知一个数列,你需要进行下面三种操作:

  • 将某区间每一个数乘上 $x$;
  • 将某区间每一个数加上 $x$;
  • 求出某区间每一个数的和。

输入格式

第一行包含三个整数 $n,q,m$,分别表示该数列数字的个数、操作的总个数和模数。

第二行包含 $n$ 个用空格分隔的整数,其中第 $i$ 个数字表示数列第 $i$ 项的初始值。

接下来 $q$ 行每行包含若干个整数,表示一个操作,具体如下:

操作 $1$: 格式:1 x y k 含义:将区间 $[x,y]$ 内每个数乘上 $k$

操作 $2$: 格式:2 x y k 含义:将区间 $[x,y]$ 内每个数加上 $k$

操作 $3$: 格式:3 x y 含义:输出区间 $[x,y]$ 内每个数的和对 $m$ 取模所得的结果

输出格式

输出包含若干行整数,即为所有操作 $3$ 的结果。

样例 #1

样例输入 #1

1
2
3
4
5
6
7
5 5 38
1 5 4 2 3
2 1 4 1
3 2 5
1 2 4 2
2 3 5 5
3 1 4

样例输出 #1

1
2
17
2

提示

【数据范围】

对于 $30%$ 的数据:$n \le 8$,$q \le 10$。
对于 $70%$ 的数据:$n \le 10^3 $,$q \le 10^4$。
对于 $100%$ 的数据:$1 \le n \le 10^5$,$1 \le q \le 10^5$。

除样例外,$m = 571373$。

(数据已经过加强 ^_^)

样例说明:

故输出应为 $17$、$2$($40 \bmod 38 = 2$)。

思路

​ 持续更新….

代码(待debug)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#include <iostream>
#include <cstring>
#include <algorithm>
#include <queue>
#include <map>
#include <unordered_map>
#include <string>
#include <cmath>
#include <set>
#include <stack>
#include <vector>
#include <deque>
#include <bitset>
#include <cstdio>
#include <iomanip>

// #define int long long
#define ull unsigned long long
#define ed '\n'
#define fi first
#define se second
#define fore(i, l, r) for(int i = (int)(l); i <=(int)r; i++)
#define debug(x) cout << "#x = " << x << ed;
#define PI acos(-1)
#define E exp(1)
#define IOS ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
#define me0(st) memset(st, 0, sizeof st)
#define me3f(st) memset(st, 0x3f, sizeof st)
#define pdf(x) printf("%lf", double(x))
#define pif(x) printf("%d ", int(x))
#define scf(x) printf("%d", int(x))
#define re return 0
#define max(a, b) ((a) > (b) ? (a) : (b))
#define min(a, b) ((a) < (b) ? (a) : (b))
#define out(x, k) cout << fixed << setprecision(k) << x << ed

using namespace std;

typedef pair<int, int> PII;
typedef long long LL;
typedef double db;


const int INF = 1e9;
const int N = 1e5 + 10;

struct Node {
int l, r;
LL sum, add, mul;
}tr[4 * N];
int n, q, m;
LL w[N];

void pushdown(int u) {
auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];

// add类型
if (root.add)
left.add += root.add, left.sum += LL(root.add * (left.r - left.l + 1) % m);
right.add += root.add, right.sum += LL(root.add * (right.r - right.l + 1) % m);
代码
// mul类型
left.mul *= root.mul, left.sum = LL(left.sum * root.mul + root.add * left.sum % m);
right.mul *= root.mul, right.sum = LL(right.sum * root.mul + root.add * right.sum % m);

root.mul = 1;
root.add = 0;
}

void pushup(int u) {
tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % m;
}

void build(int u, int l, int r) {
if (l == r) tr[u] = {l, r, w[r], 0, 1};
else {
tr[u] = {l, r, 0, 0, 1};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}

// 和
void modify1(int u,int l, int r, int d) {
if (tr[u].l >= l && tr[u].r <= r) {
tr[u].sum = (LL)((tr[u].sum + (tr[u].r - tr[u].l + 1) * d) % m);
tr[u].add += d;
} else {
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify1(u << 1, l, r, d);
if (r > mid) modify1(u << 1 | 1, l, r, d);
pushup(u);
}
}

// 积
void modify2(int u, int l, int r, int d) {
if (tr[u].l >= l && tr[u].r <= r) {
tr[u].sum = (tr[u].sum * d) % m;
tr[u].add = (tr[u].add * d) % m;
tr[u].mul =(tr[u].mul * d) % m;
} else {
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify2(u << 1, l, r, d);
if (r > mid) modify2(u << 1 | 1, l, r, d);
pushup(u);
}
}

LL query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) {
return tr[u].sum;
}
else {
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
LL sum = 0;
if (l <= mid) sum = query(u << 1, l, r);
if (r > mid) sum += query(u << 1 | 1, l, r);
return sum;
}
}

void solve() {
scanf("%d %d %d", &n, &q, &m);
for (int i = 1; i <= n; i++) scanf("%lld", &w[i]);
build(1, 1, n);

int l, r, d;
char op[2];
while (q--) {
scanf("%s %d %d", op, &l, &r);
if (*op == '1') {
scanf("%d",&d);
modify1(1, l, r, d);
} else if (*op == '2') {
scanf("%d",&d);
modify2(1, l, r, d);
} else {
printf("%lld\n", query(1, l, r));
}
}
}
int main()
{
IOS;
int _ = 1;
// cin >> _;

while(_--) {
solve();
}

re;
}