Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 116 additions & 7 deletions src/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use serde::{Serialize, ser::SerializeMap};
use unicase::Ascii;

use crate::{
Attrs, Binary, Expr, Field, FunArgs, Query, Raw, Source, SourceKind, Type, Value,
App, Attrs, Binary, Expr, Field, FunArgs, Query, Raw, Source, SourceKind, Type, Value,
error::AnalysisError, token::Operator,
};

Expand All @@ -36,6 +36,9 @@ pub struct Typed {
/// including bindings from FROM clauses and their associated types.
#[serde(skip)]
pub scope: Scope,

/// Indicates if the query uses aggregate functions.
pub aggregate: bool,
}

/// Result type for static analysis operations.
Expand Down Expand Up @@ -528,6 +531,9 @@ pub struct AnalysisContext {
/// Set to `true` to allow aggregate functions, `false` to reject them.
/// Defaults to `false`.
pub allow_agg_func: bool,

/// Indicates if the query uses aggregate functions.
pub use_agg_funcs: bool,
}

/// A type checker and static analyzer for EventQL expressions.
Expand Down Expand Up @@ -631,7 +637,7 @@ impl<'a> Analysis<'a> {
}

if let Some(expr) = &query.predicate {
self.analyze_expr(&ctx, expr, Type::Bool)?;
self.analyze_expr(&mut ctx, expr, Type::Bool)?;
}

if let Some(group_by) = &query.group_by {
Expand All @@ -642,10 +648,10 @@ impl<'a> Analysis<'a> {
));
}

self.analyze_expr(&ctx, &group_by.expr, Type::Unspecified)?;
self.analyze_expr(&mut ctx, &group_by.expr, Type::Unspecified)?;

if let Some(expr) = &group_by.predicate {
self.analyze_expr(&ctx, expr, Type::Bool)?;
self.analyze_expr(&mut ctx, expr, Type::Bool)?;
}
}

Expand All @@ -656,7 +662,7 @@ impl<'a> Analysis<'a> {
order_by.expr.attrs.pos.col,
));
}
self.analyze_expr(&ctx, &order_by.expr, Type::Unspecified)?;
self.analyze_expr(&mut ctx, &order_by.expr, Type::Unspecified)?;
}

let project = self.analyze_projection(&mut ctx, &query.projection)?;
Expand All @@ -671,7 +677,11 @@ impl<'a> Analysis<'a> {
limit: query.limit,
projection: query.projection,
distinct: query.distinct,
meta: Typed { project, scope },
meta: Typed {
project,
scope,
aggregate: ctx.use_agg_funcs,
},
})
}

Expand Down Expand Up @@ -732,6 +742,20 @@ impl<'a> Analysis<'a> {
Ok(tpe)
}

Value::App(app) => {
ctx.allow_agg_func = true;

let tpe = self.analyze_expr(ctx, expr, Type::Unspecified)?;

if ctx.use_agg_funcs {
self.check_projection_on_field_expr(&mut CheckContext::default(), expr)?;
} else {
self.reject_constant_func(&expr.attrs, app)?;
}

Ok(tpe)
}

