Skip to content

Commit 0be2244

Browse files
committed
Add vtable.get_virtual_fn_addr and vtable.get_vptr
1 parent 5c5c802 commit 0be2244

24 files changed

+367
-118
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 103 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2580,7 +2580,7 @@ def CIR_GetGlobalOp : CIR_Op<"get_global", [
25802580
// VTableAddrPointOp
25812581
//===----------------------------------------------------------------------===//
25822582

2583-
def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point",[
2583+
def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point", [
25842584
Pure, DeclareOpInterfaceMethods<SymbolUserOpInterface>
25852585
]> {
25862586
let summary = "Get the vtable (global variable) address point";
@@ -2589,17 +2589,18 @@ def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point",[
25892589
(address point) of a C++ virtual table. An object internal `__vptr`
25902590
gets initializated on top of the value returned by this operation.
25912591

2592-
`address_point.index` (vtable index) provides the appropriate vtable within the vtable group
2593-
(as specified by Itanium ABI), and `address_point.offset` (address point index) the actual address
2594-
point within that vtable.
2592+
`address_point.index` (vtable index) provides the appropriate vtable within
2593+
the vtable group (as specified by Itanium ABI), and `address_point.offset`
2594+
(address point index) the actual address point within that vtable.
25952595

2596-
The return type is always a `!cir.ptr<!cir.vtable>`.
2596+
The return type is always `!cir.ptr<!cir.vptr>`.
25972597

25982598
Example:
25992599
```mlir
26002600
cir.global linkonce_odr @_ZTV1B = ...
26012601
...
2602-
%3 = cir.vtable.address_point(@_ZTV1B, address_point = <index = 0, offset = 2>) : !cir.vtable_ptr
2602+
%3 = cir.vtable.address_point(@_ZTV1B,
2603+
address_point = <index = 0, offset = 2>) : !cir.ptr<!cir.vptr>
26032604
```
26042605
}];
26052606

@@ -2609,7 +2610,7 @@ def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point",[
26092610
CIR_AddressPointAttr:$address_point
26102611
);
26112612

2612-
let results = (outs Res<CIR_PointerType, "", []>:$addr);
2613+
let results = (outs Res<CIR_PtrToVPtr, "", []>:$addr);
26132614

