1
1
use std:: borrow:: Cow ;
2
2
3
3
use bitflags:: bitflags;
4
- use ruff_python_ast:: { ExprConstant , Ranged } ;
4
+ use ruff_python_ast:: node:: AnyNodeRef ;
5
+ use ruff_python_ast:: { self as ast, ExprConstant , ExprJoinedStr , Ranged } ;
5
6
use ruff_python_parser:: lexer:: { lex_starts_at, LexicalError , LexicalErrorType } ;
6
7
use ruff_python_parser:: { Mode , Tok } ;
8
+ use ruff_source_file:: Locator ;
7
9
use ruff_text_size:: { TextLen , TextRange , TextSize } ;
8
10
9
11
use ruff_formatter:: { format_args, write, FormatError } ;
@@ -13,11 +15,62 @@ use crate::comments::{leading_comments, trailing_comments};
13
15
use crate :: expression:: parentheses:: {
14
16
in_parentheses_only_group, in_parentheses_only_soft_line_break_or_space,
15
17
} ;
18
+ use crate :: expression:: Expr ;
16
19
use crate :: prelude:: * ;
17
20
use crate :: QuoteStyle ;
18
21
22
+ #[ derive( Copy , Clone ) ]
23
+ enum Quoting {
24
+ CanChange ,
25
+ Preserve ,
26
+ }
27
+
28
+ pub ( super ) enum AnyString < ' a > {
29
+ Constant ( & ' a ExprConstant ) ,
30
+ JoinedStr ( & ' a ExprJoinedStr ) ,
31
+ }
32
+
33
+ impl < ' a > AnyString < ' a > {
34
+ fn quoting ( & self , locator : & Locator ) -> Quoting {
35
+ match self {
36
+ Self :: Constant ( _) => Quoting :: CanChange ,
37
+ Self :: JoinedStr ( joined_str) => {
38
+ if joined_str. values . iter ( ) . any ( |value| match value {
39
+ Expr :: FormattedValue ( ast:: ExprFormattedValue { range, .. } ) => {
40
+ let string_content = locator. slice ( * range) ;
41
+ string_content. contains ( [ '"' , '\'' ] )
42
+ }
43
+ _ => false ,
44
+ } ) {
45
+ Quoting :: Preserve
46
+ } else {
47
+ Quoting :: CanChange
48
+ }
49
+ }
50
+ }
51
+ }
52
+ }
53
+
54
+ impl Ranged for AnyString < ' _ > {
55
+ fn range ( & self ) -> TextRange {
56
+ match self {
57
+ Self :: Constant ( expr) => expr. range ( ) ,
58
+ Self :: JoinedStr ( expr) => expr. range ( ) ,
59
+ }
60
+ }
61
+ }
62
+
63
+ impl < ' a > From < & AnyString < ' a > > for AnyNodeRef < ' a > {
64
+ fn from ( value : & AnyString < ' a > ) -> Self {
65
+ match value {
66
+ AnyString :: Constant ( expr) => AnyNodeRef :: ExprConstant ( expr) ,
67
+ AnyString :: JoinedStr ( expr) => AnyNodeRef :: ExprJoinedStr ( expr) ,
68
+ }
69
+ }
70
+ }
71
+
19
72
pub ( super ) struct FormatString < ' a > {
20
- constant : & ' a ExprConstant ,
73
+ string : & ' a AnyString < ' a > ,
21
74
layout : StringLayout ,
22
75
}
23
76
@@ -30,10 +83,12 @@ pub enum StringLayout {
30
83
}
31
84
32
85
impl < ' a > FormatString < ' a > {
33
- pub ( super ) fn new ( constant : & ' a ExprConstant ) -> Self {
34
- debug_assert ! ( constant. value. is_str( ) || constant. value. is_bytes( ) ) ;
86
+ pub ( super ) fn new ( string : & ' a AnyString ) -> Self {
87
+ if let AnyString :: Constant ( constant) = string {
88
+ debug_assert ! ( constant. value. is_str( ) || constant. value. is_bytes( ) ) ;
89
+ }
35
90
Self {
36
- constant ,
91
+ string ,
37
92
layout : StringLayout :: Default ,
38
93
}
39
94
}
@@ -48,40 +103,43 @@ impl<'a> Format<PyFormatContext<'_>> for FormatString<'a> {
48
103
fn fmt ( & self , f : & mut Formatter < PyFormatContext < ' _ > > ) -> FormatResult < ( ) > {
49
104
match self . layout {
50
105
StringLayout :: Default => {
51
- let string_range = self . constant . range ( ) ;
106
+ let string_range = self . string . range ( ) ;
52
107
let string_content = f. context ( ) . locator ( ) . slice ( string_range) ;
53
108
54
109
if is_implicit_concatenation ( string_content) {
55
- in_parentheses_only_group ( & FormatStringContinuation :: new ( self . constant ) ) . fmt ( f)
110
+ in_parentheses_only_group ( & FormatStringContinuation :: new ( self . string ) ) . fmt ( f)
56
111
} else {
57
- FormatStringPart :: new ( string_range) . fmt ( f)
112
+ FormatStringPart :: new ( string_range, self . string . quoting ( & f. context ( ) . locator ( ) ) )
113
+ . fmt ( f)
58
114
}
59
115
}
60
116
StringLayout :: ImplicitConcatenatedBinaryLeftSide => {
61
- FormatStringContinuation :: new ( self . constant ) . fmt ( f)
117
+ FormatStringContinuation :: new ( self . string ) . fmt ( f)
62
118
}
63
119
}
64
120
}
65
121
}
66
122
67
123
struct FormatStringContinuation < ' a > {
68
- constant : & ' a ExprConstant ,
124
+ string : & ' a AnyString < ' a > ,
69
125
}
70
126
71
127
impl < ' a > FormatStringContinuation < ' a > {
72
- fn new ( constant : & ' a ExprConstant ) -> Self {
73
- debug_assert ! ( constant. value. is_str( ) || constant. value. is_bytes( ) ) ;
74
- Self { constant }
128
+ fn new ( string : & ' a AnyString < ' a > ) -> Self {
129
+ if let AnyString :: Constant ( constant) = string {
130
+ debug_assert ! ( constant. value. is_str( ) || constant. value. is_bytes( ) ) ;
131
+ }
132
+ Self { string }
75
133
}
76
134
}
77
135
78
136
impl Format < PyFormatContext < ' _ > > for FormatStringContinuation < ' _ > {
79
137
fn fmt ( & self , f : & mut Formatter < PyFormatContext < ' _ > > ) -> FormatResult < ( ) > {
80
138
let comments = f. context ( ) . comments ( ) . clone ( ) ;
81
139
let locator = f. context ( ) . locator ( ) ;
82
- let mut dangling_comments = comments. dangling_comments ( self . constant ) ;
140
+ let mut dangling_comments = comments. dangling_comments ( self . string ) ;
83
141
84
- let string_range = self . constant . range ( ) ;
142
+ let string_range = self . string . range ( ) ;
85
143
let string_content = locator. slice ( string_range) ;
86
144
87
145
// The AST parses implicit concatenation as a single string.
@@ -155,7 +213,7 @@ impl Format<PyFormatContext<'_>> for FormatStringContinuation<'_> {
155
213
joiner. entry ( & format_args ! [
156
214
line_suffix_boundary( ) ,
157
215
leading_comments( leading_part_comments) ,
158
- FormatStringPart :: new( token_range) ,
216
+ FormatStringPart :: new( token_range, self . string . quoting ( & locator ) ) ,
159
217
trailing_comments( trailing_part_comments)
160
218
] ) ;
161
219
@@ -178,11 +236,15 @@ impl Format<PyFormatContext<'_>> for FormatStringContinuation<'_> {
178
236
179
237
struct FormatStringPart {
180
238
part_range : TextRange ,
239
+ quoting : Quoting ,
181
240
}
182
241
183
242
impl FormatStringPart {
184
- const fn new ( range : TextRange ) -> Self {
185
- Self { part_range : range }
243
+ const fn new ( range : TextRange , quoting : Quoting ) -> Self {
244
+ Self {
245
+ part_range : range,
246
+ quoting,
247
+ }
186
248
}
187
249
}
188
250
@@ -204,10 +266,15 @@ impl Format<PyFormatContext<'_>> for FormatStringPart {
204
266
205
267
let raw_content = & string_content[ relative_raw_content_range] ;
206
268
let is_raw_string = prefix. is_raw_string ( ) ;
207
- let preferred_quotes = if is_raw_string {
208
- preferred_quotes_raw ( raw_content, quotes, f. options ( ) . quote_style ( ) )
209
- } else {
210
- preferred_quotes ( raw_content, quotes, f. options ( ) . quote_style ( ) )
269
+ let preferred_quotes = match self . quoting {
270
+ Quoting :: Preserve => quotes,
271
+ Quoting :: CanChange => {
272
+ if is_raw_string {
273
+ preferred_quotes_raw ( raw_content, quotes, f. options ( ) . quote_style ( ) )
274
+ } else {
275
+ preferred_quotes ( raw_content, quotes, f. options ( ) . quote_style ( ) )
276
+ }
277
+ }
211
278
} ;
212
279
213
280
write ! ( f, [ prefix, preferred_quotes] ) ?;
0 commit comments