1use nix::sys::stat::Mode;
2use std::io;
3
4pub use nix::libc::mode_t;
5
6static NIX_PERMISSIONS: &[Mode] = &[
7 Mode::S_IRUSR,
8 Mode::S_IWUSR,
9 Mode::S_IXUSR,
10 Mode::S_IRGRP,
11 Mode::S_IWGRP,
12 Mode::S_IXGRP,
13 Mode::S_IROTH,
14 Mode::S_IWOTH,
15 Mode::S_IXOTH,
16];
17
18fn get_class(str: &str) -> Result<mode_t, io::Error> {
19 if str.is_empty() {
20 Ok(0b111111111)
21 } else {
22 let next = str.chars().find(|x| !is_user_access_token(*x));
23 if next.is_some() {
24 let msg =
25 "symbolic mode string before the '+' can only contain u, g, o, or a.".to_string();
26 Err(io::Error::other(msg))
27 } else {
28 let mut class: mode_t = 0;
29 for c in str.chars() {
30 class |= match c {
31 'u' => 0b111000000,
32 'g' => 0b000111000,
33 'o' => 0b000000111,
34 'a' => 0b111111111,
35 c if c.is_whitespace() => 0b111111111,
36 _ => 0,
37 }
38 }
39 Ok(class)
40 }
41 }
42}
43
44fn get_perms(str: &str) -> Result<mode_t, io::Error> {
45 if str.is_empty() {
46 Ok(0b111111111)
47 } else {
48 let next = str.chars().find(|x| !is_permission_token(*x));
49 if next.is_some() {
50 let msg =
51 "symbolic mode string before the '+' can only contain r, w, or x.".to_string();
52 Err(io::Error::other(msg))
53 } else {
54 let mut class: mode_t = 0;
55 for c in str.chars() {
56 class |= match c {
57 'r' => 0b100100100,
58 'w' => 0b010010010,
59 'x' => 0b001001001,
60 _ => 0,
61 }
62 }
63 Ok(class)
64 }
65 }
66}
67
68fn decode_symbolic_mode_string(str: &str, split_char: char) -> Result<(mode_t, mode_t), io::Error> {
69 let mode_strings = str.split(split_char).collect::<Vec<&str>>();
70 if mode_strings.len() == 2 {
71 if let (Some(c), Some(p)) = (mode_strings.first(), mode_strings.get(1)) {
72 if c.is_empty() && p.is_empty() {
73 let msg = format!(
74 "symbolic mode string must have a valid character before and/or \
75 after the '{}' character.",
76 split_char,
77 );
78 Err(io::Error::other(msg))
79 } else {
80 let class = get_class(c)?;
81 let perms = get_perms(p)?;
82 Ok((class, perms))
83 }
84 } else {
85 let msg = format!(
86 "symbolic mode string contains too many '{}' characters.",
87 split_char,
88 );
89 Err(io::Error::other(msg))
90 }
91 } else {
92 let msg = format!(
93 "symbolic mode string contains too many '{}' characters.",
94 split_char,
95 );
96 Err(io::Error::other(msg))
97 }
98}
99
100enum PermissionOperator {
101 Plus,
102 Minus,
103 Equal,
104}
105
106struct MaskType {
107 class: mode_t,
108 perms: mode_t,
109 mask_type: PermissionOperator,
110}
111
112impl MaskType {
113 #[allow(clippy::unnecessary_cast)]
114 fn combine(&self, mode: mode_t) -> mode_t {
115 let m = match &self.mask_type {
116 PermissionOperator::Plus => !(self.class & self.perms) & mode,
117 PermissionOperator::Minus => mode | (self.class & self.perms),
118 PermissionOperator::Equal => {
119 ((self.class & self.perms) ^ 0o777) & ((mode & !self.class) ^ self.class)
120 }
121 };
122 to_mode(m).bits() as mode_t
123 }
124}
125
126fn to_mask_type(str: &str) -> Result<MaskType, io::Error> {
127 let decode = |split_char| -> Result<(mode_t, mode_t), io::Error> {
128 decode_symbolic_mode_string(str, split_char)
129 };
130 if str.contains('+') {
131 let (class, perms) = decode('+')?;
132 Ok(MaskType {
133 class,
134 perms,
135 mask_type: PermissionOperator::Plus,
136 })
137 } else if str.contains('-') {
138 let (class, perms) = decode('-')?;
139 Ok(MaskType {
140 class,
141 perms,
142 mask_type: PermissionOperator::Minus,
143 })
144 } else if str.contains('=') {
145 let (class, perms) = decode('=')?;
146 Ok(MaskType {
147 class,
148 perms,
149 mask_type: PermissionOperator::Equal,
150 })
151 } else {
152 let msg = "symbolic mode string must contain one of '+', '-', or '='.".to_string();
153 Err(io::Error::other(msg))
154 }
155}
156
157fn get_umask_tokens(str: &str) -> Result<Vec<MaskType>, io::Error> {
158 match str.lines().count() {
159 1 => {
160 let mut masks = vec![];
161 for x in str.split(',') {
162 let mask_type = to_mask_type(x)?;
163 masks.push(mask_type);
164 }
165 Ok(masks)
166 }
167 _ => {
168 let msg = "must supply only one line as input.".to_string();
169 Err(io::Error::other(msg))
170 }
171 }
172}
173
174fn with_umask_tokens(mut umask: mode_t, masks: Vec<MaskType>) -> mode_t {
175 for x in masks {
176 umask = x.combine(umask)
177 }
178 umask
179}
180
181fn make_parsable_octal_string(str: &str) -> Result<String, io::Error> {
183 if str.is_empty() {
184 let msg = "no input.".to_string();
185 Err(io::Error::other(msg))
186 } else if str.len() > 4 {
187 let msg = "no more than 4 characters can be used to specify a umask, e.g.\
188 644 or 0222."
189 .to_string();
190 Err(io::Error::other(msg))
191 } else if str.len() == 4 && !str.starts_with('0') {
192 let msg = "most significant octal character can only be 0.".to_string();
193 Err(io::Error::other(msg))
194 } else {
195 let mut ret = String::from(str);
196 while ret.len() < 4 {
197 ret = "0".to_owned() + &ret;
198 }
199 Ok(ret)
200 }
201}
202
203fn build_mask(to_shift: usize, c: char) -> Result<mode_t, io::Error> {
204 let apply = |m| Ok((m << (to_shift * 3)) as mode_t);
205 match c {
206 '0' => apply(0b000),
207 '1' => apply(0b001),
208 '2' => apply(0b010),
209 '3' => apply(0b011),
210 '4' => apply(0b100),
211 '5' => apply(0b101),
212 '6' => apply(0b110),
213 '7' => apply(0b111),
214 _ => {
215 let msg = "octal format can only take on values between 0 and 7 inclusive.".to_string();
216 Err(io::Error::other(msg))
217 }
218 }
219}
220
221fn octal_string_to_mode_t(str: &str) -> Result<mode_t, io::Error> {
222 let mut val = 0;
223 let mut err = false;
224 for (usize, c) in str.chars().rev().enumerate() {
225 match usize {
226 0..=2 => val |= build_mask(usize, c)?,
227 3 => {}
228 _ => {
229 err = true;
230 break;
231 }
232 }
233 }
234 if err {
235 let msg = "failed to parse provided octal.".to_string();
236 Err(io::Error::other(msg))
237 } else {
238 Ok(val)
239 }
240}
241
242#[allow(clippy::unnecessary_cast)]
243fn to_mode(i: mode_t) -> Mode {
244 NIX_PERMISSIONS.iter().fold(Mode::empty(), |acc, x| {
245 if (x.bits() as mode_t & i) == x.bits() as mode_t {
246 acc | *x
247 } else {
248 acc
249 }
250 })
251}
252
253fn octal_string_to_mode(str: &str) -> Result<Mode, io::Error> {
254 let str = make_parsable_octal_string(str)?;
255 let val = octal_string_to_mode_t(&str)?;
256 Ok(to_mode(val))
257}
258
259fn is_permission_token(ch: char) -> bool {
260 matches!(ch, 'r' | 'w' | 'x')
261}
262
263fn is_user_access_token(ch: char) -> bool {
264 matches!(ch, 'u' | 'g' | 'o' | 'a')
265}
266
267fn is_digit(ch: char) -> bool {
268 matches!(
269 ch,
270 '0' | '1' | '2' | '3' | '4' | '5' | '6' | '7' | '8' | '9'
271 )
272}
273
274pub fn merge_and_set_umask(current_umask: mode_t, mask_string: &str) -> Result<mode_t, io::Error> {
277 if mask_string.parse::<u32>().is_ok() {
278 let mode = octal_string_to_mode(mask_string)?;
279 nix::sys::stat::umask(mode);
280 Ok(mode.bits())
281 } else if !mask_string.is_empty() {
282 #[allow(clippy::unnecessary_cast)]
283 let mode = if is_digit(mask_string.chars().next().unwrap()) {
284 octal_string_to_mode(mask_string)?.bits() as mode_t
285 } else {
286 let masks = get_umask_tokens(mask_string)?;
287 with_umask_tokens(current_umask, masks)
288 };
289 if let Some(umask) = Mode::from_bits(mode) {
290 nix::sys::stat::umask(umask);
291 Ok(mode)
292 } else {
293 Err(io::Error::other("invalid umask".to_string()))
294 }
295 } else {
296 let msg = "no input.".to_string();
297 Err(io::Error::other(msg))
298 }
299}
300
301#[allow(clippy::unnecessary_cast)]
303pub fn get_and_clear_umask() -> mode_t {
304 nix::sys::stat::umask(Mode::empty()).bits() as mode_t
305}
306
307#[allow(clippy::unnecessary_cast)]
309pub fn set_umask(umask: mode_t) -> Result<(), io::Error> {
310 if let Some(umask) = Mode::from_bits(umask) {
311 nix::sys::stat::umask(umask);
312 Ok(())
313 } else {
314 Err(io::Error::other("invalid mode"))
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_parse_perms() {
324 let perms_str = "rwx";
325 let perms = get_perms(perms_str).unwrap();
326 assert_eq!(0b111111111, perms);
327
328 let perms_str = "r";
329 let perms = get_perms(perms_str).unwrap();
330 assert_eq!(0b100100100, perms);
331
332 let perms_str = "w";
333 let perms = get_perms(perms_str).unwrap();
334 assert_eq!(0b010010010, perms);
335
336 let perms_str = "x";
337 let perms = get_perms(perms_str).unwrap();
338 assert_eq!(0b001001001, perms);
339
340 let perms_str = "rw";
341 let perms = get_perms(perms_str).unwrap();
342 assert_eq!(0b110110110, perms);
343 }
344
345 #[test]
346 fn test_parse_class() {
347 let class_str = "ugo";
348 let class = get_class(class_str).unwrap();
349 assert_eq!(0b111111111, class);
350
351 let class_str = "a";
352 let class = get_class(class_str).unwrap();
353 assert_eq!(0b111111111, class);
354
355 let class_str = "u";
356 let class = get_class(class_str).unwrap();
357 assert_eq!(0b111000000, class);
358
359 let class_str = "g";
360 let class = get_class(class_str).unwrap();
361 assert_eq!(0b000111000, class);
362
363 let class_str = "o";
364 let class = get_class(class_str).unwrap();
365 assert_eq!(0b000000111, class);
366
367 let class_str = "uo";
368 let class = get_class(class_str).unwrap();
369 assert_eq!(0b111000111, class);
370 }
371
372 #[test]
373 fn test_umask_octal() {
374 let bs = 0b001001001;
375 let m = to_mode(bs);
376 assert_eq!(bs, m.bits());
377
378 let m = octal_string_to_mode("0522").unwrap();
379 assert_eq!(338, m.bits() as u32);
380
381 let m = octal_string_to_mode("522").unwrap();
382 assert_eq!(338, m.bits() as u32);
383
384 let m = octal_string_to_mode("713").unwrap();
385 assert_eq!(0b111001011, m.bits() as u32);
386
387 let m = octal_string_to_mode("466").unwrap();
388 assert_eq!(0b100110110, m.bits() as u32);
389
390 let m = octal_string_to_mode("0").unwrap();
391 assert_eq!(0b000000000, m.bits() as u32);
392
393 let m = octal_string_to_mode("45").unwrap();
394 assert_eq!(0b000100101, m.bits() as u32);
395
396 assert!(octal_string_to_mode("a+n").is_err());
397
398 assert!(octal_string_to_mode("1111").is_err());
399
400 assert!(octal_string_to_mode("11111").is_err());
401
402 assert!(octal_string_to_mode("0S11").is_err());
403 }
404
405 #[test]
406 #[allow(clippy::unnecessary_cast)]
407 fn test_umask_parser() {
408 let umask = to_mode(0o022).bits() as mode_t;
409
410 let m = with_umask_tokens(umask, get_umask_tokens("go+rx").unwrap());
411 assert_eq!(0o022, m);
412
413 let m = with_umask_tokens(umask, get_umask_tokens("+w").unwrap());
414 assert_eq!(0o0, m);
415
416 let m = with_umask_tokens(umask, get_umask_tokens("a-rw").unwrap());
417 assert_eq!(0o666, m);
418
419 let m = with_umask_tokens(umask, get_umask_tokens("g-rw").unwrap());
420 assert_eq!(0o062, m);
421
422 let m = with_umask_tokens(umask, get_umask_tokens("ug=rw").unwrap());
423 assert_eq!(0o0112, m);
424
425 let m = with_umask_tokens(umask, get_umask_tokens("a=r,ug=rw").unwrap());
426 assert_eq!(0o0113, m);
427
428 let m = with_umask_tokens(umask, get_umask_tokens("a=r,g+w").unwrap());
429 assert_eq!(0o0313, m);
430
431 let m = with_umask_tokens(umask, get_umask_tokens("a=r,a-r").unwrap());
432 assert_eq!(0o0777, m);
433
434 let m = with_umask_tokens(umask, get_umask_tokens("ugo+x").unwrap());
435 assert_eq!(0o0022, m);
436
437 let m = with_umask_tokens(umask, get_umask_tokens("+x").unwrap());
438 assert_eq!(0o0022, m);
439
440 let m = with_umask_tokens(umask, get_umask_tokens("a+rw").unwrap());
441 assert_eq!(0o0, m);
442
443 let m = with_umask_tokens(umask, get_umask_tokens("a=r").unwrap());
444 assert_eq!(0o333, m);
445
446 let m = with_umask_tokens(umask, get_umask_tokens("ug+rwx").unwrap());
447 assert_eq!(0o002, m);
448
449 let m = with_umask_tokens(umask, get_umask_tokens("ug+").unwrap());
450 assert_eq!(0o002, m);
451
452 let m = with_umask_tokens(umask, get_umask_tokens("o-rwx").unwrap());
453 assert_eq!(0o0027, m);
454
455 let m = with_umask_tokens(umask, get_umask_tokens("u-x,g=r,o+w").unwrap());
456 assert_eq!(0o0130, m);
457
458 assert!(get_umask_tokens("glo+rx").is_err());
459
460 assert!(get_umask_tokens("go+nrx").is_err());
461
462 assert!(get_umask_tokens("+n").is_err());
463
464 assert!(get_umask_tokens("a+n").is_err());
465
466 assert!(get_umask_tokens("ar").is_err());
467
468 assert!(get_umask_tokens("+a+r").is_err());
469
470 assert!(get_umask_tokens("a++r").is_err());
471
472 assert!(get_umask_tokens("+ar+").is_err());
473
474 assert!(get_umask_tokens("").is_err());
475 }
476}