@@ -9239,6 +9239,323 @@ class AdjointGenerator
9239
9239
return ;
9240
9240
}
9241
9241
9242
+ if (funcName == " __mulsc3" || funcName == " __muldc3" ||
9243
+ funcName == " __multc3" || funcName == " __mulxc3" ) {
9244
+ if (gutils->knownRecomputeHeuristic .find (orig) !=
9245
+ gutils->knownRecomputeHeuristic .end ()) {
9246
+ if (!gutils->knownRecomputeHeuristic [orig]) {
9247
+ gutils->cacheForReverse (BuilderZ, newCall,
9248
+ getIndex (orig, CacheType::Self));
9249
+ }
9250
+ }
9251
+
9252
+ eraseIfUnused (*orig);
9253
+ if (gutils->isConstantInstruction (orig))
9254
+ return ;
9255
+
9256
+ Value *orig_op0 = call.getOperand (0 );
9257
+ Value *orig_op1 = call.getOperand (1 );
9258
+ Value *orig_op2 = call.getOperand (2 );
9259
+ Value *orig_op3 = call.getOperand (3 );
9260
+
9261
+ bool constantval0 = gutils->isConstantValue (orig_op0);
9262
+ bool constantval1 = gutils->isConstantValue (orig_op1);
9263
+ bool constantval2 = gutils->isConstantValue (orig_op2);
9264
+ bool constantval3 = gutils->isConstantValue (orig_op3);
9265
+
9266
+ Value *prim[4 ] = {gutils->getNewFromOriginal (orig_op0),
9267
+ gutils->getNewFromOriginal (orig_op1),
9268
+ gutils->getNewFromOriginal (orig_op2),
9269
+ gutils->getNewFromOriginal (orig_op3)};
9270
+
9271
+ auto mul = gutils->oldFunc ->getParent ()->getOrInsertFunction (
9272
+ funcName, called->getFunctionType (), called->getAttributes ());
9273
+
9274
+ switch (Mode) {
9275
+ case DerivativeMode::ForwardMode:
9276
+ case DerivativeMode::ForwardModeSplit: {
9277
+ IRBuilder<> Builder2 (&call);
9278
+ getForwardBuilder (Builder2);
9279
+
9280
+ Value *diff[4 ] = {
9281
+ constantval0 ? Constant::getNullValue (orig_op0->getType ())
9282
+ : diffe (orig_op0, Builder2),
9283
+ constantval1 ? Constant::getNullValue (orig_op1->getType ())
9284
+ : diffe (orig_op1, Builder2),
9285
+ constantval2 ? Constant::getNullValue (orig_op2->getType ())
9286
+ : diffe (orig_op2, Builder2),
9287
+ constantval3 ? Constant::getNullValue (orig_op3->getType ())
9288
+ : diffe (orig_op3, Builder2)};
9289
+
9290
+ auto cal1 =
9291
+ Builder2.CreateCall (mul, {diff[0 ], diff[1 ], prim[2 ], prim[3 ]});
9292
+ auto cal2 =
9293
+ Builder2.CreateCall (mul, {prim[0 ], prim[1 ], diff[2 ], diff[3 ]});
9294
+
9295
+ Value *resReal =
9296
+ Builder2.CreateFAdd (Builder2.CreateExtractValue (cal1, {0 }),
9297
+ Builder2.CreateExtractValue (cal2, {0 }));
9298
+ Value *resImag =
9299
+ Builder2.CreateFAdd (Builder2.CreateExtractValue (cal1, {1 }),
9300
+ Builder2.CreateExtractValue (cal2, {1 }));
9301
+
9302
+ Value *res = Builder2.CreateInsertValue (
9303
+ UndefValue::get (call.getType ()), resReal, {0 });
9304
+ res = Builder2.CreateInsertValue (res, resImag, {1 });
9305
+
9306
+ setDiffe (&call, res, Builder2);
9307
+ return ;
9308
+ }
9309
+ case DerivativeMode::ReverseModeGradient:
9310
+ case DerivativeMode::ReverseModeCombined: {
9311
+ IRBuilder<> Builder2 (call.getParent ());
9312
+ getReverseBuilder (Builder2);
9313
+
9314
+ Value *idiff = diffe (&call, Builder2);
9315
+ Value *idiffReal = Builder2.CreateExtractValue (idiff, {0 });
9316
+ Value *idiffImag = Builder2.CreateExtractValue (idiff, {1 });
9317
+
9318
+ Value *diff0 = nullptr ;
9319
+ Value *diff1 = nullptr ;
9320
+
9321
+ if (!constantval0 || !constantval1)
9322
+ diff0 = Builder2.CreateCall (mul, {idiffReal, idiffImag,
9323
+ lookup (prim[2 ], Builder2),
9324
+ lookup (prim[3 ], Builder2)});
9325
+
9326
+ if (!constantval2 || !constantval3)
9327
+ diff1 = Builder2.CreateCall (mul, {lookup (prim[0 ], Builder2),
9328
+ lookup (prim[1 ], Builder2),
9329
+ idiffReal, idiffImag});
9330
+
9331
+ if (diff0 || diff1)
9332
+ setDiffe (&call, Constant::getNullValue (call.getType ()), Builder2);
9333
+
9334
+ if (diff0) {
9335
+ addToDiffe (orig_op0, Builder2.CreateExtractValue (diff0, {0 }),
9336
+ Builder2, orig_op0->getType ());
9337
+ addToDiffe (orig_op1, Builder2.CreateExtractValue (diff0, {1 }),
9338
+ Builder2, orig_op1->getType ());
9339
+ }
9340
+
9341
+ if (diff1) {
9342
+ addToDiffe (orig_op2, Builder2.CreateExtractValue (diff1, {0 }),
9343
+ Builder2, orig_op2->getType ());
9344
+ addToDiffe (orig_op3, Builder2.CreateExtractValue (diff1, {1 }),
9345
+ Builder2, orig_op3->getType ());
9346
+ }
9347
+
9348
+ return ;
9349
+ }
9350
+ case DerivativeMode::ReverseModePrimal:
9351
+ return ;
9352
+ }
9353
+ }
9354
+
9355
+ if (funcName == " __divsc3" || funcName == " __divdc3" ||
9356
+ funcName == " __divtc3" || funcName == " __divxc3" ) {
9357
+ if (gutils->knownRecomputeHeuristic .find (orig) !=
9358
+ gutils->knownRecomputeHeuristic .end ()) {
9359
+ if (!gutils->knownRecomputeHeuristic [orig]) {
9360
+ gutils->cacheForReverse (BuilderZ, newCall,
9361
+ getIndex (orig, CacheType::Self));
9362
+ }
9363
+ }
9364
+
9365
+ if (gutils->isConstantInstruction (orig))
9366
+ return ;
9367
+
9368
+ StringMap<StringRef> map = {
9369
+ {" __divsc3" , " __mulsc3" },
9370
+ {" __divdc3" , " __muldc3" },
9371
+ {" __divtc3" , " __multc3" },
9372
+ {" __divxc3" , " __mulxc3" },
9373
+ };
9374
+
9375
+ auto mul = gutils->oldFunc ->getParent ()->getOrInsertFunction (
9376
+ map[funcName], called->getFunctionType (), called->getAttributes ());
9377
+
9378
+ auto div = gutils->oldFunc ->getParent ()->getOrInsertFunction (
9379
+ funcName, called->getFunctionType (), called->getAttributes ());
9380
+
9381
+ Value *orig_op0 = call.getOperand (0 );
9382
+ Value *orig_op1 = call.getOperand (1 );
9383
+ Value *orig_op2 = call.getOperand (2 );
9384
+ Value *orig_op3 = call.getOperand (3 );
9385
+
9386
+ bool constantval0 = gutils->isConstantValue (orig_op0);
9387
+ bool constantval1 = gutils->isConstantValue (orig_op1);
9388
+ bool constantval2 = gutils->isConstantValue (orig_op2);
9389
+ bool constantval3 = gutils->isConstantValue (orig_op3);
9390
+
9391
+ Value *prim[4 ] = {gutils->getNewFromOriginal (orig_op0),
9392
+ gutils->getNewFromOriginal (orig_op1),
9393
+ gutils->getNewFromOriginal (orig_op2),
9394
+ gutils->getNewFromOriginal (orig_op3)};
9395
+
9396
+ switch (Mode) {
9397
+ case DerivativeMode::ForwardMode:
9398
+ case DerivativeMode::ForwardModeSplit: {
9399
+ IRBuilder<> Builder2 (&call);
9400
+ getForwardBuilder (Builder2);
9401
+
9402
+ Value *diff[4 ] = {
9403
+ constantval0 ? Constant::getNullValue (orig_op0->getType ())
9404
+ : diffe (orig_op0, Builder2),
9405
+ constantval1 ? Constant::getNullValue (orig_op1->getType ())
9406
+ : diffe (orig_op1, Builder2),
9407
+ constantval2 ? Constant::getNullValue (orig_op2->getType ())
9408
+ : diffe (orig_op2, Builder2),
9409
+ constantval3 ? Constant::getNullValue (orig_op3->getType ())
9410
+ : diffe (orig_op3, Builder2)};
9411
+
9412
+ auto mul1 =
9413
+ Builder2.CreateCall (mul, {diff[0 ], diff[1 ], prim[2 ], prim[3 ]});
9414
+ auto mul2 =
9415
+ Builder2.CreateCall (mul, {prim[0 ], prim[1 ], diff[2 ], diff[3 ]});
9416
+ auto sq1 =
9417
+ Builder2.CreateCall (mul, {prim[2 ], prim[3 ], prim[2 ], prim[3 ]});
9418
+
9419
+ Value *subReal =
9420
+ Builder2.CreateFSub (Builder2.CreateExtractValue (mul1, {0 }),
9421
+ Builder2.CreateExtractValue (mul2, {0 }));
9422
+ Value *subImag =
9423
+ Builder2.CreateFSub (Builder2.CreateExtractValue (mul1, {1 }),
9424
+ Builder2.CreateExtractValue (mul2, {1 }));
9425
+
9426
+ auto div1 = Builder2.CreateCall (
9427
+ div, {subReal, subImag, Builder2.CreateExtractValue (sq1, {0 }),
9428
+ Builder2.CreateExtractValue (sq1, {1 })});
9429
+
9430
+ setDiffe (&call, div1, Builder2);
9431
+
9432
+ eraseIfUnused (*orig);
9433
+
9434
+ return ;
9435
+ }
9436
+ case DerivativeMode::ReverseModeGradient:
9437
+ case DerivativeMode::ReverseModeCombined: {
9438
+ IRBuilder<> Builder2 (call.getParent ());
9439
+ getReverseBuilder (Builder2);
9440
+
9441
+ Value *idiff = diffe (&call, Builder2);
9442
+ Value *idiffReal = Builder2.CreateExtractValue (idiff, {0 });
9443
+ Value *idiffImag = Builder2.CreateExtractValue (idiff, {1 });
9444
+
9445
+ Value *diff0 = nullptr ;
9446
+ Value *diff1 = nullptr ;
9447
+
9448
+ if (!constantval0 || !constantval1)
9449
+ diff0 = Builder2.CreateCall (div, {idiffReal, idiffImag,
9450
+ lookup (prim[2 ], Builder2),
9451
+ lookup (prim[3 ], Builder2)});
9452
+
9453
+ if (!constantval2 || !constantval3) {
9454
+ auto fdiv = Builder2.CreateCall (div, {idiffReal, idiffImag,
9455
+ lookup (prim[1 ], Builder2),
9456
+ lookup (prim[2 ], Builder2)});
9457
+
9458
+ Value *newcall = gutils->getNewFromOriginal (&call);
9459
+
9460
+ diff1 = Builder2.CreateCall (
9461
+ mul,
9462
+ {Builder2.CreateFNeg (Builder2.CreateExtractValue (newcall, {0 })),
9463
+ Builder2.CreateFNeg (Builder2.CreateExtractValue (newcall, {1 })),
9464
+ Builder2.CreateExtractValue (fdiv, {0 }),
9465
+ Builder2.CreateExtractValue (fdiv, {1 })});
9466
+ }
9467
+
9468
+ if (diff0 || diff1)
9469
+ setDiffe (&call, Constant::getNullValue (call.getType ()), Builder2);
9470
+
9471
+ if (diff0) {
9472
+ addToDiffe (orig_op0, Builder2.CreateExtractValue (diff0, {0 }),
9473
+ Builder2, orig_op0->getType ());
9474
+ addToDiffe (orig_op1, Builder2.CreateExtractValue (diff0, {1 }),
9475
+ Builder2, orig_op1->getType ());
9476
+ }
9477
+
9478
+ if (diff1) {
9479
+ addToDiffe (orig_op2, Builder2.CreateExtractValue (diff1, {0 }),
9480
+ Builder2, orig_op2->getType ());
9481
+ addToDiffe (orig_op3, Builder2.CreateExtractValue (diff1, {1 }),
9482
+ Builder2, orig_op3->getType ());
9483
+ }
9484
+
9485
+ if (constantval2 && constantval3)
9486
+ eraseIfUnused (*orig);
9487
+
9488
+ return ;
9489
+ }
9490
+ case DerivativeMode::ReverseModePrimal:;
9491
+ return ;
9492
+ }
9493
+ }
9494
+
9495
+ if (funcName == " scalbn" || funcName == " scalbnf" ||
9496
+ funcName == " scalbnl" || funcName == " scalbln" ||
9497
+ funcName == " scalblnf" || funcName == " scalblnl" ) {
9498
+ eraseIfUnused (*orig);
9499
+
9500
+ Value *orig_op0 = call.getOperand (0 );
9501
+ Value *orig_op1 = call.getOperand (1 );
9502
+
9503
+ bool constantval0 = gutils->isConstantValue (orig_op0);
9504
+
9505
+ if (gutils->isConstantInstruction (orig) || constantval0)
9506
+ return ;
9507
+
9508
+ Value *op0 = gutils->getNewFromOriginal (orig_op0);
9509
+ Value *op1 = gutils->getNewFromOriginal (orig_op1);
9510
+
9511
+ auto scal = gutils->oldFunc ->getParent ()->getOrInsertFunction (
9512
+ funcName, called->getFunctionType (), called->getAttributes ());
9513
+
9514
+ switch (Mode) {
9515
+ case DerivativeMode::ForwardMode:
9516
+ case DerivativeMode::ForwardModeSplit: {
9517
+ IRBuilder<> Builder2 (&call);
9518
+ getForwardBuilder (Builder2);
9519
+
9520
+ Value *diff0 = diffe (orig_op0, Builder2);
9521
+
9522
+ auto cal1 = Builder2.CreateCall (scal, {op0, op1});
9523
+ auto cal2 = Builder2.CreateCall (scal, {diff0, op1});
9524
+
9525
+ Value *diff = Builder2.CreateFMul (
9526
+ cal1, ConstantFP::get (call.getType (), 0.3010299957 ));
9527
+ diff = Builder2.CreateFAdd (diff, cal2);
9528
+
9529
+ setDiffe (&call, diff, Builder2);
9530
+ return ;
9531
+ }
9532
+ case DerivativeMode::ReverseModeGradient:
9533
+ case DerivativeMode::ReverseModeCombined: {
9534
+ IRBuilder<> Builder2 (call.getParent ());
9535
+ getReverseBuilder (Builder2);
9536
+
9537
+ Value *idiff = diffe (&call, Builder2);
9538
+
9539
+ if (idiff && !constantval0) {
9540
+ op1 = lookup (op1, Builder2);
9541
+
9542
+ auto cal1 = Builder2.CreateCall (scal, {op0, op1});
9543
+ auto cal2 = Builder2.CreateCall (scal, {idiff, op1});
9544
+
9545
+ Value *diff = Builder2.CreateFMul (
9546
+ cal1, ConstantFP::get (call.getType (), 0.3010299957 ));
9547
+ diff = Builder2.CreateFAdd (diff, cal2);
9548
+
9549
+ addToDiffe (orig_op0, diff, Builder2, call.getType ());
9550
+ }
9551
+
9552
+ return ;
9553
+ }
9554
+ case DerivativeMode::ReverseModePrimal:;
9555
+ return ;
9556
+ }
9557
+ }
9558
+
9242
9559
if (called) {
9243
9560
if (funcName == " erf" || funcName == " erfi" || funcName == " erfc" ||
9244
9561
funcName == " Faddeeva_erf" || funcName == " Faddeeva_erfi" ||
0 commit comments