Black Lives Matter. Support the Equal Justice Initiative.

Source file src/cmd/compile/internal/ssa/poset.go

Documentation: cmd/compile/internal/ssa

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssa
     6  
     7  import (
     8  	"fmt"
     9  	"os"
    10  )
    11  
    12  // If true, check poset integrity after every mutation
    13  var debugPoset = false
    14  
    15  const uintSize = 32 << (^uint(0) >> 32 & 1) // 32 or 64
    16  
    17  // bitset is a bit array for dense indexes.
    18  type bitset []uint
    19  
    20  func newBitset(n int) bitset {
    21  	return make(bitset, (n+uintSize-1)/uintSize)
    22  }
    23  
    24  func (bs bitset) Reset() {
    25  	for i := range bs {
    26  		bs[i] = 0
    27  	}
    28  }
    29  
    30  func (bs bitset) Set(idx uint32) {
    31  	bs[idx/uintSize] |= 1 << (idx % uintSize)
    32  }
    33  
    34  func (bs bitset) Clear(idx uint32) {
    35  	bs[idx/uintSize] &^= 1 << (idx % uintSize)
    36  }
    37  
    38  func (bs bitset) Test(idx uint32) bool {
    39  	return bs[idx/uintSize]&(1<<(idx%uintSize)) != 0
    40  }
    41  
    42  type undoType uint8
    43  
    44  const (
    45  	undoInvalid     undoType = iota
    46  	undoCheckpoint           // a checkpoint to group undo passes
    47  	undoSetChl               // change back left child of undo.idx to undo.edge
    48  	undoSetChr               // change back right child of undo.idx to undo.edge
    49  	undoNonEqual             // forget that SSA value undo.ID is non-equal to undo.idx (another ID)
    50  	undoNewNode              // remove new node created for SSA value undo.ID
    51  	undoNewConstant          // remove the constant node idx from the constants map
    52  	undoAliasNode            // unalias SSA value undo.ID so that it points back to node index undo.idx
    53  	undoNewRoot              // remove node undo.idx from root list
    54  	undoChangeRoot           // remove node undo.idx from root list, and put back undo.edge.Target instead
    55  	undoMergeRoot            // remove node undo.idx from root list, and put back its children instead
    56  )
    57  
    58  // posetUndo represents an undo pass to be performed.
    59  // It's an union of fields that can be used to store information,
    60  // and typ is the discriminant, that specifies which kind
    61  // of operation must be performed. Not all fields are always used.
    62  type posetUndo struct {
    63  	typ  undoType
    64  	idx  uint32
    65  	ID   ID
    66  	edge posetEdge
    67  }
    68  
    69  const (
    70  	// Make poset handle constants as unsigned numbers.
    71  	posetFlagUnsigned = 1 << iota
    72  )
    73  
    74  // A poset edge. The zero value is the null/empty edge.
    75  // Packs target node index (31 bits) and strict flag (1 bit).
    76  type posetEdge uint32
    77  
    78  func newedge(t uint32, strict bool) posetEdge {
    79  	s := uint32(0)
    80  	if strict {
    81  		s = 1
    82  	}
    83  	return posetEdge(t<<1 | s)
    84  }
    85  func (e posetEdge) Target() uint32 { return uint32(e) >> 1 }
    86  func (e posetEdge) Strict() bool   { return uint32(e)&1 != 0 }
    87  func (e posetEdge) String() string {
    88  	s := fmt.Sprint(e.Target())
    89  	if e.Strict() {
    90  		s += "*"
    91  	}
    92  	return s
    93  }
    94  
    95  // posetNode is a node of a DAG within the poset.
    96  type posetNode struct {
    97  	l, r posetEdge
    98  }
    99  
   100  // poset is a union-find data structure that can represent a partially ordered set
   101  // of SSA values. Given a binary relation that creates a partial order (eg: '<'),
   102  // clients can record relations between SSA values using SetOrder, and later
   103  // check relations (in the transitive closure) with Ordered. For instance,
   104  // if SetOrder is called to record that A<B and B<C, Ordered will later confirm
   105  // that A<C.
   106  //
   107  // It is possible to record equality relations between SSA values with SetEqual and check
   108  // equality with Equal. Equality propagates into the transitive closure for the partial
   109  // order so that if we know that A<B<C and later learn that A==D, Ordered will return
   110  // true for D<C.
   111  //
   112  // It is also possible to record inequality relations between nodes with SetNonEqual;
   113  // non-equality relations are not transitive, but they can still be useful: for instance
   114  // if we know that A<=B and later we learn that A!=B, we can deduce that A<B.
   115  // NonEqual can be used to check whether it is known that the nodes are different, either
   116  // because SetNonEqual was called before, or because we know that they are strictly ordered.
   117  //
   118  // poset will refuse to record new relations that contradict existing relations:
   119  // for instance if A<B<C, calling SetOrder for C<A will fail returning false; also
   120  // calling SetEqual for C==A will fail.
   121  //
   122  // poset is implemented as a forest of DAGs; in each DAG, if there is a path (directed)
   123  // from node A to B, it means that A<B (or A<=B). Equality is represented by mapping
   124  // two SSA values to the same DAG node; when a new equality relation is recorded
   125  // between two existing nodes,the nodes are merged, adjusting incoming and outgoing edges.
   126  //
   127  // Constants are specially treated. When a constant is added to the poset, it is
   128  // immediately linked to other constants already present; so for instance if the
   129  // poset knows that x<=3, and then x is tested against 5, 5 is first added and linked
   130  // 3 (using 3<5), so that the poset knows that x<=3<5; at that point, it is able
   131  // to answer x<5 correctly. This means that all constants are always within the same
   132  // DAG; as an implementation detail, we enfoce that the DAG containtining the constants
   133  // is always the first in the forest.
   134  //
   135  // poset is designed to be memory efficient and do little allocations during normal usage.
   136  // Most internal data structures are pre-allocated and flat, so for instance adding a
   137  // new relation does not cause any allocation. For performance reasons,
   138  // each node has only up to two outgoing edges (like a binary tree), so intermediate
   139  // "dummy" nodes are required to represent more than two relations. For instance,
   140  // to record that A<I, A<J, A<K (with no known relation between I,J,K), we create the
   141  // following DAG:
   142  //
   143  //         A
   144  //        / \
   145  //       I  dummy
   146  //           /  \
   147  //          J    K
   148  //
   149  type poset struct {
   150  	lastidx   uint32            // last generated dense index
   151  	flags     uint8             // internal flags
   152  	values    map[ID]uint32     // map SSA values to dense indexes
   153  	constants map[int64]uint32  // record SSA constants together with their value
   154  	nodes     []posetNode       // nodes (in all DAGs)
   155  	roots     []uint32          // list of root nodes (forest)
   156  	noneq     map[uint32]bitset // non-equal relations
   157  	undo      []posetUndo       // undo chain
   158  }
   159  
   160  func newPoset() *poset {
   161  	return &poset{
   162  		values:    make(map[ID]uint32),
   163  		constants: make(map[int64]uint32, 8),
   164  		nodes:     make([]posetNode, 1, 16),
   165  		roots:     make([]uint32, 0, 4),
   166  		noneq:     make(map[uint32]bitset),
   167  		undo:      make([]posetUndo, 0, 4),
   168  	}
   169  }
   170  
   171  func (po *poset) SetUnsigned(uns bool) {
   172  	if uns {
   173  		po.flags |= posetFlagUnsigned
   174  	} else {
   175  		po.flags &^= posetFlagUnsigned
   176  	}
   177  }
   178  
   179  // Handle children
   180  func (po *poset) setchl(i uint32, l posetEdge) { po.nodes[i].l = l }
   181  func (po *poset) setchr(i uint32, r posetEdge) { po.nodes[i].r = r }
   182  func (po *poset) chl(i uint32) uint32          { return po.nodes[i].l.Target() }
   183  func (po *poset) chr(i uint32) uint32          { return po.nodes[i].r.Target() }
   184  func (po *poset) children(i uint32) (posetEdge, posetEdge) {
   185  	return po.nodes[i].l, po.nodes[i].r
   186  }
   187  
   188  // upush records a new undo step. It can be used for simple
   189  // undo passes that record up to one index and one edge.
   190  func (po *poset) upush(typ undoType, p uint32, e posetEdge) {
   191  	po.undo = append(po.undo, posetUndo{typ: typ, idx: p, edge: e})
   192  }
   193  
   194  // upushnew pushes an undo pass for a new node
   195  func (po *poset) upushnew(id ID, idx uint32) {
   196  	po.undo = append(po.undo, posetUndo{typ: undoNewNode, ID: id, idx: idx})
   197  }
   198  
   199  // upushneq pushes a new undo pass for a nonequal relation
   200  func (po *poset) upushneq(idx1 uint32, idx2 uint32) {
   201  	po.undo = append(po.undo, posetUndo{typ: undoNonEqual, ID: ID(idx1), idx: idx2})
   202  }
   203  
   204  // upushalias pushes a new undo pass for aliasing two nodes
   205  func (po *poset) upushalias(id ID, i2 uint32) {
   206  	po.undo = append(po.undo, posetUndo{typ: undoAliasNode, ID: id, idx: i2})
   207  }
   208  
   209  // upushconst pushes a new undo pass for a new constant
   210  func (po *poset) upushconst(idx uint32, old uint32) {
   211  	po.undo = append(po.undo, posetUndo{typ: undoNewConstant, idx: idx, ID: ID(old)})
   212  }
   213  
   214  // addchild adds i2 as direct child of i1.
   215  func (po *poset) addchild(i1, i2 uint32, strict bool) {
   216  	i1l, i1r := po.children(i1)
   217  	e2 := newedge(i2, strict)
   218  
   219  	if i1l == 0 {
   220  		po.setchl(i1, e2)
   221  		po.upush(undoSetChl, i1, 0)
   222  	} else if i1r == 0 {
   223  		po.setchr(i1, e2)
   224  		po.upush(undoSetChr, i1, 0)
   225  	} else {
   226  		// If n1 already has two children, add an intermediate dummy
   227  		// node to record the relation correctly (without relating
   228  		// n2 to other existing nodes). Use a non-deterministic value
   229  		// to decide whether to append on the left or the right, to avoid
   230  		// creating degenerated chains.
   231  		//
   232  		//      n1
   233  		//     /  \
   234  		//   i1l  dummy
   235  		//        /   \
   236  		//      i1r   n2
   237  		//
   238  		dummy := po.newnode(nil)
   239  		if (i1^i2)&1 != 0 { // non-deterministic
   240  			po.setchl(dummy, i1r)
   241  			po.setchr(dummy, e2)
   242  			po.setchr(i1, newedge(dummy, false))
   243  			po.upush(undoSetChr, i1, i1r)
   244  		} else {
   245  			po.setchl(dummy, i1l)
   246  			po.setchr(dummy, e2)
   247  			po.setchl(i1, newedge(dummy, false))
   248  			po.upush(undoSetChl, i1, i1l)
   249  		}
   250  	}
   251  }
   252  
   253  // newnode allocates a new node bound to SSA value n.
   254  // If n is nil, this is a dummy node (= only used internally).
   255  func (po *poset) newnode(n *Value) uint32 {
   256  	i := po.lastidx + 1
   257  	po.lastidx++
   258  	po.nodes = append(po.nodes, posetNode{})
   259  	if n != nil {
   260  		if po.values[n.ID] != 0 {
   261  			panic("newnode for Value already inserted")
   262  		}
   263  		po.values[n.ID] = i
   264  		po.upushnew(n.ID, i)
   265  	} else {
   266  		po.upushnew(0, i)
   267  	}
   268  	return i
   269  }
   270  
   271  // lookup searches for a SSA value into the forest of DAGS, and return its node.
   272  // Constants are materialized on the fly during lookup.
   273  func (po *poset) lookup(n *Value) (uint32, bool) {
   274  	i, f := po.values[n.ID]
   275  	if !f && n.isGenericIntConst() {
   276  		po.newconst(n)
   277  		i, f = po.values[n.ID]
   278  	}
   279  	return i, f
   280  }
   281  
   282  // newconst creates a node for a constant. It links it to other constants, so
   283  // that n<=5 is detected true when n<=3 is known to be true.
   284  // TODO: this is O(N), fix it.
   285  func (po *poset) newconst(n *Value) {
   286  	if !n.isGenericIntConst() {
   287  		panic("newconst on non-constant")
   288  	}
   289  
   290  	// If the same constant is already present in the poset through a different
   291  	// Value, just alias to it without allocating a new node.
   292  	val := n.AuxInt
   293  	if po.flags&posetFlagUnsigned != 0 {
   294  		val = int64(n.AuxUnsigned())
   295  	}
   296  	if c, found := po.constants[val]; found {
   297  		po.values[n.ID] = c
   298  		po.upushalias(n.ID, 0)
   299  		return
   300  	}
   301  
   302  	// Create the new node for this constant
   303  	i := po.newnode(n)
   304  
   305  	// If this is the first constant, put it as a new root, as
   306  	// we can't record an existing connection so we don't have
   307  	// a specific DAG to add it to. Notice that we want all
   308  	// constants to be in root #0, so make sure the new root
   309  	// goes there.
   310  	if len(po.constants) == 0 {
   311  		idx := len(po.roots)
   312  		po.roots = append(po.roots, i)
   313  		po.roots[0], po.roots[idx] = po.roots[idx], po.roots[0]
   314  		po.upush(undoNewRoot, i, 0)
   315  		po.constants[val] = i
   316  		po.upushconst(i, 0)
   317  		return
   318  	}
   319  
   320  	// Find the lower and upper bound among existing constants. That is,
   321  	// find the higher constant that is lower than the one that we're adding,
   322  	// and the lower constant that is higher.
   323  	// The loop is duplicated to handle signed and unsigned comparison,
   324  	// depending on how the poset was configured.
   325  	var lowerptr, higherptr uint32
   326  
   327  	if po.flags&posetFlagUnsigned != 0 {
   328  		var lower, higher uint64
   329  		val1 := n.AuxUnsigned()
   330  		for val2, ptr := range po.constants {
   331  			val2 := uint64(val2)
   332  			if val1 == val2 {
   333  				panic("unreachable")
   334  			}
   335  			if val2 < val1 && (lowerptr == 0 || val2 > lower) {
   336  				lower = val2
   337  				lowerptr = ptr
   338  			} else if val2 > val1 && (higherptr == 0 || val2 < higher) {
   339  				higher = val2
   340  				higherptr = ptr
   341  			}
   342  		}
   343  	} else {
   344  		var lower, higher int64
   345  		val1 := n.AuxInt
   346  		for val2, ptr := range po.constants {
   347  			if val1 == val2 {
   348  				panic("unreachable")
   349  			}
   350  			if val2 < val1 && (lowerptr == 0 || val2 > lower) {
   351  				lower = val2
   352  				lowerptr = ptr
   353  			} else if val2 > val1 && (higherptr == 0 || val2 < higher) {
   354  				higher = val2
   355  				higherptr = ptr
   356  			}
   357  		}
   358  	}
   359  
   360  	if lowerptr == 0 && higherptr == 0 {
   361  		// This should not happen, as at least one
   362  		// other constant must exist if we get here.
   363  		panic("no constant found")
   364  	}
   365  
   366  	// Create the new node and connect it to the bounds, so that
   367  	// lower < n < higher. We could have found both bounds or only one
   368  	// of them, depending on what other constants are present in the poset.
   369  	// Notice that we always link constants together, so they
   370  	// are always part of the same DAG.
   371  	switch {
   372  	case lowerptr != 0 && higherptr != 0:
   373  		// Both bounds are present, record lower < n < higher.
   374  		po.addchild(lowerptr, i, true)
   375  		po.addchild(i, higherptr, true)
   376  
   377  	case lowerptr != 0:
   378  		// Lower bound only, record lower < n.
   379  		po.addchild(lowerptr, i, true)
   380  
   381  	case higherptr != 0:
   382  		// Higher bound only. To record n < higher, we need
   383  		// a dummy root:
   384  		//
   385  		//        dummy
   386  		//        /   \
   387  		//      root   \
   388  		//       /      n
   389  		//     ....    /
   390  		//       \    /
   391  		//       higher
   392  		//
   393  		i2 := higherptr
   394  		r2 := po.findroot(i2)
   395  		if r2 != po.roots[0] { // all constants should be in root #0
   396  			panic("constant not in root #0")
   397  		}
   398  		dummy := po.newnode(nil)
   399  		po.changeroot(r2, dummy)
   400  		po.upush(undoChangeRoot, dummy, newedge(r2, false))
   401  		po.addchild(dummy, r2, false)
   402  		po.addchild(dummy, i, false)
   403  		po.addchild(i, i2, true)
   404  	}
   405  
   406  	po.constants[val] = i
   407  	po.upushconst(i, 0)
   408  }
   409  
   410  // aliasnewnode records that a single node n2 (not in the poset yet) is an alias
   411  // of the master node n1.
   412  func (po *poset) aliasnewnode(n1, n2 *Value) {
   413  	i1, i2 := po.values[n1.ID], po.values[n2.ID]
   414  	if i1 == 0 || i2 != 0 {
   415  		panic("aliasnewnode invalid arguments")
   416  	}
   417  
   418  	po.values[n2.ID] = i1
   419  	po.upushalias(n2.ID, 0)
   420  }
   421  
   422  // aliasnodes records that all the nodes i2s are aliases of a single master node n1.
   423  // aliasnodes takes care of rearranging the DAG, changing references of parent/children
   424  // of nodes in i2s, so that they point to n1 instead.
   425  // Complexity is O(n) (with n being the total number of nodes in the poset, not just
   426  // the number of nodes being aliased).
   427  func (po *poset) aliasnodes(n1 *Value, i2s bitset) {
   428  	i1 := po.values[n1.ID]
   429  	if i1 == 0 {
   430  		panic("aliasnode for non-existing node")
   431  	}
   432  	if i2s.Test(i1) {
   433  		panic("aliasnode i2s contains n1 node")
   434  	}
   435  
   436  	// Go through all the nodes to adjust parent/chidlren of nodes in i2s
   437  	for idx, n := range po.nodes {
   438  		// Do not touch i1 itself, otherwise we can create useless self-loops
   439  		if uint32(idx) == i1 {
   440  			continue
   441  		}
   442  		l, r := n.l, n.r
   443  
   444  		// Rename all references to i2s into i1
   445  		if i2s.Test(l.Target()) {
   446  			po.setchl(uint32(idx), newedge(i1, l.Strict()))
   447  			po.upush(undoSetChl, uint32(idx), l)
   448  		}
   449  		if i2s.Test(r.Target()) {
   450  			po.setchr(uint32(idx), newedge(i1, r.Strict()))
   451  			po.upush(undoSetChr, uint32(idx), r)
   452  		}
   453  
   454  		// Connect all chidren of i2s to i1 (unless those children
   455  		// are in i2s as well, in which case it would be useless)
   456  		if i2s.Test(uint32(idx)) {
   457  			if l != 0 && !i2s.Test(l.Target()) {
   458  				po.addchild(i1, l.Target(), l.Strict())
   459  			}
   460  			if r != 0 && !i2s.Test(r.Target()) {
   461  				po.addchild(i1, r.Target(), r.Strict())
   462  			}
   463  			po.setchl(uint32(idx), 0)
   464  			po.setchr(uint32(idx), 0)
   465  			po.upush(undoSetChl, uint32(idx), l)
   466  			po.upush(undoSetChr, uint32(idx), r)
   467  		}
   468  	}
   469  
   470  	// Reassign all existing IDs that point to i2 to i1.
   471  	// This includes n2.ID.
   472  	for k, v := range po.values {
   473  		if i2s.Test(v) {
   474  			po.values[k] = i1
   475  			po.upushalias(k, v)
   476  		}
   477  	}
   478  
   479  	// If one of the aliased nodes is a constant, then make sure
   480  	// po.constants is updated to point to the master node.
   481  	for val, idx := range po.constants {
   482  		if i2s.Test(idx) {
   483  			po.constants[val] = i1
   484  			po.upushconst(i1, idx)
   485  		}
   486  	}
   487  }
   488  
   489  func (po *poset) isroot(r uint32) bool {
   490  	for i := range po.roots {
   491  		if po.roots[i] == r {
   492  			return true
   493  		}
   494  	}
   495  	return false
   496  }
   497  
   498  func (po *poset) changeroot(oldr, newr uint32) {
   499  	for i := range po.roots {
   500  		if po.roots[i] == oldr {
   501  			po.roots[i] = newr
   502  			return
   503  		}
   504  	}
   505  	panic("changeroot on non-root")
   506  }
   507  
   508  func (po *poset) removeroot(r uint32) {
   509  	for i := range po.roots {
   510  		if po.roots[i] == r {
   511  			po.roots = append(po.roots[:i], po.roots[i+1:]...)
   512  			return
   513  		}
   514  	}
   515  	panic("removeroot on non-root")
   516  }
   517  
   518  // dfs performs a depth-first search within the DAG whose root is r.
   519  // f is the visit function called for each node; if it returns true,
   520  // the search is aborted and true is returned. The root node is
   521  // visited too.
   522  // If strict, ignore edges across a path until at least one
   523  // strict edge is found. For instance, for a chain A<=B<=C<D<=E<F,
   524  // a strict walk visits D,E,F.
   525  // If the visit ends, false is returned.
   526  func (po *poset) dfs(r uint32, strict bool, f func(i uint32) bool) bool {
   527  	closed := newBitset(int(po.lastidx + 1))
   528  	open := make([]uint32, 1, 64)
   529  	open[0] = r
   530  
   531  	if strict {
   532  		// Do a first DFS; walk all paths and stop when we find a strict
   533  		// edge, building a "next" list of nodes reachable through strict
   534  		// edges. This will be the bootstrap open list for the real DFS.
   535  		next := make([]uint32, 0, 64)
   536  
   537  		for len(open) > 0 {
   538  			i := open[len(open)-1]
   539  			open = open[:len(open)-1]
   540  
   541  			// Don't visit the same node twice. Notice that all nodes
   542  			// across non-strict paths are still visited at least once, so
   543  			// a non-strict path can never obscure a strict path to the
   544  			// same node.
   545  			if !closed.Test(i) {
   546  				closed.Set(i)
   547  
   548  				l, r := po.children(i)
   549  				if l != 0 {
   550  					if l.Strict() {
   551  						next = append(next, l.Target())
   552  					} else {
   553  						open = append(open, l.Target())
   554  					}
   555  				}
   556  				if r != 0 {
   557  					if r.Strict() {
   558  						next = append(next, r.Target())
   559  					} else {
   560  						open = append(open, r.Target())
   561  					}
   562  				}
   563  			}
   564  		}
   565  		open = next
   566  		closed.Reset()
   567  	}
   568  
   569  	for len(open) > 0 {
   570  		i := open[len(open)-1]
   571  		open = open[:len(open)-1]
   572  
   573  		if !closed.Test(i) {
   574  			if f(i) {
   575  				return true
   576  			}
   577  			closed.Set(i)
   578  			l, r := po.children(i)
   579  			if l != 0 {
   580  				open = append(open, l.Target())
   581  			}
   582  			if r != 0 {
   583  				open = append(open, r.Target())
   584  			}
   585  		}
   586  	}
   587  	return false
   588  }
   589  
   590  // Returns true if there is a path from i1 to i2.
   591  // If strict ==  true: if the function returns true, then i1 <  i2.
   592  // If strict == false: if the function returns true, then i1 <= i2.
   593  // If the function returns false, no relation is known.
   594  func (po *poset) reaches(i1, i2 uint32, strict bool) bool {
   595  	return po.dfs(i1, strict, func(n uint32) bool {
   596  		return n == i2
   597  	})
   598  }
   599  
   600  // findroot finds i's root, that is which DAG contains i.
   601  // Returns the root; if i is itself a root, it is returned.
   602  // Panic if i is not in any DAG.
   603  func (po *poset) findroot(i uint32) uint32 {
   604  	// TODO(rasky): if needed, a way to speed up this search is
   605  	// storing a bitset for each root using it as a mini bloom filter
   606  	// of nodes present under that root.
   607  	for _, r := range po.roots {
   608  		if po.reaches(r, i, false) {
   609  			return r
   610  		}
   611  	}
   612  	panic("findroot didn't find any root")
   613  }
   614  
   615  // mergeroot merges two DAGs into one DAG by creating a new dummy root
   616  func (po *poset) mergeroot(r1, r2 uint32) uint32 {
   617  	// Root #0 is special as it contains all constants. Since mergeroot
   618  	// discards r2 as root and keeps r1, make sure that r2 is not root #0,
   619  	// otherwise constants would move to a different root.
   620  	if r2 == po.roots[0] {
   621  		r1, r2 = r2, r1
   622  	}
   623  	r := po.newnode(nil)
   624  	po.setchl(r, newedge(r1, false))
   625  	po.setchr(r, newedge(r2, false))
   626  	po.changeroot(r1, r)
   627  	po.removeroot(r2)
   628  	po.upush(undoMergeRoot, r, 0)
   629  	return r
   630  }
   631  
   632  // collapsepath marks n1 and n2 as equal and collapses as equal all
   633  // nodes across all paths between n1 and n2. If a strict edge is
   634  // found, the function does not modify the DAG and returns false.
   635  // Complexity is O(n).
   636  func (po *poset) collapsepath(n1, n2 *Value) bool {
   637  	i1, i2 := po.values[n1.ID], po.values[n2.ID]
   638  	if po.reaches(i1, i2, true) {
   639  		return false
   640  	}
   641  
   642  	// Find all the paths from i1 to i2
   643  	paths := po.findpaths(i1, i2)
   644  	// Mark all nodes in all the paths as aliases of n1
   645  	// (excluding n1 itself)
   646  	paths.Clear(i1)
   647  	po.aliasnodes(n1, paths)
   648  	return true
   649  }
   650  
   651  // findpaths is a recursive function that calculates all paths from cur to dst
   652  // and return them as a bitset (the index of a node is set in the bitset if
   653  // that node is on at least one path from cur to dst).
   654  // We do a DFS from cur (stopping going deep any time we reach dst, if ever),
   655  // and mark as part of the paths any node that has a children which is already
   656  // part of the path (or is dst itself).
   657  func (po *poset) findpaths(cur, dst uint32) bitset {
   658  	seen := newBitset(int(po.lastidx + 1))
   659  	path := newBitset(int(po.lastidx + 1))
   660  	path.Set(dst)
   661  	po.findpaths1(cur, dst, seen, path)
   662  	return path
   663  }
   664  
   665  func (po *poset) findpaths1(cur, dst uint32, seen bitset, path bitset) {
   666  	if cur == dst {
   667  		return
   668  	}
   669  	seen.Set(cur)
   670  	l, r := po.chl(cur), po.chr(cur)
   671  	if !seen.Test(l) {
   672  		po.findpaths1(l, dst, seen, path)
   673  	}
   674  	if !seen.Test(r) {
   675  		po.findpaths1(r, dst, seen, path)
   676  	}
   677  	if path.Test(l) || path.Test(r) {
   678  		path.Set(cur)
   679  	}
   680  }
   681  
   682  // Check whether it is recorded that i1!=i2
   683  func (po *poset) isnoneq(i1, i2 uint32) bool {
   684  	if i1 == i2 {
   685  		return false
   686  	}
   687  	if i1 < i2 {
   688  		i1, i2 = i2, i1
   689  	}
   690  
   691  	// Check if we recorded a non-equal relation before
   692  	if bs, ok := po.noneq[i1]; ok && bs.Test(i2) {
   693  		return true
   694  	}
   695  	return false
   696  }
   697  
   698  // Record that i1!=i2
   699  func (po *poset) setnoneq(n1, n2 *Value) {
   700  	i1, f1 := po.lookup(n1)
   701  	i2, f2 := po.lookup(n2)
   702  
   703  	// If any of the nodes do not exist in the poset, allocate them. Since
   704  	// we don't know any relation (in the partial order) about them, they must
   705  	// become independent roots.
   706  	if !f1 {
   707  		i1 = po.newnode(n1)
   708  		po.roots = append(po.roots, i1)
   709  		po.upush(undoNewRoot, i1, 0)
   710  	}
   711  	if !f2 {
   712  		i2 = po.newnode(n2)
   713  		po.roots = append(po.roots, i2)
   714  		po.upush(undoNewRoot, i2, 0)
   715  	}
   716  
   717  	if i1 == i2 {
   718  		panic("setnoneq on same node")
   719  	}
   720  	if i1 < i2 {
   721  		i1, i2 = i2, i1
   722  	}
   723  	bs := po.noneq[i1]
   724  	if bs == nil {
   725  		// Given that we record non-equality relations using the
   726  		// higher index as a key, the bitsize will never change size.
   727  		// TODO(rasky): if memory is a problem, consider allocating
   728  		// a small bitset and lazily grow it when higher indices arrive.
   729  		bs = newBitset(int(i1))
   730  		po.noneq[i1] = bs
   731  	} else if bs.Test(i2) {
   732  		// Already recorded
   733  		return
   734  	}
   735  	bs.Set(i2)
   736  	po.upushneq(i1, i2)
   737  }
   738  
   739  // CheckIntegrity verifies internal integrity of a poset. It is intended
   740  // for debugging purposes.
   741  func (po *poset) CheckIntegrity() {
   742  	// Record which index is a constant
   743  	constants := newBitset(int(po.lastidx + 1))
   744  	for _, c := range po.constants {
   745  		constants.Set(c)
   746  	}
   747  
   748  	// Verify that each node appears in a single DAG, and that
   749  	// all constants are within the first DAG
   750  	seen := newBitset(int(po.lastidx + 1))
   751  	for ridx, r := range po.roots {
   752  		if r == 0 {
   753  			panic("empty root")
   754  		}
   755  
   756  		po.dfs(r, false, func(i uint32) bool {
   757  			if seen.Test(i) {
   758  				panic("duplicate node")
   759  			}
   760  			seen.Set(i)
   761  			if constants.Test(i) {
   762  				if ridx != 0 {
   763  					panic("constants not in the first DAG")
   764  				}
   765  			}
   766  			return false
   767  		})
   768  	}
   769  
   770  	// Verify that values contain the minimum set
   771  	for id, idx := range po.values {
   772  		if !seen.Test(idx) {
   773  			panic(fmt.Errorf("spurious value [%d]=%d", id, idx))
   774  		}
   775  	}
   776  
   777  	// Verify that only existing nodes have non-zero children
   778  	for i, n := range po.nodes {
   779  		if n.l|n.r != 0 {
   780  			if !seen.Test(uint32(i)) {
   781  				panic(fmt.Errorf("children of unknown node %d->%v", i, n))
   782  			}
   783  			if n.l.Target() == uint32(i) || n.r.Target() == uint32(i) {
   784  				panic(fmt.Errorf("self-loop on node %d", i))
   785  			}
   786  		}
   787  	}
   788  }
   789  
   790  // CheckEmpty checks that a poset is completely empty.
   791  // It can be used for debugging purposes, as a poset is supposed to
   792  // be empty after it's fully rolled back through Undo.
   793  func (po *poset) CheckEmpty() error {
   794  	if len(po.nodes) != 1 {
   795  		return fmt.Errorf("non-empty nodes list: %v", po.nodes)
   796  	}
   797  	if len(po.values) != 0 {
   798  		return fmt.Errorf("non-empty value map: %v", po.values)
   799  	}
   800  	if len(po.roots) != 0 {
   801  		return fmt.Errorf("non-empty root list: %v", po.roots)
   802  	}
   803  	if len(po.constants) != 0 {
   804  		return fmt.Errorf("non-empty constants: %v", po.constants)
   805  	}
   806  	if len(po.undo) != 0 {
   807  		return fmt.Errorf("non-empty undo list: %v", po.undo)
   808  	}
   809  	if po.lastidx != 0 {
   810  		return fmt.Errorf("lastidx index is not zero: %v", po.lastidx)
   811  	}
   812  	for _, bs := range po.noneq {
   813  		for _, x := range bs {
   814  			if x != 0 {
   815  				return fmt.Errorf("non-empty noneq map")
   816  			}
   817  		}
   818  	}
   819  	return nil
   820  }
   821  
   822  // DotDump dumps the poset in graphviz format to file fn, with the specified title.
   823  func (po *poset) DotDump(fn string, title string) error {
   824  	f, err := os.Create(fn)
   825  	if err != nil {
   826  		return err
   827  	}
   828  	defer f.Close()
   829  
   830  	// Create reverse index mapping (taking aliases into account)
   831  	names := make(map[uint32]string)
   832  	for id, i := range po.values {
   833  		s := names[i]
   834  		if s == "" {
   835  			s = fmt.Sprintf("v%d", id)
   836  		} else {
   837  			s += fmt.Sprintf(", v%d", id)
   838  		}
   839  		names[i] = s
   840  	}
   841  
   842  	// Create reverse constant mapping
   843  	consts := make(map[uint32]int64)
   844  	for val, idx := range po.constants {
   845  		consts[idx] = val
   846  	}
   847  
   848  	fmt.Fprintf(f, "digraph poset {\n")
   849  	fmt.Fprintf(f, "\tedge [ fontsize=10 ]\n")
   850  	for ridx, r := range po.roots {
   851  		fmt.Fprintf(f, "\tsubgraph root%d {\n", ridx)
   852  		po.dfs(r, false, func(i uint32) bool {
   853  			if val, ok := consts[i]; ok {
   854  				// Constant
   855  				var vals string
   856  				if po.flags&posetFlagUnsigned != 0 {
   857  					vals = fmt.Sprint(uint64(val))
   858  				} else {
   859  					vals = fmt.Sprint(int64(val))
   860  				}
   861  				fmt.Fprintf(f, "\t\tnode%d [shape=box style=filled fillcolor=cadetblue1 label=<%s <font point-size=\"6\">%s [%d]</font>>]\n",
   862  					i, vals, names[i], i)
   863  			} else {
   864  				// Normal SSA value
   865  				fmt.Fprintf(f, "\t\tnode%d [label=<%s <font point-size=\"6\">[%d]</font>>]\n", i, names[i], i)
   866  			}
   867  			chl, chr := po.children(i)
   868  			for _, ch := range []posetEdge{chl, chr} {
   869  				if ch != 0 {
   870  					if ch.Strict() {
   871  						fmt.Fprintf(f, "\t\tnode%d -> node%d [label=\" <\" color=\"red\"]\n", i, ch.Target())
   872  					} else {
   873  						fmt.Fprintf(f, "\t\tnode%d -> node%d [label=\" <=\" color=\"green\"]\n", i, ch.Target())
   874  					}
   875  				}
   876  			}
   877  			return false
   878  		})
   879  		fmt.Fprintf(f, "\t}\n")
   880  	}
   881  	fmt.Fprintf(f, "\tlabelloc=\"t\"\n")
   882  	fmt.Fprintf(f, "\tlabeldistance=\"3.0\"\n")
   883  	fmt.Fprintf(f, "\tlabel=%q\n", title)
   884  	fmt.Fprintf(f, "}\n")
   885  	return nil
   886  }
   887  
   888  // Ordered reports whether n1<n2. It returns false either when it is
   889  // certain that n1<n2 is false, or if there is not enough information
   890  // to tell.
   891  // Complexity is O(n).
   892  func (po *poset) Ordered(n1, n2 *Value) bool {
   893  	if debugPoset {
   894  		defer po.CheckIntegrity()
   895  	}
   896  	if n1.ID == n2.ID {
   897  		panic("should not call Ordered with n1==n2")
   898  	}
   899  
   900  	i1, f1 := po.lookup(n1)
   901  	i2, f2 := po.lookup(n2)
   902  	if !f1 || !f2 {
   903  		return false
   904  	}
   905  
   906  	return i1 != i2 && po.reaches(i1, i2, true)
   907  }
   908  
   909  // Ordered reports whether n1<=n2. It returns false either when it is
   910  // certain that n1<=n2 is false, or if there is not enough information
   911  // to tell.
   912  // Complexity is O(n).
   913  func (po *poset) OrderedOrEqual(n1, n2 *Value) bool {
   914  	if debugPoset {
   915  		defer po.CheckIntegrity()
   916  	}
   917  	if n1.ID == n2.ID {
   918  		panic("should not call Ordered with n1==n2")
   919  	}
   920  
   921  	i1, f1 := po.lookup(n1)
   922  	i2, f2 := po.lookup(n2)
   923  	if !f1 || !f2 {
   924  		return false
   925  	}
   926  
   927  	return i1 == i2 || po.reaches(i1, i2, false)
   928  }
   929  
   930  // Equal reports whether n1==n2. It returns false either when it is
   931  // certain that n1==n2 is false, or if there is not enough information
   932  // to tell.
   933  // Complexity is O(1).
   934  func (po *poset) Equal(n1, n2 *Value) bool {
   935  	if debugPoset {
   936  		defer po.CheckIntegrity()
   937  	}
   938  	if n1.ID == n2.ID {
   939  		panic("should not call Equal with n1==n2")
   940  	}
   941  
   942  	i1, f1 := po.lookup(n1)
   943  	i2, f2 := po.lookup(n2)
   944  	return f1 && f2 && i1 == i2
   945  }
   946  
   947  // NonEqual reports whether n1!=n2. It returns false either when it is
   948  // certain that n1!=n2 is false, or if there is not enough information
   949  // to tell.
   950  // Complexity is O(n) (because it internally calls Ordered to see if we
   951  // can infer n1!=n2 from n1<n2 or n2<n1).
   952  func (po *poset) NonEqual(n1, n2 *Value) bool {
   953  	if debugPoset {
   954  		defer po.CheckIntegrity()
   955  	}
   956  	if n1.ID == n2.ID {
   957  		panic("should not call NonEqual with n1==n2")
   958  	}
   959  
   960  	// If we never saw the nodes before, we don't
   961  	// have a recorded non-equality.
   962  	i1, f1 := po.lookup(n1)
   963  	i2, f2 := po.lookup(n2)
   964  	if !f1 || !f2 {
   965  		return false
   966  	}
   967  
   968  	// Check if we recored inequality
   969  	if po.isnoneq(i1, i2) {
   970  		return true
   971  	}
   972  
   973  	// Check if n1<n2 or n2<n1, in which case we can infer that n1!=n2
   974  	if po.Ordered(n1, n2) || po.Ordered(n2, n1) {
   975  		return true
   976  	}
   977  
   978  	return false
   979  }
   980  
   981  // setOrder records that n1<n2 or n1<=n2 (depending on strict). Returns false
   982  // if this is a contradiction.
   983  // Implements SetOrder() and SetOrderOrEqual()
   984  func (po *poset) setOrder(n1, n2 *Value, strict bool) bool {
   985  	i1, f1 := po.lookup(n1)
   986  	i2, f2 := po.lookup(n2)
   987  
   988  	switch {
   989  	case !f1 && !f2:
   990  		// Neither n1 nor n2 are in the poset, so they are not related
   991  		// in any way to existing nodes.
   992  		// Create a new DAG to record the relation.
   993  		i1, i2 = po.newnode(n1), po.newnode(n2)
   994  		po.roots = append(po.roots, i1)
   995  		po.upush(undoNewRoot, i1, 0)
   996  		po.addchild(i1, i2, strict)
   997  
   998  	case f1 && !f2:
   999  		// n1 is in one of the DAGs, while n2 is not. Add n2 as children
  1000  		// of n1.
  1001  		i2 = po.newnode(n2)
  1002  		po.addchild(i1, i2, strict)
  1003  
  1004  	case !f1 && f2:
  1005  		// n1 is not in any DAG but n2 is. If n2 is a root, we can put
  1006  		// n1 in its place as a root; otherwise, we need to create a new
  1007  		// dummy root to record the relation.
  1008  		i1 = po.newnode(n1)
  1009  
  1010  		if po.isroot(i2) {
  1011  			po.changeroot(i2, i1)
  1012  			po.upush(undoChangeRoot, i1, newedge(i2, strict))
  1013  			po.addchild(i1, i2, strict)
  1014  			return true
  1015  		}
  1016  
  1017  		// Search for i2's root; this requires a O(n) search on all
  1018  		// DAGs
  1019  		r := po.findroot(i2)
  1020  
  1021  		// Re-parent as follows:
  1022  		//
  1023  		//                  dummy
  1024  		//     r            /   \
  1025  		//      \   ===>   r    i1
  1026  		//      i2          \   /
  1027  		//                    i2
  1028  		//
  1029  		dummy := po.newnode(nil)
  1030  		po.changeroot(r, dummy)
  1031  		po.upush(undoChangeRoot, dummy, newedge(r, false))
  1032  		po.addchild(dummy, r, false)
  1033  		po.addchild(dummy, i1, false)
  1034  		po.addchild(i1, i2, strict)
  1035  
  1036  	case f1 && f2:
  1037  		// If the nodes are aliased, fail only if we're setting a strict order
  1038  		// (that is, we cannot set n1<n2 if n1==n2).
  1039  		if i1 == i2 {
  1040  			return !strict
  1041  		}
  1042  
  1043  		// If we are trying to record n1<=n2 but we learned that n1!=n2,
  1044  		// record n1<n2, as it provides more information.
  1045  		if !strict && po.isnoneq(i1, i2) {
  1046  			strict = true
  1047  		}
  1048  
  1049  		// Both n1 and n2 are in the poset. This is the complex part of the algorithm
  1050  		// as we need to find many different cases and DAG shapes.
  1051  
  1052  		// Check if n1 somehow reaches n2
  1053  		if po.reaches(i1, i2, false) {
  1054  			// This is the table of all cases we need to handle:
  1055  			//
  1056  			//      DAG          New      Action
  1057  			//      ---------------------------------------------------
  1058  			// #1:  N1<=X<=N2 |  N1<=N2 | do nothing
  1059  			// #2:  N1<=X<=N2 |  N1<N2  | add strict edge (N1<N2)
  1060  			// #3:  N1<X<N2   |  N1<=N2 | do nothing (we already know more)
  1061  			// #4:  N1<X<N2   |  N1<N2  | do nothing
  1062  
  1063  			// Check if we're in case #2
  1064  			if strict && !po.reaches(i1, i2, true) {
  1065  				po.addchild(i1, i2, true)
  1066  				return true
  1067  			}
  1068  
  1069  			// Case #1, #3 o #4: nothing to do
  1070  			return true
  1071  		}
  1072  
  1073  		// Check if n2 somehow reaches n1
  1074  		if po.reaches(i2, i1, false) {
  1075  			// This is the table of all cases we need to handle:
  1076  			//
  1077  			//      DAG           New      Action
  1078  			//      ---------------------------------------------------
  1079  			// #5:  N2<=X<=N1  |  N1<=N2 | collapse path (learn that N1=X=N2)
  1080  			// #6:  N2<=X<=N1  |  N1<N2  | contradiction
  1081  			// #7:  N2<X<N1    |  N1<=N2 | contradiction in the path
  1082  			// #8:  N2<X<N1    |  N1<N2  | contradiction
  1083  
  1084  			if strict {
  1085  				// Cases #6 and #8: contradiction
  1086  				return false
  1087  			}
  1088  
  1089  			// We're in case #5 or #7. Try to collapse path, and that will
  1090  			// fail if it realizes that we are in case #7.
  1091  			return po.collapsepath(n2, n1)
  1092  		}
  1093  
  1094  		// We don't know of any existing relation between n1 and n2. They could
  1095  		// be part of the same DAG or not.
  1096  		// Find their roots to check whether they are in the same DAG.
  1097  		r1, r2 := po.findroot(i1), po.findroot(i2)
  1098  		if r1 != r2 {
  1099  			// We need to merge the two DAGs to record a relation between the nodes
  1100  			po.mergeroot(r1, r2)
  1101  		}
  1102  
  1103  		// Connect n1 and n2
  1104  		po.addchild(i1, i2, strict)
  1105  	}
  1106  
  1107  	return true
  1108  }
  1109  
  1110  // SetOrder records that n1<n2. Returns false if this is a contradiction
  1111  // Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
  1112  func (po *poset) SetOrder(n1, n2 *Value) bool {
  1113  	if debugPoset {
  1114  		defer po.CheckIntegrity()
  1115  	}
  1116  	if n1.ID == n2.ID {
  1117  		panic("should not call SetOrder with n1==n2")
  1118  	}
  1119  	return po.setOrder(n1, n2, true)
  1120  }
  1121  
  1122  // SetOrderOrEqual records that n1<=n2. Returns false if this is a contradiction
  1123  // Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
  1124  func (po *poset) SetOrderOrEqual(n1, n2 *Value) bool {
  1125  	if debugPoset {
  1126  		defer po.CheckIntegrity()
  1127  	}
  1128  	if n1.ID == n2.ID {
  1129  		panic("should not call SetOrder with n1==n2")
  1130  	}
  1131  	return po.setOrder(n1, n2, false)
  1132  }
  1133  
  1134  // SetEqual records that n1==n2. Returns false if this is a contradiction
  1135  // (that is, if it is already recorded that n1<n2 or n2<n1).
  1136  // Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
  1137  func (po *poset) SetEqual(n1, n2 *Value) bool {
  1138  	if debugPoset {
  1139  		defer po.CheckIntegrity()
  1140  	}
  1141  	if n1.ID == n2.ID {
  1142  		panic("should not call Add with n1==n2")
  1143  	}
  1144  
  1145  	i1, f1 := po.lookup(n1)
  1146  	i2, f2 := po.lookup(n2)
  1147  
  1148  	switch {
  1149  	case !f1 && !f2:
  1150  		i1 = po.newnode(n1)
  1151  		po.roots = append(po.roots, i1)
  1152  		po.upush(undoNewRoot, i1, 0)
  1153  		po.aliasnewnode(n1, n2)
  1154  	case f1 && !f2:
  1155  		po.aliasnewnode(n1, n2)
  1156  	case !f1 && f2:
  1157  		po.aliasnewnode(n2, n1)
  1158  	case f1 && f2:
  1159  		if i1 == i2 {
  1160  			// Already aliased, ignore
  1161  			return true
  1162  		}
  1163  
  1164  		// If we recorded that n1!=n2, this is a contradiction.
  1165  		if po.isnoneq(i1, i2) {
  1166  			return false
  1167  		}
  1168  
  1169  		// If we already knew that n1<=n2, we can collapse the path to
  1170  		// record n1==n2 (and viceversa).
  1171  		if po.reaches(i1, i2, false) {
  1172  			return po.collapsepath(n1, n2)
  1173  		}
  1174  		if po.reaches(i2, i1, false) {
  1175  			return po.collapsepath(n2, n1)
  1176  		}
  1177  
  1178  		r1 := po.findroot(i1)
  1179  		r2 := po.findroot(i2)
  1180  		if r1 != r2 {
  1181  			// Merge the two DAGs so we can record relations between the nodes
  1182  			po.mergeroot(r1, r2)
  1183  		}
  1184  
  1185  		// Set n2 as alias of n1. This will also update all the references
  1186  		// to n2 to become references to n1
  1187  		i2s := newBitset(int(po.lastidx) + 1)
  1188  		i2s.Set(i2)
  1189  		po.aliasnodes(n1, i2s)
  1190  	}
  1191  	return true
  1192  }
  1193  
  1194  // SetNonEqual records that n1!=n2. Returns false if this is a contradiction
  1195  // (that is, if it is already recorded that n1==n2).
  1196  // Complexity is O(n).
  1197  func (po *poset) SetNonEqual(n1, n2 *Value) bool {
  1198  	if debugPoset {
  1199  		defer po.CheckIntegrity()
  1200  	}
  1201  	if n1.ID == n2.ID {
  1202  		panic("should not call SetNonEqual with n1==n2")
  1203  	}
  1204  
  1205  	// Check whether the nodes are already in the poset
  1206  	i1, f1 := po.lookup(n1)
  1207  	i2, f2 := po.lookup(n2)
  1208  
  1209  	// If either node wasn't present, we just record the new relation
  1210  	// and exit.
  1211  	if !f1 || !f2 {
  1212  		po.setnoneq(n1, n2)
  1213  		return true
  1214  	}
  1215  
  1216  	// See if we already know this, in which case there's nothing to do.
  1217  	if po.isnoneq(i1, i2) {
  1218  		return true
  1219  	}
  1220  
  1221  	// Check if we're contradicting an existing equality relation
  1222  	if po.Equal(n1, n2) {
  1223  		return false
  1224  	}
  1225  
  1226  	// Record non-equality
  1227  	po.setnoneq(n1, n2)
  1228  
  1229  	// If we know that i1<=i2 but not i1<i2, learn that as we
  1230  	// now know that they are not equal. Do the same for i2<=i1.
  1231  	// Do this check only if both nodes were already in the DAG,
  1232  	// otherwise there cannot be an existing relation.
  1233  	if po.reaches(i1, i2, false) && !po.reaches(i1, i2, true) {
  1234  		po.addchild(i1, i2, true)
  1235  	}
  1236  	if po.reaches(i2, i1, false) && !po.reaches(i2, i1, true) {
  1237  		po.addchild(i2, i1, true)
  1238  	}
  1239  
  1240  	return true
  1241  }
  1242  
  1243  // Checkpoint saves the current state of the DAG so that it's possible
  1244  // to later undo this state.
  1245  // Complexity is O(1).
  1246  func (po *poset) Checkpoint() {
  1247  	po.undo = append(po.undo, posetUndo{typ: undoCheckpoint})
  1248  }
  1249  
  1250  // Undo restores the state of the poset to the previous checkpoint.
  1251  // Complexity depends on the type of operations that were performed
  1252  // since the last checkpoint; each Set* operation creates an undo
  1253  // pass which Undo has to revert with a worst-case complexity of O(n).
  1254  func (po *poset) Undo() {
  1255  	if len(po.undo) == 0 {
  1256  		panic("empty undo stack")
  1257  	}
  1258  	if debugPoset {
  1259  		defer po.CheckIntegrity()
  1260  	}
  1261  
  1262  	for len(po.undo) > 0 {
  1263  		pass := po.undo[len(po.undo)-1]
  1264  		po.undo = po.undo[:len(po.undo)-1]
  1265  
  1266  		switch pass.typ {
  1267  		case undoCheckpoint:
  1268  			return
  1269  
  1270  		case undoSetChl:
  1271  			po.setchl(pass.idx, pass.edge)
  1272  
  1273  		case undoSetChr:
  1274  			po.setchr(pass.idx, pass.edge)
  1275  
  1276  		case undoNonEqual:
  1277  			po.noneq[uint32(pass.ID)].Clear(pass.idx)
  1278  
  1279  		case undoNewNode:
  1280  			if pass.idx != po.lastidx {
  1281  				panic("invalid newnode index")
  1282  			}
  1283  			if pass.ID != 0 {
  1284  				if po.values[pass.ID] != pass.idx {
  1285  					panic("invalid newnode undo pass")
  1286  				}
  1287  				delete(po.values, pass.ID)
  1288  			}
  1289  			po.setchl(pass.idx, 0)
  1290  			po.setchr(pass.idx, 0)
  1291  			po.nodes = po.nodes[:pass.idx]
  1292  			po.lastidx--
  1293  
  1294  		case undoNewConstant:
  1295  			// FIXME: remove this O(n) loop
  1296  			var val int64
  1297  			var i uint32
  1298  			for val, i = range po.constants {
  1299  				if i == pass.idx {
  1300  					break
  1301  				}
  1302  			}
  1303  			if i != pass.idx {
  1304  				panic("constant not found in undo pass")
  1305  			}
  1306  			if pass.ID == 0 {
  1307  				delete(po.constants, val)
  1308  			} else {
  1309  				// Restore previous index as constant node
  1310  				// (also restoring the invariant on correct bounds)
  1311  				oldidx := uint32(pass.ID)
  1312  				po.constants[val] = oldidx
  1313  			}
  1314  
  1315  		case undoAliasNode:
  1316  			ID, prev := pass.ID, pass.idx
  1317  			cur := po.values[ID]
  1318  			if prev == 0 {
  1319  				// Born as an alias, die as an alias
  1320  				delete(po.values, ID)
  1321  			} else {
  1322  				if cur == prev {
  1323  					panic("invalid aliasnode undo pass")
  1324  				}
  1325  				// Give it back previous value
  1326  				po.values[ID] = prev
  1327  			}
  1328  
  1329  		case undoNewRoot:
  1330  			i := pass.idx
  1331  			l, r := po.children(i)
  1332  			if l|r != 0 {
  1333  				panic("non-empty root in undo newroot")
  1334  			}
  1335  			po.removeroot(i)
  1336  
  1337  		case undoChangeRoot:
  1338  			i := pass.idx
  1339  			l, r := po.children(i)
  1340  			if l|r != 0 {
  1341  				panic("non-empty root in undo changeroot")
  1342  			}
  1343  			po.changeroot(i, pass.edge.Target())
  1344  
  1345  		case undoMergeRoot:
  1346  			i := pass.idx
  1347  			l, r := po.children(i)
  1348  			po.changeroot(i, l.Target())
  1349  			po.roots = append(po.roots, r.Target())
  1350  
  1351  		default:
  1352  			panic(pass.typ)
  1353  		}
  1354  	}
  1355  
  1356  	if debugPoset && po.CheckEmpty() != nil {
  1357  		panic("poset not empty at the end of undo")
  1358  	}
  1359  }
  1360  

View as plain text