Source file src/cmd/compile/internal/rangefunc/rewrite.go

     1  // Copyright 2023 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  /*
     6  Package rangefunc rewrites range-over-func to code that doesn't use range-over-funcs.
     7  Rewriting the construct in the front end, before noder, means the functions generated during
     8  the rewrite are available in a noder-generated representation for inlining by the back end.
     9  
    10  # Theory of Operation
    11  
    12  The basic idea is to rewrite
    13  
    14  	for x := range f {
    15  		...
    16  	}
    17  
    18  into
    19  
    20  	f(func(x T) bool {
    21  		...
    22  	})
    23  
    24  But it's not usually that easy.
    25  
    26  # Range variables
    27  
    28  For a range not using :=, the assigned variables cannot be function parameters
    29  in the generated body function. Instead, we allocate fake parameters and
    30  start the body with an assignment. For example:
    31  
    32  	for expr1, expr2 = range f {
    33  		...
    34  	}
    35  
    36  becomes
    37  
    38  	f(func(#p1 T1, #p2 T2) bool {
    39  		expr1, expr2 = #p1, #p2
    40  		...
    41  	})
    42  
    43  (All the generated variables have a # at the start to signal that they
    44  are internal variables when looking at the generated code in a
    45  debugger. Because variables have all been resolved to the specific
    46  objects they represent, there is no danger of using plain "p1" and
    47  colliding with a Go variable named "p1"; the # is just nice to have,
    48  not for correctness.)
    49  
    50  It can also happen that there are fewer range variables than function
    51  arguments, in which case we end up with something like
    52  
    53  	f(func(x T1, _ T2) bool {
    54  		...
    55  	})
    56  
    57  or
    58  
    59  	f(func(#p1 T1, #p2 T2, _ T3) bool {
    60  		expr1, expr2 = #p1, #p2
    61  		...
    62  	})
    63  
    64  # Return
    65  
    66  If the body contains a "break", that break turns into "return false",
    67  to tell f to stop. And if the body contains a "continue", that turns
    68  into "return true", to tell f to proceed with the next value.
    69  Those are the easy cases.
    70  
    71  If the body contains a return or a break/continue/goto L, then we need
    72  to rewrite that into code that breaks out of the loop and then
    73  triggers that control flow. In general we rewrite
    74  
    75  	for x := range f {
    76  		...
    77  	}
    78  
    79  into
    80  
    81  	{
    82  		var #next int
    83  		f(func(x T1) bool {
    84  			...
    85  			return true
    86  		})
    87  		... check #next ...
    88  	}
    89  
    90  The variable #next is an integer code that says what to do when f
    91  returns. Each difficult statement sets #next and then returns false to
    92  stop f.
    93  
    94  A plain "return" rewrites to {#next = -1; return false}.
    95  The return false breaks the loop. Then when f returns, the "check
    96  #next" section includes
    97  
    98  	if #next == -1 { return }
    99  
   100  which causes the return we want.
   101  
   102  Return with arguments is more involved. We need somewhere to store the
   103  arguments while we break out of f, so we add them to the var
   104  declaration, like:
   105  
   106  	{
   107  		var (
   108  			#next int
   109  			#r1 type1
   110  			#r2 type2
   111  		)
   112  		f(func(x T1) bool {
   113  			...
   114  			{
   115  				// return a, b
   116  				#r1, #r2 = a, b
   117  				#next = -2
   118  				return false
   119  			}
   120  			...
   121  			return true
   122  		})
   123  		if #next == -2 { return #r1, #r2 }
   124  	}
   125  
   126  TODO: What about:
   127  
   128  	func f() (x bool) {
   129  		for range g(&x) {
   130  			return true
   131  		}
   132  	}
   133  
   134  	func g(p *bool) func(func() bool) {
   135  		return func(yield func() bool) {
   136  			yield()
   137  			// Is *p true or false here?
   138  		}
   139  	}
   140  
   141  With this rewrite the "return true" is not visible after yield returns,
   142  but maybe it should be?
   143  
   144  # Checking
   145  
   146  To permit checking that an iterator is well-behaved -- that is, that
   147  it does not call the loop body again after it has returned false or
   148  after the entire loop has exited (it might retain a copy of the body
   149  function, or pass it to another goroutine) -- each generated loop has
   150  its own #exitK flag that is checked before each iteration, and set both
   151  at any early exit and after the iteration completes.
   152  
   153  For example:
   154  
   155  	for x := range f {
   156  		...
   157  		if ... { break }
   158  		...
   159  	}
   160  
   161  becomes
   162  
   163  	{
   164  		var #exit1 bool
   165  		f(func(x T1) bool {
   166  			if #exit1 { runtime.panicrangeexit() }
   167  			...
   168  			if ... { #exit1 = true ; return false }
   169  			...
   170  			return true
   171  		})
   172  		#exit1 = true
   173  	}
   174  
   175  # Nested Loops
   176  
   177  So far we've only considered a single loop. If a function contains a
   178  sequence of loops, each can be translated individually. But loops can
   179  be nested. It would work to translate the innermost loop and then
   180  translate the loop around it, and so on, except that there'd be a lot
   181  of rewriting of rewritten code and the overall traversals could end up
   182  taking time quadratic in the depth of the nesting. To avoid all that,
   183  we use a single rewriting pass that handles a top-most range-over-func
   184  loop and all the range-over-func loops it contains at the same time.
   185  
   186  If we need to return from inside a doubly-nested loop, the rewrites
   187  above stay the same, but the check after the inner loop only says
   188  
   189  	if #next < 0 { return false }
   190  
   191  to stop the outer loop so it can do the actual return. That is,
   192  
   193  	for range f {
   194  		for range g {
   195  			...
   196  			return a, b
   197  			...
   198  		}
   199  	}
   200  
   201  becomes
   202  
   203  	{
   204  		var (
   205  			#next int
   206  			#r1 type1
   207  			#r2 type2
   208  		)
   209  		var #exit1 bool
   210  		f(func() {
   211  			if #exit1 { runtime.panicrangeexit() }
   212  			var #exit2 bool
   213  			g(func() {
   214  				if #exit2 { runtime.panicrangeexit() }
   215  				...
   216  				{
   217  					// return a, b
   218  					#r1, #r2 = a, b
   219  					#next = -2
   220  					#exit1, #exit2 = true, true
   221  					return false
   222  				}
   223  				...
   224  				return true
   225  			})
   226  			#exit2 = true
   227  			if #next < 0 {
   228  				return false
   229  			}
   230  			return true
   231  		})
   232  		#exit1 = true
   233  		if #next == -2 {
   234  			return #r1, #r2
   235  		}
   236  	}
   237  
   238  Note that the #next < 0 after the inner loop handles both kinds of
   239  return with a single check.
   240  
   241  # Labeled break/continue of range-over-func loops
   242  
   243  For a labeled break or continue of an outer range-over-func, we
   244  use positive #next values. Any such labeled break or continue
   245  really means "do N breaks" or "do N breaks and 1 continue".
   246  We encode that as perLoopStep*N or perLoopStep*N+1 respectively.
   247  
   248  Loops that might need to propagate a labeled break or continue
   249  add one or both of these to the #next checks:
   250  
   251  	if #next >= 2 {
   252  		#next -= 2
   253  		return false
   254  	}
   255  
   256  	if #next == 1 {
   257  		#next = 0
   258  		return true
   259  	}
   260  
   261  For example
   262  
   263  	F: for range f {
   264  		for range g {
   265  			for range h {
   266  				...
   267  				break F
   268  				...
   269  				...
   270  				continue F
   271  				...
   272  			}
   273  		}
   274  		...
   275  	}
   276  
   277  becomes
   278  
   279  	{
   280  		var #next int
   281  		var #exit1 bool
   282  		f(func() {
   283  			if #exit1 { runtime.panicrangeexit() }
   284  			var #exit2 bool
   285  			g(func() {
   286  				if #exit2 { runtime.panicrangeexit() }
   287  				var #exit3 bool
   288  				h(func() {
   289  					if #exit3 { runtime.panicrangeexit() }
   290  					...
   291  					{
   292  						// break F
   293  						#next = 4
   294  						#exit1, #exit2, #exit3 = true, true, true
   295  						return false
   296  					}
   297  					...
   298  					{
   299  						// continue F
   300  						#next = 3
   301  						#exit2, #exit3 = true, true
   302  						return false
   303  					}
   304  					...
   305  					return true
   306  				})
   307  				#exit3 = true
   308  				if #next >= 2 {
   309  					#next -= 2
   310  					return false
   311  				}
   312  				return true
   313  			})
   314  			#exit2 = true
   315  			if #next >= 2 {
   316  				#next -= 2
   317  				return false
   318  			}
   319  			if #next == 1 {
   320  				#next = 0
   321  				return true
   322  			}
   323  			...
   324  			return true
   325  		})
   326  		#exit1 = true
   327  	}
   328  
   329  Note that the post-h checks only consider a break,
   330  since no generated code tries to continue g.
   331  
   332  # Gotos and other labeled break/continue
   333  
   334  The final control flow translations are goto and break/continue of a
   335  non-range-over-func statement. In both cases, we may need to break out
   336  of one or more range-over-func loops before we can do the actual
   337  control flow statement. Each such break/continue/goto L statement is
   338  assigned a unique negative #next value (below -2, since -1 and -2 are
   339  for the two kinds of return). Then the post-checks for a given loop
   340  test for the specific codes that refer to labels directly targetable
   341  from that block. Otherwise, the generic
   342  
   343  	if #next < 0 { return false }
   344  
   345  check handles stopping the next loop to get one step closer to the label.
   346  
   347  For example
   348  
   349  	Top: print("start\n")
   350  	for range f {
   351  		for range g {
   352  			...
   353  			for range h {
   354  				...
   355  				goto Top
   356  				...
   357  			}
   358  		}
   359  	}
   360  
   361  becomes
   362  
   363  	Top: print("start\n")
   364  	{
   365  		var #next int
   366  		var #exit1 bool
   367  		f(func() {
   368  			if #exit1 { runtime.panicrangeexit() }
   369  			var #exit2 bool
   370  			g(func() {
   371  				if #exit2 { runtime.panicrangeexit() }
   372  				...
   373  				var #exit3 bool
   374  				h(func() {
   375  				if #exit3 { runtime.panicrangeexit() }
   376  					...
   377  					{
   378  						// goto Top
   379  						#next = -3
   380  						#exit1, #exit2, #exit3 = true, true, true
   381  						return false
   382  					}
   383  					...
   384  					return true
   385  				})
   386  				#exit3 = true
   387  				if #next < 0 {
   388  					return false
   389  				}
   390  				return true
   391  			})
   392  			#exit2 = true
   393  			if #next < 0 {
   394  				return false
   395  			}
   396  			return true
   397  		})
   398  		#exit1 = true
   399  		if #next == -3 {
   400  			#next = 0
   401  			goto Top
   402  		}
   403  	}
   404  
   405  Labeled break/continue to non-range-over-funcs are handled the same
   406  way as goto.
   407  
   408  # Defers
   409  
   410  The last wrinkle is handling defer statements. If we have
   411  
   412  	for range f {
   413  		defer print("A")
   414  	}
   415  
   416  we cannot rewrite that into
   417  
   418  	f(func() {
   419  		defer print("A")
   420  	})
   421  
   422  because the deferred code will run at the end of the iteration, not
   423  the end of the containing function. To fix that, the runtime provides
   424  a special hook that lets us obtain a defer "token" representing the
   425  outer function and then use it in a later defer to attach the deferred
   426  code to that outer function.
   427  
   428  Normally,
   429  
   430  	defer print("A")
   431  
   432  compiles to
   433  
   434  	runtime.deferproc(func() { print("A") })
   435  
   436  This changes in a range-over-func. For example:
   437  
   438  	for range f {
   439  		defer print("A")
   440  	}
   441  
   442  compiles to
   443  
   444  	var #defers = runtime.deferrangefunc()
   445  	f(func() {
   446  		runtime.deferprocat(func() { print("A") }, #defers)
   447  	})
   448  
   449  For this rewriting phase, we insert the explicit initialization of
   450  #defers and then attach the #defers variable to the CallStmt
   451  representing the defer. That variable will be propagated to the
   452  backend and will cause the backend to compile the defer using
   453  deferprocat instead of an ordinary deferproc.
   454  
   455  TODO: Could call runtime.deferrangefuncend after f.
   456  */
   457  package rangefunc
   458  
   459  import (
   460  	"cmd/compile/internal/base"
   461  	"cmd/compile/internal/syntax"
   462  	"cmd/compile/internal/types2"
   463  	"fmt"
   464  	"go/constant"
   465  	"os"
   466  )
   467  
   468  // nopos is the zero syntax.Pos.
   469  var nopos syntax.Pos
   470  
   471  // A rewriter implements rewriting the range-over-funcs in a given function.
   472  type rewriter struct {
   473  	pkg   *types2.Package
   474  	info  *types2.Info
   475  	outer *syntax.FuncType
   476  	body  *syntax.BlockStmt
   477  
   478  	// References to important types and values.
   479  	any   types2.Object
   480  	bool  types2.Object
   481  	int   types2.Object
   482  	true  types2.Object
   483  	false types2.Object
   484  
   485  	// Branch numbering, computed as needed.
   486  	branchNext map[branch]int             // branch -> #next value
   487  	labelLoop  map[string]*syntax.ForStmt // label -> innermost rangefunc loop it is declared inside (nil for no loop)
   488  
   489  	// Stack of nodes being visited.
   490  	stack    []syntax.Node // all nodes
   491  	forStack []*forLoop    // range-over-func loops
   492  
   493  	rewritten map[*syntax.ForStmt]syntax.Stmt
   494  
   495  	// Declared variables in generated code for outermost loop.
   496  	declStmt     *syntax.DeclStmt
   497  	nextVar      types2.Object
   498  	retVars      []types2.Object
   499  	defers       types2.Object
   500  	exitVarCount int // exitvars are referenced from their respective loops
   501  }
   502  
   503  // A branch is a single labeled branch.
   504  type branch struct {
   505  	tok   syntax.Token
   506  	label string
   507  }
   508  
   509  // A forLoop describes a single range-over-func loop being processed.
   510  type forLoop struct {
   511  	nfor         *syntax.ForStmt // actual syntax
   512  	exitFlag     *types2.Var     // #exit variable for this loop
   513  	exitFlagDecl *syntax.VarDecl
   514  
   515  	checkRet      bool     // add check for "return" after loop
   516  	checkRetArgs  bool     // add check for "return args" after loop
   517  	checkBreak    bool     // add check for "break" after loop
   518  	checkContinue bool     // add check for "continue" after loop
   519  	checkBranch   []branch // add check for labeled branch after loop
   520  }
   521  
   522  // Rewrite rewrites all the range-over-funcs in the files.
   523  func Rewrite(pkg *types2.Package, info *types2.Info, files []*syntax.File) {
   524  	for _, file := range files {
   525  		syntax.Inspect(file, func(n syntax.Node) bool {
   526  			switch n := n.(type) {
   527  			case *syntax.FuncDecl:
   528  				rewriteFunc(pkg, info, n.Type, n.Body)
   529  				return false
   530  			case *syntax.FuncLit:
   531  				rewriteFunc(pkg, info, n.Type, n.Body)
   532  				return false
   533  			}
   534  			return true
   535  		})
   536  	}
   537  }
   538  
   539  // rewriteFunc rewrites all the range-over-funcs in a single function (a top-level func or a func literal).
   540  // The typ and body are the function's type and body.
   541  func rewriteFunc(pkg *types2.Package, info *types2.Info, typ *syntax.FuncType, body *syntax.BlockStmt) {
   542  	if body == nil {
   543  		return
   544  	}
   545  	r := &rewriter{
   546  		pkg:   pkg,
   547  		info:  info,
   548  		outer: typ,
   549  		body:  body,
   550  	}
   551  	syntax.Inspect(body, r.inspect)
   552  	if (base.Flag.W != 0) && r.forStack != nil {
   553  		syntax.Fdump(os.Stderr, body)
   554  	}
   555  }
   556  
   557  // checkFuncMisuse reports whether to check for misuse of iterator callbacks functions.
   558  func (r *rewriter) checkFuncMisuse() bool {
   559  	return base.Debug.RangeFuncCheck != 0
   560  }
   561  
   562  // inspect is a callback for syntax.Inspect that drives the actual rewriting.
   563  // If it sees a func literal, it kicks off a separate rewrite for that literal.
   564  // Otherwise, it maintains a stack of range-over-func loops and
   565  // converts each in turn.
   566  func (r *rewriter) inspect(n syntax.Node) bool {
   567  	switch n := n.(type) {
   568  	case *syntax.FuncLit:
   569  		rewriteFunc(r.pkg, r.info, n.Type, n.Body)
   570  		return false
   571  
   572  	default:
   573  		// Push n onto stack.
   574  		r.stack = append(r.stack, n)
   575  		if nfor, ok := forRangeFunc(n); ok {
   576  			loop := &forLoop{nfor: nfor}
   577  			r.forStack = append(r.forStack, loop)
   578  			r.startLoop(loop)
   579  		}
   580  
   581  	case nil:
   582  		// n == nil signals that we are done visiting
   583  		// the top-of-stack node's children. Find it.
   584  		n = r.stack[len(r.stack)-1]
   585  
   586  		// If we are inside a range-over-func,
   587  		// take this moment to replace any break/continue/goto/return
   588  		// statements directly contained in this node.
   589  		// Also replace any converted for statements
   590  		// with the rewritten block.
   591  		switch n := n.(type) {
   592  		case *syntax.BlockStmt:
   593  			for i, s := range n.List {
   594  				n.List[i] = r.editStmt(s)
   595  			}
   596  		case *syntax.CaseClause:
   597  			for i, s := range n.Body {
   598  				n.Body[i] = r.editStmt(s)
   599  			}
   600  		case *syntax.CommClause:
   601  			for i, s := range n.Body {
   602  				n.Body[i] = r.editStmt(s)
   603  			}
   604  		case *syntax.LabeledStmt:
   605  			n.Stmt = r.editStmt(n.Stmt)
   606  		}
   607  
   608  		// Pop n.
   609  		if len(r.forStack) > 0 && r.stack[len(r.stack)-1] == r.forStack[len(r.forStack)-1].nfor {
   610  			r.endLoop(r.forStack[len(r.forStack)-1])
   611  			r.forStack = r.forStack[:len(r.forStack)-1]
   612  		}
   613  		r.stack = r.stack[:len(r.stack)-1]
   614  	}
   615  	return true
   616  }
   617  
   618  // startLoop sets up for converting a range-over-func loop.
   619  func (r *rewriter) startLoop(loop *forLoop) {
   620  	// For first loop in function, allocate syntax for any, bool, int, true, and false.
   621  	if r.any == nil {
   622  		r.any = types2.Universe.Lookup("any")
   623  		r.bool = types2.Universe.Lookup("bool")
   624  		r.int = types2.Universe.Lookup("int")
   625  		r.true = types2.Universe.Lookup("true")
   626  		r.false = types2.Universe.Lookup("false")
   627  		r.rewritten = make(map[*syntax.ForStmt]syntax.Stmt)
   628  	}
   629  	if r.checkFuncMisuse() {
   630  		// declare the exit flag for this loop's body
   631  		loop.exitFlag, loop.exitFlagDecl = r.exitVar(loop.nfor.Pos())
   632  	}
   633  }
   634  
   635  // editStmt returns the replacement for the statement x,
   636  // or x itself if it should be left alone.
   637  // This includes the for loops we are converting,
   638  // as left in x.rewritten by r.endLoop.
   639  func (r *rewriter) editStmt(x syntax.Stmt) syntax.Stmt {
   640  	if x, ok := x.(*syntax.ForStmt); ok {
   641  		if s := r.rewritten[x]; s != nil {
   642  			return s
   643  		}
   644  	}
   645  
   646  	if len(r.forStack) > 0 {
   647  		switch x := x.(type) {
   648  		case *syntax.BranchStmt:
   649  			return r.editBranch(x)
   650  		case *syntax.CallStmt:
   651  			if x.Tok == syntax.Defer {
   652  				return r.editDefer(x)
   653  			}
   654  		case *syntax.ReturnStmt:
   655  			return r.editReturn(x)
   656  		}
   657  	}
   658  
   659  	return x
   660  }
   661  
   662  // editDefer returns the replacement for the defer statement x.
   663  // See the "Defers" section in the package doc comment above for more context.
   664  func (r *rewriter) editDefer(x *syntax.CallStmt) syntax.Stmt {
   665  	if r.defers == nil {
   666  		// Declare and initialize the #defers token.
   667  		init := &syntax.CallExpr{
   668  			Fun: runtimeSym(r.info, "deferrangefunc"),
   669  		}
   670  		tv := syntax.TypeAndValue{Type: r.any.Type()}
   671  		tv.SetIsValue()
   672  		init.SetTypeInfo(tv)
   673  		r.defers = r.declVar("#defers", r.any.Type(), init)
   674  	}
   675  
   676  	// Attach the token as an "extra" argument to the defer.
   677  	x.DeferAt = r.useVar(r.defers)
   678  	setPos(x.DeferAt, x.Pos())
   679  	return x
   680  }
   681  
   682  func (r *rewriter) exitVar(pos syntax.Pos) (*types2.Var, *syntax.VarDecl) {
   683  	r.exitVarCount++
   684  
   685  	name := fmt.Sprintf("#exit%d", r.exitVarCount)
   686  	typ := r.bool.Type()
   687  	obj := types2.NewVar(pos, r.pkg, name, typ)
   688  	n := syntax.NewName(pos, name)
   689  	setValueType(n, typ)
   690  	r.info.Defs[n] = obj
   691  
   692  	return obj, &syntax.VarDecl{NameList: []*syntax.Name{n}}
   693  }
   694  
   695  // editReturn returns the replacement for the return statement x.
   696  // See the "Return" section in the package doc comment above for more context.
   697  func (r *rewriter) editReturn(x *syntax.ReturnStmt) syntax.Stmt {
   698  	// #next = -1 is return with no arguments; -2 is return with arguments.
   699  	var next int
   700  	if x.Results == nil {
   701  		next = -1
   702  		r.forStack[0].checkRet = true
   703  	} else {
   704  		next = -2
   705  		r.forStack[0].checkRetArgs = true
   706  	}
   707  
   708  	// Tell the loops along the way to check for a return.
   709  	for _, loop := range r.forStack[1:] {
   710  		loop.checkRet = true
   711  	}
   712  
   713  	// Assign results, set #next, and return false.
   714  	bl := &syntax.BlockStmt{}
   715  	if x.Results != nil {
   716  		if r.retVars == nil {
   717  			for i, a := range r.outer.ResultList {
   718  				obj := r.declVar(fmt.Sprintf("#r%d", i+1), a.Type.GetTypeInfo().Type, nil)
   719  				r.retVars = append(r.retVars, obj)
   720  			}
   721  		}
   722  		bl.List = append(bl.List, &syntax.AssignStmt{Lhs: r.useList(r.retVars), Rhs: x.Results})
   723  	}
   724  	bl.List = append(bl.List, &syntax.AssignStmt{Lhs: r.next(), Rhs: r.intConst(next)})
   725  	if r.checkFuncMisuse() {
   726  		// mark all enclosing loop bodies as exited
   727  		for i := 0; i < len(r.forStack); i++ {
   728  			bl.List = append(bl.List, r.setExitedAt(i))
   729  		}
   730  	}
   731  	bl.List = append(bl.List, &syntax.ReturnStmt{Results: r.useVar(r.false)})
   732  	setPos(bl, x.Pos())
   733  	return bl
   734  }
   735  
   736  // perLoopStep is part of the encoding of loop-spanning control flow
   737  // for function range iterators.  Each multiple of two encodes a "return false"
   738  // passing control to an enclosing iterator; a terminal value of 1 encodes
   739  // "return true" (i.e., local continue) from the body function, and a terminal
   740  // value of 0 encodes executing the remainder of the body function.
   741  const perLoopStep = 2
   742  
   743  // editBranch returns the replacement for the branch statement x,
   744  // or x itself if it should be left alone.
   745  // See the package doc comment above for more context.
   746  func (r *rewriter) editBranch(x *syntax.BranchStmt) syntax.Stmt {
   747  	if x.Tok == syntax.Fallthrough {
   748  		// Fallthrough is unaffected by the rewrite.
   749  		return x
   750  	}
   751  
   752  	// Find target of break/continue/goto in r.forStack.
   753  	// (The target may not be in r.forStack at all.)
   754  	targ := x.Target
   755  	i := len(r.forStack) - 1
   756  	if x.Label == nil && r.forStack[i].nfor != targ {
   757  		// Unlabeled break or continue that's not nfor must be inside nfor. Leave alone.
   758  		return x
   759  	}
   760  	for i >= 0 && r.forStack[i].nfor != targ {
   761  		i--
   762  	}
   763  	// exitFrom is the index of the loop interior to the target of the control flow,
   764  	// if such a loop exists (it does not if i == len(r.forStack) - 1)
   765  	exitFrom := i + 1
   766  
   767  	// Compute the value to assign to #next and the specific return to use.
   768  	var next int
   769  	var ret *syntax.ReturnStmt
   770  	if x.Tok == syntax.Goto || i < 0 {
   771  		// goto Label
   772  		// or break/continue of labeled non-range-over-func loop.
   773  		// We may be able to leave it alone, or we may have to break
   774  		// out of one or more nested loops and then use #next to signal
   775  		// to complete the break/continue/goto.
   776  		// Figure out which range-over-func loop contains the label.
   777  		r.computeBranchNext()
   778  		nfor := r.forStack[len(r.forStack)-1].nfor
   779  		label := x.Label.Value
   780  		targ := r.labelLoop[label]
   781  		if nfor == targ {
   782  			// Label is in the innermost range-over-func loop; use it directly.
   783  			return x
   784  		}
   785  
   786  		// Set #next to the code meaning break/continue/goto label.
   787  		next = r.branchNext[branch{x.Tok, label}]
   788  
   789  		// Break out of nested loops up to targ.
   790  		i := len(r.forStack) - 1
   791  		for i >= 0 && r.forStack[i].nfor != targ {
   792  			i--
   793  		}
   794  		exitFrom = i + 1
   795  
   796  		// Mark loop we exit to get to targ to check for that branch.
   797  		// When i==-1 that's the outermost func body
   798  		top := r.forStack[i+1]
   799  		top.checkBranch = append(top.checkBranch, branch{x.Tok, label})
   800  
   801  		// Mark loops along the way to check for a plain return, so they break.
   802  		for j := i + 2; j < len(r.forStack); j++ {
   803  			r.forStack[j].checkRet = true
   804  		}
   805  
   806  		// In the innermost loop, use a plain "return false".
   807  		ret = &syntax.ReturnStmt{Results: r.useVar(r.false)}
   808  	} else {
   809  		// break/continue of labeled range-over-func loop.
   810  		depth := len(r.forStack) - 1 - i
   811  
   812  		// For continue of innermost loop, use "return true".
   813  		// Otherwise we are breaking the innermost loop, so "return false".
   814  
   815  		if depth == 0 && x.Tok == syntax.Continue {
   816  			ret = &syntax.ReturnStmt{Results: r.useVar(r.true)}
   817  			setPos(ret, x.Pos())
   818  			return ret
   819  		}
   820  		ret = &syntax.ReturnStmt{Results: r.useVar(r.false)}
   821  
   822  		// If this is a simple break, mark this loop as exited and return false.
   823  		// No adjustments to #next.
   824  		if depth == 0 {
   825  			var stmts []syntax.Stmt
   826  			if r.checkFuncMisuse() {
   827  				stmts = []syntax.Stmt{r.setExited(), ret}
   828  			} else {
   829  				stmts = []syntax.Stmt{ret}
   830  			}
   831  			bl := &syntax.BlockStmt{
   832  				List: stmts,
   833  			}
   834  			setPos(bl, x.Pos())
   835  			return bl
   836  		}
   837  
   838  		// The loop inside the one we are break/continue-ing
   839  		// needs to make that happen when we break out of it.
   840  		if x.Tok == syntax.Continue {
   841  			r.forStack[exitFrom].checkContinue = true
   842  		} else {
   843  			exitFrom = i
   844  			r.forStack[exitFrom].checkBreak = true
   845  		}
   846  
   847  		// The loops along the way just need to break.
   848  		for j := exitFrom + 1; j < len(r.forStack); j++ {
   849  			r.forStack[j].checkBreak = true
   850  		}
   851  
   852  		// Set next to break the appropriate number of times;
   853  		// the final time may be a continue, not a break.
   854  		next = perLoopStep * depth
   855  		if x.Tok == syntax.Continue {
   856  			next--
   857  		}
   858  	}
   859  
   860  	// Assign #next = next and do the return.
   861  	as := &syntax.AssignStmt{Lhs: r.next(), Rhs: r.intConst(next)}
   862  	bl := &syntax.BlockStmt{
   863  		List: []syntax.Stmt{as},
   864  	}
   865  
   866  	if r.checkFuncMisuse() {
   867  		// Set #exitK for this loop and those exited by the control flow.
   868  		for i := exitFrom; i < len(r.forStack); i++ {
   869  			bl.List = append(bl.List, r.setExitedAt(i))
   870  		}
   871  	}
   872  
   873  	bl.List = append(bl.List, ret)
   874  	setPos(bl, x.Pos())
   875  	return bl
   876  }
   877  
   878  // computeBranchNext computes the branchNext numbering
   879  // and determines which labels end up inside which range-over-func loop bodies.
   880  func (r *rewriter) computeBranchNext() {
   881  	if r.labelLoop != nil {
   882  		return
   883  	}
   884  
   885  	r.labelLoop = make(map[string]*syntax.ForStmt)
   886  	r.branchNext = make(map[branch]int)
   887  
   888  	var labels []string
   889  	var stack []syntax.Node
   890  	var forStack []*syntax.ForStmt
   891  	forStack = append(forStack, nil)
   892  	syntax.Inspect(r.body, func(n syntax.Node) bool {
   893  		if n != nil {
   894  			stack = append(stack, n)
   895  			if nfor, ok := forRangeFunc(n); ok {
   896  				forStack = append(forStack, nfor)
   897  			}
   898  			if n, ok := n.(*syntax.LabeledStmt); ok {
   899  				l := n.Label.Value
   900  				labels = append(labels, l)
   901  				f := forStack[len(forStack)-1]
   902  				r.labelLoop[l] = f
   903  			}
   904  		} else {
   905  			n := stack[len(stack)-1]
   906  			stack = stack[:len(stack)-1]
   907  			if n == forStack[len(forStack)-1] {
   908  				forStack = forStack[:len(forStack)-1]
   909  			}
   910  		}
   911  		return true
   912  	})
   913  
   914  	// Assign numbers to all the labels we observed.
   915  	used := -2
   916  	for _, l := range labels {
   917  		used -= 3
   918  		r.branchNext[branch{syntax.Break, l}] = used
   919  		r.branchNext[branch{syntax.Continue, l}] = used + 1
   920  		r.branchNext[branch{syntax.Goto, l}] = used + 2
   921  	}
   922  }
   923  
   924  // endLoop finishes the conversion of a range-over-func loop.
   925  // We have inspected and rewritten the body of the loop and can now
   926  // construct the body function and rewrite the for loop into a call
   927  // bracketed by any declarations and checks it requires.
   928  func (r *rewriter) endLoop(loop *forLoop) {
   929  	// Pick apart for range X { ... }
   930  	nfor := loop.nfor
   931  	start, end := nfor.Pos(), nfor.Body.Rbrace // start, end position of for loop
   932  	rclause := nfor.Init.(*syntax.RangeClause)
   933  	rfunc := types2.CoreType(rclause.X.GetTypeInfo().Type).(*types2.Signature) // type of X - func(func(...)bool)
   934  	if rfunc.Params().Len() != 1 {
   935  		base.Fatalf("invalid typecheck of range func")
   936  	}
   937  	ftyp := types2.CoreType(rfunc.Params().At(0).Type()).(*types2.Signature) // func(...) bool
   938  	if ftyp.Results().Len() != 1 {
   939  		base.Fatalf("invalid typecheck of range func")
   940  	}
   941  
   942  	// Build X(bodyFunc)
   943  	call := &syntax.ExprStmt{
   944  		X: &syntax.CallExpr{
   945  			Fun: rclause.X,
   946  			ArgList: []syntax.Expr{
   947  				r.bodyFunc(nfor.Body.List, syntax.UnpackListExpr(rclause.Lhs), rclause.Def, ftyp, start, end),
   948  			},
   949  		},
   950  	}
   951  	setPos(call, start)
   952  
   953  	// Build checks based on #next after X(bodyFunc)
   954  	checks := r.checks(loop, end)
   955  
   956  	// Rewrite for vars := range X { ... } to
   957  	//
   958  	//	{
   959  	//		r.declStmt
   960  	//		call
   961  	//		checks
   962  	//	}
   963  	//
   964  	// The r.declStmt can be added to by this loop or any inner loop
   965  	// during the creation of r.bodyFunc; it is only emitted in the outermost
   966  	// converted range loop.
   967  	block := &syntax.BlockStmt{Rbrace: end}
   968  	setPos(block, start)
   969  	if len(r.forStack) == 1 && r.declStmt != nil {
   970  		setPos(r.declStmt, start)
   971  		block.List = append(block.List, r.declStmt)
   972  	}
   973  
   974  	// declare the exitFlag here so it has proper scope and zeroing
   975  	if r.checkFuncMisuse() {
   976  		exitFlagDecl := &syntax.DeclStmt{DeclList: []syntax.Decl{loop.exitFlagDecl}}
   977  		block.List = append(block.List, exitFlagDecl)
   978  	}
   979  
   980  	// iteratorFunc(bodyFunc)
   981  	block.List = append(block.List, call)
   982  
   983  	if r.checkFuncMisuse() {
   984  		// iteratorFunc has exited, mark the exit flag for the body
   985  		block.List = append(block.List, r.setExited())
   986  	}
   987  	block.List = append(block.List, checks...)
   988  
   989  	if len(r.forStack) == 1 { // ending an outermost loop
   990  		r.declStmt = nil
   991  		r.nextVar = nil
   992  		r.retVars = nil
   993  		r.defers = nil
   994  	}
   995  
   996  	r.rewritten[nfor] = block
   997  }
   998  
   999  func (r *rewriter) setExited() *syntax.AssignStmt {
  1000  	return r.setExitedAt(len(r.forStack) - 1)
  1001  }
  1002  
  1003  func (r *rewriter) setExitedAt(index int) *syntax.AssignStmt {
  1004  	loop := r.forStack[index]
  1005  	return &syntax.AssignStmt{
  1006  		Lhs: r.useVar(loop.exitFlag),
  1007  		Rhs: r.useVar(r.true),
  1008  	}
  1009  }
  1010  
  1011  // bodyFunc converts the loop body (control flow has already been updated)
  1012  // to a func literal that can be passed to the range function.
  1013  //
  1014  // vars is the range variables from the range statement.
  1015  // def indicates whether this is a := range statement.
  1016  // ftyp is the type of the function we are creating
  1017  // start and end are the syntax positions to use for new nodes
  1018  // that should be at the start or end of the loop.
  1019  func (r *rewriter) bodyFunc(body []syntax.Stmt, lhs []syntax.Expr, def bool, ftyp *types2.Signature, start, end syntax.Pos) *syntax.FuncLit {
  1020  	// Starting X(bodyFunc); build up bodyFunc first.
  1021  	var params, results []*types2.Var
  1022  	results = append(results, types2.NewVar(start, nil, "", r.bool.Type()))
  1023  	bodyFunc := &syntax.FuncLit{
  1024  		// Note: Type is ignored but needs to be non-nil to avoid panic in syntax.Inspect.
  1025  		Type: &syntax.FuncType{},
  1026  		Body: &syntax.BlockStmt{
  1027  			List:   []syntax.Stmt{},
  1028  			Rbrace: end,
  1029  		},
  1030  	}
  1031  	setPos(bodyFunc, start)
  1032  
  1033  	for i := 0; i < ftyp.Params().Len(); i++ {
  1034  		typ := ftyp.Params().At(i).Type()
  1035  		var paramVar *types2.Var
  1036  		if i < len(lhs) && def {
  1037  			// Reuse range variable as parameter.
  1038  			x := lhs[i]
  1039  			paramVar = r.info.Defs[x.(*syntax.Name)].(*types2.Var)
  1040  		} else {
  1041  			// Declare new parameter and assign it to range expression.
  1042  			paramVar = types2.NewVar(start, r.pkg, fmt.Sprintf("#p%d", 1+i), typ)
  1043  			if i < len(lhs) {
  1044  				x := lhs[i]
  1045  				as := &syntax.AssignStmt{Lhs: x, Rhs: r.useVar(paramVar)}
  1046  				as.SetPos(x.Pos())
  1047  				setPos(as.Rhs, x.Pos())
  1048  				bodyFunc.Body.List = append(bodyFunc.Body.List, as)
  1049  			}
  1050  		}
  1051  		params = append(params, paramVar)
  1052  	}
  1053  
  1054  	tv := syntax.TypeAndValue{
  1055  		Type: types2.NewSignatureType(nil, nil, nil,
  1056  			types2.NewTuple(params...),
  1057  			types2.NewTuple(results...),
  1058  			false),
  1059  	}
  1060  	tv.SetIsValue()
  1061  	bodyFunc.SetTypeInfo(tv)
  1062  
  1063  	loop := r.forStack[len(r.forStack)-1]
  1064  
  1065  	if r.checkFuncMisuse() {
  1066  		bodyFunc.Body.List = append(bodyFunc.Body.List, r.assertNotExited(start, loop))
  1067  	}
  1068  
  1069  	// Original loop body (already rewritten by editStmt during inspect).
  1070  	bodyFunc.Body.List = append(bodyFunc.Body.List, body...)
  1071  
  1072  	// return true to continue at end of loop body
  1073  	ret := &syntax.ReturnStmt{Results: r.useVar(r.true)}
  1074  	ret.SetPos(end)
  1075  	bodyFunc.Body.List = append(bodyFunc.Body.List, ret)
  1076  
  1077  	return bodyFunc
  1078  }
  1079  
  1080  // checks returns the post-call checks that need to be done for the given loop.
  1081  func (r *rewriter) checks(loop *forLoop, pos syntax.Pos) []syntax.Stmt {
  1082  	var list []syntax.Stmt
  1083  	if len(loop.checkBranch) > 0 {
  1084  		did := make(map[branch]bool)
  1085  		for _, br := range loop.checkBranch {
  1086  			if did[br] {
  1087  				continue
  1088  			}
  1089  			did[br] = true
  1090  			doBranch := &syntax.BranchStmt{Tok: br.tok, Label: &syntax.Name{Value: br.label}}
  1091  			list = append(list, r.ifNext(syntax.Eql, r.branchNext[br], doBranch))
  1092  		}
  1093  	}
  1094  	if len(r.forStack) == 1 {
  1095  		if loop.checkRetArgs {
  1096  			list = append(list, r.ifNext(syntax.Eql, -2, retStmt(r.useList(r.retVars))))
  1097  		}
  1098  		if loop.checkRet {
  1099  			list = append(list, r.ifNext(syntax.Eql, -1, retStmt(nil)))
  1100  		}
  1101  	} else {
  1102  		if loop.checkRetArgs || loop.checkRet {
  1103  			// Note: next < 0 also handles gotos handled by outer loops.
  1104  			// We set checkRet in that case to trigger this check.
  1105  			list = append(list, r.ifNext(syntax.Lss, 0, retStmt(r.useVar(r.false))))
  1106  		}
  1107  		if loop.checkBreak {
  1108  			list = append(list, r.ifNext(syntax.Geq, perLoopStep, retStmt(r.useVar(r.false))))
  1109  		}
  1110  		if loop.checkContinue {
  1111  			list = append(list, r.ifNext(syntax.Eql, perLoopStep-1, retStmt(r.useVar(r.true))))
  1112  		}
  1113  	}
  1114  
  1115  	for _, j := range list {
  1116  		setPos(j, pos)
  1117  	}
  1118  	return list
  1119  }
  1120  
  1121  // retStmt returns a return statement returning the given return values.
  1122  func retStmt(results syntax.Expr) *syntax.ReturnStmt {
  1123  	return &syntax.ReturnStmt{Results: results}
  1124  }
  1125  
  1126  // ifNext returns the statement:
  1127  //
  1128  //	if #next op c { adjust; then }
  1129  //
  1130  // When op is >=, adjust is #next -= c.
  1131  // When op is == and c is not -1 or -2, adjust is #next = 0.
  1132  // Otherwise adjust is omitted.
  1133  func (r *rewriter) ifNext(op syntax.Operator, c int, then syntax.Stmt) syntax.Stmt {
  1134  	nif := &syntax.IfStmt{
  1135  		Cond: &syntax.Operation{Op: op, X: r.next(), Y: r.intConst(c)},
  1136  		Then: &syntax.BlockStmt{
  1137  			List: []syntax.Stmt{then},
  1138  		},
  1139  	}
  1140  	tv := syntax.TypeAndValue{Type: r.bool.Type()}
  1141  	tv.SetIsValue()
  1142  	nif.Cond.SetTypeInfo(tv)
  1143  
  1144  	if op == syntax.Geq {
  1145  		sub := &syntax.AssignStmt{
  1146  			Op:  syntax.Sub,
  1147  			Lhs: r.next(),
  1148  			Rhs: r.intConst(c),
  1149  		}
  1150  		nif.Then.List = []syntax.Stmt{sub, then}
  1151  	}
  1152  	if op == syntax.Eql && c != -1 && c != -2 {
  1153  		clr := &syntax.AssignStmt{
  1154  			Lhs: r.next(),
  1155  			Rhs: r.intConst(0),
  1156  		}
  1157  		nif.Then.List = []syntax.Stmt{clr, then}
  1158  	}
  1159  
  1160  	return nif
  1161  }
  1162  
  1163  // setValueType marks x as a value with type typ.
  1164  func setValueType(x syntax.Expr, typ syntax.Type) {
  1165  	tv := syntax.TypeAndValue{Type: typ}
  1166  	tv.SetIsValue()
  1167  	x.SetTypeInfo(tv)
  1168  }
  1169  
  1170  // assertNotExited returns the statement:
  1171  //
  1172  //	if #exitK { runtime.panicrangeexit() }
  1173  //
  1174  // where #exitK is the exit guard for loop.
  1175  func (r *rewriter) assertNotExited(start syntax.Pos, loop *forLoop) syntax.Stmt {
  1176  	callPanicExpr := &syntax.CallExpr{
  1177  		Fun: runtimeSym(r.info, "panicrangeexit"),
  1178  	}
  1179  	setValueType(callPanicExpr, nil) // no result type
  1180  
  1181  	callPanic := &syntax.ExprStmt{X: callPanicExpr}
  1182  
  1183  	nif := &syntax.IfStmt{
  1184  		Cond: r.useVar(loop.exitFlag),
  1185  		Then: &syntax.BlockStmt{
  1186  			List: []syntax.Stmt{callPanic},
  1187  		},
  1188  	}
  1189  	setPos(nif, start)
  1190  	return nif
  1191  }
  1192  
  1193  // next returns a reference to the #next variable.
  1194  func (r *rewriter) next() *syntax.Name {
  1195  	if r.nextVar == nil {
  1196  		r.nextVar = r.declVar("#next", r.int.Type(), nil)
  1197  	}
  1198  	return r.useVar(r.nextVar)
  1199  }
  1200  
  1201  // forRangeFunc checks whether n is a range-over-func.
  1202  // If so, it returns n.(*syntax.ForStmt), true.
  1203  // Otherwise it returns nil, false.
  1204  func forRangeFunc(n syntax.Node) (*syntax.ForStmt, bool) {
  1205  	nfor, ok := n.(*syntax.ForStmt)
  1206  	if !ok {
  1207  		return nil, false
  1208  	}
  1209  	nrange, ok := nfor.Init.(*syntax.RangeClause)
  1210  	if !ok {
  1211  		return nil, false
  1212  	}
  1213  	_, ok = types2.CoreType(nrange.X.GetTypeInfo().Type).(*types2.Signature)
  1214  	if !ok {
  1215  		return nil, false
  1216  	}
  1217  	return nfor, true
  1218  }
  1219  
  1220  // intConst returns syntax for an integer literal with the given value.
  1221  func (r *rewriter) intConst(c int) *syntax.BasicLit {
  1222  	lit := &syntax.BasicLit{
  1223  		Value: fmt.Sprint(c),
  1224  		Kind:  syntax.IntLit,
  1225  	}
  1226  	tv := syntax.TypeAndValue{Type: r.int.Type(), Value: constant.MakeInt64(int64(c))}
  1227  	tv.SetIsValue()
  1228  	lit.SetTypeInfo(tv)
  1229  	return lit
  1230  }
  1231  
  1232  // useVar returns syntax for a reference to decl, which should be its declaration.
  1233  func (r *rewriter) useVar(obj types2.Object) *syntax.Name {
  1234  	n := syntax.NewName(nopos, obj.Name())
  1235  	tv := syntax.TypeAndValue{Type: obj.Type()}
  1236  	tv.SetIsValue()
  1237  	n.SetTypeInfo(tv)
  1238  	r.info.Uses[n] = obj
  1239  	return n
  1240  }
  1241  
  1242  // useList is useVar for a list of decls.
  1243  func (r *rewriter) useList(vars []types2.Object) syntax.Expr {
  1244  	var new []syntax.Expr
  1245  	for _, obj := range vars {
  1246  		new = append(new, r.useVar(obj))
  1247  	}
  1248  	if len(new) == 1 {
  1249  		return new[0]
  1250  	}
  1251  	return &syntax.ListExpr{ElemList: new}
  1252  }
  1253  
  1254  // declVar declares a variable with a given name type and initializer value.
  1255  func (r *rewriter) declVar(name string, typ types2.Type, init syntax.Expr) *types2.Var {
  1256  	if r.declStmt == nil {
  1257  		r.declStmt = &syntax.DeclStmt{}
  1258  	}
  1259  	stmt := r.declStmt
  1260  	obj := types2.NewVar(stmt.Pos(), r.pkg, name, typ)
  1261  	n := syntax.NewName(stmt.Pos(), name)
  1262  	tv := syntax.TypeAndValue{Type: typ}
  1263  	tv.SetIsValue()
  1264  	n.SetTypeInfo(tv)
  1265  	r.info.Defs[n] = obj
  1266  	stmt.DeclList = append(stmt.DeclList, &syntax.VarDecl{
  1267  		NameList: []*syntax.Name{n},
  1268  		// Note: Type is ignored
  1269  		Values: init,
  1270  	})
  1271  	return obj
  1272  }
  1273  
  1274  // declType declares a type with the given name and type.
  1275  // This is more like "type name = typ" than "type name typ".
  1276  func declType(pos syntax.Pos, name string, typ types2.Type) *syntax.Name {
  1277  	n := syntax.NewName(pos, name)
  1278  	n.SetTypeInfo(syntax.TypeAndValue{Type: typ})
  1279  	return n
  1280  }
  1281  
  1282  // runtimePkg is a fake runtime package that contains what we need to refer to in package runtime.
  1283  var runtimePkg = func() *types2.Package {
  1284  	var nopos syntax.Pos
  1285  	pkg := types2.NewPackage("runtime", "runtime")
  1286  	anyType := types2.Universe.Lookup("any").Type()
  1287  
  1288  	// func deferrangefunc() unsafe.Pointer
  1289  	obj := types2.NewFunc(nopos, pkg, "deferrangefunc", types2.NewSignatureType(nil, nil, nil, nil, types2.NewTuple(types2.NewParam(nopos, pkg, "extra", anyType)), false))
  1290  	pkg.Scope().Insert(obj)
  1291  
  1292  	// func panicrangeexit()
  1293  	obj = types2.NewFunc(nopos, pkg, "panicrangeexit", types2.NewSignatureType(nil, nil, nil, nil, nil, false))
  1294  	pkg.Scope().Insert(obj)
  1295  
  1296  	return pkg
  1297  }()
  1298  
  1299  // runtimeSym returns a reference to a symbol in the fake runtime package.
  1300  func runtimeSym(info *types2.Info, name string) *syntax.Name {
  1301  	obj := runtimePkg.Scope().Lookup(name)
  1302  	n := syntax.NewName(nopos, "runtime."+name)
  1303  	tv := syntax.TypeAndValue{Type: obj.Type()}
  1304  	tv.SetIsValue()
  1305  	tv.SetIsRuntimeHelper()
  1306  	n.SetTypeInfo(tv)
  1307  	info.Uses[n] = obj
  1308  	return n
  1309  }
  1310  
  1311  // setPos walks the top structure of x that has no position assigned
  1312  // and assigns it all to have position pos.
  1313  // When setPos encounters a syntax node with a position assigned,
  1314  // setPos does not look inside that node.
  1315  // setPos only needs to handle syntax we create in this package;
  1316  // all other syntax should have positions assigned already.
  1317  func setPos(x syntax.Node, pos syntax.Pos) {
  1318  	if x == nil {
  1319  		return
  1320  	}
  1321  	syntax.Inspect(x, func(n syntax.Node) bool {
  1322  		if n == nil || n.Pos() != nopos {
  1323  			return false
  1324  		}
  1325  		n.SetPos(pos)
  1326  		switch n := n.(type) {
  1327  		case *syntax.BlockStmt:
  1328  			if n.Rbrace == nopos {
  1329  				n.Rbrace = pos
  1330  			}
  1331  		}
  1332  		return true
  1333  	})
  1334  }
  1335  

View as plain text