Skip to content

Commit 975305a

Browse files
shivasuryaclaude
andauthored
fix: resolve module-level classmethod aliases in stdlib type inference (shivasurya#616)
Python stdlib modules like tarfile expose classmethods at module level (tarfile.open is actually TarFile.open). The type inference engine's Phase A resolver only checked GetFunction and GetClass, missing these aliases entirely. Variables assigned from such calls stayed unresolved as "call:tarfile.open", breaking downstream type-aware rule matching. Add FindClassMethodAlias to StdlibRegistryRemote that searches all classes in a module for a matching method name. Uses GetClassMethod internally to also check inherited methods. Prefers the class whose name matches the module (e.g., TarFile in tarfile) for determinism. Resolves "Self" return types (e.g., tarfile.Self → tarfile.TarFile) so Phase B can chain method calls on the resolved type. Before: 1 detection (only tarfile.TarFile constructor path) After: 4 detections (module alias + constructor + extract + safe_extract) Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent ad7b078 commit 975305a

4 files changed

Lines changed: 291 additions & 0 deletions

File tree

sast-engine/graph/callgraph/builder/builder.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1934,6 +1934,25 @@ func resolveStdlibVariableBindings(typeEngine *resolution.TypeInferenceEngine, l
19341934
scope.Variables[varName][i].AssignedFrom = funcName
19351935
continue
19361936
}
1937+
1938+
// Try as module-level alias to a classmethod.
1939+
// Handles patterns like tarfile.open which is TarFile.open.
1940+
if method, className := loader.FindClassMethodAlias(moduleName, name, logger); method != nil {
1941+
returnType := method.ReturnType
1942+
if returnType != "" && returnType != "unknown" {
1943+
// Resolve "Self" return types to the owning class
1944+
if returnType == "Self" || returnType == moduleName+".Self" {
1945+
returnType = moduleName + "." + className
1946+
}
1947+
scope.Variables[varName][i].Type = &core.TypeInfo{
1948+
TypeFQN: returnType,
1949+
Confidence: binding.Type.Confidence * method.Confidence * 0.95,
1950+
Source: "stdlib",
1951+
}
1952+
scope.Variables[varName][i].AssignedFrom = funcName
1953+
continue
1954+
}
1955+
}
19371956
}
19381957

19391958
}

sast-engine/graph/callgraph/builder/stdlib_resolve_test.go

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,3 +309,115 @@ func TestResolveStdlibVariableBindings_NoKnownModule(t *testing.T) {
309309
assert.NotNil(t, binding)
310310
assert.Equal(t, "call:unknownmod.func1", binding.Type.TypeFQN) // unchanged
311311
}
312+
313+
func TestResolveStdlibVariableBindings_PhaseA_ClassMethodAlias(t *testing.T) {
314+
// tarfile.open is a module-level alias for TarFile.open (classmethod).
315+
// GetFunction("tarfile", "open") returns nil because it's not in Functions.
316+
// GetClass("tarfile", "open") returns nil because "open" is not a class.
317+
// FindClassMethodAlias should find TarFile.open and resolve its return type.
318+
loader := newTestStdlibLoader(map[string]*core.StdlibModule{
319+
"tarfile": {
320+
Module: "tarfile",
321+
Functions: map[string]*core.StdlibFunction{},
322+
Classes: map[string]*core.StdlibClass{
323+
"TarFile": {
324+
Type: "class",
325+
Methods: map[string]*core.StdlibFunction{
326+
"open": {
327+
ReturnType: "tarfile.Self",
328+
Confidence: 0.95,
329+
},
330+
"extractall": {
331+
ReturnType: "builtins.NoneType",
332+
Confidence: 0.95,
333+
},
334+
},
335+
},
336+
},
337+
},
338+
})
339+
340+
typeEngine := resolution.NewTypeInferenceEngine(nil)
341+
typeEngine.StdlibRemote = loader
342+
343+
scope := resolution.NewFunctionScope("app.extract_archive")
344+
typeEngine.Scopes["app.extract_archive"] = scope
345+
scope.AddVariable(&resolution.VariableBinding{
346+
VarName: "tar",
347+
Type: &core.TypeInfo{
348+
TypeFQN: "call:tarfile.open",
349+
Confidence: 0.5,
350+
},
351+
})
352+
353+
logger := output.NewLogger(output.VerbosityDefault)
354+
resolveStdlibVariableBindings(typeEngine, logger)
355+
356+
binding := scope.GetVariable("tar")
357+
assert.NotNil(t, binding)
358+
// "tarfile.Self" should be resolved to "tarfile.TarFile"
359+
assert.Equal(t, "tarfile.TarFile", binding.Type.TypeFQN)
360+
assert.Equal(t, "tarfile.open", binding.AssignedFrom)
361+
assert.Equal(t, "stdlib", binding.Type.Source)
362+
}
363+
364+
func TestResolveStdlibVariableBindings_PhaseA_ClassMethodAlias_ChainToPhaseB(t *testing.T) {
365+
// Full chain: tarfile.open() → TarFile, then tar.extractall() via Phase B.
366+
loader := newTestStdlibLoader(map[string]*core.StdlibModule{
367+
"tarfile": {
368+
Module: "tarfile",
369+
Functions: map[string]*core.StdlibFunction{},
370+
Classes: map[string]*core.StdlibClass{
371+
"TarFile": {
372+
Type: "class",
373+
Methods: map[string]*core.StdlibFunction{
374+
"open": {
375+
ReturnType: "tarfile.Self",
376+
Confidence: 0.95,
377+
},
378+
"extractall": {
379+
ReturnType: "builtins.NoneType",
380+
Confidence: 0.95,
381+
},
382+
},
383+
},
384+
},
385+
},
386+
})
387+
388+
typeEngine := resolution.NewTypeInferenceEngine(nil)
389+
typeEngine.StdlibRemote = loader
390+
391+
scope := resolution.NewFunctionScope("app.extract_archive")
392+
typeEngine.Scopes["app.extract_archive"] = scope
393+
394+
// tar = tarfile.open(...)
395+
scope.AddVariable(&resolution.VariableBinding{
396+
VarName: "tar",
397+
Type: &core.TypeInfo{
398+
TypeFQN: "call:tarfile.open",
399+
Confidence: 0.5,
400+
},
401+
})
402+
// result = tar.extractall(...)
403+
scope.AddVariable(&resolution.VariableBinding{
404+
VarName: "result",
405+
Type: &core.TypeInfo{
406+
TypeFQN: "call:tar.extractall",
407+
Confidence: 0.5,
408+
},
409+
})
410+
411+
logger := output.NewLogger(output.VerbosityDefault)
412+
resolveStdlibVariableBindings(typeEngine, logger)
413+
414+
// Phase A: tar → tarfile.TarFile
415+
tarBinding := scope.GetVariable("tar")
416+
assert.Equal(t, "tarfile.TarFile", tarBinding.Type.TypeFQN)
417+
418+
// Phase B: result → builtins.NoneType (via TarFile.extractall)
419+
resultBinding := scope.GetVariable("result")
420+
assert.NotNil(t, resultBinding)
421+
assert.Equal(t, "builtins.NoneType", resultBinding.Type.TypeFQN)
422+
assert.Equal(t, "tarfile.TarFile.extractall", resultBinding.AssignedFrom)
423+
}

