Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,582 changes: 797 additions & 785 deletions proto/gen/rill/runtime/v1/api.pb.go

Large diffs are not rendered by default.

29 changes: 29 additions & 0 deletions proto/gen/rill/runtime/v1/api.pb.validate.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions proto/gen/rill/runtime/v1/runtime.swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5432,6 +5432,8 @@ definitions:
properties:
message:
$ref: '#/definitions/v1Message'
result:
$ref: '#/definitions/v1Message'
title: Response message for RuntimeService.GetAIMessage
v1GetConversationResponse:
type: object
Expand Down
1 change: 1 addition & 0 deletions proto/rill/runtime/v1/api.proto
Original file line number Diff line number Diff line change
Expand Up @@ -1216,6 +1216,7 @@ message GetAIMessageRequest {
// Response message for RuntimeService.GetAIMessage
message GetAIMessageResponse {
Message message = 1;
Message result = 2;
}

// **********
Expand Down
39 changes: 24 additions & 15 deletions runtime/ai/metrics_view_get.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,39 +53,48 @@ func (t *GetMetricsView) CheckAccess(ctx context.Context) (bool, error) {
func (t *GetMetricsView) Handler(ctx context.Context, args *GetMetricsViewArgs) (*GetMetricsViewResult, error) {
session := GetSession(ctx)

ctrl, err := t.Runtime.Controller(ctx, session.InstanceID())
mvSpec, err := resolveMetricsView(ctx, t.Runtime, session, args.MetricsView)
if err != nil {
return nil, err
}

r, err := ctrl.Get(ctx, &runtimev1.ResourceName{Kind: runtime.ResourceKindMetricsView, Name: args.MetricsView}, false)
specJSON, err := protojson.Marshal(mvSpec)
if err != nil {
return nil, err
}

r, access, err := t.Runtime.ApplySecurityPolicy(ctx, session.InstanceID(), session.Claims(), r)
var specMap map[string]any
err = json.Unmarshal(specJSON, &specMap)
if err != nil {
return nil, err
}
if !access {
return nil, fmt.Errorf("resource not found")
}

if r.GetMetricsView().State.ValidSpec == nil {
return nil, fmt.Errorf("metrics view %q is invalid", args.MetricsView)
return &GetMetricsViewResult{
Spec: specMap,
}, nil
}

func resolveMetricsView(ctx context.Context, rt *runtime.Runtime, session *Session, metricsView string) (*runtimev1.MetricsViewSpec, error) {
ctrl, err := rt.Controller(ctx, session.InstanceID())
if err != nil {
return nil, err
}

specJSON, err := protojson.Marshal(r.GetMetricsView().State.ValidSpec)
r, err := ctrl.Get(ctx, &runtimev1.ResourceName{Kind: runtime.ResourceKindMetricsView, Name: metricsView}, false)
if err != nil {
return nil, err
}
var specMap map[string]any
err = json.Unmarshal(specJSON, &specMap)

r, access, err := rt.ApplySecurityPolicy(ctx, session.InstanceID(), session.Claims(), r)
if err != nil {
return nil, err
}
if !access {
return nil, fmt.Errorf("resource not found")
}

return &GetMetricsViewResult{
Spec: specMap,
}, nil
if r.GetMetricsView().State.ValidSpec == nil {
return nil, fmt.Errorf("metrics view %q is invalid", metricsView)
}

return r.GetMetricsView().State.ValidSpec, nil
}
103 changes: 96 additions & 7 deletions runtime/ai/metrics_view_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@ import (

"github.com/google/jsonschema-go/jsonschema"
"github.com/modelcontextprotocol/go-sdk/mcp"
runtimev1 "github.com/rilldata/rill/proto/gen/rill/runtime/v1"
"github.com/rilldata/rill/runtime"
"github.com/rilldata/rill/runtime/metricsview"
"github.com/rilldata/rill/runtime/pkg/mapstructureutil"
"github.com/rilldata/rill/runtime/pkg/rilltime"
"github.com/rilldata/rill/runtime/pkg/timeutil"
"github.com/rilldata/rill/runtime/queries"
)

const QueryMetricsViewName = "query_metrics_view"
Expand All @@ -25,10 +30,11 @@ var _ Tool[QueryMetricsViewArgs, *QueryMetricsViewResult] = (*QueryMetricsView)(
type QueryMetricsViewArgs map[string]any

type QueryMetricsViewResult struct {
Schema []SchemaField `json:"schema"`
Data [][]any `json:"data"`
OpenURL string `json:"open_url,omitempty"`
TruncationWarning string `json:"truncation_warning,omitempty"`
Schema []SchemaField `json:"schema"`
Data [][]any `json:"data"`
ResolvedTimeRanges []*metricsview.TimeRange `json:"resolved_time_ranges,omitempty"`
OpenURL string `json:"open_url,omitempty"`
TruncationWarning string `json:"truncation_warning,omitempty"`
}

func (t *QueryMetricsView) Spec() *mcp.Tool {
Expand Down Expand Up @@ -273,6 +279,12 @@ func (t *QueryMetricsView) Handler(ctx context.Context, args QueryMetricsViewArg
return nil, err
}

// Resolve time ranges and store them in the result to record the exact resolved time ranges for this tool call
resolvedTimeRanges, err := t.resolveTimeRanges(ctx, session, args)
if err != nil {
return nil, err
}

// Generate an open URL for the query
openURL, err := t.generateOpenURL(ctx, session.InstanceID(), session.ID(), session.ParentID)
if err != nil {
Expand All @@ -281,9 +293,10 @@ func (t *QueryMetricsView) Handler(ctx context.Context, args QueryMetricsViewArg

// Build the result
result := &QueryMetricsViewResult{
Schema: schema,
Data: data,
OpenURL: openURL,
Schema: schema,
Data: data,
OpenURL: openURL,
ResolvedTimeRanges: resolvedTimeRanges,
}
if isSystemLimit && int64(len(data)) >= limit { // Add a warning if we hit the system limit
msg := fmt.Sprintf("The system truncated the result to %d rows", limit)
Expand Down Expand Up @@ -321,3 +334,79 @@ func (t *QueryMetricsView) generateOpenURL(ctx context.Context, instanceID, sess

return openURL.String(), nil
}

func (t *QueryMetricsView) resolveTimeRanges(ctx context.Context, session *Session, args map[string]any) ([]*metricsview.TimeRange, error) {
qry := &metricsview.Query{}
if err := mapstructureutil.WeakDecode(args, qry); err != nil {
return nil, err
}

if qry.TimeRange == nil {
return nil, nil
}

mvSpec, err := resolveMetricsView(ctx, t.Runtime, session, qry.MetricsView)
if err != nil {
return nil, err
}

ts, err := queries.ResolveTimestampResult(ctx, t.Runtime, session.InstanceID(), qry.MetricsView, qry.TimeRange.TimeDimension, session.Claims(), 0)
if err != nil {
return nil, err
}

var tz *time.Location
if qry.TimeZone != "" {
tz, err = time.LoadLocation(qry.TimeZone)
if err != nil {
return nil, err
}
}

var resolvedTimeRanges []*metricsview.TimeRange
tr, err := evalTimeRangeExpr(mvSpec, qry.TimeRange, ts, tz)
if err != nil {
return nil, err
}
if tr != nil {
resolvedTimeRanges = append(resolvedTimeRanges, tr)
}

ctr, err := evalTimeRangeExpr(mvSpec, qry.ComparisonTimeRange, ts, tz)
if err != nil {
return nil, err
}
if ctr != nil {
resolvedTimeRanges = append(resolvedTimeRanges, ctr)
}

return resolvedTimeRanges, nil
}

func evalTimeRangeExpr(mvSpec *runtimev1.MetricsViewSpec, timeRange *metricsview.TimeRange, ts metricsview.TimestampsResult, tz *time.Location) (*metricsview.TimeRange, error) {
if timeRange == nil || timeRange.Expression == "" {
return nil, nil
}

expr, err := rilltime.Parse(timeRange.Expression, rilltime.ParseOptions{
TimeZoneOverride: tz,
SmallestGrain: timeutil.TimeGrainFromAPI(mvSpec.SmallestTimeGrain),
})
if err != nil {
return nil, fmt.Errorf("error parsing time range %s: %w", timeRange.Expression, err)
}

start, end, _ := expr.Eval(rilltime.EvalOptions{
Now: time.Now(),
MinTime: ts.Min,
MaxTime: ts.Max,
Watermark: ts.Watermark,
FirstDay: int(mvSpec.FirstDayOfWeek),
FirstMonth: int(mvSpec.FirstMonthOfYear),
})
return &metricsview.TimeRange{
Expression: timeRange.Expression,
Start: start,
End: end,
}, nil
}
79 changes: 79 additions & 0 deletions runtime/ai/metrics_view_query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,3 +175,82 @@ cache:
require.Equal(t, nil, res.Data[0][0])
require.Equal(t, nil, res.Data[0][1])
}

func TestMetricsViewQueryResolvedTimeRanges(t *testing.T) {
// Setup a metrics view with a time dimension. The watermark defaults to the max event_time, i.e. 2025-05-13T00:00:00Z.
rt, instanceID := testruntime.NewInstanceWithOptions(t, testruntime.InstanceOptions{
Files: map[string]string{
"test_data.sql": `
SELECT '2025-05-10T00:00:00Z'::TIMESTAMP AS event_time, 'US' AS country, 100 AS revenue
UNION ALL
SELECT '2025-05-11T00:00:00Z'::TIMESTAMP AS event_time, 'US' AS country, 200 AS revenue
UNION ALL
SELECT '2025-05-12T00:00:00Z'::TIMESTAMP AS event_time, 'US' AS country, 300 AS revenue
UNION ALL
SELECT '2025-05-13T00:00:00Z'::TIMESTAMP AS event_time, 'US' AS country, 400 AS revenue
`,
"test_metrics.yaml": `
type: metrics_view
model: test_data
timeseries: event_time
dimensions:
- column: country
measures:
- name: total_revenue
expression: SUM(revenue)
explore:
skip: true
`,
},
Variables: map[string]string{
"rill.ai.require_time_range": "false",
},
})
testruntime.RequireReconcileState(t, rt, instanceID, 3, 0, 0)

s := newSession(t, rt, instanceID)

baseArgs := func() ai.QueryMetricsViewArgs {
return ai.QueryMetricsViewArgs{
"metrics_view": "test_metrics",
"dimensions": []map[string]any{{"name": "country"}},
"measures": []map[string]any{{"name": "total_revenue"}},
}
}

t.Run("no time range", func(t *testing.T) {
var res *ai.QueryMetricsViewResult
_, err := s.CallTool(t.Context(), ai.RoleUser, ai.QueryMetricsViewName, &res, baseArgs())
require.NoError(t, err)
require.Nil(t, res.ResolvedTimeRanges)
})

t.Run("literal start/end time range", func(t *testing.T) {
args := baseArgs()
args["time_range"] = map[string]any{
"start": "2025-05-10T00:00:00Z",
"end": "2025-05-14T00:00:00Z",
}
var res *ai.QueryMetricsViewResult
_, err := s.CallTool(t.Context(), ai.RoleUser, ai.QueryMetricsViewName, &res, args)
require.NoError(t, err)
// Only expression-based time ranges are recorded in ResolvedTimeRanges.
require.Nil(t, res.ResolvedTimeRanges)
})

t.Run("expression time range and comparison", func(t *testing.T) {
args := baseArgs()
args["time_range"] = map[string]any{"expression": "1D as of watermark/D"}
args["comparison_time_range"] = map[string]any{"expression": "1D as of watermark/D offset -1D"}
var res *ai.QueryMetricsViewResult
_, err := s.CallTool(t.Context(), ai.RoleUser, ai.QueryMetricsViewName, &res, args)
require.NoError(t, err)
require.Len(t, res.ResolvedTimeRanges, 2)
require.Equal(t, "1D as of watermark/D", res.ResolvedTimeRanges[0].Expression)
require.Equal(t, parseTestTime(t, "2025-05-12T00:00:00Z"), res.ResolvedTimeRanges[0].Start)
require.Equal(t, parseTestTime(t, "2025-05-13T00:00:00Z"), res.ResolvedTimeRanges[0].End)
require.Equal(t, "1D as of watermark/D offset -1D", res.ResolvedTimeRanges[1].Expression)
require.Equal(t, parseTestTime(t, "2025-05-11T00:00:00Z"), res.ResolvedTimeRanges[1].Start)
require.Equal(t, parseTestTime(t, "2025-05-12T00:00:00Z"), res.ResolvedTimeRanges[1].End)
})
}
11 changes: 11 additions & 0 deletions runtime/server/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,19 @@ func (s *Server) GetAIMessage(ctx context.Context, req *runtimev1.GetAIMessageRe
return nil, status.Errorf(codes.Internal, "failed to convert message to protobuf: %v", err)
}

var resPbMsg *runtimev1.Message
resMsg, ok := session.Message(ai.FilterByParent(req.MessageId), ai.FilterByType(ai.MessageTypeResult))
// Don't throw if there is no result message.
if ok {
resPbMsg, err = messageToPB(session, resMsg)
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to convert message to protobuf: %v", err)
}
}

return &runtimev1.GetAIMessageResponse{
Message: pbMsg,
Result: resPbMsg,
}, nil
}

Expand Down
1 change: 1 addition & 0 deletions runtime/server/chat_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ measures:
})
require.NoError(t, err)
require.Equal(t, res1.Messages[0], msgRes.Message)
require.Equal(t, res1.Messages[5], msgRes.Result)

// Check it errors if completing a conversation that doesn't exist
_, err = srv.Complete(fooCtx, &runtimev1.CompleteRequest{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,14 @@ import { fetchMessage } from "@rilldata/web-common/features/chat/core/citation-u
export async function load({ params: { conversationId, messageId }, parent }) {
const { runtime } = await parent();

const message = await fetchMessage(runtime, conversationId, messageId);
const { message, result } = await fetchMessage(
runtime,
conversationId,
messageId,
);

return {
message,
result,
};
}
Loading
Loading