Python源码分析6 – 从CST到AST的转化
Introduction上篇文章解释了Python是如何使用PyParser生成CST的。回顾一下,Python执行代码要经过如下过程:
1. Tokenizer进行词法分析,把源程序分解为Token
2. Parser根据Token创建CST
3. CST被转换为AST
4. AST被编译为字节码
5. 执行字节码
当执行Python代码的时候,以代码存放在文件中的情况为例,Python会调用PyParser_ASTFromFile函数将文件的代码内容转换为AST:
mod_ty
PyParser_ASTFromFile(FILE *fp, const char *filename, int start, char *ps1,
char *ps2, PyCompilerFlags *flags, int *errcode,
PyArena *arena)
{
mod_ty mod;
perrdetail err;
node *n = PyParser_ParseFileFlags(fp, filename, &_PyParser_Grammar,
start, ps1, ps2, &err, PARSER_FLAGS(flags));
if (n) {
mod = PyAST_FromNode(n, flags, filename, arena);
PyNode_Free(n);
return mod;
}
else {
err_input(&err);
if (errcode)
*errcode = err.error;
return NULL;
}
}
在PyParser_ParseFileFlags把文件转换成CST之后,PyAST_FromNode函数会把CST转换成AST。此函数定义在include\ast.h中:
PyAPI_FUNC(mod_ty) PyAST_FromNode(const node *, PyCompilerFlags *flags,
const char *, PyArena *);
在分析此函数之前,我们先来看一下有关AST的一些基本的类型定义。
AST Types
AST所用到的类型均定义在Python_ast.h中,以stmt_ty类型为例:
enum _stmt_kind {FunctionDef_kind=1, ClassDef_kind=2, Return_kind=3,
Delete_kind=4, Assign_kind=5, AugAssign_kind=6, Print_kind=7,
For_kind=8, While_kind=9, If_kind=10, With_kind=11,
Raise_kind=12, TryExcept_kind=13, TryFinally_kind=14,
Assert_kind=15, Import_kind=16, ImportFrom_kind=17,
Exec_kind=18, Global_kind=19, Expr_kind=20, Pass_kind=21,
Break_kind=22, Continue_kind=23};
struct _stmt {
enum _stmt_kind kind;
union {
struct {
identifier name;
arguments_ty args;
asdl_seq *body;
asdl_seq *decorators;
} FunctionDef;
struct {
identifier name;
asdl_seq *bases;
asdl_seq *body;
} ClassDef;
struct {
expr_ty value;
} Return;
// ... 过长,中间从略
struct {
expr_ty value;
} Expr;
} v;
int lineno;
int col_offset;
};
typedef struct _stmt *stmt_ty;
stmt_ty是语句结点类型,实际上是_stmt结构的指针。_stmt结构比较长,但有着很清晰的Pattern:
1. 第一个Field为kind,代表语句的类型。_stmt_kind定义了_stmt的所有可能的语句类型,从函数定义语句,类定义语句直到Continue语句共有23种类型。
2. 接下来是一个union v,每个成员均为一个struct,分别对应_stmt_kind中的一种类型,如_stmt.v.FunctionDef对应了_stmt_kind枚举中的FunctionDef_Kind,也就是说,当_stmt.kind == FunctionDef_Kind时,_stmt.v.FunctionDef中保存的就是对应的函数定义语句的具体内容。
3. 其他数据,如lineno和col_offset
大部分AST结点类型均是按照类似的pattern来定义的,不再赘述。除此之外,另外有一种比较简单的AST类型如operator_ty,expr_context_ty等,由于这些类型仍以_ty结尾,因此也可以认为是AST的结点,但实际上,这些类型只是简单的枚举类型,并非指针。因此在以后的文章中,并不把此类AST类型作为结点看待,而是作为简单的枚举处理。
由于每个AST类型会在union中引用其他的AST,这样层层引用,最后便形成了一颗AST树,试举例如下:
这颗AST树代表的是单条语句a+1。
与AST类型对应,在Python_ast.h / .c中定义了大量用于创建AST结点的函数,可以看作是AST结点的构造函数。以BinOp函数为例:
expr_ty
BinOp(expr_ty left, operator_ty op, expr_ty right, int lineno, int col_offset,
PyArena *arena)
{
expr_ty p;
if (!left) {
PyErr_SetString(PyExc_ValueError,
"field left is required for BinOp");
return NULL;
}
if (!op) {
PyErr_SetString(PyExc_ValueError,
"field op is required for BinOp");
return NULL;
}
if (!right) {
PyErr_SetString(PyExc_ValueError,
"field right is required for BinOp");
return NULL;
}
p = (expr_ty)PyArena_Malloc(arena, sizeof(*p));
if (!p) {
PyErr_NoMemory();
return NULL;
}
p->kind = BinOp_kind;
p->v.BinOp.left = left;
p->v.BinOp.op = op;
p->v.BinOp.right = right;
p->lineno = lineno;
p->col_offset = col_offset;
return p;
}
此函数只是根据传入的参数做一些简单的错误检查,分配内存,初始化对应的expr_ty类型,并返回指针。
adsl_seq & adsl_int_seq
在上面的stmt_ty定义中,如果稍微注意的话,可以发现其中大量用到了adsl_seq类型。类似在python_ast.h中其他AST类型中还会用到adsl_int_seq类型。adsl_seq & adsl_int_seq简单来说,是一个动态构造出的定长数组。Adsl_seq是void *的数组:
typedef struct {
int size;
void *elements;
} asdl_seq;
而adsl_int_seq则是int类型的数组:
typedef struct {
int size;
int elements;
} asdl_int_seq;
Size是数组长度,elements则是数组的元素。注意这些类型在定义elements时使用了一点技巧,定义的elements数组长度为1,而在动态分配内存的时候则是按照实际长度sizeof(adsl_seq) + size - 1来分配:
asdl_seq *
asdl_seq_new(int size, PyArena *arena)
{
asdl_seq *seq = NULL;
size_t n = sizeof(asdl_seq) +
(size ? (sizeof(void *) * (size - 1)) : 0);
seq = (asdl_seq *)PyArena_Malloc(arena, n);
if (!seq) {
PyErr_NoMemory();
return NULL;
}
memset(seq, 0, n);
seq->size = size;
return seq;
}
这样既可以动态分配数组元素,也可以很方便的用elements来访问数组元素。
用如下的宏和函数可以操作adsl_seq / adsl_int_seq :
asdl_seq *asdl_seq_new(int size, PyArena *arena);
asdl_int_seq *asdl_int_seq_new(int size, PyArena *arena);
#define asdl_seq_GET(S, I) (S)->elements[(I)]
#define asdl_seq_LEN(S) ((S) == NULL ? 0 : (S)->size)
#ifdef Py_DEBUG
#define asdl_seq_SET(S, I, V) { \
int _asdl_i = (I); \
assert((S) && _asdl_i < (S)->size); \
(S)->elements = (V); \
}
#else
#define asdl_seq_SET(S, I, V) (S)->elements = (V)
#endif
需要说明的是adsl_seq / adsl_int_seq均是从PyArena中分配出,PyArena会在以后的文章中详细分析,目前我们可以暂时把PyArena简单看作一个分配内存用的堆。
From CST to AST
如前所述,PyAST_FromNode负责从CST到AST的转换。简单来说,此函数会深度遍历整棵CST,过滤掉CST中的多余信息,只是将有意义的CST子树转换成AST结点构造出AST树。
PyAst_FromNode函数的大致代码如下:
mod_ty
PyAST_FromNode(const node *n, PyCompilerFlags *flags, const char *filename,
PyArena *arena)
{
...
switch (TYPE(n)) {
case file_input:
stmts = asdl_seq_new(num_stmts(n), arena);
if (!stmts)
return NULL;
for (i = 0; i < NCH(n) - 1; i++) {
ch = CHILD(n, i);
if (TYPE(ch) == NEWLINE)
continue;
REQ(ch, stmt);
num = num_stmts(ch);
if (num == 1) {
s = ast_for_stmt(&c, ch);
if (!s)
goto error;
asdl_seq_SET(stmts, k++, s);
}
else {
ch = CHILD(ch, 0);
REQ(ch, simple_stmt);
for (j = 0; j < num; j++) {
s = ast_for_stmt(&c, CHILD(ch, j * 2));
if (!s)
goto error;
asdl_seq_SET(stmts, k++, s);
}
}
}
return Module(stmts, arena);
case eval_input: {
...
}
case single_input: {
...
}
default:
goto error;
}
可以看到PyAst_FromNode根据N的类型作了不同处理,以file_input为例,file_input的产生式(在Grammar文件中定义)如下:File_input : (NEWLINE | stmt)* ENDMARKER,对应的PyAst_FromNode的代码作了如下事情:
1. 调用num_stmts(n)计算出所有顶层语句的个数,并创建出合适大小的adsl_seq结构以存放这些语句
2. 对于file_input结点的所有子结点作如下处理: file_input: ( NEW_LINE | stmt )* ENDMARKER
a. 忽略掉NEW_LINE,换行无需处理
b. REQ(ch, stmt)断言ch的类型必定为stmt,从产生式可以得出此结论
c. 计算出子结点stmt的语句条数n:
i. N == 1,说明stmt对应单条语句,调用ast_for_stmt遍历stmt对应得CST子树,生成对应的AST子树,并调用adsl_seq_SET设置到数组之中。这样AST的根结点mod_ty便可以知道有哪些顶层的语句(stmt),这些语句结点便是根结点mod_ty的子结点。
ii. N > 1,说明stmt对应多条语句。根据Grammar文件中定义的如下产生式可以推知此时ch的子结点必然为simple_stmt。
stmt: simple_stmt | compound_stmt
simple_stmt: small_stmt (';' small_stmt)* [';'] NEWLINE
small_stmt: (expr_stmt | print_stmt | del_stmt | pass_stmt | flow_stmt |
import_stmt | global_stmt | exec_stmt | assert_stmt)
compound_stmt: if_stmt | while_stmt | for_stmt | try_stmt | with_stmt | funcdef | classdef
由于simple_stmt的定义中small_stmt和’;’总是成对出现,因此index为偶数的CST结点便是所需的单条顶层语句的结点,对于每个这样的结点调用adsl_seq_SET设置到数组之中
3. 最后,调用Module函数从stmts数组生成mod_ty结点,也就是AST的根结点
上面的过程中用到了两个关键函数:num_stmts和ast_for_stmt。先来看num_stmts函数:
static int
num_stmts(const node *n)
{
int i, l;
node *ch;
switch (TYPE(n)) {
case single_input:
if (TYPE(CHILD(n, 0)) == NEWLINE)
return 0;
else
return num_stmts(CHILD(n, 0));
case file_input:
l = 0;
for (i = 0; i < NCH(n); i++) {
ch = CHILD(n, i);
if (TYPE(ch) == stmt)
l += num_stmts(ch);
}
return l;
case stmt:
return num_stmts(CHILD(n, 0));
case compound_stmt:
return 1;
case simple_stmt:
return NCH(n) / 2; /* Divide by 2 to remove count of semi-colons */
case suite:
if (NCH(n) == 1)
return num_stmts(CHILD(n, 0));
else {
l = 0;
for (i = 2; i < (NCH(n) - 1); i++)
l += num_stmts(CHILD(n, i));
return l;
}
default: {
char buf;
sprintf(buf, "Non-statement found: %d %d\n",
TYPE(n), NCH(n));
Py_FatalError(buf);
}
}
assert(0);
return 0;
}
此函数比较简单,根据结点类型和产生式递归计算顶层语句的个数。所谓顶层语句,也就是把复合语句(compound_stmt)看作单条语句,复合语句中的内部的语句不做计算,当然普通的简单语句(small_stmt) 也是算1条语句。下面根据不同结点类型分析此函数:
1. Single_input
代表单条交互语句,对应的产生式:single_input: NEWLINE | simple_stmt | compound_stmt NEWLINE
如果single_input的第一个子结点为NEW_LINE,说明无语句,返回0,否则说明是simple_stmt或者compound_stmt NEWLINE,可以直接递归调用num_stmts处理
2. File_input
代表整个代码文件,对应的产生式:file_input: (NEWLINE | stmt)* ENDMARKER
只需要反复对每个子结点调用num_stmts既可。
3. Stmt
代表语句,对应的产生式:stmt: simple_stmt | compound_stmt
对第一个子结点调用num_stmts既可。
4. Compound_stmt
代表复合语句,对应的产生式:compound_stmt: if_stmt | while_stmt | for_stmt | try_stmt | with_stmt | funcdef | classdef
compound_stmt只可能有单个子结点,而且必然代表单条顶层的语句,因此无需继续遍历,直接返回1既可。
5. Simple_stmt
代表简单语句(非复合语句)的集合,对应的产生式:simple_stmt: small_stmt (';' small_stmt)* [';'] NEWLINE
可以看到顶层语句数=子结点数/2 (去掉多余的分号和NEWLINE)
6. Suite
代表复合语句中的语句块,也就是冒号之后的部分(如:classdef: 'class' NAME ['(' ')'] ':' suite),类似于C/C++大括号中的内容,对应的产生式如下:suite: simple_stmt | NEWLINE INDENT stmt+ DEDENT
子结点数为1,说明必然是simple_stmt,可以直接调用num_stmts处理,否则,说明是多个stmt的集合,遍历所有子结点调用num_stmts并累加既可
可以看到,num_stmts基本上是和语句有关的产生式是一一对应的。
接下来分析ast_for_stmts的内容:
static stmt_ty
ast_for_stmt(struct compiling *c, const node *n)
{
if (TYPE(n) == stmt) {
assert(NCH(n) == 1);
n = CHILD(n, 0);
}
<p c
页:
[1]