diff --git a/btf/core.go b/btf/core.go index a3d311a06..4bc644474 100644 --- a/btf/core.go +++ b/btf/core.go @@ -127,10 +127,11 @@ const ( reloTypeSize /* type size in bytes */ reloEnumvalExists /* enum value existence in target kernel */ reloEnumvalValue /* enum value integer value */ + reloTypeMatches /* type matches kernel type */ ) func (k coreKind) checksForExistence() bool { - return k == reloEnumvalExists || k == reloTypeExists || k == reloFieldExists + return k == reloEnumvalExists || k == reloTypeExists || k == reloFieldExists || k == reloTypeMatches } func (k coreKind) String() string { @@ -159,8 +160,10 @@ func (k coreKind) String() string { return "enumval_exists" case reloEnumvalValue: return "enumval_value" + case reloTypeMatches: + return "type_matches" default: - return "unknown" + return fmt.Sprintf("unknown (%d)", k) } } @@ -369,6 +372,21 @@ func coreCalculateFixup(relo *CORERelocation, target Type, targetID TypeID, bo b local := relo.typ switch relo.kind { + case reloTypeMatches: + if len(relo.accessor) > 1 || relo.accessor[0] != 0 { + return zero, fmt.Errorf("unexpected accessor %v", relo.accessor) + } + + err := coreTypesMatch(local, target, false, nil) + if errors.Is(err, errIncompatibleTypes) { + return poison() + } + if err != nil { + return zero, err + } + + return fixup(1, 1) + case reloTypeIDTarget, reloTypeSize, reloTypeExists: if len(relo.accessor) > 1 || relo.accessor[0] != 0 { return zero, fmt.Errorf("unexpected accessor %v", relo.accessor) @@ -1016,19 +1034,6 @@ func coreAreMembersCompatible(localType Type, targetType Type) error { localType = UnderlyingType(localType) targetType = UnderlyingType(targetType) - doNamesMatch := func(a, b string) error { - if a == "" || b == "" { - // allow anonymous and named type to match - return nil - } - - if newEssentialName(a) == newEssentialName(b) { - return nil - } - - return fmt.Errorf("names don't match: %w", errImpossibleRelocation) - } - _, lok := localType.(composite) _, tok := targetType.(composite) if lok && tok { @@ -1045,13 +1050,295 @@ func coreAreMembersCompatible(localType Type, targetType Type) error { case *Enum: tv := targetType.(*Enum) - return doNamesMatch(lv.Name, tv.Name) + if !coreNamesMatch(lv.Name, tv.Name) { + return fmt.Errorf("names %q and %q don't match: %w", lv.Name, tv.Name, errImpossibleRelocation) + } + + return nil case *Fwd: tv := targetType.(*Fwd) - return doNamesMatch(lv.Name, tv.Name) + if !coreNamesMatch(lv.Name, tv.Name) { + return fmt.Errorf("names %q and %q don't match: %w", lv.Name, tv.Name, errImpossibleRelocation) + } + + return nil default: return fmt.Errorf("type %s: %w", localType, ErrNotSupported) } } + +func coreNamesMatch(a, b string) bool { + if a == "" || b == "" { + // allow anonymous and named type to match + return true + } + + return newEssentialName(a) == newEssentialName(b) +} + +/* The comment below is from __bpf_core_types_match in relo_core.c: + * + * Check that two types "match". This function assumes that root types were + * already checked for name match. + * + * The matching relation is defined as follows: + * - modifiers and typedefs are stripped (and, hence, effectively ignored) + * - generally speaking types need to be of same kind (struct vs. struct, union + * vs. union, etc.) + * - exceptions are struct/union behind a pointer which could also match a + * forward declaration of a struct or union, respectively, and enum vs. + * enum64 (see below) + * Then, depending on type: + * - integers: + * - match if size and signedness match + * - arrays & pointers: + * - target types are recursively matched + * - structs & unions: + * - local members need to exist in target with the same name + * - for each member we recursively check match unless it is already behind a + * pointer, in which case we only check matching names and compatible kind + * - enums: + * - local variants have to have a match in target by symbolic name (but not + * numeric value) + * - size has to match (but enum may match enum64 and vice versa) + * - function pointers: + * - number and position of arguments in local type has to match target + * - for each argument and the return value we recursively check match + */ +func coreTypesMatch(localType Type, targetType Type, behindPtr bool, visited map[pair]struct{}) error { + localType = UnderlyingType(localType) + targetType = UnderlyingType(targetType) + + if !coreNamesMatch(localType.TypeName(), targetType.TypeName()) { + return fmt.Errorf("type name %q don't match %q: %w", localType.TypeName(), targetType.TypeName(), errIncompatibleTypes) + } + + if _, ok := visited[pair{localType, targetType}]; ok { + return nil + } + if visited == nil { + visited = make(map[pair]struct{}) + } + visited[pair{localType, targetType}] = struct{}{} + + switch lv := (localType).(type) { + case *Void: + if _, ok := targetType.(*Void); !ok { + return fmt.Errorf("type mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes) + } + + case *Fwd: + if behindPtr { + if tv, ok := targetType.(*Fwd); ok { + if lv.Kind != tv.Kind { + return fmt.Errorf("fwd kind mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes) + } + + return nil + } + + if _, ok := targetType.(*Struct); ok && lv.Kind == FwdStruct { + return nil + } + + if _, ok := targetType.(*Union); ok && lv.Kind == FwdUnion { + return nil + } + + return fmt.Errorf("fwd kind mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes) + } + + if tv, ok := targetType.(*Fwd); ok && tv.Kind == lv.Kind { + return nil + } + + return fmt.Errorf("fwd kind mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes) + + case *Enum: + tv, ok := targetType.(*Enum) + if !ok { + return fmt.Errorf("type mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes) + } + + if err := coreEnumsMatch(lv, tv); err != nil { + return err + } + + case *Struct, *Union: + if behindPtr { + if reflect.TypeOf(localType) == reflect.TypeOf(targetType) { + return nil + } + + tv, ok := targetType.(*Fwd) + if !ok { + return fmt.Errorf("type mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes) + } + + if _, ok := lv.(*Struct); ok && tv.Kind == FwdStruct { + return nil + } + + if _, ok := lv.(*Union); ok && tv.Kind == FwdUnion { + return nil + } + + return fmt.Errorf("fwd kind mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes) + } + + if reflect.TypeOf(localType) != reflect.TypeOf(targetType) { + return fmt.Errorf("type mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes) + } + + lComp, ok := lv.(composite) + if !ok { + return fmt.Errorf("expected composite type, got %T", lv) + } + + tComp, ok := targetType.(composite) + if !ok { + return fmt.Errorf("expected composite type, got %T", targetType) + } + + if err := coreCompositesMatch(lComp, tComp, behindPtr, visited); err != nil { + return err + } + + case *Int: + tv, ok := targetType.(*Int) + if !ok { + return fmt.Errorf("type mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes) + } + + if lv.Size != tv.Size || (lv.Encoding == Signed) != (tv.Encoding == Signed) { + return fmt.Errorf("int mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes) + } + + case *Pointer: + tv, ok := targetType.(*Pointer) + if !ok { + return fmt.Errorf("type mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes) + } + + return coreTypesMatch(lv.Target, tv.Target, true, visited) + + case *Array: + tv, ok := targetType.(*Array) + if !ok { + return fmt.Errorf("type mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes) + } + + if lv.Nelems != tv.Nelems { + return fmt.Errorf("array mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes) + } + + return coreTypesMatch(lv.Type, tv.Type, behindPtr, visited) + + case *FuncProto: + tv, ok := targetType.(*FuncProto) + if !ok { + return fmt.Errorf("type mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes) + } + + if len(lv.Params) != len(tv.Params) { + return fmt.Errorf("function param mismatch: %w", errIncompatibleTypes) + } + + for i, lparam := range lv.Params { + if err := coreTypesMatch(lparam.Type, tv.Params[i].Type, behindPtr, visited); err != nil { + return err + } + } + + return coreTypesMatch(lv.Return, tv.Return, behindPtr, visited) + + default: + return fmt.Errorf("unsupported type %T", localType) + } + + return nil +} + +// coreEnumsMatch checks two enums match, which is considered to be the case if the following is true: +// - size has to match (but enum may match enum64 and vice versa) +// - local variants have to have a match in target by symbolic name (but not numeric value) +func coreEnumsMatch(local *Enum, target *Enum) error { + if local.Size != target.Size { + return fmt.Errorf("size mismatch between %v and %v: %w", local, target, errIncompatibleTypes) + } + + // If there are more values in the local than the target, there must be at least one value in the local + // that isn't in the target, and therefor the types are incompatible. + if len(local.Values) > len(target.Values) { + return fmt.Errorf("local has more values than target: %w", errIncompatibleTypes) + } + +outer: + for _, lv := range local.Values { + for _, rv := range target.Values { + if coreNamesMatch(lv.Name, rv.Name) { + continue outer + } + } + + return fmt.Errorf("no match for %v in %v: %w", lv, target, errIncompatibleTypes) + } + + return nil +} + +func coreCompositesMatch(localType, targetType composite, behindPtr bool, visited map[pair]struct{}) error { + if reflect.TypeOf(localType) != reflect.TypeOf(targetType) { + return fmt.Errorf("type mismatch between %v and %v: %w", localType, targetType, errIncompatibleTypes) + } + + var ( + localMembers []Member + targetMembers []Member + ) + + switch lv := localType.(type) { + case *Struct: + localMembers = lv.members() + case *Union: + localMembers = lv.members() + default: + return fmt.Errorf("expected coreCompositeMatch to be called with a composite type, got %T", localType) + } + + switch tv := targetType.(type) { + case *Struct: + targetMembers = tv.members() + case *Union: + targetMembers = tv.members() + default: + return fmt.Errorf("expected coreCompositeMatch to be called with a composite type, got %T", targetType) + } + + for _, localMember := range localMembers { + matches := false + for _, targetMember := range targetMembers { + if !coreNamesMatch(localMember.Name, targetMember.Name) { + continue + } + + err := coreTypesMatch(localMember.Type, targetMember.Type, behindPtr, visited) + if err != nil && !errors.Is(err, errIncompatibleTypes) { + return err + } + + if err == nil { + matches = true + break + } + } + + if !matches { + return fmt.Errorf("no match for %v in %v: %w", localMember, targetType, errIncompatibleTypes) + } + } + + return nil +} diff --git a/btf/core_test.go b/btf/core_test.go index f56d3d481..d60149d09 100644 --- a/btf/core_test.go +++ b/btf/core_test.go @@ -718,3 +718,176 @@ func BenchmarkCORESkBuff(b *testing.B) { }) } } + +func TestCORETypesMatch(t *testing.T) { + tests := []struct { + a, b Type + match bool + reversible bool + }{ + {&Void{}, &Void{}, true, true}, + {&Int{Size: 32}, &Int{Size: 32}, true, true}, + {&Int{Size: 64}, &Int{Size: 32}, false, true}, + {&Int{Size: 32}, &Int{Size: 32, Encoding: Signed}, false, true}, + {&Fwd{Name: "a"}, &Fwd{Name: "a"}, true, true}, + {&Fwd{Name: "a"}, &Fwd{Name: "b___new"}, false, true}, + {&Fwd{Name: "a"}, &Fwd{Name: "a___new"}, true, true}, + {&Fwd{Name: "a"}, &Struct{Name: "a___new"}, false, true}, + {&Fwd{Name: "a"}, &Union{Name: "a___new"}, false, true}, + {&Pointer{&Fwd{Name: "a", Kind: FwdStruct}}, &Pointer{&Struct{Name: "a___new"}}, true, true}, + {&Pointer{&Fwd{Name: "a", Kind: FwdUnion}}, &Pointer{&Union{Name: "a___new"}}, true, true}, + {&Pointer{&Fwd{Name: "a", Kind: FwdStruct}}, &Pointer{&Union{Name: "a___new"}}, false, true}, + {&Struct{Name: "a___new"}, &Union{Name: "a___new"}, false, true}, + {&Pointer{&Struct{Name: "a"}}, &Pointer{&Union{Name: "a___new"}}, false, true}, + { + &Struct{Name: "a", Members: []Member{ + {Name: "foo", Type: &Int{}}, + }}, + &Struct{Name: "a___new", Members: []Member{ + {Name: "foo", Type: &Int{}}, + }}, + true, + true, + }, + { + &Struct{Name: "a", Members: []Member{ + {Name: "foo", Type: &Int{}}, + }}, + &Struct{Name: "a___new", Members: []Member{ + {Name: "foo", Type: &Int{}}, + {Name: "bar", Type: &Int{}}, + }}, + true, + false, + }, + { + &Struct{Name: "a", Members: []Member{ + {Name: "foo", Type: &Int{}}, + {Name: "bar", Type: &Int{}}, + }}, + &Struct{Name: "a___new", Members: []Member{ + {Name: "foo", Type: &Int{}}, + }}, + false, + false, + }, + { + &Enum{Name: "a", Values: []EnumValue{ + {"foo", 1}, + }}, + &Enum{Name: "a___new", Values: []EnumValue{ + {"foo", 1}, + }}, + true, + true, + }, + { + &Enum{Name: "a", Values: []EnumValue{ + {"foo", 1}, + }}, + &Enum{Name: "a___new", Values: []EnumValue{ + {"foo", 1}, + {"bar", 2}, + }}, + true, + false, + }, + { + &Enum{Name: "a", Values: []EnumValue{ + {"foo", 1}, + {"bar", 2}, + }}, + &Enum{Name: "a___new", Values: []EnumValue{ + {"foo", 1}, + }}, + false, + false, + }, + { + &Array{Type: &Int{}, Nelems: 2}, + &Array{Type: &Int{}, Nelems: 2}, + true, + true, + }, + { + &Array{Type: &Int{}, Nelems: 3}, + &Array{Type: &Int{}, Nelems: 2}, + false, + true, + }, + { + &Array{Type: &Void{}, Nelems: 2}, + &Array{Type: &Int{}, Nelems: 2}, + false, + true, + }, + { + &FuncProto{Return: &Int{}, Params: []FuncParam{ + {Name: "foo", Type: &Int{}}, + }}, + &FuncProto{Return: &Int{}, Params: []FuncParam{ + {Name: "bar", Type: &Int{}}, + }}, + true, + true, + }, + { + &FuncProto{Return: &Int{}, Params: []FuncParam{ + {Name: "foo", Type: &Int{}}, + }}, + &FuncProto{Return: &Int{}, Params: []FuncParam{ + {Name: "bar", Type: &Int{}}, + {Name: "baz", Type: &Int{}}, + }}, + false, + true, + }, + { + &FuncProto{Return: &Void{}, Params: []FuncParam{ + {Name: "foo", Type: &Int{}}, + }}, + &FuncProto{Return: &Int{}, Params: []FuncParam{ + {Name: "bar", Type: &Int{}}, + }}, + false, + true, + }, + } + + for _, test := range tests { + err := coreTypesMatch(test.a, test.b, false, nil) + if test.match { + if err != nil { + t.Errorf("Expected types to match: %s\na = %#v\nb = %#v", err, test.a, test.b) + continue + } + } else { + if !errors.Is(err, errIncompatibleTypes) { + t.Errorf("Expected types to be incompatible: %s\na = %#v\nb = %#v", err, test.a, test.b) + continue + } + } + + if test.reversible { + err = coreTypesMatch(test.b, test.a, false, nil) + if test.match { + if err != nil { + t.Errorf("Expected reversed types to match: %s\na = %#v\nb = %#v", err, test.a, test.b) + } + } else { + if !errors.Is(err, errIncompatibleTypes) { + t.Errorf("Expected reversed types to be incompatible: %s\na = %#v\nb = %#v", err, test.a, test.b) + } + } + } + } + + for _, invalid := range []Type{&Var{}, &Datasec{}} { + err := coreTypesMatch(invalid, invalid, false, nil) + if errors.Is(err, errIncompatibleTypes) { + t.Errorf("Expected an error for %T, not errIncompatibleTypes", invalid) + } else if err == nil { + t.Errorf("Expected an error for %T", invalid) + } + } +} diff --git a/btf/testdata/bpf_core_read.h b/btf/testdata/bpf_core_read.h index 09ebe3db5..9b26f1e3d 100644 --- a/btf/testdata/bpf_core_read.h +++ b/btf/testdata/bpf_core_read.h @@ -27,8 +27,9 @@ enum bpf_type_id_kind { /* second argument to __builtin_preserve_type_info() built-in */ enum bpf_type_info_kind { - BPF_TYPE_EXISTS = 0, /* type existence in target kernel */ + BPF_TYPE_EXISTS = 0, /* type existence in target kernel */ BPF_TYPE_SIZE = 1, /* type size in target kernel */ + BPF_TYPE_MATCHES = 2, /* type match in target kernel */ }; /* second argument to __builtin_preserve_enum_value() built-in */ @@ -154,6 +155,17 @@ enum bpf_enum_value_kind { #define bpf_core_type_exists(type) \ __builtin_preserve_type_info(*(typeof(type) *)0, BPF_TYPE_EXISTS) +/* + * Convenience macro to check that provided named type + * (struct/union/enum/typedef) "matches" that in a target kernel. + * Returns: + * 1, if the type matches in the target kernel's BTF; + * 0, if the type does not match any in the target kernel + */ +#define bpf_core_type_matches(type) \ + __builtin_preserve_type_info(*(typeof(type) *)0, BPF_TYPE_MATCHES) + + /* * Convenience macro to get the byte size of a provided named type * (struct/union/enum/typedef) in a target kernel. diff --git a/btf/testdata/relocs-eb.elf b/btf/testdata/relocs-eb.elf index 7ae44dacd..7fc2336ac 100644 Binary files a/btf/testdata/relocs-eb.elf and b/btf/testdata/relocs-eb.elf differ diff --git a/btf/testdata/relocs-el.elf b/btf/testdata/relocs-el.elf index d3ef66622..f16f82965 100644 Binary files a/btf/testdata/relocs-el.elf and b/btf/testdata/relocs-el.elf differ diff --git a/btf/testdata/relocs.c b/btf/testdata/relocs.c index 799190134..137452ba1 100644 --- a/btf/testdata/relocs.c +++ b/btf/testdata/relocs.c @@ -97,6 +97,13 @@ __section("socket/type_ids") int type_ids() { } \ }) +#define type_matches(expr) \ + ({ \ + if (!bpf_core_type_matches(expr)) { \ + return __LINE__; \ + } \ + }) + __section("socket/types") int types() { type_exists(struct s); type_exists(s_t); @@ -125,6 +132,19 @@ __section("socket/types") int types() { type_size_matches(const u_t); type_size_matches(volatile u_t); + type_matches(struct s); + type_matches(s_t); + type_matches(const s_t); + type_matches(volatile s_t); + type_matches(enum e); + type_matches(e_t); + type_matches(const e_t); + type_matches(volatile e_t); + type_matches(union u); + type_matches(u_t); + type_matches(const u_t); + type_matches(volatile u_t); + return 0; }