@@ -2166,178 +2166,6 @@ class AdjointGenerator
2166
2166
}
2167
2167
}
2168
2168
2169
- void subTransferHelper (Type *secretty, BasicBlock *parent,
2170
- Intrinsic::ID intrinsic, unsigned dstalign,
2171
- unsigned srcalign, unsigned offset, Value *orig_dst,
2172
- Value *orig_src, Value *length, Value *isVolatile,
2173
- llvm::CallInst *MTI, bool allowForward = true ) {
2174
- // TODO offset
2175
- if (secretty) {
2176
- // no change to forward pass if represents floats
2177
- if (Mode == DerivativeMode::ReverseModeGradient ||
2178
- Mode == DerivativeMode::ReverseModeCombined) {
2179
- IRBuilder<> Builder2 (parent);
2180
- getReverseBuilder (Builder2);
2181
-
2182
- // If the src is constant simply zero d_dst and don't propagate to d_src
2183
- // (which thus == src and may be illegal)
2184
- if (gutils->isConstantValue (orig_src)) {
2185
- SmallVector<Value *, 4 > args;
2186
- args.push_back (
2187
- lookup (gutils->invertPointerM (orig_dst, Builder2), Builder2));
2188
- if (args[0 ]->getType ()->isIntegerTy ())
2189
- args[0 ] = Builder2.CreateIntToPtr (
2190
- args[0 ], Type::getInt8PtrTy (MTI->getContext ()));
2191
- args.push_back (
2192
- ConstantInt::get (Type::getInt8Ty (parent->getContext ()), 0 ));
2193
- args.push_back (lookup (length, Builder2));
2194
- #if LLVM_VERSION_MAJOR <= 6
2195
- args.push_back (ConstantInt::get (
2196
- Type::getInt32Ty (parent->getContext ()), max (1U , dstalign)));
2197
- #endif
2198
- args.push_back (ConstantInt::getFalse (parent->getContext ()));
2199
-
2200
- Type *tys[] = {args[0 ]->getType (), args[2 ]->getType ()};
2201
- auto memsetIntr = Intrinsic::getDeclaration (
2202
- parent->getParent ()->getParent (), Intrinsic::memset , tys);
2203
- auto cal = Builder2.CreateCall (memsetIntr, args);
2204
- cal->setCallingConv (memsetIntr->getCallingConv ());
2205
- if (dstalign != 0 ) {
2206
- #if LLVM_VERSION_MAJOR >= 10
2207
- cal->addParamAttr (0 , Attribute::getWithAlignment (
2208
- parent->getContext (), Align (dstalign)));
2209
- #else
2210
- cal->addParamAttr (
2211
- 0 , Attribute::getWithAlignment (parent->getContext (), dstalign));
2212
- #endif
2213
- }
2214
-
2215
- } else {
2216
- SmallVector<Value *, 4 > args;
2217
- auto dsto =
2218
- lookup (gutils->invertPointerM (orig_dst, Builder2), Builder2);
2219
- if (dsto->getType ()->isIntegerTy ())
2220
- dsto = Builder2.CreateIntToPtr (
2221
- dsto, Type::getInt8PtrTy (dsto->getContext ()));
2222
- unsigned dstaddr =
2223
- cast<PointerType>(dsto->getType ())->getAddressSpace ();
2224
- auto secretpt = PointerType::get (secretty, dstaddr);
2225
- if (offset != 0 )
2226
- dsto = Builder2.CreateConstInBoundsGEP1_64 (dsto, offset);
2227
- args.push_back (Builder2.CreatePointerCast (dsto, secretpt));
2228
- auto srco =
2229
- lookup (gutils->invertPointerM (orig_src, Builder2), Builder2);
2230
- if (srco->getType ()->isIntegerTy ())
2231
- srco = Builder2.CreateIntToPtr (
2232
- srco, Type::getInt8PtrTy (srco->getContext ()));
2233
- unsigned srcaddr =
2234
- cast<PointerType>(srco->getType ())->getAddressSpace ();
2235
- secretpt = PointerType::get (secretty, srcaddr);
2236
- if (offset != 0 )
2237
- srco = Builder2.CreateConstInBoundsGEP1_64 (srco, offset);
2238
- args.push_back (Builder2.CreatePointerCast (srco, secretpt));
2239
- args.push_back (Builder2.CreateUDiv (
2240
- lookup (length, Builder2),
2241
-
2242
- ConstantInt::get (length->getType (),
2243
- Builder2.GetInsertBlock ()
2244
- ->getParent ()
2245
- ->getParent ()
2246
- ->getDataLayout ()
2247
- .getTypeAllocSizeInBits (secretty) /
2248
- 8 )));
2249
-
2250
- auto dmemcpy = ((intrinsic == Intrinsic::memcpy )
2251
- ? getOrInsertDifferentialFloatMemcpy
2252
- : getOrInsertDifferentialFloatMemmove)(
2253
- *parent->getParent ()->getParent (), secretty, dstalign, srcalign,
2254
- dstaddr, srcaddr);
2255
- Builder2.CreateCall (dmemcpy, args);
2256
- }
2257
- }
2258
- } else {
2259
-
2260
- // if represents pointer or integer type then only need to modify forward
2261
- // pass with the copy
2262
- if (allowForward && (Mode == DerivativeMode::ReverseModePrimal ||
2263
- Mode == DerivativeMode::ReverseModeCombined)) {
2264
-
2265
- // It is questionable how the following case would even occur, but if
2266
- // the dst is constant, we shouldn't do anything extra
2267
- if (gutils->isConstantValue (orig_dst)) {
2268
- return ;
2269
- }
2270
-
2271
- SmallVector<Value *, 4 > args;
2272
- IRBuilder<> BuilderZ (gutils->getNewFromOriginal (MTI));
2273
-
2274
- // If src is inactive, then we should copy from the regular pointer
2275
- // (i.e. suppose we are copying constant memory representing dimensions
2276
- // into a tensor)
2277
- // to ensure that the differential tensor is well formed for use
2278
- // OUTSIDE the derivative generation (as enzyme doesn't need this), we
2279
- // should also perform the copy onto the differential. Future
2280
- // Optimization (not implemented): If dst can never escape Enzyme code,
2281
- // we may omit this copy.
2282
- // no need to update pointers, even if dst is active
2283
- auto dsto = gutils->invertPointerM (orig_dst, BuilderZ);
2284
- if (dsto->getType ()->isIntegerTy ())
2285
- dsto = BuilderZ.CreateIntToPtr (dsto,
2286
- Type::getInt8PtrTy (MTI->getContext ()));
2287
- if (offset != 0 )
2288
- dsto = BuilderZ.CreateConstInBoundsGEP1_64 (dsto, offset);
2289
- args.push_back (dsto);
2290
- auto srco = gutils->invertPointerM (orig_src, BuilderZ);
2291
- if (srco->getType ()->isIntegerTy ())
2292
- srco = BuilderZ.CreateIntToPtr (srco,
2293
- Type::getInt8PtrTy (MTI->getContext ()));
2294
- if (offset != 0 )
2295
- srco = BuilderZ.CreateConstInBoundsGEP1_64 (srco, offset);
2296
- args.push_back (srco);
2297
-
2298
- args.push_back (length);
2299
- #if LLVM_VERSION_MAJOR <= 6
2300
- args.push_back (ConstantInt::get (Type::getInt32Ty (parent->getContext ()),
2301
- max (1U , min (srcalign, dstalign))));
2302
- #endif
2303
- args.push_back (isVolatile);
2304
-
2305
- // #if LLVM_VERSION_MAJOR >= 7
2306
- Type *tys[] = {args[0 ]->getType (), args[1 ]->getType (),
2307
- args[2 ]->getType ()};
2308
- // #else
2309
- // Type *tys[] = {args[0]->getType(), args[1]->getType(),
2310
- // args[2]->getType(), args[3]->getType()}; #endif
2311
-
2312
- auto memtransIntr = Intrinsic::getDeclaration (
2313
- gutils->newFunc ->getParent (), intrinsic, tys);
2314
- auto cal = BuilderZ.CreateCall (memtransIntr, args);
2315
- cal->setAttributes (MTI->getAttributes ());
2316
- cal->setCallingConv (memtransIntr->getCallingConv ());
2317
- cal->setTailCallKind (MTI->getTailCallKind ());
2318
-
2319
- if (dstalign != 0 ) {
2320
- #if LLVM_VERSION_MAJOR >= 10
2321
- cal->addParamAttr (0 , Attribute::getWithAlignment (parent->getContext (),
2322
- Align (dstalign)));
2323
- #else
2324
- cal->addParamAttr (
2325
- 0 , Attribute::getWithAlignment (parent->getContext (), dstalign));
2326
- #endif
2327
- }
2328
- if (srcalign != 0 ) {
2329
- #if LLVM_VERSION_MAJOR >= 10
2330
- cal->addParamAttr (1 , Attribute::getWithAlignment (parent->getContext (),
2331
- Align (srcalign)));
2332
- #else
2333
- cal->addParamAttr (
2334
- 1 , Attribute::getWithAlignment (parent->getContext (), srcalign));
2335
- #endif
2336
- }
2337
- }
2338
- }
2339
- }
2340
-
2341
2169
void visitMemTransferInst (llvm::MemTransferInst &MTI) {
2342
2170
#if LLVM_VERSION_MAJOR >= 7
2343
2171
Value *isVolatile = gutils->getNewFromOriginal (MTI.getOperand (3 ));
@@ -2352,16 +2180,20 @@ class AdjointGenerator
2352
2180
auto dstAlign = MTI.getDestAlignment ();
2353
2181
#endif
2354
2182
visitMemTransferCommon (MTI.getIntrinsicID (), srcAlign, dstAlign, MTI,
2183
+ MTI.getOperand (0 ), MTI.getOperand (1 ),
2184
+ gutils->getNewFromOriginal (MTI.getOperand (2 )),
2355
2185
isVolatile);
2356
2186
}
2357
2187
2358
2188
#if LLVM_VERSION_MAJOR >= 10
2359
2189
void visitMemTransferCommon (Intrinsic::ID ID, MaybeAlign srcAlign,
2360
2190
MaybeAlign dstAlign, llvm::CallInst &MTI,
2191
+ Value *orig_dst, Value *orig_src, Value *new_size,
2361
2192
Value *isVolatile)
2362
2193
#else
2363
2194
void visitMemTransferCommon (Intrinsic::ID ID, unsigned srcAlign,
2364
2195
unsigned dstAlign, llvm::CallInst &MTI,
2196
+ Value *orig_dst, Value *orig_src, Value *new_size,
2365
2197
Value *isVolatile)
2366
2198
#endif
2367
2199
{
@@ -2375,10 +2207,6 @@ class AdjointGenerator
2375
2207
return ;
2376
2208
}
2377
2209
2378
- Value *orig_dst = MTI.getOperand (0 );
2379
- Value *orig_src = MTI.getOperand (1 );
2380
- Value *new_size = gutils->getNewFromOriginal (MTI.getOperand (2 ));
2381
-
2382
2210
// copying into nullptr is invalid (not sure why it exists here), but we
2383
2211
// shouldn't do it in reverse pass or shadow
2384
2212
if (isa<ConstantPointerNull>(orig_dst) ||
@@ -2529,8 +2357,17 @@ class AdjointGenerator
2529
2357
srcalign = 1 ;
2530
2358
}
2531
2359
}
2532
- subTransferHelper (dt.isFloat (), MTI.getParent (), ID, subdstalign,
2533
- subsrcalign, /* offset*/ start, orig_dst, orig_src,
2360
+ IRBuilder<> BuilderZ (gutils->getNewFromOriginal (&MTI));
2361
+ Value *shadow_dst = gutils->isConstantValue (orig_dst)
2362
+ ? gutils->getNewFromOriginal (orig_dst)
2363
+ : gutils->invertPointerM (orig_dst, BuilderZ);
2364
+ Value *shadow_src = gutils->isConstantValue (orig_src)
2365
+ ? gutils->getNewFromOriginal (orig_src)
2366
+ : gutils->invertPointerM (orig_src, BuilderZ);
2367
+ SubTransferHelper (gutils, Mode, dt.isFloat (), ID, subdstalign,
2368
+ subsrcalign, /* offset*/ start,
2369
+ gutils->isConstantValue (orig_dst), shadow_dst,
2370
+ gutils->isConstantValue (orig_src), shadow_src,
2534
2371
/* length*/ length, /* volatile*/ isVolatile, &MTI);
2535
2372
2536
2373
if (nextStart == size)
@@ -7741,10 +7578,14 @@ class AdjointGenerator
7741
7578
#if LLVM_VERSION_MAJOR >= 10
7742
7579
visitMemTransferCommon (ID, /* srcAlign*/ MaybeAlign (1 ),
7743
7580
/* dstAlign*/ MaybeAlign (1 ), *orig,
7581
+ orig->getArgOperand (0 ), orig->getArgOperand (1 ),
7582
+ gutils->getNewFromOriginal (orig->getArgOperand (2 )),
7744
7583
ConstantInt::getFalse (orig->getContext ()));
7745
7584
#else
7746
7585
visitMemTransferCommon (ID, /* srcAlign*/ 1 ,
7747
- /* dstAlign*/ 1 , *orig,
7586
+ /* dstAlign*/ 1 , *orig, orig->getArgOperand (0 ),
7587
+ orig->getArgOperand (1 ),
7588
+ gutils->getNewFromOriginal (orig->getArgOperand (2 )),
7748
7589
ConstantInt::getFalse (orig->getContext ()));
7749
7590
#endif
7750
7591
return ;
0 commit comments