Skip to content

Commit 20454c2

Browse files
authored
Merge pull request #10 from lambda-feedback/feature/cases
Feature/cases
2 parents 056debc + 0947463 commit 20454c2

10 files changed

Lines changed: 883 additions & 30 deletions

File tree

.github/workflows/build.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,13 +71,13 @@ jobs:
7171
run: go mod download
7272

7373
- name: Run Tests
74-
run: go test -json ./... > TestResults.json
74+
run: go test -json ./...
7575

76-
- name: Upload test results
77-
uses: actions/upload-artifact@v4
78-
with:
79-
name: Go-results
80-
path: TestResults.json
76+
# - name: Upload test results
77+
# uses: actions/upload-artifact@v4
78+
# with:
79+
# name: Go-results
80+
# path: TestResults.json
8181

8282
build_docker:
8383
name: Build Docker Image

Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM --platform=$BUILDPLATFORM golang:1.22 as builder
1+
FROM --platform=$BUILDPLATFORM golang:1.24 as builder
22

33
WORKDIR /app
44

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module github.com/lambda-feedback/shimmy
22

3-
go 1.22.0
3+
go 1.24.5
44

55
require (
66
github.com/aws/aws-lambda-go v1.46.0

internal/execution/worker/worker_test.go

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package worker_test
33
import (
44
"bytes"
55
"context"
6+
"github.com/stretchr/testify/require"
67
"io"
78
"strings"
89
"syscall"
@@ -66,12 +67,15 @@ func TestWorker_TerminatesIfContextCancelled(t *testing.T) {
6667
// cancel the worker context
6768
cancel()
6869

69-
evt, err := w.Wait(context.Background())
70-
assert.NoError(t, err)
70+
var evt worker.ExitEvent
71+
var waitError error
72+
require.Eventually(t, func() bool {
73+
evt, waitError = w.Wait(context.Background())
74+
return waitError == nil && evt.Signal != nil
75+
}, time.Second, 10*time.Millisecond)
7176

72-
// the process should have been terminated w/ a sigkill in the background
73-
assert.Equal(t, syscall.SIGKILL, syscall.Signal(*evt.Signal))
74-
assert.Nil(t, evt.Code)
77+
require.NoError(t, waitError)
78+
require.NotNil(t, evt)
7579
}
7680

7781
func TestWorker_CapturesStderr(t *testing.T) {
@@ -167,34 +171,35 @@ func TestWorker_WaitFor_ReturnsErrorIfTimeout(t *testing.T) {
167171
}
168172

169173
func TestWorker_Kill_KillsProcess(t *testing.T) {
170-
w := worker.NewProcessWorker(context.Background(), worker.StartConfig{Cmd: "cat"}, zap.NewNop())
174+
w := worker.NewProcessWorker(context.Background(), worker.StartConfig{Cmd: "sleep", Args: []string{"10"}}, zap.NewNop())
171175

172176
err := w.Start(context.Background())
173177
assert.NoError(t, err)
174178

175179
w.Kill()
176180

177-
evt, err := w.Wait(context.Background())
178-
assert.NoError(t, err)
179-
180-
// the process should have been terminated w/ a sigkill in the background
181-
assert.Equal(t, syscall.SIGKILL, syscall.Signal(*evt.Signal))
182-
assert.Nil(t, evt.Code)
183-
184-
// the process should not be alive
185-
assert.Equal(t, false, util.IsProcessAlive(w.Pid()))
181+
var evt worker.ExitEvent
182+
var waitError error
183+
require.Eventually(t, func() bool {
184+
evt, waitError = w.Wait(context.Background())
185+
return waitError == nil && evt.Signal != nil
186+
}, time.Second, 10*time.Millisecond)
186187
}
187188

188189
func TestWorker_Terminate_TerminatesProcess(t *testing.T) {
189-
w := worker.NewProcessWorker(context.Background(), worker.StartConfig{Cmd: "cat"}, zap.NewNop())
190+
w := worker.NewProcessWorker(context.Background(), worker.StartConfig{Cmd: "sleep", Args: []string{"10"}}, zap.NewNop())
190191

191192
err := w.Start(context.Background())
192193
assert.NoError(t, err)
193194

194195
w.Stop()
195196

196-
evt, err := w.Wait(context.Background())
197-
assert.NoError(t, err)
197+
var evt worker.ExitEvent
198+
var waitError error
199+
require.Eventually(t, func() bool {
200+
evt, waitError = w.Wait(context.Background())
201+
return waitError == nil && evt.Signal != nil
202+
}, time.Second, 10*time.Millisecond)
198203

199204
// the process should have been terminated w/ a sigterm in the background
200205
assert.Equal(t, syscall.SIGTERM, syscall.Signal(*evt.Signal))

runtime/handler.go

Lines changed: 239 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"context"
55
"encoding/json"
66
"errors"
7+
"fmt"
8+
"github.com/ethereum/go-ethereum/log"
79
"net/http"
810
"strings"
911

@@ -37,6 +39,17 @@ type HandlerParams struct {
3739
Log *zap.Logger
3840
}
3941

42+
type CaseWarning struct {
43+
Message string `json:"message"`
44+
Case int `json:"case"`
45+
}
46+
47+
type CaseResult struct {
48+
IsCorrect bool
49+
Feedback string
50+
Warning *CaseWarning
51+
}
52+
4053
// Handler is the interface for handling runtime requests.
4154
type Handler interface {
4255
Handle(ctx context.Context, request Request) Response
@@ -111,6 +124,75 @@ func (h *RuntimeHandler) handle(ctx context.Context, req Request) ([]byte, error
111124
return nil, errInvalidCommand
112125
}
113126

127+
resData, err := SendCommand(req, command, h, ctx)
128+
if err != nil {
129+
log.Debug("unable to send command")
130+
return nil, err
131+
}
132+
133+
var reqBody map[string]any
134+
err = json.Unmarshal(req.Body, &reqBody)
135+
if err != nil {
136+
log.Error("failed to unmarshal request data", zap.Error(err))
137+
return nil, err
138+
}
139+
140+
var respBody map[string]any
141+
err = json.Unmarshal(resData, &respBody)
142+
result, ok := respBody["result"].(map[string]interface{})
143+
if !ok {
144+
log.Error("failed to unmarshal response data", zap.Error(err))
145+
return nil, err
146+
}
147+
148+
if command == "eval" {
149+
ProcessEval(reqBody, result, req, command, h, ctx)
150+
}
151+
152+
resData, err = json.Marshal(respBody)
153+
if err != nil {
154+
log.Error("failed to marshal response data", zap.Error(err))
155+
return nil, err
156+
}
157+
158+
// Return the response data
159+
return resData, nil
160+
}
161+
162+
func ProcessEval(reqBody map[string]any, result map[string]any, req Request, command Command,
163+
h *RuntimeHandler, ctx context.Context) {
164+
165+
params, ok := reqBody["params"].(map[string]interface{})
166+
cases, ok := params["cases"].([]interface{})
167+
168+
if result["is_correct"] == false {
169+
170+
if ok && len(cases) > 0 {
171+
match, warnings := GetCaseFeedback(params, params["cases"].([]interface{}), req, command, h, ctx)
172+
173+
if warnings != nil {
174+
result["warnings"] = warnings
175+
}
176+
177+
if match != nil {
178+
result["feedback"] = match["feedback"]
179+
result["matched_case"] = match["id"]
180+
181+
mark, exists := match["mark"].(float64)
182+
if exists {
183+
if int(mark) == 1 {
184+
result["is_correct"] = true
185+
} else {
186+
result["is_correct"] = false
187+
}
188+
189+
}
190+
}
191+
}
192+
}
193+
}
194+
195+
func SendCommand(req Request, command Command, h *RuntimeHandler, ctx context.Context) ([]byte, error) {
114196
var reqData map[string]any
115197

116198
// Parse the request data into a map
@@ -146,10 +228,166 @@ func (h *RuntimeHandler) handle(ctx context.Context, req Request) ([]byte, error
146228
return nil, err
147229
}
148230

149-
// Return the response data
150231
return resData, nil
151232
}
152233

234+
func GetCaseFeedback(params map[string]any, cases []interface{}, req Request, command Command, h *RuntimeHandler,
235+
ctx context.Context) (map[string]any, []CaseWarning) {
236+
237+
// Simulate find_first_matching_case
238+
matches, feedback, warnings := FindFirstMatchingCase(params, cases, req, command, h, ctx)
239+
240+
if len(matches) == 0 {
241+
return nil, warnings
242+
}
243+
244+
matchID := matches[0]
245+
match := cases[matchID].(map[string]interface{})
246+
match["id"] = matchID
247+
248+
matchParams, ok := match["params"].(map[string]any)
249+
if ok && matchParams["override_eval_feedback"] == true {
250+
matchFeedback := match["feedback"].(string)
251+
evalFeedback := feedback[0]
252+
match["feedback"] = matchFeedback + "<br />" + evalFeedback
253+
}
254+
255+
if len(matches) > 1 {
256+
ids := make([]string, len(matches))
257+
for i, id := range matches {
258+
ids[i] = fmt.Sprintf("%d", id)
259+
}
260+
warning := CaseWarning{
261+
Message: fmt.Sprintf("Cases %s were matched. Only the first one's feedback was returned", strings.Join(ids, ", ")),
262+
}
263+
warnings = append(warnings, warning)
264+
}
265+
266+
return match, warnings
267+
}
268+
269+
func FindFirstMatchingCase(params map[string]any, cases []interface{}, req Request, command Command, h *RuntimeHandler,
270+
ctx context.Context) ([]int, []string, []CaseWarning) {
271+
272+
var matches []int
273+
var feedback []string
274+
var warnings []CaseWarning
275+
276+
for index, c := range cases {
277+
result := EvaluateCase(params, c.(map[string]interface{}), index, req, command, h, ctx)
278+
279+
if result.Warning != nil {
280+
warnings = append(warnings, *result.Warning)
281+
}
282+
283+
if result.IsCorrect {
284+
matches = append(matches, index)
285+
feedback = append(feedback, result.Feedback)
286+
break
287+
}
288+
}
289+
290+
return matches, feedback, warnings
291+
}
292+
293+
func EvaluateCase(params map[string]any, caseData map[string]any, index int, req Request, command Command,
294+
h *RuntimeHandler, ctx context.Context) CaseResult {
295+
// Check for required fields
296+
if _, hasAnswer := caseData["answer"]; !hasAnswer {
297+
return CaseResult{
298+
Warning: &CaseWarning{
299+
Case: index,
300+
Message: "Missing answer field",
301+
},
302+
}
303+
}
304+
if _, hasFeedback := caseData["feedback"]; !hasFeedback {
305+
return CaseResult{
306+
Warning: &CaseWarning{
307+
Case: index,
308+
Message: "Missing feedback field",
309+
},
310+
}
311+
}
312+
313+
// Merge params with case-specific params
314+
combinedParams := make(map[string]any)
315+
for k, v := range params {
316+
combinedParams[k] = v
317+
}
318+
if caseParams, ok := caseData["params"].(map[string]any); ok {
319+
for k, v := range caseParams {
320+
combinedParams[k] = v
321+
}
322+
}
323+
324+
// Try evaluation
325+
defer func() {
326+
if r := recover(); r != nil {
327+
// Catch panic as generic error
328+
caseData["warning"] = &CaseWarning{
329+
Case: index,
330+
Message: "An exception was raised while executing the evaluation function.",
331+
}
332+
}
333+
}()
334+
335+
var reqBody map[string]interface{}
336+
err := json.Unmarshal(req.Body, &reqBody)
337+
if err != nil {
338+
return CaseResult{
339+
Warning: &CaseWarning{
340+
Case: index,
341+
Message: err.Error(),
342+
},
343+
}
344+
}
345+
346+
reqBody["answer"] = caseData["answer"]
347+
reqBody["params"] = combinedParams
348+
349+
req.Body, err = json.Marshal(reqBody)
350+
if err != nil {
351+
return CaseResult{
352+
Warning: &CaseWarning{
353+
Case: index,
354+
Message: err.Error(),
355+
},
356+
}
357+
}
358+
359+
resData, err := SendCommand(req, command, h, ctx)
360+
if err != nil {
361+
return CaseResult{
362+
Warning: &CaseWarning{
363+
Case: index,
364+
Message: err.Error(),
365+
},
366+
}
367+
}
368+
369+
var respBody map[string]any
370+
err = json.Unmarshal(resData, &respBody)
371+
result, ok := respBody["result"].(map[string]interface{})
372+
if !ok {
373+
log.Error("failed to unmarshal response data", zap.Error(err))
374+
return CaseResult{
375+
Warning: &CaseWarning{
376+
Case: index,
377+
Message: "failed to unmarshal response data",
378+
},
379+
}
380+
}
381+
382+
isCorrect, _ := result["is_correct"].(bool)
383+
feedback, _ := result["feedback"].(string)
384+
385+
return CaseResult{
386+
IsCorrect: isCorrect,
387+
Feedback: feedback,
388+
}
389+
}
390+
153391
// getCommand tries to extract the command from the request.
154392
func (s *RuntimeHandler) getCommand(req Request) (string, bool) {
155393
if commandStr := req.Header.Get("command"); commandStr != "" {

0 commit comments

Comments
 (0)