forked from SciSharp/BotSharp
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSqlValidateFn.cs
More file actions
82 lines (70 loc) · 3.17 KB
/
SqlValidateFn.cs
File metadata and controls
82 lines (70 loc) · 3.17 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
using BotSharp.Abstraction.Instructs.Options;
using BotSharp.Abstraction.Models;
namespace BotSharp.Plugin.SqlDriver.Functions;
public class SqlValidateFn : IFunctionCallback
{
public string Name => "validate_sql";
public string Indication => "Performing data validate operation.";
private readonly IServiceProvider _services;
private readonly ILogger _logger;
public SqlValidateFn(IServiceProvider services, ILogger<SqlValidateFn> logger)
{
_services = services;
_logger = logger;
}
public async Task<bool> Execute(RoleDialogModel message)
{
// remove comments start with "--"
string pattern = @"--.*";
string sql = Regex.Replace(message.Content, pattern, string.Empty);
pattern = @"```sql\s*([\s\S]*?)\s*```";
sql = Regex.Match(sql, pattern)?.Value;
if (!Regex.IsMatch(sql, pattern))
{
return false;
}
sql = Regex.Match(sql, pattern).Groups[1].Value;
message.Content = sql;
var dbHook = _services.GetRequiredService<IText2SqlHook>();
var dbType = dbHook.GetDatabaseType(message);
var validateSql = dbType.ToLower() switch
{
"mysql" => $"EXPLAIN\r\n{sql.Replace("SET ", "-- SET ", StringComparison.InvariantCultureIgnoreCase).Replace(";", "; EXPLAIN ").TrimEnd("EXPLAIN ".ToCharArray())}",
"sqlserver" or "mssql" => $"SET PARSEONLY ON;\r\n{sql}\r\nSET PARSEONLY OFF;",
"redshift" => $"explain\r\n{sql}",
_ => throw new NotImplementedException($"Database type {dbType} is not supported.")
};
var msgCopy = RoleDialogModel.From(message);
msgCopy.FunctionArgs = JsonSerializer.Serialize(new ExecuteQueryArgs
{
SqlStatements = [validateSql]
});
var fn = _services.GetRequiredService<IRoutingService>();
await fn.InvokeFunction("execute_sql", msgCopy);
if (msgCopy.Data != null && msgCopy.Data is DbException ex)
{
var settingService = _services.GetRequiredService<ISettingService>();
var instructService = _services.GetRequiredService<IInstructService>();
var agentService = _services.GetRequiredService<IAgentService>();
var states = _services.GetRequiredService<IConversationStateService>();
var query = "Correct SQL Statement and keep the comments/explanations";
var ddl = states.GetState("table_ddls");
var correctedSql = await instructService.Instruct<string>(query,
new InstructOptions
{
Provider = "openai",
Model = settingService.GetUpgradeModel(Gpt4xModelConstants.GPT_4o),
AgentId = BuiltInAgentId.SqlDriver,
TemplateName = "sql_statement_correctness",
Data = new Dictionary<string, object>
{
{ "original_sql", message.Content },
{ "error_message", ex.Message },
{ "table_structure", ddl }
}
});
message.Content = correctedSql;
}
return true;
}
}