Skip to content

Commit 5bd1108

Browse files
Merge pull request #5 from enveritas/fix/escape-constants
Fix/escape constants
2 parents 440baa6 + 1967bfa commit 5bd1108

File tree

2 files changed

+70
-23
lines changed

2 files changed

+70
-23
lines changed

internal/poet/reserved.go

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,25 +4,37 @@ import "slices"
44

55
// TODO(quentin@escape.tech): check if this is complete
66
var reservedKeywords = []string{
7-
"class",
8-
"if",
9-
"else",
10-
"elif",
11-
"not",
12-
"for",
137
"and",
14-
"in",
15-
"is",
16-
"or",
17-
"with",
188
"as",
199
"assert",
10+
"async",
11+
"await",
2012
"break",
13+
"class",
14+
"continue",
15+
"def",
16+
"del",
17+
"elif",
18+
"else",
2119
"except",
2220
"finally",
23-
"try",
21+
"for",
22+
"from",
23+
"global",
24+
"if",
25+
"import",
26+
"in",
27+
"is",
28+
"lambda",
29+
"nonlocal",
30+
"not",
31+
"or",
32+
"pass",
2433
"raise",
2534
"return",
35+
"try",
36+
"while",
37+
"with",
2638
"yield",
2739
}
2840

internal/printer/printer.go

Lines changed: 47 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,22 @@ func (w *writer) printIndent(indent int32) {
3737
}
3838
}
3939

40+
func (w *writer) printCommentText(text string, indent int32) {
41+
lines := strings.Split(text, "\n")
42+
for _, line := range lines {
43+
w.print("#")
44+
// trim right space which is usually unintended,
45+
// but leave left space untouched in case if it's intentionally formatted.
46+
trimmed := strings.TrimRight(line, " ")
47+
if trimmed != "" {
48+
w.print(" ")
49+
w.print(trimmed)
50+
}
51+
w.print("\n")
52+
w.printIndent(indent)
53+
}
54+
}
55+
4056
func (w *writer) printNode(node *ast.Node, indent int32) {
4157
switch n := node.Node.(type) {
4258

@@ -132,10 +148,7 @@ func (w *writer) printNode(node *ast.Node, indent int32) {
132148

133149
func (w *writer) printAnnAssign(aa *ast.AnnAssign, indent int32) {
134150
if aa.Comment != "" {
135-
w.print("# ")
136-
w.print(aa.Comment)
137-
w.print("\n")
138-
w.printIndent(indent)
151+
w.printCommentText(aa.Comment, indent)
139152
}
140153
w.printName(aa.Target, indent)
141154
w.print(": ")
@@ -255,10 +268,7 @@ func (w *writer) printClassDef(cd *ast.ClassDef, indent int32) {
255268
if i == 0 {
256269
if e, ok := node.Node.(*ast.Node_Expr); ok {
257270
if c, ok := e.Expr.Value.Node.(*ast.Node_Constant); ok {
258-
w.print(`""`)
259-
w.printConstant(c.Constant, indent)
260-
w.print(`""`)
261-
w.print("\n")
271+
w.printDocString(c.Constant, indent)
262272
continue
263273
}
264274
}
@@ -268,6 +278,33 @@ func (w *writer) printClassDef(cd *ast.ClassDef, indent int32) {
268278
}
269279
}
270280

281+
func (w *writer) printDocString(c *ast.Constant, indent int32) {
282+
switch n := c.Value.(type) {
283+
case *ast.Constant_Str:
284+
w.print(`"""`)
285+
lines := strings.Split(n.Str, "\n")
286+
printedN := 0
287+
for n, line := range lines {
288+
// trim right space which is usually unintended,
289+
// but leave left space untouched in case if it's intentionally formatted.
290+
trimmed := strings.TrimRight(line, " ")
291+
if trimmed == "" {
292+
continue
293+
}
294+
if printedN > 0 && n < len(lines)-1 {
295+
w.print("\n")
296+
w.printIndent(indent)
297+
}
298+
w.print(strings.ReplaceAll(trimmed, `"`, `\"`))
299+
printedN++
300+
}
301+
w.print(`"""`)
302+
w.print("\n")
303+
default:
304+
panic(n)
305+
}
306+
}
307+
271308
func (w *writer) printConstant(c *ast.Constant, indent int32) {
272309
switch n := c.Value.(type) {
273310
case *ast.Constant_Int:
@@ -282,7 +319,7 @@ func (w *writer) printConstant(c *ast.Constant, indent int32) {
282319
str = `"""`
283320
}
284321
w.print(str)
285-
w.print(n.Str)
322+
w.print(strings.ReplaceAll(n.Str, `"`, `\"`))
286323
w.print(str)
287324

288325
default:
@@ -291,9 +328,7 @@ func (w *writer) printConstant(c *ast.Constant, indent int32) {
291328
}
292329

293330
func (w *writer) printComment(c *ast.Comment, indent int32) {
294-
w.print("# ")
295-
w.print(c.Text)
296-
w.print("\n")
331+
w.printCommentText(c.Text, 0)
297332
}
298333

299334
func (w *writer) printCompare(c *ast.Compare, indent int32) {

0 commit comments

Comments
 (0)