-
Notifications
You must be signed in to change notification settings - Fork 64
Expand file tree
/
Copy pathmain.rs
More file actions
208 lines (197 loc) · 5.11 KB
/
main.rs
File metadata and controls
208 lines (197 loc) · 5.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
//! This module contains the main entrypoint to the tangram cli.
use clap::Parser;
use colored::Colorize;
use std::path::PathBuf;
use tracing_subscriber::prelude::*;
#[cfg(feature = "tangram_app")]
mod app;
#[cfg(feature = "tangram_app")]
mod migrate;
#[cfg(feature = "train")]
mod predict;
#[cfg(feature = "serve")]
mod serve;
#[cfg(feature = "train")]
mod train;
#[derive(Parser)]
#[clap(
version = concat!(env!("CARGO_PKG_VERSION")),
about = "Train and deploy a machine learning model in minutes.",
disable_help_subcommand = true,
)]
struct Args {
#[clap(subcommand)]
subcommand: Subcommand,
}
#[derive(Parser)]
enum Subcommand {
#[cfg(feature = "train")]
#[clap(name = "train")]
Train(Box<TrainArgs>),
#[cfg(feature = "train")]
#[clap(name = "predict")]
Predict(Box<PredictArgs>),
#[cfg(feature = "tangram_app")]
#[clap(name = "app")]
App(Box<AppArgs>),
#[cfg(feature = "tangram_app")]
#[clap(name = "migrate")]
Migrate(Box<MigrateArgs>),
#[cfg(feature = "serve")]
#[clap(name = "serve")]
Serve(Box<ServeArgs>),
}
#[cfg(feature = "train")]
#[derive(Parser)]
#[clap(
about = "Train a model.",
long_about = "Train a model from a csv file."
)]
pub struct TrainArgs {
#[clap(
short,
long,
help = "the path to your .csv file",
conflicts_with_all=&["file-train", "file-test"],
)]
file: Option<PathBuf>,
#[clap(
long,
help = "the path to your .csv file used for training",
requires = "file-test"
)]
file_train: Option<PathBuf>,
#[clap(
long,
help = "the path to your .csv file used for testing",
requires = "file-train"
)]
file_test: Option<PathBuf>,
#[clap(long, help = "Pass the training data via stdin.")]
stdin: bool,
#[clap(short, long, help = "the name of the column to predict")]
target: String,
#[clap(short, long, help = "the path to a config file")]
config: Option<PathBuf>,
#[clap(short, long, help = "the path to write the .tangram file to")]
output: Option<PathBuf>,
#[clap(
long = "no-progress",
help = "disable the cli progress view",
parse(from_flag = std::ops::Not::not),
)]
progress: bool,
}
#[cfg(feature = "train")]
#[derive(Parser)]
#[clap(
about = "Make predictions with a model.",
long_about = "Make predictions with a model on the command line from a csv file."
)]
pub struct PredictArgs {
#[clap(
short,
long,
help = "the path to read examples from, defaults to stdin"
)]
file: Option<PathBuf>,
#[clap(short, long, help = "the path to the model to make predictions with")]
model: PathBuf,
#[clap(
short,
long,
help = "the path to write the predictions to, defaults to stdout"
)]
output: Option<PathBuf>,
#[clap(
short,
long,
help = "output probabilities instead of class labels, only relevant for classifier models"
)]
probabilities: Option<bool>,
#[clap(short, long, help = "The threshold value to use for predictions.")]
threshold: Option<f32>,
}
#[cfg(feature = "tangram_app")]
#[derive(Parser)]
#[clap(about = "Run the app.", long_about = "Run the app.")]
pub struct AppArgs {
#[clap(short, long = "config")]
config: Option<PathBuf>,
}
#[cfg(feature = "tangram_app")]
#[derive(Parser)]
#[clap(
about = "Migrate your app database.",
long_about = "Migrate your app database to the latest version."
)]
pub struct MigrateArgs {
#[clap(long)]
database_url: Option<String>,
}
#[cfg(feature = "serve")]
#[derive(Parser)]
#[clap(
about = "Serve predictions via HTTP",
long_about = "Create HTTP server exposing an endpoint for running predictions against a Tangram model"
)]
pub struct ServeArgs {
#[clap(
short,
long,
default_value = "127.0.0.1",
help = "Host IP at which to bind the server"
)]
address: String,
#[clap(
short,
long,
help = "Path to the `.tangram` file containing the model to serve"
)]
model: PathBuf,
#[clap(short, long, default_value = "8080", help = "Port to listen on")]
port: u16,
}
fn main() {
setup_tracing();
let args = Args::parse();
let result = match args.subcommand {
#[cfg(feature = "train")]
Subcommand::Train(args) => self::train::train(*args),
#[cfg(feature = "train")]
Subcommand::Predict(args) => self::predict::predict(*args),
#[cfg(feature = "tangram_app")]
Subcommand::App(args) => self::app::app(*args),
#[cfg(feature = "tangram_app")]
Subcommand::Migrate(args) => self::migrate::migrate(*args),
#[cfg(feature = "serve")]
Subcommand::Serve(args) => self::serve::serve(*args),
};
if let Err(error) = result {
eprintln!("{}: {:#}", "error".red().bold(), error);
std::process::exit(1);
}
}
fn setup_tracing() {
let env_layer = tracing_subscriber::EnvFilter::try_from_env("TANGRAM_TRACING");
let env_layer = if cfg!(debug_assertions) {
Some(env_layer.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("[]=info")))
} else {
env_layer.ok()
};
if let Some(env_layer) = env_layer {
if cfg!(debug_assertions) {
let format_layer = tracing_subscriber::fmt::layer().pretty();
let subscriber = tracing_subscriber::registry()
.with(env_layer)
.with(format_layer);
subscriber.init();
} else {
let json_layer = tracing_subscriber::fmt::layer().json();
let subscriber = tracing_subscriber::registry()
.with(env_layer)
.with(json_layer);
subscriber.init();
}
}
}