Skip to content

Commit deba550

Browse files
wsmosesvchuravy
andauthored
Expose Subtransfer and additional C API functions (rust-lang#384)
* [CAPI] Expose subtransfer * Update * Fixup * Update enzyme/Enzyme/CApi.cpp Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com> * Update enzyme/Enzyme/CApi.cpp Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com> * Update enzyme/Enzyme/CApi.cpp Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com> * Update enzyme/Enzyme/CApi.cpp Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com> * Update enzyme/Enzyme/CApi.cpp Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com> Co-authored-by: Valentin Churavy <vchuravy@users.noreply.github.com>
1 parent cbe56be commit deba550

File tree

4 files changed

+234
-179
lines changed

4 files changed

+234
-179
lines changed

enzyme/Enzyme/AdjointGenerator.h

Lines changed: 20 additions & 179 deletions
Original file line numberDiff line numberDiff line change
@@ -2166,178 +2166,6 @@ class AdjointGenerator
21662166
}
21672167
}
21682168

2169-
void subTransferHelper(Type *secretty, BasicBlock *parent,
2170-
Intrinsic::ID intrinsic, unsigned dstalign,
2171-
unsigned srcalign, unsigned offset, Value *orig_dst,
2172-
Value *orig_src, Value *length, Value *isVolatile,
2173-
llvm::CallInst *MTI, bool allowForward = true) {
2174-
// TODO offset
2175-
if (secretty) {
2176-
// no change to forward pass if represents floats
2177-
if (Mode == DerivativeMode::ReverseModeGradient ||
2178-
Mode == DerivativeMode::ReverseModeCombined) {
2179-
IRBuilder<> Builder2(parent);
2180-
getReverseBuilder(Builder2);
2181-
2182-
// If the src is constant simply zero d_dst and don't propagate to d_src
2183-
// (which thus == src and may be illegal)
2184-
if (gutils->isConstantValue(orig_src)) {
2185-
SmallVector<Value *, 4> args;
2186-
args.push_back(
2187-
lookup(gutils->invertPointerM(orig_dst, Builder2), Builder2));
2188-
if (args[0]->getType()->isIntegerTy())
2189-
args[0] = Builder2.CreateIntToPtr(
2190-
args[0], Type::getInt8PtrTy(MTI->getContext()));
2191-
args.push_back(
2192-
ConstantInt::get(Type::getInt8Ty(parent->getContext()), 0));
2193-
args.push_back(lookup(length, Builder2));
2194-
#if LLVM_VERSION_MAJOR <= 6
2195-
args.push_back(ConstantInt::get(
2196-
Type::getInt32Ty(parent->getContext()), max(1U, dstalign)));
2197-
#endif
2198-
args.push_back(ConstantInt::getFalse(parent->getContext()));
2199-
2200-
Type *tys[] = {args[0]->getType(), args[2]->getType()};
2201-
auto memsetIntr = Intrinsic::getDeclaration(
2202-
parent->getParent()->getParent(), Intrinsic::memset, tys);
2203-
auto cal = Builder2.CreateCall(memsetIntr, args);
2204-
cal->setCallingConv(memsetIntr->getCallingConv());
2205-
if (dstalign != 0) {
2206-
#if LLVM_VERSION_MAJOR >= 10
2207-
cal->addParamAttr(0, Attribute::getWithAlignment(
2208-
parent->getContext(), Align(dstalign)));
2209-
#else
2210-
cal->addParamAttr(
2211-
0, Attribute::getWithAlignment(parent->getContext(), dstalign));
2212-
#endif
2213-
}
2214-
2215-
} else {
2216-
SmallVector<Value *, 4> args;
2217-
auto dsto =
2218-
lookup(gutils->invertPointerM(orig_dst, Builder2), Builder2);
2219-
if (dsto->getType()->isIntegerTy())
2220-
dsto = Builder2.CreateIntToPtr(
2221-
dsto, Type::getInt8PtrTy(dsto->getContext()));
2222-
unsigned dstaddr =
2223-
cast<PointerType>(dsto->getType())->getAddressSpace();
2224-
auto secretpt = PointerType::get(secretty, dstaddr);
2225-
if (offset != 0)
2226-
dsto = Builder2.CreateConstInBoundsGEP1_64(dsto, offset);
2227-
args.push_back(Builder2.CreatePointerCast(dsto, secretpt));
2228-
auto srco =
2229-
lookup(gutils->invertPointerM(orig_src, Builder2), Builder2);
2230-
if (srco->getType()->isIntegerTy())
2231-
srco = Builder2.CreateIntToPtr(
2232-
srco, Type::getInt8PtrTy(srco->getContext()));
2233-
unsigned srcaddr =
2234-
cast<PointerType>(srco->getType())->getAddressSpace();
2235-
secretpt = PointerType::get(secretty, srcaddr);
2236-
if (offset != 0)
2237-
srco = Builder2.CreateConstInBoundsGEP1_64(srco, offset);
2238-
args.push_back(Builder2.CreatePointerCast(srco, secretpt));
2239-
args.push_back(Builder2.CreateUDiv(
2240-
lookup(length, Builder2),
2241-
2242-
ConstantInt::get(length->getType(),
2243-
Builder2.GetInsertBlock()
2244-
->getParent()
2245-
->getParent()
2246-
->getDataLayout()
2247-
.getTypeAllocSizeInBits(secretty) /
2248-
8)));
2249-
2250-
auto dmemcpy = ((intrinsic == Intrinsic::memcpy)
2251-
? getOrInsertDifferentialFloatMemcpy
2252-
: getOrInsertDifferentialFloatMemmove)(
2253-
*parent->getParent()->getParent(), secretty, dstalign, srcalign,
2254-
dstaddr, srcaddr);
2255-
Builder2.CreateCall(dmemcpy, args);
2256-
}
2257-
}
2258-
} else {
2259-
2260-
// if represents pointer or integer type then only need to modify forward
2261-
// pass with the copy
2262-
if (allowForward && (Mode == DerivativeMode::ReverseModePrimal ||
2263-
Mode == DerivativeMode::ReverseModeCombined)) {
2264-
2265-
// It is questionable how the following case would even occur, but if
2266-
// the dst is constant, we shouldn't do anything extra
2267-
if (gutils->isConstantValue(orig_dst)) {
2268-
return;
2269-
}
2270-
2271-
SmallVector<Value *, 4> args;
2272-
IRBuilder<> BuilderZ(gutils->getNewFromOriginal(MTI));
2273-
2274-
// If src is inactive, then we should copy from the regular pointer
2275-
// (i.e. suppose we are copying constant memory representing dimensions
2276-
// into a tensor)
2277-
// to ensure that the differential tensor is well formed for use
2278-
// OUTSIDE the derivative generation (as enzyme doesn't need this), we
2279-
// should also perform the copy onto the differential. Future
2280-
// Optimization (not implemented): If dst can never escape Enzyme code,
2281-
// we may omit this copy.
2282-
// no need to update pointers, even if dst is active
2283-
auto dsto = gutils->invertPointerM(orig_dst, BuilderZ);
2284-
if (dsto->getType()->isIntegerTy())
2285-
dsto = BuilderZ.CreateIntToPtr(dsto,
2286-
Type::getInt8PtrTy(MTI->getContext()));
2287-
if (offset != 0)
2288-
dsto = BuilderZ.CreateConstInBoundsGEP1_64(dsto, offset);
2289-
args.push_back(dsto);
2290-
auto srco = gutils->invertPointerM(orig_src, BuilderZ);
2291-
if (srco->getType()->isIntegerTy())
2292-
srco = BuilderZ.CreateIntToPtr(srco,
2293-
Type::getInt8PtrTy(MTI->getContext()));
2294-
if (offset != 0)
2295-
srco = BuilderZ.CreateConstInBoundsGEP1_64(srco, offset);
2296-
args.push_back(srco);
2297-
2298-
args.push_back(length);
2299-
#if LLVM_VERSION_MAJOR <= 6
2300-
args.push_back(ConstantInt::get(Type::getInt32Ty(parent->getContext()),
2301-
max(1U, min(srcalign, dstalign))));
2302-
#endif
2303-
args.push_back(isVolatile);
2304-
2305-
//#if LLVM_VERSION_MAJOR >= 7
2306-
Type *tys[] = {args[0]->getType(), args[1]->getType(),
2307-
args[2]->getType()};
2308-
//#else
2309-
// Type *tys[] = {args[0]->getType(), args[1]->getType(),
2310-
// args[2]->getType(), args[3]->getType()}; #endif
2311-
2312-
auto memtransIntr = Intrinsic::getDeclaration(
2313-
gutils->newFunc->getParent(), intrinsic, tys);
2314-
auto cal = BuilderZ.CreateCall(memtransIntr, args);
2315-
cal->setAttributes(MTI->getAttributes());
2316-
cal->setCallingConv(memtransIntr->getCallingConv());
2317-
cal->setTailCallKind(MTI->getTailCallKind());
2318-
2319-
if (dstalign != 0) {
2320-
#if LLVM_VERSION_MAJOR >= 10
2321-
cal->addParamAttr(0, Attribute::getWithAlignment(parent->getContext(),
2322-
Align(dstalign)));
2323-
#else
2324-
cal->addParamAttr(
2325-
0, Attribute::getWithAlignment(parent->getContext(), dstalign));
2326-
#endif
2327-
}
2328-
if (srcalign != 0) {
2329-
#if LLVM_VERSION_MAJOR >= 10
2330-
cal->addParamAttr(1, Attribute::getWithAlignment(parent->getContext(),
2331-
Align(srcalign)));
2332-
#else
2333-
cal->addParamAttr(
2334-
1, Attribute::getWithAlignment(parent->getContext(), srcalign));
2335-
#endif
2336-
}
2337-
}
2338-
}
2339-
}
2340-
23412169
void visitMemTransferInst(llvm::MemTransferInst &MTI) {
23422170
#if LLVM_VERSION_MAJOR >= 7
23432171
Value *isVolatile = gutils->getNewFromOriginal(MTI.getOperand(3));
@@ -2352,16 +2180,20 @@ class AdjointGenerator
23522180
auto dstAlign = MTI.getDestAlignment();
23532181
#endif
23542182
visitMemTransferCommon(MTI.getIntrinsicID(), srcAlign, dstAlign, MTI,
2183+
MTI.getOperand(0), MTI.getOperand(1),
2184+
gutils->getNewFromOriginal(MTI.getOperand(2)),
23552185
isVolatile);
23562186
}
23572187

