shell/platform/unix/
umask.rs

1use nix::sys::stat::Mode;
2use std::io;
3
4pub use nix::libc::mode_t;
5
6static NIX_PERMISSIONS: &[Mode] = &[
7    Mode::S_IRUSR,
8    Mode::S_IWUSR,
9    Mode::S_IXUSR,
10    Mode::S_IRGRP,
11    Mode::S_IWGRP,
12    Mode::S_IXGRP,
13    Mode::S_IROTH,
14    Mode::S_IWOTH,
15    Mode::S_IXOTH,
16];
17
18fn get_class(str: &str) -> Result<mode_t, io::Error> {
19    if str.is_empty() {
20        Ok(0b111111111)
21    } else {
22        let next = str.chars().find(|x| !is_user_access_token(*x));
23        if next.is_some() {
24            let msg =
25                "symbolic mode string before the '+' can only contain u, g, o, or a.".to_string();
26            Err(io::Error::other(msg))
27        } else {
28            let mut class: mode_t = 0;
29            for c in str.chars() {
30                class |= match c {
31                    'u' => 0b111000000,
32                    'g' => 0b000111000,
33                    'o' => 0b000000111,
34                    'a' => 0b111111111,
35                    c if c.is_whitespace() => 0b111111111,
36                    _ => 0,
37                }
38            }
39            Ok(class)
40        }
41    }
42}
43
44fn get_perms(str: &str) -> Result<mode_t, io::Error> {
45    if str.is_empty() {
46        Ok(0b111111111)
47    } else {
48        let next = str.chars().find(|x| !is_permission_token(*x));
49        if next.is_some() {
50            let msg =
51                "symbolic mode string before the '+' can only contain r, w, or x.".to_string();
52            Err(io::Error::other(msg))
53        } else {
54            let mut class: mode_t = 0;
55            for c in str.chars() {
56                class |= match c {
57                    'r' => 0b100100100,
58                    'w' => 0b010010010,
59                    'x' => 0b001001001,
60                    _ => 0,
61                }
62            }
63            Ok(class)
64        }
65    }
66}
67
68fn decode_symbolic_mode_string(str: &str, split_char: char) -> Result<(mode_t, mode_t), io::Error> {
69    let mode_strings = str.split(split_char).collect::<Vec<&str>>();
70    if mode_strings.len() == 2 {
71        if let (Some(c), Some(p)) = (mode_strings.first(), mode_strings.get(1)) {
72            if c.is_empty() && p.is_empty() {
73                let msg = format!(
74                    "symbolic mode string must have a valid character before and/or \
75                        after the '{}' character.",
76                    split_char,
77                );
78                Err(io::Error::other(msg))
79            } else {
80                let class = get_class(c)?;
81                let perms = get_perms(p)?;
82                Ok((class, perms))
83            }
84        } else {
85            let msg = format!(
86                "symbolic mode string contains too many '{}' characters.",
87                split_char,
88            );
89            Err(io::Error::other(msg))
90        }
91    } else {
92        let msg = format!(
93            "symbolic mode string contains too many '{}' characters.",
94            split_char,
95        );
96        Err(io::Error::other(msg))
97    }
98}
99
100enum PermissionOperator {
101    Plus,
102    Minus,
103    Equal,
104}
105
106struct MaskType {
107    class: mode_t,
108    perms: mode_t,
109    mask_type: PermissionOperator,
110}
111
112impl MaskType {
113    #[allow(clippy::unnecessary_cast)]
114    fn combine(&self, mode: mode_t) -> mode_t {
115        let m = match &self.mask_type {
116            PermissionOperator::Plus => !(self.class & self.perms) & mode,
117            PermissionOperator::Minus => mode | (self.class & self.perms),
118            PermissionOperator::Equal => {
119                ((self.class & self.perms) ^ 0o777) & ((mode & !self.class) ^ self.class)
120            }
121        };
122        to_mode(m).bits() as mode_t
123    }
124}
125
126fn to_mask_type(str: &str) -> Result<MaskType, io::Error> {
127    let decode = |split_char| -> Result<(mode_t, mode_t), io::Error> {
128        decode_symbolic_mode_string(str, split_char)
129    };
130    if str.contains('+') {
131        let (class, perms) = decode('+')?;
132        Ok(MaskType {
133            class,
134            perms,
135            mask_type: PermissionOperator::Plus,
136        })
137    } else if str.contains('-') {
138        let (class, perms) = decode('-')?;
139        Ok(MaskType {
140            class,
141            perms,
142            mask_type: PermissionOperator::Minus,
143        })
144    } else if str.contains('=') {
145        let (class, perms) = decode('=')?;
146        Ok(MaskType {
147            class,
148            perms,
149            mask_type: PermissionOperator::Equal,
150        })
151    } else {
152        let msg = "symbolic mode string must contain one of '+', '-', or '='.".to_string();
153        Err(io::Error::other(msg))
154    }
155}
156
157fn get_umask_tokens(str: &str) -> Result<Vec<MaskType>, io::Error> {
158    match str.lines().count() {
159        1 => {
160            let mut masks = vec![];
161            for x in str.split(',') {
162                let mask_type = to_mask_type(x)?;
163                masks.push(mask_type);
164            }
165            Ok(masks)
166        }
167        _ => {
168            let msg = "must supply only one line as input.".to_string();
169            Err(io::Error::other(msg))
170        }
171    }
172}
173
174fn with_umask_tokens(mut umask: mode_t, masks: Vec<MaskType>) -> mode_t {
175    for x in masks {
176        umask = x.combine(umask)
177    }
178    umask
179}
180
181/// makes sure the returned string is 4 characters and the first character is 0.
182fn make_parsable_octal_string(str: &str) -> Result<String, io::Error> {
183    if str.is_empty() {
184        let msg = "no input.".to_string();
185        Err(io::Error::other(msg))
186    } else if str.len() > 4 {
187        let msg = "no more than 4 characters can be used to specify a umask, e.g.\
188             644 or 0222."
189            .to_string();
190        Err(io::Error::other(msg))
191    } else if str.len() == 4 && !str.starts_with('0') {
192        let msg = "most significant octal character can only be 0.".to_string();
193        Err(io::Error::other(msg))
194    } else {
195        let mut ret = String::from(str);
196        while ret.len() < 4 {
197            ret = "0".to_owned() + &ret;
198        }
199        Ok(ret)
200    }
201}
202
203fn build_mask(to_shift: usize, c: char) -> Result<mode_t, io::Error> {
204    let apply = |m| Ok((m << (to_shift * 3)) as mode_t);
205    match c {
206        '0' => apply(0b000),
207        '1' => apply(0b001),
208        '2' => apply(0b010),
209        '3' => apply(0b011),
210        '4' => apply(0b100),
211        '5' => apply(0b101),
212        '6' => apply(0b110),
213        '7' => apply(0b111),
214        _ => {
215            let msg = "octal format can only take on values between 0 and 7 inclusive.".to_string();
216            Err(io::Error::other(msg))
217        }
218    }
219}
220
221fn octal_string_to_mode_t(str: &str) -> Result<mode_t, io::Error> {
222    let mut val = 0;
223    let mut err = false;
224    for (usize, c) in str.chars().rev().enumerate() {
225        match usize {
226            0..=2 => val |= build_mask(usize, c)?,
227            3 => {}
228            _ => {
229                err = true;
230                break;
231            }
232        }
233    }
234    if err {
235        let msg = "failed to parse provided octal.".to_string();
236        Err(io::Error::other(msg))
237    } else {
238        Ok(val)
239    }
240}
241
242#[allow(clippy::unnecessary_cast)]
243fn to_mode(i: mode_t) -> Mode {
244    NIX_PERMISSIONS.iter().fold(Mode::empty(), |acc, x| {
245        if (x.bits() as mode_t & i) == x.bits() as mode_t {
246            acc | *x
247        } else {
248            acc
249        }
250    })
251}
252
253fn octal_string_to_mode(str: &str) -> Result<Mode, io::Error> {
254    let str = make_parsable_octal_string(str)?;
255    let val = octal_string_to_mode_t(&str)?;
256    Ok(to_mode(val))
257}
258
259fn is_permission_token(ch: char) -> bool {
260    matches!(ch, 'r' | 'w' | 'x')
261}
262
263fn is_user_access_token(ch: char) -> bool {
264    matches!(ch, 'u' | 'g' | 'o' | 'a')
265}
266
267fn is_digit(ch: char) -> bool {
268    matches!(
269        ch,
270        '0' | '1' | '2' | '3' | '4' | '5' | '6' | '7' | '8' | '9'
271    )
272}
273
274/// If mask_string is a mode string then merge it with umask and set the current umask.
275/// If mask_string is an int then treat it as a umask and set the current umask (no merge)).
276pub fn merge_and_set_umask(current_umask: mode_t, mask_string: &str) -> Result<mode_t, io::Error> {
277    if mask_string.parse::<u32>().is_ok() {
278        let mode = octal_string_to_mode(mask_string)?;
279        nix::sys::stat::umask(mode);
280        Ok(mode.bits())
281    } else if !mask_string.is_empty() {
282        #[allow(clippy::unnecessary_cast)]
283        let mode = if is_digit(mask_string.chars().next().unwrap()) {
284            octal_string_to_mode(mask_string)?.bits() as mode_t
285        } else {
286            let masks = get_umask_tokens(mask_string)?;
287            with_umask_tokens(current_umask, masks)
288        };
289        if let Some(umask) = Mode::from_bits(mode) {
290            nix::sys::stat::umask(umask);
291            Ok(mode)
292        } else {
293            Err(io::Error::other("invalid umask".to_string()))
294        }
295    } else {
296        let msg = "no input.".to_string();
297        Err(io::Error::other(msg))
298    }
299}
300
301/// Clears the current umask and returns the previous umask.
302#[allow(clippy::unnecessary_cast)]
303pub fn get_and_clear_umask() -> mode_t {
304    nix::sys::stat::umask(Mode::empty()).bits() as mode_t
305}
306
307/// Set current umask to umask.
308#[allow(clippy::unnecessary_cast)]
309pub fn set_umask(umask: mode_t) -> Result<(), io::Error> {
310    if let Some(umask) = Mode::from_bits(umask) {
311        nix::sys::stat::umask(umask);
312        Ok(())
313    } else {
314        Err(io::Error::other("invalid mode"))
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    #[test]
323    fn test_parse_perms() {
324        let perms_str = "rwx";
325        let perms = get_perms(perms_str).unwrap();
326        assert_eq!(0b111111111, perms);
327
328        let perms_str = "r";
329        let perms = get_perms(perms_str).unwrap();
330        assert_eq!(0b100100100, perms);
331
332        let perms_str = "w";
333        let perms = get_perms(perms_str).unwrap();
334        assert_eq!(0b010010010, perms);
335
336        let perms_str = "x";
337        let perms = get_perms(perms_str).unwrap();
338        assert_eq!(0b001001001, perms);
339
340        let perms_str = "rw";
341        let perms = get_perms(perms_str).unwrap();
342        assert_eq!(0b110110110, perms);
343    }
344
345    #[test]
346    fn test_parse_class() {
347        let class_str = "ugo";
348        let class = get_class(class_str).unwrap();
349        assert_eq!(0b111111111, class);
350
351        let class_str = "a";
352        let class = get_class(class_str).unwrap();
353        assert_eq!(0b111111111, class);
354
355        let class_str = "u";
356        let class = get_class(class_str).unwrap();
357        assert_eq!(0b111000000, class);
358
359        let class_str = "g";
360        let class = get_class(class_str).unwrap();
361        assert_eq!(0b000111000, class);
362
363        let class_str = "o";
364        let class = get_class(class_str).unwrap();
365        assert_eq!(0b000000111, class);
366
367        let class_str = "uo";
368        let class = get_class(class_str).unwrap();
369        assert_eq!(0b111000111, class);
370    }
371
372    #[test]
373    fn test_umask_octal() {
374        let bs = 0b001001001;
375        let m = to_mode(bs);
376        assert_eq!(bs, m.bits());
377
378        let m = octal_string_to_mode("0522").unwrap();
379        assert_eq!(338, m.bits() as u32);
380
381        let m = octal_string_to_mode("522").unwrap();
382        assert_eq!(338, m.bits() as u32);
383
384        let m = octal_string_to_mode("713").unwrap();
385        assert_eq!(0b111001011, m.bits() as u32);
386
387        let m = octal_string_to_mode("466").unwrap();
388        assert_eq!(0b100110110, m.bits() as u32);
389
390        let m = octal_string_to_mode("0").unwrap();
391        assert_eq!(0b000000000, m.bits() as u32);
392
393        let m = octal_string_to_mode("45").unwrap();
394        assert_eq!(0b000100101, m.bits() as u32);
395
396        assert!(octal_string_to_mode("a+n").is_err());
397
398        assert!(octal_string_to_mode("1111").is_err());
399
400        assert!(octal_string_to_mode("11111").is_err());
401
402        assert!(octal_string_to_mode("0S11").is_err());
403    }
404
405    #[test]
406    #[allow(clippy::unnecessary_cast)]
407    fn test_umask_parser() {
408        let umask = to_mode(0o022).bits() as mode_t;
409
410        let m = with_umask_tokens(umask, get_umask_tokens("go+rx").unwrap());
411        assert_eq!(0o022, m);
412
413        let m = with_umask_tokens(umask, get_umask_tokens("+w").unwrap());
414        assert_eq!(0o0, m);
415
416        let m = with_umask_tokens(umask, get_umask_tokens("a-rw").unwrap());
417        assert_eq!(0o666, m);
418
419        let m = with_umask_tokens(umask, get_umask_tokens("g-rw").unwrap());
420        assert_eq!(0o062, m);
421
422        let m = with_umask_tokens(umask, get_umask_tokens("ug=rw").unwrap());
423        assert_eq!(0o0112, m);
424
425        let m = with_umask_tokens(umask, get_umask_tokens("a=r,ug=rw").unwrap());
426        assert_eq!(0o0113, m);
427
428        let m = with_umask_tokens(umask, get_umask_tokens("a=r,g+w").unwrap());
429        assert_eq!(0o0313, m);
430
431        let m = with_umask_tokens(umask, get_umask_tokens("a=r,a-r").unwrap());
432        assert_eq!(0o0777, m);
433
434        let m = with_umask_tokens(umask, get_umask_tokens("ugo+x").unwrap());
435        assert_eq!(0o0022, m);
436
437        let m = with_umask_tokens(umask, get_umask_tokens("+x").unwrap());
438        assert_eq!(0o0022, m);
439
440        let m = with_umask_tokens(umask, get_umask_tokens("a+rw").unwrap());
441        assert_eq!(0o0, m);
442
443        let m = with_umask_tokens(umask, get_umask_tokens("a=r").unwrap());
444        assert_eq!(0o333, m);
445
446        let m = with_umask_tokens(umask, get_umask_tokens("ug+rwx").unwrap());
447        assert_eq!(0o002, m);
448
449        let m = with_umask_tokens(umask, get_umask_tokens("ug+").unwrap());
450        assert_eq!(0o002, m);
451
452        let m = with_umask_tokens(umask, get_umask_tokens("o-rwx").unwrap());
453        assert_eq!(0o0027, m);
454
455        let m = with_umask_tokens(umask, get_umask_tokens("u-x,g=r,o+w").unwrap());
456        assert_eq!(0o0130, m);
457
458        assert!(get_umask_tokens("glo+rx").is_err());
459
460        assert!(get_umask_tokens("go+nrx").is_err());
461
462        assert!(get_umask_tokens("+n").is_err());
463
464        assert!(get_umask_tokens("a+n").is_err());
465
466        assert!(get_umask_tokens("ar").is_err());
467
468        assert!(get_umask_tokens("+a+r").is_err());
469
470        assert!(get_umask_tokens("a++r").is_err());
471
472        assert!(get_umask_tokens("+ar+").is_err());
473
474        assert!(get_umask_tokens("").is_err());
475    }
476}