diff --git a/CHANGELOG.md b/CHANGELOG.md index 172804820..8858336aa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -220,9 +220,9 @@ https://github.com/elixir-explorer/explorer/compare/v0.10.1...v0.11.0 - Add initial support for SQL queries. - The `Explorer.DataFrame.sql/3` is a function that accepts a dataframe and - a SQL query. The SQL is not validated by Explorer, so the queries will be - backend dependent. Right now we have only Polars as the backend. + The `Explorer.DataFrame.sql/2` is a function that accepts a map of table names + to DataFrames and a SQL query. The SQL is not validated by Explorer, so the + queries will be backend dependent. Right now we have only Polars as the backend. - Add support for remote series and dataframes. diff --git a/lib/explorer/backend/data_frame.ex b/lib/explorer/backend/data_frame.ex index 5dd78c83d..80da68af6 100644 --- a/lib/explorer/backend/data_frame.ex +++ b/lib/explorer/backend/data_frame.ex @@ -258,15 +258,18 @@ defmodule Explorer.Backend.DataFrame do @callback concat_columns([df], out_df :: df()) :: df @callback concat_rows([df], out_df :: df()) :: df + # SQL + + @callback sql_execute( + tables :: [{binary(), df}], + sql_string :: binary() + ) :: df + # Groups @callback summarise_with(df, out_df :: df(), aggregations :: [{column_name(), lazy_series()}]) :: df - # SQL - - @callback sql(df, sql_string :: binary(), table_name :: binary()) :: df() - # Functions alias Explorer.{DataFrame, Series} diff --git a/lib/explorer/data_frame.ex b/lib/explorer/data_frame.ex index 709bab9d9..c74a4de04 100644 --- a/lib/explorer/data_frame.ex +++ b/lib/explorer/data_frame.ex @@ -6884,7 +6884,10 @@ defmodule Explorer.DataFrame do # SQL @doc """ - Create a dataframe from the result of a SQL query on an existing dataframe. + Execute a SQL query on one or more DataFrames. + + Pass a map of table names to DataFrames and a SQL query string. + Each DataFrame is registered as a table with its corresponding name. > ### SQL Query is Unvalidated {: .warning} > @@ -6892,39 +6895,51 @@ defmodule Explorer.DataFrame do > it directly to the backend. As such, the SQL dialect will be backend > dependent and any errors will come directly from the backend. - ## `from` Clause - - The `from` clause of the SQL query should reference a chosen table name. The - default name is `"df"`. See the examples for a custom table name. - ## Examples - Basic example: + Single DataFrame: iex> df = Explorer.DataFrame.new(a: [1, 2, 3], b: ["x", "y", "y"]) - iex> Explorer.DataFrame.sql(df, "select ARRAY_AGG(a), b from df group by b order by b") + iex> Explorer.DataFrame.sql(%{df: df}, "select ARRAY_AGG(a), b from df group by b order by b") #Explorer.DataFrame< Polars[2 x 2] a list[s64] [[1], [2, 3]] b string ["x", "y"] > - Custom table name: + Multiple DataFrames: - iex> df = Explorer.DataFrame.new(a: [1, 2, 3]) - iex> Explorer.DataFrame.sql(df, "select a + 1 from my_table", table_name: "my_table") + iex> df1 = Explorer.DataFrame.new(id: [1, 2, 3], name: ["Alice", "Bob", "Charlie"]) + iex> df2 = Explorer.DataFrame.new(id: [1, 2, 4], age: [25, 30, 35]) + iex> Explorer.DataFrame.sql(%{users: df1, ages: df2}, "SELECT users.name, ages.age FROM users JOIN ages ON users.id = ages.id") |> Explorer.DataFrame.collect() #Explorer.DataFrame< - Polars[3 x 1] - a s64 [2, 3, 4] + Polars[2 x 2] + name string ["Alice", "Bob"] + age s64 [25, 30] > """ @doc type: :single - @spec sql(df :: DataFrame.t(), sql_string :: binary(), opts :: Keyword.t()) :: + @spec sql(tables :: %{required(atom() | binary()) => DataFrame.t()}, sql_string :: binary()) :: df :: DataFrame.t() - def sql(%__MODULE__{} = df, sql_string, opts \\ []) - when is_binary(sql_string) and is_list(opts) do - [table_name: table_name] = Keyword.validate!(opts, table_name: "df") - Shared.apply_dataframe(df, :sql, [sql_string, table_name]) + def sql(tables, sql_string) + when is_map(tables) and not is_struct(tables) and is_binary(sql_string) do + tables_list = + tables + |> Enum.map(fn {name, %__MODULE__{} = df} -> + {to_string(name), df} + end) + + impl = + case tables_list do + [] -> + backend = Explorer.Backend.get() + :"#{backend}.DataFrame" + + [{_, %__MODULE__{data: %impl_mod{}}} | _] -> + impl_mod + end + + apply(impl, :sql_execute, [tables_list, sql_string]) end # Helpers diff --git a/lib/explorer/polars_backend/data_frame.ex b/lib/explorer/polars_backend/data_frame.ex index d690332ea..6eb46d5c5 100644 --- a/lib/explorer/polars_backend/data_frame.ex +++ b/lib/explorer/polars_backend/data_frame.ex @@ -959,6 +959,23 @@ defmodule Explorer.PolarsBackend.DataFrame do %{out_df | data: out_data} end + # SQL + + @impl true + def sql_execute(tables, sql_string) do + tables_with_df = + Enum.map(tables, fn {name, df} -> + {name, df.data} + end) + + with {:ok, polars_ldf} <- Native.sql_execute(tables_with_df, sql_string), + {:ok, polars_df} <- Native.lf_compute(polars_ldf) do + Shared.create_dataframe!(polars_df) + else + {:error, error} -> raise error + end + end + # Groups @impl true @@ -976,16 +993,6 @@ defmodule Explorer.PolarsBackend.DataFrame do Explorer.Backend.DataFrame.inspect(df, "Polars", n_rows(df), opts) end - # SQL - - @impl true - def sql(%DataFrame{} = df, sql_string, table_name) do - df - |> lazy() - |> LazyFrame.sql(sql_string, table_name) - |> LazyFrame.collect() - end - @impl true def re_dtype(regex_as_string) when is_binary(regex_as_string) do case Explorer.PolarsBackend.Native.df_re_dtype(regex_as_string) do diff --git a/lib/explorer/polars_backend/lazy_frame.ex b/lib/explorer/polars_backend/lazy_frame.ex index 9feb81e9b..fe5e740b4 100644 --- a/lib/explorer/polars_backend/lazy_frame.ex +++ b/lib/explorer/polars_backend/lazy_frame.ex @@ -644,14 +644,21 @@ defmodule Explorer.PolarsBackend.LazyFrame do %{out_df | data: polars_df} end + # SQL + @impl true - def sql(ldf, sql_string, table_name) do - with {:ok, polars_lf} <- Native.lf_sql(ldf.data, sql_string, table_name), - {:ok, names} <- Native.lf_names(polars_lf), - {:ok, dtypes} <- Native.lf_dtypes(polars_lf) do - Explorer.Backend.DataFrame.new(polars_lf, names, dtypes) + def sql_execute(tables, sql_string) do + tables_with_df = + Enum.map(tables, fn {name, df} -> + {name, df.data} + end) + + with {:ok, polars_ldf} <- Native.sql_execute(tables_with_df, sql_string), + {:ok, names} <- Native.lf_names(polars_ldf), + {:ok, dtypes} <- Native.lf_dtypes(polars_ldf) do + Explorer.Backend.DataFrame.new(polars_ldf, names, dtypes) else - {:error, polars_error} -> raise polars_error + {:error, error} -> raise error end end diff --git a/lib/explorer/polars_backend/native.ex b/lib/explorer/polars_backend/native.ex index 9adbcfbaa..c6e8a2bd8 100644 --- a/lib/explorer/polars_backend/native.ex +++ b/lib/explorer/polars_backend/native.ex @@ -308,8 +308,6 @@ defmodule Explorer.PolarsBackend.Native do def lf_to_ipc(_df, _filename, _compression, _streaming), do: err() def lf_to_ipc_cloud(_df, _cloud_entry, _compression), do: err() def lf_to_csv(_df, _filename, _header, _delimiter, _quote_style, _streaming), do: err() - def lf_sql(_df, _sql_string, _table_name), do: err() - # Series def s_as_str(_s), do: err() def s_add(_s, _other), do: err() @@ -520,11 +518,13 @@ defmodule Explorer.PolarsBackend.Native do def s_join(_s, _separator), do: err() def s_lengths(_s), do: err() def s_member(_s, _value, _inner_dtype), do: err() - def s_field(_s, _name), do: err() def s_json_decode(_s, _dtype), do: err() def s_json_path_match(_s, _json_path), do: err() def s_index_of(_s, _v), do: err() + # SQL + def sql_execute(_tables, _query), do: err() + defp err, do: :erlang.nif_error(:nif_not_loaded) end diff --git a/native/explorer/src/lazyframe.rs b/native/explorer/src/lazyframe.rs index b779c1a3c..cbfb90cfb 100644 --- a/native/explorer/src/lazyframe.rs +++ b/native/explorer/src/lazyframe.rs @@ -477,18 +477,22 @@ pub fn lf_concat_columns(ldfs: Vec) -> Result, + query: String, ) -> Result { - let mut ctx = polars::sql::SQLContext::new(); + use polars::prelude::IntoLazy; + use polars::sql::SQLContext; - let lf = lf.clone_inner(); - ctx.register(table_name, lf); + let mut ctx = SQLContext::new(); - match ctx.execute(sql_string) { - Ok(lf_sql) => Ok(ExLazyFrame::new(lf_sql)), + for (name, df) in tables { + let ldf = df.clone_inner().lazy(); + ctx.register(&name, ldf); + } + + match ctx.execute(&query) { + Ok(lazy_frame) => Ok(ExLazyFrame::new(lazy_frame)), Err(polars_error) => Err(ExplorerError::Polars(polars_error)), } } diff --git a/test/explorer/data_frame_sql_test.exs b/test/explorer/data_frame_sql_test.exs new file mode 100644 index 000000000..71f8edd06 --- /dev/null +++ b/test/explorer/data_frame_sql_test.exs @@ -0,0 +1,233 @@ +defmodule Explorer.DataFrameSQLTest do + use ExUnit.Case, async: true + + alias Explorer.DataFrame, as: DF + + describe "sql/2 with single DataFrame" do + test "executes SQL query with table name" do + df = DF.new(a: [1, 2, 3], b: ["x", "y", "y"]) + + result = DF.sql(%{df: df}, "select ARRAY_AGG(a), b from df group by b order by b") + + assert result != nil + result = DF.collect(result) + assert DF.n_rows(result) == 2 + assert DF.names(result) == ["a", "b"] + end + + test "executes SQL query with custom table name" do + df = DF.new(a: [1, 2, 3]) + + result = DF.sql(%{my_table: df}, "select a + 1 from my_table") + + assert result != nil + result = DF.collect(result) + assert DF.n_rows(result) == 3 + assert DF.names(result) == ["a"] + end + + test "executes SQL query with WHERE clause" do + df = DF.new(id: [1, 2, 3, 4, 5], value: [10, 20, 30, 40, 50]) + + result = DF.sql(%{df: df}, "select id, value from df where id > 2") + + result = DF.collect(result) + assert DF.n_rows(result) == 3 + assert DF.to_columns(result, atom_keys: true) == %{id: [3, 4, 5], value: [30, 40, 50]} + end + + test "executes SQL query with ORDER BY clause" do + df = DF.new(name: ["Alice", "Bob", "Charlie"], age: [30, 25, 35]) + + result = DF.sql(%{df: df}, "select name, age from df order by age") + + result = DF.collect(result) + assert DF.n_rows(result) == 3 + + assert DF.to_columns(result, atom_keys: true) == %{ + name: ["Bob", "Alice", "Charlie"], + age: [25, 30, 35] + } + end + end + + describe "sql/2 with multiple DataFrames" do + test "executes SQL query on single registered DataFrame" do + df1 = DF.new(column_a: [1, 2, 3]) + + result = + DF.sql(%{t1: df1}, "select 2 * t.column_a as column_2a from t1 as t where t.column_a < 3") + + assert result != nil + result = DF.collect(result) + assert DF.n_rows(result) == 2 + assert DF.to_columns(result, atom_keys: true) == %{column_2a: [2, 4]} + end + + test "executes SQL query with JOIN between two DataFrames" do + df1 = DF.new(id: [1, 2, 3], name: ["Alice", "Bob", "Charlie"]) + df2 = DF.new(id: [1, 2, 4], age: [25, 30, 35]) + + result = + DF.sql( + %{users: df1, ages: df2}, + "SELECT users.name, ages.age FROM users JOIN ages ON users.id = ages.id" + ) + + assert result != nil + result = DF.collect(result) + assert DF.n_rows(result) == 2 + assert DF.names(result) == ["name", "age"] + assert DF.to_columns(result, atom_keys: true) == %{name: ["Alice", "Bob"], age: [25, 30]} + end + + test "executes SQL query with LEFT JOIN" do + df1 = DF.new(id: [1, 2, 3], name: ["Alice", "Bob", "Charlie"]) + df2 = DF.new(id: [1, 2], age: [25, 30]) + + result = + DF.sql( + %{users: df1, ages: df2}, + "SELECT users.name, ages.age FROM users LEFT JOIN ages ON users.id = ages.id ORDER BY users.id" + ) + + result = DF.collect(result) + assert DF.n_rows(result) == 3 + end + + test "executes SQL query with multiple table references" do + df1 = DF.new(a: [1, 2, 3]) + df2 = DF.new(b: [1, 2, 3]) + df3 = DF.new(c: [1, 2, 3]) + + result = + DF.sql( + %{t1: df1, t2: df2, t3: df3}, + "SELECT t1.a, t2.b, t3.c FROM t1 JOIN t2 ON t1.a = t2.b JOIN t3 ON t1.a = t3.c" + ) + + result = DF.collect(result) + assert DF.n_rows(result) == 3 + assert DF.names(result) == ["a", "b", "c"] + end + + test "executes SQL query with aggregation across tables" do + df1 = DF.new(category: ["A", "A", "B", "B"], value: [10, 20, 30, 40]) + df2 = DF.new(category: ["A", "B"], multiplier: [2, 3]) + + result = + DF.sql( + %{data: df1, factors: df2}, + "SELECT data.category, SUM(data.value * factors.multiplier) as total FROM data JOIN factors ON data.category = factors.category GROUP BY data.category ORDER BY data.category" + ) + + result = DF.collect(result) + assert DF.n_rows(result) == 2 + assert DF.names(result) == ["category", "total"] + end + + test "executes SQL query with string table names (not atoms)" do + df1 = DF.new(id: [1, 2, 3], name: ["Alice", "Bob", "Charlie"]) + df2 = DF.new(id: [1, 2, 4], age: [25, 30, 35]) + + result = + DF.sql( + %{"users" => df1, "ages" => df2}, + "SELECT users.name FROM users JOIN ages ON users.id = ages.id" + ) + + result = DF.collect(result) + assert DF.n_rows(result) == 2 + assert DF.names(result) == ["name"] + end + end + + describe "sql/2 without registered tables" do + test "executes SQL query without any DataFrame registered" do + result = DF.sql(%{}, "select 1 as column_a union all select 2 as column_a") + + assert result != nil + result = DF.collect(result) + assert DF.n_rows(result) == 2 + assert DF.to_columns(result, atom_keys: true) == %{column_a: [1, 2]} + end + + test "executes SQL query with only literal values" do + result = DF.sql(%{}, "select 1 + 2 as sum, 'hello' as greeting") + + result = DF.collect(result) + assert DF.n_rows(result) == 1 + assert DF.to_columns(result, atom_keys: true) == %{sum: [3], greeting: ["hello"]} + end + end + + describe "error handling" do + test "raises error for invalid SQL syntax" do + df = DF.new(a: [1, 2, 3]) + + assert_raise RuntimeError, fn -> + DF.sql(%{t1: df}, "select from invalid syntax") + end + end + + test "raises error when referencing non-existent table" do + df = DF.new(a: [1, 2, 3]) + + assert_raise RuntimeError, fn -> + DF.sql(%{t1: df}, "select * from nonexistent_table") + end + end + + test "raises error when referencing non-existent column" do + df = DF.new(a: [1, 2, 3]) + + assert_raise RuntimeError, fn -> + DF.sql(%{t1: df}, "select nonexistent_column from t1") + end + end + end + + describe "complex SQL operations" do + test "executes SQL query with WHERE clause and comparison" do + df = DF.new(id: [1, 2, 3, 4, 5], value: [10, 20, 30, 40, 50]) + + result = + DF.sql( + %{data: df}, + "SELECT id, value FROM data WHERE value > 30" + ) + + result = DF.collect(result) + assert DF.n_rows(result) == 2 + assert DF.to_columns(result, atom_keys: true) == %{id: [4, 5], value: [40, 50]} + end + + test "executes SQL query with CASE expression" do + df = DF.new(value: [10, 20, 30, 40, 50]) + + result = + DF.sql( + %{data: df}, + "SELECT value, CASE WHEN value < 20 THEN 'low' WHEN value < 40 THEN 'medium' ELSE 'high' END as category FROM data" + ) + + result = DF.collect(result) + assert DF.n_rows(result) == 5 + assert DF.names(result) == ["value", "category"] + end + + test "executes SQL query with UNION" do + df1 = DF.new(id: [1, 2], type: ["A", "A"]) + df2 = DF.new(id: [3, 4], type: ["B", "B"]) + + result = + DF.sql( + %{a: df1, b: df2}, + "SELECT id, type FROM a UNION ALL SELECT id, type FROM b ORDER BY id" + ) + + result = DF.collect(result) + assert DF.n_rows(result) == 4 + end + end +end