23582188
#if LLVM_VERSION_MAJOR >= 10
23592189
void visitMemTransferCommon(Intrinsic::ID ID, MaybeAlign srcAlign,
23602190
MaybeAlign dstAlign, llvm::CallInst &MTI,
2191+
Value *orig_dst, Value *orig_src, Value *new_size,
23612192
Value *isVolatile)
23622193
#else
23632194
void visitMemTransferCommon(Intrinsic::ID ID, unsigned srcAlign,
23642195
unsigned dstAlign, llvm::CallInst &MTI,
2196+
Value *orig_dst, Value *orig_src, Value *new_size,
23652197
Value *isVolatile)
23662198
#endif
23672199
{
@@ -2375,10 +2207,6 @@ class AdjointGenerator
23752207
return;
23762208
}
23772209

2378-
Value *orig_dst = MTI.getOperand(0);
2379-
Value *orig_src = MTI.getOperand(1);
2380-
Value *new_size = gutils->getNewFromOriginal(MTI.getOperand(2));
2381-
23822210
// copying into nullptr is invalid (not sure why it exists here), but we
23832211
// shouldn't do it in reverse pass or shadow
23842212
if (isa<ConstantPointerNull>(orig_dst) ||
@@ -2529,8 +2357,17 @@ class AdjointGenerator
25292357
srcalign = 1;
25302358
}
25312359
}
2532-
subTransferHelper(dt.isFloat(), MTI.getParent(), ID, subdstalign,
2533-
subsrcalign, /*offset*/ start, orig_dst, orig_src,
2360+
IRBuilder<> BuilderZ(gutils->getNewFromOriginal(&MTI));
2361+
Value *shadow_dst = gutils->isConstantValue(orig_dst)
2362+
? gutils->getNewFromOriginal(orig_dst)
2363+
: gutils->invertPointerM(orig_dst, BuilderZ);
2364+
Value *shadow_src = gutils->isConstantValue(orig_src)
2365+
? gutils->getNewFromOriginal(orig_src)
2366+
: gutils->invertPointerM(orig_src, BuilderZ);
2367+
SubTransferHelper(gutils, Mode, dt.isFloat(), ID, subdstalign,
2368+
subsrcalign, /*offset*/ start,
2369+
gutils->isConstantValue(orig_dst), shadow_dst,
2370+
gutils->isConstantValue(orig_src), shadow_src,
25342371
/*length*/ length, /*volatile*/ isVolatile, &MTI);
25352372

25362373
if (nextStart == size)
@@ -7741,10 +7578,14 @@ class AdjointGenerator
77417578
#if LLVM_VERSION_MAJOR >= 10
77427579
visitMemTransferCommon(ID, /*srcAlign*/ MaybeAlign(1),
77437580
/*dstAlign*/ MaybeAlign(1), *orig,
7581+
orig->getArgOperand(0), orig->getArgOperand(1),
7582+
gutils->getNewFromOriginal(orig->getArgOperand(2)),
77447583
ConstantInt::getFalse(orig->getContext()));
77457584
#else
77467585
visitMemTransferCommon(ID, /*srcAlign*/ 1,
7747-
/*dstAlign*/ 1, *orig,
7586+
/*dstAlign*/ 1, *orig, orig->getArgOperand(0),
7587+
orig->getArgOperand(1),
7588+
gutils->getNewFromOriginal(orig->getArgOperand(2)),
77487589
ConstantInt::getFalse(orig->getContext()));
77497590
#endif
77507591
return;

enzyme/Enzyme/CApi.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,10 @@ LLVMValueRef EnzymeGradientUtilsNewFromOriginal(GradientUtils *gutils,
282282
return wrap(gutils->getNewFromOriginal(unwrap(val)));
283283
}
284284

285+
CDerivativeMode EnzymeGradientUtilsGetMode(GradientUtils *gutils) {
286+
return (CDerivativeMode)gutils->mode;
287+
}
288+
285289
void EnzymeGradientUtilsSetDebugLocFromOriginal(GradientUtils *gutils,
286290
LLVMValueRef val,
287291
LLVMValueRef orig) {
@@ -331,6 +335,31 @@ LLVMBasicBlockRef EnzymeGradientUtilsAllocationBlock(GradientUtils *gutils) {
331335
return wrap(gutils->inversionAllocs);
332336
}
333337

338+
CTypeTreeRef EnzymeGradientUtilsAllocAndGetTypeTree(GradientUtils *gutils,
339+
LLVMValueRef val) {
340+
auto v = unwrap(val);
341+
assert(gutils->my_TR);
342+
TypeTree TT = gutils->my_TR->query(v);
343+
TypeTree *pTT = new TypeTree(TT);
344+
return (CTypeTreeRef)pTT;
345+
}
346+
347+
void EnzymeGradientUtilsSubTransferHelper(
348+
GradientUtils *gutils, CDerivativeMode mode, LLVMTypeRef secretty,
349+
uint64_t intrinsic, uint64_t dstAlign, uint64_t srcAlign, uint64_t offset,
350+
uint8_t dstConstant, LLVMValueRef shadow_dst, uint8_t srcConstant,
351+
LLVMValueRef shadow_src, LLVMValueRef length, LLVMValueRef isVolatile,
352+
LLVMValueRef MTI, uint8_t allowForward) {
353+
auto orig = unwrap(MTI);
354+
assert(orig);
355+
SubTransferHelper(gutils, (DerivativeMode)mode, unwrap(secretty),
356+
(Intrinsic::ID)intrinsic, (unsigned)dstAlign,
357+
(unsigned)srcAlign, (unsigned)offset, (bool)dstConstant,
358+
unwrap(shadow_dst), (bool)srcConstant, unwrap(shadow_src),
359+
unwrap(length), unwrap(isVolatile), cast<CallInst>(orig),
360+
(bool)allowForward);
361+
}
362+
334363
LLVMValueRef EnzymeCreateForwardDiff(
335364
EnzymeLogicRef Logic, LLVMValueRef todiff, CDIFFE_TYPE retType,
336365
CDIFFE_TYPE *constant_args, size_t constant_args_size,
@@ -467,6 +496,15 @@ void EnzymeTypeTreeOnlyEq(CTypeTreeRef CTT, int64_t x) {
467496
void EnzymeTypeTreeData0Eq(CTypeTreeRef CTT) {
468497
*(TypeTree *)CTT = ((TypeTree *)CTT)->Data0();
469498
}
499+
500+
void EnzymeTypeTreeLookupEq(CTypeTreeRef CTT, int64_t size, const char *dl) {
501+
*(TypeTree *)CTT = ((TypeTree *)CTT)->Lookup(size, DataLayout(dl));
502+
}
503+
504+
CConcreteType EnzymeTypeTreeInner0(CTypeTreeRef CTT) {
505+
return ewrap(((TypeTree *)CTT)->Inner0());
506+
}
507+
470508
void EnzymeTypeTreeShiftIndiciesEq(CTypeTreeRef CTT, const char *datalayout,
471509
int64_t offset, int64_t maxSize,
472510
uint64_t addOffset) {

0 commit comments

Comments
 (0)