sast-engine/graph/callgraph/registry/stdlib_remote.go

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,42 @@ func (r *StdlibRegistryRemote) GetClassMethod(moduleName, className, methodName
295295
return nil
296296
}
297297

298+
// FindClassMethodAlias searches all classes in a module for a method matching
299+
// the given name. This handles module-level aliases like tarfile.open which is
300+
// actually TarFile.open (a classmethod exposed at module level).
301+
//
302+
// When multiple classes have a method with the same name, the class whose name
303+
// matches the module (e.g., TarFile in tarfile) is preferred for determinism.
304+
// Also checks inherited methods via GetClassMethod.
305+
//
306+
// Returns the matching StdlibFunction and the owning class name, or nil/"" if
307+
// no match is found.
308+
func (r *StdlibRegistryRemote) FindClassMethodAlias(moduleName, functionName string, logger *output.Logger) (*core.StdlibFunction, string) {
309+
module, err := r.GetModule(moduleName, logger)
310+
if err != nil || module == nil {
311+
return nil, ""
312+
}
313+
314+
var bestMethod *core.StdlibFunction
315+
var bestClassName string
316+
317+
for className := range module.Classes {
318+
method := r.GetClassMethod(moduleName, className, functionName, logger)
319+
if method == nil {
320+
continue
321+
}
322+
if bestMethod == nil {
323+
bestMethod = method
324+
bestClassName = className
325+
} else if strings.EqualFold(className, moduleName) {
326+
// Prefer the class matching the module name (e.g., TarFile in tarfile)
327+
bestMethod = method
328+
bestClassName = className
329+
}
330+
}
331+
return bestMethod, bestClassName
332+
}
333+
298334
// ModuleCount returns the number of modules in the manifest.
299335
//
300336
// Returns:

