5
5
BaseDatabaseSchemaEditor , logger , _is_relevant_relation , _related_non_m2m_objects ,
6
6
)
7
7
from django .db .backends .ddl_references import (
8
- Statement ,
8
+ Columns , IndexName , Statement as DjStatement , Table ,
9
9
)
10
10
from django .db .models import Index
11
11
from django .db .models .fields import AutoField , BigAutoField
12
12
from django .db .transaction import TransactionManagementError
13
13
from django .utils .encoding import force_text
14
14
15
15
16
+ class Statement (DjStatement ):
17
+ def __hash__ (self ):
18
+ return hash ((self .template , str (self .parts ['name' ])))
19
+
20
+ def __eq__ (self , other ):
21
+ return self .template == other .template and str (self .parts ['name' ]) == str (other .parts ['name' ])
22
+
23
+
16
24
class DatabaseSchemaEditor (BaseDatabaseSchemaEditor ):
17
25
18
26
_sql_check_constraint = " CONSTRAINT %(name)s CHECK (%(check)s)"
@@ -174,16 +182,51 @@ def _alter_many_to_many(self, model, old_field, new_field, strict):
174
182
175
183
return super ()._alter_many_to_many (model , old_field , new_field , strict )
176
184
185
+ def _db_table_constraint_names (self , db_table , column_names = None , unique = None ,
186
+ primary_key = None , index = None , foreign_key = None ,
187
+ check = None , type_ = None , exclude = None ):
188
+ """Return all constraint names matching the columns and conditions."""
189
+ if column_names is not None :
190
+ column_names = [
191
+ self .connection .introspection .identifier_converter (name )
192
+ for name in column_names
193
+ ]
194
+ with self .connection .cursor () as cursor :
195
+ constraints = self .connection .introspection .get_constraints (cursor , db_table )
196
+ result = []
197
+ for name , infodict in constraints .items ():
198
+ if column_names is None or column_names == infodict ['columns' ]:
199
+ if unique is not None and infodict ['unique' ] != unique :
200
+ continue
201
+ if primary_key is not None and infodict ['primary_key' ] != primary_key :
202
+ continue
203
+ if index is not None and infodict ['index' ] != index :
204
+ continue
205
+ if check is not None and infodict ['check' ] != check :
206
+ continue
207
+ if foreign_key is not None and not infodict ['foreign_key' ]:
208
+ continue
209
+ if type_ is not None and infodict ['type' ] != type_ :
210
+ continue
211
+ if not exclude or name not in exclude :
212
+ result .append (name )
213
+ return result
214
+
215
+ def _db_table_delete_constraint_sql (self , template , db_table , name ):
216
+ return Statement (
217
+ template ,
218
+ table = Table (db_table , self .quote_name ),
219
+ name = self .quote_name (name ),
220
+ )
221
+
177
222
def alter_db_table (self , model , old_db_table , new_db_table ):
178
- index_names = self ._constraint_names ( model , index = True )
223
+ index_names = self ._db_table_constraint_names ( old_db_table , index = True )
179
224
for index_name in index_names :
180
- self .execute (self ._delete_constraint_sql (self .sql_delete_index , model , index_name ))
225
+ self .execute (self ._db_table_delete_constraint_sql (self .sql_delete_index , old_db_table , index_name ))
181
226
182
- model ._meta .db_table = old_db_table
183
- index_names = self ._constraint_names (model , index = True )
227
+ index_names = self ._db_table_constraint_names (new_db_table , index = True )
184
228
for index_name in index_names :
185
- self .execute (self ._delete_constraint_sql (self .sql_delete_index , model , index_name ))
186
- model ._meta .db_table = new_db_table
229
+ self .execute (self ._db_table_delete_constraint_sql (self .sql_delete_index , new_db_table , index_name ))
187
230
188
231
return super ().alter_db_table (model , old_db_table , new_db_table )
189
232
@@ -627,6 +670,61 @@ def add_field(self, model, field):
627
670
if self .connection .features .connection_persists_old_columns :
628
671
self .connection .close ()
629
672
673
+ def _create_unique_sql (self , model , columns , name = None , condition = None ):
674
+ def create_unique_name (* args , ** kwargs ):
675
+ return self .quote_name (self ._create_index_name (* args , ** kwargs ))
676
+
677
+ table = Table (model ._meta .db_table , self .quote_name )
678
+ if name is None :
679
+ name = IndexName (model ._meta .db_table , columns , '_uniq' , create_unique_name )
680
+ else :
681
+ name = self .quote_name (name )
682
+ columns = Columns (table , columns , self .quote_name )
683
+ if condition :
684
+ return Statement (
685
+ self .sql_create_unique_index ,
686
+ table = table ,
687
+ name = name ,
688
+ columns = columns ,
689
+ condition = ' WHERE ' + condition ,
690
+ ) if self .connection .features .supports_partial_indexes else None
691
+ else :
692
+ return Statement (
693
+ self .sql_create_unique ,
694
+ table = table ,
695
+ name = name ,
696
+ columns = columns ,
697
+ )
698
+
699
+ def _create_index_sql (self , model , fields , * , name = None , suffix = '' , using = '' ,
700
+ db_tablespace = None , col_suffixes = (), sql = None , opclasses = (),
701
+ condition = None ):
702
+ """
703
+ Return the SQL statement to create the index for one or several fields.
704
+ `sql` can be specified if the syntax differs from the standard (GIS
705
+ indexes, ...).
706
+ """
707
+ tablespace_sql = self ._get_index_tablespace_sql (model , fields , db_tablespace = db_tablespace )
708
+ columns = [field .column for field in fields ]
709
+ sql_create_index = sql or self .sql_create_index
710
+ table = model ._meta .db_table
711
+
712
+ def create_index_name (* args , ** kwargs ):
713
+ nonlocal name
714
+ if name is None :
715
+ name = self ._create_index_name (* args , ** kwargs )
716
+ return self .quote_name (name )
717
+
718
+ return Statement (
719
+ sql_create_index ,
720
+ table = Table (table , self .quote_name ),
721
+ name = IndexName (table , columns , suffix , create_index_name ),
722
+ using = using ,
723
+ columns = self ._index_columns (table , columns , col_suffixes , opclasses ),
724
+ extra = tablespace_sql ,
725
+ condition = (' WHERE ' + condition ) if condition else '' ,
726
+ )
727
+
630
728
def create_model (self , model ):
631
729
"""
632
730
Takes a model and creates a table for it in the database.
@@ -684,6 +782,13 @@ def create_model(self, model):
684
782
if autoinc_sql :
685
783
self .deferred_sql .extend (autoinc_sql )
686
784
785
+ # Add any unique_togethers (always deferred, as some fields might be
786
+ # created afterwards, like geometry fields with some backends)
787
+ for fields in model ._meta .unique_together :
788
+ columns = [model ._meta .get_field (field ).column for field in fields ]
789
+ condition = ' AND ' .join (["[%s] IS NOT NULL" % col for col in columns ])
790
+ self .deferred_sql .append (self ._create_unique_sql (model , columns , condition = condition ))
791
+
687
792
# Make the table
688
793
sql = self .sql_create_table % {
689
794
"table" : self .quote_name (model ._meta .db_table ),
@@ -698,6 +803,7 @@ def create_model(self, model):
698
803
699
804
# Add any field index and index_together's (deferred as SQLite3 _remake_table needs it)
700
805
self .deferred_sql .extend (self ._model_indexes_sql (model ))
806
+ self .deferred_sql = list (set (self .deferred_sql ))
701
807
702
808
# Make M2M tables
703
809
for field in model ._meta .local_many_to_many :
0 commit comments