|
2 | 2 | import ast
|
3 | 3 | import re
|
4 | 4 |
|
| 5 | +def deindent_docstring(doc): |
| 6 | + if doc: |
| 7 | + # Find the indent to remove from the doctring. We consider the following possibilities: |
| 8 | + # Option 1: |
| 9 | + # """This is the first line |
| 10 | + # This is the second line |
| 11 | + # """ |
| 12 | + # Option 2: |
| 13 | + # """ |
| 14 | + # This is the first line |
| 15 | + # This is the second line |
| 16 | + # """ |
| 17 | + # Option 3: |
| 18 | + # """ |
| 19 | + # This is the first line |
| 20 | + # This is the second line |
| 21 | + # """ |
| 22 | + # |
| 23 | + # In all cases, we can find the indent to remove by doing the following: |
| 24 | + # - Check the first non-empty line, if it has an indent, use that as the base indent |
| 25 | + # - If it does not have an indent and there is a second line, check the indent of the |
| 26 | + # second line and use that |
| 27 | + saw_first_line = False |
| 28 | + matched_indent = None |
| 29 | + for line in doc.splitlines(): |
| 30 | + if line: |
| 31 | + matched_indent = re.match('[\t ]+', line) |
| 32 | + if matched_indent is not None or saw_first_line: |
| 33 | + break |
| 34 | + saw_first_line = True |
| 35 | + if matched_indent: |
| 36 | + return re.sub(r'\n' + matched_indent.group(), '\n', doc).strip() |
| 37 | + else: |
| 38 | + return doc |
| 39 | + else: |
| 40 | + return '' |
| 41 | + |
5 | 42 | class DAGNode(object):
|
6 | 43 | def __init__(self, func_ast, decos, doc):
|
7 | 44 | self.name = func_ast.name
|
8 | 45 | self.func_lineno = func_ast.lineno
|
9 | 46 | self.decorators = decos
|
10 |
| - self.doc = doc.rstrip() |
| 47 | + self.doc = deindent_docstring(doc) |
11 | 48 |
|
12 | 49 | # these attributes are populated by _parse
|
13 | 50 | self.tail_next_lineno = 0
|
@@ -111,46 +148,32 @@ def __str__(self):
|
111 | 148 |
|
112 | 149 | class StepVisitor(ast.NodeVisitor):
|
113 | 150 |
|
114 |
| - def __init__(self, nodes): |
| 151 | + def __init__(self, nodes, flow): |
115 | 152 | self.nodes = nodes
|
| 153 | + self.flow = flow |
116 | 154 | super(StepVisitor, self).__init__()
|
117 | 155 |
|
118 | 156 | def visit_FunctionDef(self, node):
|
119 |
| - decos = [d.func.id if isinstance(d, ast.Call) else d.id |
120 |
| - for d in node.decorator_list] |
121 |
| - if 'step' in decos: |
122 |
| - doc = ast.get_docstring(node) |
123 |
| - self.nodes[node.name] = DAGNode(node, decos, doc if doc else '') |
| 157 | + func = getattr(self.flow, node.name) |
| 158 | + if hasattr(func, 'is_step'): |
| 159 | + self.nodes[node.name] = DAGNode(node, func.decorators, func.__doc__) |
124 | 160 |
|
125 | 161 | class FlowGraph(object):
|
126 | 162 |
|
127 |
| - def __init__(self, flow=None, source=None, name=None): |
128 |
| - if flow: |
129 |
| - module = __import__(flow.__module__) |
130 |
| - source = inspect.getsource(module) |
131 |
| - self.name = flow.__name__ |
132 |
| - else: |
133 |
| - self.name = name |
134 |
| - |
135 |
| - self.nodes = self._create_nodes(source) |
| 163 | + def __init__(self, flow): |
| 164 | + self.name = flow.__name__ |
| 165 | + self.nodes = self._create_nodes(flow) |
| 166 | + self.doc = deindent_docstring(flow.__doc__) |
136 | 167 | self._traverse_graph()
|
137 | 168 | self._postprocess()
|
138 | 169 |
|
139 |
| - def _create_nodes(self, source): |
140 |
| - def _flow(n): |
141 |
| - if isinstance(n, ast.ClassDef): |
142 |
| - bases = [b.id for b in n.bases] |
143 |
| - if 'FlowSpec' in bases: |
144 |
| - return self.name is None or n.name == self.name |
145 |
| - |
146 |
| - # NOTE: this will fail if a file has multiple FlowSpec classes |
147 |
| - # and no name is specified |
148 |
| - [root] = list(filter(_flow, ast.parse(source).body)) |
149 |
| - self.name = root.name |
150 |
| - doc = ast.get_docstring(root) |
151 |
| - self.doc = doc if doc else '' |
| 170 | + def _create_nodes(self, flow): |
| 171 | + module = __import__(flow.__module__) |
| 172 | + tree = ast.parse(inspect.getsource(module)).body |
| 173 | + root = [n for n in tree\ |
| 174 | + if isinstance(n, ast.ClassDef) and n.name == self.name][0] |
152 | 175 | nodes = {}
|
153 |
| - StepVisitor(nodes).visit(root) |
| 176 | + StepVisitor(nodes, flow).visit(root) |
154 | 177 | return nodes
|
155 | 178 |
|
156 | 179 | def _postprocess(self):
|
|
0 commit comments