Value::Id(id) => {
if let Some(tpe) = self.scope.entries.get(id.as_str()).cloned() {
Ok(tpe)
Expand Down Expand Up @@ -994,6 +1018,87 @@ impl<'a> Analysis<'a> {
}
}

fn reject_constant_func(&self, attrs: &Attrs, app: &App) -> AnalysisResult<()> {
if app.args.is_empty() {
return Err(AnalysisError::ConstantExprInProjectIntoClause(
attrs.pos.line,
attrs.pos.col,
));
}

let mut errored = None;
for arg in &app.args {
if let Err(e) = self.reject_constant_expr(arg) {
if errored.is_none() {
errored = Some(e);
}

continue;
}

// if at least one arg is sourced-bound is ok
return Ok(());
}

Err(errored.expect("to be defined at that point"))
}

fn reject_constant_expr(&self, expr: &Expr) -> AnalysisResult<()> {
match &expr.value {
Value::Id(id) if self.scope.entries.contains_key(id.as_str()) => Ok(()),

Value::Array(exprs) => {
let mut errored = None;
for expr in exprs {
if let Err(e) = self.reject_constant_expr(expr) {
if errored.is_none() {
errored = Some(e);
}

continue;
}

// if at least one arg is sourced-bound is ok
return Ok(());
}

Err(errored.expect("to be defined at that point"))
}

Value::Record(fields) => {
let mut errored = None;
for field in fields {
if let Err(e) = self.reject_constant_expr(&field.value) {
if errored.is_none() {
errored = Some(e);
}

continue;
}

// if at least one arg is sourced-bound is ok
return Ok(());
}

Err(errored.expect("to be defined at that point"))
}

Value::Binary(binary) => self
.reject_constant_expr(&binary.lhs)
.or_else(|e| self.reject_constant_expr(&binary.rhs).map_err(|_| e)),

Value::Access(access) => self.reject_constant_expr(access.target.as_ref()),
Value::App(app) => self.reject_constant_func(&expr.attrs, app),
Value::Unary(unary) => self.reject_constant_expr(&unary.expr),
Value::Group(expr) => self.reject_constant_expr(expr),

_ => Err(AnalysisError::ConstantExprInProjectIntoClause(
expr.attrs.pos.line,
expr.attrs.pos.col,
)),
}
}

/// Analyzes an expression and checks it against an expected type.
///
/// This method performs type checking on an expression, verifying that all operations
Expand Down Expand Up @@ -1025,7 +1130,7 @@ impl<'a> Analysis<'a> {
/// ```
pub fn analyze_expr(
&mut self,
ctx: &AnalysisContext,
ctx: &mut AnalysisContext,
expr: &Expr,
mut expect: Type,
) -> AnalysisResult<Type> {
Expand Down Expand Up @@ -1147,6 +1252,10 @@ impl<'a> Analysis<'a> {
));
}

if *aggregate && ctx.allow_agg_func {
ctx.use_agg_funcs = true;
}

for (arg, tpe) in app.args.iter().zip(args.values.iter().cloned()) {
self.analyze_expr(ctx, arg, tpe)?;
}
Expand Down
28 changes: 28 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,34 @@ pub enum AnalysisError {
/// ```
#[error("{0}:{1}: aggregate functions arguments must be source-bound fields")]
ExpectSourceBoundProperty(u32, u32),

/// A constant expression is used in PROJECT INTO clause.
///
/// Fields: `(line, column)`
///
/// # Example
///
/// Invalid usage:
/// ```eql
/// FROM e IN events
/// // Error: NOW() is constant value
/// PROJECT INTO NOW()
///
/// ```
/// Invalid usage:
/// ```eql
/// FROM e IN events
/// // Error: DAY(NOW()) is also constant value
/// PROJECT INTO DAY(NOW())
/// ```
///
/// Valid usage:
/// ```eql
/// FROM e IN events
/// PROJECT INTO DAY(e.data.date)
/// ```
#[error("{0}:{1}: constant expressions are forbidden in PROJECT INTO clause")]
ConstantExprInProjectIntoClause(u32, u32),
}

impl From<LexerError> for Error {
Expand Down
26 changes: 22 additions & 4 deletions src/tests/analysis.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ fn test_typecheck_datetime_contravariance_1() {

// `e.time` is a `Type::DateTime` but it will typecheck if a `Type::Date` is expected
insta::assert_yaml_snapshot!(analysis.analyze_expr(
&AnalysisContext::default(),
&mut AnalysisContext::default(),
&expr,
Type::Date
));
Expand All @@ -142,7 +142,7 @@ fn test_typecheck_datetime_contravariance_2() {

// `NOW()` is a `Type::DateTime` but it will typecheck if a `Type::Time` is expected
insta::assert_yaml_snapshot!(analysis.analyze_expr(
&AnalysisContext::default(),
&mut AnalysisContext::default(),
&expr,
Type::Time
));
Expand All @@ -156,7 +156,7 @@ fn test_typecheck_datetime_contravariance_3() {
let mut analysis = Analysis::new(&options);

insta::assert_yaml_snapshot!(analysis.analyze_expr(
&AnalysisContext::default(),
&mut AnalysisContext::default(),
&expr,
Type::Number
));
Expand All @@ -170,7 +170,7 @@ fn test_typecheck_datetime_contravariance_4() {
let mut analysis = Analysis::new(&options);

insta::assert_yaml_snapshot!(analysis.analyze_expr(
&AnalysisContext::default(),
&mut AnalysisContext::default(),
&expr,
Type::Number
));
Expand Down Expand Up @@ -199,3 +199,21 @@ fn test_analyze_lowercase_function() {
let query = parse_query(include_str!("./resources/lowercase_function.eql")).unwrap();
insta::assert_yaml_snapshot!(query.run_static_analysis(&Default::default()));
}

#[test]
fn test_analyze_project_agg_value() {
let query = parse_query(include_str!("./resources/project_agg_value.eql")).unwrap();
insta::assert_yaml_snapshot!(query.run_static_analysis(&Default::default()));
}

#[test]
fn test_analyze_reject_constant_expr_in_project_into_clause() {
let query = parse_query(include_str!("./resources/reject_constant_expr.eql")).unwrap();
insta::assert_yaml_snapshot!(query.run_static_analysis(&Default::default()));
}

#[test]
fn test_analyze_allow_constant_agg_func() {
let query = parse_query(include_str!("./resources/allow_constant_agg_func.eql")).unwrap();
insta::assert_yaml_snapshot!(query.run_static_analysis(&Default::default()));
}
2 changes: 2 additions & 0 deletions src/tests/resources/allow_constant_agg_func.eql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
FROM e IN events
PROJECT INTO count()
2 changes: 2 additions & 0 deletions src/tests/resources/project_agg_value.eql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
FROM e IN events
PROJECT INTO count(e.data.price > 10)
2 changes: 2 additions & 0 deletions src/tests/resources/reject_constant_expr.eql
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
FROM e IN events
PROJECT INTO day(now())
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
---
source: src/tests/analysis.rs
expression: "query.run_static_analysis(&Default::default())"
---
Ok:
attrs:
pos:
line: 1
col: 1
sources:
- binding:
name: e
pos:
line: 1
col: 6
kind:
Name: events
predicate: ~
group_by: ~
order_by: ~
limit: ~
projection:
attrs:
pos:
line: 2
col: 14
value:
App:
func: count
args: []
distinct: false
meta:
project: Number
aggregate: true
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,4 @@ Ok:
distinct: false
meta:
project: String
aggregate: false
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,4 @@ Ok:
traceparent: String
tracestate: String
type: String
aggregate: false
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,4 @@ Ok:
project:
Record:
voters: Number
aggregate: true
Loading