sast-engine/graph/callgraph/registry/stdlib_remote_test.go

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -857,3 +857,127 @@ func TestStdlibRegistryRemote_GetModule_HTTPError(t *testing.T) {
857857
assert.Contains(t, err.Error(), "module download failed with status: 404")
858858
assert.Nil(t, module)
859859
}
860+
861+
func TestStdlibRegistryRemote_FindClassMethodAlias(t *testing.T) {
862+
// tarfile.open is a module-level alias for TarFile.open (classmethod)
863+
remote := &StdlibRegistryRemote{
864+
ModuleCache: map[string]*core.StdlibModule{
865+
"tarfile": {
866+
Module: "tarfile",
867+
Functions: map[string]*core.StdlibFunction{},
868+
Classes: map[string]*core.StdlibClass{
869+
"TarFile": {
870+
Type: "class",
871+
Methods: map[string]*core.StdlibFunction{
872+
"open": {ReturnType: "tarfile.Self", Confidence: 0.95},
873+
"extractall": {ReturnType: "builtins.NoneType", Confidence: 0.95},
874+
},
875+
},
876+
"TarInfo": {
877+
Type: "class",
878+
Methods: map[string]*core.StdlibFunction{
879+
"isdir": {ReturnType: "builtins.bool", Confidence: 0.9},
880+
},
881+
},
882+
},
883+
},
884+
},
885+
}
886+
logger := newTestLogger()
887+
888+
// Found: tarfile.open → TarFile.open
889+
method, className := remote.FindClassMethodAlias("tarfile", "open", logger)
890+
require.NotNil(t, method, "should find TarFile.open as alias for tarfile.open")
891+
assert.Equal(t, "TarFile", className)
892+
assert.Equal(t, "tarfile.Self", method.ReturnType)
893+
assert.InDelta(t, 0.95, method.Confidence, 0.001)
894+
895+
// Found: tarfile.extractall → TarFile.extractall
896+
method2, className2 := remote.FindClassMethodAlias("tarfile", "extractall", logger)
897+
require.NotNil(t, method2)
898+
assert.Equal(t, "TarFile", className2)
899+
assert.Equal(t, "builtins.NoneType", method2.ReturnType)
900+
901+
// Found: tarfile.isdir → TarInfo.isdir
902+
method3, className3 := remote.FindClassMethodAlias("tarfile", "isdir", logger)
903+
require.NotNil(t, method3)
904+
assert.Equal(t, "TarInfo", className3)
905+
906+
// Not found: no class has "nonexistent"
907+
method4, className4 := remote.FindClassMethodAlias("tarfile", "nonexistent", logger)
908+
assert.Nil(t, method4)
909+
assert.Equal(t, "", className4)
910+
911+
// Not found: module not in cache
912+
method5, className5 := remote.FindClassMethodAlias("unknown", "open", logger)
913+
assert.Nil(t, method5)
914+
assert.Equal(t, "", className5)
915+
}
916+
917+
func TestStdlibRegistryRemote_FindClassMethodAlias_PrefersModuleNameClass(t *testing.T) {
918+
// When multiple classes have the same method name, prefer the class
919+
// whose name matches the module (case-insensitive).
920+
// Use many non-matching classes to increase chance the preferred one
921+
// is not the first iterated (Go map order is random).
922+
classes := map[string]*core.StdlibClass{
923+
"MyMod": {
924+
Type: "class",
925+
Methods: map[string]*core.StdlibFunction{
926+
"connect": {ReturnType: "mymod.MyMod", Confidence: 0.95},
927+
},
928+
},
929+
}
930+
// Add many non-matching classes with the same method to force iteration
931+
for _, name := range []string{"Alpha", "Beta", "Gamma", "Delta", "Epsilon", "Zeta", "Eta", "Theta"} {
932+
classes[name] = &core.StdlibClass{
933+
Type: "class",
934+
Methods: map[string]*core.StdlibFunction{
935+
"connect": {ReturnType: "mymod." + name, Confidence: 0.8},
936+
},
937+
}
938+
}
939+
940+
remote := &StdlibRegistryRemote{
941+
ModuleCache: map[string]*core.StdlibModule{
942+
"mymod": {
943+
Module: "mymod",
944+
Functions: map[string]*core.StdlibFunction{},
945+
Classes: classes,
946+
},
947+
},
948+
}
949+
950+
// Run multiple times to exercise different map iteration orders
951+
for i := 0; i < 20; i++ {
952+
method, className := remote.FindClassMethodAlias("mymod", "connect", newTestLogger())
953+
require.NotNil(t, method, "iteration %d: should find connect", i)
954+
assert.Equal(t, "MyMod", className, "iteration %d: should always prefer class matching module name", i)
955+
assert.Equal(t, "mymod.MyMod", method.ReturnType, "iteration %d", i)
956+
}
957+
}
958+
959+
func TestStdlibRegistryRemote_FindClassMethodAlias_Inherited(t *testing.T) {
960+
// FindClassMethodAlias uses GetClassMethod which checks inherited methods too.
961+
remote := &StdlibRegistryRemote{
962+
ModuleCache: map[string]*core.StdlibModule{
963+
"mymod": {
964+
Module: "mymod",
965+
Functions: map[string]*core.StdlibFunction{},
966+
Classes: map[string]*core.StdlibClass{
967+
"Base": {
968+
Type: "class",
969+
Methods: map[string]*core.StdlibFunction{},
970+
InheritedMethods: map[string]*core.InheritedMember{
971+
"read": {ReturnType: "builtins.bytes", Confidence: 0.85, Source: "io.IOBase"},
972+
},
973+
},
974+
},
975+
},
976+
},
977+
}
978+
979+
method, className := remote.FindClassMethodAlias("mymod", "read", newTestLogger())
980+
require.NotNil(t, method, "should find inherited method via GetClassMethod")
981+
assert.Equal(t, "Base", className)
982+
assert.Equal(t, "builtins.bytes", method.ReturnType)
983+
}

0 commit comments

Comments
 (0)