Source file src/cmd/compile/internal/ssa/poset.go

Documentation: cmd/compile/internal/ssa

     1  // Copyright 2018 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package ssa
     6  
     7  import (
     8  	"errors"
     9  	"fmt"
    10  	"os"
    11  )
    12  
    13  const uintSize = 32 << (^uint(0) >> 32 & 1) // 32 or 64
    14  
    15  // bitset is a bit array for dense indexes.
    16  type bitset []uint
    17  
    18  func newBitset(n int) bitset {
    19  	return make(bitset, (n+uintSize-1)/uintSize)
    20  }
    21  
    22  func (bs bitset) Reset() {
    23  	for i := range bs {
    24  		bs[i] = 0
    25  	}
    26  }
    27  
    28  func (bs bitset) Set(idx uint32) {
    29  	bs[idx/uintSize] |= 1 << (idx % uintSize)
    30  }
    31  
    32  func (bs bitset) Clear(idx uint32) {
    33  	bs[idx/uintSize] &^= 1 << (idx % uintSize)
    34  }
    35  
    36  func (bs bitset) Test(idx uint32) bool {
    37  	return bs[idx/uintSize]&(1<<(idx%uintSize)) != 0
    38  }
    39  
    40  type undoType uint8
    41  
    42  const (
    43  	undoInvalid    undoType = iota
    44  	undoCheckpoint          // a checkpoint to group undo passes
    45  	undoSetChl              // change back left child of undo.idx to undo.edge
    46  	undoSetChr              // change back right child of undo.idx to undo.edge
    47  	undoNonEqual            // forget that SSA value undo.ID is non-equal to undo.idx (another ID)
    48  	undoNewNode             // remove new node created for SSA value undo.ID
    49  	undoAliasNode           // unalias SSA value undo.ID so that it points back to node index undo.idx
    50  	undoNewRoot             // remove node undo.idx from root list
    51  	undoChangeRoot          // remove node undo.idx from root list, and put back undo.edge.Target instead
    52  	undoMergeRoot           // remove node undo.idx from root list, and put back its children instead
    53  )
    54  
    55  // posetUndo represents an undo pass to be performed.
    56  // It's an union of fields that can be used to store information,
    57  // and typ is the discriminant, that specifies which kind
    58  // of operation must be performed. Not all fields are always used.
    59  type posetUndo struct {
    60  	typ  undoType
    61  	idx  uint32
    62  	ID   ID
    63  	edge posetEdge
    64  }
    65  
    66  const (
    67  	// Make poset handle constants as unsigned numbers.
    68  	posetFlagUnsigned = 1 << iota
    69  )
    70  
    71  // A poset edge. The zero value is the null/empty edge.
    72  // Packs target node index (31 bits) and strict flag (1 bit).
    73  type posetEdge uint32
    74  
    75  func newedge(t uint32, strict bool) posetEdge {
    76  	s := uint32(0)
    77  	if strict {
    78  		s = 1
    79  	}
    80  	return posetEdge(t<<1 | s)
    81  }
    82  func (e posetEdge) Target() uint32 { return uint32(e) >> 1 }
    83  func (e posetEdge) Strict() bool   { return uint32(e)&1 != 0 }
    84  func (e posetEdge) String() string {
    85  	s := fmt.Sprint(e.Target())
    86  	if e.Strict() {
    87  		s += "*"
    88  	}
    89  	return s
    90  }
    91  
    92  // posetNode is a node of a DAG within the poset.
    93  type posetNode struct {
    94  	l, r posetEdge
    95  }
    96  
    97  // poset is a union-find data structure that can represent a partially ordered set
    98  // of SSA values. Given a binary relation that creates a partial order (eg: '<'),
    99  // clients can record relations between SSA values using SetOrder, and later
   100  // check relations (in the transitive closure) with Ordered. For instance,
   101  // if SetOrder is called to record that A<B and B<C, Ordered will later confirm
   102  // that A<C.
   103  //
   104  // It is possible to record equality relations between SSA values with SetEqual and check
   105  // equality with Equal. Equality propagates into the transitive closure for the partial
   106  // order so that if we know that A<B<C and later learn that A==D, Ordered will return
   107  // true for D<C.
   108  //
   109  // poset will refuse to record new relations that contradict existing relations:
   110  // for instance if A<B<C, calling SetOrder for C<A will fail returning false; also
   111  // calling SetEqual for C==A will fail.
   112  //
   113  // It is also possible to record inequality relations between nodes with SetNonEqual;
   114  // given that non-equality is not transitive, the only effect is that a later call
   115  // to SetEqual for the same values will fail. NonEqual checks whether it is known that
   116  // the nodes are different, either because SetNonEqual was called before, or because
   117  // we know that they are strictly ordered.
   118  //
   119  // It is implemented as a forest of DAGs; in each DAG, if node A dominates B,
   120  // it means that A<B. Equality is represented by mapping two SSA values to the same
   121  // DAG node; when a new equality relation is recorded between two existing nodes,
   122  // the nodes are merged, adjusting incoming and outgoing edges.
   123  //
   124  // Constants are specially treated. When a constant is added to the poset, it is
   125  // immediately linked to other constants already present; so for instance if the
   126  // poset knows that x<=3, and then x is tested against 5, 5 is first added and linked
   127  // 3 (using 3<5), so that the poset knows that x<=3<5; at that point, it is able
   128  // to answer x<5 correctly.
   129  //
   130  // poset is designed to be memory efficient and do little allocations during normal usage.
   131  // Most internal data structures are pre-allocated and flat, so for instance adding a
   132  // new relation does not cause any allocation. For performance reasons,
   133  // each node has only up to two outgoing edges (like a binary tree), so intermediate
   134  // "dummy" nodes are required to represent more than two relations. For instance,
   135  // to record that A<I, A<J, A<K (with no known relation between I,J,K), we create the
   136  // following DAG:
   137  //
   138  //         A
   139  //        / \
   140  //       I  dummy
   141  //           /  \
   142  //          J    K
   143  //
   144  type poset struct {
   145  	lastidx   uint32        // last generated dense index
   146  	flags     uint8         // internal flags
   147  	values    map[ID]uint32 // map SSA values to dense indexes
   148  	constants []*Value      // record SSA constants together with their value
   149  	nodes     []posetNode   // nodes (in all DAGs)
   150  	roots     []uint32      // list of root nodes (forest)
   151  	noneq     map[ID]bitset // non-equal relations
   152  	undo      []posetUndo   // undo chain
   153  }
   154  
   155  func newPoset() *poset {
   156  	return &poset{
   157  		values:    make(map[ID]uint32),
   158  		constants: make([]*Value, 0, 8),
   159  		nodes:     make([]posetNode, 1, 16),
   160  		roots:     make([]uint32, 0, 4),
   161  		noneq:     make(map[ID]bitset),
   162  		undo:      make([]posetUndo, 0, 4),
   163  	}
   164  }
   165  
   166  func (po *poset) SetUnsigned(uns bool) {
   167  	if uns {
   168  		po.flags |= posetFlagUnsigned
   169  	} else {
   170  		po.flags &^= posetFlagUnsigned
   171  	}
   172  }
   173  
   174  // Handle children
   175  func (po *poset) setchl(i uint32, l posetEdge) { po.nodes[i].l = l }
   176  func (po *poset) setchr(i uint32, r posetEdge) { po.nodes[i].r = r }
   177  func (po *poset) chl(i uint32) uint32          { return po.nodes[i].l.Target() }
   178  func (po *poset) chr(i uint32) uint32          { return po.nodes[i].r.Target() }
   179  func (po *poset) children(i uint32) (posetEdge, posetEdge) {
   180  	return po.nodes[i].l, po.nodes[i].r
   181  }
   182  
   183  // upush records a new undo step. It can be used for simple
   184  // undo passes that record up to one index and one edge.
   185  func (po *poset) upush(typ undoType, p uint32, e posetEdge) {
   186  	po.undo = append(po.undo, posetUndo{typ: typ, idx: p, edge: e})
   187  }
   188  
   189  // upushnew pushes an undo pass for a new node
   190  func (po *poset) upushnew(id ID, idx uint32) {
   191  	po.undo = append(po.undo, posetUndo{typ: undoNewNode, ID: id, idx: idx})
   192  }
   193  
   194  // upushneq pushes a new undo pass for a nonequal relation
   195  func (po *poset) upushneq(id1 ID, id2 ID) {
   196  	po.undo = append(po.undo, posetUndo{typ: undoNonEqual, ID: id1, idx: uint32(id2)})
   197  }
   198  
   199  // upushalias pushes a new undo pass for aliasing two nodes
   200  func (po *poset) upushalias(id ID, i2 uint32) {
   201  	po.undo = append(po.undo, posetUndo{typ: undoAliasNode, ID: id, idx: i2})
   202  }
   203  
   204  // addchild adds i2 as direct child of i1.
   205  func (po *poset) addchild(i1, i2 uint32, strict bool) {
   206  	i1l, i1r := po.children(i1)
   207  	e2 := newedge(i2, strict)
   208  
   209  	if i1l == 0 {
   210  		po.setchl(i1, e2)
   211  		po.upush(undoSetChl, i1, 0)
   212  	} else if i1r == 0 {
   213  		po.setchr(i1, e2)
   214  		po.upush(undoSetChr, i1, 0)
   215  	} else {
   216  		// If n1 already has two children, add an intermediate dummy
   217  		// node to record the relation correctly (without relating
   218  		// n2 to other existing nodes). Use a non-deterministic value
   219  		// to decide whether to append on the left or the right, to avoid
   220  		// creating degenerated chains.
   221  		//
   222  		//      n1
   223  		//     /  \
   224  		//   i1l  dummy
   225  		//        /   \
   226  		//      i1r   n2
   227  		//
   228  		dummy := po.newnode(nil)
   229  		if (i1^i2)&1 != 0 { // non-deterministic
   230  			po.setchl(dummy, i1r)
   231  			po.setchr(dummy, e2)
   232  			po.setchr(i1, newedge(dummy, false))
   233  			po.upush(undoSetChr, i1, i1r)
   234  		} else {
   235  			po.setchl(dummy, i1l)
   236  			po.setchr(dummy, e2)
   237  			po.setchl(i1, newedge(dummy, false))
   238  			po.upush(undoSetChl, i1, i1l)
   239  		}
   240  	}
   241  }
   242  
   243  // newnode allocates a new node bound to SSA value n.
   244  // If n is nil, this is a dummy node (= only used internally).
   245  func (po *poset) newnode(n *Value) uint32 {
   246  	i := po.lastidx + 1
   247  	po.lastidx++
   248  	po.nodes = append(po.nodes, posetNode{})
   249  	if n != nil {
   250  		if po.values[n.ID] != 0 {
   251  			panic("newnode for Value already inserted")
   252  		}
   253  		po.values[n.ID] = i
   254  		po.upushnew(n.ID, i)
   255  	} else {
   256  		po.upushnew(0, i)
   257  	}
   258  	return i
   259  }
   260  
   261  // lookup searches for a SSA value into the forest of DAGS, and return its node.
   262  // Constants are materialized on the fly during lookup.
   263  func (po *poset) lookup(n *Value) (uint32, bool) {
   264  	i, f := po.values[n.ID]
   265  	if !f && n.isGenericIntConst() {
   266  		po.newconst(n)
   267  		i, f = po.values[n.ID]
   268  	}
   269  	return i, f
   270  }
   271  
   272  // newconst creates a node for a constant. It links it to other constants, so
   273  // that n<=5 is detected true when n<=3 is known to be true.
   274  // TODO: this is O(N), fix it.
   275  func (po *poset) newconst(n *Value) {
   276  	if !n.isGenericIntConst() {
   277  		panic("newconst on non-constant")
   278  	}
   279  
   280  	// If this is the first constant, put it into a new root, as
   281  	// we can't record an existing connection so we don't have
   282  	// a specific DAG to add it to.
   283  	if len(po.constants) == 0 {
   284  		i := po.newnode(n)
   285  		po.roots = append(po.roots, i)
   286  		po.upush(undoNewRoot, i, 0)
   287  		po.constants = append(po.constants, n)
   288  		return
   289  	}
   290  
   291  	// Find the lower and upper bound among existing constants. That is,
   292  	// find the higher constant that is lower than the one that we're adding,
   293  	// and the lower constant that is higher.
   294  	// The loop is duplicated to handle signed and unsigned comparison,
   295  	// depending on how the poset was configured.
   296  	var lowerptr, higherptr *Value
   297  
   298  	if po.flags&posetFlagUnsigned != 0 {
   299  		var lower, higher uint64
   300  		val1 := n.AuxUnsigned()
   301  		for _, ptr := range po.constants {
   302  			val2 := ptr.AuxUnsigned()
   303  			if val1 == val2 {
   304  				po.aliasnode(ptr, n)
   305  				return
   306  			}
   307  			if val2 < val1 && (lowerptr == nil || val2 > lower) {
   308  				lower = val2
   309  				lowerptr = ptr
   310  			} else if val2 > val1 && (higherptr == nil || val2 < higher) {
   311  				higher = val2
   312  				higherptr = ptr
   313  			}
   314  		}
   315  	} else {
   316  		var lower, higher int64
   317  		val1 := n.AuxInt
   318  		for _, ptr := range po.constants {
   319  			val2 := ptr.AuxInt
   320  			if val1 == val2 {
   321  				po.aliasnode(ptr, n)
   322  				return
   323  			}
   324  			if val2 < val1 && (lowerptr == nil || val2 > lower) {
   325  				lower = val2
   326  				lowerptr = ptr
   327  			} else if val2 > val1 && (higherptr == nil || val2 < higher) {
   328  				higher = val2
   329  				higherptr = ptr
   330  			}
   331  		}
   332  	}
   333  
   334  	if lowerptr == nil && higherptr == nil {
   335  		// This should not happen, as at least one
   336  		// other constant must exist if we get here.
   337  		panic("no constant found")
   338  	}
   339  
   340  	// Create the new node and connect it to the bounds, so that
   341  	// lower < n < higher. We could have found both bounds or only one
   342  	// of them, depending on what other constants are present in the poset.
   343  	// Notice that we always link constants together, so they
   344  	// are always part of the same DAG.
   345  	i := po.newnode(n)
   346  	switch {
   347  	case lowerptr != nil && higherptr != nil:
   348  		// Both bounds are present, record lower < n < higher.
   349  		po.addchild(po.values[lowerptr.ID], i, true)
   350  		po.addchild(i, po.values[higherptr.ID], true)
   351  
   352  	case lowerptr != nil:
   353  		// Lower bound only, record lower < n.
   354  		po.addchild(po.values[lowerptr.ID], i, true)
   355  
   356  	case higherptr != nil:
   357  		// Higher bound only. To record n < higher, we need
   358  		// a dummy root:
   359  		//
   360  		//        dummy
   361  		//        /   \
   362  		//      root   \
   363  		//       /      n
   364  		//     ....    /
   365  		//       \    /
   366  		//       higher
   367  		//
   368  		i2 := po.values[higherptr.ID]
   369  		r2 := po.findroot(i2)
   370  		dummy := po.newnode(nil)
   371  		po.changeroot(r2, dummy)
   372  		po.upush(undoChangeRoot, dummy, newedge(r2, false))
   373  		po.addchild(dummy, r2, false)
   374  		po.addchild(dummy, i, false)
   375  		po.addchild(i, i2, true)
   376  	}
   377  
   378  	po.constants = append(po.constants, n)
   379  }
   380  
   381  // aliasnode records that n2 is an alias of n1
   382  func (po *poset) aliasnode(n1, n2 *Value) {
   383  	i1 := po.values[n1.ID]
   384  	if i1 == 0 {
   385  		panic("aliasnode for non-existing node")
   386  	}
   387  
   388  	i2 := po.values[n2.ID]
   389  	if i2 != 0 {
   390  		// Rename all references to i2 into i1
   391  		// (do not touch i1 itself, otherwise we can create useless self-loops)
   392  		for idx, n := range po.nodes {
   393  			if uint32(idx) != i1 {
   394  				l, r := n.l, n.r
   395  				if l.Target() == i2 {
   396  					po.setchl(uint32(idx), newedge(i1, l.Strict()))
   397  					po.upush(undoSetChl, uint32(idx), l)
   398  				}
   399  				if r.Target() == i2 {
   400  					po.setchr(uint32(idx), newedge(i1, r.Strict()))
   401  					po.upush(undoSetChr, uint32(idx), r)
   402  				}
   403  			}
   404  		}
   405  
   406  		// Reassign all existing IDs that point to i2 to i1.
   407  		// This includes n2.ID.
   408  		for k, v := range po.values {
   409  			if v == i2 {
   410  				po.values[k] = i1
   411  				po.upushalias(k, i2)
   412  			}
   413  		}
   414  	} else {
   415  		// n2.ID wasn't seen before, so record it as alias to i1
   416  		po.values[n2.ID] = i1
   417  		po.upushalias(n2.ID, 0)
   418  	}
   419  }
   420  
   421  func (po *poset) isroot(r uint32) bool {
   422  	for i := range po.roots {
   423  		if po.roots[i] == r {
   424  			return true
   425  		}
   426  	}
   427  	return false
   428  }
   429  
   430  func (po *poset) changeroot(oldr, newr uint32) {
   431  	for i := range po.roots {
   432  		if po.roots[i] == oldr {
   433  			po.roots[i] = newr
   434  			return
   435  		}
   436  	}
   437  	panic("changeroot on non-root")
   438  }
   439  
   440  func (po *poset) removeroot(r uint32) {
   441  	for i := range po.roots {
   442  		if po.roots[i] == r {
   443  			po.roots = append(po.roots[:i], po.roots[i+1:]...)
   444  			return
   445  		}
   446  	}
   447  	panic("removeroot on non-root")
   448  }
   449  
   450  // dfs performs a depth-first search within the DAG whose root is r.
   451  // f is the visit function called for each node; if it returns true,
   452  // the search is aborted and true is returned. The root node is
   453  // visited too.
   454  // If strict, ignore edges across a path until at least one
   455  // strict edge is found. For instance, for a chain A<=B<=C<D<=E<F,
   456  // a strict walk visits D,E,F.
   457  // If the visit ends, false is returned.
   458  func (po *poset) dfs(r uint32, strict bool, f func(i uint32) bool) bool {
   459  	closed := newBitset(int(po.lastidx + 1))
   460  	open := make([]uint32, 1, 64)
   461  	open[0] = r
   462  
   463  	if strict {
   464  		// Do a first DFS; walk all paths and stop when we find a strict
   465  		// edge, building a "next" list of nodes reachable through strict
   466  		// edges. This will be the bootstrap open list for the real DFS.
   467  		next := make([]uint32, 0, 64)
   468  
   469  		for len(open) > 0 {
   470  			i := open[len(open)-1]
   471  			open = open[:len(open)-1]
   472  
   473  			// Don't visit the same node twice. Notice that all nodes
   474  			// across non-strict paths are still visited at least once, so
   475  			// a non-strict path can never obscure a strict path to the
   476  			// same node.
   477  			if !closed.Test(i) {
   478  				closed.Set(i)
   479  
   480  				l, r := po.children(i)
   481  				if l != 0 {
   482  					if l.Strict() {
   483  						next = append(next, l.Target())
   484  					} else {
   485  						open = append(open, l.Target())
   486  					}
   487  				}
   488  				if r != 0 {
   489  					if r.Strict() {
   490  						next = append(next, r.Target())
   491  					} else {
   492  						open = append(open, r.Target())
   493  					}
   494  				}
   495  			}
   496  		}
   497  		open = next
   498  		closed.Reset()
   499  	}
   500  
   501  	for len(open) > 0 {
   502  		i := open[len(open)-1]
   503  		open = open[:len(open)-1]
   504  
   505  		if !closed.Test(i) {
   506  			if f(i) {
   507  				return true
   508  			}
   509  			closed.Set(i)
   510  			l, r := po.children(i)
   511  			if l != 0 {
   512  				open = append(open, l.Target())
   513  			}
   514  			if r != 0 {
   515  				open = append(open, r.Target())
   516  			}
   517  		}
   518  	}
   519  	return false
   520  }
   521  
   522  // Returns true if i1 dominates i2.
   523  // If strict ==  true: if the function returns true, then i1 <  i2.
   524  // If strict == false: if the function returns true, then i1 <= i2.
   525  // If the function returns false, no relation is known.
   526  func (po *poset) dominates(i1, i2 uint32, strict bool) bool {
   527  	return po.dfs(i1, strict, func(n uint32) bool {
   528  		return n == i2
   529  	})
   530  }
   531  
   532  // findroot finds i's root, that is which DAG contains i.
   533  // Returns the root; if i is itself a root, it is returned.
   534  // Panic if i is not in any DAG.
   535  func (po *poset) findroot(i uint32) uint32 {
   536  	// TODO(rasky): if needed, a way to speed up this search is
   537  	// storing a bitset for each root using it as a mini bloom filter
   538  	// of nodes present under that root.
   539  	for _, r := range po.roots {
   540  		if po.dominates(r, i, false) {
   541  			return r
   542  		}
   543  	}
   544  	panic("findroot didn't find any root")
   545  }
   546  
   547  // mergeroot merges two DAGs into one DAG by creating a new dummy root
   548  func (po *poset) mergeroot(r1, r2 uint32) uint32 {
   549  	r := po.newnode(nil)
   550  	po.setchl(r, newedge(r1, false))
   551  	po.setchr(r, newedge(r2, false))
   552  	po.changeroot(r1, r)
   553  	po.removeroot(r2)
   554  	po.upush(undoMergeRoot, r, 0)
   555  	return r
   556  }
   557  
   558  // collapsepath marks i1 and i2 as equal and collapses as equal all
   559  // nodes across all paths between i1 and i2. If a strict edge is
   560  // found, the function does not modify the DAG and returns false.
   561  func (po *poset) collapsepath(n1, n2 *Value) bool {
   562  	i1, i2 := po.values[n1.ID], po.values[n2.ID]
   563  	if po.dominates(i1, i2, true) {
   564  		return false
   565  	}
   566  
   567  	// TODO: for now, only handle the simple case of i2 being child of i1
   568  	l, r := po.children(i1)
   569  	if l.Target() == i2 || r.Target() == i2 {
   570  		po.aliasnode(n1, n2)
   571  		po.addchild(i1, i2, false)
   572  		return true
   573  	}
   574  	return true
   575  }
   576  
   577  // Check whether it is recorded that id1!=id2
   578  func (po *poset) isnoneq(id1, id2 ID) bool {
   579  	if id1 < id2 {
   580  		id1, id2 = id2, id1
   581  	}
   582  
   583  	// Check if we recorded a non-equal relation before
   584  	if bs, ok := po.noneq[id1]; ok && bs.Test(uint32(id2)) {
   585  		return true
   586  	}
   587  	return false
   588  }
   589  
   590  // Record that id1!=id2
   591  func (po *poset) setnoneq(id1, id2 ID) {
   592  	if id1 < id2 {
   593  		id1, id2 = id2, id1
   594  	}
   595  	bs := po.noneq[id1]
   596  	if bs == nil {
   597  		// Given that we record non-equality relations using the
   598  		// higher ID as a key, the bitsize will never change size.
   599  		// TODO(rasky): if memory is a problem, consider allocating
   600  		// a small bitset and lazily grow it when higher IDs arrive.
   601  		bs = newBitset(int(id1))
   602  		po.noneq[id1] = bs
   603  	} else if bs.Test(uint32(id2)) {
   604  		// Already recorded
   605  		return
   606  	}
   607  	bs.Set(uint32(id2))
   608  	po.upushneq(id1, id2)
   609  }
   610  
   611  // CheckIntegrity verifies internal integrity of a poset. It is intended
   612  // for debugging purposes.
   613  func (po *poset) CheckIntegrity() (err error) {
   614  	// Record which index is a constant
   615  	constants := newBitset(int(po.lastidx + 1))
   616  	for _, c := range po.constants {
   617  		if idx, ok := po.values[c.ID]; !ok {
   618  			err = errors.New("node missing for constant")
   619  			return err
   620  		} else {
   621  			constants.Set(idx)
   622  		}
   623  	}
   624  
   625  	// Verify that each node appears in a single DAG, and that
   626  	// all constants are within the same DAG
   627  	var croot uint32
   628  	seen := newBitset(int(po.lastidx + 1))
   629  	for _, r := range po.roots {
   630  		if r == 0 {
   631  			err = errors.New("empty root")
   632  			return
   633  		}
   634  
   635  		po.dfs(r, false, func(i uint32) bool {
   636  			if seen.Test(i) {
   637  				err = errors.New("duplicate node")
   638  				return true
   639  			}
   640  			seen.Set(i)
   641  			if constants.Test(i) {
   642  				if croot == 0 {
   643  					croot = r
   644  				} else if croot != r {
   645  					err = errors.New("constants are in different DAGs")
   646  					return true
   647  				}
   648  			}
   649  			return false
   650  		})
   651  		if err != nil {
   652  			return
   653  		}
   654  	}
   655  
   656  	// Verify that values contain the minimum set
   657  	for id, idx := range po.values {
   658  		if !seen.Test(idx) {
   659  			err = fmt.Errorf("spurious value [%d]=%d", id, idx)
   660  			return
   661  		}
   662  	}
   663  
   664  	// Verify that only existing nodes have non-zero children
   665  	for i, n := range po.nodes {
   666  		if n.l|n.r != 0 {
   667  			if !seen.Test(uint32(i)) {
   668  				err = fmt.Errorf("children of unknown node %d->%v", i, n)
   669  				return
   670  			}
   671  			if n.l.Target() == uint32(i) || n.r.Target() == uint32(i) {
   672  				err = fmt.Errorf("self-loop on node %d", i)
   673  				return
   674  			}
   675  		}
   676  	}
   677  
   678  	return
   679  }
   680  
   681  // CheckEmpty checks that a poset is completely empty.
   682  // It can be used for debugging purposes, as a poset is supposed to
   683  // be empty after it's fully rolled back through Undo.
   684  func (po *poset) CheckEmpty() error {
   685  	if len(po.nodes) != 1 {
   686  		return fmt.Errorf("non-empty nodes list: %v", po.nodes)
   687  	}
   688  	if len(po.values) != 0 {
   689  		return fmt.Errorf("non-empty value map: %v", po.values)
   690  	}
   691  	if len(po.roots) != 0 {
   692  		return fmt.Errorf("non-empty root list: %v", po.roots)
   693  	}
   694  	if len(po.constants) != 0 {
   695  		return fmt.Errorf("non-empty constants: %v", po.constants)
   696  	}
   697  	if len(po.undo) != 0 {
   698  		return fmt.Errorf("non-empty undo list: %v", po.undo)
   699  	}
   700  	if po.lastidx != 0 {
   701  		return fmt.Errorf("lastidx index is not zero: %v", po.lastidx)
   702  	}
   703  	for _, bs := range po.noneq {
   704  		for _, x := range bs {
   705  			if x != 0 {
   706  				return fmt.Errorf("non-empty noneq map")
   707  			}
   708  		}
   709  	}
   710  	return nil
   711  }
   712  
   713  // DotDump dumps the poset in graphviz format to file fn, with the specified title.
   714  func (po *poset) DotDump(fn string, title string) error {
   715  	f, err := os.Create(fn)
   716  	if err != nil {
   717  		return err
   718  	}
   719  	defer f.Close()
   720  
   721  	// Create reverse index mapping (taking aliases into account)
   722  	names := make(map[uint32]string)
   723  	for id, i := range po.values {
   724  		s := names[i]
   725  		if s == "" {
   726  			s = fmt.Sprintf("v%d", id)
   727  		} else {
   728  			s += fmt.Sprintf(", v%d", id)
   729  		}
   730  		names[i] = s
   731  	}
   732  
   733  	// Create constant mapping
   734  	consts := make(map[uint32]int64)
   735  	for _, v := range po.constants {
   736  		idx := po.values[v.ID]
   737  		if po.flags&posetFlagUnsigned != 0 {
   738  			consts[idx] = int64(v.AuxUnsigned())
   739  		} else {
   740  			consts[idx] = v.AuxInt
   741  		}
   742  	}
   743  
   744  	fmt.Fprintf(f, "digraph poset {\n")
   745  	fmt.Fprintf(f, "\tedge [ fontsize=10 ]\n")
   746  	for ridx, r := range po.roots {
   747  		fmt.Fprintf(f, "\tsubgraph root%d {\n", ridx)
   748  		po.dfs(r, false, func(i uint32) bool {
   749  			if val, ok := consts[i]; ok {
   750  				// Constant
   751  				var vals string
   752  				if po.flags&posetFlagUnsigned != 0 {
   753  					vals = fmt.Sprint(uint64(val))
   754  				} else {
   755  					vals = fmt.Sprint(int64(val))
   756  				}
   757  				fmt.Fprintf(f, "\t\tnode%d [shape=box style=filled fillcolor=cadetblue1 label=<%s <font point-size=\"6\">%s [%d]</font>>]\n",
   758  					i, vals, names[i], i)
   759  			} else {
   760  				// Normal SSA value
   761  				fmt.Fprintf(f, "\t\tnode%d [label=<%s <font point-size=\"6\">[%d]</font>>]\n", i, names[i], i)
   762  			}
   763  			chl, chr := po.children(i)
   764  			for _, ch := range []posetEdge{chl, chr} {
   765  				if ch != 0 {
   766  					if ch.Strict() {
   767  						fmt.Fprintf(f, "\t\tnode%d -> node%d [label=\" <\" color=\"red\"]\n", i, ch.Target())
   768  					} else {
   769  						fmt.Fprintf(f, "\t\tnode%d -> node%d [label=\" <=\" color=\"green\"]\n", i, ch.Target())
   770  					}
   771  				}
   772  			}
   773  			return false
   774  		})
   775  		fmt.Fprintf(f, "\t}\n")
   776  	}
   777  	fmt.Fprintf(f, "\tlabelloc=\"t\"\n")
   778  	fmt.Fprintf(f, "\tlabeldistance=\"3.0\"\n")
   779  	fmt.Fprintf(f, "\tlabel=%q\n", title)
   780  	fmt.Fprintf(f, "}\n")
   781  	return nil
   782  }
   783  
   784  // Ordered reports whether n1<n2. It returns false either when it is
   785  // certain that n1<n2 is false, or if there is not enough information
   786  // to tell.
   787  // Complexity is O(n).
   788  func (po *poset) Ordered(n1, n2 *Value) bool {
   789  	if n1.ID == n2.ID {
   790  		panic("should not call Ordered with n1==n2")
   791  	}
   792  
   793  	i1, f1 := po.lookup(n1)
   794  	i2, f2 := po.lookup(n2)
   795  	if !f1 || !f2 {
   796  		return false
   797  	}
   798  
   799  	return i1 != i2 && po.dominates(i1, i2, true)
   800  }
   801  
   802  // Ordered reports whether n1<=n2. It returns false either when it is
   803  // certain that n1<=n2 is false, or if there is not enough information
   804  // to tell.
   805  // Complexity is O(n).
   806  func (po *poset) OrderedOrEqual(n1, n2 *Value) bool {
   807  	if n1.ID == n2.ID {
   808  		panic("should not call Ordered with n1==n2")
   809  	}
   810  
   811  	i1, f1 := po.lookup(n1)
   812  	i2, f2 := po.lookup(n2)
   813  	if !f1 || !f2 {
   814  		return false
   815  	}
   816  
   817  	return i1 == i2 || po.dominates(i1, i2, false) ||
   818  		(po.dominates(i2, i1, false) && !po.dominates(i2, i1, true))
   819  }
   820  
   821  // Equal reports whether n1==n2. It returns false either when it is
   822  // certain that n1==n2 is false, or if there is not enough information
   823  // to tell.
   824  // Complexity is O(1).
   825  func (po *poset) Equal(n1, n2 *Value) bool {
   826  	if n1.ID == n2.ID {
   827  		panic("should not call Equal with n1==n2")
   828  	}
   829  
   830  	i1, f1 := po.lookup(n1)
   831  	i2, f2 := po.lookup(n2)
   832  	return f1 && f2 && i1 == i2
   833  }
   834  
   835  // NonEqual reports whether n1!=n2. It returns false either when it is
   836  // certain that n1!=n2 is false, or if there is not enough information
   837  // to tell.
   838  // Complexity is O(n) (because it internally calls Ordered to see if we
   839  // can infer n1!=n2 from n1<n2 or n2<n1).
   840  func (po *poset) NonEqual(n1, n2 *Value) bool {
   841  	if n1.ID == n2.ID {
   842  		panic("should not call Equal with n1==n2")
   843  	}
   844  	if po.isnoneq(n1.ID, n2.ID) {
   845  		return true
   846  	}
   847  
   848  	// Check if n1<n2 or n2<n1, in which case we can infer that n1!=n2
   849  	if po.Ordered(n1, n2) || po.Ordered(n2, n1) {
   850  		return true
   851  	}
   852  
   853  	return false
   854  }
   855  
   856  // setOrder records that n1<n2 or n1<=n2 (depending on strict).
   857  // Implements SetOrder() and SetOrderOrEqual()
   858  func (po *poset) setOrder(n1, n2 *Value, strict bool) bool {
   859  	// If we are trying to record n1<=n2 but we learned that n1!=n2,
   860  	// record n1<n2, as it provides more information.
   861  	if !strict && po.isnoneq(n1.ID, n2.ID) {
   862  		strict = true
   863  	}
   864  
   865  	i1, f1 := po.lookup(n1)
   866  	i2, f2 := po.lookup(n2)
   867  
   868  	switch {
   869  	case !f1 && !f2:
   870  		// Neither n1 nor n2 are in the poset, so they are not related
   871  		// in any way to existing nodes.
   872  		// Create a new DAG to record the relation.
   873  		i1, i2 = po.newnode(n1), po.newnode(n2)
   874  		po.roots = append(po.roots, i1)
   875  		po.upush(undoNewRoot, i1, 0)
   876  		po.addchild(i1, i2, strict)
   877  
   878  	case f1 && !f2:
   879  		// n1 is in one of the DAGs, while n2 is not. Add n2 as children
   880  		// of n1.
   881  		i2 = po.newnode(n2)
   882  		po.addchild(i1, i2, strict)
   883  
   884  	case !f1 && f2:
   885  		// n1 is not in any DAG but n2 is. If n2 is a root, we can put
   886  		// n1 in its place as a root; otherwise, we need to create a new
   887  		// dummy root to record the relation.
   888  		i1 = po.newnode(n1)
   889  
   890  		if po.isroot(i2) {
   891  			po.changeroot(i2, i1)
   892  			po.upush(undoChangeRoot, i1, newedge(i2, strict))
   893  			po.addchild(i1, i2, strict)
   894  			return true
   895  		}
   896  
   897  		// Search for i2's root; this requires a O(n) search on all
   898  		// DAGs
   899  		r := po.findroot(i2)
   900  
   901  		// Re-parent as follows:
   902  		//
   903  		//                  dummy
   904  		//     r            /   \
   905  		//      \   ===>   r    i1
   906  		//      i2          \   /
   907  		//                    i2
   908  		//
   909  		dummy := po.newnode(nil)
   910  		po.changeroot(r, dummy)
   911  		po.upush(undoChangeRoot, dummy, newedge(r, false))
   912  		po.addchild(dummy, r, false)
   913  		po.addchild(dummy, i1, false)
   914  		po.addchild(i1, i2, strict)
   915  
   916  	case f1 && f2:
   917  		// If the nodes are aliased, fail only if we're setting a strict order
   918  		// (that is, we cannot set n1<n2 if n1==n2).
   919  		if i1 == i2 {
   920  			return !strict
   921  		}
   922  
   923  		// Both n1 and n2 are in the poset. This is the complex part of the algorithm
   924  		// as we need to find many different cases and DAG shapes.
   925  
   926  		// Check if n1 somehow dominates n2
   927  		if po.dominates(i1, i2, false) {
   928  			// This is the table of all cases we need to handle:
   929  			//
   930  			//      DAG          New      Action
   931  			//      ---------------------------------------------------
   932  			// #1:  N1<=X<=N2 |  N1<=N2 | do nothing
   933  			// #2:  N1<=X<=N2 |  N1<N2  | add strict edge (N1<N2)
   934  			// #3:  N1<X<N2   |  N1<=N2 | do nothing (we already know more)
   935  			// #4:  N1<X<N2   |  N1<N2  | do nothing
   936  
   937  			// Check if we're in case #2
   938  			if strict && !po.dominates(i1, i2, true) {
   939  				po.addchild(i1, i2, true)
   940  				return true
   941  			}
   942  
   943  			// Case #1, #3 o #4: nothing to do
   944  			return true
   945  		}
   946  
   947  		// Check if n2 somehow dominates n1
   948  		if po.dominates(i2, i1, false) {
   949  			// This is the table of all cases we need to handle:
   950  			//
   951  			//      DAG           New      Action
   952  			//      ---------------------------------------------------
   953  			// #5:  N2<=X<=N1  |  N1<=N2 | collapse path (learn that N1=X=N2)
   954  			// #6:  N2<=X<=N1  |  N1<N2  | contradiction
   955  			// #7:  N2<X<N1    |  N1<=N2 | contradiction in the path
   956  			// #8:  N2<X<N1    |  N1<N2  | contradiction
   957  
   958  			if strict {
   959  				// Cases #6 and #8: contradiction
   960  				return false
   961  			}
   962  
   963  			// We're in case #5 or #7. Try to collapse path, and that will
   964  			// fail if it realizes that we are in case #7.
   965  			return po.collapsepath(n2, n1)
   966  		}
   967  
   968  		// We don't know of any existing relation between n1 and n2. They could
   969  		// be part of the same DAG or not.
   970  		// Find their roots to check whether they are in the same DAG.
   971  		r1, r2 := po.findroot(i1), po.findroot(i2)
   972  		if r1 != r2 {
   973  			// We need to merge the two DAGs to record a relation between the nodes
   974  			po.mergeroot(r1, r2)
   975  		}
   976  
   977  		// Connect n1 and n2
   978  		po.addchild(i1, i2, strict)
   979  	}
   980  
   981  	return true
   982  }
   983  
   984  // SetOrder records that n1<n2. Returns false if this is a contradiction
   985  // Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
   986  func (po *poset) SetOrder(n1, n2 *Value) bool {
   987  	if n1.ID == n2.ID {
   988  		panic("should not call SetOrder with n1==n2")
   989  	}
   990  	return po.setOrder(n1, n2, true)
   991  }
   992  
   993  // SetOrderOrEqual records that n1<=n2. Returns false if this is a contradiction
   994  // Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
   995  func (po *poset) SetOrderOrEqual(n1, n2 *Value) bool {
   996  	if n1.ID == n2.ID {
   997  		panic("should not call SetOrder with n1==n2")
   998  	}
   999  	return po.setOrder(n1, n2, false)
  1000  }
  1001  
  1002  // SetEqual records that n1==n2. Returns false if this is a contradiction
  1003  // (that is, if it is already recorded that n1<n2 or n2<n1).
  1004  // Complexity is O(1) if n2 was never seen before, or O(n) otherwise.
  1005  func (po *poset) SetEqual(n1, n2 *Value) bool {
  1006  	if n1.ID == n2.ID {
  1007  		panic("should not call Add with n1==n2")
  1008  	}
  1009  
  1010  	// If we recorded that n1!=n2, this is a contradiction.
  1011  	if po.isnoneq(n1.ID, n2.ID) {
  1012  		return false
  1013  	}
  1014  
  1015  	i1, f1 := po.lookup(n1)
  1016  	i2, f2 := po.lookup(n2)
  1017  
  1018  	switch {
  1019  	case !f1 && !f2:
  1020  		i1 = po.newnode(n1)
  1021  		po.roots = append(po.roots, i1)
  1022  		po.upush(undoNewRoot, i1, 0)
  1023  		po.aliasnode(n1, n2)
  1024  	case f1 && !f2:
  1025  		po.aliasnode(n1, n2)
  1026  	case !f1 && f2:
  1027  		po.aliasnode(n2, n1)
  1028  	case f1 && f2:
  1029  		if i1 == i2 {
  1030  			// Already aliased, ignore
  1031  			return true
  1032  		}
  1033  
  1034  		// If we already knew that n1<=n2, we can collapse the path to
  1035  		// record n1==n2 (and viceversa).
  1036  		if po.dominates(i1, i2, false) {
  1037  			return po.collapsepath(n1, n2)
  1038  		}
  1039  		if po.dominates(i2, i1, false) {
  1040  			return po.collapsepath(n2, n1)
  1041  		}
  1042  
  1043  		r1 := po.findroot(i1)
  1044  		r2 := po.findroot(i2)
  1045  		if r1 != r2 {
  1046  			// Merge the two DAGs so we can record relations between the nodes
  1047  			po.mergeroot(r1, r2)
  1048  		}
  1049  
  1050  		// Set n2 as alias of n1. This will also update all the references
  1051  		// to n2 to become references to n1
  1052  		po.aliasnode(n1, n2)
  1053  
  1054  		// Connect i2 (now dummy) as child of i1. This allows to keep the correct
  1055  		// order with its children.
  1056  		po.addchild(i1, i2, false)
  1057  	}
  1058  	return true
  1059  }
  1060  
  1061  // SetNonEqual records that n1!=n2. Returns false if this is a contradiction
  1062  // (that is, if it is already recorded that n1==n2).
  1063  // Complexity is O(n).
  1064  func (po *poset) SetNonEqual(n1, n2 *Value) bool {
  1065  	if n1.ID == n2.ID {
  1066  		panic("should not call Equal with n1==n2")
  1067  	}
  1068  
  1069  	// See if we already know this
  1070  	if po.isnoneq(n1.ID, n2.ID) {
  1071  		return true
  1072  	}
  1073  
  1074  	// Check if we're contradicting an existing relation
  1075  	if po.Equal(n1, n2) {
  1076  		return false
  1077  	}
  1078  
  1079  	// Record non-equality
  1080  	po.setnoneq(n1.ID, n2.ID)
  1081  
  1082  	// If we know that i1<=i2 but not i1<i2, learn that as we
  1083  	// now know that they are not equal. Do the same for i2<=i1.
  1084  	i1, f1 := po.lookup(n1)
  1085  	i2, f2 := po.lookup(n2)
  1086  	if f1 && f2 {
  1087  		if po.dominates(i1, i2, false) && !po.dominates(i1, i2, true) {
  1088  			po.addchild(i1, i2, true)
  1089  		}
  1090  		if po.dominates(i2, i1, false) && !po.dominates(i2, i1, true) {
  1091  			po.addchild(i2, i1, true)
  1092  		}
  1093  	}
  1094  
  1095  	return true
  1096  }
  1097  
  1098  // Checkpoint saves the current state of the DAG so that it's possible
  1099  // to later undo this state.
  1100  // Complexity is O(1).
  1101  func (po *poset) Checkpoint() {
  1102  	po.undo = append(po.undo, posetUndo{typ: undoCheckpoint})
  1103  }
  1104  
  1105  // Undo restores the state of the poset to the previous checkpoint.
  1106  // Complexity depends on the type of operations that were performed
  1107  // since the last checkpoint; each Set* operation creates an undo
  1108  // pass which Undo has to revert with a worst-case complexity of O(n).
  1109  func (po *poset) Undo() {
  1110  	if len(po.undo) == 0 {
  1111  		panic("empty undo stack")
  1112  	}
  1113  
  1114  	for len(po.undo) > 0 {
  1115  		pass := po.undo[len(po.undo)-1]
  1116  		po.undo = po.undo[:len(po.undo)-1]
  1117  
  1118  		switch pass.typ {
  1119  		case undoCheckpoint:
  1120  			return
  1121  
  1122  		case undoSetChl:
  1123  			po.setchl(pass.idx, pass.edge)
  1124  
  1125  		case undoSetChr:
  1126  			po.setchr(pass.idx, pass.edge)
  1127  
  1128  		case undoNonEqual:
  1129  			po.noneq[pass.ID].Clear(pass.idx)
  1130  
  1131  		case undoNewNode:
  1132  			if pass.idx != po.lastidx {
  1133  				panic("invalid newnode index")
  1134  			}
  1135  			if pass.ID != 0 {
  1136  				if po.values[pass.ID] != pass.idx {
  1137  					panic("invalid newnode undo pass")
  1138  				}
  1139  				delete(po.values, pass.ID)
  1140  			}
  1141  			po.setchl(pass.idx, 0)
  1142  			po.setchr(pass.idx, 0)
  1143  			po.nodes = po.nodes[:pass.idx]
  1144  			po.lastidx--
  1145  
  1146  			// If it was the last inserted constant, remove it
  1147  			nc := len(po.constants)
  1148  			if nc > 0 && po.constants[nc-1].ID == pass.ID {
  1149  				po.constants = po.constants[:nc-1]
  1150  			}
  1151  
  1152  		case undoAliasNode:
  1153  			ID, prev := pass.ID, pass.idx
  1154  			cur := po.values[ID]
  1155  			if prev == 0 {
  1156  				// Born as an alias, die as an alias
  1157  				delete(po.values, ID)
  1158  			} else {
  1159  				if cur == prev {
  1160  					panic("invalid aliasnode undo pass")
  1161  				}
  1162  				// Give it back previous value
  1163  				po.values[ID] = prev
  1164  			}
  1165  
  1166  		case undoNewRoot:
  1167  			i := pass.idx
  1168  			l, r := po.children(i)
  1169  			if l|r != 0 {
  1170  				panic("non-empty root in undo newroot")
  1171  			}
  1172  			po.removeroot(i)
  1173  
  1174  		case undoChangeRoot:
  1175  			i := pass.idx
  1176  			l, r := po.children(i)
  1177  			if l|r != 0 {
  1178  				panic("non-empty root in undo changeroot")
  1179  			}
  1180  			po.changeroot(i, pass.edge.Target())
  1181  
  1182  		case undoMergeRoot:
  1183  			i := pass.idx
  1184  			l, r := po.children(i)
  1185  			po.changeroot(i, l.Target())
  1186  			po.roots = append(po.roots, r.Target())
  1187  
  1188  		default:
  1189  			panic(pass.typ)
  1190  		}
  1191  	}
  1192  }
  1193  

View as plain text