Skip to content

(WIP) Implement Generics API #7424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 26 commits into from
May 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
f3ff534
Implement Generics API
jinzhu Apr 17, 2025
ded38c5
Merge branch 'master' into generics
jinzhu Apr 17, 2025
1bd31b0
Merge branch 'master' into generics
jinzhu Apr 17, 2025
0fbe4f6
Add more generics tests
jinzhu Apr 17, 2025
ba27874
Add more tests and Take method
jinzhu Apr 17, 2025
2d6d7f9
use delayed‑ops pipeline for generics API
jinzhu Apr 18, 2025
3de6d0b
fix generics tests for mysql
jinzhu Apr 18, 2025
797a557
Support SubQuery for Generics
jinzhu Apr 20, 2025
4fcd909
Add clause.JoinTable helper method
jinzhu Apr 20, 2025
7095605
Fix golangci-lint error
jinzhu Apr 21, 2025
05925b2
Complete the design and implementation of generic version Join
jinzhu May 9, 2025
d073805
improve generics version Joins support
jinzhu May 20, 2025
ba94e4e
allow configuring select/omit columns for joins via subqueries
jinzhu May 20, 2025
4694673
finish generic version Preload
jinzhu May 20, 2025
e330694
handle error of generics Joins/Preload
jinzhu May 20, 2025
9b1ce2b
fix tests
jinzhu May 20, 2025
91eb947
Merge branch 'master' into generics
jinzhu May 20, 2025
6307f69
Add LimitPerRecord for generic version Preload
jinzhu May 21, 2025
304baab
fix tests for mysql 5.7
jinzhu May 21, 2025
774d957
test for nested generic version Join/Preload
jinzhu May 22, 2025
ddaee81
Add WithResult support for generics API
jinzhu May 22, 2025
8ced549
test reuse generics db conditions
jinzhu May 22, 2025
0305e0d
fix data race
jinzhu May 22, 2025
4ee59e1
remove ExampleLRU test
jinzhu May 22, 2025
4db3fde
Add default transaction timeout support
jinzhu May 23, 2025
4f189a7
fix test
jinzhu May 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions callbacks/create.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ func Create(config *Config) func(db *gorm.DB) {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, mode)

if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}

return
Expand All @@ -103,6 +107,12 @@ func Create(config *Config) func(db *gorm.DB) {
}

db.RowsAffected, _ = result.RowsAffected()

if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}

if db.RowsAffected == 0 {
return
}
Expand Down
10 changes: 10 additions & 0 deletions callbacks/delete.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,15 +157,25 @@ func Delete(config *Config) func(db *gorm.DB) {
ok, mode := hasReturning(db, supportReturning)
if !ok {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)

if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()

if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
}

return
}

if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil {
gorm.Scan(rows, db, mode)

if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
db.AddError(rows.Close())
}
}
Expand Down
8 changes: 7 additions & 1 deletion callbacks/preload.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues)

if len(values) != 0 {
tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values})

for _, cond := range conds {
if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok {
tx = fc(tx)
Expand All @@ -283,7 +285,11 @@ func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preload
}
}

if err := tx.Where(clause.IN{Column: column, Values: values}).Find(reflectResults.Addr().Interface(), inlineConds...).Error; err != nil {
if len(inlineConds) > 0 {
tx = tx.Where(inlineConds[0], inlineConds[1:]...)
}

if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil {
return err
}
}
Expand Down
49 changes: 30 additions & 19 deletions callbacks/query.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ func Query(db *gorm.DB) {
db.AddError(rows.Close())
}()
gorm.Scan(rows, db, 0)

if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
}
}
Expand Down Expand Up @@ -110,7 +114,7 @@ func BuildQuerySQL(db *gorm.DB) {
}
}

