Source file src/cmd/vendor/golang.org/x/tools/go/ast/astutil/rewrite.go

     1  // Copyright 2017 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package astutil
     6  
     7  import (
     8  	"fmt"
     9  	"go/ast"
    10  	"reflect"
    11  	"sort"
    12  )
    13  
    14  // An ApplyFunc is invoked by Apply for each node n, even if n is nil,
    15  // before and/or after the node's children, using a Cursor describing
    16  // the current node and providing operations on it.
    17  //
    18  // The return value of ApplyFunc controls the syntax tree traversal.
    19  // See Apply for details.
    20  type ApplyFunc func(*Cursor) bool
    21  
    22  // Apply traverses a syntax tree recursively, starting with root,
    23  // and calling pre and post for each node as described below.
    24  // Apply returns the syntax tree, possibly modified.
    25  //
    26  // If pre is not nil, it is called for each node before the node's
    27  // children are traversed (pre-order). If pre returns false, no
    28  // children are traversed, and post is not called for that node.
    29  //
    30  // If post is not nil, and a prior call of pre didn't return false,
    31  // post is called for each node after its children are traversed
    32  // (post-order). If post returns false, traversal is terminated and
    33  // Apply returns immediately.
    34  //
    35  // Only fields that refer to AST nodes are considered children;
    36  // i.e., token.Pos, Scopes, Objects, and fields of basic types
    37  // (strings, etc.) are ignored.
    38  //
    39  // Children are traversed in the order in which they appear in the
    40  // respective node's struct definition. A package's files are
    41  // traversed in the filenames' alphabetical order.
    42  func Apply(root ast.Node, pre, post ApplyFunc) (result ast.Node) {
    43  	parent := &struct{ ast.Node }{root}
    44  	defer func() {
    45  		if r := recover(); r != nil && r != abort {
    46  			panic(r)
    47  		}
    48  		result = parent.Node
    49  	}()
    50  	a := &application{pre: pre, post: post}
    51  	a.apply(parent, "Node", nil, root)
    52  	return
    53  }
    54  
    55  var abort = new(int) // singleton, to signal termination of Apply
    56  
    57  // A Cursor describes a node encountered during Apply.
    58  // Information about the node and its parent is available
    59  // from the Node, Parent, Name, and Index methods.
    60  //
    61  // If p is a variable of type and value of the current parent node
    62  // c.Parent(), and f is the field identifier with name c.Name(),
    63  // the following invariants hold:
    64  //
    65  //	p.f            == c.Node()  if c.Index() <  0
    66  //	p.f[c.Index()] == c.Node()  if c.Index() >= 0
    67  //
    68  // The methods Replace, Delete, InsertBefore, and InsertAfter
    69  // can be used to change the AST without disrupting Apply.
    70  type Cursor struct {
    71  	parent ast.Node
    72  	name   string
    73  	iter   *iterator // valid if non-nil
    74  	node   ast.Node
    75  }
    76  
    77  // Node returns the current Node.
    78  func (c *Cursor) Node() ast.Node { return c.node }
    79  
    80  // Parent returns the parent of the current Node.
    81  func (c *Cursor) Parent() ast.Node { return c.parent }
    82  
    83  // Name returns the name of the parent Node field that contains the current Node.
    84  // If the parent is a *ast.Package and the current Node is a *ast.File, Name returns
    85  // the filename for the current Node.
    86  func (c *Cursor) Name() string { return c.name }
    87  
    88  // Index reports the index >= 0 of the current Node in the slice of Nodes that
    89  // contains it, or a value < 0 if the current Node is not part of a slice.
    90  // The index of the current node changes if InsertBefore is called while
    91  // processing the current node.
    92  func (c *Cursor) Index() int {
    93  	if c.iter != nil {
    94  		return c.iter.index
    95  	}
    96  	return -1
    97  }
    98  
    99  // field returns the current node's parent field value.
   100  func (c *Cursor) field() reflect.Value {
   101  	return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name)
   102  }
   103  
   104  // Replace replaces the current Node with n.
   105  // The replacement node is not walked by Apply.
   106  func (c *Cursor) Replace(n ast.Node) {
   107  	if _, ok := c.node.(*ast.File); ok {
   108  		file, ok := n.(*ast.File)
   109  		if !ok {
   110  			panic("attempt to replace *ast.File with non-*ast.File")
   111  		}
   112  		c.parent.(*ast.Package).Files[c.name] = file
   113  		return
   114  	}
   115  
   116  	v := c.field()
   117  	if i := c.Index(); i >= 0 {
   118  		v = v.Index(i)
   119  	}
   120  	v.Set(reflect.ValueOf(n))
   121  }
   122  
   123  // Delete deletes the current Node from its containing slice.
   124  // If the current Node is not part of a slice, Delete panics.
   125  // As a special case, if the current node is a package file,
   126  // Delete removes it from the package's Files map.
   127  func (c *Cursor) Delete() {
   128  	if _, ok := c.node.(*ast.File); ok {
   129  		delete(c.parent.(*ast.Package).Files, c.name)
   130  		return
   131  	}
   132  
   133  	i := c.Index()
   134  	if i < 0 {
   135  		panic("Delete node not contained in slice")
   136  	}
   137  	v := c.field()
   138  	l := v.Len()
   139  	reflect.Copy(v.Slice(i, l), v.Slice(i+1, l))
   140  	v.Index(l - 1).Set(reflect.Zero(v.Type().Elem()))
   141  	v.SetLen(l - 1)
   142  	c.iter.step--
   143  }
   144  
   145  // InsertAfter inserts n after the current Node in its containing slice.
   146  // If the current Node is not part of a slice, InsertAfter panics.
   147  // Apply does not walk n.
   148  func (c *Cursor) InsertAfter(n ast.Node) {
   149  	i := c.Index()
   150  	if i < 0 {
   151  		panic("InsertAfter node not contained in slice")
   152  	}
   153  	v := c.field()
   154  	v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
   155  	l := v.Len()
   156  	reflect.Copy(v.Slice(i+2, l), v.Slice(i+1, l))
   157  	v.Index(i + 1).Set(reflect.ValueOf(n))
   158  	c.iter.step++
   159  }
   160  
   161  // InsertBefore inserts n before the current Node in its containing slice.
   162  // If the current Node is not part of a slice, InsertBefore panics.
   163  // Apply will not walk n.
   164  func (c *Cursor) InsertBefore(n ast.Node) {
   165  	i := c.Index()
   166  	if i < 0 {
   167  		panic("InsertBefore node not contained in slice")
   168  	}
   169  	v := c.field()
   170  	v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
   171  	l := v.Len()
   172  	reflect.Copy(v.Slice(i+1, l), v.Slice(i, l))
   173  	v.Index(i).Set(reflect.ValueOf(n))
   174  	c.iter.index++
   175  }
   176  
   177  // application carries all the shared data so we can pass it around cheaply.
   178  type application struct {
   179  	pre, post ApplyFunc
   180  	cursor    Cursor
   181  	iter      iterator
   182  }
   183  
   184  func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.Node) {
   185  	// convert typed nil into untyped nil
   186  	if v := reflect.ValueOf(n); v.Kind() == reflect.Ptr && v.IsNil() {
   187  		n = nil
   188  	}
   189  
   190  	// avoid heap-allocating a new cursor for each apply call; reuse a.cursor instead
   191  	saved := a.cursor
   192  	a.cursor.parent = parent
   193  	a.cursor.name = name
   194  	a.cursor.iter = iter
   195  	a.cursor.node = n
   196  
   197  	if a.pre != nil && !a.pre(&a.cursor) {
   198  		a.cursor = saved
   199  		return
   200  	}
   201  
   202  	// walk children
   203  	// (the order of the cases matches the order of the corresponding node types in go/ast)
   204  	switch n := n.(type) {
   205  	case nil:
   206  		// nothing to do
   207  
   208  	// Comments and fields
   209  	case *ast.Comment:
   210  		// nothing to do
   211  
   212  	case *ast.CommentGroup:
   213  		if n != nil {
   214  			a.applyList(n, "List")
   215  		}
   216  
   217  	case *ast.Field:
   218  		a.apply(n, "Doc", nil, n.Doc)
   219  		a.applyList(n, "Names")
   220  		a.apply(n, "Type", nil, n.Type)
   221  		a.apply(n, "Tag", nil, n.Tag)
   222  		a.apply(n, "Comment", nil, n.Comment)
   223  
   224  	case *ast.FieldList:
   225  		a.applyList(n, "List")
   226  
   227  	// Expressions
   228  	case *ast.BadExpr, *ast.Ident, *ast.BasicLit:
   229  		// nothing to do
   230  
   231  	case *ast.Ellipsis:
   232  		a.apply(n, "Elt", nil, n.Elt)
   233  
   234  	case *ast.FuncLit:
   235  		a.apply(n, "Type", nil, n.Type)
   236  		a.apply(n, "Body", nil, n.Body)
   237  
   238  	case *ast.CompositeLit:
   239  		a.apply(n, "Type", nil, n.Type)
   240  		a.applyList(n, "Elts")
   241  
   242  	case *ast.ParenExpr:
   243  		a.apply(n, "X", nil, n.X)
   244  
   245  	case *ast.SelectorExpr:
   246  		a.apply(n, "X", nil, n.X)
   247  		a.apply(n, "Sel", nil, n.Sel)
   248  
   249  	case *ast.IndexExpr:
   250  		a.apply(n, "X", nil, n.X)
   251  		a.apply(n, "Index", nil, n.Index)
   252  
   253  	case *ast.IndexListExpr:
   254  		a.apply(n, "X", nil, n.X)
   255  		a.applyList(n, "Indices")
   256  
   257  	case *ast.SliceExpr:
   258  		a.apply(n, "X", nil, n.X)
   259  		a.apply(n, "Low", nil, n.Low)
   260  		a.apply(n, "High", nil, n.High)
   261  		a.apply(n, "Max", nil, n.Max)
   262  
   263  	case *ast.TypeAssertExpr:
   264  		a.apply(n, "X", nil, n.X)
   265  		a.apply(n, "Type", nil, n.Type)
   266  
   267  	case *ast.CallExpr:
   268  		a.apply(n, "Fun", nil, n.Fun)
   269  		a.applyList(n, "Args")
   270  
   271  	case *ast.StarExpr:
   272  		a.apply(n, "X", nil, n.X)
   273  
   274  	case *ast.UnaryExpr:
   275  		a.apply(n, "X", nil, n.X)
   276  
   277  	case *ast.BinaryExpr:
   278  		a.apply(n, "X", nil, n.X)
   279  		a.apply(n, "Y", nil, n.Y)
   280  
   281  	case *ast.KeyValueExpr:
   282  		a.apply(n, "Key", nil, n.Key)
   283  		a.apply(n, "Value", nil, n.Value)
   284  
   285  	// Types
   286  	case *ast.ArrayType:
   287  		a.apply(n, "Len", nil, n.Len)
   288  		a.apply(n, "Elt", nil, n.Elt)
   289  
   290  	case *ast.StructType:
   291  		a.apply(n, "Fields", nil, n.Fields)
   292  
   293  	case *ast.FuncType:
   294  		if tparams := n.TypeParams; tparams != nil {
   295  			a.apply(n, "TypeParams", nil, tparams)
   296  		}
   297  		a.apply(n, "Params", nil, n.Params)
   298  		a.apply(n, "Results", nil, n.Results)
   299  
   300  	case *ast.InterfaceType:
   301  		a.apply(n, "Methods", nil, n.Methods)
   302  
   303  	case *ast.MapType:
   304  		a.apply(n, "Key", nil, n.Key)
   305  		a.apply(n, "Value", nil, n.Value)
   306  
   307  	case *ast.ChanType:
   308  		a.apply(n, "Value", nil, n.Value)
   309  
   310  	// Statements
   311  	case *ast.BadStmt:
   312  		// nothing to do
   313  
   314  	case *ast.DeclStmt:
   315  		a.apply(n, "Decl", nil, n.Decl)
   316  
   317  	case *ast.EmptyStmt:
   318  		// nothing to do
   319  
   320  	case *ast.LabeledStmt:
   321  		a.apply(n, "Label", nil, n.Label)
   322  		a.apply(n, "Stmt", nil, n.Stmt)
   323  
   324  	case *ast.ExprStmt:
   325  		a.apply(n, "X", nil, n.X)
   326  
   327  	case *ast.SendStmt:
   328  		a.apply(n, "Chan", nil, n.Chan)
   329  		a.apply(n, "Value", nil, n.Value)
   330  
   331  	case *ast.IncDecStmt:
   332  		a.apply(n, "X", nil, n.X)
   333  
   334  	case *ast.AssignStmt:
   335  		a.applyList(n, "Lhs")
   336  		a.applyList(n, "Rhs")
   337  
   338  	case *ast.GoStmt:
   339  		a.apply(n, "Call", nil, n.Call)
   340  
   341  	case *ast.DeferStmt:
   342  		a.apply(n, "Call", nil, n.Call)
   343  
   344  	case *ast.ReturnStmt:
   345  		a.applyList(n, "Results")
   346  
   347  	case *ast.BranchStmt:
   348  		a.apply(n, "Label", nil, n.Label)
   349  
   350  	case *ast.BlockStmt:
   351  		a.applyList(n, "List")
   352  
   353  	case *ast.IfStmt:
   354  		a.apply(n, "Init", nil, n.Init)
   355  		a.apply(n, "Cond", nil, n.Cond)
   356  		a.apply(n, "Body", nil, n.Body)
   357  		a.apply(n, "Else", nil, n.Else)
   358  
   359  	case *ast.CaseClause:
   360  		a.applyList(n, "List")
   361  		a.applyList(n, "Body")
   362  
   363  	case *ast.SwitchStmt:
   364  		a.apply(n, "Init", nil, n.Init)
   365  		a.apply(n, "Tag", nil, n.Tag)
   366  		a.apply(n, "Body", nil, n.Body)
   367  
   368  	case *ast.TypeSwitchStmt:
   369  		a.apply(n, "Init", nil, n.Init)
   370  		a.apply(n, "Assign", nil, n.Assign)
   371  		a.apply(n, "Body", nil, n.Body)
   372  
   373  	case *ast.CommClause:
   374  		a.apply(n, "Comm", nil, n.Comm)
   375  		a.applyList(n, "Body")
   376  
   377  	case *ast.SelectStmt:
   378  		a.apply(n, "Body", nil, n.Body)
   379  
   380  	case *ast.ForStmt:
   381  		a.apply(n, "Init", nil, n.Init)
   382  		a.apply(n, "Cond", nil, n.Cond)
   383  		a.apply(n, "Post", nil, n.Post)
   384  		a.apply(n, "Body", nil, n.Body)
   385  
   386  	case *ast.RangeStmt:
   387  		a.apply(n, "Key", nil, n.Key)
   388  		a.apply(n, "Value", nil, n.Value)
   389  		a.apply(n, "X", nil, n.X)
   390  		a.apply(n, "Body", nil, n.Body)
   391  
   392  	// Declarations
   393  	case *ast.ImportSpec:
   394  		a.apply(n, "Doc", nil, n.Doc)
   395  		a.apply(n, "Name", nil, n.Name)
   396  		a.apply(n, "Path", nil, n.Path)
   397  		a.apply(n, "Comment", nil, n.Comment)
   398  
   399  	case *ast.ValueSpec:
   400  		a.apply(n, "Doc", nil, n.Doc)
   401  		a.applyList(n, "Names")
   402  		a.apply(n, "Type", nil, n.Type)
   403  		a.applyList(n, "Values")
   404  		a.apply(n, "Comment", nil, n.Comment)
   405  
   406  	case *ast.TypeSpec:
   407  		a.apply(n, "Doc", nil, n.Doc)
   408  		a.apply(n, "Name", nil, n.Name)
   409  		if tparams := n.TypeParams; tparams != nil {
   410  			a.apply(n, "TypeParams", nil, tparams)
   411  		}
   412  		a.apply(n, "Type", nil, n.Type)
   413  		a.apply(n, "Comment", nil, n.Comment)
   414  
   415  	case *ast.BadDecl:
   416  		// nothing to do
   417  
   418  	case *ast.GenDecl:
   419  		a.apply(n, "Doc", nil, n.Doc)
   420  		a.applyList(n, "Specs")
   421  
   422  	case *ast.FuncDecl:
   423  		a.apply(n, "Doc", nil, n.Doc)
   424  		a.apply(n, "Recv", nil, n.Recv)
   425  		a.apply(n, "Name", nil, n.Name)
   426  		a.apply(n, "Type", nil, n.Type)
   427  		a.apply(n, "Body", nil, n.Body)
   428  
   429  	// Files and packages
   430  	case *ast.File:
   431  		a.apply(n, "Doc", nil, n.Doc)
   432  		a.apply(n, "Name", nil, n.Name)
   433  		a.applyList(n, "Decls")
   434  		// Don't walk n.Comments; they have either been walked already if
   435  		// they are Doc comments, or they can be easily walked explicitly.
   436  
   437  	case *ast.Package:
   438  		// collect and sort names for reproducible behavior
   439  		var names []string
   440  		for name := range n.Files {
   441  			names = append(names, name)
   442  		}
   443  		sort.Strings(names)
   444  		for _, name := range names {
   445  			a.apply(n, name, nil, n.Files[name])
   446  		}
   447  
   448  	default:
   449  		panic(fmt.Sprintf("Apply: unexpected node type %T", n))
   450  	}
   451  
   452  	if a.post != nil && !a.post(&a.cursor) {
   453  		panic(abort)
   454  	}
   455  
   456  	a.cursor = saved
   457  }
   458  
   459  // An iterator controls iteration over a slice of nodes.
   460  type iterator struct {
   461  	index, step int
   462  }
   463  
   464  func (a *application) applyList(parent ast.Node, name string) {
   465  	// avoid heap-allocating a new iterator for each applyList call; reuse a.iter instead
   466  	saved := a.iter
   467  	a.iter.index = 0
   468  	for {
   469  		// must reload parent.name each time, since cursor modifications might change it
   470  		v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name)
   471  		if a.iter.index >= v.Len() {
   472  			break
   473  		}
   474  
   475  		// element x may be nil in a bad AST - be cautious
   476  		var x ast.Node
   477  		if e := v.Index(a.iter.index); e.IsValid() {
   478  			x = e.Interface().(ast.Node)
   479  		}
   480  
   481  		a.iter.step = 1
   482  		a.apply(parent, name, &a.iter, x)
   483  		a.iter.index += a.iter.step
   484  	}
   485  	a.iter = saved
   486  }
   487  

View as plain text