2
2
3
3
from __future__ import annotations
4
4
5
- from typing import TYPE_CHECKING
5
+ from typing import TYPE_CHECKING , Protocol
6
6
from unittest .mock import patch
7
7
from contextlib import contextmanager
8
8
import copy
9
+ import fnmatch
9
10
10
11
from e3 .os .process import Run , to_cmd_lines
11
12
@@ -49,6 +50,57 @@ def mock_run(config: MockRunConfig | None = None) -> Iterator[MockRun]:
49
50
yield run
50
51
51
52
53
+ class ArgumentChecker (Protocol ):
54
+ """Argument checker."""
55
+
56
+ def check (self , arg : str ) -> bool :
57
+ """Check an argument.
58
+
59
+ :param arg: the argument
60
+ :return: if the argument is valid
61
+ """
62
+ ...
63
+
64
+ def __repr__ (self ) -> str :
65
+ """Return a textual representation of the expected argument."""
66
+ ...
67
+
68
+
69
+ class GlobChecker (ArgumentChecker ):
70
+ """Check an argument against a glob."""
71
+
72
+ def __init__ (self , pattern : str ) -> None :
73
+ """Initialize GlobChecker.
74
+
75
+ :param pattern: the glob pattern
76
+ """
77
+ self .pattern = pattern
78
+
79
+ def check (self , arg : str ) -> bool :
80
+ """See ArgumentChecker."""
81
+ return fnmatch .fnmatch (arg , self .pattern )
82
+
83
+ def __repr__ (self ) -> str :
84
+ """See ArgumentChecker."""
85
+ return self .pattern .__repr__ ()
86
+
87
+
88
+ class SideEffect (Protocol ):
89
+ """Function to be called when a mocked command is called."""
90
+
91
+ def __call__ (
92
+ self , result : CommandResult , cmd : list [str ], * args : Any , ** kwargs : Any
93
+ ) -> None :
94
+ """Run when the mocked command is called.
95
+
96
+ :param result: the mocked command
97
+ :param cmd: actual arguments of the command
98
+ :param args: additional arguments for Run
99
+ :param kwargs: additional keyword arguments for Run
100
+ """
101
+ ...
102
+
103
+
52
104
class CommandResult :
53
105
"""Result of a command.
54
106
@@ -58,22 +110,25 @@ class CommandResult:
58
110
59
111
def __init__ (
60
112
self ,
61
- cmd : list [str ],
113
+ cmd : list [str | ArgumentChecker ],
62
114
status : int | None = None ,
63
115
raw_out : bytes = b"" ,
64
116
raw_err : bytes = b"" ,
117
+ side_effect : SideEffect | None = None ,
65
118
) -> None :
66
119
"""Initialize CommandResult.
67
120
68
121
:param cmd: expected arguments of the command
69
122
:param status: status code
70
123
:param raw_out: raw output log
71
124
:param raw_err: raw error log
125
+ :param side_effect: a function to be called when the command is called
72
126
"""
73
127
self .cmd = cmd
74
128
self .status = status if status is not None else 0
75
129
self .raw_out = raw_out
76
130
self .raw_err = raw_err
131
+ self .side_effect = side_effect
77
132
78
133
def check (self , cmd : list [str ]) -> None :
79
134
"""Check that cmd matches the expected arguments.
@@ -86,10 +141,16 @@ def check(self, cmd: list[str]) -> None:
86
141
)
87
142
88
143
for i , arg in enumerate (cmd ):
89
- if arg != self .cmd [i ] and self .cmd [i ] != "*" :
90
- raise UnexpectedCommandError (
91
- f"unexpected arguments { cmd } , expected { self .cmd } "
92
- )
144
+ checker = self .cmd [i ]
145
+ if isinstance (checker , str ):
146
+ if arg == checker or checker == "*" :
147
+ continue
148
+ elif checker .check (arg ):
149
+ continue
150
+
151
+ raise UnexpectedCommandError (
152
+ f"unexpected arguments { cmd } , expected { self .cmd } "
153
+ )
93
154
94
155
def __call__ (self , cmd : list [str ], * args : Any , ** kwargs : Any ) -> None :
95
156
"""Allow to run code to emulate the command.
@@ -101,7 +162,8 @@ def __call__(self, cmd: list[str], *args: Any, **kwargs: Any) -> None:
101
162
:param args: additional arguments for Run
102
163
:param kwargs: additional keyword arguments for Run
103
164
"""
104
- pass
165
+ if self .side_effect :
166
+ self .side_effect (self , cmd , * args , ** kwargs )
105
167
106
168
107
169
class MockRun (Run ):
0 commit comments