Skip to content

Commit c324a9a

Browse files
aisksbinet
andauthored
py: harden __import__ argument handling
Fixes #204. Co-authored-by: Sebastien Binet <binet@cern.ch>
1 parent 530fdbd commit c324a9a

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

py/import.go

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,9 @@ func ImportModuleLevelObject(ctx Context, name string, globals, locals StringDic
108108
}
109109

110110
if fromFile, ok := globals["__file__"]; ok {
111-
opts.CurDir = filepath.Dir(string(fromFile.(String)))
111+
if fromFileStr, ok := fromFile.(String); ok {
112+
opts.CurDir = filepath.Dir(string(fromFileStr))
113+
}
112114
}
113115

114116
module, err := RunFile(ctx, srcPathname, opts, name)
@@ -344,14 +346,42 @@ func BuiltinImport(ctx Context, self Object, args Tuple, kwargs StringDict, curr
344346
var globals Object = currentGlobal
345347
var locals Object = NewStringDict()
346348
var fromlist Object = Tuple{}
349+
var fromlistTuple Tuple
347350
var level Object = Int(0)
348351

349352
err := ParseTupleAndKeywords(args, kwargs, "U|OOOi:__import__", kwlist, &name, &globals, &locals, &fromlist, &level)
350353
if err != nil {
351354
return nil, err
352355
}
353-
if fromlist == None {
354-
fromlist = Tuple{}
356+
levelObj, ok := level.(Int)
357+
if !ok {
358+
return nil, ExceptionNewf(TypeError, "__import__() argument 5 must be int, not %s", level.Type().Name)
359+
}
360+
levelInt, err := levelObj.GoInt()
361+
if err != nil {
362+
return nil, err
363+
}
364+
365+
globalsDict, ok := globals.(StringDict)
366+
if !ok {
367+
if levelInt > 0 {
368+
return nil, ExceptionNewf(TypeError, "globals must be a dict")
369+
}
370+
globalsDict = StringDict{}
355371
}
356-
return ImportModuleLevelObject(ctx, string(name.(String)), globals.(StringDict), locals.(StringDict), fromlist.(Tuple), int(level.(Int)))
372+
373+
localsDict, ok := locals.(StringDict)
374+
if !ok {
375+
localsDict = StringDict{}
376+
}
377+
378+
fromlistTuple = Tuple{}
379+
if fromlist != None {
380+
fromlistTuple, err = SequenceTuple(fromlist)
381+
if err != nil {
382+
return nil, err
383+
}
384+
}
385+
386+
return ImportModuleLevelObject(ctx, string(name.(String)), globalsDict, localsDict, fromlistTuple, levelInt)
357387
}

stdlib/builtin/tests/builtin.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,26 @@ class C: pass
501501
assert lib.libfn() == 42
502502
assert lib.libvar == 43
503503
assert lib.libclass().method() == 44
504+
lib = __import__("lib", {}, {}, [""])
505+
assert lib.libfn() == 42
506+
ok = False
507+
try:
508+
__import__("lib", {}, {}, 1)
509+
except TypeError:
510+
ok = True
511+
assert ok, "TypeError not raised"
512+
lib = __import__("lib", 1, {}, [""])
513+
assert lib.libfn() == 42
514+
ok = False
515+
try:
516+
__import__("lib", 1, {}, [""], 1)
517+
except TypeError as e:
518+
if e.args[0] != "globals must be a dict":
519+
raise
520+
ok = True
521+
assert ok, "TypeError not raised"
522+
lib = __import__("lib", {"__file__": 1}, {}, [""])
523+
assert lib.libfn() == 42
504524

505525
doc="input"
506526
import sys

0 commit comments

Comments
 (0)