26142615
let assemblyFormat = [{
26152616
`(`
@@ -2624,6 +2625,101 @@ def CIR_VTableAddrPointOp : CIR_Op<"vtable.address_point",[
26242625
let hasVerifier = 1;
26252626
}
26262627

2628+
//===----------------------------------------------------------------------===//
2629+
// VTableGetVptr
2630+
//===----------------------------------------------------------------------===//
2631+
2632+
def CIR_VTableGetVptrOp : CIR_Op<"vtable.get_vptr", [Pure]> {
2633+
let summary = "Get a the address of the vtable pointer for an object";
2634+
let description = [{
2635+
The `vtable.get_vptr` operation retrieves the address of the vptr for a
2636+
C++ object. This operation requires that the object pointer points to
2637+
the start of a complete object. (TODO: Describe how we get that).
2638+
The vptr will always be at offset zero in the object, but this operation
2639+
is more explicit about what is being retrieved than a direct bitcast.
2640+
2641+
The return type is always `!cir.ptr<!cir.vptr>`.
2642+
2643+
Example:
2644+
```mlir
2645+
%2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_C>>, !cir.ptr<!rec_C>
2646+
%3 = cir.vtable.get_vptr %2 : !cir.ptr<!rec_C> -> !cir.ptr<!cir.vptr>
2647+
```
2648+
}];
2649+
2650+
let arguments = (ins
2651+
Arg<CIR_PointerType, "the vptr address", [MemRead]>:$src);
2652+
2653+
let results = (outs CIR_PtrToVPtr:$vptr_ty);
2654+
2655+
2656+
let assemblyFormat = [{
2657+
$src `:` qualified(type($src)) `->` qualified(type($vptr_ty)) attr-dict
2658+
}];
2659+
2660+
}
2661+
2662+
//===----------------------------------------------------------------------===//
2663+
// VTableGetVirtualFnAddrOp
2664+
//===----------------------------------------------------------------------===//
2665+
2666+
def CIR_VTableGetVirtualFnAddrOp : CIR_Op<"vtable.get_virtual_fn_addr", [
2667+
Pure
2668+
]> {
2669+
let summary = "Get a the address of a virtual function pointer";
2670+
let description = [{
2671+
The `vtable.get_virtual_fn_addr` operation retrieves the address of a
2672+
virtual function pointer from an object's vtable (__vptr).
2673+
This is an abstraction to perform the basic pointer arithmetic to get
2674+
the address of the virtual function pointer, which can then be loaded and
2675+
called.
2676+
2677+
The return type is a pointer-to-pointer to the function type.
2678+
2679+
Example:
2680+
```mlir
2681+
%2 = cir.load %0 : !cir.ptr<!cir.ptr<!rec_C>>, !cir.ptr<!rec_C>
2682+
%3 = cir.vtable.get_vptr %2 : !cir.ptr<!rec_C> -> !cir.ptr<!cir.vptr>
2683+
%4 = cir.load %3 : !cir.ptr<!cir.vptr>, !cir.vptr
2684+
%5 = cir.vtable.get_virtual_fn_addr(%4, index = 2) : !cir.vptr
2685+
-> !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>>
2686+
%6 = cir.load align(8) %5 : !cir.ptr<!cir.ptr<!cir.func<(!cir.ptr<!rec_C>)
2687+
-> !s32i>>>,
2688+
!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>
2689+
%7 = cir.call %6(%2) : (!cir.ptr<!cir.func<(!cir.ptr<!rec_C>) -> !s32i>>,
2690+
!cir.ptr<!rec_C>) -> !s32i
2691+
```
2692+
}];
2693+
2694+
let arguments = (ins
2695+
Arg<CIR_VPtrType, "vptr", [MemRead]>:$vptr,
2696+
IndexAttr:$index_attr);
2697+
2698+
let results = (outs CIR_PointerType:$vfptr_ty);
2699+
2700+
let assemblyFormat = [{
2701+
`(`
2702+
$vptr `,` `index` `=` $index_attr
2703+
`)`
2704+
`:` qualified(type($vptr)) `,` qualified(type($vfptr_ty)) attr-dict
2705+
}];
2706+
2707+
let builders = [
2708+
OpBuilder<(ins "mlir::Type":$type,
2709+
"mlir::Value":$value,
2710+
"unsigned":$index),
2711+
[{
2712+
mlir::APInt fnIdx(64, index);
2713+
build($_builder, $_state, type, value, fnIdx);
2714+
}]>
2715+
];
2716+
2717+
let extraClassDeclaration = [{
2718+
/// Return the index of the record member being accessed.
2719+
uint64_t getIndex() { return getIndexAttr().getZExtValue(); }
2720+
}];
2721+
}
2722+
26272723
//===----------------------------------------------------------------------===//
26282724
// VTTAddrPointOp
26292725
//===----------------------------------------------------------------------===//

clang/include/clang/CIR/Dialect/IR/CIRTypeConstraints.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,20 +264,21 @@ def CIR_AnyDataMemberType : CIR_TypeBase<"::cir::DataMemberType",
264264
"data member type">;
265265

266266
//===----------------------------------------------------------------------===//
267-
// VTable type predicates
267+
// VPtr type predicates
268268
//===----------------------------------------------------------------------===//
269269

270-
def CIR_AnyVTableType : CIR_TypeBase<"::cir::VTableType",
271-
"vtable type">;
270+
def CIR_AnyVPtrType : CIR_TypeBase<"::cir::VPtrType",
271+
"vptr type">;
272272

273+
def CIR_PtrToVPtr : CIR_PtrToType<CIR_AnyVPtrType>;
273274

274275
//===----------------------------------------------------------------------===//
275276
// Scalar Type predicates
276277
//===----------------------------------------------------------------------===//
277278

278279
defvar CIR_ScalarTypes = [
279280
CIR_AnyBoolType, CIR_AnyIntType, CIR_AnyFloatType, CIR_AnyPtrType,
280-
CIR_AnyDataMemberType, CIR_AnyVTableType
281+
CIR_AnyDataMemberType, CIR_AnyVPtrType
281282
];
282283

283284
def CIR_AnyScalarType : AnyTypeOf<CIR_ScalarTypes, "cir scalar type"> {

clang/include/clang/CIR/Dialect/IR/CIRTypes.td

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -344,17 +344,31 @@ def CIR_DataMemberType : CIR_Type<"DataMember", "data_member",
344344
}
345345

346346
//===----------------------------------------------------------------------===//
347-
// CIR_VTableType
347+
// CIR_VPtrType
348348
//===----------------------------------------------------------------------===//
349349

350-
def CIR_VTableType : CIR_Type<"VTable", "vtable",
350+
def CIR_VPtrType : CIR_Type<"VPtr", "vptr",
351351
[DeclareTypeInterfaceMethods<DataLayoutTypeInterface>]> {
352352

353-
let summary = "CIR type that is used for pointers that point to a C++ vtable";
353+
let summary = "CIR type that is used for the vptr member of C++ objects";
354354
let description = [{
355-
`cir.vtable` is a special type used as the pointee type for pointers to
356-
vtables. This avoids using arbitrary pointer types to declare vtable
357-
pointer values.
355+
`cir.vptr` is a special type used as the type for the vptr member of a C++
356+
object. This avoids using arbitrary pointer types to declare vptr values
357+
and allows stronger type-based checking for operations that use or provide
358+
access to the vptr.
359+
360+
This type will be the element type of the 'vptr' member of structures that
361+
require a vtable pointer. A pointer to this type is returned by the
362+
`cir.vtable.address_point` and `cir.vtable.get_vptr` operations, and this
363+
pointer may be passed to the `cir.vtable.get_virtual_fn_addr` operation to
364+
get the address of a virtual function pointer.
365+
366+
The pointer may also be cast to other pointer types in order to perform
367+
pointer arithmetic based on information encoded in the AST layout to get
368+
the offset from a pointer to a dynamic object to the base object pointer,
369+
the base object offset value from the vtable, or the type information
370+
entry for an object.
371+
TODO: We should have special operations to do that too.
358372
}];
359373
}
360374

@@ -768,7 +782,7 @@ def CIR_AnyType : AnyTypeOf<[
768782
CIR_IntType, CIR_PointerType, CIR_DataMemberType, CIR_MethodType,
769783
CIR_BoolType, CIR_ArrayType, CIR_VectorType, CIR_FuncType, CIR_VoidType,
770784
CIR_RecordType, CIR_ExceptionType, CIR_AnyFloatType, CIR_ComplexType,
771-
CIR_VTableType
785+
CIR_VPtrType
772786
]>;
773787

774788
#endif // MLIR_CIR_DIALECT_CIR_TYPES

clang/lib/CIR/CodeGen/CIRGenBuilder.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,8 +424,8 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
424424
llvm_unreachable("unsupported long double format");
425425
}
426426

427-
mlir::Type getVirtualFnPtrType() {
428-
return cir::PointerType::get(cir::VTableType::get(getContext()));
427+
mlir::Type getPtrToVPtrType() {
428+
return getPointerTo(cir::VPtrType::get(getContext()));
429429
}
430430

431431
cir::FuncType getFuncType(llvm::ArrayRef<mlir::Type> params, mlir::Type retTy,

clang/lib/CIR/CodeGen/CIRGenClass.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,10 +1704,12 @@ void CIRGenFunction::emitTypeMetadataCodeForVCall(const CXXRecordDecl *RD,
17041704
}
17051705

17061706
mlir::Value CIRGenFunction::getVTablePtr(mlir::Location Loc, Address This,
1707-
mlir::Type VTableTy,
17081707
const CXXRecordDecl *RD) {
1709-
Address VTablePtrSrc = builder.createElementBitCast(Loc, This, VTableTy);
1710-
auto VTable = builder.createLoad(Loc, VTablePtrSrc);
1708+
auto VTablePtr = builder.create<cir::VTableGetVptrOp>(
1709+
Loc, builder.getPtrToVPtrType(), This.getPointer());
1710+
Address VTablePtrAddr = Address(VTablePtr, This.getAlignment());
1711+
1712+
auto VTable = builder.createLoad(Loc, VTablePtrAddr);
17111713
assert(!cir::MissingFeatures::tbaa());
17121714

17131715
if (CGM.getCodeGenOpts().OptimizationLevel > 0 &&

clang/lib/CIR/CodeGen/CIRGenFunction.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,6 @@ class CIRGenFunction : public CIRGenTypeCache {
957957
VisitedVirtualBasesSetTy &VBases, VPtrsVector &vptrs);
958958
/// Return the Value of the vtable pointer member pointed to by This.
959959
mlir::Value getVTablePtr(mlir::Location Loc, Address This,
960-
mlir::Type VTableTy,
961960
const CXXRecordDecl *VTableClass);
962961

963962
/// Returns whether we should perform a type checked load when loading a

clang/lib/CIR/CodeGen/CIRGenItaniumCXXABI.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -931,8 +931,7 @@ CIRGenCallee CIRGenItaniumCXXABI::getVirtualFunctionPointer(
931931
auto loc = CGF.getLoc(Loc);
932932
auto TyPtr = CGF.getBuilder().getPointerTo(Ty);
933933
auto *MethodDecl = cast<CXXMethodDecl>(GD.getDecl());
934-
auto VTable = CGF.getVTablePtr(
935-
loc, This, CGF.getBuilder().getPointerTo(TyPtr), MethodDecl->getParent());
934+
auto VTable = CGF.getVTablePtr(loc, This, MethodDecl->getParent());
936935

937936
uint64_t VTableIndex = CGM.getItaniumVTableContext().getMethodVTableIndex(GD);
938937
mlir::Value VFunc{};
@@ -945,13 +944,9 @@ CIRGenCallee CIRGenItaniumCXXABI::getVirtualFunctionPointer(
945944
if (CGM.getItaniumVTableContext().isRelativeLayout()) {
946945
llvm_unreachable("NYI");
947946
} else {
948-
VTable = CGF.getBuilder().createBitcast(
949-
loc, VTable, CGF.getBuilder().getPointerTo(TyPtr));
950-
auto VTableSlotPtr = CGF.getBuilder().create<cir::VTableAddrPointOp>(
951-
loc, CGF.getBuilder().getPointerTo(TyPtr),
952-
::mlir::FlatSymbolRefAttr{}, VTable,
953-
cir::AddressPointAttr::get(CGF.getBuilder().getContext(), 0,
954-
VTableIndex));
947+
auto VTableSlotPtr =
948+
CGF.getBuilder().create<cir::VTableGetVirtualFnAddrOp>(
949+
loc, CGF.getBuilder().getPointerTo(TyPtr), VTable, VTableIndex);
955950
VFuncLoad = CGF.getBuilder().createAlignedLoad(loc, TyPtr, VTableSlotPtr,
956951
CGF.getPointerAlign());
957952
}
@@ -1007,7 +1002,7 @@ CIRGenItaniumCXXABI::getVTableAddressPoint(BaseSubobject Base,
10071002
.getAddressPoint(Base);
10081003

10091004
auto &builder = CGM.getBuilder();
1010-
auto vtablePtrTy = builder.getVirtualFnPtrType();
1005+
auto vtablePtrTy = builder.getPtrToVPtrType();
10111006

10121007
return builder.create<cir::VTableAddrPointOp>(
10131008
CGM.getLoc(VTableClass->getSourceRange()), vtablePtrTy,
@@ -2377,14 +2372,16 @@ void CIRGenItaniumCXXABI::emitThrow(CIRGenFunction &CGF,
23772372
mlir::Value CIRGenItaniumCXXABI::getVirtualBaseClassOffset(
23782373
mlir::Location loc, CIRGenFunction &CGF, Address This,
23792374
const CXXRecordDecl *ClassDecl, const CXXRecordDecl *BaseClassDecl) {
2380-
auto VTablePtr = CGF.getVTablePtr(loc, This, CGM.UInt8PtrTy, ClassDecl);
2375+
auto VTablePtr = CGF.getVTablePtr(loc, This, ClassDecl);
2376+
auto VTableBytePtr =
2377+
CGF.getBuilder().createBitcast(VTablePtr, CGM.UInt8PtrTy);
23812378
CharUnits VBaseOffsetOffset =
23822379
CGM.getItaniumVTableContext().getVirtualBaseOffsetOffset(ClassDecl,
23832380
BaseClassDecl);
23842381
mlir::Value OffsetVal =
23852382
CGF.getBuilder().getSInt64(VBaseOffsetOffset.getQuantity(), loc);
23862383
auto VBaseOffsetPtr = CGF.getBuilder().create<cir::PtrStrideOp>(
2387-
loc, VTablePtr.getType(), VTablePtr,
2384+
loc, CGM.UInt8PtrTy, VTableBytePtr,
23882385
OffsetVal); // vbase.offset.ptr
23892386

23902387
mlir::Value VBaseOffset;

clang/lib/CIR/CodeGen/CIRRecordLayoutBuilder.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ void CIRRecordLowering::accumulateVPtrs() {
488488
}
489489

490490
mlir::Type CIRRecordLowering::getVFPtrType() {
491-
return builder.getVirtualFnPtrType();
491+
return cir::VPtrType::get(builder.getContext());
492492
}
493493

494494
void CIRRecordLowering::fillOutputFields() {

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,6 +561,12 @@ LogicalResult cir::CastOp::verify() {
561561
return success();
562562
}
563563

564+
// Allow casting cir.vptr to pointer types.
565+
// TODO: Add operations to get object offset and type info and remove this.
566+
if (mlir::isa<cir::VPtrType>(srcType) &&
567+
mlir::dyn_cast<cir::PointerType>(resType))
568+
return success();
569+
564570
// Handle the data member pointer types.
565571
if (mlir::isa<cir::DataMemberType>(srcType) &&
566572
mlir::isa<cir::DataMemberType>(resType))
@@ -2423,7 +2429,7 @@ LogicalResult cir::VTableAddrPointOp::verify() {
24232429
return success();
24242430

24252431
auto resultType = getAddr().getType();
2426-
auto resTy = cir::PointerType::get(cir::VTableType::get(getContext()));
2432+
auto resTy = cir::PointerType::get(cir::VPtrType::get(getContext()));
24272433

24282434
if (resultType != resTy)
24292435
return emitOpError("result type must be '")

clang/lib/CIR/Dialect/IR/CIRTypes.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -408,15 +408,15 @@ DataMemberType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
408408
}
409409

410410
llvm::TypeSize
411-
VTableType::getTypeSizeInBits(const ::mlir::DataLayout &dataLayout,
412-
::mlir::DataLayoutEntryListRef params) const {
411+
VPtrType::getTypeSizeInBits(const ::mlir::DataLayout &dataLayout,
412+
::mlir::DataLayoutEntryListRef params) const {
413413
// FIXME: consider size differences under different ABIs
414414
return llvm::TypeSize::getFixed(64);
415415
}
416416

417417
uint64_t
418-
VTableType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
419-
::mlir::DataLayoutEntryListRef params) const {
418+
VPtrType::getABIAlignment(const ::mlir::DataLayout &dataLayout,
419+
::mlir::DataLayoutEntryListRef params) const {
420420
// FIXME: consider alignment differences under different ABIs
421421
return 8;
422422
}

0 commit comments

Comments
 (0)