Source file src/cmd/compile/internal/compare/compare.go

     1  // Copyright 2022 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package compare contains code for generating comparison
     6  // routines for structs, strings and interfaces.
     7  package compare
     8  
     9  import (
    10  	"cmd/compile/internal/base"
    11  	"cmd/compile/internal/ir"
    12  	"cmd/compile/internal/typecheck"
    13  	"cmd/compile/internal/types"
    14  	"fmt"
    15  	"math/bits"
    16  	"sort"
    17  )
    18  
    19  // IsRegularMemory reports whether t can be compared/hashed as regular memory.
    20  func IsRegularMemory(t *types.Type) bool {
    21  	a, _ := types.AlgType(t)
    22  	return a == types.AMEM
    23  }
    24  
    25  // Memrun finds runs of struct fields for which memory-only algs are appropriate.
    26  // t is the parent struct type, and start is the field index at which to start the run.
    27  // size is the length in bytes of the memory included in the run.
    28  // next is the index just after the end of the memory run.
    29  func Memrun(t *types.Type, start int) (size int64, next int) {
    30  	next = start
    31  	for {
    32  		next++
    33  		if next == t.NumFields() {
    34  			break
    35  		}
    36  		// Stop run after a padded field.
    37  		if types.IsPaddedField(t, next-1) {
    38  			break
    39  		}
    40  		// Also, stop before a blank or non-memory field.
    41  		if f := t.Field(next); f.Sym.IsBlank() || !IsRegularMemory(f.Type) {
    42  			break
    43  		}
    44  		// For issue 46283, don't combine fields if the resulting load would
    45  		// require a larger alignment than the component fields.
    46  		if base.Ctxt.Arch.Alignment > 1 {
    47  			align := t.Alignment()
    48  			if off := t.Field(start).Offset; off&(align-1) != 0 {
    49  				// Offset is less aligned than the containing type.
    50  				// Use offset to determine alignment.
    51  				align = 1 << uint(bits.TrailingZeros64(uint64(off)))
    52  			}
    53  			size := t.Field(next).End() - t.Field(start).Offset
    54  			if size > align {
    55  				break
    56  			}
    57  		}
    58  	}
    59  	return t.Field(next-1).End() - t.Field(start).Offset, next
    60  }
    61  
    62  // EqCanPanic reports whether == on type t could panic (has an interface somewhere).
    63  // t must be comparable.
    64  func EqCanPanic(t *types.Type) bool {
    65  	switch t.Kind() {
    66  	default:
    67  		return false
    68  	case types.TINTER:
    69  		return true
    70  	case types.TARRAY:
    71  		return EqCanPanic(t.Elem())
    72  	case types.TSTRUCT:
    73  		for _, f := range t.Fields() {
    74  			if !f.Sym.IsBlank() && EqCanPanic(f.Type) {
    75  				return true
    76  			}
    77  		}
    78  		return false
    79  	}
    80  }
    81  
    82  // EqStructCost returns the cost of an equality comparison of two structs.
    83  //
    84  // The cost is determined using an algorithm which takes into consideration
    85  // the size of the registers in the current architecture and the size of the
    86  // memory-only fields in the struct.
    87  func EqStructCost(t *types.Type) int64 {
    88  	cost := int64(0)
    89  
    90  	for i, fields := 0, t.Fields(); i < len(fields); {
    91  		f := fields[i]
    92  
    93  		// Skip blank-named fields.
    94  		if f.Sym.IsBlank() {
    95  			i++
    96  			continue
    97  		}
    98  
    99  		n, _, next := eqStructFieldCost(t, i)
   100  
   101  		cost += n
   102  		i = next
   103  	}
   104  
   105  	return cost
   106  }
   107  
   108  // eqStructFieldCost returns the cost of an equality comparison of two struct fields.
   109  // t is the parent struct type, and i is the index of the field in the parent struct type.
   110  // eqStructFieldCost may compute the cost of several adjacent fields at once. It returns
   111  // the cost, the size of the set of fields it computed the cost for (in bytes), and the
   112  // index of the first field not part of the set of fields for which the cost
   113  // has already been calculated.
   114  func eqStructFieldCost(t *types.Type, i int) (int64, int64, int) {
   115  	var (
   116  		cost    = int64(0)
   117  		regSize = int64(types.RegSize)
   118  
   119  		size int64
   120  		next int
   121  	)
   122  
   123  	if base.Ctxt.Arch.CanMergeLoads {
   124  		// If we can merge adjacent loads then we can calculate the cost of the
   125  		// comparison using the size of the memory run and the size of the registers.
   126  		size, next = Memrun(t, i)
   127  		cost = size / regSize
   128  		if size%regSize != 0 {
   129  			cost++
   130  		}
   131  		return cost, size, next
   132  	}
   133  
   134  	// If we cannot merge adjacent loads then we have to use the size of the
   135  	// field and take into account the type to determine how many loads and compares
   136  	// are needed.
   137  	ft := t.Field(i).Type
   138  	size = ft.Size()
   139  	next = i + 1
   140  
   141  	return calculateCostForType(ft), size, next
   142  }
   143  
   144  func calculateCostForType(t *types.Type) int64 {
   145  	var cost int64
   146  	switch t.Kind() {
   147  	case types.TSTRUCT:
   148  		return EqStructCost(t)
   149  	case types.TSLICE:
   150  		// Slices are not comparable.
   151  		base.Fatalf("calculateCostForType: unexpected slice type")
   152  	case types.TARRAY:
   153  		elemCost := calculateCostForType(t.Elem())
   154  		cost = t.NumElem() * elemCost
   155  	case types.TSTRING, types.TINTER, types.TCOMPLEX64, types.TCOMPLEX128:
   156  		cost = 2
   157  	case types.TINT64, types.TUINT64:
   158  		cost = 8 / int64(types.RegSize)
   159  	default:
   160  		cost = 1
   161  	}
   162  	return cost
   163  }
   164  
   165  // EqStruct compares two structs np and nq for equality.
   166  // It works by building a list of boolean conditions to satisfy.
   167  // Conditions must be evaluated in the returned order and
   168  // properly short-circuited by the caller.
   169  // The first return value is the flattened list of conditions,
   170  // the second value is a boolean indicating whether any of the
   171  // comparisons could panic.
   172  func EqStruct(t *types.Type, np, nq ir.Node) ([]ir.Node, bool) {
   173  	// The conditions are a list-of-lists. Conditions are reorderable
   174  	// within each inner list. The outer lists must be evaluated in order.
   175  	var conds [][]ir.Node
   176  	conds = append(conds, []ir.Node{})
   177  	and := func(n ir.Node) {
   178  		i := len(conds) - 1
   179  		conds[i] = append(conds[i], n)
   180  	}
   181  
   182  	// Walk the struct using memequal for runs of AMEM
   183  	// and calling specific equality tests for the others.
   184  	for i, fields := 0, t.Fields(); i < len(fields); {
   185  		f := fields[i]
   186  
   187  		// Skip blank-named fields.
   188  		if f.Sym.IsBlank() {
   189  			i++
   190  			continue
   191  		}
   192  
   193  		typeCanPanic := EqCanPanic(f.Type)
   194  
   195  		// Compare non-memory fields with field equality.
   196  		if !IsRegularMemory(f.Type) {
   197  			if typeCanPanic {
   198  				// Enforce ordering by starting a new set of reorderable conditions.
   199  				conds = append(conds, []ir.Node{})
   200  			}
   201  			switch {
   202  			case f.Type.IsString():
   203  				p := typecheck.DotField(base.Pos, typecheck.Expr(np), i)
   204  				q := typecheck.DotField(base.Pos, typecheck.Expr(nq), i)
   205  				eqlen, eqmem := EqString(p, q)
   206  				and(eqlen)
   207  				and(eqmem)
   208  			default:
   209  				and(eqfield(np, nq, i))
   210  			}
   211  			if typeCanPanic {
   212  				// Also enforce ordering after something that can panic.
   213  				conds = append(conds, []ir.Node{})
   214  			}
   215  			i++
   216  			continue
   217  		}
   218  
   219  		cost, size, next := eqStructFieldCost(t, i)
   220  		if cost <= 4 {
   221  			// Cost of 4 or less: use plain field equality.
   222  			for j := i; j < next; j++ {
   223  				and(eqfield(np, nq, j))
   224  			}
   225  		} else {
   226  			// Higher cost: use memequal.
   227  			cc := eqmem(np, nq, i, size)
   228  			and(cc)
   229  		}
   230  		i = next
   231  	}
   232  
   233  	// Sort conditions to put runtime calls last.
   234  	// Preserve the rest of the ordering.
   235  	var flatConds []ir.Node
   236  	for _, c := range conds {
   237  		isCall := func(n ir.Node) bool {
   238  			return n.Op() == ir.OCALL || n.Op() == ir.OCALLFUNC
   239  		}
   240  		sort.SliceStable(c, func(i, j int) bool {
   241  			return !isCall(c[i]) && isCall(c[j])
   242  		})
   243  		flatConds = append(flatConds, c...)
   244  	}
   245  	return flatConds, len(conds) > 1
   246  }
   247  
   248  // EqString returns the nodes
   249  //
   250  //	len(s) == len(t)
   251  //
   252  // and
   253  //
   254  //	memequal(s.ptr, t.ptr, len(s))
   255  //
   256  // which can be used to construct string equality comparison.
   257  // eqlen must be evaluated before eqmem, and shortcircuiting is required.
   258  func EqString(s, t ir.Node) (eqlen *ir.BinaryExpr, eqmem *ir.CallExpr) {
   259  	s = typecheck.Conv(s, types.Types[types.TSTRING])
   260  	t = typecheck.Conv(t, types.Types[types.TSTRING])
   261  	sptr := ir.NewUnaryExpr(base.Pos, ir.OSPTR, s)
   262  	tptr := ir.NewUnaryExpr(base.Pos, ir.OSPTR, t)
   263  	slen := typecheck.Conv(ir.NewUnaryExpr(base.Pos, ir.OLEN, s), types.Types[types.TUINTPTR])
   264  	tlen := typecheck.Conv(ir.NewUnaryExpr(base.Pos, ir.OLEN, t), types.Types[types.TUINTPTR])
   265  
   266  	// Pick the 3rd arg to memequal. Both slen and tlen are fine to use, because we short
   267  	// circuit the memequal call if they aren't the same. But if one is a constant some
   268  	// memequal optimizations are easier to apply.
   269  	probablyConstant := func(n ir.Node) bool {
   270  		if n.Op() == ir.OCONVNOP {
   271  			n = n.(*ir.ConvExpr).X
   272  		}
   273  		if n.Op() == ir.OLITERAL {
   274  			return true
   275  		}
   276  		if n.Op() != ir.ONAME {
   277  			return false
   278  		}
   279  		name := n.(*ir.Name)
   280  		if name.Class != ir.PAUTO {
   281  			return false
   282  		}
   283  		if def := name.Defn; def == nil {
   284  			// n starts out as the empty string
   285  			return true
   286  		} else if def.Op() == ir.OAS && (def.(*ir.AssignStmt).Y == nil || def.(*ir.AssignStmt).Y.Op() == ir.OLITERAL) {
   287  			// n starts out as a constant string
   288  			return true
   289  		}
   290  		return false
   291  	}
   292  	cmplen := slen
   293  	if probablyConstant(t) && !probablyConstant(s) {
   294  		cmplen = tlen
   295  	}
   296  
   297  	fn := typecheck.LookupRuntime("memequal", types.Types[types.TUINT8], types.Types[types.TUINT8])
   298  	call := typecheck.Call(base.Pos, fn, []ir.Node{sptr, tptr, ir.Copy(cmplen)}, false).(*ir.CallExpr)
   299  
   300  	cmp := ir.NewBinaryExpr(base.Pos, ir.OEQ, slen, tlen)
   301  	cmp = typecheck.Expr(cmp).(*ir.BinaryExpr)
   302  	cmp.SetType(types.Types[types.TBOOL])
   303  	return cmp, call
   304  }
   305  
   306  // EqInterface returns the nodes
   307  //
   308  //	s.tab == t.tab (or s.typ == t.typ, as appropriate)
   309  //
   310  // and
   311  //
   312  //	ifaceeq(s.tab, s.data, t.data) (or efaceeq(s.typ, s.data, t.data), as appropriate)
   313  //
   314  // which can be used to construct interface equality comparison.
   315  // eqtab must be evaluated before eqdata, and shortcircuiting is required.
   316  func EqInterface(s, t ir.Node) (eqtab *ir.BinaryExpr, eqdata *ir.CallExpr) {
   317  	if !types.Identical(s.Type(), t.Type()) {
   318  		base.Fatalf("EqInterface %v %v", s.Type(), t.Type())
   319  	}
   320  	// func ifaceeq(tab *uintptr, x, y unsafe.Pointer) (ret bool)
   321  	// func efaceeq(typ *uintptr, x, y unsafe.Pointer) (ret bool)
   322  	var fn ir.Node
   323  	if s.Type().IsEmptyInterface() {
   324  		fn = typecheck.LookupRuntime("efaceeq")
   325  	} else {
   326  		fn = typecheck.LookupRuntime("ifaceeq")
   327  	}
   328  
   329  	stab := ir.NewUnaryExpr(base.Pos, ir.OITAB, s)
   330  	ttab := ir.NewUnaryExpr(base.Pos, ir.OITAB, t)
   331  	sdata := ir.NewUnaryExpr(base.Pos, ir.OIDATA, s)
   332  	tdata := ir.NewUnaryExpr(base.Pos, ir.OIDATA, t)
   333  	sdata.SetType(types.Types[types.TUNSAFEPTR])
   334  	tdata.SetType(types.Types[types.TUNSAFEPTR])
   335  	sdata.SetTypecheck(1)
   336  	tdata.SetTypecheck(1)
   337  
   338  	call := typecheck.Call(base.Pos, fn, []ir.Node{stab, sdata, tdata}, false).(*ir.CallExpr)
   339  
   340  	cmp := ir.NewBinaryExpr(base.Pos, ir.OEQ, stab, ttab)
   341  	cmp = typecheck.Expr(cmp).(*ir.BinaryExpr)
   342  	cmp.SetType(types.Types[types.TBOOL])
   343  	return cmp, call
   344  }
   345  
   346  // eqfield returns the node
   347  //
   348  //	p.field == q.field
   349  func eqfield(p, q ir.Node, field int) ir.Node {
   350  	nx := typecheck.DotField(base.Pos, typecheck.Expr(p), field)
   351  	ny := typecheck.DotField(base.Pos, typecheck.Expr(q), field)
   352  	return typecheck.Expr(ir.NewBinaryExpr(base.Pos, ir.OEQ, nx, ny))
   353  }
   354  
   355  // eqmem returns the node
   356  //
   357  //	memequal(&p.field, &q.field, size)
   358  func eqmem(p, q ir.Node, field int, size int64) ir.Node {
   359  	nx := typecheck.Expr(typecheck.NodAddr(typecheck.DotField(base.Pos, p, field)))
   360  	ny := typecheck.Expr(typecheck.NodAddr(typecheck.DotField(base.Pos, q, field)))
   361  
   362  	fn, needsize := eqmemfunc(size, nx.Type().Elem())
   363  	call := ir.NewCallExpr(base.Pos, ir.OCALL, fn, nil)
   364  	call.Args.Append(nx)
   365  	call.Args.Append(ny)
   366  	if needsize {
   367  		call.Args.Append(ir.NewInt(base.Pos, size))
   368  	}
   369  
   370  	return call
   371  }
   372  
   373  func eqmemfunc(size int64, t *types.Type) (fn *ir.Name, needsize bool) {
   374  	if !base.Ctxt.Arch.CanMergeLoads && t.Alignment() < int64(base.Ctxt.Arch.Alignment) && t.Alignment() < t.Size() {
   375  		// We can't use larger comparisons if the value might not be aligned
   376  		// enough for the larger comparison. See issues 46283 and 67160.
   377  		size = 0
   378  	}
   379  	switch size {
   380  	case 1, 2, 4, 8, 16:
   381  		buf := fmt.Sprintf("memequal%d", int(size)*8)
   382  		return typecheck.LookupRuntime(buf, t, t), false
   383  	}
   384  
   385  	return typecheck.LookupRuntime("memequal", t, t), true
   386  }
   387  

View as plain text