|
2 | 2 | import ast
|
3 | 3 | import re
|
4 | 4 |
|
5 |
| -def deindent_docstring(doc): |
6 |
| - if doc: |
7 |
| - # Find the indent to remove from the doctring. We consider the following possibilities: |
8 |
| - # Option 1: |
9 |
| - # """This is the first line |
10 |
| - # This is the second line |
11 |
| - # """ |
12 |
| - # Option 2: |
13 |
| - # """ |
14 |
| - # This is the first line |
15 |
| - # This is the second line |
16 |
| - # """ |
17 |
| - # Option 3: |
18 |
| - # """ |
19 |
| - # This is the first line |
20 |
| - # This is the second line |
21 |
| - # """ |
22 |
| - # |
23 |
| - # In all cases, we can find the indent to remove by doing the following: |
24 |
| - # - Check the first non-empty line, if it has an indent, use that as the base indent |
25 |
| - # - If it does not have an indent and there is a second line, check the indent of the |
26 |
| - # second line and use that |
27 |
| - saw_first_line = False |
28 |
| - matched_indent = None |
29 |
| - for line in doc.splitlines(): |
30 |
| - if line: |
31 |
| - matched_indent = re.match('[\t ]+', line) |
32 |
| - if matched_indent is not None or saw_first_line: |
33 |
| - break |
34 |
| - saw_first_line = True |
35 |
| - if matched_indent: |
36 |
| - return re.sub(r'\n' + matched_indent.group(), '\n', doc).strip() |
37 |
| - else: |
38 |
| - return doc |
39 |
| - else: |
40 |
| - return '' |
41 |
| - |
42 | 5 | class DAGNode(object):
|
43 | 6 | def __init__(self, func_ast, decos, doc):
|
44 | 7 | self.name = func_ast.name
|
45 | 8 | self.func_lineno = func_ast.lineno
|
46 | 9 | self.decorators = decos
|
47 |
| - self.doc = deindent_docstring(doc) |
| 10 | + self.doc = doc.rstrip() |
48 | 11 |
|
49 | 12 | # these attributes are populated by _parse
|
50 | 13 | self.tail_next_lineno = 0
|
@@ -148,32 +111,46 @@ def __str__(self):
|
148 | 111 |
|
149 | 112 | class StepVisitor(ast.NodeVisitor):
|
150 | 113 |
|
151 |
| - def __init__(self, nodes, flow): |
| 114 | + def __init__(self, nodes): |
152 | 115 | self.nodes = nodes
|
153 |
| - self.flow = flow |
154 | 116 | super(StepVisitor, self).__init__()
|
155 | 117 |
|
156 | 118 | def visit_FunctionDef(self, node):
|
157 |
| - func = getattr(self.flow, node.name) |
158 |
| - if hasattr(func, 'is_step'): |
159 |
| - self.nodes[node.name] = DAGNode(node, func.decorators, func.__doc__) |
| 119 | + decos = [d.func.id if isinstance(d, ast.Call) else d.id |
| 120 | + for d in node.decorator_list] |
| 121 | + if 'step' in decos: |
| 122 | + doc = ast.get_docstring(node) |
| 123 | + self.nodes[node.name] = DAGNode(node, decos, doc if doc else '') |
160 | 124 |
|
161 | 125 | class FlowGraph(object):
|
162 | 126 |
|
163 |
| - def __init__(self, flow): |
164 |
| - self.name = flow.__name__ |
165 |
| - self.nodes = self._create_nodes(flow) |
166 |
| - self.doc = deindent_docstring(flow.__doc__) |
| 127 | + def __init__(self, flow=None, source=None, name=None): |
| 128 | + if flow: |
| 129 | + module = __import__(flow.__module__) |
| 130 | + source = inspect.getsource(module) |
| 131 | + self.name = flow.__name__ |
| 132 | + else: |
| 133 | + self.name = name |
| 134 | + |
| 135 | + self.nodes = self._create_nodes(source) |
167 | 136 | self._traverse_graph()
|
168 | 137 | self._postprocess()
|
169 | 138 |
|
170 |
| - def _create_nodes(self, flow): |
171 |
| - module = __import__(flow.__module__) |
172 |
| - tree = ast.parse(inspect.getsource(module)).body |
173 |
| - root = [n for n in tree\ |
174 |
| - if isinstance(n, ast.ClassDef) and n.name == self.name][0] |
| 139 | + def _create_nodes(self, source): |
| 140 | + def _flow(n): |
| 141 | + if isinstance(n, ast.ClassDef): |
| 142 | + bases = [b.id for b in n.bases] |
| 143 | + if 'FlowSpec' in bases: |
| 144 | + return self.name is None or n.name == self.name |
| 145 | + |
| 146 | + # NOTE: this will fail if a file has multiple FlowSpec classes |
| 147 | + # and no name is specified |
| 148 | + [root] = list(filter(_flow, ast.parse(source).body)) |
| 149 | + self.name = root.name |
| 150 | + doc = ast.get_docstring(root) |
| 151 | + self.doc = doc if doc else '' |
175 | 152 | nodes = {}
|
176 |
| - StepVisitor(nodes, flow).visit(root) |
| 153 | + StepVisitor(nodes).visit(root) |
177 | 154 | return nodes
|
178 | 155 |
|
179 | 156 | def _postprocess(self):
|
|
0 commit comments