specifiedRelationsName := make(map[string]interface{})
specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable}
for _, join := range db.Statement.Joins {
if db.Statement.Schema != nil {
var isRelations bool // is relations or raw sql
Expand All @@ -124,12 +128,12 @@ func BuildQuerySQL(db *gorm.DB) {
nestedJoinNames := strings.Split(join.Name, ".")
if len(nestedJoinNames) > 1 {
isNestedJoin := true
gussNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames))
currentRelations := db.Statement.Schema.Relationships.Relations
for _, relname := range nestedJoinNames {
// incomplete match, only treated as raw sql
if relation, ok = currentRelations[relname]; ok {
gussNestedRelations = append(gussNestedRelations, relation)
guessNestedRelations = append(guessNestedRelations, relation)
currentRelations = relation.FieldSchema.Relationships.Relations
} else {
isNestedJoin = false
Expand All @@ -139,18 +143,13 @@ func BuildQuerySQL(db *gorm.DB) {

if isNestedJoin {
isRelations = true
relations = gussNestedRelations
relations = guessNestedRelations
}
}
}

if isRelations {
genJoinClause := func(joinType clause.JoinType, parentTableName string, relation *schema.Relationship) clause.Join {
tableAliasName := relation.Name
if parentTableName != clause.CurrentTable {
tableAliasName = utils.NestedRelationName(parentTableName, tableAliasName)
}

genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join {
columnStmt := gorm.Statement{
Table: tableAliasName, DB: db, Schema: relation.FieldSchema,
Selects: join.Selects, Omits: join.Omits,
Expand All @@ -167,6 +166,13 @@ func BuildQuerySQL(db *gorm.DB) {
}
}

if join.Expression != nil {
return clause.Join{
Type: join.JoinType,
Expression: join.Expression,
}
}

exprs := make([]clause.Expression, len(relation.References))
for idx, ref := range relation.References {
if ref.OwnPrimaryKey {
Expand Down Expand Up @@ -226,19 +232,24 @@ func BuildQuerySQL(db *gorm.DB) {
}

parentTableName := clause.CurrentTable
for _, rel := range relations {
for idx, rel := range relations {
// joins table alias like "Manager, Company, Manager__Company"
nestedAlias := utils.NestedRelationName(parentTableName, rel.Name)
if _, ok := specifiedRelationsName[nestedAlias]; !ok {
fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, parentTableName, rel))
specifiedRelationsName[nestedAlias] = nil
curAliasName := rel.Name
if parentTableName != clause.CurrentTable {
curAliasName = utils.NestedRelationName(parentTableName, curAliasName)
}

if parentTableName != clause.CurrentTable {
parentTableName = utils.NestedRelationName(parentTableName, rel.Name)
} else {
parentTableName = rel.Name
if _, ok := specifiedRelationsName[curAliasName]; !ok {
aliasName := curAliasName
if idx == len(relations)-1 && join.Alias != "" {
aliasName = join.Alias
}

fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel))
specifiedRelationsName[curAliasName] = aliasName
}

parentTableName = curAliasName
}
} else {
fromClause.Joins = append(fromClause.Joins, clause.Join{
Expand Down
5 changes: 5 additions & 0 deletions callbacks/raw.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,10 @@ func RawExec(db *gorm.DB) {
}

db.RowsAffected, _ = result.RowsAffected()

if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
}
9 changes: 9 additions & 0 deletions callbacks/update.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,22 @@ func Update(config *Config) func(db *gorm.DB) {
gorm.Scan(rows, db, mode)
db.Statement.Dest = dest
db.AddError(rows.Close())

if db.Statement.Result != nil {
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
} else {
result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...)

if db.AddError(err) == nil {
db.RowsAffected, _ = result.RowsAffected()
}

if db.Statement.Result != nil {
db.Statement.Result.Result = result
db.Statement.Result.RowsAffected = db.RowsAffected
}
}
}
}
Expand Down
7 changes: 4 additions & 3 deletions chainable_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,9 +448,10 @@ func (db *DB) Assign(attrs ...interface{}) (tx *DB) {
// Unscoped allows queries to include records marked as deleted,
// overriding the soft deletion behavior.
// Example:
// var users []User
// db.Unscoped().Find(&users)
// // Retrieves all users, including deleted ones.
//
// var users []User
// db.Unscoped().Find(&users)
// // Retrieves all users, including deleted ones.
func (db *DB) Unscoped() (tx *DB) {
tx = db.getInstance()
tx.Statement.Unscoped = true
Expand Down
32 changes: 32 additions & 0 deletions clause/joins.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
package clause

import "gorm.io/gorm/utils"

type JoinType string

const (
Expand All @@ -9,6 +11,30 @@ const (
RightJoin JoinType = "RIGHT"
)

type JoinTarget struct {
Type JoinType
Association string
Subquery Expression
Table string
}

func Has(name string) JoinTarget {
return JoinTarget{Type: InnerJoin, Association: name}
}

func (jt JoinType) Association(name string) JoinTarget {
return JoinTarget{Type: jt, Association: name}
}

func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget {
return JoinTarget{Type: jt, Association: name, Subquery: subquery}
}

func (jt JoinTarget) As(name string) JoinTarget {
jt.Table = name
return jt
}

// Join clause for from
type Join struct {
Type JoinType
Expand All @@ -18,6 +44,12 @@ type Join struct {
Expression Expression
}

func JoinTable(names ...string) Table {
return Table{
Name: utils.JoinNestedRelationNames(names),
}
}

func (join Join) Build(builder Builder) {
if join.Expression != nil {
join.Expression.Build(builder)
Expand Down
12 changes: 10 additions & 2 deletions finisher_api.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gorm

import (
"context"
"database/sql"
"errors"
"fmt"
Expand Down Expand Up @@ -673,11 +674,18 @@
opt = opts[0]
}

ctx := tx.Statement.Context
if _, ok := ctx.Deadline(); !ok {
if db.Config.DefaultTransactionTimeout > 0 {
ctx, _ = context.WithTimeout(ctx, db.Config.DefaultTransactionTimeout)

Check failure on line 680 in finisher_api.go

View workflow job for this annotation

GitHub Actions / lint

lostcancel: the cancel function returned by context.WithTimeout should be called, not discarded, to avoid a context leak (govet)
}
}

switch beginner := tx.Statement.ConnPool.(type) {
case TxBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
case ConnPoolBeginner:
tx.Statement.ConnPool, err = beginner.BeginTx(tx.Statement.Context, opt)
tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt)
default:
err = ErrInvalidTransaction
}
Expand Down
Loading
Loading